diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index a7b325ca2ba..c6bf8c79e78 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -73,7 +73,9 @@ def is_applicable_for_batch_dim( >= real_batch_dim.prefill_req_count + real_batch_dim.decode_req_count ) - def is_valid(self, max_requests: int, max_sequence_length: int) -> bool: + def is_valid( + self, max_requests: int, max_sequence_length: int, num_speculative_tokens: int + ) -> bool: """ Checks if the batch dimension is valid based on resource constraints. @@ -92,11 +94,17 @@ def is_valid(self, max_requests: int, max_sequence_length: int) -> bool: return False # Check if token count is sufficient for requests - if self.token_count < self.prefill_req_count + self.decode_req_count: + if self.token_count < self.prefill_req_count + self.decode_req_count * ( + num_speculative_tokens + 1 + ): return False # Check if the prefill requests are shorter than the max sequence length - if self.token_count > self.prefill_req_count * max_sequence_length + self.decode_req_count: + if ( + self.token_count + > self.prefill_req_count * max_sequence_length + + self.decode_req_count * (num_speculative_tokens + 1) + ): return False return True @@ -308,6 +316,7 @@ def generate_cuda_graph_batch_dimensions_list( max_tokens: int, max_sequence_length: int, use_cuda_graphs_for_non_decode_steps: bool, + num_speculative_tokens: int = 0, ) -> Tuple[List[InferenceBatchDimensions], Optional[List[int]]]: """ Generate CUDA graph batch dimensions. @@ -344,6 +353,7 @@ def generate_cuda_graph_batch_dimensions_list( max_tokens: Maximum total tokens max_sequence_length: Maximum sequence length use_cuda_graphs_for_non_decode_steps: Whether to use CUDA graphs for non-decode steps + num_speculative_tokens: Number of speculative tokens Returns: Tuple containing: @@ -355,7 +365,7 @@ def generate_cuda_graph_batch_dimensions_list( def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int) -> None: """Helper to create and append batch dimension to list only if it's valid.""" batch_dim = InferenceBatchDimensions(token_count, prefill_req_count, decode_req_count) - if batch_dim.is_valid(max_requests, max_sequence_length): + if batch_dim.is_valid(max_requests, max_sequence_length, num_speculative_tokens): cuda_graph_batch_dimensions_list.append(batch_dim) # Cuda graph token-counts @@ -372,9 +382,10 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): cuda_graph_max_tokens = max_tokens - assert cuda_graph_max_tokens == max_requests, ( - f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests " - f"({max_requests}). This is required for correctly syncing EP ranks: " + assert cuda_graph_max_tokens == max_requests * (num_speculative_tokens + 1), ( + f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests *" + f"(num_speculative_tokens + 1) ({max_requests * (num_speculative_tokens + 1)}). " + "This is required for correctly syncing EP ranks: " f"prefill and decode graph pools must have the same token count granularity." ) @@ -395,8 +406,9 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ) # Calculate separate token counts for decode-only graphs. - # Decode graphs can be more conservative since each request uses exactly 1 token. - cuda_graph_max_tokens_decode = min(cuda_graph_max_tokens, max_requests) + cuda_graph_max_tokens_decode = min( + cuda_graph_max_tokens, max_requests * (num_speculative_tokens + 1) + ) cuda_graph_decode_token_counts = ( CUDAGraphBatchDimensionBuilder._calculate_cuda_graph_token_counts( tp_size=tp_size, @@ -415,20 +427,29 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): # decode only # Use decode-specific token counts for decode-only graphs for size in cuda_graph_decode_token_counts: + decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) + token_count = decode_req_count * (num_speculative_tokens + 1) + token_count = token_count // tp_size * tp_size add_if_valid( - token_count=min(size, max_requests), - prefill_req_count=0, - decode_req_count=min(size, max_requests), + token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count ) else: # Mixed prefill and decode mode # Create prefill and mixed dimensions with full token counts for size in cuda_graph_prefill_token_counts: + assert size % tp_size == 0 + prefill_req_count = min(cuda_graph_mixed_prefill_request_count, max_requests) + decode_req_count = max( + 0, + min( + (size - prefill_req_count) // (num_speculative_tokens + 1), + max_requests - prefill_req_count, + ), + ) add_if_valid( token_count=size, - prefill_req_count=min(cuda_graph_mixed_prefill_request_count, max_requests), - decode_req_count=min(size, max_requests) - - min(cuda_graph_mixed_prefill_request_count, max_requests), + prefill_req_count=prefill_req_count, + decode_req_count=decode_req_count, ) # We need to ensure the prefill requests are shorter than the max sequence length, # considering the one decode token is used for prefill request construction @@ -445,16 +466,21 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int # Create decode-only dimensions with optimized token counts for size in cuda_graph_decode_token_counts: + decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) + token_count = decode_req_count * (num_speculative_tokens + 1) + token_count = token_count // tp_size * tp_size add_if_valid( - token_count=min(size, max_requests), - prefill_req_count=0, - decode_req_count=min(size, max_requests), + token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count ) # Remove duplicates and sort by prefill token count cuda_graph_batch_dimensions_list = list(set(cuda_graph_batch_dimensions_list)) cuda_graph_batch_dimensions_list.sort( - key=lambda x: ((x.token_count - x.decode_req_count), x.decode_req_count), reverse=True + key=lambda x: ( + (x.token_count - x.decode_req_count * (num_speculative_tokens + 1)), + x.decode_req_count, + ), + reverse=True, ) # Collect actual token counts from batch dimensions, then unique and sort diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index ac7fb85c57b..bc0770d450d 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -228,6 +228,9 @@ class InferenceConfig: enable_chunked_prefill: bool = False """Whether to enable chunked prefill.""" + num_speculative_tokens: int = 0 + """The number of speculative tokens to generate for decode steps.""" + enable_prefix_caching: bool = False """Whether to enable prefix caching for KV cache block sharing.""" diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index bacaf882944..34a19cf0394 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -259,12 +259,17 @@ def update( self.cu_seqlens = self._cu_seqlens_buffer[: padded_prefill_count + 1] if padded_decode_count > 0 and padded_prefill_count > 0: - self._device_decode_prefill_buffer[0] = real_decode_count + self._device_decode_prefill_buffer[0] = cu_seqlens[real_decode_count] # This describes the number of items in the prefill tensor relative to the # decode tensor. If chunked prefill is present, it is included in the # "prefill" part of the main split. - self._device_decode_prefill_buffer[1] = regular_prefill_count + ( - 1 if has_chunked_prefill_req else 0 + self._device_decode_prefill_buffer[1] = ( + cu_seqlens[ + real_decode_count + + regular_prefill_count + + (1 if has_chunked_prefill_req else 0) + ] + - cu_seqlens[real_decode_count] ) self.device_decode_prefill = self._device_decode_prefill_buffer diff --git a/megatron/core/inference/contexts/attention_context/mha_metadata.py b/megatron/core/inference/contexts/attention_context/mha_metadata.py index 1b6e8020275..07f8a349b51 100644 --- a/megatron/core/inference/contexts/attention_context/mha_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mha_metadata.py @@ -41,6 +41,7 @@ def update( request_to_kv_block_ids: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, + num_speculative_tokens: int = 0, ): """ Args: @@ -49,6 +50,7 @@ def update( request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) batch_dimensions: Configuration object containing real batch settings padded_batch_dimensions: Configuration object containing padded batch settings + num_speculative_tokens: Number of speculative tokens """ # Extract values from configs real_batch_size = batch_dimensions.req_count @@ -99,7 +101,7 @@ def update( ) if padded_batch_dimensions.prefill_req_count == 0: - self._max_seqlen_q = 1 + self._max_seqlen_q = num_speculative_tokens + 1 else: # Make sure we will launch the prefill kernel for prefill graphs self._max_seqlen_q = max(2, padded_batch_dimensions.token_count) @@ -150,6 +152,7 @@ def update( request_to_kv_block_ids: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, + num_speculative_tokens: int = 0, ): """ Args: @@ -158,6 +161,7 @@ def update( request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) batch_dimensions: Configuration object containing real batch settings padded_batch_dimensions: Configuration object containing padded batch settings + num_speculative_tokens: Number of speculative tokens """ super().update( request_query_lengths, @@ -165,6 +169,7 @@ def update( request_to_kv_block_ids, batch_dimensions, padded_batch_dimensions, + num_speculative_tokens, ) def reset(self): @@ -183,6 +188,7 @@ def update( request_to_kv_block_ids: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, + num_speculative_tokens: int = 0, ): """ Args: @@ -191,6 +197,7 @@ def update( request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) batch_dimensions: Configuration object containing real batch settings padded_batch_dimensions: Configuration object containing padded batch settings + num_speculative_tokens: Number of speculative tokens """ super().update( request_query_lengths, @@ -198,10 +205,11 @@ def update( request_to_kv_block_ids, batch_dimensions, padded_batch_dimensions, + num_speculative_tokens, ) if len(self.state_data["query_lengths"]) > 0: self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item() self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item() else: - self.state_data["max_seqlen_q"] = 1 + self.state_data["max_seqlen_q"] = num_speculative_tokens + 1 self.state_data["max_seqlen_k"] = 1 diff --git a/megatron/core/inference/contexts/dynamic_block_allocator.py b/megatron/core/inference/contexts/dynamic_block_allocator.py index abfb7278b14..5bbf7001094 100644 --- a/megatron/core/inference/contexts/dynamic_block_allocator.py +++ b/megatron/core/inference/contexts/dynamic_block_allocator.py @@ -85,19 +85,39 @@ def get_total_used(self): def get_active_used(self): """Compute number of active blocks used.""" - return ( - self.context.request_kv_block_counts[ - self.context.paused_request_count : self.context.total_request_count - ] - .sum() - .item() - ) + if not self.enable_prefix_caching: + return ( + self.context.request_kv_block_counts[ + self.context.paused_request_count : self.context.total_request_count + ] + .sum() + .item() + ) + + active_start = self.context.paused_request_count + active_end = self.context.total_request_count + if active_end > active_start: + active_rows = self.context.request_to_kv_block_ids[active_start:active_end] + valid_ids = active_rows[active_rows >= 0] + if valid_ids.numel() > 0: + return int(torch.unique(valid_ids).numel()) + return 0 def get_paused_used(self): """Compute number of paused blocks used.""" - return ( - self.context.request_kv_block_counts[: self.context.paused_request_count].sum().item() - ) + if not self.enable_prefix_caching: + return ( + self.context.request_kv_block_counts[: self.context.paused_request_count] + .sum() + .item() + ) + + if self.context.paused_request_count > 0: + paused_rows = self.context.request_to_kv_block_ids[: self.context.paused_request_count] + valid_ids = paused_rows[paused_rows >= 0] + if valid_ids.numel() > 0: + return int(torch.unique(valid_ids).numel()) + return 0 def get_active_avail(self): """Compute number of active blocks available.""" diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index dd7af272546..28559943481 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -279,6 +279,12 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC else: self.num_attention_heads_per_partition = 1 + self.num_speculative_tokens = inference_config.num_speculative_tokens + assert self.num_speculative_tokens < inference_config.block_size_tokens, ( + f"num_speculative_tokens ({self.num_speculative_tokens}) must be < " + f"block_size_tokens ({inference_config.block_size_tokens})" + ) + # Cache the PP group we should use for PP collectives inside the context. # If the model provides a pg_collection with a pp group, prefer it. # Otherwise: @@ -360,6 +366,15 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC math.prod(self.mamba_ssm_states_shape) * self.mamba_ssm_states_dtype.itemsize ) mamba_states_memory_per_request *= self.num_mamba_layers + if self.num_speculative_tokens > 0: + # Add memory for intermediate conv and SSM states + intermediate_memory_per_request = ( + math.prod(self.mamba_conv_states_shape) * self.mamba_conv_states_dtype.itemsize + + math.prod(self.mamba_ssm_states_shape) * self.mamba_ssm_states_dtype.itemsize + ) + intermediate_memory_per_request *= self.num_mamba_layers + intermediate_memory_per_request *= self.num_speculative_tokens + 1 + mamba_states_memory_per_request += intermediate_memory_per_request # Unified memory and general tensor management. self.unified_memory_level = inference_config.unified_memory_level @@ -532,12 +547,13 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, num_cuda_graphs=inference_config.num_cuda_graphs, - cuda_graph_max_tokens=self.max_requests, + cuda_graph_max_tokens=self.max_requests * (self.num_speculative_tokens + 1), cuda_graph_mixed_prefill_request_count=inference_config.cuda_graph_mixed_prefill_count, max_requests=self.max_requests, max_tokens=self.max_tokens, max_sequence_length=self.max_sequence_length, use_cuda_graphs_for_non_decode_steps=self.use_cuda_graphs_for_non_decode_steps, + num_speculative_tokens=self.num_speculative_tokens, ) ) @@ -608,6 +624,9 @@ def _allocate_memory_buffer(self): def _allocate_mamba_states(self): """Allocate Mamba states for hybrid models.""" if self.is_hybrid_model: + self.mamba_metadata = MambaMetadata( + max_requests=self.max_requests, max_tokens=self.max_tokens + ) self.mamba_conv_states = torch.empty( (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, dtype=self.mamba_conv_states_dtype, @@ -618,6 +637,27 @@ def _allocate_mamba_states(self): dtype=self.mamba_ssm_states_dtype, device=torch.cuda.current_device(), ) + if self.num_speculative_tokens > 0: + self.mamba_intermediate_conv_states = torch.empty( + ( + self.num_mamba_layers, + self.max_requests, + self.num_speculative_tokens + 1, + *self.mamba_conv_states_shape, + ), + dtype=self.mamba_conv_states_dtype, + device=torch.cuda.current_device(), + ) + self.mamba_intermediate_ssm_states = torch.empty( + ( + self.num_mamba_layers, + self.max_requests, + self.num_speculative_tokens + 1, + *self.mamba_ssm_states_shape, + ), + dtype=self.mamba_ssm_states_dtype, + device=torch.cuda.current_device(), + ) if ( self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD and not self._uses_torch_memory_saver @@ -631,6 +671,19 @@ def _allocate_mamba_states(self): self._offloadable_cpu_backups["mamba_ssm_states"] = torch.empty_like( self.mamba_ssm_states, device="cpu" ).pin_memory() + if self.num_speculative_tokens > 0: + self._offloadable_tensor_names.add("mamba_intermediate_conv_states") + self._offloadable_cpu_backups["mamba_intermediate_conv_states"] = ( + torch.empty_like( + self.mamba_intermediate_conv_states, device="cpu" + ).pin_memory() + ) + self._offloadable_tensor_names.add("mamba_intermediate_ssm_states") + self._offloadable_cpu_backups["mamba_intermediate_ssm_states"] = ( + torch.empty_like( + self.mamba_intermediate_ssm_states, device="cpu" + ).pin_memory() + ) else: self.mamba_metadata = None @@ -655,6 +708,8 @@ def initialize_all_tensors(self) -> None: ) # request_query_lengths is the input prompt tokens length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) self.request_query_lengths = torch.empty_like(self.request_ids) + # True only for a new request , then after a forward pass it is set to False + self.request_in_prefill_status_tensor = torch.empty_like(self.request_ids) # request_output_lengths is len(input_prompt_tokens) + num_tokens_to_generate self.request_output_lengths = torch.empty_like(self.request_ids) # request_kv_length_offsets is the same as query length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) @@ -949,13 +1004,19 @@ def key_value_cache(self, layer_number: int) -> Tuple[Tensor, Optional[Tensor], self.active_attn_metadata["mha_metadata"].state_data["block_table"], ) - def mamba_states_cache(self, layer_number: int) -> Tuple[Tensor, Tensor]: + def mamba_states_cache( + self, layer_number: int, intermediate: bool = False + ) -> Tuple[Tensor, Tensor]: """Returns the Mamba state tensors for the given layer.""" assert self.is_hybrid_model, "Only hybrid models have Mamba state tensors" mamba_layer_number = self.layer_map[layer_number - 1] - conv_state = self.mamba_conv_states[mamba_layer_number] - ssm_state = self.mamba_ssm_states[mamba_layer_number] + if intermediate: + conv_state = self.mamba_intermediate_conv_states[mamba_layer_number] + ssm_state = self.mamba_intermediate_ssm_states[mamba_layer_number] + else: + conv_state = self.mamba_conv_states[mamba_layer_number] + ssm_state = self.mamba_ssm_states[mamba_layer_number] return (conv_state, ssm_state) @@ -1146,6 +1207,7 @@ def add_dummy_requests_parallel( self.request_ids[request_slice] = request_ids_tensor self.request_query_lengths[request_slice] = lengths_tensor + self.request_in_prefill_status_tensor[request_slice] = 1 self.request_output_lengths[request_slice] = lengths_tensor + tokens_to_generate_tensor self.request_kv_length_offsets[request_slice] = 0 self.request_kv_block_counts[request_slice] = block_counts @@ -1226,11 +1288,15 @@ def add_dummy_requests_for_cudagraph_capture( Adds dummy requests to reflect the number of prefill and decode requests in the graph config. These are using during cuda graph captures. """ - prefill_tokens = graph_dimensions.token_count - graph_dimensions.decode_req_count + prefill_tokens = graph_dimensions.token_count - ( + graph_dimensions.decode_req_count * (self.num_speculative_tokens + 1) + ) # Pre-construct shared objects (safe due to deep copy in DynamicInferenceRequest.__post_init__) shared_sampling_params = SamplingParams(num_tokens_to_generate=1, termination_id=-1) - shared_decode_tokens = torch.zeros(1, dtype=torch.long, device=torch.cuda.current_device()) + shared_decode_tokens = torch.zeros( + self.num_speculative_tokens + 1, dtype=torch.long, device=torch.cuda.current_device() + ) decode_requests = [ DynamicInferenceRequest( @@ -1292,31 +1358,44 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: pass can run without error. """ - smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) + smallest_cuda_graph_dimensions = min( + [x for x in self.cuda_graph_batch_dimensions_list if x.prefill_req_count == 0] + ) # the smallest cuda graph is decode only. assert smallest_cuda_graph_dimensions.prefill_req_count == 0 N = smallest_cuda_graph_dimensions.decode_req_count + tokens_per_request = self.num_speculative_tokens + 1 + T = smallest_cuda_graph_dimensions.token_count # N * tokens_per_request dummy_block_idx = self.block_allocator.dummy_block_idx - # 1. Request counts and token count (decode-only: 1 token per request). + # 1. Request counts and token count. + # With speculative decoding each decode request has (num_speculative_tokens + 1) tokens. self.total_request_count = N - self.active_token_count = N + self.active_token_count = T self.num_prefill_requests = 0 # 2. Per-request state consumed by mha_metadata.update(). - self.request_query_lengths[0:N].fill_(1) + self.request_query_lengths[0:N].fill_(tokens_per_request) self.request_kv_length_offsets[0:N].fill_(0) self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx # 3. Token-level state consumed by the triton KV append kernel. - self.token_to_block_idx[0:N] = dummy_block_idx - self.token_to_local_position_within_kv_block[0:N] = 0 + self.token_to_block_idx[0:T] = dummy_block_idx + self.token_to_local_position_within_kv_block[0:T] = ( + torch.arange(T, device=self.token_to_block_idx.device) % tokens_per_request + ) if self.is_hybrid_model: # 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models. - self.token_to_request_idx[0:N] = torch.arange( - 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype + self.token_to_request_idx[0:T] = torch.repeat_interleave( + torch.arange( + 0, + N, + device=self.token_to_request_idx.device, + dtype=self.token_to_request_idx.dtype, + ), + tokens_per_request, ) # 5. Mamba state: allocate slots for dummy requests. @@ -1333,8 +1412,10 @@ def initialize_attention_state( """Initialize attention state so that every layer can use it. Args: - construct_graph_dimensions (Optional[InferenceBatchDimensions]): The graph config to use for constructing the cuda graphs. - is_expert_parallel_dummy_cuda_graph_step (bool): Whether this is a dummy expert model parallel step. + construct_graph_dimensions (Optional[InferenceBatchDimensions]): + The graph config to use for constructing the cuda graphs. + is_expert_parallel_dummy_cuda_graph_step (bool): + Whether this is a dummy expert model parallel step. Return: None. """ @@ -1381,16 +1462,22 @@ def initialize_attention_state( if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph else: - padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): - padded_token_count = min( - self.max_tokens, - self.max_requests, - self.round_up_tokens(self.active_token_count), - ) - padded_decode_req_count = padded_token_count + if self.num_speculative_tokens > 0: + padded_decode_req_count = min( + self.max_requests, self.round_up_requests(self.num_decode_requests) + ) + padded_token_count = padded_decode_req_count * (self.num_speculative_tokens + 1) + else: + padded_token_count = min( + self.max_tokens, + self.max_requests, + self.round_up_tokens(self.active_token_count), + ) + padded_decode_req_count = padded_token_count padded_prefill_req_count = 0 else: + padded_token_count = self.round_up_tokens(self.active_token_count) target_padding_req_count = min( self.max_requests, self.round_up_requests(self.total_request_count - self.paused_request_count), @@ -1449,6 +1536,7 @@ def initialize_attention_state( request_to_kv_block_ids=request_to_kv_block_ids_view, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, + num_speculative_tokens=self.num_speculative_tokens, ) if self.is_hybrid_model: @@ -1484,6 +1572,7 @@ def reset_tensors(self) -> None: self.request_last_kv_block_id.fill_(-1) self.request_last_kv_block_offset.fill_(0) self.request_to_kv_block_ids.fill_(-1) + self.request_in_prefill_status_tensor.fill_(-1) # Reset request metadata. for metadata_tensor in self.request_metadata.values(): @@ -1518,6 +1607,7 @@ def reset_metadata(self) -> None: self.padded_active_token_count = 0 self.padded_active_request_count = 0 self.paused_tokens = None + self.paused_speculative_tokens = None # Reset attention, mamba, and block allocator state. self.reset_attention_state() @@ -1566,6 +1656,9 @@ def current_input_and_position_ids( (Tuple[Tensor, Tensor]) Flattened active input and position IDs. """ num_tokens = num_warmup_tokens or self.padded_active_token_count + assert num_tokens >= self.padded_batch_dimensions.decode_req_count * ( + self.num_speculative_tokens + 1 + ) return ( self.token_to_input_ids[:num_tokens].unsqueeze(0), self.token_to_pos_ids[:num_tokens].unsqueeze(0), @@ -1580,7 +1673,6 @@ def last_token_logits(self, logits: Tensor) -> Tensor: Return: (Tensor) Last token logits. """ - # todo: @lmcafee, remove these asserts? assert logits.size(0) == 1, f"logits.size(0) ({tuple(logits.shape)}) != 1" assert logits.size(1) == self.padded_active_token_count, ( @@ -1733,7 +1825,9 @@ def _find_matching_prefix_blocks( return matched_blocks, parent_hash def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] = None) -> None: - """Add request to context. At this stage, we assume that the request is valid and can be added, as the checks are done in the schedule function. + """ + Add request to context. At this stage, we assume that the request is valid and can be added, + as the checks are done in the schedule function. Args: req (DynamicInferenceRequest): Request to add. @@ -1796,9 +1890,9 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] # no need to update count, as it is already here if is_chunked_prefill: current_id = self.total_request_count - 1 - self.active_token_count -= ( - 1 # Overwrite the last token, which is the useless token from chunked prefill - ) + # Overwrite the last token, which is the useless token from chunked prefill + chunked_prefill_offset = 1 + self.num_speculative_tokens + self.active_token_count -= chunked_prefill_offset assert ( self.request_ids[current_id] == req.request_id ), "Continuation current_id mismatch" @@ -1832,6 +1926,7 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] # Handle length and block assignments. self.request_query_lengths[current_id] = effective_chunk_length + self.request_in_prefill_status_tensor[current_id] = 1 self.request_output_lengths[current_id] = ( req.finished_chunk_token_count + chunk_length @@ -1923,16 +2018,22 @@ def _register_range(start: int, end: int): self.total_request_count += 0 if req.finished_chunk_token_count > 0 else 1 self.num_prefill_requests += 1 - def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): + def _move_book_keeping_tensors( + self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens=None + ): """ Move all the relevent booking tensors with src idxs to dst idxs """ self.request_kv_length_offsets[dst_idxs] = self.request_kv_length_offsets[src_idxs] + self.request_in_prefill_status_tensor[dst_idxs] = self.request_in_prefill_status_tensor[ + src_idxs + ] self.request_query_lengths[dst_idxs] = self.request_query_lengths[src_idxs] self.request_output_lengths[dst_idxs] = self.request_output_lengths[src_idxs] self.request_ids[dst_idxs] = self.request_ids[src_idxs] - next_tokens[dst_idxs] = next_tokens[src_idxs] - + next_tokens[dst_idxs] = next_tokens[src_idxs] # num tokens sames as num samples + if new_speculative_tokens is not None: + new_speculative_tokens[:, dst_idxs] = new_speculative_tokens[:, src_idxs] self.request_to_kv_block_ids[dst_idxs] = self.request_to_kv_block_ids[src_idxs] self.request_kv_block_counts[dst_idxs] = self.request_kv_block_counts[src_idxs] self.request_last_kv_block_id[dst_idxs] = self.request_last_kv_block_id[src_idxs] @@ -1946,12 +2047,15 @@ def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): self.mamba_metadata.request_to_mamba_state_idx[src_idxs] ) - def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): + def _swap_book_keeping_tensors( + self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens=None + ): """ Swaps all the relevent booking tensors with src idxs to dst idxs """ tensor_swap(self.request_kv_length_offsets, src_idxs, dst_idxs) tensor_swap(self.request_query_lengths, src_idxs, dst_idxs) + tensor_swap(self.request_in_prefill_status_tensor, src_idxs, dst_idxs) tensor_swap(self.request_output_lengths, src_idxs, dst_idxs) tensor_swap(self.request_ids, src_idxs, dst_idxs) tensor_swap(next_tokens, src_idxs, dst_idxs) @@ -1960,6 +2064,11 @@ def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): tensor_swap(self.request_last_kv_block_id, src_idxs, dst_idxs) tensor_swap(self.request_last_kv_block_offset, src_idxs, dst_idxs) + if new_speculative_tokens is not None: + # new_speculative_tokens has request dimension as second dimension, + # so swap on transposed view + tensor_swap(new_speculative_tokens.t(), src_idxs, dst_idxs) + for metadata_tensor in self.request_metadata.values(): tensor_swap(metadata_tensor, src_idxs, dst_idxs) @@ -2003,10 +2112,7 @@ def release_memory_blocks_from_request_indexes(self, request_indexes) -> None: self.mamba_metadata.free_slots(request_indexes) def resume_paused_requests( - self, - active_request_count: int, - newly_paused_request_ids: torch.Tensor, - next_tokens: torch.Tensor, + self, active_request_count: int, newly_paused_request_ids: torch.Tensor ) -> tuple[int, torch.Tensor]: """Resume as many paused requests as we have space for in the active buffer. @@ -2024,55 +2130,60 @@ def resume_paused_requests( resume_request_count = 0 if self.paused_request_count > 0: active_block_count_avail = self.block_allocator.get_active_avail() + # Clone not needed: flip() makes a copy. paused_block_counts = self.request_kv_block_counts[: self.paused_request_count] # Flip counts before cumsum, since paused requests are resumed from # the right-most index, so we must count resumed blocks starting from # the right side. paused_block_counts = paused_block_counts.flip(dims=[0]) - # Add +1 to all block counts, since any time a paused request is - # resumed, it will be starting a new memory block. For background, - # pausing happens after a request has generated the final token of a - # memory block (i.e., token 256 of that block), which means the very - # next token (whenever that request gets unpaused) will be in a new - # block. So, when we resume a paused request, we have to account for - # the fact that it will need an extra block beyond the ones that it - # has already used. - paused_block_counts += 1 # +1 for newly added block + + # Check which paused requests will actually need a new block upon resuming + offsets = self.request_last_kv_block_offset[: self.paused_request_count] + needs_new_block = ( + offsets >= self.block_size_tokens - 1 - self.num_speculative_tokens + ).to(paused_block_counts.dtype) + needs_new_block = needs_new_block.flip(dims=[0]) + + # Add +1 ONLY to the block counts of requests that finished their previous memory block + paused_block_counts += needs_new_block paused_block_counts_cumsum = paused_block_counts.cumsum(dim=0) resume_request_count = min( torch.nonzero(paused_block_counts_cumsum <= active_block_count_avail).numel(), self.block_allocator.total_avail, ) + # Constrain resumptions by the maximum allowed active requests and tokens + max_allowed_active = min( + self.max_requests, self.max_tokens // (self.num_speculative_tokens + 1) + ) + allowed_to_resume = max(0, max_allowed_active - active_request_count) + resume_request_count = min(resume_request_count, allowed_to_resume) + self.paused_request_count -= resume_request_count active_request_count += resume_request_count # Resume requests by assigning blocks and updating bookkeeping tensors. if resume_request_count > 0: - assert torch.all( - self.request_last_kv_block_offset[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] - == self.block_size_tokens - 1 - ), "The request_last_kv_block_offset should be 0 for the requests that just got resumed this step." + resume_start = self.paused_request_count + resume_end = self.paused_request_count + resume_request_count - assert resume_request_count <= self.block_allocator.total_avail - block_ids = self.block_allocator.allocate_memory_blocks(resume_request_count) - row_idx = torch.arange( - self.paused_request_count, - self.paused_request_count + resume_request_count, - device=torch.cuda.current_device(), - ) - col_idx = self.request_kv_block_counts[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] - self.request_to_kv_block_ids[row_idx, col_idx] = block_ids - self.request_kv_block_counts[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] += 1 - self.request_last_kv_block_id[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] = block_ids + # Check which resumed requests actually need a new block + offsets = self.request_last_kv_block_offset[resume_start:resume_end] + needs_new_block = offsets >= (self.block_size_tokens - 1 - self.num_speculative_tokens) + num_new_blocks = needs_new_block.sum().item() + + if num_new_blocks > 0: + assert num_new_blocks <= self.block_allocator.total_avail + block_ids = self.block_allocator.allocate_memory_blocks(num_new_blocks) + + # Apply updates only to the requests that required a new block + relative_row_idx = torch.nonzero(needs_new_block).squeeze(1) + row_idx = resume_start + relative_row_idx + col_idx = self.request_kv_block_counts[row_idx] + + self.request_to_kv_block_ids[row_idx, col_idx] = block_ids + self.request_kv_block_counts[row_idx] += 1 + self.request_last_kv_block_id[row_idx] = block_ids # Remove resumed requests from newly_paused_request_ids. We do this by # truncating the end of newly_paused_request_ids, which works because we @@ -2085,7 +2196,10 @@ def resume_paused_requests( return active_request_count, newly_paused_request_ids def evict_overflow_paused_requests( - self, active_request_count: int, next_tokens: torch.Tensor + self, + active_request_count: int, + next_tokens: torch.Tensor, + new_speculative_tokens: Optional[torch.Tensor] = None, ) -> Optional[tuple[torch.Tensor, torch.Tensor]]: """Evict requests that overflow the paused buffer. @@ -2140,6 +2254,8 @@ def evict_overflow_paused_requests( evict_request_idxs = torch.arange( evict_start_idx, evict_end_idx, device=torch.cuda.current_device() ) + # Clone needed: subsequent release_memory_blocks_from_request_indexes and + # _swap_book_keeping_tensors calls mutate self.request_ids in place. evict_request_ids = self.request_ids[evict_start_idx:evict_end_idx].clone() # Release memory. @@ -2174,7 +2290,10 @@ def evict_overflow_paused_requests( # Swap evicted and active requests. self._swap_book_keeping_tensors( - src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens + src_idxs=src_idxs, + dst_idxs=dst_idxs, + next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # Update tracking vars. @@ -2191,7 +2310,12 @@ def evict_overflow_paused_requests( return evict_request_ids - def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> Tensor: + def update_requests( + self, + active_requests_mask: Tensor, + new_tokens: Tensor, + new_speculative_tokens: Tensor = None, + ) -> Tensor: """Update context state after calling engine.step(). This method is responsible for: @@ -2224,8 +2348,11 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T 8. We make relevant changes to the token bookkeeping tensors Args: - active_requests_mask (Tensor): 1D Mask tensor marking active requests. - new_tokens (Tensor): Newly sampled tokens, with one token per active request. + active_requests_mask (Tensor): 1D Mask tensor marking active requests. (Active request length) + new_tokens (Tensor): Newly sampled tokens, with one token per active request. (Active request length) + new_speculative_tokens (Tensor): Newly sampled speculative tokens, + with num_speculative tokens per active request. + (num_speculative_tokens, active_request_length) Return: (Tensor) Newly paused request IDs. @@ -2236,6 +2363,10 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T # finished_request_count are requests that have reached the termination criterion self.num_prefill_requests = 0 # all turns to decode + # All request that were in prefill become decode requests + self.request_in_prefill_status_tensor[self.request_in_prefill_status_tensor == 1] = ( + 0 # TODO : Check how this works with chunked prefill + ) if self.chunked_prefill_request_id != -1: active_requests_mask[-1] = ( 1 # must keep this, next iteration will add a new chunk to it @@ -2276,6 +2407,10 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T if self.paused_request_count != 0: assert self.paused_tokens is not None next_tokens = torch.cat((self.paused_tokens, new_tokens)) + if new_speculative_tokens is not None and self.paused_speculative_tokens is not None: + new_speculative_tokens = torch.cat( + (self.paused_speculative_tokens, new_speculative_tokens), dim=1 + ) else: next_tokens = new_tokens @@ -2306,6 +2441,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T src_idxs=active_idxs_on_right, dst_idxs=finished_idxs_on_left, next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # Reset chunk ids for recently moved requests. @@ -2323,7 +2459,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T self.paused_request_count : (active_request_count + self.paused_request_count) ] active_requests_requiring_new_block = ( - num_tokens_in_last_block == self.block_size_tokens - 1 + num_tokens_in_last_block >= self.block_size_tokens - 1 - self.num_speculative_tokens ).byte() if self.chunked_prefill_request_id != -1: @@ -2331,6 +2467,13 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T active_requests_requiring_new_block[ self.get_index_of_chunked_prefill_request() - self.paused_request_count ] = 0 # chunked prefill should not be paused + else: + max_allowed_active = min( + self.max_requests, self.max_tokens // (self.num_speculative_tokens + 1) + ) + if active_request_count > max_allowed_active: + # Force-pause excess requests in a decode-only batch + active_requests_requiring_new_block[max_allowed_active:] = 1 active_requests_requiring_new_block_count = ( (active_requests_requiring_new_block == 1).sum().item() @@ -2371,7 +2514,10 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T dst_idxs = torch.cat((active_request_ids_on_left, paused_requests_idxs_on_right)) src_idxs = torch.cat((paused_requests_idxs_on_right, active_request_ids_on_left)) self._move_book_keeping_tensors( - src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens + src_idxs=src_idxs, + dst_idxs=dst_idxs, + next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) self.paused_request_count += active_requests_requiring_new_block_count @@ -2380,17 +2526,30 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T # 6. Now that we have the requests in following order [Paused, Active, Finished] # We determine how many requests we can resume and resume them + # For multi-token generation: store previous block IDs BEFORE resume allocates new blocks. + # This allows us to know which block tokens should go to if they don't cross the boundary. + # After resume_paused_requests, request_last_kv_block_id will be updated to the NEW block + # for resumed requests, but we need the OLD block for tokens that don't cross. + prev_last_block_ids = None + if self.num_speculative_tokens > 0: + # Clone needed: resume_paused_requests mutates request_last_kv_block_id + # (assigns new block IDs), but we need the old values later to determine + # which block tokens should go to when they don't cross a block boundary. + prev_last_block_ids = self.request_last_kv_block_id.clone() + # 6.a. First, resume temporarily paused requests. active_request_count, newly_paused_request_ids = self.resume_paused_requests( - active_request_count, newly_paused_request_ids, next_tokens + active_request_count, newly_paused_request_ids ) # 6.b. Evict requests that overflow the paused buffer. - evict_request_ids = self.evict_overflow_paused_requests(active_request_count, next_tokens) + evict_request_ids = self.evict_overflow_paused_requests( + active_request_count, next_tokens, new_speculative_tokens + ) # 6.c. Resume any additional requests. active_request_count, newly_paused_request_ids = self.resume_paused_requests( - active_request_count, newly_paused_request_ids, next_tokens + active_request_count, newly_paused_request_ids ) assert active_request_count > 0, "active_request_count == %d." % active_request_count @@ -2402,50 +2561,159 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T src_idxs=torch.tensor([self.get_index_of_chunked_prefill_request()]), dst_idxs=torch.tensor([self.total_request_count - 1]), next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # 7. We make changes to the request book keeping tesnsors and setup the tokens for next iteration assert self.total_request_count == active_request_count + self.paused_request_count - # All these active requests are in decode phase, so they need only 1 token per request - self.active_token_count = active_request_count - # Always the first section of token input ids are only used. - self.token_to_input_ids[: self.active_token_count] = next_tokens[ - self.paused_request_count : self.total_request_count - ] - if self.paused_request_count > 0: - self.paused_tokens = next_tokens[: self.paused_request_count] - - # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_) - # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) + # Clone needed: next_tokens is a shared buffer that will be overwritten in + # the next iteration; paused_tokens must persist independently. + self.paused_tokens = next_tokens[: self.paused_request_count].clone() + if new_speculative_tokens is not None: + # Clone needed: same reason as paused_tokens above. + self.paused_speculative_tokens = new_speculative_tokens[ + :, : self.paused_request_count + ].clone() + + # add_ and fill_ calls seems to work as intended with sliced indexing + # (i.e. x[3:5].add(...) or x[3:5].fill_) but when another tensor is used + # for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) self.request_kv_length_offsets[self.paused_request_count : self.total_request_count].add_( self.request_query_lengths[self.paused_request_count : self.total_request_count] ) - self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_(1) - self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ + + num_generated_tokens = 1 + self.num_speculative_tokens + self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_( + num_generated_tokens + ) + + # Clone needed: old_offsets is reused later to compute raw_positions + # for block-boundary detection. The write-back on the next line overwrites the + # underlying tensor, so without clone the boundary-crossing logic would see the + # new offsets instead of the pre-update values. + old_offsets = self.request_last_kv_block_offset[ self.paused_request_count : self.total_request_count - ] + ].clone() self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] = ( - self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] - + 1 + old_offsets + num_generated_tokens ) % self.block_size_tokens - # 8. We make relevant changes to the token bookkeeping tensors + self.active_token_count = active_request_count * num_generated_tokens + sampled_tokens = next_tokens[self.paused_request_count : self.total_request_count] + + if self.num_speculative_tokens > 0: + # new_speculative_tokens has shape [num_spec_tokens, num_requests], + # slice the request dimension (dim 1) + sampled_speculative_tokens = new_speculative_tokens[ + :, self.paused_request_count : self.total_request_count + ] + # This will become [sampled, spec1, spec2, sampled, spec1, spec2 ...] + # For every request we will have the sampled token followed by the + # speculative tokens (i.e next indices) + next_tokens = torch.vstack( + [sampled_tokens.unsqueeze(0), sampled_speculative_tokens] + ).T.reshape(-1) + else: + next_tokens = sampled_tokens + + self.token_to_input_ids[: self.active_token_count] = next_tokens + + # Req kv length offsets : [0, 5, 10 ... ] + # For num spec tokens = 2 , this will become [0, 1, 2, 5, 6, 7 10, 11, 12 ...] + self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ + self.paused_request_count : self.total_request_count + ].repeat_interleave(num_generated_tokens) + torch.arange( + num_generated_tokens, device=torch.cuda.current_device() + ).repeat( + active_request_count + ) + # + # Token to request idx : [0, 0, 0, 1, 1, 1, 2, 2, 2 ...] self.token_to_request_idx[: self.active_token_count] = torch.arange( self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() - ) - self.token_to_position_in_request[: self.active_token_count] = ( - self.request_kv_length_offsets[self.paused_request_count : self.total_request_count] + ).repeat_interleave(num_generated_tokens) + + self.token_to_position_in_request[: self.active_token_count] = self.token_to_pos_ids[ + : self.active_token_count + ] + + self.token_to_local_position_within_kv_block[: self.active_token_count] = ( + self.token_to_pos_ids[: self.active_token_count] % self.block_size_tokens ) - self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ + current_block_ids = self.request_last_kv_block_id[ self.paused_request_count : self.total_request_count ] - self.token_to_local_position_within_kv_block[: self.active_token_count] = ( - self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] + + # raw positions shape : [active_request_count, num_generated_tokens] + # e.g block size 6, old_offsets = [1,5,2] , num_generated_tokens = 3 + # raw_positions = [[1, 2, 3], [5, 6, 7], [2, 3, 4]] + # crosses_boundary = [[False, False, False], [False, True, True], [False, False, False]] + raw_positions = ( + old_offsets[:, None] + + 1 # Offset by 1 because old_offsets points to the LAST token + + torch.arange(num_generated_tokens, device=torch.cuda.current_device())[None, :] ) + # + # A token crosses to the next block if its raw_position >= block_size + crosses_boundary = raw_positions >= self.block_size_tokens + + if not crosses_boundary.any() or self.num_speculative_tokens == 0: + # Fast path: no tokens cross block boundary, all use current block + self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ + self.paused_request_count : self.total_request_count + ].repeat_interleave(num_generated_tokens) + else: + + # Some tokens cross to the next block (this happens for resumed requests) + # + # When a request is paused and resumed: + # 1. It was paused because remaining_space < num_tokens_per_step + # 2. A NEW block is allocated in resume_paused_requests + # 3. request_last_kv_block_id is updated to the NEW block + # 4. The old offset is preserved (wasn't reset) + # + # So for resumed requests: + # - Tokens before the boundary (raw_pos < block_size): go to PREVIOUS block + # - Tokens at/after the boundary (raw_pos >= block_size): go to CURRENT (new) block + # + # For non-resumed requests (no boundary crossing): all go to current block + # + # We use prev_last_block_ids which was stored BEFORE resume_paused_requests + # was called, so it contains the OLD block IDs before new blocks were allocated. + + # Get previous block IDs (stored before resume_paused_requests) + prev_block_ids = prev_last_block_ids[ + self.paused_request_count : self.total_request_count + ] # [active_count] + + # For each request, check if ANY token crosses (i.e., request was resumed) + request_has_crossing = crosses_boundary.any(dim=1) # [active_count] + + # Build block_idx: [active_count, N] + # Start with current (new) block for all + # Lets say current block ids is [a1, a2 , a3] and num generated_tokens is 3 + # This will be [[a1, a1, a1], [a2, a2, a2], [a3, a3, a3]] + # No clone needed: expand() returns a read-only view, and downstream + # torch.where() and .flatten() both return new tensors without in-place mutation. + block_idx = current_block_ids[:, None].expand( + -1, num_generated_tokens + ) # [active_count, N] + + # For requests that have crossing, tokens BEFORE boundary use prev block + # crosses_boundary is False for tokens before boundary + # So: where request_has_crossing AND NOT crosses_boundary, use prev_block + use_prev_block = request_has_crossing[:, None] & ~crosses_boundary # [active_count, N] + + # Apply previous block IDs where needed + prev_block_ids_expanded = prev_block_ids[:, None].expand(-1, num_generated_tokens) + block_idx = torch.where(use_prev_block, prev_block_ids_expanded, block_idx) + + # Convert back to 1d tensor + self.token_to_block_idx[: self.active_token_count] = block_idx.flatten() return { "newly_paused_request_ids": newly_paused_request_ids, diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index 60ca06819e7..a9b2445b5e5 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -389,14 +389,12 @@ def start(self): return if request_hashes: - self._update_rank_hashes(next_data_parallel_rank_identity, request_hashes) + self._update_rank_hashes(next_identity, request_hashes) if self.schedule_records is not None: self.schedule_records.append( { "request_id": request_id, - "rank_index": self.identity_to_rank_index[ - next_data_parallel_rank_identity - ], + "rank_index": self.identity_to_rank_index[next_identity], "num_hashes": len(request_hashes), } ) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 88e0f31b7b6..102aaa716a7 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -202,15 +202,28 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen # Initialization options. self.controller = controller self.context = context + + self.num_speculative_tokens = inference_config.num_speculative_tokens + self.materialize_only_last_token_logits = ( + inference_config.materialize_only_last_token_logits + ) + + assert self.num_speculative_tokens >= 0, "Number of speculative tokens must be non-negative" + + if self.num_speculative_tokens > 0: + assert ( + self.num_speculative_tokens <= self.controller.num_mtp_heads + ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" + assert ( + not self.materialize_only_last_token_logits + ), "materialize_only_last_token_logits must be False when num_speculative_tokens > 0" + self.track_paused_request_events = inference_config.track_paused_request_events self.track_generated_token_events = inference_config.track_generated_token_events self.enable_chunked_prefill = inference_config.enable_chunked_prefill self.metrics_writer = inference_config.metrics_writer self.logging_step_interval = inference_config.logging_step_interval self.unified_memory_level = inference_config.unified_memory_level - self.materialize_only_last_token_logits = ( - inference_config.materialize_only_last_token_logits - ) self.cuda_graph_impl = model_config.cuda_graph_impl self.cuda_graph_scope = model_config.cuda_graph_scope # Initialize engine. @@ -283,6 +296,11 @@ def reset(self) -> None: self.resume_request_ids = None + # Speculative decoding acceptance tracking. + self._spec_tokens_proposed = 0 + self._spec_tokens_accepted = 0 + self._spec_steps = 0 + # Prefix caching coordination state. self._prefix_coordination_waits = 0 @@ -927,6 +945,7 @@ def post_process_requests( evict_request_ids: torch.Tensor, step_time: float, sample: torch.Tensor, + accepted_tokens: torch.Tensor, log_probs: torch.Tensor, top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None, routing_indices_per_request: Optional[Dict[int, torch.Tensor]] = None, @@ -939,7 +958,8 @@ def post_process_requests( finished_request_ids (torch.Tensor): A list of finished request ids evict_request_ids (torch.Tensor): A list of evicted request ids. step_time (float): The latency of the last step - sample: (torch.Tensor): The newly generated tokens for each request + sample: Tensor: The newly generated token for each request + accepted_tokens: Tensor: The additional accepted tokens for each request log_probs: (List): Log probs for each request top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to list of (top_n_logprobs, top_n_indices) tuples. @@ -970,49 +990,100 @@ def post_process_requests( blocks_hashed_active = blocks_allocated blocks_ref_count = None - for req_idx, (request_id, token, request_log_probs) in enumerate( - zip(request_ids.tolist(), sample.tolist(), log_probs_iter) + # When accepted_tokens is None (no speculative decoding), use repeat([]) to provide + # empty lists for each request, so the zip produces the correct number of iterations + accepted_tokens_iter = repeat([]) if accepted_tokens is None else accepted_tokens.tolist() + + if self.num_speculative_tokens > 0 and accepted_tokens is not None: + self._spec_steps += 1 + + for req_idx, (request_id, tokens, accepted_tokens_list, request_log_probs) in enumerate( + zip(request_ids.tolist(), sample.tolist(), accepted_tokens_iter, log_probs_iter) ): + + # Ensure tokens is always a list for consistent handling + if not isinstance(tokens, list): + tokens = [tokens] + request: DynamicInferenceRequest = self.get_request(request_id) + + if self.num_speculative_tokens > 0: + accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) + + # Track acceptance statistics for logging (decode requests only). + # Prefill requests don't propose speculative tokens, so including + # them would inflate the proposed count and deflate the rate. + # A request in its first generation step (empty generated_tokens) + # was in prefill this step. + if len(request.generated_tokens) > 0: + self._spec_tokens_proposed += self.num_speculative_tokens + self._spec_tokens_accepted += len(accepted_tokens) + + # The order `accepted_tokens + tokens` is correct here. + # `accepted_tokens` contains the sequence of + # successfully verified draft tokens. `tokens` (from `sample`) is the + # brand new token generated by the target model based on that accepted prefix. + # Therefore, the newly sampled token must go at the end of the sequence. + tokens = accepted_tokens + tokens + + num_stop_word_trim = 0 if request_id != self.context.chunked_prefill_request_id: # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) + # If the request already has more tokens, then we only append as much as is necessary + if ( + len(request.generated_tokens) + len(tokens) + >= request.sampling_params.num_tokens_to_generate + ): + tokens = tokens[ + : request.sampling_params.num_tokens_to_generate + - len(request.generated_tokens) + ] if request_id not in self.stop_word_being_finished_ids: is_first_token = len(request.generated_tokens) == 0 - request.generated_tokens.append(token) + request.generated_tokens += tokens + first_token_event = None if self.track_generated_token_events: - if block_allocator.enable_prefix_caching: - event_generated_token = request.add_event_generated_token( - token, - blocks_total=block_allocator.total_count, - blocks_hashed_total=blocks_allocated, - blocks_hashed_active=blocks_hashed_active, - blocks_ref_count=blocks_ref_count, - ) - else: - event_generated_token = request.add_event_generated_token( - token, - blocks_total=block_allocator.total_count, - blocks_hashed_total=blocks_allocated, - blocks_hashed_active=blocks_hashed_active, - ) + for token in tokens: + if block_allocator.enable_prefix_caching: + event = request.add_event_generated_token( + token, + blocks_total=block_allocator.total_count, + blocks_hashed_total=blocks_allocated, + blocks_hashed_active=blocks_hashed_active, + blocks_ref_count=blocks_ref_count, + ) + else: + event = request.add_event_generated_token( + token, + blocks_total=block_allocator.total_count, + blocks_hashed_total=blocks_allocated, + blocks_hashed_active=blocks_hashed_active, + ) + if first_token_event is None: + first_token_event = event if is_first_token: - if self.track_generated_token_events: - first_token_event = event_generated_token - else: + if not self.track_generated_token_events: first_token_event = DynamicInferenceEvent( type=DynamicInferenceEventType.GENERATED_TOKEN, - payload={"token_id": token}, + payload={"token_id": tokens[0]}, ) request.ttft = ( first_token_event.timestamp - request.event_add_engine.timestamp ) if request.tpot is None: request.tpot = [] - request.tpot.append(step_time) - - # Check for stop words (after token is appended) - stop_word_hit = self._check_stop_words_for_request_post_append(request) + per_token_step_time = step_time / len(tokens) + request.tpot.extend([per_token_step_time] * len(tokens)) + + # Check for stop words (after token is appended). + # With speculative decoding, a stop word may end before the last + # appended token. The check truncates generated_tokens in-place and + # returns how many trailing tokens were removed so we can also trim + # the corresponding log probs below. + stop_word_hit, num_stop_word_trim = self._check_stop_words_for_request_post_append( + request + ) if request_id in finished_request_ids: # Request finished by normal means (termination_id, max_length, or stop word from previous step) @@ -1037,6 +1108,14 @@ def post_process_requests( # Additionally, chunked prefill request do not finish. active_request_ids.append(request_id) + # When a stop word was found mid-speculative-batch, trim log probs + # and top_n_logprobs to match the truncated generated_tokens. + if num_stop_word_trim > 0: + if request_log_probs is not None: + request_log_probs = request_log_probs[:-num_stop_word_trim] + if top_n_logprobs is not None and req_idx in top_n_logprobs: + top_n_logprobs[req_idx] = top_n_logprobs[req_idx][:-num_stop_word_trim] + # Process log_probs if available (unified for both regular and chunked prefill) if request_log_probs is not None: # Initialize lists if they don't exist @@ -1060,8 +1139,16 @@ def post_process_requests( # Handle skip_prompt_log_probs during prefill # If skip_prompt_log_probs is True and we have multiple log probs (prefill), - # only process the last one (first generated token) - if request.sampling_params.skip_prompt_log_probs and len(request_log_probs) > 1: + # only process the last one (first generated token). + # With speculative decoding, decode steps also produce multiple log probs + # (one per accepted token + new sample), so we must check that this is + # actually a prefill step (no generated log probs accumulated yet). + is_prefill_log_probs = len(request.generated_log_probs) == 0 + if ( + request.sampling_params.skip_prompt_log_probs + and len(request_log_probs) > 1 + and is_prefill_log_probs + ): # Only append the last log prob (first generated token) to generated_log_probs request.generated_log_probs.append(request_log_probs[-1]) else: @@ -1177,33 +1264,50 @@ def _get_and_clear_stop_word_finished_ids(self, active_request_ids: list[int]) - self.stop_word_finished_request_ids -= result return result - def _check_stop_words_for_request_post_append(self, request: DynamicInferenceRequest) -> bool: + def _check_stop_words_for_request_post_append( + self, request: DynamicInferenceRequest + ) -> Tuple[bool, int]: """Check if a request should stop due to stop words (after token is appended). This method is called from post_process_requests after the token has already - been appended to request.generated_tokens. + been appended to request.generated_tokens. In the speculative decoding case, + multiple tokens may have been appended at once. If a stop word is found in the + middle of the speculative tokens, the trailing tokens after the stop word are + truncated from generated_tokens. + + With speculative decoding, multiple tokens are appended at once. The stop word + may end before the last appended token, leaving extra tokens that must be + trimmed. When this happens, generated_tokens is truncated in-place and the + number of trimmed tokens is returned so the caller can also trim log probs. Args: request: The request to check. Returns: - bool: True if the generated sequence ends with a stop word, False otherwise. + Tuple of (stop_word_hit, num_tokens_trimmed): + stop_word_hit: True if the generated sequence contains a stop word. + num_tokens_trimmed: Number of tokens removed from the end of + generated_tokens (0 when the stop word is at the very end + or when no stop word was found). """ - # Check if request has stop words configured if request.stop_word_ids is None or len(request.stop_word_ids) == 0: - return False + return False, 0 generated_tokens = request.generated_tokens - # Check if the sequence ends with any stop word for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: - # Check if the last stop_len tokens match the stop word - if generated_tokens[-stop_len:] == stop_word_ids: - return True - - return False + # Check the last stop_len tokens shifting by 1 up to num_speculative_tokens. + # Speculative decoding can append multiple tokens at once, so the stop + # word might end at any position within the newly appended tokens. + for i in range(self.num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + if i > 0: + request.generated_tokens = request.generated_tokens[:-i] + return True, i + return False, 0 def get_prefix_coordination_metrics(self) -> dict: """Return prefix caching coordination metrics. @@ -1469,6 +1573,7 @@ async def async_bookkeep( newly_paused_request_ids = step_result.get("newly_paused_request_ids") evict_request_ids = step_result.get("evict_request_ids") sample = step_result["sample"] + accepted_tokens = step_result["accepted_tokens"] log_probs = step_result["log_probs"] top_n_logprobs = step_result.get("top_n_logprobs", None) routing_indices_per_request = step_result.get("routing_indices_per_request", None) @@ -1486,6 +1591,7 @@ async def async_bookkeep( evict_request_ids, step_time, sample, + accepted_tokens, log_probs, top_n_logprobs, routing_indices_per_request, @@ -1553,6 +1659,14 @@ async def async_bookkeep( else: metrics[f'inference/{key}'] = value + # Add speculative decoding acceptance metrics. + if self.num_speculative_tokens > 0 and self._spec_tokens_proposed > 0: + acceptance_rate = self._spec_tokens_accepted / self._spec_tokens_proposed + metrics['inference/spec_decode_acceptance_rate'] = float(acceptance_rate * 100.0) + metrics['inference/spec_decode_tokens_proposed'] = int(self._spec_tokens_proposed) + metrics['inference/spec_decode_tokens_accepted'] = int(self._spec_tokens_accepted) + metrics['inference/spec_decode_num_steps'] = int(self._spec_steps) + if HAVE_WANDB and self.metrics_writer.__name__ == "wandb": self.metrics_writer.log(metrics, commit=True) else: @@ -1602,10 +1716,24 @@ async def async_bookkeep( mem["reserved_bytes.all.current"] / (1024**3), ) ) + if self.num_speculative_tokens > 0 and self._spec_tokens_proposed > 0: + spec_rate = self._spec_tokens_accepted / self._spec_tokens_proposed * 100.0 + output_str += " ... spec: accept %.1f%% (%d/%d in %d steps)" % ( + spec_rate, + self._spec_tokens_accepted, + self._spec_tokens_proposed, + self._spec_steps, + ) if context_state["is_decode_only"]: output_str = f"\033[94m{output_str}\033[0m" logging.info(output_str) + # Reset speculative decoding accumulators after both wandb and console logging. + if self.num_speculative_tokens > 0: + self._spec_tokens_proposed = 0 + self._spec_tokens_accepted = 0 + self._spec_steps = 0 + return { "active_request_ids": active_request_ids, "finished_request_records": finished_request_records, diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 7e50f58e3e6..e94cd29692d 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -6,7 +6,7 @@ import functools import inspect from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union import torch import torch.nn.functional as F @@ -62,6 +62,7 @@ def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, token self.model_config = self.inference_wrapped_model.model.config inference_config = self.inference_wrapped_model.inference_context.config self.tokenizer = tokenizer + self.num_speculative_tokens = inference_config.num_speculative_tokens pg_collection = inference_config.pg_collection if pg_collection is not None: @@ -80,11 +81,19 @@ def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, token self.vocab_size = unwrapped_model.vocab_size self.sampling_rng = torch.Generator(device=torch.cuda.current_device()) + self.num_mtp_heads = self._get_mtp_num_heads() self.sampling_rng.manual_seed(self.model_config.inference_sampling_seed) if self.inference_wrapped_model.inference_context.is_dynamic_batching(): self._init_dynamic_sampling_tensors() + def _get_mtp_num_heads(self) -> int: + """Get the number of MTP layers from the model config.""" + model = self.inference_wrapped_model.model + if hasattr(model, 'config') and hasattr(model.config, 'mtp_num_layers'): + return model.config.mtp_num_layers or 0 + return 0 + def set_stop_word_finished_ids_callback(self, callback): """Set a callback to get request IDs that should be marked as finished due to stop words. @@ -109,6 +118,12 @@ def _init_dynamic_sampling_tensors(self): self._sampling_backend = "torch" self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) + # Speculative tokens tensor will be allocated later when num_speculative_tokens is set by the engine + self._accepted_tokens_per_request = None + # MTP tensor will be allocated later when num_speculative_tokens is set by the engine + self._sampled_mtp_tokens_cuda = None + # Last accepted sequence indices for serial MTP computation + self._last_accepted_seq_indices = None # Keep track of request metadata. self._request_metadata: Dict[str, Tensor] = {} @@ -122,7 +137,25 @@ def _init_dynamic_sampling_tensors(self): # Used for inefficient torch sampling. if self._sampling_backend == "torch": - self._torch_sampling_buckets: Iterator[Tuple] = [] + self._torch_sampling_buckets: List[Tuple] = [] + + self._init_mtp_sampling_tensor() + + def _init_mtp_sampling_tensor(self): + """Initialize the MTP sampling tensor after num_speculative_tokens is set.""" + if self.num_speculative_tokens is not None and self.num_speculative_tokens > 0: + context = self.inference_wrapped_model.inference_context + max_requests = context.max_requests + device = torch.cuda.current_device() + self._sampled_mtp_tokens_cuda = torch.empty( + [self.num_speculative_tokens, max_requests], dtype=torch.int64, device=device + ) + self._accepted_tokens_per_request = ( + torch.ones( + [max_requests, self.num_speculative_tokens], dtype=torch.int64, device=device + ) + * -1 + ) def tokenize_prompt(self, prompt: str, add_BOS: bool = False) -> List[int]: """Utility to tokenize the input prompts. @@ -275,6 +308,8 @@ def modify_logits_for_top_p_filtering(logits, top_p): # in the original implementation: # https://github.com/ari-holtzman/degen/blob/master/gen.py # and I guess it is needed so keeping it for now. + # Clone needed: filter_[:, 1:] and filter_[:, :-1] are overlapping views; + # without clone, each write would corrupt the next read during the shift. filter_[:, 1:] = filter_[:, :-1].clone() # Make sure we at least have one token to select from. filter_[..., 0] = 0 @@ -287,6 +322,8 @@ def modify_logits_for_top_p_filtering(logits, top_p): if top_k == 1: sampled_logits = torch.argmax(last_token_logits, dim=-1) else: + # Clone needed: .div_() and masked_fill_() below modify in-place, + # which would mutate the caller's tensor without this clone. last_token_logits = last_token_logits.clone() if temperature != 1.0: last_token_logits.div_(temperature) @@ -589,13 +626,18 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) logits = self.inference_wrapped_model.run_one_forward_step( {"tokens": input_ids, "position_ids": position_ids, "attention_mask": None} ) + # logits shape: [1, seq_len, vocab_size] + + # Note: When speculative decoding is active (num_speculative_tokens > 0), + # the model skips MTP computation during the forward pass. MTP logits + # will be computed serially after verification to ensure they are + # conditioned on verified tokens only. if self.model_is_pipeline_parallel: - logits_seq_len = ( - active_request_count - if context.config.materialize_only_last_token_logits - else input_ids.shape[1] - ) + if context.config.materialize_only_last_token_logits: + logits_seq_len = active_request_count + else: + logits_seq_len = input_ids.shape[1] logits_shape = [1, logits_seq_len, self.vocab_size] if is_pipeline_last_stage(self.pp_group): @@ -618,26 +660,529 @@ def _dynamic_step_sample_bookkeeping(self): if self._sampling_backend == "torch": # Bucketize the core sampling parameters. # Doing so via list comprehension is orders of magnitude faster than via torch. - bucket_map = {} + bucket_map = defaultdict(list) # Shorthands for the dictionary comprehension. temp = self._request_metadata["temperature"][active_request_slice].tolist() top_k = self._request_metadata["top_k"][active_request_slice].tolist() top_p = self._request_metadata["top_p"][active_request_slice].tolist() - for i, (t, k, p) in enumerate(zip(temp, top_k, top_p)): - h = (t, k, p) - bucket = bucket_map.get(h, None) - if bucket is None: - bucket_map[h] = ([i], i) - else: - bucket[0].append(i) + for request_index, (t, k, p) in enumerate(zip(temp, top_k, top_p)): + sampling_params = (t, k, p) + bucket_map[sampling_params].append(request_index) + + # Just unpack the key directly! + self._torch_sampling_buckets = [ + (indices, *sampling_params) for sampling_params, indices in bucket_map.items() + ] + + def _rewind_kv_cache(self): + """Update the KV cache bookkeeping for speculative decoding. + + After forward pass with speculative tokens, some tokens may be rejected. + This function "rewinds" the KV cache bookkeeping to reflect only the accepted tokens. + + When speculative tokens are rejected, we need to: + 1. Update request_kv_length_offsets (total sequence length) + 2. Update request_last_kv_block_offset (position within last block) + 3. If rewinding crosses a block boundary: + - Reduce request_kv_block_counts + - Update request_last_kv_block_id to point to the previous block + - Clear the entry in request_to_kv_block_ids for the released block + - Release the block back to the allocator + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_request_slice = slice(context.paused_request_count, context.total_request_count) + + # Get the accepted token counts for each request + # Note: _accepted_token_counts is indexed from 0 to active_request_count-1 + accepted_tokens_per_request = self._accepted_token_counts_per_request[:active_request_count] + + # Number of tokens to rewind (rejected speculative tokens) + num_tokens_to_rewind = self.num_speculative_tokens - accepted_tokens_per_request + + # For prefill requests, no speculative tokens were forwarded through the model, + # so there is nothing to rewind. + request_in_prefill_status = context.request_in_prefill_status_tensor[active_request_slice] + num_tokens_to_rewind[request_in_prefill_status == 1] = 0 + + # Save the original offset BEFORE modifying to correctly detect block boundary crossing + original_offset = context.request_last_kv_block_offset[active_request_slice].clone() + + # Check which requests need to rewind to a previous block BEFORE modifying + # A request crosses back to a previous block if: original_offset - num_tokens_to_rewind < 0 + remove_allocated_blocks_mask = (original_offset - num_tokens_to_rewind) < 0 + + # Update the offsets + context.request_last_kv_block_offset[active_request_slice] = ( + original_offset - num_tokens_to_rewind + ) % context.block_size_tokens + + context.request_kv_length_offsets[active_request_slice] = ( + context.request_kv_length_offsets[active_request_slice] - num_tokens_to_rewind + ) + + # No need to update request_query_lengths (It will be set correctly in the next iteration) + + # For requests that crossed back to a previous block, we need to: + # 1. Reduce the block count by 1 + # 2. Get the block ID to release (current request_last_kv_block_id) + # 3. Update request_last_kv_block_id to point to the previous block + # 4. Clear the entry in request_to_kv_block_ids for the released block + # 5. Release the block back to the allocator + if remove_allocated_blocks_mask.any(): + # Get indices of requests that need to release a block (relative to active requests) + requests_needing_release = torch.nonzero(remove_allocated_blocks_mask, as_tuple=True)[0] + # Convert to absolute indices in the context tensors + absolute_indices = requests_needing_release + context.paused_request_count + + # No clone needed: advanced (fancy) indexing with a tensor already returns + # a copy, not a view. + blocks_to_release = context.request_last_kv_block_id[absolute_indices] + + # Reduce block counts for requests that crossed back + context.request_kv_block_counts[absolute_indices] -= 1 + + # Get the new block counts after decrement + new_block_counts = context.request_kv_block_counts[absolute_indices] + + # Update request_last_kv_block_id to point to the previous block + # and clear the released block entry in request_to_kv_block_ids + # Vectorized implementation using advanced indexing: + # Note: new_block_counts is guaranteed to be > 0 for all requests here, since + # crossing back to a previous block implies the request had at least 2 blocks. + + # Update request_last_kv_block_id to point to the previous block (at index new_count - 1) + context.request_last_kv_block_id[absolute_indices] = context.request_to_kv_block_ids[ + absolute_indices, new_block_counts - 1 + ] + + # Clear the released block entry (at index new_count, which was the old last block) + context.request_to_kv_block_ids[absolute_indices, new_block_counts] = -1 + + # Release the blocks back to the allocator + context.block_allocator.release_memory_blocks(blocks_to_release) + + # Mamba speculative rewind state update + if context.is_hybrid_model: + active_mamba_indices = context.mamba_metadata.request_to_mamba_state_idx[ + active_request_slice + ] + is_decode_mask = context.request_in_prefill_status_tensor[active_request_slice] == 0 + decode_mamba_indices = active_mamba_indices[is_decode_mask] + accepted_tokens_per_decode_request = accepted_tokens_per_request[is_decode_mask] + + if decode_mamba_indices.numel() > 0: + context.mamba_conv_states[:, decode_mamba_indices] = ( + context.mamba_intermediate_conv_states[ + :, decode_mamba_indices, accepted_tokens_per_decode_request + ] + ) + context.mamba_ssm_states[:, decode_mamba_indices] = ( + context.mamba_intermediate_ssm_states[ + :, decode_mamba_indices, accepted_tokens_per_decode_request + ] + ) + + def _sample_from_logits_2d(self, logits_2d: Tensor) -> Tensor: + """Sample tokens from 2D logits using existing sampling parameters. + + Args: + logits_2d (Tensor): Logits of shape [num_requests, vocab_size]. + + Returns: + Tensor: Sampled tokens of shape [num_requests]. + """ + spec_token_list = [] + indices_list = [] + for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: + request_indices_tensor = torch.tensor( + request_indices, device=logits_2d.device, dtype=torch.long + ) + spec_token_list.append( + self._torch_sampling_func(logits_2d[request_indices_tensor, :], temp, top_k, top_p) + ) + indices_list.append(request_indices_tensor) + + spec_tokens = torch.empty(logits_2d.shape[0], device=logits_2d.device, dtype=torch.int64) + for tokens, indices in zip(spec_token_list, indices_list): + spec_tokens[indices] = tokens + return spec_tokens + + def _compute_serial_mtp_and_sample(self): + """Compute MTP logits serially after verification and sample speculative tokens. + + This ensures that MTP predictions are always conditioned on verified tokens. + Each MTP depth receives the correctly sampled token from the previous depth + (or the base token for depth 0) rather than stale speculative tokens from + the previous step. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_slice = slice(context.paused_request_count, context.total_request_count) + + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + + # On non-last pipeline stages, the model won't have decoder hidden states. + has_mtp = is_pipeline_last_stage(self.pp_group) and hasattr( + unwrapped_model, '_decoder_hidden_states_cache' + ) + + if has_mtp: + # Get decoder hidden states at last accepted positions. + hidden_states = unwrapped_model._decoder_hidden_states_cache + last_accepted_hidden = hidden_states[self._last_accepted_seq_indices, :, :] + # Shape: [active_request_count, 1, hidden_size] + else: + last_accepted_hidden = None + + # Compute position IDs for the next tokens. + # After rewind, request_kv_length_offsets has been adjusted. The actual + # KV cache length is: adjusted_offset + (1 + num_speculative_tokens). + # The next position to predict starts at that cache length. + adjusted_offsets = context.request_kv_length_offsets[active_slice] + base_position = adjusted_offsets + (1 + self.num_speculative_tokens) + + # Start with the freshly sampled base token. + next_token_ids = self._sampled_tokens_cuda[:active_request_count].clone() + current_hidden = last_accepted_hidden if has_mtp else None + + num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) + for depth in range(num_depths): + position_ids = (base_position + depth).unsqueeze(0) # [1, active_request_count] + token_ids = next_token_ids.unsqueeze(0) # [1, active_request_count] + + mtp_logits_2d = None + if has_mtp: + current_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( + hidden_states=current_hidden, + next_token_ids=token_ids, + position_ids=position_ids, + depth=depth, + ) + # mtp_logits: [active_request_count, 1, vocab_size] + mtp_logits_2d = mtp_logits.squeeze(1) # [active_request_count, vocab_size] + + # Broadcast MTP logits across pipeline stages. + if self.model_is_pipeline_parallel: + mtp_logits_2d = broadcast_from_last_pipeline_stage( + [active_request_count, self.vocab_size], + dtype=self.model_config.params_dtype, + tensor=mtp_logits_2d, + pp_group=self.pp_group, + ) + + # Sample speculative token using the same sampling parameters. + spec_tokens = self._sample_from_logits_2d(mtp_logits_2d) + self._sampled_mtp_tokens_cuda[depth, :active_request_count] = spec_tokens + + # Use sampled token as input for the next depth. + next_token_ids = spec_tokens + + # Clean up cached hidden states. + if has_mtp: + del unwrapped_model._decoder_hidden_states_cache + + def _get_required_logit_indices( + self, + request_in_prefill_status_tensor: Tensor, + request_query_lengths: Tensor, + num_decode_requests: int, + num_prefill_requests: int, + device: torch.device, + ) -> Tensor: + """Get indices into the logits tensor for tokens that need sampling. + + For decode requests, all tokens (base + speculative) are needed. + For prefill requests, only the last token logits are needed. + Decode requests will always be on the left, followed by prefill requests. + + Example with 5 requests (2 spec tokens): + Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] + Request to prefill [ 0 | 0 | 0 | 1 | 1 ] + Request query lengths [ 3 | 3 | 3 | 2 | 4 ] + OUTPUT : required_logit_indices [ 0 1 2 | 3 4 5 | 6 7 8 | 10 | 14 ] + + Returns: + Tensor: Indices into the sequence dimension of the logits tensor. + """ + decode_request_indices = torch.arange( + num_decode_requests * (self.num_speculative_tokens + 1), device=device + ) + prefill_request_indices = ( + request_query_lengths.cumsum(dim=0)[request_in_prefill_status_tensor == 1] - 1 + ) # Last token indices for prefill requests + required_logit_indices = torch.cat([decode_request_indices, prefill_request_indices]) + assert ( + len(required_logit_indices) + == num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests + ), ( + f"Expected length of required_logit_indices to be " + f"num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, " + f"but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} " + f"and num_prefill_requests {num_prefill_requests}" + ) + return required_logit_indices + + def _sample_speculative_logits( + self, + required_logits: Tensor, + required_mtp_logits: Tensor, + request_in_prefill_status_tensor: Tensor, + ) -> tuple: + """Sample tokens from logits and MTP logits using sampling buckets. + + For torch sampling buckets: [request_indices, temp, top_k, top_p] + + Example with 5 requests: + token_to_request_idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] + required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size] + + Sampling buckets: [[[0,2], temp1, top_k1, top_p1], [[1], temp3, top_k3, top_p3], [[3, 4], temp2, top_k2, top_p2]] + + Final output tokens : [a5s a6s a7s c6s c7s c8s b3s b4s b5s d2s e4s] # Shape [11] + (Rearranged from sampling bucket order back to input order using token_order) + + Returns: + tuple: (output_tokens, mtp_output_tokens, repeats) where output_tokens has shape + [total_required_tokens] and mtp_output_tokens has shape + [num_speculative_tokens, total_required_tokens]. + """ + repeats = torch.where( + request_in_prefill_status_tensor == 0, 1 + self.num_speculative_tokens, 1 + ) + token_to_request_index = torch.repeat_interleave( + torch.arange( + len(request_in_prefill_status_tensor), + device=request_in_prefill_status_tensor.device, + ), + repeats, + ) + + output_tokens_jumbled_list = [] + mtp_output_tokens_jumbled_list = [] + token_order_list = [] + + has_mtp_logits = required_mtp_logits is not None + + for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: + request_indices_tensor = torch.tensor( + request_indices, device=token_to_request_index.device + ) + required_indices = torch.where( + torch.isin(token_to_request_index, request_indices_tensor) + )[0] + output_tokens_jumbled_list.append( + self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) + ) + if has_mtp_logits: + mtp_logits_slice = required_mtp_logits[:, required_indices, :] + num_spec, num_reqs, vocab = mtp_logits_slice.shape + sampled_mtp = self._torch_sampling_func( + mtp_logits_slice.reshape(num_spec * num_reqs, vocab), temp, top_k, top_p + ) + mtp_output_tokens_jumbled_list.append(sampled_mtp.reshape(num_spec, num_reqs)) + token_order_list.append(required_indices) + + output_tokens_jumbled = torch.cat(output_tokens_jumbled_list, dim=0) + output_tokens = torch.empty( + len(output_tokens_jumbled), + device=output_tokens_jumbled.device, + dtype=output_tokens_jumbled.dtype, + ) + token_order = torch.cat(token_order_list, dim=0) + # Rearrange output tokens from sampling_bucket request order back to input ids order + output_tokens[token_order] = output_tokens_jumbled + + mtp_output_tokens = None + if has_mtp_logits: + mtp_output_tokens_jumbled = torch.cat( + mtp_output_tokens_jumbled_list, dim=1 + ) # Shape [num_speculative_tokens, total_tokens] + mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) + mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled + + return output_tokens, mtp_output_tokens, repeats + + def _verify_speculative_tokens( + self, + output_tokens: Tensor, + input_tokens_required: Tensor, + request_in_prefill_status_tensor: Tensor, + repeats: Tensor, + num_decode_requests: int, + num_prefill_requests: int, + active_request_count: int, + ) -> tuple: + """Verify speculative tokens against input tokens and compute acceptance. + + Creates an accepted tokens mask where: + - For prefill requests, the token is always accepted. + - For decode requests, the first token (base token) is always accepted, then we compare + sampled tokens with input tokens and accept consecutive matches. + Then finds the index of the last accepted token per request. + + Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): + input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 + Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] + Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] + Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + Last one indices [ 1 | 5 | 6 | 9 | 10 ] + + Returns: + tuple: (last_one_indices, accepted_tokens_mask, input_tokens_required) where + last_one_indices contains the index of the last accepted token per request. + """ + if input_tokens_required.ndim == 2: + assert ( + input_tokens_required.shape[0] == 1 + ), f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" + input_tokens_required = input_tokens_required.squeeze(0) + + # Initialize mask with False to prevent boundary bleed + accepted_tokens_mask = torch.zeros_like(input_tokens_required, dtype=torch.bool) + + # Make all prefill tokens accepted + token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) + accepted_tokens_mask[token_to_prefill_idx == 1] = True - # Store the buckets and their equivalence class representatives. - self._torch_sampling_buckets = ( - (indices, temp[rep], top_k[rep], top_p[rep]) for indices, rep in bucket_map.values() + # Safe decode token verification without cross-batch boundary contamination + decode_mask_2d = None + if num_decode_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + + decode_inputs = input_tokens_required[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1 + ) + decode_outputs = output_tokens[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1 ) + # Shift outputs right by 1 *within* each request to align sampled tokens with input targets + decode_outputs_shifted = decode_outputs.roll(1, dims=1) + decode_mask_2d = decode_inputs == decode_outputs_shifted + # The first token (base token) is always accepted + decode_mask_2d[:, 0] = True + # Enforce consecutive acceptance: cummin propagates False to the right + decode_mask_2d = decode_mask_2d.cummin(dim=1).values + accepted_tokens_mask[:decode_len] = decode_mask_2d.flatten() + + last_one_indices = torch.full( + (active_request_count,), -1, device=input_tokens_required.device + ) + + if num_decode_requests > 0: + # Summing the consecutive mask gives the count; subtract 1 for the local index + local_last_indices = decode_mask_2d.sum(dim=1) - 1 + row_offsets = torch.arange(num_decode_requests, device=last_one_indices.device) * ( + self.num_speculative_tokens + 1 + ) + last_one_indices[:num_decode_requests] = row_offsets + local_last_indices + + if num_prefill_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + prefill_valid = ( + torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len + ) + last_one_indices[num_decode_requests:] = prefill_valid + + return last_one_indices, accepted_tokens_mask, input_tokens_required + + def _dynamic_step_sample_logits_and_verify_tokens( + self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor + ): + """ + Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests + + # Get the logit indices for tokens that need sampling. + required_logit_indices = self._get_required_logit_indices( + request_in_prefill_status_tensor, + request_query_lengths, + num_decode_requests, + num_prefill_requests, + logits.device, + ) + + required_logits = logits.squeeze(0)[ + required_logit_indices, : + ] # Shape [num_required, vocab_size] + required_mtp_logits = None + if mtp_logits is not None: + required_mtp_logits = mtp_logits[ + :, required_logit_indices, : + ] # Shape [num_speculative_tokens, num_required, vocab_size] + + # Sample tokens from logits (and MTP logits if provided). + output_tokens, mtp_output_tokens, repeats = self._sample_speculative_logits( + required_logits, required_mtp_logits, request_in_prefill_status_tensor + ) + + # Verify speculative tokens against input tokens. + input_tokens_required = input_ids[0, required_logit_indices] + last_one_indices, accepted_tokens_mask, input_tokens_required = ( + self._verify_speculative_tokens( + output_tokens, + input_tokens_required, + request_in_prefill_status_tensor, + repeats, + num_decode_requests, + num_prefill_requests, + active_request_count, + ) + ) + + # Store the final sampled tokens for the next forward pass. + final_sampled_tokens = output_tokens[last_one_indices] + self._sampled_tokens_cuda[: len(final_sampled_tokens)] = final_sampled_tokens + + # Store MTP tokens if they were computed inline (non-serial path). + if mtp_output_tokens is not None: + self._sampled_mtp_tokens_cuda[:, : len(final_sampled_tokens)] = mtp_output_tokens[ + :, last_one_indices + ] + + # Store the last accepted positions in the packed sequence for serial + # MTP computation after verification. + self._last_accepted_seq_indices = required_logit_indices[last_one_indices] + + # Extract accepted tokens and counts for decode requests. + # For prefill it is always set to 1. For decode, the first token is always accepted, + # then we compare with input tokens and accept the next tokens if its a match. + # + # Example (continuing from above): + # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] + # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only decode requests (prefill defaults to -1) + # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 + input_tokens_required[accepted_tokens_mask == 0] = -1 # Mask out non-accepted tokens + input_tokens_decode_mode = input_tokens_required[ + : num_decode_requests * (self.num_speculative_tokens + 1) + ] + input_tokens_reshaped = input_tokens_decode_mode.reshape( + -1, self.num_speculative_tokens + 1 + ) # shape: [num_decode_requests, num_speculative_tokens + 1] + + # Skip the first token of every decode request (i.e a5, b3, c6) + accepted_tokens = input_tokens_reshaped[:, 1:] + self._accepted_tokens_per_request[: accepted_tokens.shape[0], :] = accepted_tokens + self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum( + dim=1 + ) + def _dynamic_step_sample_logits(self, logits: Tensor): """Sample tokens from logits for dynamic batching. @@ -652,24 +1197,31 @@ def _dynamic_step_sample_logits(self, logits: Tensor): if context.config.materialize_only_last_token_logits: # When materialize_only_last_token_logits is true, last_token_logits is # already called in the forward pass of GPT. - last_token_logits = logits.squeeze(0) + required_token_logits = logits.squeeze(0) else: - last_token_logits = context.last_token_logits(logits) + # todo : Should do verification here and get approrpiate las token logits + required_token_logits = context.last_token_logits(logits) if self._sampling_backend == "torch": # Concatenate the outputs once to prevent repeated small writes. token_list = [] indices_list = [] + # e.g torch sample buckets will be + # i.e (for all unique comibnation of t, topk, topk what are the associated + # requests indices (based on the active slices) + # [ [req at index 0, req at index 2], t1, topk1, topp1 ]] + # [ [req at index 1, req at index 3, req at index 4] , t2, topk2, topp2] for indices, temp, top_k, top_p in self._torch_sampling_buckets: token_list.append( - self._torch_sampling_func(last_token_logits[indices, :], temp, top_k, top_p) + self._torch_sampling_func(required_token_logits[indices, :], temp, top_k, top_p) ) indices_list.append(torch.tensor(indices)) # Single write to the output tensor. sampled_tokens = torch.cat(token_list, dim=0) sampled_indices = torch.cat(indices_list, dim=0) + self._sampled_tokens_cuda[sampled_indices] = sampled_tokens def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: @@ -757,6 +1309,201 @@ def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: only_last_token_logits=context.config.materialize_only_last_token_logits, ) + def _dynamic_step_calculate_log_probs_speculative( + self, logits: Tensor + ) -> Tuple[List[List[float]], Tensor]: + """Calculate log probs from logits for speculative decoding. + + For decode requests, computes log probs for each accepted speculative token + and the newly sampled token using the main model logits. For prefill requests, + handles prompt log probs the same way as non-speculative decoding. + + The main model logits at position j predict the token at position j+1. So: + - log_prob(accepted_token[j]) comes from logits at position j + - log_prob(newly_sampled_token) comes from logits at position accepted_count + + Args: + logits (Tensor): The main model logits [1, seq_len, vocab_size]. + + Returns: + Tuple of (log_probs_list, log_probs_tensor): + log_probs_list: List of lists, one per active request, containing + log probs for the tokens emitted in this step. + log_probs_tensor: Full log_softmax tensor for top-n computation. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests + + logits_squeezed = logits.squeeze(0).float() + log_probs_tensor = F.log_softmax(logits_squeezed[: context.active_token_count], dim=-1) + + log_probs_list_decode = [] + + if num_decode_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + decode_log_probs = log_probs_tensor[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1, -1 + ) + accepted_counts = self._accepted_token_counts_per_request[:num_decode_requests] + + # Build a [num_decode, num_spec+1] token ID matrix for gathering. + # Columns 0..num_spec-1 hold accepted speculative tokens (clamped to 0 + # where rejected, since those positions will be masked out). + # At column accepted_count[i], place the newly sampled token. + gather_tokens = torch.zeros( + num_decode_requests, + self.num_speculative_tokens + 1, + device=logits.device, + dtype=torch.long, + ) + gather_tokens[:, : self.num_speculative_tokens] = self._accepted_tokens_per_request[ + :num_decode_requests + ].clamp(min=0) + gather_tokens[ + torch.arange(num_decode_requests, device=logits.device), accepted_counts + ] = self._sampled_tokens_cuda[:num_decode_requests] + + # Gather: [num_decode, num_spec+1] + gathered_log_probs = decode_log_probs.gather(2, gather_tokens.unsqueeze(-1)).squeeze(-1) + + log_probs_list_decode = [ + gathered_log_probs[i, : accepted_counts[i].item() + 1].tolist() + for i in range(num_decode_requests) + ] + + log_probs_list_prefill = [] + if num_prefill_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + prefill_log_probs = log_probs_tensor[decode_len:] + + prefill_token_ids = context.token_to_input_ids[ + decode_len : context.active_token_count + ].roll(-1, 0) + prefill_query_lengths = request_query_lengths[request_in_prefill_status_tensor == 1] + new_token_idx = prefill_query_lengths.cumsum(0) - 1 + prefill_new_tokens = self._sampled_tokens_cuda[num_decode_requests:active_request_count] + prefill_token_ids[new_token_idx] = prefill_new_tokens + + prefill_token_count = context.active_token_count - decode_len + seq_idx = torch.arange(prefill_token_count, device=logits.device) + selected_log_probs = prefill_log_probs[seq_idx, prefill_token_ids] + + prefill_log_probs_split = selected_log_probs.cpu().split( + prefill_query_lengths.tolist(), dim=0 + ) + log_probs_list_prefill = [lp.tolist() for lp in prefill_log_probs_split] + + log_probs_list = log_probs_list_decode + log_probs_list_prefill + + return log_probs_list, log_probs_tensor + + def _dynamic_step_calculate_top_n_logprobs_speculative( + self, log_probs_tensor: Tensor + ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: + """Calculate top-n log probs for speculative decoding. + + For decode requests, computes top-n at each position that produced an + emitted token (accepted speculative positions + the newly sampled position). + For prefill requests, behaves identically to the non-speculative path. + + Args: + log_probs_tensor (Tensor): Pre-computed log_softmax tensor from + _dynamic_step_calculate_log_probs_speculative. + + Returns: + A dictionary mapping request_idx to list of (top_n_values, top_n_indices) + tuples, one per emitted token position. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_request_slice = slice(context.paused_request_count, context.total_request_count) + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests + + top_n_results = {} + + if num_decode_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + decode_log_probs = log_probs_tensor[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1, -1 + ) + accepted_counts = self._accepted_token_counts_per_request[:num_decode_requests] + top_n_per_request = self._request_metadata["top_n_logprobs"][active_request_slice][ + :num_decode_requests + ] + max_top_n = int(top_n_per_request.max().item()) + + if max_top_n > 0: + + # Single batched topk on GPU: [num_decode, num_spec+1, max_top_n] + topk_results = torch.topk(decode_log_probs, k=max_top_n, dim=-1) + + # Single CPU transfer instead of O(num_decode * num_spec) transfers + topk_values_cpu = topk_results.values.cpu() + topk_indices_cpu = topk_results.indices.cpu() + + for i in range(num_decode_requests): + top_n = int(top_n_per_request[i].item()) + if top_n > 0: + num_valid = accepted_counts[i].item() + 1 + top_n_results[i] = [ + (topk_values_cpu[i, j, :top_n], topk_indices_cpu[i, j, :top_n]) + for j in range(num_valid) + ] + + if num_prefill_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + prefill_log_probs = log_probs_tensor[decode_len:] + prefill_query_lengths = request_query_lengths[request_in_prefill_status_tensor == 1] + prefill_log_probs_per_request = prefill_log_probs.split( + prefill_query_lengths.tolist(), dim=0 + ) + + for i in range(num_prefill_requests): + req_idx = num_decode_requests + i + top_n = int( + self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() + ) + if top_n > 0: + request_lp = prefill_log_probs_per_request[i] + skip_prompt = bool( + self._request_metadata["skip_prompt_log_probs"][req_idx].item() + ) + + if skip_prompt and request_lp.size(0) > 1: + top_n_logits = torch.topk(request_lp[-1], k=top_n) + top_n_results[req_idx] = [ + (top_n_logits.values.cpu(), top_n_logits.indices.cpu()) + ] + else: + top_n_logits = torch.topk(request_lp, k=top_n, dim=-1) + top_n_values_cpu = top_n_logits.values.cpu() + top_n_indices_cpu = top_n_logits.indices.cpu() + top_n_results[req_idx] = [ + (top_n_values_cpu[t], top_n_indices_cpu[t]) + for t in range(request_lp.size(0)) + ] + + return top_n_results if top_n_results else None + def _dynamic_step_calculate_top_n_logprobs( self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: @@ -853,7 +1600,9 @@ def dummy_forward(self): model_config = get_model_config(unwrapped_model) if model_config.transformer_impl == "inference_optimized": context.maybe_initialize_symmetric_memory() - return self.inference_wrapped_model.dummy_forward() + self.inference_wrapped_model.dummy_forward() + self._dummy_serial_mtp_forward() + return # attempt to use cuda-graph if possible input_ids, position_ids = self._dynamic_step_context_init(is_dummy_forward=True) @@ -869,9 +1618,78 @@ def dummy_forward(self): # fallback to eager dummy forward self.inference_wrapped_model.dummy_forward() + # Disable MoE padding for MTP computation + if self.model_config.moe_pad_experts_for_cuda_graph_inference: + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + set_decode_expert_padding(unwrapped_model, False) + + # When speculative decoding is active, the real EP ranks perform serial + # MTP forward passes after the main forward pass. MTP layers may contain + # MoE sublayers (inherited from the decoder spec), which require EP + # all-to-all collectives. The dummy rank must participate in these + # collectives to avoid a hang. + self._dummy_serial_mtp_forward() + # clear the context of any temporary state from the dummy forward context.reset() + def _dummy_serial_mtp_forward(self): + """Run dummy MTP forward passes to participate in EP collectives. + + When speculative decoding is active and MTP layers contain MoE sublayers + (inherited from the decoder layer spec), each serial MTP step triggers + EP all-to-all collectives. The dummy EP rank must issue matching + collective calls so the real ranks do not hang. + + This mirrors the structure of ``_compute_serial_mtp_and_sample``: + - On the last PP stage (where MTP resides): run ``compute_mtp_single_step`` + with dummy tensors so the MoE all-to-all is executed. + - When PP > 1: participate in the ``broadcast_from_last_pipeline_stage`` + that the real ranks also perform. + """ + if self.num_speculative_tokens == 0 or self.num_mtp_heads == 0: + return + if self.model_config.expert_model_parallel_size <= 1: + return + + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + + is_last_stage = is_pipeline_last_stage(self.pp_group) + has_mtp = is_last_stage and hasattr(unwrapped_model, 'mtp') + if not has_mtp and not self.model_is_pipeline_parallel: + # No MTP on this rank and no PP broadcast to participate in. + return + + device = torch.cuda.current_device() + dtype = self.model_config.params_dtype + hidden_size = self.model_config.hidden_size + num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) + + dummy_hidden = None + if has_mtp: + # Minimal dummy tensors — just enough to drive the MTP layer forward + # so that the MoE all-to-all collectives are issued. + dummy_hidden = torch.zeros((1, 1, hidden_size), device=device, dtype=dtype) + dummy_token_ids = torch.zeros((1, 1), device=device, dtype=torch.long) + dummy_position_ids = torch.zeros((1, 1), device=device, dtype=torch.long) + + for depth in range(num_depths): + mtp_logits_2d = None + if has_mtp: + dummy_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( + hidden_states=dummy_hidden, + next_token_ids=dummy_token_ids, + position_ids=dummy_position_ids, + depth=depth, + ) + mtp_logits_2d = mtp_logits.squeeze(1) # [1, vocab_size] + + # Match the PP broadcast that real ranks do in _compute_serial_mtp_and_sample. + if self.model_is_pipeline_parallel: + broadcast_from_last_pipeline_stage( + [1, self.vocab_size], dtype=dtype, tensor=mtp_logits_2d, pp_group=self.pp_group + ) + def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: """Update the dynamic inference context after sampling. @@ -894,7 +1712,13 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: # Active sequence lengths. active_request_ids = context.request_ids[active_request_slice].long() active_sequence_lengths = context.get_active_sequence_lengths() - active_sequence_lengths += 1 # Account for the token we just generated + + if self.num_speculative_tokens > 0: + active_sequence_lengths += ( + self._accepted_token_counts_per_request[:active_request_count] + 1 + ) + else: + active_sequence_lengths += 1 max_sequence_lengths = context.get_max_sequence_lengths() # Request finished if termination_id or length >= max_sequence_length. @@ -918,11 +1742,19 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: ) finished_request_ids = context.request_ids[finished_idxs] - # New sample gets updated in update_requests, so we pass in a clone + # Clone needed: update_requests mutates next_tokens in-place via tensor_swap, + # which would corrupt the reused _sampled_tokens_cuda buffer. new_sample_copy = self._sampled_tokens_cuda[:active_request_count].clone() # Update requests. - update_result = context.update_requests(active_request_mask, new_sample_copy) + # _sampled_mtp_tokens_cuda has shape [num_speculative_tokens, max_requests] + if self.num_speculative_tokens > 0: + sampled_mtp_tokens_cuda = self._sampled_mtp_tokens_cuda[:, :active_request_count] + else: + sampled_mtp_tokens_cuda = None + update_result = context.update_requests( + active_request_mask, new_sample_copy, sampled_mtp_tokens_cuda + ) return { "active_request_ids": active_request_ids, @@ -966,6 +1798,8 @@ async def async_generate_output_tokens_dynamic_batch( if config.moe_enable_routing_replay: RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + # Forward pass produces only base logits. When speculative decoding is + # active, MTP logits are computed serially after verification. logits = self._dynamic_step_forward_logits(input_ids, position_ids) # Collect routing indices per request (must be done before context transitions) @@ -979,19 +1813,44 @@ async def async_generate_output_tokens_dynamic_batch( # Todo [Siddharth]: Can we condition the sleep on a cuda event? # NOTE [TDE]: This will be moved once CPU and GPU methods are separated. await asyncio.sleep(0) - return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() + self._dynamic_step_sample_bookkeeping() - self._dynamic_step_sample_logits(logits) + + if self.num_speculative_tokens > 0: + # Phase 1: Verify speculative tokens using base logits only. + # MTP logits are NOT passed here; they will be computed serially. + self._dynamic_step_sample_logits_and_verify_tokens(logits, None, input_ids) + # Phase 2: Rewind KV cache for rejected tokens. + self._rewind_kv_cache() + + # Disable MoE padding for MTP computation + if self.model_config.moe_pad_experts_for_cuda_graph_inference: + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + set_decode_expert_padding(unwrapped_model, False) + + # Phase 3: Compute MTP serially with correct (verified) inputs. + self._compute_serial_mtp_and_sample() + else: + self._dynamic_step_sample_logits(logits) log_probs = None top_n_logprobs = None if return_log_probs or return_top_n_logprobs: - log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) - if return_top_n_logprobs: - top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( - logits, log_probs_tensor + if self.num_speculative_tokens > 0: + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs_speculative( + logits ) + if return_top_n_logprobs: + top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs_speculative( + log_probs_tensor + ) + else: + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) + if return_top_n_logprobs: + top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( + logits, log_probs_tensor + ) if skip_bookkeeping: request_bookkeeping = {} @@ -999,12 +1858,22 @@ async def async_generate_output_tokens_dynamic_batch( request_bookkeeping = self._dynamic_step_context_bookkeeping() ret = { - "sample": self._sampled_tokens_cuda[:active_request_count], + # Clone needed: _sampled_tokens_cuda is a reused buffer overwritten each step. + "sample": self._sampled_tokens_cuda[:active_request_count].clone(), + "accepted_tokens": ( + # Clone needed: .fill_(-1) on line 1480 would corrupt the returned value. + self._accepted_tokens_per_request.clone() + if self.num_speculative_tokens > 0 + else None + ), "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, "routing_indices_per_request": routing_indices_per_request, "cuda_graph_request_count": cuda_graph_request_count, } + if self.num_speculative_tokens > 0: + self._accepted_tokens_per_request.fill_(-1) + self._accepted_token_counts_per_request.fill_(0) ret.update(request_bookkeeping) return ret diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 05a7e8f60bb..0e560f939f2 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -240,7 +240,15 @@ def get_rotary_seq_len( # by the tp and cp size. return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) elif inference_context is not None: - rotary_seq_len = inference_context.max_sequence_length + # For dynamic batching, use the max of context's max_sequence_length and the actual + # input size to ensure rotary embeddings cover CUDA graph warmup token counts + context_max_seq_len = inference_context.max_sequence_length + input_seq_len = 0 + if transformer_input is not None: + input_seq_len = transformer_input.size(0) + elif transformer is not None and transformer.input_tensor is not None: + input_seq_len = transformer.input_tensor.size(0) + rotary_seq_len = max(context_max_seq_len, input_seq_len) else: if transformer is not None and transformer.input_tensor is not None: rotary_seq_len = transformer.input_tensor.size(0) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 2a3b3b3c69d..2de628f1f8e 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -27,6 +27,7 @@ from megatron.core.transformer.enums import CudaGraphScope, ModelType from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, + compute_mtp_inference_logits, mtp_on_this_rank, process_mtp_loss, ) @@ -589,17 +590,26 @@ def _postprocess( if in_inference_mode: assert runtime_gather_output, "Inference must always gather TP logits" + # Check if speculative decoding is active. When it is, MTP must be + # computed *after* verification so that it is conditioned on verified + # tokens rather than stale speculative tokens from the previous step. + is_spec_decode = ( + in_inference_mode + and hasattr(inference_context, 'num_speculative_tokens') + and inference_context.num_speculative_tokens > 0 + ) + # logits and loss output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + if mtp_in_postprocess and not is_spec_decode: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, hidden_states=hidden_states, attention_mask=attention_mask, - inference_params=inference_params, + inference_params=None, # MTP layers don't use KV cache rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, @@ -613,20 +623,38 @@ def _postprocess( return hidden_states if self.config.mtp_num_layers: - hidden_states = process_mtp_loss( - hidden_states=hidden_states, - labels=labels, - loss_mask=loss_mask, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - is_training=self.training, - compute_language_model_loss=self.compute_language_model_loss, - config=self.config, - cp_group=self.pg_collection.cp, - packed_seq_params=packed_seq_params, - scale_logits_fn=self._scale_logits if self.config.use_mup else None, - ) + assert self.config.mtp_num_layers > 0 + # The new process_mtp_loss function doesn't handle mtp_logits_cache, + # so we manually generate and cache MTP logits when in inference mode. + if in_inference_mode: + if is_spec_decode: + # Cache decoder hidden states for serial MTP computation + # after speculative token verification. + self._decoder_hidden_states_cache = hidden_states + else: + hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( + hidden_states=hidden_states, + mtp_num_layers=self.config.mtp_num_layers, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + else: + # In training/eval, use the utility function for processing MTP loss/scaling. + hidden_states = process_mtp_loss( + hidden_states=hidden_states, + labels=labels, + loss_mask=loss_mask, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + is_training=self.training, + compute_language_model_loss=self.compute_language_model_loss, + config=self.config, + cp_group=self.pg_collection.cp, + packed_seq_params=packed_seq_params, + scale_logits_fn=self._scale_logits if self.config.use_mup else None, + ) sequence_parallel_override = False if in_inference_mode and inference_context.config.materialize_only_last_token_logits: @@ -643,12 +671,10 @@ def _postprocess( self.output_layer.sequence_parallel = False sequence_parallel_override = True - # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden - # state ([B, H]) → unsqueeze back to [B, 1, H] - # (so that the output layer, which expects S×B×H, receives only the final token) - hidden_states = inference_context.last_token_logits( - hidden_states.squeeze(1).unsqueeze(0) - ).unsqueeze(1) + # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, + # then back to [S’, B, H] for the output layer. + reshaped = hidden_states.squeeze(1).unsqueeze(0) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output @@ -686,6 +712,49 @@ def _postprocess( return loss + @torch.inference_mode() + def compute_mtp_single_step( + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + depth: int, + runtime_gather_output: bool = True, + ) -> tuple: + """Compute a single MTP depth for speculative decoding. + + This is called after speculative token verification to compute MTP + predictions conditioned on verified tokens only. + + Args: + hidden_states (Tensor): Hidden states at last accepted positions [N, 1, H]. + next_token_ids (Tensor): Correct next token IDs [1, N]. + position_ids (Tensor): Position IDs for the next tokens [1, N]. + depth (int): MTP depth index (0-indexed). + runtime_gather_output (bool): Whether to gather output across TP. + + Returns: + tuple: (new_hidden_states [N, 1, H], logits [N, 1, vocab_size]). + """ + layer_idx = 0 if self.mtp.mtp_use_repeated_layer else depth + mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( + hidden_states=hidden_states, + next_token_ids=next_token_ids, + position_ids=position_ids, + embedding=self.embedding, + ) + + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + logits, _ = self.output_layer( + mtp_hidden, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + logits = self._scale_logits(logits) + + return mtp_hidden, logits + def build_schedule_plan( self, input_ids: Tensor, diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index edafe3db3e5..65c371ef0fd 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -3,6 +3,7 @@ import logging from typing import Literal, Optional +import torch from torch import Tensor from megatron.core import tensor_parallel @@ -19,6 +20,7 @@ from megatron.core.transformer.enums import ModelType from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, + compute_mtp_inference_logits, mtp_on_this_rank, process_mtp_loss, ) @@ -386,8 +388,16 @@ def forward( if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - # TODO(helenn/MCore inference): enable MTP inference. - mtp_forward_ran = self.mtp_process and self.training and inference_context is None + # Check if speculative decoding is active. When it is, MTP must be + # computed *after* verification so that it is conditioned on verified + # tokens rather than stale speculative tokens from the previous step. + is_spec_decode = ( + in_inference_mode + and hasattr(inference_context, 'num_speculative_tokens') + and inference_context.num_speculative_tokens > 0 + ) + + mtp_forward_ran = self.mtp_process and not is_spec_decode if mtp_forward_ran: hidden_states = self.mtp( input_ids=input_ids, @@ -403,22 +413,38 @@ def forward( if not self.post_process: return hidden_states - if self.config.mtp_num_layers is not None and mtp_forward_ran: - hidden_states = process_mtp_loss( - hidden_states=hidden_states, - labels=labels, - loss_mask=loss_mask, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - is_training=self.training, - compute_language_model_loss=self.compute_language_model_loss, - config=self.config, - cp_group=self.pg_collection.cp, - packed_seq_params=packed_seq_params, - scale_logits_fn=self._scale_logits if self.config.use_mup else None, - ) - + if self.config.mtp_num_layers is not None and (mtp_forward_ran or is_spec_decode): + assert self.config.mtp_num_layers > 0 + # The new process_mtp_loss function doesn't handle mtp_logits_cache, + # so we manually generate and cache MTP logits when in inference mode. + if in_inference_mode: + if is_spec_decode: + # Cache decoder hidden states for serial MTP computation + # after speculative token verification. + self._decoder_hidden_states_cache = hidden_states + else: + hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( + hidden_states=hidden_states, + mtp_num_layers=self.config.mtp_num_layers, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + else: + hidden_states = process_mtp_loss( + hidden_states=hidden_states, + labels=labels, + loss_mask=loss_mask, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + is_training=self.training, + compute_language_model_loss=self.compute_language_model_loss, + config=self.config, + cp_group=self.pg_collection.cp, + packed_seq_params=packed_seq_params, + scale_logits_fn=self._scale_logits if self.config.use_mup else None, + ) sequence_parallel_override = False if in_inference_mode and inference_context.config.materialize_only_last_token_logits: if inference_context.is_static_batching(): @@ -434,12 +460,10 @@ def forward( self.output_layer.sequence_parallel = False sequence_parallel_override = True - # Reshape [B, 1, H] to [1, B, H] → extract each sample's true last‐token hidden - # state ([B, H]) → unsqueeze back to [B, 1, H] - # (so that the output layer, which expects S×B×H, receives only the final token) - hidden_states = inference_context.last_token_logits( - hidden_states.squeeze(1).unsqueeze(0) - ).unsqueeze(1) + # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, + # then back to [S', B, H] for the output layer. + reshaped = hidden_states.squeeze(1).unsqueeze(0) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output @@ -462,3 +486,46 @@ def forward( loss = self.compute_language_model_loss(labels, logits) return loss + + @torch.inference_mode() + def compute_mtp_single_step( + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + depth: int, + runtime_gather_output: bool = True, + ) -> tuple: + """Compute a single MTP depth for speculative decoding. + + This is called after speculative token verification to compute MTP + predictions conditioned on verified tokens only. + + Args: + hidden_states (Tensor): Hidden states at last accepted positions [N, 1, H]. + next_token_ids (Tensor): Correct next token IDs [1, N]. + position_ids (Tensor): Position IDs for the next tokens [1, N]. + depth (int): MTP depth index (0-indexed). + runtime_gather_output (bool): Whether to gather output across TP. + + Returns: + tuple: (new_hidden_states [N, 1, H], logits [N, 1, vocab_size]). + """ + layer_idx = 0 if self.mtp.mtp_use_repeated_layer else depth + mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( + hidden_states=hidden_states, + next_token_ids=next_token_ids, + position_ids=position_ids, + embedding=self.embedding, + ) + + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + logits, _ = self.output_layer( + mtp_hidden, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + logits = self._scale_logits(logits) + + return mtp_hidden, logits diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 6c2395ded94..30b1a28bd71 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -25,6 +25,8 @@ ) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update +from megatron.core.ssm.ops.mamba_ssm import selective_state_update from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule @@ -45,16 +47,11 @@ from .mamba_context_parallel import MambaContextParallel try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + from causal_conv1d import causal_conv1d_fn from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states + except ImportError: causal_conv1d_fn = None - causal_conv1d_update = None try: from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated @@ -428,8 +425,18 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere ) assert sequence_packing_available, reason_for_no_sequence_packing + # Grab standard states conv_state, ssm_state = context.mamba_states_cache(self.layer_number - self.pp_layer_offset) + # Fetch intermediate states for speculative decoding + # (just buffers, existing data is overwritten) + int_conv_state = None + int_ssm_state = None + if context.num_speculative_tokens > 0: + int_conv_state, int_ssm_state = context.mamba_states_cache( + self.layer_number - self.pp_layer_offset, intermediate=True + ) + padded_dims = context.padded_batch_dimensions token_count = padded_dims.token_count decode_req_count = padded_dims.decode_req_count @@ -444,14 +451,25 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere # Decode if decode_req_count > 0: # For mixed batch, the decode tokens are at the start of zxBCdt - zxBCdt_decode = zxBCdt[:decode_req_count] if prefill_req_count > 0 else zxBCdt + seq_len = 1 + context.num_speculative_tokens + decode_token_count = decode_req_count * seq_len + + zxBCdt_decode = zxBCdt[:decode_token_count] if prefill_req_count > 0 else zxBCdt + + # Reshape from [N*S, 1, d] to [N, S, d] for the 3D Triton kernels + zxBCdt_decode = zxBCdt_decode.squeeze(1).view(decode_req_count, seq_len, -1) y_decode = self._ssm_decode( - zxBCdt_decode.transpose(0, 1), + zxBCdt_decode, conv_state, ssm_state, - context.mamba_metadata.batch_indices_decode, - ).transpose(0, 1) + batch_indices=context.mamba_metadata.batch_indices_decode, + intermediate_conv_state=int_conv_state, + intermediate_ssm_state=int_ssm_state, + ) + + # Flatten back to [N*S, 1, d] to match merge logic + y_decode = y_decode.view(decode_token_count, 1, -1) # Prefill if prefill_req_count > 0: @@ -853,27 +871,29 @@ def _ssm_decode( conv_state: torch.Tensor, ssm_state: torch.Tensor, batch_indices: Optional[torch.Tensor] = None, + intermediate_conv_state: Optional[torch.Tensor] = None, + intermediate_ssm_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Performs SSM computation for inference decode step. Args: - zxBCdt: The input tensor of shape (l, b, d), which is a concatenation of - z, x, B, C, and dt projections. For decoding, l must be 1. + zxBCdt: The input tensor of shape (b, s, d), which is a concatenation of + z, x, B, C, and dt projections. + s is the sequence length (1 + num_speculative_tokens). conv_state: The convolution state tensor for inference. ssm_state: The selective scan state tensor for inference. - batch_indices: A map from batch id to position in the Mamba state tensors for - dynamic inference. + batch_indices: A map from batch id to position in the Mamba state tensors. + intermediate_conv_state: Optional buffer for storing conv state at each + sequence step (for speculative decoding rollback). + intermediate_ssm_state: Optional buffer for storing SSM state at each + sequence step (for speculative decoding rollback). Returns: - The output tensor of shape (l, b, d). + The output tensor of shape (b, s, d). """ - seq_len, batch_size, _ = zxBCdt.shape + batch_size, seq_len, _ = zxBCdt.shape dtype = zxBCdt.dtype - assert seq_len == 1, "Only support decoding with 1 token at a time for now" - - # Remove sequence dimension - zxBCdt = zxBCdt.squeeze(0) z, xBC, dt = torch.split( zxBCdt, @@ -887,14 +907,17 @@ def _ssm_decode( # Conv step if causal_conv1d_update is None: + # TODO(ksanthanam): Consider deprecating this path + assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" + xBC_squeeze = xBC.squeeze(1) conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum( + conv_state[:, :, -1] = xBC_squeeze + xBC_squeeze = torch.sum( conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 ) # (B D) if self.conv1d.bias is not None: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(dtype=xBC.dtype) + xBC_squeeze = xBC_squeeze + self.conv1d.bias + xBC = self.act(xBC_squeeze).to(dtype=xBC.dtype).unsqueeze(1) else: # Conv state dtype might differ from params dtype, so cast xBC and weight / bias # tensors to the conv state dtype for causal_conv1d_update and then cast xBC @@ -908,6 +931,7 @@ def _ssm_decode( self.conv1d.bias.to(conv_state.dtype), self.activation, conv_state_indices=batch_indices, + intermediate_conv_states=intermediate_conv_state, ).to(xBC_dtype) x, B, C = torch.split( @@ -923,6 +947,16 @@ def _ssm_decode( # SSM step if selective_state_update is None: + # TODO(ksanthanam): Consider deprecating this path + assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" + + x = x.squeeze(1) + B = B.squeeze(1) + C = C.squeeze(1) + dt = dt.squeeze(1) + if z is not None: + z = z.squeeze(1) + if self.ngroups_local_tp > 1: B = rearrange(B, "b (g n) -> b g n", n=self.d_state) C = rearrange(C, "b (g n) -> b g n", n=self.d_state) @@ -967,16 +1001,20 @@ def _ssm_decode( y = rearrange(y, "b h p -> b (h p)") if not self.rmsnorm: y = y * self.act(z) # (B D) + + y = y.unsqueeze(1) # Restore seq dimension else: A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) - dt = repeat(dt, "b h -> b h p", p=self.headdim) + + # Incorporate sequence dimension in einops rearrengements + dt = repeat(dt, "b s h -> b s h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) D = repeat(self.D, "h -> h p", p=self.headdim) - B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local_tp) - C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local_tp) - x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + B = rearrange(B, "b s (g n) -> b s g n", g=self.ngroups_local_tp) + C = rearrange(C, "b s (g n) -> b s g n", g=self.ngroups_local_tp) + x_reshaped = rearrange(x, "b s (h p) -> b s h p", p=self.headdim) if not self.rmsnorm: - z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + z = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) # Upcast the batch_indices to prevent integer overflow errors in the case of # large max request counts. @@ -995,14 +1033,14 @@ def _ssm_decode( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=batch_indices, + intermediate_ssm_states=intermediate_ssm_state, # SSM only ) - y = rearrange(y, "b h p -> b (h p)") + y = rearrange(y, "b s h p -> b s (h p)") if self.rmsnorm: y = self.norm(y, z) - # Restore sequence dimension - return y.unsqueeze(0) + return y def mamba_state_shapes_per_request(self) -> Tuple[Tuple[int], Tuple[int]]: """Returns the Mamba conv and ssm states shapes per request.""" diff --git a/megatron/core/ssm/ops/__init__.py b/megatron/core/ssm/ops/__init__.py new file mode 100644 index 00000000000..3e4afde2e29 --- /dev/null +++ b/megatron/core/ssm/ops/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py new file mode 100644 index 00000000000..36d14a1d91b --- /dev/null +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -0,0 +1,274 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +# Some of this code was adopted from https://github.com/Dao-AILab/causal-conv1d/ +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def causal_conv1d_update_kernel( + x_ptr, + x_b_stride, + x_s_stride, + x_c_stride, + conv_state_ptr, + conv_state_b_stride, + conv_state_c_stride, + conv_state_l_stride, + int_state_ptr, + int_state_b_stride, + int_state_s_stride, + int_state_c_stride, + int_state_l_stride, + weight_ptr, + weight_c_stride, + weight_width_stride, + bias_ptr, + bias_stride, + out_ptr, + out_b_stride, + out_s_stride, + out_c_stride, + conv_state_indices_ptr, + batch, + seq_len, + dim, + state_len, + WIDTH: tl.constexpr, + BLOCK_DIM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_STATE_INDICES: tl.constexpr, + HAS_INT_STATE: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, +): + """Triton implementation of causal_conv1d_update (kernel).""" + batch_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + + channel_offsets = channel_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) + mask = channel_offsets < dim + + # State batch coordinate mapping + if HAS_STATE_INDICES: + state_batch_coord = tl.load(conv_state_indices_ptr + batch_id) + else: + state_batch_coord = batch_id + + # Base Pointers + conv_state_ptrs = ( + conv_state_ptr + + state_batch_coord * conv_state_b_stride + + channel_offsets * conv_state_c_stride + ) + weight_ptrs = weight_ptr + channel_offsets * weight_c_stride + + # Skip padding tokens (block-level uniform condition) + if state_batch_coord < 0: + for s in range(seq_len): + out_ptrs = ( + out_ptr + + batch_id * out_b_stride + + s * out_s_stride + + channel_offsets * out_c_stride + ) + tl.store(out_ptrs, 0.0, mask=mask) + return + + # Load Bias + if HAS_BIAS: + bias_val = tl.load(bias_ptr + channel_offsets * bias_stride, mask=mask).to(tl.float32) + else: + bias_val = tl.zeros([BLOCK_DIM], dtype=tl.float32) + + # Load Weights + if WIDTH == 2: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + elif WIDTH == 3: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + w2 = tl.load(weight_ptrs + 2 * weight_width_stride, mask=mask).to(tl.float32) + elif WIDTH == 4: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + w2 = tl.load(weight_ptrs + 2 * weight_width_stride, mask=mask).to(tl.float32) + w3 = tl.load(weight_ptrs + 3 * weight_width_stride, mask=mask).to(tl.float32) + + # Initialize independent x_vals to match unrolled float array + x_val_0 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_1 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_2 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_3 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + + # Loop over the sequence dimension (e.g., speculative tokens) + for s in range(seq_len): + x_ptrs = x_ptr + batch_id * x_b_stride + s * x_s_stride + channel_offsets * x_c_stride + out_ptrs = ( + out_ptr + batch_id * out_b_stride + s * out_s_stride + channel_offsets * out_c_stride + ) + + # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten + # by the shift + if WIDTH >= 2: + x_val_0 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 3: + x_val_1 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 4: + x_val_2 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask + ).to(tl.float32) + + # Shift the linear state buffer left by 1 + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + + # Process the single token for the current sequence step + x_val = tl.load(x_ptrs, mask=mask) + + # Store the new token at the end of the linear state buffer + tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) + + # Write out to the intermediate state buffer if requested + if HAS_INT_STATE: + i = 0 + while i < state_len: + val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) + int_ptr = ( + int_state_ptr + + state_batch_coord * int_state_b_stride + + s * int_state_s_stride + + channel_offsets * int_state_c_stride + + i * int_state_l_stride + ) + tl.store(int_ptr, val, mask=mask) + i += 1 + + # Advance registers for calculation + x_val_f32 = x_val.to(tl.float32) + if WIDTH == 2: + x_val_1 = x_val_f32 + elif WIDTH == 3: + x_val_2 = x_val_f32 + elif WIDTH == 4: + x_val_3 = x_val_f32 + + # Compute output + out_val = bias_val + if WIDTH == 2: + out_val += w0 * x_val_0 + w1 * x_val_1 + elif WIDTH == 3: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + elif WIDTH == 4: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + w3 * x_val_3 + + if SILU_ACTIVATION: + out_val = out_val * tl.sigmoid(out_val) + + tl.store(out_ptrs, out_val.to(out_ptrs.dtype.element_ty), mask=mask) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + conv_state_indices: torch.Tensor | None, + intermediate_conv_states: torch.Tensor | None = None, +) -> torch.Tensor: + """Triton implementation of causal_conv1d_update (entrypoint).""" + + # Check if input is 2D, temporarily treat as 3D for uniform processing + is_2d = x.dim() == 2 + if is_2d: + x = x.unsqueeze(1) + + batch, seq_len, dim = x.shape + out = torch.empty_like(x) + state_len = conv_state.shape[-1] + width = weight.shape[-1] + + if bias is not None: + bias_stride = bias.stride(0) + has_bias = True + else: + bias = x # Dummy pointer + bias_stride = 0 + has_bias = False + + if conv_state_indices is not None: + has_state_indices = True + else: + conv_state_indices = x # Dummy pointer + has_state_indices = False + + # Extract intermediate state strides if provided + if intermediate_conv_states is not None: + has_int_state = True + int_state_ptr = intermediate_conv_states + int_state_b_stride = intermediate_conv_states.stride(0) + int_state_s_stride = intermediate_conv_states.stride(1) + int_state_c_stride = intermediate_conv_states.stride(2) + int_state_l_stride = intermediate_conv_states.stride(3) + else: + has_int_state = False + int_state_ptr = x # Dummy pointer + int_state_b_stride = 0 + int_state_s_stride = 0 + int_state_c_stride = 0 + int_state_l_stride = 0 + + BLOCK_DIM = 64 + grid = (batch, triton.cdiv(dim, BLOCK_DIM)) + + causal_conv1d_update_kernel[grid]( + x, + x.stride(0), + x.stride(1), + x.stride(2), + conv_state, + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + int_state_ptr, + int_state_b_stride, + int_state_s_stride, + int_state_c_stride, + int_state_l_stride, + weight, + weight.stride(0), + weight.stride(1), + bias, + bias_stride, + out, + out.stride(0), + out.stride(1), + out.stride(2), + conv_state_indices, + batch, + seq_len, + dim, + state_len, + WIDTH=width, + BLOCK_DIM=BLOCK_DIM, + HAS_BIAS=has_bias, + HAS_STATE_INDICES=has_state_indices, + HAS_INT_STATE=has_int_state, + SILU_ACTIVATION=silu_activation == "silu", + ) + + if is_2d: + out = out.squeeze(1) + + return out diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py new file mode 100644 index 00000000000..cd2041eb084 --- /dev/null +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -0,0 +1,441 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + + +if TRITON3: + + @triton.jit + def softplus(dt): + """Optimized softplus.""" + return tl.math.log(tl.math.exp(dt) + 1) + +else: + + @triton.jit + def softplus(dt): + """Optimized softplus.""" + return tl.math.log1p(tl.exp(dt)) + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + {"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None} +) +@triton.heuristics({"HAS_INT_STATE": lambda args: args["int_state_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + int_state_ptr, + # Matrix dimensions + batch, + seq_len, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_seq, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_seq, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_seq, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_seq, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_seq, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_int_batch, + stride_int_seq, + stride_int_head, + stride_int_dim, + stride_int_dstate, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + HAS_INT_STATE: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + out_ptrs = out_ptr + offs_m * stride_out_dim + + # 1. State Mapping (handles dynamic batching slot allocation) + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + # Skip padding tokens (e.g. from graph capture or inactive slots) + if state_batch_idx < 0: + for s in range(seq_len): + out_s_ptrs = out_ptrs + s * stride_out_seq + tl.store(out_s_ptrs, 0.0, mask=offs_m < dim) + return + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + if HAS_INT_STATE: + int_state_ptr += state_batch_idx * stride_int_batch + pid_h * stride_int_head + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + if HAS_INT_STATE: + int_state_ptr += pid_b * stride_int_batch + pid_h * stride_int_head + + # Base Pointers for Sequence iteration + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + + # Constant offsets (A, D, and bias do not have a sequence dimension) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + if HAS_INT_STATE: + int_state_ptrs = int_state_ptr + ( + offs_m[:, None] * stride_int_dim + offs_n[None, :] * stride_int_dstate + ) + + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + D_ptrs = D_ptr + offs_m * stride_D_dim + + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + + # Load initial historical state and constant parameters + state = tl.load( + state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + + if not TIE_HDIM: + A = tl.load( + A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + else: + A = tl.load(A_ptr).to(tl.float32) + + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # ---------------------------------------------------- + # Sequence Loop (Processes Main Token + Speculative Drafts) + # ---------------------------------------------------- + for s in range(seq_len): + x_s_ptrs = x_ptrs + s * stride_x_seq + dt_s_ptrs = dt_ptrs + s * stride_dt_seq + B_s_ptrs = B_ptrs + s * stride_B_seq + C_s_ptrs = C_ptrs + s * stride_C_seq + if HAS_Z: + z_s_ptrs = z_ptrs + s * stride_z_seq + + x = tl.load(x_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # Calculate dt and dA + if not TIE_HDIM: + dt = tl.load(dt_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr + s * stride_dt_seq).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + dA = tl.exp(A * dt) + + # Load B and C + B = tl.load(B_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + if not TIE_HDIM: + dB = B[None, :] * dt[:, None] + else: + dB = B * dt + + # ---------------------------------------------------- + # The Core State Recurrence (h_t = dA * h_{t-1} + dB * x_t) + # ---------------------------------------------------- + state = state * dA + dB * x[:, None] + + # ---------------------------------------------------- + # Dump Intermediate Speculative State Snapshot + # ---------------------------------------------------- + if HAS_INT_STATE: + int_state_s_ptrs = int_state_ptrs + s * stride_int_seq + tl.store( + int_state_s_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + ) + + # Calculate Output + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + + out_s_ptrs = out_ptrs + s * stride_out_seq + tl.store(out_s_ptrs, out, mask=offs_m < dim) + + # After processing all sequence steps, persist the final state back to HBM + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + + +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + intermediate_ssm_states=None, +): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim), (batch, seqlen, dim), (batch, nheads, dim) or (batch, seqlen, nheads, dim) + dt: Matches x + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or + (batch, seqlen, ngroups, dstate) + C: Matches B + D: (dim,) or (nheads, dim) + z: Matches x + dt_bias: (dim,) or (nheads, dim) + intermediate_ssm_states: Optional buffer of shape (batch, seqlen, nheads, dim, dstate) + or (batch, seqlen, dim, dstate) + Return: + out: shape matches x + """ + has_heads = state.dim() > 3 + if not has_heads: + state = state.unsqueeze(1) + + # Standardize inputs to explicit sequence and head dimensions: (batch, seq_len, nheads, dim) + is_seq_unsq = False + if has_heads: + if x.dim() == 3: # (batch, nheads, dim) -> (batch, 1, nheads, dim) + x = x.unsqueeze(1) + dt = dt.unsqueeze(1) + B = B.unsqueeze(1) + C = C.unsqueeze(1) + if z is not None: + z = z.unsqueeze(1) + is_seq_unsq = True + else: + if x.dim() == 2: # (batch, dim) -> (batch, 1, 1, dim) + x = x.unsqueeze(1).unsqueeze(2) + dt = dt.unsqueeze(1).unsqueeze(2) + B = B.unsqueeze(1).unsqueeze(2) + C = C.unsqueeze(1).unsqueeze(2) + if z is not None: + z = z.unsqueeze(1).unsqueeze(2) + is_seq_unsq = True + elif x.dim() == 3: # (batch, seqlen, dim) -> (batch, seqlen, 1, dim) + x = x.unsqueeze(2) + dt = dt.unsqueeze(2) + B = B.unsqueeze(2) + C = C.unsqueeze(2) + if z is not None: + z = z.unsqueeze(2) + + if A.dim() == 2: + A = A.unsqueeze(0) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + + # Set up Intermediate State standardization + if intermediate_ssm_states is not None: + if not has_heads and intermediate_ssm_states.dim() == 4: + intermediate_ssm_states = intermediate_ssm_states.unsqueeze( + 2 + ) # (batch, seqlen, 1, dim, dstate) + int_state_strides = ( + intermediate_ssm_states.stride(0), + intermediate_ssm_states.stride(1), + intermediate_ssm_states.stride(2), + intermediate_ssm_states.stride(3), + intermediate_ssm_states.stride(4), + ) + else: + intermediate_ssm_states = x # Dummy pointer + int_state_strides = (0, 0, 0, 0, 0) + + batch, seq_len, nheads, dim = x.shape + dstate = state.shape[-1] + ngroups = B.shape[-2] + + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ( + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0) + ) + + BLOCK_SIZE_M, num_warps = ( + (32, 4) + if dstate <= 16 + else ( + (16, 4) + if dstate <= 32 + else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) + ) + ) + + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and (dt_bias is None or dt_bias.stride(-1) == 0) + ) + + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + intermediate_ssm_states, + batch, + seq_len, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + dt.stride(0), + dt.stride(1), + dt.stride(2), + dt.stride(3), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0), + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(3), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + *(D.stride(0), D.stride(1)) if D is not None else (0, 0), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + *int_state_strides, + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + + # Revert dimensions back to match original x format + if not has_heads: + out = out.squeeze(2) + if is_seq_unsq: + out = out.squeeze(1) + + return out diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 28e3dde01c4..310a59bde35 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -799,7 +799,7 @@ def flash_decode_and_prefill( assert block_table is not None # Flash attn kernel. - if not is_decode_only: + if max_seqlen_q > 1: q = q.squeeze(1) if getattr(self, "softmax_scale", None) is not None: softmax_scale = self.softmax_scale diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 4ad2e517cfc..3426d83b7b2 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -612,6 +612,44 @@ def set_loss_scale(scale: torch.Tensor): MTPLossAutoScaler.main_loss_backward_scale = scale +def compute_mtp_inference_logits( + hidden_states: Tensor, + mtp_num_layers: int, + output_layer: Callable, + output_weight: Optional[Tensor], + runtime_gather_output: Optional[bool], +) -> tuple: + """Compute MTP logits for inference mode. + + Splits the concatenated hidden states and generates logits for each MTP layer. + + Args: + hidden_states (Tensor): Concatenated hidden states from main + MTP layers. + mtp_num_layers (int): Number of MTP layers. + output_layer (Callable): Output layer method to compute logits. + output_weight (Optional[Tensor]): Optional output weight for shared embeddings. + runtime_gather_output (Optional[bool]): Whether to gather output at runtime. + + Returns: + tuple: (hidden_states, mtp_logits_cache) where hidden_states is the main hidden + states and mtp_logits_cache is a tensor of shape + [mtp_num_layers, batch_size, vocab_size]. + """ + hidden_states_list = torch.chunk(hidden_states, 1 + mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + mtp_inference_logits = [] + for mtp_layer_number in range(mtp_num_layers): + mtp_logits, _ = output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # mtp logits shape [b, 1, vocab size] + mtp_inference_logits.append(mtp_logits.squeeze(1).unsqueeze(0)) + mtp_logits_cache = torch.cat(mtp_inference_logits, dim=0) + return hidden_states, mtp_logits_cache + + def process_mtp_loss( hidden_states: Tensor, labels: Tensor, @@ -998,6 +1036,53 @@ def _postprocess(self, hidden_states: torch.Tensor): return hidden_states + def forward_single_position( + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + embedding: Callable, + attention_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + inference_params=None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + ) -> Tensor: + """Forward for single positions without roll_tensor (speculative decoding). + + Unlike the regular forward which rolls input_ids to get the next token's + embedding, this method directly takes the correct next_token_ids. This is + used in speculative decoding where the correct next token is known after + verification. + + Args: + hidden_states (Tensor): Hidden states at positions of interest [N, B, H]. + next_token_ids (Tensor): The correct next token IDs [B, N]. + position_ids (Tensor): Position IDs for the next tokens [B, N]. + embedding (Callable): The embedding module. + + Returns: + Tensor: MTP hidden states [N, B, H]. + """ + decoder_input = embedding(input_ids=next_token_ids, position_ids=position_ids) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=False, keep_graph=False + ) + hidden_states = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + return hidden_states + def _checkpointed_forward(self, forward_func, *args, **kwargs): def checkpoint_handler(): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 0c6462d26d1..ec8f1088be1 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -341,7 +341,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): mamba_inference_state_config=mamba_inference_state_config, pg_collection=pg_collection, use_flashinfer_fused_rope=args.use_flashinfer_fused_rope, - materialize_only_last_token_logits=not args.return_log_probs, + materialize_only_last_token_logits=(not args.return_log_probs and args.num_speculative_tokens == 0), track_generated_token_events=args.inference_dynamic_batching_track_generated_token_events, track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, @@ -350,6 +350,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): prefix_caching_coordinator_policy=PrefixCachingCoordinatorPolicy(args.inference_dynamic_batching_prefix_caching_coordinator_policy), metrics_writer=metrics_writer, logging_step_interval=args.inference_logging_step_interval, + num_speculative_tokens=args.num_speculative_tokens, ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 99c891fc60d..7a2905d7a68 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1795,10 +1795,11 @@ def _add_inference_args(parser): '1) allocate `memory_buffer` in unified memory. ' 'Eventually, additional levels will be included to ' 'control other tensors within the context.') - # TODO(ksanthanam): Clean this up in future PR group.add_argument('--enable-chunked-prefill', dest='enable_chunked_prefill', action='store_true', default=False, help="Enable chunked prefill (disabled by default)") + group.add_argument('--num-speculative-tokens', type=int, default=0, + help='Number of speculative tokens generated during decode') group.add_argument('--inference-dynamic-batching-prefix-caching', dest='inference_dynamic_batching_enable_prefix_caching', action=argparse.BooleanOptionalAction, diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 7ac9c6bada9..bebf574b965 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1502,6 +1502,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('spec', force=True) _set_arg('num_experts', force=True) + _set_arg('mtp_num_layers', force=True) _set_arg('moe_layer_freq', force=True) if getattr(checkpoint_args, 'num_experts', None) is not None: _set_arg('moe_ffn_hidden_size', force=True) diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index dd34061888e..7e76ce4b7b0 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -100,6 +100,7 @@ def test_update_decode_only_exact_match(self, metadata_context): expected_decode = torch.arange(4, dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + assert metadata_context.batch_indices_prefill is None assert metadata_context.batch_indices_chunked_prefill is None assert metadata_context.device_decode_prefill is None @@ -124,6 +125,7 @@ def test_update_decode_only_padded(self, metadata_context): [0, 1, -1, -1], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + assert metadata_context.batch_indices_prefill is None assert metadata_context.batch_indices_chunked_prefill is None assert metadata_context.device_decode_prefill is None @@ -144,6 +146,7 @@ def test_update_chunked_enabled_no_prefill_reqs(self, metadata_context): # Should behave exactly like decode-only (chunked logic skipped if real_prefill == 0) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + assert metadata_context.batch_indices_chunked_prefill is None assert metadata_context.batch_indices_prefill is None assert metadata_context.cu_seqlens is None @@ -242,7 +245,7 @@ def test_update_mixed_batch_exact(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [2, 2], dtype=torch.int32, device=metadata_context.device + [2, 30], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) @@ -288,7 +291,7 @@ def test_update_padded_prefill_and_decode(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [1, 1], dtype=torch.int32, device=metadata_context.device + [1, 10], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) @@ -334,7 +337,7 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [1, 2], dtype=torch.int32, device=metadata_context.device + [1, 60], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) @@ -375,7 +378,7 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [2, 2], dtype=torch.int32, device=metadata_context.device + [2, 60], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index fcc6fb5a29e..e16ebaf4353 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -2,6 +2,7 @@ import contextlib import math +from unittest import mock import pytest import torch @@ -60,6 +61,7 @@ def _get_dynamic_context( layer_type_list=None, paused_buffer_size_gb=None, num_cuda_graphs=None, + num_speculative_tokens=0, ): if is_hybrid_model: if layer_type_list is None: @@ -93,6 +95,7 @@ def _get_dynamic_context( ), block_size_tokens=block_size_tokens, max_tokens=max_tokens, + num_speculative_tokens=num_speculative_tokens, mamba_inference_state_config=mamba_inference_state_config, use_flashinfer_fused_rope=None, # default to using flash-infer if available # this is for compatibility with the LTS environment @@ -826,6 +829,10 @@ def test_release_memory_blocks_for_finished_requests(self, is_hybrid_model): dynamic_context.request_to_kv_block_ids[i, 0] = initial_blocks[i] dynamic_context.request_query_lengths[i] = 1 dynamic_context.request_ids[i] = i + dynamic_context.request_last_kv_block_id[i] = initial_blocks[i] + dynamic_context.request_last_kv_block_offset[i] = 0 + dynamic_context.request_kv_block_counts[i] = 1 + dynamic_context.request_in_prefill_status_tensor[i] = 0 if is_hybrid_model: dynamic_context.mamba_conv_states[:, i, :, :].fill_( float(i + 1) @@ -914,9 +921,16 @@ def test_finished_requests_with_multiple_blocks(self, is_hybrid_model): for i in range(3): dynamic_context.request_query_lengths[i] = 1 dynamic_context.request_ids[i] = i + dynamic_context.request_last_kv_block_id[i] = dynamic_context.request_to_kv_block_ids[ + i, dynamic_context.request_kv_block_counts[i] - 1 + ] + dynamic_context.request_last_kv_block_offset[i] = 0 + dynamic_context.request_in_prefill_status_tensor[i] = 0 if is_hybrid_model: dynamic_context.mamba_conv_states[:, i, :, :].fill_(float(i + 1)) dynamic_context.mamba_ssm_states[:, i, :, :, :].fill_(float(i + 1)) + dynamic_context.mamba_metadata.request_to_mamba_state_idx[i] = i + dynamic_context.mamba_metadata.mamba_state_free_slot_count -= 1 # Create an active_requests_mask where all requests are finished active_requests_mask = torch.tensor([0, 0, 0], device=torch.cuda.current_device()) @@ -1404,8 +1418,9 @@ def test_max_requests_less_than_tp_size(self): @rounder_override(64) @pytest.mark.parametrize("is_hybrid_model", [False, True]) @pytest.mark.parametrize("num_cuda_graphs", [-1, 16, 32]) + @pytest.mark.parametrize("num_speculative_tokens", [0, 3]) def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( - self, is_hybrid_model: bool, num_cuda_graphs: int + self, is_hybrid_model: bool, num_cuda_graphs: int, num_speculative_tokens: int ): """The fast path (add_dummy_requests_for_expert_parallel_step) must leave the same observable state as the slow path @@ -1429,10 +1444,12 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( else None ), num_cuda_graphs=num_cuda_graphs, + num_speculative_tokens=num_speculative_tokens, ) smallest = min(ctx.cuda_graph_batch_dimensions_list) N = smallest.decode_req_count + T = smallest.token_count # N * (num_speculative_tokens + 1) assert smallest.prefill_req_count == 0, "smallest graph must be decode-only" # --- slow path (reference) --- @@ -1444,10 +1461,10 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( slow_request_query_lengths = ctx.request_query_lengths[:N].clone() slow_request_kv_length_offsets = ctx.request_kv_length_offsets[:N].clone() slow_request_to_kv_block_ids_col0 = ctx.request_to_kv_block_ids[:N, 0].clone() - slow_token_to_block_idx = ctx.token_to_block_idx[:N].clone() - slow_token_to_local_pos = ctx.token_to_local_position_within_kv_block[:N].clone() + slow_token_to_block_idx = ctx.token_to_block_idx[:T].clone() + slow_token_to_local_pos = ctx.token_to_local_position_within_kv_block[:T].clone() if is_hybrid_model: - slow_token_to_request_idx = ctx.token_to_request_idx[:N].clone() + slow_token_to_request_idx = ctx.token_to_request_idx[:T].clone() slow_mamba = ctx.mamba_metadata.request_to_mamba_state_idx[:N].clone() # --- reset and run fast path --- @@ -1466,13 +1483,13 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( # 3. Token-level state dummy_block_idx = ctx.block_allocator.dummy_block_idx - assert torch.all(ctx.token_to_block_idx[:N] == dummy_block_idx) - assert torch.equal(ctx.token_to_block_idx[:N], slow_token_to_block_idx) - assert torch.equal(ctx.token_to_local_position_within_kv_block[:N], slow_token_to_local_pos) + assert torch.all(ctx.token_to_block_idx[:T] == dummy_block_idx) + assert torch.equal(ctx.token_to_block_idx[:T], slow_token_to_block_idx) + assert torch.equal(ctx.token_to_local_position_within_kv_block[:T], slow_token_to_local_pos) if is_hybrid_model: # 4. token_to_request_idx - assert torch.equal(ctx.token_to_request_idx[:N], slow_token_to_request_idx) + assert torch.equal(ctx.token_to_request_idx[:T], slow_token_to_request_idx) # 5. Mamba state slots allocated (indices may differ, but must be valid and unique) fast_mamba = ctx.mamba_metadata.request_to_mamba_state_idx[:N] @@ -1506,3 +1523,1078 @@ def test_gqa_high_tp_partition_heads(self): # With TP=8 and GQA=2, num_attention_heads_per_partition should be clamped to 1 assert dynamic_context.num_attention_heads_per_partition == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_update_requests_speculative(self): + """Test update_requests correctly interleaves sampled and speculative tokens.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=256, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 2 active decode requests + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + ctx.request_ids[:2] = torch.tensor([10, 11]) + ctx.request_query_lengths[:2] = 1 + ctx.request_kv_length_offsets[:2] = torch.tensor([5, 8]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([5, 8]) + ctx.request_to_kv_block_ids[:2, 0] = torch.tensor([0, 1]) + ctx.request_last_kv_block_id[:2] = torch.tensor([0, 1]) + + active_requests_mask = torch.tensor([1, 1], device='cuda') + new_tokens = torch.tensor([99, 100], device='cuda') # Sampled tokens + new_speculative_tokens = torch.tensor( + [[991, 1001], [992, 1002]], device='cuda' + ) # Spec tokens + + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # Each request generates 1 (sampled) + 2 (speculative) = 3 tokens. + assert ctx.active_token_count == 6 + assert torch.equal( + ctx.request_query_lengths[:2], torch.tensor([3, 3], dtype=torch.int32, device='cuda') + ) + assert torch.equal( + ctx.request_kv_length_offsets[:2], + torch.tensor([6, 9], dtype=torch.int32, device='cuda'), + ) + + # Check interleaving: [sampled_1, spec1_1, spec2_1, sampled_2, spec1_2, spec2_2] + expected_tokens = torch.tensor([99, 991, 992, 100, 1001, 1002], device='cuda') + assert torch.equal(ctx.token_to_input_ids[:6], expected_tokens) + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_boundary_crossing(self): + """Test token block assignment when speculative tokens cross a KV block boundary.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=1024, + buffer_size_gb=0.1, + block_size_tokens=256, # FA2-compatible block size to force boundary crossing + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 1 active decode request + ctx.total_request_count = 1 + ctx.paused_request_count = 0 + ctx.active_token_count = 1 + + ctx.request_ids[0] = 10 + ctx.request_query_lengths[0] = 1 + ctx.request_kv_block_counts[0] = 1 + + # Length is 254, meaning existing tokens are at indices 0..253. + # The last inserted token was at offset 253. + # Adding 3 tokens places them at offsets 254, 255, and 256 (crosses block size of 256). + ctx.request_kv_length_offsets[0] = 254 + ctx.request_last_kv_block_offset[0] = 253 + + # Allocate one initial block manually + blocks = ctx.block_allocator.allocate_memory_blocks(1) + first_block = blocks[0] + ctx.request_to_kv_block_ids[0, 0] = first_block + ctx.request_last_kv_block_id[0] = first_block + + active_requests_mask = torch.tensor([1], device='cuda') + new_tokens = torch.tensor([50], device='cuda') + new_speculative_tokens = torch.tensor([[51], [52]], device='cuda') + + # Run update_requests natively. It will automatically: + # 1. Detect the boundary crossing and pause the request. + # 2. Clone the prev_last_block_ids internally. + # 3. Resume the request, allocating the new block. + # 4. Map the 3 new tokens across the boundary. + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # Verify a new block was natively allocated by the resume logic + assert ctx.request_kv_block_counts[0] == 2 + second_block = ctx.request_to_kv_block_ids[0, 1] + assert second_block != -1 + assert second_block != first_block + + # Expected token mapping for the 3 generated tokens (sampled, spec1, spec2) + # Token 0 (offset 2) -> first_block + # Token 1 (offset 3) -> first_block + # Token 2 (offset 4) -> second_block + expected_blocks = torch.tensor( + [first_block, first_block, second_block], dtype=torch.int, device='cuda' + ) + + assert torch.equal(ctx.token_to_block_idx[:3], expected_blocks) + + @pytest.mark.internal + @rounder_override(64) + def test_paused_speculative_tokens_tracking(self): + """ + Test that speculative tokens are correctly saved and concatenated + when requests are temporarily paused. + """ + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=1024, + buffer_size_gb=0.1, + block_size_tokens=256, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 2 active requests. Request 0 is about to overflow its block. + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + ctx.request_ids[:2] = torch.tensor([10, 11]) + ctx.request_query_lengths[:2] = 1 + + # Request 0 is at offset 254. Adding 1 sampled + 2 spec = 3 tokens will push it to 257, + # which is >= block_size_tokens (256). It will require a new block. + # Request 1 is at offset 5. It will not require a new block. + ctx.request_kv_length_offsets[:2] = torch.tensor([254, 5]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([254, 5]) + ctx.request_kv_block_counts[:2] = 1 + + # Allocate blocks + blocks = ctx.block_allocator.allocate_memory_blocks(2) + ctx.request_to_kv_block_ids[0, 0] = blocks[0] + ctx.request_to_kv_block_ids[1, 0] = blocks[1] + ctx.request_last_kv_block_id[:2] = blocks + + # Force the allocator to have no available blocks. + # This guarantees request 0 stays paused and cannot immediately resume. + ctx.block_allocator.total_avail = 0 + ctx.block_allocator.paused_count = 100 # Ensure it doesn't get completely evicted either + + active_requests_mask = torch.tensor([1, 1], device='cuda') + new_tokens = torch.tensor([99, 100], device='cuda') # Sampled + new_speculative_tokens = torch.tensor( + [[991, 1001], [992, 1002]], device='cuda' + ) # Speculative + + # In update_requests, request 0 will be paused to allocate a new block. + # Since total_avail is 0, it will stay paused and its tokens will be cached. + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # Verify paused state was populated correctly + assert ctx.paused_tokens is not None + assert ctx.paused_speculative_tokens is not None + + # Request 0 was the one paused, so its tokens should be shifted to + # index 0 of the paused tensors. + assert ctx.paused_request_count == 1 + assert ctx.total_request_count == 2 + + assert ctx.paused_tokens[0].item() == 99 + assert torch.equal( + ctx.paused_speculative_tokens[:, 0], torch.tensor([991, 992], device='cuda') + ) + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_tokens_less_than_block_size_assert(self): + self._setup_model_parallel_group(1, 1) + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=256, + num_speculative_tokens=256, + unified_memory_level=0, + ) + with pytest.raises( + AssertionError, match="num_speculative_tokens.*must be < block_size_tokens" + ): + DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + @pytest.mark.internal + @rounder_override(64) + def test_swap_book_keeping_tensors_with_speculative_tokens(self): + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=256, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + ctx.request_ids[:2] = torch.tensor([10, 11]) + next_tokens = torch.tensor([99, 100], device='cuda') + new_speculative_tokens = torch.tensor([[991, 1001], [992, 1002]], device='cuda') + + ctx._swap_book_keeping_tensors( + src_idxs=torch.tensor([0]), + dst_idxs=torch.tensor([1]), + next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + assert torch.equal(ctx.request_ids[:2], torch.tensor([11, 10], device='cuda')) + assert torch.equal(next_tokens[:2], torch.tensor([100, 99], device='cuda')) + assert torch.equal( + new_speculative_tokens[:, :2], torch.tensor([[1001, 991], [1002, 992]], device='cuda') + ) + + @pytest.mark.internal + @rounder_override(64) + def test_update_requests_with_finished_requests_and_speculative_tokens(self): + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 3 active requests: req0 (active), req1 (finished), req2 (active) + ctx.total_request_count = 3 + ctx.paused_request_count = 0 + ctx.active_token_count = 3 + ctx.request_ids[:3] = torch.tensor([10, 11, 12]) + ctx.request_query_lengths[:3] = 1 + ctx.request_kv_length_offsets[:3] = torch.tensor([5, 8, 12]) + ctx.request_last_kv_block_offset[:3] = torch.tensor([5, 8, 12]) + ctx.request_to_kv_block_ids[:3, 0] = torch.tensor([0, 1, 2]) + ctx.request_last_kv_block_id[:3] = torch.tensor([0, 1, 2]) + ctx.request_kv_block_counts[:3] = 1 + + active_requests_mask = torch.tensor([1, 0, 1], device='cuda') + new_tokens = torch.tensor([99, 100, 101], device='cuda') + new_speculative_tokens = torch.tensor([[991, 1001, 1011], [992, 1002, 1012]], device='cuda') + + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # req1 is finished. req2 moves to req1's position. + assert ctx.total_request_count == 2 + assert torch.equal( + ctx.request_ids[:2], torch.tensor([10, 12], device='cuda', dtype=torch.int32) + ) + + # Check interleaving for req0 and req2 + # req0: [99, 991, 992] + # req2: [101, 1011, 1012] + expected_tokens = torch.tensor([99, 991, 992, 101, 1011, 1012], device='cuda') + assert torch.equal(ctx.token_to_input_ids[:6], expected_tokens) + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_speculative_offset_math(self): + """ + Test that the active_token_count is correctly adjusted by chunked_prefill_offset + when a chunked prefill request continues in a speculative decoding setup. + """ + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.05, + block_size_tokens=128, + max_requests=256, + max_tokens=256, + num_speculative_tokens=3, # 3 spec tokens -> offset = 4 + enable_chunked_prefill=True, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + ctx.reset_tensors() + + # Setup a request that is already mid-chunked-prefill + ctx.total_request_count = 1 + ctx.chunked_prefill_request_id = 42 + ctx.request_ids[0] = 42 + + # Simulate active tokens from the previous step. + # Normally, the previous step generated a dummy token + spec tokens that + # need to be overwritten. + initial_active_tokens = 100 + ctx.active_token_count = initial_active_tokens + + req = DynamicInferenceRequest( + request_id=42, + prompt_tokens=torch.arange(0, 50, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10), + ) + # Mark as continuing chunked prefill + req.finished_chunk_token_count = 100 + + # Add the next chunk + chunk_length = 50 + ctx.add_request(req, chunk_length=chunk_length) + + # The new active token count should be: + # initial (100) - chunked_prefill_offset (1 + 3 = 4) + chunk_length (50) = 146 + expected_active_tokens = ( + initial_active_tokens - (1 + ctx.num_speculative_tokens) + chunk_length + ) + + assert ctx.active_token_count == expected_active_tokens + assert ( + ctx.request_output_lengths[0].item() + == req.finished_chunk_token_count + + chunk_length + + req.sampling_params.num_tokens_to_generate + ) + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_swap_with_speculative_tokens(self): + """Test that swapping a chunked prefill request to the end of the buffer + correctly brings along the 2D speculative tokens for the other decode requests. + """ + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=2, + enable_chunked_prefill=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 2 active requests in the WRONG order (violating the invariant) + # Index 0: Chunked Prefill Request (ID 42) + # Index 1: Standard Decode Request (ID 99) + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + + ctx.chunked_prefill_request_id = 42 + ctx.request_ids[:2] = torch.tensor([42, 99]) + + # Status: 1 = Prefill, 0 = Decode + ctx.request_in_prefill_status_tensor[:2] = torch.tensor([1, 0]) + ctx.request_query_lengths[:2] = 1 + ctx.request_kv_length_offsets[:2] = torch.tensor([10, 20]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([10, 20]) + ctx.request_to_kv_block_ids[:2, 0] = torch.tensor([0, 1]) + ctx.request_last_kv_block_id[:2] = torch.tensor([0, 1]) + ctx.request_kv_block_counts[:2] = 1 + + active_requests_mask = torch.tensor([1, 1], device='cuda') + + # New base tokens: [100 (for prefill), 200 (for decode)] + new_tokens = torch.tensor([100, 200], device='cuda') + + # New spec tokens: Col 0 for prefill (dummy), Col 1 for decode (real draft tokens) + new_speculative_tokens = torch.tensor([[101, 201], [102, 202]], device='cuda') + + # Trigger update_requests. + # It must detect ID 42 is at index 0, and swap it with index 1. + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # 1. Verify the IDs were swapped successfully + assert torch.equal( + ctx.request_ids[:2], torch.tensor([99, 42], dtype=torch.int32, device='cuda') + ) + + # 2. Verify the Decode request (now at Index 0) correctly flattened its + # base token (200) AND its specific speculative tokens (201, 202). + # 3. Verify the Prefill request (now at Index 1) flattened its tokens (100, 101, 102). + expected_flattened_tokens = torch.tensor( + [200, 201, 202, 100, 101, 102], # Decode request (ID 99) # Prefill request (ID 42) + device='cuda', + ) + + assert torch.equal( + ctx.token_to_input_ids[:6], expected_flattened_tokens + ), "Speculative tokens were not correctly swapped alongside the chunked prefill request!" + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_with_prefix_caching_shared_blocks(self): + """Test that prefix caching correctly shares blocks when speculative decoding is enabled.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + num_speculative_tokens=2, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 3, device='cuda') + + # First request registers blocks. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + first_blocks = [ctx.request_to_kv_block_ids[0][i].item() for i in range(3)] + avail_after_first = ctx.block_allocator.total_avail + + # Second request with same prefix should share all blocks. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + second_blocks = [ctx.request_to_kv_block_ids[1][i].item() for i in range(3)] + + # Blocks should be shared (same IDs, no pool consumption). + assert first_blocks == second_blocks + assert ctx.block_allocator.total_avail == avail_after_first + + # Ref counts should be 2. + for bid in first_blocks: + assert ctx.block_allocator.block_ref_counts[bid].item() == 2 + + # Second request should skip prefix tokens (query_length == 1 for full match). + assert ctx.request_query_lengths[1].item() == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_with_prefix_caching_kv_offset(self): + """Test that KV offset accounts for prefix skip when spec decoding is enabled.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + num_speculative_tokens=3, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # First request. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + # Second request with same prefix: should have kv_offset = prefix_skip_tokens. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # Full match: prefix_skip = min(2 * bs, 2*bs - 1) = 2*bs - 1 + expected_skip = 2 * bs - 1 + assert ctx.request_kv_length_offsets[1].item() == expected_skip + assert ctx.request_query_lengths[1].item() == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_update_then_release_with_prefix_caching(self): + """Test that update_requests with spec tokens + block release respects ref counts.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=4, + num_speculative_tokens=2, + enable_prefix_caching=True, + unified_memory_level=0, + max_requests=512, + max_tokens=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # Two requests sharing the same prefix. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + shared_blocks = [ctx.request_to_kv_block_ids[0][i].item() for i in range(2)] + + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # Verify initial ref counts are 2. + for bid in shared_blocks: + assert ctx.block_allocator.block_ref_counts[bid].item() == 2 + + # Release one request. Ref counts should decrement to 1. + ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) + for bid in shared_blocks: + assert ctx.block_allocator.block_ref_counts[bid].item() == 1 + + # Blocks should still be discoverable via hash map. + for bid in shared_blocks: + h = ctx.block_allocator.block_hashes[bid].item() + assert h in ctx.block_allocator.hash_to_block_id + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_boundary_crossing_with_prefix_caching(self): + """Test block boundary crossing from speculative tokens does not corrupt shared blocks.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=4, + num_speculative_tokens=2, + enable_prefix_caching=True, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # Request 1: adds prefix blocks. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + shared_b0 = ctx.request_to_kv_block_ids[0][0].item() + shared_b1 = ctx.request_to_kv_block_ids[0][1].item() + + # Request 2: shares prefix, gets its own decode block. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # Both requests share the same 2 blocks. + assert ctx.request_to_kv_block_ids[1][0].item() == shared_b0 + assert ctx.request_to_kv_block_ids[1][1].item() == shared_b1 + + # Set up request 0 for decode at offset that will cross block boundary. + # Place at offset (block_size - 1) in last block so adding 3 tokens crosses. + ctx.request_kv_length_offsets[0] = bs * 2 - 1 # one token from end of block 1 + # The local offset of index 6 is (6 % bs) + ctx.request_last_kv_block_offset[0] = bs - 2 + ctx.request_query_lengths[0] = 1 + ctx.request_in_prefill_status_tensor[0] = 0 + ctx.active_token_count = 2 + + active_mask = torch.tensor([1, 1], device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([50, 50], device='cuda') + new_spec = torch.tensor([[51, 51], [52, 52]], device='cuda') + + ctx.update_requests( + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec + ) + + # A new block should have been allocated for the boundary crossing. + assert ctx.request_kv_block_counts[0] == 3 + new_block = ctx.request_to_kv_block_ids[0][2].item() + assert new_block != -1 + assert new_block != shared_b0 + assert new_block != shared_b1 + + # Shared blocks should remain intact with ref count 2. + assert ctx.block_allocator.block_ref_counts[shared_b0].item() == 2 + assert ctx.block_allocator.block_ref_counts[shared_b1].item() == 2 + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_speculative_offset_with_prefix_caching(self): + """Test chunked prefill offset math combines correctly with prefix caching and spec decoding.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + max_requests=256, + max_tokens=256, + num_speculative_tokens=2, + enable_chunked_prefill=True, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + ctx.reset_tensors() + + bs = ctx.block_size_tokens + + # First request: register prefix blocks (bs * 3 tokens = 3 complete blocks). + first_prompt = torch.arange(bs * 3, device='cuda') + req_first = DynamicInferenceRequest( + request_id=1, + prompt_tokens=first_prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req_first) + + # Second request: same prefix, continuing chunked prefill. + # Simulate that this request already processed bs tokens in a prior chunk. + ctx.chunked_prefill_request_id = 42 + ctx.request_ids[ctx.total_request_count] = 42 + + # Manually set up as if request is mid-chunked-prefill. + ctx.total_request_count += 1 + current_id = ctx.total_request_count - 1 + ctx.request_ids[current_id] = 42 + + initial_active_tokens = ctx.active_token_count + 1 + ctx.num_speculative_tokens + ctx.active_token_count = initial_active_tokens + + req2 = DynamicInferenceRequest( + request_id=42, + prompt_tokens=first_prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + req2.finished_chunk_token_count = bs # Already processed 1 block + + chunk_length = bs * 2 # Process 2 more blocks + ctx.add_request(req2, chunk_length=chunk_length) + + # Prefix match should find 2 matching blocks (blocks 1 and 2 from req_first). + # The chunked_prefill_offset (1 + num_speculative_tokens = 3) should be subtracted. + chunked_prefill_offset = 1 + ctx.num_speculative_tokens + # With prefix match: 2 blocks matched -> skip (2*bs - 1) tokens + # effective_chunk_length = chunk_length - prefix_skip_tokens + (_, _, _, _, prefix_skip, eff_chunk) = ctx._compute_prefix_match(req2, chunk_length) + expected_active = initial_active_tokens - chunked_prefill_offset + eff_chunk + assert ctx.active_token_count == expected_active + + @pytest.mark.internal + @rounder_override(64) + def test_prefix_caching_check_availability_with_speculative(self): + """Test check_availability accounts for prefix match when spec decoding is enabled.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=3, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # First request registers blocks. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + # Exhaust the remaining pool. + while ctx.block_allocator.total_avail > 0: + ctx.block_allocator.allocate_memory_blocks(1) + + # A new request with the same prefix should still be schedulable + # because prefix matching means 0 new blocks are needed from pool. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + _, _, kv_available = ctx.check_availability(req2) + assert kv_available, "Matched blocks should not require pool allocation" + + @pytest.mark.internal + @rounder_override(64) + def test_prefix_match_exact_block_boundary(self): + """Test prefix matching when the shared prefix is an exact multiple of the block size.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=16, + enable_prefix_caching=True, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + + # req1: 32 tokens (exactly 2 complete blocks) + prompt1 = torch.arange(bs * 2, device='cuda') + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt1, + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + # req2: 35 tokens (first 32 tokens match req1) + prompt2 = torch.arange(bs * 2 + 3, device='cuda') + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt2, + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # req2 should have 3 blocks total + assert ctx.request_kv_block_counts[1].item() == 3 + + # The first 2 blocks should be shared + assert ctx.request_to_kv_block_ids[1, 0].item() == ctx.request_to_kv_block_ids[0, 0].item() + assert ctx.request_to_kv_block_ids[1, 1].item() == ctx.request_to_kv_block_ids[0, 1].item() + + # The 3rd block should be a newly allocated pool block + assert ctx.request_to_kv_block_ids[1, 2].item() != ctx.request_to_kv_block_ids[0, 1].item() + + # The offset points to the last token (index 34). In the 3rd block (indices 32-47), 34 is at offset 2. + assert ctx.request_last_kv_block_offset[1].item() == 2 + + # Effective query length should be 3 (35 total - 32 skipped) + assert ctx.request_query_lengths[1].item() == 3 + + @pytest.mark.internal + @rounder_override(64) + def test_eviction_with_shared_prefix_blocks(self): + """Test that evicting a request drops ref counts correctly without destroying shared blocks.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=16, + enable_prefix_caching=True, + unified_memory_level=0, + paused_buffer_size_gb=0.0, # 0 paused capacity to force immediate eviction + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # Add req1 and req2 with identical prompts + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + shared_b0 = ctx.request_to_kv_block_ids[0, 0].item() + shared_b1 = ctx.request_to_kv_block_ids[0, 1].item() + + # Both blocks should be safely shared with ref count 2 + assert ctx.block_allocator.block_ref_counts[shared_b0].item() == 2 + + # Mock the state to make req1 paused and req2 active + ctx.paused_request_count = 1 + ctx.total_request_count = 2 + ctx.request_ids[0] = 1 + ctx.request_ids[1] = 2 + ctx.request_kv_block_counts[0] = 2 + ctx.request_kv_block_counts[1] = 2 + + # Exhaust the active block allocator + ctx.block_allocator.total_avail = 0 + + # Trigger the eviction logic + # next_tokens must be sized to total_request_count (1 paused + 1 active = 2) + next_tokens = torch.tensor([50, 51], device='cuda') + evicted_ids = ctx.evict_overflow_paused_requests( + active_request_count=1, next_tokens=next_tokens + ) + + # req1 should be successfully evicted + assert evicted_ids is not None + assert evicted_ids[0].item() == 1 + + # req2 remains active, so the shared blocks should drop to a ref count of 1 + assert ctx.block_allocator.block_ref_counts[shared_b0].item() == 1 + assert ctx.block_allocator.block_ref_counts[shared_b1].item() == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_oom_during_speculative_boundary_crossing(self): + """Test boundary crossing with speculative tokens pauses the request gracefully when KV cache is full, keeping other requests active.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.1, + block_size_tokens=16, + num_speculative_tokens=2, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + bs = ctx.block_size_tokens + + # Setup 2 active requests. + # Request 0 is exactly 1 token away from its boundary (will OOM). + # Request 1 has plenty of space (will remain active). + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + + ctx.request_ids[:2] = torch.tensor([10, 11], device='cuda') + ctx.request_query_lengths[:2] = 1 + ctx.request_kv_block_counts[:2] = 1 + + # Request 0 offset is 15. Adding 1 sampled + 2 spec = 3 tokens crosses the boundary (16). + # Request 1 offset is 5. Adding 3 tokens = 8 (does not cross). + ctx.request_kv_length_offsets[:2] = torch.tensor( + [bs - 1, 5], device='cuda', dtype=torch.int32 + ) + ctx.request_last_kv_block_offset[:2] = torch.tensor( + [bs - 1, 5], device='cuda', dtype=torch.int32 + ) + + blocks = ctx.block_allocator.allocate_memory_blocks(2) + ctx.request_to_kv_block_ids[0, 0] = blocks[0] + ctx.request_to_kv_block_ids[1, 0] = blocks[1] + ctx.request_last_kv_block_id[:2] = blocks + + # Force OOM condition (no blocks left in the active pool) + ctx.block_allocator.total_avail = 0 + ctx.block_allocator.paused_count = 100 # Prevent immediate eviction out of the system + + active_mask = torch.tensor([1, 1], device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([99, 88], device='cuda') + new_spec = torch.tensor([[100, 200], [101, 201]], device='cuda') + + # Run update requests + ctx.update_requests( + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec + ) + + # Request 0 should detect OOM, fail to allocate a new block, and pause. + # Request 1 remains active, so active_request_count goes 2 -> 1, avoiding the deadlock assert. + assert ctx.paused_request_count == 1 + assert ctx.total_request_count == 2 + + # Request 1 generated 3 tokens (1 sampled + 2 spec) + assert ctx.active_token_count == 3 + + # Tokens must be cached in the paused buffers so Request 0 can resume cleanly later + assert ctx.paused_tokens is not None + assert ctx.paused_tokens[0].item() == 99 + + assert ctx.paused_speculative_tokens is not None + assert ctx.paused_speculative_tokens[0, 0].item() == 100 + assert ctx.paused_speculative_tokens[1, 0].item() == 101 + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_meets_prefix_caching(self): + """Test that chunks in a chunked-prefill pipeline properly hit the prefix cache mid-flight.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + enable_chunked_prefill=True, + enable_prefix_caching=True, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(128, device='cuda') + + # Cache req1 (fully processed) + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + req1_blocks = [ctx.request_to_kv_block_ids[0, i].item() for i in range(4)] + + # Start chunked prefill for req2. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + + # Add the first chunk (64 tokens) + req2.finished_chunk_token_count = 0 + ctx.chunked_prefill_request_id = 2 + ctx.add_request(req2, chunk_length=64) + + # Assert the first chunk perfectly matched the first 2 cached blocks + assert ctx.request_to_kv_block_ids[1, 0].item() == req1_blocks[0] + assert ctx.request_to_kv_block_ids[1, 1].item() == req1_blocks[1] + assert ctx.request_kv_block_counts[1].item() == 2 + + # Simulate update_requests completing the chunk + ctx.active_token_count += 1 + ctx.request_in_prefill_status_tensor[1] = 0 + + # Add the second chunk (64 tokens) + req2.finished_chunk_token_count = 64 + ctx.add_request(req2, chunk_length=64) + + # It should correctly discover the remaining prefix blocks despite being mid-prefill + assert ctx.request_to_kv_block_ids[1, 2].item() == req1_blocks[2] + assert ctx.request_to_kv_block_ids[1, 3].item() == req1_blocks[3] + assert ctx.request_kv_block_counts[1].item() == 4 + + # Verify block references updated appropriately + assert ctx.block_allocator.block_ref_counts[req1_blocks[2]].item() == 2 + assert ctx.block_allocator.block_ref_counts[req1_blocks[3]].item() == 2 diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 6a07d7a35ae..4117ef39b92 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from functools import partial from typing import Dict, List, Optional, Tuple +from unittest import mock import pytest import torch @@ -41,6 +42,7 @@ get_gpt_layer_local_spec, get_gpt_layer_with_inference_spec, get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec @@ -124,6 +126,7 @@ class DynamicEngineTestConfig: materialize_only_last_token_logits: bool = True skip_prompt_log_probs: bool = False enable_chunked_prefill: bool = False + enable_prefix_caching: bool = False cuda_graph_scope: List[CudaGraphScope] = field( default_factory=lambda: [CudaGraphScope.full_iteration_inference] ) @@ -139,19 +142,22 @@ class DynamicEngineTestConfig: kv_cache_management_mode: str = "persist" static_kv_memory_pointers: bool = True track_generated_token_events: bool = False - - fp8: bool = False + num_speculative_tokens: int = 0 def __post_init__(self): # Compute max_sequence_length. - assert self.max_sequence_length is None - assert self.num_tokens_to_generate is None or self.num_tokens_total is None - if self.num_tokens_to_generate is not None: - self.max_sequence_length = self.max_prompt_length + self.num_tokens_to_generate - else: - assert self.num_tokens_total is not None - self.max_sequence_length = self.num_tokens_total + if self.max_sequence_length is None: + assert self.num_tokens_to_generate is None or self.num_tokens_total is None + if self.num_tokens_to_generate is not None: + self.max_sequence_length = ( + self.max_prompt_length + + self.num_tokens_to_generate + + self.num_speculative_tokens + ) + else: + assert self.num_tokens_total is not None + self.max_sequence_length = self.num_tokens_total + self.num_speculative_tokens # Default paused buffer size. if self.context_paused_buffer_size_gb is None: @@ -259,10 +265,12 @@ def _build_inference_context( ), static_kv_memory_pointers=test_config.static_kv_memory_pointers, enable_chunked_prefill=test_config.enable_chunked_prefill, + enable_prefix_caching=test_config.enable_prefix_caching, use_flashinfer_fused_rope=None, # default to using flash-infer if available # this is for compatibility with the LTS environment unified_memory_level=0, # unit tests currently broken with UVM track_generated_token_events=test_config.track_generated_token_events, + num_speculative_tokens=test_config.num_speculative_tokens, ), ) @@ -296,6 +304,7 @@ def _build_test_env(cls, test_config): transformer_config = TransformerConfig( params_dtype=torch.bfloat16, num_layers=4, + mtp_num_layers=test_config.num_speculative_tokens, hidden_size=128 if test_config.fp8 else 32, num_attention_heads=4, use_cpu_initialization=True, @@ -337,6 +346,14 @@ def _build_test_env(cls, test_config): elif test_config.transformer_impl == "inference_optimized": layer_spec = get_gpt_layer_with_inference_spec() + # MTP block spec (needed for speculative decoding). + mtp_block_spec = None + if test_config.num_speculative_tokens > 0: + use_te = test_config.fp8 or test_config.transformer_impl == "transformer_engine" + mtp_block_spec = get_gpt_mtp_block_spec( + config=transformer_config, spec=layer_spec, use_transformer_engine=use_te + ) + # GPT model. model = GPTModel( config=transformer_config, @@ -346,6 +363,7 @@ def _build_test_env(cls, test_config): parallel_output=True, pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), + mtp_block_spec=mtp_block_spec, ).cuda() elif test_config.model_provider == "mamba": pp_size = test_config.pipeline_model_parallel_size @@ -355,6 +373,7 @@ def _build_test_env(cls, test_config): num_layers=( 3 if pp_size == 1 else 6 ), # 1 Mamba layer, 1 attention layer, 1 MLP layer + mtp_num_layers=test_config.num_speculative_tokens, hidden_size=256, # The Mamba layer places several constraints on this mamba_num_heads=16, num_attention_heads=16, @@ -1053,7 +1072,7 @@ def test_parallel_inference( if tp_size == 1 and pp_size == 1 and ep_size == 1: pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") elif not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") + Utils.initialize_distributed() world_size = torch.distributed.get_world_size() min_world_size = tp_size * pp_size * ep_size if world_size < min_world_size: @@ -1074,10 +1093,6 @@ def test_parallel_inference( "when tp_size > 1." ) ) - if model_provider == "mamba": - pytest.skip( - reason="Mamba model is not supported with the inference optimized transformer." - ) env = self._run_test( model_provider=model_provider, @@ -1999,3 +2014,689 @@ def test_staleness_tracking(self, use_checkpoint): assert (record[-1].policy_staleness == pre_ps + 1).all() assert (record[-1].kv_cache_staleness == 0).all() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_early_termination(self): + """Test that speculative decoding handles premature request termination safely + (e.g. hitting max_sequence_length mid-speculative-batch).""" + + # Set max_sequence_length tight so it terminates during a speculative step + test_config = DynamicEngineTestConfig( + num_requests=1, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=3, # Prompt (4) + Gen (3) = 7 + max_sequence_length=7, # Will force termination after 3 tokens + model_provider="gpt", + num_speculative_tokens=3, + materialize_only_last_token_logits=False, + ) + + env = self._build_test_env(test_config) + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward to return deterministic data so speculative tokens are always accepted + hidden_size = unwrapped_model.config.hidden_size + + def mock_mtp_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + + base_logits = torch.zeros( + tokens.size(0), + tokens.size(1), + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + base_logits[:, :, 0] = 100.0 # High probability for token 0 + + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + tokens.size(1), 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 + ) + return base_logits + + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits[:, :, 0] = 100.0 # High probability for token 0 + return hidden_states, logits + + unwrapped_model.forward = mock_mtp_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step + + env.engine._add_request(env.requests[0]) + env.engine.schedule_waiting_requests() + + # Step engine until finished naturally + # This allows the bookkeeping logic to gracefully truncate the + # speculative tokens to the max_sequence_length boundary. + while env.engine.has_unfinished_requests(): + env.engine.step_modern() + + assert env.requests[0].status == Status.COMPLETED + + # It should trim the output to the max_sequence_length boundary + # Prompt was 4, Max was 7, so it should have generated exactly 3 tokens. + assert len(env.requests[0].generated_tokens) == 3 + + # Validate the engine's tracking state is clean + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 + + @pytest.mark.internal + @torch.inference_mode() + def test_speculative_block_boundary_crossing(self): + """Test to verify KV cache block boundary crossing logic. + + When a request fills exactly one block and speculative decoding generates + multiple tokens, the first new token shouldn't incorrectly overwrite the old block. + """ + test_config = DynamicEngineTestConfig( + num_requests=1, + min_prompt_length=256, + max_prompt_length=256, + num_tokens_to_generate=3, + num_speculative_tokens=2, + context_block_size_tokens=256, # Exactly matches prompt length + context_max_requests=16, + model_provider="gpt", + materialize_only_last_token_logits=False, + use_fixed_output_lengths=True, + ) + env = self._build_test_env(test_config) + + req = env.requests[0] + req.sampling_params.num_tokens_to_generate = 3 + env.engine._add_request(req) + env.engine.schedule_waiting_requests() + + # Step 1: Prefill. Processes the 4 prompt tokens. + # At the end of this step, `update_requests` prepares the token indices for Step 2. + # It assigns block indices for the 3 upcoming tokens (1 base + 2 spec). + env.engine.step_modern() + + context = env.engine.context + + # The request has 2 blocks allocated now (1 for prompt, 1 for the new 3 tokens) + assigned_blocks = context.request_to_kv_block_ids[0] + first_block = assigned_blocks[0].item() + second_block = assigned_blocks[1].item() + + # The active_token_count for the next step should be 3 + assert context.active_token_count == 3 + + # Check which blocks the 3 new tokens are assigned to. + # Because the prompt exactly filled the first block, ALL 3 new tokens + # MUST go to the second block. + token_blocks = context.token_to_block_idx[: context.active_token_count].tolist() + + assert token_blocks == [ + second_block, + second_block, + second_block, + ], f"Expected all new tokens to go to block {second_block}, but got {token_blocks}." + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_stop_word_hit(self): + """Test that if an accepted speculative token completes a stop word, + the request correctly triggers the stop logic without crashing.""" + + test_config = DynamicEngineTestConfig( + num_requests=0, # We will manually add our request cleanly + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=10, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", + ) + env = self._build_test_env(test_config) + + unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size + + # Mock forward to deterministically output an ascending sequence (1->2->3...) + def mock_deterministic_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 + ) + return base_logits + + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict next_token_ids + 1 (continuing the ascending sequence) + pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, pred_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + + unwrapped_model.forward = mock_deterministic_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step + + # Add the request formally to ensure all internal state tensors align + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + ) + + # Inject the parsed stop word IDs + tracked_req = env.engine.get_request(0) + tracked_req.stop_word_ids = [[8, 9]] # The sequence will generate 5, 6, 7, 8, 9, ... + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + # Retrieve the finalized request from the engine's output + finished_req = finished_records[0].merge() + + assert finished_req.status == Status.COMPLETED + # Since num_tokens_to_generate=10, output should stop early at ~7 tokens + assert len(finished_req.generated_tokens) < 10 + # Verify the stop word was actually generated and caused the termination + token_pairs = [ + finished_req.generated_tokens[i : i + 2] + for i in range(len(finished_req.generated_tokens) - 1) + ] + assert [8, 9] in token_pairs + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_long_stop_word_hit(self): + """Test that if an accepted speculative token completes a long stop word + (length > num_speculative_tokens), it is correctly detected.""" + + test_config = DynamicEngineTestConfig( + num_requests=0, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=10, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", + ) + env = self._build_test_env(test_config) + + unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size + + # Mock forward to deterministically output an ascending sequence + def mock_deterministic_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 + ) + return base_logits + + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict next_token_ids + 1 (continuing the ascending sequence) + pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, pred_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + + unwrapped_model.forward = mock_deterministic_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step + + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + ) + + # Stop word length 3 > num_speculative_tokens (2) + tracked_req = env.engine.get_request(0) + tracked_req.stop_word_ids = [[7, 8, 9]] + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + finished_req = finished_records[0].merge() + + assert finished_req.status == Status.COMPLETED + assert len(finished_req.generated_tokens) < 10 + token_triplets = [ + finished_req.generated_tokens[i : i + 3] + for i in range(len(finished_req.generated_tokens) - 2) + ] + assert [7, 8, 9] in token_triplets + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_stop_word_truncates_trailing_tokens(self): + """Test that when a stop word lands in the middle of speculative tokens, + the extra tokens generated after the stop word are removed. + + With num_speculative_tokens=2, each step produces up to 3 tokens + (1 base + 2 speculative). If the stop word is [6] and the engine + generates [5, 6, 7] in one step, token 7 must be truncated so the + output ends with the stop word [6].""" + + test_config = DynamicEngineTestConfig( + num_requests=0, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=10, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", + ) + env = self._build_test_env(test_config) + + unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size + + # Mock forward to deterministically output an ascending sequence (1->2->3...) + def mock_deterministic_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 + ) + return base_logits + + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict next_token_ids + 1 (continuing the ascending sequence) + pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, pred_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + + unwrapped_model.forward = mock_deterministic_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step + + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + ) + + # Stop word [6] will land in the middle of a speculative batch [5, 6, 7]. + # Token 7 should be truncated from the output. + tracked_req = env.engine.get_request(0) + tracked_req.stop_word_ids = [[6]] + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + finished_req = finished_records[0].merge() + + assert finished_req.status == Status.COMPLETED + # The output should end exactly at the stop word, with no trailing tokens. + assert finished_req.generated_tokens[-1] == 6, ( + f"Expected last token to be stop word 6, " + f"got {finished_req.generated_tokens[-1]}. " + f"Trailing tokens after stop word were not truncated. " + f"Full output: {finished_req.generated_tokens}" + ) + # Verify no tokens after the stop word exist + assert 7 not in finished_req.generated_tokens, ( + f"Token 7 should have been truncated after stop word 6. " + f"Full output: {finished_req.generated_tokens}" + ) + + @pytest.mark.internal + @torch.inference_mode() + def test_speculative_sequence_length_double_counting(self): + """Test to verify active_sequence_lengths is not double-counted. + + If active sequence length is double-counted during speculative decoding, + the request will terminate prematurely before generating the requested tokens. + """ + test_config = DynamicEngineTestConfig( + num_requests=0, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=6, + max_sequence_length=10, # Exactly prompt (4) + generate (6) + context_max_requests=16, + num_speculative_tokens=2, + model_provider="gpt", + materialize_only_last_token_logits=False, + use_fixed_output_lengths=False, + context_max_tokens=512, + ) + env = self._build_test_env(test_config) + + # Mock forward pass to return deterministic base logits. + # Speculative tokens will be wrong (predicted by MTP as tokens + 5) + # to guarantee rejection every time. + model = env.engine.controller.inference_wrapped_model.model + hidden_size = model.config.hidden_size + + def mock_mtp_forward_reject(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + # Base model correctly predicts tokens + 1 + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + # Cache hidden states for serial MTP computation + model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 + ) + return base_logits + + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict wildly wrong tokens (+ 5) to guarantee rejection + wrong_toks = (next_token_ids + 5).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, wrong_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + + model.forward = mock_mtp_forward_reject + model.compute_mtp_single_step = mock_compute_mtp_single_step + + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=6, termination_id=99), + ) + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + finished_req = finished_records[0].merge() + + # If there is double counting, the tracked active length will outpace the actual + # generated tokens, causing premature termination when it thinks it hit max_sequence_length. + assert finished_req.status == Status.COMPLETED + assert ( + len(finished_req.generated_tokens) == 6 + ), f"Expected 6 tokens, got {len(finished_req.generated_tokens)}. Double counting occurred." + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_eviction_and_swapping(self): + """Test that speculative decoding works correctly when requests are paused and evicted. + + This exercises the `_swap_book_keeping_tensors` logic with the 2D `new_speculative_tokens` + tensor, ensuring no dimensional mismatch or index errors occur during tensor swapping. + """ + # Very constrained memory environment to force pausing and eviction + test_config = DynamicEngineTestConfig( + num_requests=3, + min_prompt_length=256, + max_prompt_length=256, + num_tokens_to_generate=512, + context_block_size_tokens=256, + num_speculative_tokens=2, + context_buffer_size_gb=0.00064, # 640 KB + context_paused_buffer_size_gb=0.0, # 0 paused buffer forces immediate eviction + model_provider="gpt", + materialize_only_last_token_logits=False, + use_fixed_output_lengths=True, + ) + + env = self._build_test_env(test_config) + + unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size + + # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes + # in torch.multinomial caused by randomly initialized weights. + def mock_safe_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + base_logits[:, :, 0] = 100.0 # Force model to deterministically pick token 0 + + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 + ) + return base_logits + + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits[:, :, 0] = 100.0 # Force speculative heads to also pick token 0 + return hidden_states, logits + + unwrapped_model.forward = mock_safe_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step + + # Add all requests at once. They will all start prefill, but as they generate + # and request more blocks, the engine will run out of active blocks. + # Since paused_buffer_size is 0, any request that pauses will immediately + # overflow the paused buffer and trigger an eviction. + for request in env.requests: + request.sampling_params.num_tokens_to_generate = 512 + env.engine._add_request(request) + + eviction_occurred = False + + # Step the engine manually until all requests finish. + while env.engine.has_unfinished_requests(): + # Record the number of evicted requests before the step + evicted_before = env.engine.evicted_request_count + + # Step the engine + env.engine.schedule_waiting_requests() + env.engine.step_modern() + + # Check if any request was evicted during this step + if env.engine.evicted_request_count > evicted_before: + eviction_occurred = True + + # Assert that our constrained memory actually caused an eviction, + # proving we exercised the evict_overflow_paused_requests path with spec tokens. + assert ( + eviction_occurred + ), "Test failed to trigger an eviction. The test environment memory wasn't tight enough." + + # Verify all requests successfully went back through the queue and finished cleanly. + # We MUST check the merged records from the engine, because eviction checkpoints + # the requests, leaving the original instances in env.requests permanently active. + for request_id, entry in env.engine.requests.items(): + merged_req = entry.record.merge() + assert ( + merged_req.status == Status.COMPLETED + ), f"Request {request_id} failed to complete." + assert ( + len(merged_req.generated_tokens) == 511 + ), f"Request {request_id} didn't generate expected tokens." + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_prefix_caching(self): + """Test that speculative decoding works correctly when prefix caching is enabled. + + Two requests share the same prompt prefix. The second request should reuse + cached KV blocks from the first and still generate correctly with spec decoding. + """ + test_config = DynamicEngineTestConfig( + num_requests=0, # Added manually below + min_prompt_length=256, + max_prompt_length=256, + num_tokens_to_generate=4, + num_speculative_tokens=2, + enable_prefix_caching=True, # Set at config level + context_block_size_tokens=256, # Ensure exact 1 block per prompt + materialize_only_last_token_logits=False, + model_provider="gpt", + context_max_tokens=4096, + context_max_requests=512, + ) + env = self._build_test_env(test_config) + + # Create two pairs of requests with identical shared prefixes. + shared_prompt_a = torch.randint( + 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' + ) + shared_prompt_b = torch.randint( + 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' + ) + + prompts = [shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b] + + for i, prompt in enumerate(prompts): + # Using the clean public API guarantees correct hashing and dataclass creation + env.engine.add_request( + request_id=i, + prompt=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), + ) + + # First, run schedule_waiting_requests and ONE step to allocate the prefill blocks. + # Req 0 and 2 will schedule immediately. Req 1 and 3 will defer because their hashes + # are currently pending (being registered by 0 and 2). + env.engine.schedule_waiting_requests() + env.engine.step_modern() + + # After step 1, Req 0 and 2 have completely registered their cached blocks. + # Now, schedule the deferred ones (Req 1 and 3). They will find the registered blocks! + env.engine.schedule_waiting_requests() + env.engine.step_modern() + + # 4 requests. 2 unique prefixes (1 block each). + # Without sharing, we'd need 8 blocks + 1 dummy = 9 active_used. + # With sharing, we need 2 shared blocks + 4 generation blocks + 1 dummy = 7 active_used. + active_used = env.engine.context.block_allocator.get_active_used() + assert ( + active_used <= 7 + ), f"Prefix caching failed, expected <= 7 active blocks but got {active_used}" + + while env.engine.has_unfinished_requests(): + env.engine.step_modern() + + # Context should be clean after all requests finish. + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_chunked_prefill_and_prefix_caching(self): + """End-to-end test combining speculative decoding, chunked prefill, and prefix caching. + + Verifies that all three features interact correctly: + - Prefix caching shares KV blocks between requests with common prompts + - Chunked prefill processes long prompts in chunks + - Speculative decoding generates multiple tokens per step + """ + test_config = DynamicEngineTestConfig( + num_requests=0, + min_prompt_length=512, + max_prompt_length=512, + num_tokens_to_generate=128, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + enable_chunked_prefill=True, + enable_prefix_caching=True, # Set at config level + context_block_size_tokens=256, + model_provider="gpt", + context_max_tokens=1536, # Force chunking + context_max_requests=48, + ) + env = self._build_test_env(test_config) + + # Create identical prompts for all 4 requests + shared_prompt = torch.randint( + 0, test_config.vocab_size - 1, (512,), dtype=torch.int64, device='cuda' + ) + + for i in range(4): + env.engine.add_request( + request_id=i, + prompt=shared_prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), + ) + + while env.engine.has_unfinished_requests(): + env.engine.step_modern() + + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index 81d538ec60e..f899f8b1c97 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -370,3 +370,102 @@ def test_one_rank_oversized_forces_no_match(self, num_cuda_graphs): result = _match(real, graph_list, ep_group=ep_group) _assert_consistent_across_ranks(result, ep_group) assert result is None, "All-reduce max from oversized rank should cause no match" + + +class TestSpeculativeDecodingBatchDimensions: + """Tests for batch dimensions specifically handling speculative decoding.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=Utils.world_size, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @staticmethod + def _get_ep_group(): + return ps.get_expert_model_parallel_group() + + @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 5]) + def test_generate_graphs_with_speculative_tokens(self, num_speculative_tokens): + """Verify graph generation strictly adheres to the speculative token multiplier.""" + graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=TP_SIZE, + num_cuda_graphs=4, + cuda_graph_max_tokens=MAX_REQUESTS * (num_speculative_tokens + 1), + cuda_graph_mixed_prefill_request_count=MIXED_PREFILL_COUNT, + max_requests=MAX_REQUESTS, + max_tokens=MAX_TOKENS, + max_sequence_length=MAX_SEQ_LEN, + use_cuda_graphs_for_non_decode_steps=True, + num_speculative_tokens=num_speculative_tokens, + ) + + # For pure decode graphs, token_count must exactly equal decode_req_count * (spec_tokens + 1) + decode_graphs = [g for g in graph_list if g.prefill_req_count == 0] + assert len(decode_graphs) > 0, "Should generate decode-only graphs" + + for g in decode_graphs: + expected_tokens = g.decode_req_count * (num_speculative_tokens + 1) + assert g.token_count == expected_tokens, ( + f"Mismatch in speculative token math: Expected {expected_tokens} tokens " + f"for {g.decode_req_count} requests with {num_speculative_tokens} spec tokens, got {g.token_count}." + ) + + def test_is_valid_with_speculative_tokens(self): + """Verify that validation correctly enforces speculative token budgets.""" + num_speculative_tokens = 4 + # 10 decode requests * (4 spec + 1 actual) = 50 tokens required. + + # 49 tokens is not enough -> should be invalid + bd_invalid = BD(token_count=49, prefill_req_count=0, decode_req_count=10) + assert not bd_invalid.is_valid( + max_requests=MAX_REQUESTS, + max_sequence_length=MAX_SEQ_LEN, + num_speculative_tokens=num_speculative_tokens, + ), "Should reject batch dimension without enough tokens for speculative budget." + + # Exactly 50 tokens -> should be valid + bd_valid = BD(token_count=50, prefill_req_count=0, decode_req_count=10) + assert bd_valid.is_valid( + max_requests=MAX_REQUESTS, + max_sequence_length=MAX_SEQ_LEN, + num_speculative_tokens=num_speculative_tokens, + ), "Should accept batch dimension with perfectly matched speculative budget." + + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, -1]) + def test_ep_sync_with_speculative_tokens(self, num_cuda_graphs): + """Verify matching and EP rank syncing scales correctly with speculative tokens.""" + ep_group = self._get_ep_group() + num_speculative_tokens = 2 + + graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=TP_SIZE, + num_cuda_graphs=num_cuda_graphs, + cuda_graph_max_tokens=MAX_REQUESTS * (num_speculative_tokens + 1), + cuda_graph_mixed_prefill_request_count=MIXED_PREFILL_COUNT, + max_requests=MAX_REQUESTS, + max_tokens=MAX_TOKENS, + max_sequence_length=MAX_SEQ_LEN, + use_cuda_graphs_for_non_decode_steps=True, + num_speculative_tokens=num_speculative_tokens, + ) + + rank = dist.get_rank() + + # Each rank has a different number of decode requests. + decode_reqs = (rank + 1) * 2 + token_count = decode_reqs * (num_speculative_tokens + 1) + real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=decode_reqs) + + result = _match(real, graph_list, ep_group=ep_group) + + # All ranks should end up syncing to the maximum requirement and picking the same graph + _assert_consistent_across_ranks(result, ep_group) + if result is not None: + # Confirm the selected graph preserves the speculative token mathematical invariance + assert result.token_count == result.decode_req_count * (num_speculative_tokens + 1) diff --git a/tests/unit_tests/inference/test_stop_words.py b/tests/unit_tests/inference/test_stop_words.py index 31665c0bb81..525194004db 100644 --- a/tests/unit_tests/inference/test_stop_words.py +++ b/tests/unit_tests/inference/test_stop_words.py @@ -31,132 +31,344 @@ class TestStopWordDetection: """Test stop word detection logic.""" def _check_stop_words_for_request_post_append( - self, request: MockDynamicInferenceRequest - ) -> bool: + self, request: MockDynamicInferenceRequest, num_speculative_tokens: int = 0 + ) -> tuple: """ Check if a request should stop due to stop words (after token is appended). - This mirrors the logic in DynamicInferenceEngine._check_stop_words_for_request_post_append + This mirrors the logic in DynamicInferenceEngine._check_stop_words_for_request_post_append. + Returns (stop_word_hit, num_tokens_trimmed). """ - # Check if request has stop words configured if request.stop_word_ids is None or len(request.stop_word_ids) == 0: - return False + return False, 0 generated_tokens = request.generated_tokens - # Check if the sequence ends with any stop word for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: - # Check if the last stop_len tokens match the stop word - if list(generated_tokens[-stop_len:]) == stop_word_ids: - return True + for i in range(num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + if i > 0: + request.generated_tokens = request.generated_tokens[:-i] + return True, i - return False + return False, 0 def test_no_stop_words_configured(self): """Test that requests without stop words configured don't trigger stop.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=None ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False + assert trim == 0 def test_empty_stop_words_list(self): """Test that empty stop words list doesn't trigger stop.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_single_token_stop_word_match(self): """Test detection of single-token stop word.""" - # Stop word is token 300 request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[300]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True + assert trim == 0 + assert request.generated_tokens == [100, 200, 300] def test_single_token_stop_word_no_match(self): """Test no detection when single-token stop word doesn't match.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[400]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_multi_token_stop_word_match(self): """Test detection of multi-token stop word.""" - # Stop word is tokens [200, 300] request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True + assert trim == 0 def test_multi_token_stop_word_no_match_partial(self): """Test no detection when only partial stop word matches.""" - # Stop word is [200, 300], but generated ends with [100, 200] request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200], stop_word_ids=[[200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_multi_token_stop_word_no_match_wrong_order(self): """Test no detection when tokens are present but in wrong order.""" - # Stop word is [200, 300], but generated ends with [300, 200] request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 300, 200], stop_word_ids=[[200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_multiple_stop_words_first_matches(self): """Test with multiple stop words where first one matches.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[300], [400], [500]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True def test_multiple_stop_words_second_matches(self): """Test with multiple stop words where second one matches.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 400], stop_word_ids=[[300], [400], [500]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True def test_multiple_stop_words_none_match(self): """Test with multiple stop words where none match.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 600], stop_word_ids=[[300], [400], [500]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_stop_word_longer_than_generated(self): """Test that stop word longer than generated tokens doesn't crash.""" - # Stop word is 5 tokens, but only 3 tokens generated request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[1, 2, 3, 4, 5]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_stop_word_exact_length_match(self): """Test stop word that matches entire generated sequence.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[100, 200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True def test_empty_generated_tokens(self): """Test with no generated tokens.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[], stop_word_ids=[[300]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_stop_word_in_middle_not_end(self): """Test that stop word in middle of sequence doesn't trigger (only end matters).""" - # Stop word is [200], which is in middle but not at end request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[200]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False + + +class TestStopWordSpeculativeDecoding: + """Test stop word detection and truncation with speculative decoding.""" + + def _check_stop_words_for_request_post_append( + self, request: MockDynamicInferenceRequest, num_speculative_tokens: int = 0 + ) -> tuple: + """Mirror of DynamicInferenceEngine._check_stop_words_for_request_post_append.""" + if request.stop_word_ids is None or len(request.stop_word_ids) == 0: + return False, 0 + + generated_tokens = request.generated_tokens + + for stop_word_ids in request.stop_word_ids: + stop_len = len(stop_word_ids) + if len(generated_tokens) >= stop_len: + for i in range(num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + if i > 0: + request.generated_tokens = request.generated_tokens[:-i] + return True, i + + return False, 0 + + def test_stop_word_at_end_no_trim(self): + """Stop word is the last token — no trimming needed.""" + # Speculative tokens: [tok1, STOP, tok3] appended, stop word at end of accepted + # But here STOP is at the very end after all tokens + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 0 + assert request.generated_tokens == [10, 20, 42] + + def test_stop_word_with_one_extra_token(self): + """Stop word is second-to-last — one extra token should be trimmed.""" + # Speculative appended [tok1, STOP, tok3], STOP=42 at position -2 + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42, 99], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 1 + assert request.generated_tokens == [10, 20, 42] + + def test_stop_word_with_two_extra_tokens(self): + """Stop word is third-to-last — two extra tokens should be trimmed.""" + # Speculative appended [STOP, tok2, tok3], STOP=42 at position -3 + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 42, 77, 88], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 2 + assert request.generated_tokens == [10, 42] + + def test_multi_token_stop_word_with_extra_tokens(self): + """Multi-token stop word found mid-speculative-batch.""" + # Speculative appended [tok1, STOP_A, STOP_B, tok4], stop word is [STOP_A, STOP_B] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42, 43, 99], stop_word_ids=[[42, 43]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 1 + assert request.generated_tokens == [10, 20, 42, 43] + + def test_multi_token_stop_word_with_two_extra(self): + """Multi-token stop word with two extra tokens after.""" + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 42, 43, 77, 88], stop_word_ids=[[42, 43]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 2 + assert request.generated_tokens == [10, 42, 43] + + def test_no_stop_word_speculative(self): + """No stop word in speculative batch — nothing happens.""" + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 30, 40], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is False + assert trim == 0 + assert request.generated_tokens == [10, 20, 30, 40] + + def test_stop_word_outside_speculative_window(self): + """Stop word exists but is outside the speculative search window.""" + # Stop word [42] is at position -4, but num_speculative_tokens=2 + # so we only check positions -1, -2, -3 (i=0,1,2) + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[42, 10, 20, 30], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is False + assert trim == 0 + + def test_log_probs_trimming_scenario(self): + """Verify that the trim count can be used to trim log probs correctly.""" + # Simulate: speculative batch appended [tok1, STOP, tok3] + # Log probs: [lp1, lp2, lp3] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42, 99], stop_word_ids=[[42]] + ) + log_probs = [-1.5, -0.3, -2.1] + + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 1 + + # Trim log probs the same way the engine does + if trim > 0: + log_probs = log_probs[:-trim] + + assert log_probs == [-1.5, -0.3] + assert request.generated_tokens == [10, 20, 42] + + def test_speculative_stop_word_at_end(self): + """Test stop word at end of speculative tokens (no truncation needed).""" + # Speculative tokens appended: [200, 300], stop word is [300] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[300]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=2) + is True + ) + assert request.generated_tokens == [100, 200, 300] + + def test_speculative_stop_word_in_middle_truncates(self): + """Test that stop word in middle of speculative tokens truncates trailing tokens.""" + # Speculative tokens appended: [200, 300, 400], stop word is [200] + # Token 200 is at position -3, so tokens [300, 400] should be truncated + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300, 400], stop_word_ids=[[200]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=3) + is True + ) + assert request.generated_tokens == [100, 200] + + def test_speculative_multi_token_stop_word_in_middle_truncates(self): + """Test multi-token stop word in middle of speculative tokens truncates.""" + # Generated: [100, 200, 300, 400, 500], stop word is [200, 300] + # Stop word ends at -2, so tokens [400, 500] should be truncated + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300, 400, 500], stop_word_ids=[[200, 300]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=4) + is True + ) + assert request.generated_tokens == [100, 200, 300] + + def test_speculative_stop_word_not_found(self): + """Test no stop word found even with speculative scanning.""" + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300, 400], stop_word_ids=[[999]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=3) + is False + ) + assert request.generated_tokens == [100, 200, 300, 400] + + def test_speculative_stop_word_one_trailing_token(self): + """Test stop word with exactly one trailing token to truncate.""" + # Generated: [100, 200, 300], stop word is [200], one trailing token [300] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[200]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=2) + is True + ) + assert request.generated_tokens == [100, 200] class TestStopWordTrackingFlow: diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index bdf95c2d9bf..ff296b68390 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -52,6 +52,11 @@ def setup_model( batch_size: int = 4, static: bool = True, use_training_random_init: bool = False, + materialize_only_last_token_logits: bool = False, + num_speculative_tokens: int = 0, + block_size_tokens: int = 256, + enable_prefix_caching: bool = False, + max_requests: int = None, ): Utils.initialize_model_parallel( tensor_model_parallel_size=tensor_model_parallel_size, @@ -108,10 +113,14 @@ def setup_model( inference_config=InferenceConfig( max_sequence_length=2048, buffer_size_gb=0.2, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, use_flashinfer_fused_rope=None, # default to using flash-infer if available # this is for compatibility with the LTS environment unified_memory_level=0, # unit tests currently broken with UVM + num_speculative_tokens=num_speculative_tokens, + block_size_tokens=block_size_tokens, + enable_prefix_caching=enable_prefix_caching, + max_requests=max_requests, ), ) @@ -224,11 +233,15 @@ def test_sample_from_dynamic_logits( self, backend: str, materialize_only_last_token_logits: bool ): batch_size = 12 - self.setup_model(torch.float32, batch_size=batch_size, static=False) + self.setup_model( + torch.float32, + batch_size=batch_size, + static=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, + ) self.mock_tokenizer.eod = self.vocab_size context = self.text_generation_controller.inference_wrapped_model.inference_context - context.materialize_only_last_token_logits = materialize_only_last_token_logits # Prepare sampling params in human-readable format, to aid with test maintenance. sampling_test_cases: List[Tuple[SamplingParams, List[int]]] = [ @@ -743,11 +756,15 @@ def test_dynamic_top_n_logprobs_calculation( 3. Correct number of tokens are returned for each request """ batch_size = 4 - self.setup_model(torch.bfloat16, batch_size=batch_size, static=False) + self.setup_model( + torch.bfloat16, + batch_size=batch_size, + static=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, + ) self.mock_tokenizer.eod = self.vocab_size context = self.text_generation_controller.inference_wrapped_model.inference_context - context.materialize_only_last_token_logits = materialize_only_last_token_logits # Prepare sampling params top_n = 5 @@ -1007,3 +1024,289 @@ def test_sampled_tokens_match_with_parallelism(self, static, tp_size, pp_size): assert ( expected == actual ), f"Rank {i} tokens differ from rank {local_rank} tokens for request {j}" + + @pytest.mark.internal + def test_speculative_verify_tokens(self): + """Test consecutive token acceptance logic for speculative decoding.""" + self.setup_model(torch.float32, static=False, num_speculative_tokens=2, max_requests=2) + + # Enable speculative decoding + self.text_generation_controller.num_speculative_tokens = 2 + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor( + [0, 0], device='cuda' + ) # Decode requests + ctx.request_query_lengths = torch.tensor( + [3, 3], dtype=torch.int32, device='cuda' + ) # 1 sampled + 2 spec + + # Init accepted tokens tensors + self.text_generation_controller._init_mtp_sampling_tensor() + + # Mock inputs: [Req 1 sampled, Req 1 spec1, Req 1 spec2, Req 2 sampled, Req 2 spec1, Req 2 spec2] + # Target tokens (what the model was fed): [T0, T1, T2, T3, T4, T5] + input_ids = torch.tensor([[10, 11, 12, 20, 21, 22]], device='cuda') + + # We need the sampling function to return a 1D tensor for base logits, + # and a 1D tensor for the flattened MTP logits. + def mock_sampling_func(logits, *args, **kwargs): + if logits.shape[0] == 6: + # Base logits -> return 1D tensor of shape [6] + # Req 1: Predicts [11, 12, 99]. Matches T1, T2. Rejects T3. -> Accepts 2 spec tokens. + # Req 2: Predicts [99, 22, 23]. Fails at first spec token (99 != 21). -> Accepts 0 spec tokens. + return torch.tensor([11, 12, 99, 99, 22, 23], dtype=torch.long, device='cuda') + else: + # MTP logits -> return 1D tensor of shape [12] + # The verification logic only uses base tokens, so we can return zeros here. + return torch.zeros((12,), dtype=torch.long, device='cuda') + + # Override sampling to return our predictable mock outputs + self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 1, 0.0)] + self.text_generation_controller._torch_sampling_func = mock.MagicMock( + side_effect=mock_sampling_func + ) + + # Mock logits matching input shape + logits = torch.randn(1, 6, self.vocab_size, device='cuda') + mtp_logits = torch.randn(2, 6, self.vocab_size, device='cuda') + + self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens( + logits, mtp_logits, input_ids + ) + + # Verify acceptance counts + accepted_counts = self.text_generation_controller._accepted_token_counts_per_request[:2] + assert torch.equal(accepted_counts, torch.tensor([2, 0], device='cuda')) + + # Verify accepted tokens tensor + accepted_tokens = self.text_generation_controller._accepted_tokens_per_request[:2] + # Req 1 accepted 2 tokens: 11, 12 + assert torch.equal(accepted_tokens[0], torch.tensor([11, 12], device='cuda')) + # Req 2 accepted 0 tokens, should remain -1 + assert torch.equal(accepted_tokens[1], torch.tensor([-1, -1], device='cuda')) + + @pytest.mark.internal + @pytest.mark.parametrize("is_hybrid_model", [False, True]) + def test_rewind_kv_cache(self, is_hybrid_model): + """Test KV cache state is properly rewound for rejected speculative tokens.""" + self.setup_model( + torch.float32, + static=False, + num_speculative_tokens=3, + block_size_tokens=4, + max_requests=16, + ) + self.text_generation_controller.num_speculative_tokens = 3 + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') + + # Initialize allocator and states + ctx.block_allocator.total_avail = 100 + ctx.request_kv_length_offsets[:2] = torch.tensor([10, 15], device='cuda') + ctx.request_kv_block_counts[:2] = torch.tensor([3, 4], device='cuda') + + # Req 0: offset 2. Rewinding 2 tokens -> offset 0. No block released. + # Req 1: offset 1. Rewinding 3 tokens -> offset 2 (prev block). 1 block released. + ctx.request_last_kv_block_offset[:2] = torch.tensor([2, 1], device='cuda') + ctx.request_last_kv_block_id[:2] = torch.tensor([50, 60], device='cuda') + ctx.request_to_kv_block_ids[:2, :4] = torch.tensor( + [[48, 49, 50, -1], [57, 58, 59, 60]], dtype=torch.int, device='cuda' + ) + + if is_hybrid_model: + ctx.is_hybrid_model = True + ctx.mamba_metadata = mock.MagicMock() + ctx.mamba_metadata.request_to_mamba_state_idx = torch.tensor([0, 1], device='cuda') + ctx.mamba_ssm_states = torch.zeros((1, 2, 16), device='cuda') + ctx.mamba_intermediate_ssm_states = torch.ones((1, 2, 4, 16), device='cuda') * 99 + ctx.mamba_conv_states = torch.zeros((1, 2, 8), device='cuda') + ctx.mamba_intermediate_conv_states = torch.ones((1, 2, 4, 8), device='cuda') * 77 + + # Mock accepted token counts: Req 0 accepts 1 (rejects 2), Req 1 accepts 0 (rejects 3) + self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( + [1, 0], device='cuda' + ) + + self.text_generation_controller._rewind_kv_cache() + + # Assert offsets updated + assert torch.equal( + ctx.request_last_kv_block_offset[:2], + torch.tensor([0, 2], dtype=torch.int, device='cuda'), + ) + assert torch.equal( + ctx.request_kv_length_offsets[:2], torch.tensor([8, 12], dtype=torch.int, device='cuda') + ) + + # Assert block counts and IDs updated for boundary crossing + assert torch.equal( + ctx.request_kv_block_counts[:2], torch.tensor([3, 3], dtype=torch.int, device='cuda') + ) + assert torch.equal( + ctx.request_last_kv_block_id[:2], torch.tensor([50, 59], dtype=torch.int, device='cuda') + ) + + # Assert released block is cleared + assert ctx.request_to_kv_block_ids[1, 3].item() == -1 + assert ctx.block_allocator.total_avail == 101 # 1 block released + + if is_hybrid_model: + # Check Mamba state was restored from intermediate cache based on accepted counts + assert torch.all(ctx.mamba_ssm_states[:, 0] == 99) # Req 0 accepted 1, loaded index 1 + assert torch.all(ctx.mamba_ssm_states[:, 1] == 99) # Req 1 accepted 0, loaded index 0 + assert torch.all(ctx.mamba_conv_states[:, 0] == 77) # Req 0 accepted 1, loaded index 1 + assert torch.all(ctx.mamba_conv_states[:, 1] == 77) # Req 1 accepted 0, loaded index 0 + + @pytest.mark.internal + def test_speculative_multinomial_sampling(self): + """Test that speculative decoding can successfully use non-greedy sampling + (top_k > 1, top_p > 0) by flattening 3D MTP logits for torch.multinomial.""" + num_spec = 3 + self.setup_model( + torch.float32, static=False, num_speculative_tokens=num_spec, max_requests=2 + ) + + # Enable speculative decoding + self.text_generation_controller.num_speculative_tokens = num_spec + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor( + [0, 0], device='cuda' + ) # Decode requests + # query lengths for decode with spec tokens is (1 + num_spec) = 4 + ctx.request_query_lengths = torch.tensor([4, 4], dtype=torch.int32, device='cuda') + + # Setup inputs + input_ids = torch.randint(0, self.vocab_size, (1, 8), device='cuda') + + # Create random logits + # Base logits shape: [1, 8, vocab_size] + logits = torch.randn(1, 8, self.vocab_size, device='cuda') + # MTP logits shape: [num_spec, 8, vocab_size] + mtp_logits = torch.randn(num_spec, 8, self.vocab_size, device='cuda') + + # Set up a bucket that forces multinomial sampling (top_p = 0.9, top_k = 0) + # _torch_sampling_buckets format: (indices, temp, top_k, top_p) + self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 0, 0.9)] + + # Since we are actually testing the internal math of `_torch_sampling_func` handling the shapes, + # we DO NOT mock `_torch_sampling_func` here. We want it to run natively to prove it doesn't crash. + + try: + self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens( + logits, mtp_logits, input_ids + ) + except RuntimeError as e: + if "prob_dist must be 1 or 2 dim" in str(e): + pytest.fail("MTP logits were not flattened before calling multinomial sampling.") + else: + raise e + + # Validate that sampling produced output arrays of the correct sizes + active_request_count = ctx.total_request_count + sampled_tokens = self.text_generation_controller._sampled_tokens_cuda[:active_request_count] + sampled_mtp_tokens = self.text_generation_controller._sampled_mtp_tokens_cuda[ + :, :active_request_count + ] + + assert sampled_tokens.shape == (2,) + assert sampled_mtp_tokens.shape == (num_spec, 2) + + @pytest.mark.internal + def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): + """Test that _rewind_kv_cache correctly decrements ref counts on shared blocks + when speculative token rejection causes a block boundary crossing.""" + self.setup_model( + torch.float32, + static=False, + num_speculative_tokens=2, + block_size_tokens=4, + enable_prefix_caching=True, + max_requests=16, + ) + + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') + + # Req 0: 3 blocks, offset 1 in last block. Rewinding 1 token -> no block release. + # Req 1: 3 blocks, offset 0 in last block. Rewinding 2 tokens -> crosses back, release block. + ctx.request_kv_length_offsets[:2] = torch.tensor([9, 9], device='cuda') + ctx.request_kv_block_counts[:2] = torch.tensor([3, 3], device='cuda') + ctx.request_last_kv_block_offset[:2] = torch.tensor([1, 0], device='cuda') + ctx.request_last_kv_block_id[:2] = torch.tensor([10, 20], device='cuda') + ctx.request_to_kv_block_ids[:2, :3] = torch.tensor( + [[8, 9, 10], [18, 19, 20]], dtype=torch.int, device='cuda' + ) + + # Set ref counts: block 20 is shared (ref=2), block 10 is exclusive (ref=1). + ctx.block_allocator.block_ref_counts[20] = 2 + ctx.block_allocator.block_ref_counts[10] = 1 + + initial_avail = ctx.block_allocator.total_avail + + # Req 0 accepts 1 (rewinds 1), Req 1 accepts 0 (rewinds 2, crosses boundary). + self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( + [1, 0], device='cuda' + ) + + self.text_generation_controller._rewind_kv_cache() + + # Req 1 should have released block 20 (ref count decremented). + assert ctx.block_allocator.block_ref_counts[20].item() == 1 + # Block 10 should be untouched. + assert ctx.block_allocator.block_ref_counts[10].item() == 1 + + @pytest.mark.internal + def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): + """Test that rewinding only releases the last block, never shared prefix blocks.""" + self.setup_model( + torch.float32, + static=False, + num_speculative_tokens=3, + block_size_tokens=4, + max_requests=16, + ) + + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.total_request_count = 1 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor([0], device='cuda') + + # 4 blocks. Offset 2 in last block. Rewinding 3 crosses into previous block. + ctx.request_kv_length_offsets[:1] = torch.tensor([14], device='cuda') + ctx.request_kv_block_counts[:1] = torch.tensor([4], device='cuda') + ctx.request_last_kv_block_offset[:1] = torch.tensor([2], device='cuda') + ctx.request_last_kv_block_id[:1] = torch.tensor([40], device='cuda') + ctx.request_to_kv_block_ids[0, :4] = torch.tensor( + [10, 20, 30, 40], dtype=torch.int, device='cuda' + ) + + # Blocks 10, 20 are shared prefix blocks. Block 30, 40 are exclusive. + ctx.block_allocator.total_avail = 50 + + self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( + [0], device='cuda' + ) + + self.text_generation_controller._rewind_kv_cache() + + # Only block 40 should be released, not blocks 10, 20, or 30. + assert ctx.request_kv_block_counts[0].item() == 3 + assert ctx.request_last_kv_block_id[0].item() == 30 + assert ctx.request_to_kv_block_ids[0, 3].item() == -1 + assert ctx.block_allocator.total_avail == 51 # exactly 1 block released + + # Prefix blocks remain in request_to_kv_block_ids. + assert ctx.request_to_kv_block_ids[0, 0].item() == 10 + assert ctx.request_to_kv_block_ids[0, 1].item() == 20 + assert ctx.request_to_kv_block_ids[0, 2].item() == 30 diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py new file mode 100644 index 00000000000..3015f5ed989 --- /dev/null +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -0,0 +1,258 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import pytest +import torch + +from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update + + +def _requires_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + +# ---------------------- Reference Implementations ---------------------- # + + +def causal_conv1d_update_ref(x, conv_state, weight, bias, silu_activation): + """Reference: linear (non-circular) causal conv1d update.""" + batch, seq_len, dim = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + out = torch.empty_like(x) + for b in range(batch): + for s in range(seq_len): + # Shift state left by 1 + conv_state[b, :, :-1] = conv_state[b, :, 1:].clone() + conv_state[b, :, -1] = x[b, s, :] + # Convolution over the last `width` elements + window = conv_state[b, :, state_len - width : state_len].float() + w = weight.float() + val = (window * w).sum(dim=1) + if bias is not None: + val = val + bias.float() + if silu_activation: + val = val * torch.sigmoid(val) + out[b, s, :] = val.to(x.dtype) + return out + + +# ---------------------- Tests ---------------------- # + + +@pytest.mark.internal +class TestCausalConv1dUpdate: + + def setup_method(self, method): + _requires_cuda() + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_no_bias(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 3, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state_triton, weight, bias=None, silu_activation=False, conv_state_indices=None + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=None, silu_activation=False + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_with_bias(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 3, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + bias = torch.randn(D, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state_triton, weight, bias=bias, silu_activation=False, conv_state_indices=None + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=bias, silu_activation=False + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_with_silu(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 1, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + bias = torch.randn(D, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state_triton, weight, bias=bias, silu_activation="silu", conv_state_indices=None + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=bias, silu_activation=True + ) + + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_2d_input(self): + """Test that 2D input (B, D) is handled correctly and returns 2D output.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + x = torch.randn(B, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state, weight, bias=None, silu_activation=False, conv_state_indices=None + ) + + assert result.dim() == 2 + assert result.shape == (B, D) + + def test_conv_state_indices(self): + """Test that conv_state_indices correctly maps batch to state entries.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + num_states = 4 + x = torch.randn(B, 1, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(num_states, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + # Map batch 0 -> state 2, batch 1 -> state 0 + state_indices = torch.tensor([2, 0], device="cuda", dtype=torch.int32) + + # Run with indices + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + ) + + # Run without indices by manually reordering + conv_state_reordered = conv_state[state_indices.long()].clone() + expected = causal_conv1d_update( + x, + conv_state_reordered, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + def test_negative_state_index_zeros_output(self): + """Padding batch entries (index < 0) should produce zero output.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + x = torch.randn(B, 1, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + state_indices = torch.tensor([-1, 0], device="cuda", dtype=torch.int32) + + result = causal_conv1d_update( + x, + conv_state, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + ) + + # Batch 0 (padded) should be all zeros + torch.testing.assert_close(result[0], torch.zeros(1, D, device="cuda", dtype=torch.float32)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_half_precision(self, dtype): + torch.manual_seed(42) + B, seq_len, D, state_len, width = 2, 1, 64, 8, 4 + x = torch.randn(B, seq_len, D, device="cuda", dtype=dtype) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=dtype) + weight = torch.randn(D, width, device="cuda", dtype=dtype) + + result = causal_conv1d_update( + x, conv_state, weight, bias=None, silu_activation=False, conv_state_indices=None + ) + + assert result.dtype == dtype + assert result.shape == (B, seq_len, D) + assert torch.isfinite(result).all() + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_intermediate_state(self, width): + """Test that intermediate conv states are correctly stored at each sequence step.""" + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 4, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + # Allocate intermediate state buffer: (B, seq_len, D, state_len) + int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32) + + # Run with intermediate state recording + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + intermediate_conv_states=int_states, + ) + + # Verify by running step-by-step and checking each intermediate + conv_state_ref = conv_state.clone() + for s in range(seq_len): + conv_state_ref[:, :, :-1] = conv_state_ref[:, :, 1:].clone() + conv_state_ref[:, :, -1] = x[:, s, :] + torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5) + + def test_intermediate_state_with_indices(self): + """Test intermediate states work correctly with conv_state_indices mapping.""" + torch.manual_seed(42) + B, seq_len, D, state_len, width = 2, 3, 64, 8, 4 + num_states = 4 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(num_states, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + state_indices = torch.tensor([2, 0], device="cuda", dtype=torch.int32) + + # Intermediate states are indexed by state_batch_coord (i.e., req index, not batch index) + int_states = torch.zeros( + num_states, seq_len, D, state_len, device="cuda", dtype=torch.float32 + ) + + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + intermediate_conv_states=int_states, + ) + + # The final intermediate state at last seq step should match the final conv_state + for b_idx in range(B): + req_idx = state_indices[b_idx].item() + torch.testing.assert_close( + int_states[req_idx, seq_len - 1, :, :], + conv_state_copy[req_idx, :, :], + atol=1e-5, + rtol=1e-5, + )