From 690f57d9d545d59bbdc5e1805c17ab9d15ad024e Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 14 Apr 2026 15:48:46 -0700 Subject: [PATCH 1/3] Add MTP cuda graph support Signed-off-by: Keshav Santhanam --- .../inference/gpt/gpt_dynamic_inference.py | 37 +- .../gpt/gpt_dynamic_inference_12b.sh | 1 - .../gpt/gpt_dynamic_inference_357m.sh | 1 - examples/inference/gpt/utils.py | 17 +- .../core/inference/engines/dynamic_engine.py | 94 ++++ .../text_generation_controller.py | 531 +++++++++--------- .../common/language_module/language_module.py | 10 + megatron/core/transformer/cuda_graphs.py | 36 +- .../transformer/multi_token_prediction.py | 24 +- .../cuda_graphs.sh | 1 - .../test_text_generation_controller.py | 97 +++- .../test_attention_variant_dsa.py | 2 + 12 files changed, 521 insertions(+), 330 deletions(-) diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index f02aae9c221..2800742c420 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -284,10 +284,6 @@ def main(): args_defaults={'no_load_rng': True, 'no_load_optim': True}, ) - # Start Nsight profiler. - if os.environ.get("NSIGHT_PREFIX"): - torch.cuda.cudart().cudaProfilerStart() - level_str = os.getenv("LOG_LEVEL", "INFO").upper() level = getattr(logging, level_str, logging.INFO) logging.basicConfig(level=level, force=True) @@ -350,8 +346,23 @@ def main(): print(setup_prefix) print("~~~") + # Warmup: run one untimed iteration so CUDA caches, JIT kernels, and + # allocator pools are ready before the measured runs. + if args.inference_repeat_n > 1: + print("Running warmup iteration ...") + engine.reset() + run_inference(requests, engine) + torch.cuda.synchronize() + engine.reset() + + # Start CUDA profiler after warmup so nsys traces only the measured runs. + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStart() + # Run and time test, optionally `args.inference_repeat_n` times. throughputs = [] + cuda_start_event = torch.cuda.Event(enable_timing=True) + cuda_end_event = torch.cuda.Event(enable_timing=True) for _ in range(args.inference_repeat_n): # Reset engine. @@ -359,19 +370,29 @@ def main(): torch.cuda.reset_peak_memory_stats() - # Trial. + # Synchronize before starting the timer to avoid measuring stale GPU work. + torch.cuda.synchronize() + + # Trial — use both wall-clock and CUDA events for accurate GPU timing. t = get_curr_time() + cuda_start_event.record() result = run_inference(requests, engine) + cuda_end_event.record() step_times = result["step_times"] add_times = result["add_times"] output_times = result["output_times"] total_output_tokens = result["total_output_tokens"] torch.cuda.synchronize() total_time = get_curr_time() - t + cuda_elapsed_ms = cuda_start_event.elapsed_time(cuda_end_event) stats = torch.cuda.memory_stats() throughput = total_output_tokens / total_time throughputs.append(throughput) + # Stop CUDA profiler after measured runs. + if os.environ.get("NSIGHT_PREFIX"): + torch.cuda.cudart().cudaProfilerStop() + # Validate all requests finished. for request in requests: assert request.state == "finished", f"request.state == '{request.state}' != 'finished'." @@ -505,19 +526,17 @@ def escape_str(s): # f"count [ p {p_count}, d {d_count} ]." # ) capture_str = f"{engine.capture_stats['time']:.2f} sec" if engine.capture_stats else "--" + cuda_throughput = total_output_tokens / (cuda_elapsed_ms / 1000.0) print( f"{setup_prefix} … " f"throughput: {throughput:.3f} tok/s … ", f"total time: {total_time:.3f}s … " + f"cuda time: {cuda_elapsed_ms:.1f}ms ({cuda_throughput:.3f} tok/s) … " f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … " f"steps: {engine.context.step_count:d} … " f"capture {capture_str}", ) print("~~~") - # Stop Nsight profiler. - if os.environ.get("NSIGHT_PREFIX"): - torch.cuda.cudart().cudaProfilerStop() - if __name__ == "__main__": main() diff --git a/examples/inference/gpt/gpt_dynamic_inference_12b.sh b/examples/inference/gpt/gpt_dynamic_inference_12b.sh index ca21bb170a5..d848fdb51b7 100644 --- a/examples/inference/gpt/gpt_dynamic_inference_12b.sh +++ b/examples/inference/gpt/gpt_dynamic_inference_12b.sh @@ -5,7 +5,6 @@ set -u # Libraries. -pip install simpy pip install sentencepiece pip install tiktoken diff --git a/examples/inference/gpt/gpt_dynamic_inference_357m.sh b/examples/inference/gpt/gpt_dynamic_inference_357m.sh index cc99bdddec1..d0c126cd191 100644 --- a/examples/inference/gpt/gpt_dynamic_inference_357m.sh +++ b/examples/inference/gpt/gpt_dynamic_inference_357m.sh @@ -5,7 +5,6 @@ set -u # Libraries. -pip install simpy pip install sentencepiece pip install tiktoken diff --git a/examples/inference/gpt/utils.py b/examples/inference/gpt/utils.py index c9b1c05c544..ca26985e046 100644 --- a/examples/inference/gpt/utils.py +++ b/examples/inference/gpt/utils.py @@ -106,18 +106,13 @@ def get_time_offsets( random.seed(seed) - import simpy # Guard against this import in test case - - # Generate random time offsets. - def arrival(r): - while True: - yield env.timeout(random.expovariate(r)) - time_offsets.append(env.now) - + # Generate Poisson arrival times by accumulating exponential inter-arrival intervals. time_offsets = [] - env = simpy.Environment() - env.process(arrival(incoming_requests_per_sec)) - env.run(incoming_requests_duration) + current_time = 0.0 + while current_time < incoming_requests_duration: + current_time += random.expovariate(incoming_requests_per_sec) + if current_time < incoming_requests_duration: + time_offsets.append(current_time) # Ensure at least a single request. if len(time_offsets) == 0: diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index e2aef4b27c6..b5e9a7e8125 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -69,6 +69,7 @@ internal_api, is_row_parallel_linear, trace_async_exceptions, + unwrap_model, ) from .async_zmq_communicator import AsyncZMQCommunicator @@ -452,6 +453,11 @@ def create_cuda_graphs(self, reset_context: bool = True): if is_inference_optimized_ep: unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) + # MTP CUDA graph warmup: capture graphs for the MTP TransformerLayers + # used during speculative decoding. This must happen after decoder graph + # warmup so that the MTP graphs are captured independently. + self._create_mtp_cuda_graphs(controller, context) + # Memory usage. time_end = time.time() mem_stats_end = torch.cuda.memory_stats() @@ -476,6 +482,94 @@ def create_cuda_graphs(self, reset_context: bool = True): self.capture_stats = capture_stats + def _create_mtp_cuda_graphs(self, controller, context): + """Capture CUDA graphs for MTP layers used in speculative decoding. + + Derives the set of MTP batch sizes from the decoder CUDA graph batch + dimensions, then runs ``compute_mtp_single_step`` per batch size to + trigger graph capture. With ``mtp_use_repeated_layer`` one call covers + every depth; with unique layers the remaining depths capture lazily. + """ + num_mtp_heads = controller.num_mtp_heads + num_spec_tokens = controller.num_speculative_tokens or 0 + if num_mtp_heads == 0 or num_spec_tokens == 0: + return + + model = controller.inference_wrapped_model.model + unwrapped = unwrap_model(model) + if not hasattr(unwrapped, 'mtp'): + return + + model_config = model.config + + # Only proceed when local CUDA graphs are enabled. + if model_config.cuda_graph_impl != "local": + return + + # Collect batch sizes from all graph dimensions. MTP serial forward + # runs on all active requests (decode + prefill), so we need graphs + # for total request counts, not just decode-only counts. + tp_size = get_pg_size(controller.inference_wrapped_model.tp_group) + sp_enabled = model_config.sequence_parallel and tp_size > 1 + mtp_batch_sizes = set() + for dim in context.cuda_graph_batch_dimensions_list: + n = dim.req_count + if n > 0: + if sp_enabled: + n += (tp_size - n % tp_size) % tp_size + mtp_batch_sizes.add(n) + if not mtp_batch_sizes: + return + + # Flag that MTP CUDA graphs are available. The actual padded count is + # re-derived at runtime from padded_batch_dimensions.req_count. + controller._has_mtp_cuda_graphs = True + + device = torch.cuda.current_device() + dtype = model_config.params_dtype + hidden_size = model_config.hidden_size + + # Enable inference dispatcher for EP during MTP graph capture. + is_inference_optimized_ep = ( + model_config.transformer_impl == "inference_optimized" + and model_config.expert_model_parallel_size > 1 + ) + if is_inference_optimized_ep: + set_inference_cuda_graphed_iteration_for_ep_inference(model) + + logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_batch_sizes)) + + from megatron.core.transformer.cuda_graphs import _set_capture_end, _set_capture_start + + _set_capture_start() + for batch_size in sorted(mtp_batch_sizes): + dummy_hidden = torch.zeros((batch_size, 1, hidden_size), device=device, dtype=dtype) + if sp_enabled: + from megatron.core.tensor_parallel.mappings import ( + scatter_to_sequence_parallel_region, + ) + + dummy_hidden = scatter_to_sequence_parallel_region( + dummy_hidden, group=controller.inference_wrapped_model.tp_group + ) + dummy_token_ids = torch.zeros((1, batch_size), device=device, dtype=torch.long) + dummy_position_ids = torch.zeros((1, batch_size), device=device, dtype=torch.int64) + + # One call per batch size; depth=0 warms the shared layer (repeated + # mode) or the first unique layer (non-repeated mode). + unwrapped.compute_mtp_single_step( + hidden_states=dummy_hidden, + next_token_ids=dummy_token_ids, + position_ids=dummy_position_ids, + depth=0, + ) + _set_capture_end() + + if is_inference_optimized_ep: + unset_inference_cuda_graphed_iteration_for_ep_inference(model) + + logging.info("> MTP CUDA graph warmup complete") + @internal_api async def start_listening_to_data_parallel_coordinator( self, diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index abf1bbf585b..c57fb1b34d0 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -39,6 +39,8 @@ get_asyncio_loop, get_model_config, get_pg_size, + nvtx_range_pop, + nvtx_range_push, unwrap_model, ) @@ -51,6 +53,12 @@ HAVE_TE = False from megatron.core.inference.batch_dimensions_utils import InferenceBatchDimensions +from megatron.core.inference.text_generation_controllers.triton_kernels import ( + mamba_state_selective_copy, + prepare_next_forward_pass, + rewind_kv_cache, + verify_speculative_tokens, +) # pylint: disable=line-too-long @@ -174,6 +182,27 @@ def _init_mtp_sampling_tensor(self): ) * -1 ) + self._accepted_token_counts_per_request = torch.zeros( + max_requests, dtype=torch.int64, device=device + ) + self._last_accepted_seq_indices_buf = torch.empty( + max_requests, dtype=torch.int64, device=device + ) + + # Cache values that are constant across inference steps. + self._unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + self._is_last_pp_stage = is_pipeline_last_stage(self.pp_group) + self._tp_size = get_pg_size(self.inference_wrapped_model.tp_group) + self._sp_enabled = self.model_config.sequence_parallel and self._tp_size > 1 + self._num_mtp_depths = min(self.num_speculative_tokens, self.num_mtp_heads) + + # Pre-allocate padded buffers for per-depth token/position IDs. + self._mtp_token_ids_buf = torch.empty( + [1, max_requests], dtype=torch.int64, device=device + ) + self._mtp_position_ids_buf = torch.empty( + [1, max_requests], dtype=torch.int64, device=device + ) @staticmethod def tokenize_prompt(tokenizer, prompt: str, add_BOS: bool = False) -> List[int]: @@ -581,6 +610,27 @@ def _dynamic_step_context_init( is_expert_parallel_dummy_cuda_graph_step=is_dummy_forward, ) + # Derive the MTP padded batch size from the EP-synced graph dimensions. + # In eager mode MTP uses locally SP-aligned batch size instead. + if getattr(self, '_has_mtp_cuda_graphs', False) and context.using_cuda_graph_this_step(): + self._mtp_resolved_padded_count = context.padded_batch_dimensions.req_count + if self._sp_enabled: + self._mtp_resolved_padded_count += ( + self._tp_size - self._mtp_resolved_padded_count % self._tp_size + ) % self._tp_size + else: + self._mtp_resolved_padded_count = None + + # Tell MTP layers whether to use CUDA graphs this step. When the main + # model falls back to eager mode, MTP must also run eagerly across all + # EP ranks — otherwise some ranks may replay a captured graph while + # others run eagerly, causing EP collectives to hang. + if getattr(self, '_has_mtp_cuda_graphs', False): + use_mtp_graphs = context.using_cuda_graph_this_step() + if hasattr(unwrapped_model, 'mtp'): + for layer in unwrapped_model.mtp.layers: + layer.use_mtp_cuda_graphs = use_mtp_graphs + # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels symmetric_ar_type = self.model_config.symmetric_ar_type @@ -696,118 +746,76 @@ def _dynamic_step_sample_bookkeeping(self): bucket_map[sampling_params].append(request_index) # Just unpack the key directly! + device = torch.cuda.current_device() self._torch_sampling_buckets = [ (indices, *sampling_params) for sampling_params, indices in bucket_map.items() ] + # Pre-compute index tensors on GPU to avoid per-step H2D copies. + self._torch_sampling_bucket_index_tensors = [ + torch.tensor(indices, device=device, dtype=torch.long) + for indices, *_ in self._torch_sampling_buckets + ] - def _rewind_kv_cache(self): + def _rewind_kv_cache(self) -> tuple: """Update the KV cache bookkeeping for speculative decoding. After forward pass with speculative tokens, some tokens may be rejected. - This function "rewinds" the KV cache bookkeeping to reflect only the accepted tokens. - - When speculative tokens are rejected, we need to: - 1. Update request_kv_length_offsets (total sequence length) - 2. Update request_last_kv_block_offset (position within last block) - 3. If rewinding crosses a block boundary: - - Reduce request_kv_block_counts - - Update request_last_kv_block_id to point to the previous block - - Clear the entry in request_to_kv_block_ids for the released block - - Release the block back to the allocator + This function "rewinds" the KV cache bookkeeping to reflect only the accepted + tokens. The core bookkeeping is handled by a Triton kernel (one thread per + request). Mamba hybrid-model state updates remain in PyTorch. + + Returns (blocks_to_release, remove_mask) for the caller to release blocks + back to the allocator outside the compiled graph. """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) - # Get the accepted token counts for each request - # Note: _accepted_token_counts is indexed from 0 to active_request_count-1 accepted_tokens_per_request = self._accepted_token_counts_per_request[:active_request_count] - # Number of tokens to rewind (rejected speculative tokens) - num_tokens_to_rewind = self.num_speculative_tokens - accepted_tokens_per_request - - # For prefill requests, no speculative tokens were forwarded through the model, - # so there is nothing to rewind. request_in_prefill_status = context.request_in_prefill_status_tensor[active_request_slice] - num_tokens_to_rewind[request_in_prefill_status == 1] = 0 - - # Save the original offset BEFORE modifying to correctly detect block boundary crossing - original_offset = context.request_last_kv_block_offset[active_request_slice].clone() - - # Check which requests need to rewind to a previous block BEFORE modifying - # A request crosses back to a previous block if: original_offset - num_tokens_to_rewind < 0 - remove_allocated_blocks_mask = (original_offset - num_tokens_to_rewind) < 0 - - # Update the offsets - context.request_last_kv_block_offset[active_request_slice] = ( - original_offset - num_tokens_to_rewind - ) % context.block_size_tokens - - context.request_kv_length_offsets[active_request_slice] = ( - context.request_kv_length_offsets[active_request_slice] - num_tokens_to_rewind + request_last_kv_block_offset = context.request_last_kv_block_offset[active_request_slice] + request_kv_length_offsets = context.request_kv_length_offsets[active_request_slice] + request_kv_block_counts = context.request_kv_block_counts[active_request_slice] + request_last_kv_block_id = context.request_last_kv_block_id[active_request_slice] + request_to_kv_block_ids = context.request_to_kv_block_ids[active_request_slice] + + # --- Triton kernel: core KV-cache rewind --- + blocks_to_release, remove_mask = rewind_kv_cache( + accepted_counts=accepted_tokens_per_request, + prefill_status=request_in_prefill_status, + last_kv_block_offset=request_last_kv_block_offset, + kv_length_offsets=request_kv_length_offsets, + kv_block_counts=request_kv_block_counts, + last_kv_block_id=request_last_kv_block_id, + kv_block_ids=request_to_kv_block_ids, + num_speculative_tokens=self.num_speculative_tokens, + block_size_tokens=context.block_size_tokens, ) - # No need to update request_query_lengths (It will be set correctly in the next iteration) - - # For requests that crossed back to a previous block, we need to: - # 1. Reduce the block count by 1 - # 2. Get the block ID to release (current request_last_kv_block_id) - # 3. Update request_last_kv_block_id to point to the previous block - # 4. Clear the entry in request_to_kv_block_ids for the released block - # 5. Release the block back to the allocator - if remove_allocated_blocks_mask.any(): - # Get indices of requests that need to release a block (relative to active requests) - requests_needing_release = torch.nonzero(remove_allocated_blocks_mask, as_tuple=True)[0] - # Convert to absolute indices in the context tensors - absolute_indices = requests_needing_release + context.paused_request_count - - # No clone needed: advanced (fancy) indexing with a tensor already returns - # a copy, not a view. - blocks_to_release = context.request_last_kv_block_id[absolute_indices] - - # Reduce block counts for requests that crossed back - context.request_kv_block_counts[absolute_indices] -= 1 - - # Get the new block counts after decrement - new_block_counts = context.request_kv_block_counts[absolute_indices] - - # Update request_last_kv_block_id to point to the previous block - # and clear the released block entry in request_to_kv_block_ids - # Vectorized implementation using advanced indexing: - # Note: new_block_counts is guaranteed to be > 0 for all requests here, since - # crossing back to a previous block implies the request had at least 2 blocks. - - # Update request_last_kv_block_id to point to the previous block (at index new_count - 1) - context.request_last_kv_block_id[absolute_indices] = context.request_to_kv_block_ids[ - absolute_indices, new_block_counts - 1 - ] - - # Clear the released block entry (at index new_count, which was the old last block) - context.request_to_kv_block_ids[absolute_indices, new_block_counts] = -1 - - # Release the blocks back to the allocator - context.kv_block_allocator.release_memory_blocks(blocks_to_release) - - # Mamba speculative rewind state update + # Mamba speculative rewind: copy accepted intermediate states in-place. if context.is_hybrid_model: - active_mamba_indices = context.mamba_metadata.request_to_mamba_state_idx[ + mamba_state_idx = context.mamba_metadata.request_to_mamba_state_idx[ active_request_slice ] - is_decode_mask = context.request_in_prefill_status_tensor[active_request_slice] == 0 - decode_mamba_indices = active_mamba_indices[is_decode_mask] - accepted_tokens_per_decode_request = accepted_tokens_per_request[is_decode_mask] - - if decode_mamba_indices.numel() > 0: - context.mamba_conv_states[:, decode_mamba_indices] = ( - context.mamba_intermediate_conv_states[ - :, decode_mamba_indices, accepted_tokens_per_decode_request - ] - ) - context.mamba_ssm_states[:, decode_mamba_indices] = ( - context.mamba_intermediate_ssm_states[ - :, decode_mamba_indices, accepted_tokens_per_decode_request - ] - ) + mamba_state_selective_copy( + intermediate_states=context.mamba_intermediate_conv_states, + current_states=context.mamba_conv_states, + prefill_status=request_in_prefill_status, + state_idx=mamba_state_idx, + accepted_counts=accepted_tokens_per_request, + num_layers=context.num_mamba_layers, + ) + mamba_state_selective_copy( + intermediate_states=context.mamba_intermediate_ssm_states, + current_states=context.mamba_ssm_states, + prefill_status=request_in_prefill_status, + state_idx=mamba_state_idx, + accepted_counts=accepted_tokens_per_request, + num_layers=context.num_mamba_layers, + ) + + return blocks_to_release, remove_mask def _sample_from_logits_2d(self, logits_2d: Tensor) -> Tensor: """Sample tokens from 2D logits using existing sampling parameters. @@ -819,18 +827,15 @@ def _sample_from_logits_2d(self, logits_2d: Tensor) -> Tensor: Tensor: Sampled tokens of shape [num_requests]. """ spec_token_list = [] - indices_list = [] - for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: - request_indices_tensor = torch.tensor( - request_indices, device=logits_2d.device, dtype=torch.long - ) + for idx_tensor, (_, temp, top_k, top_p) in zip( + self._torch_sampling_bucket_index_tensors, self._torch_sampling_buckets + ): spec_token_list.append( - self._torch_sampling_func(logits_2d[request_indices_tensor, :], temp, top_k, top_p) + self._torch_sampling_func(logits_2d[idx_tensor, :], temp, top_k, top_p) ) - indices_list.append(request_indices_tensor) spec_tokens = torch.empty(logits_2d.shape[0], device=logits_2d.device, dtype=torch.int64) - for tokens, indices in zip(spec_token_list, indices_list): + for tokens, indices in zip(spec_token_list, self._torch_sampling_bucket_index_tensors): spec_tokens[indices] = tokens return spec_tokens @@ -846,14 +851,15 @@ def _compute_serial_mtp_and_sample(self): (scattered along the first dimension) between MTP depths to avoid a redundant gather + scatter round-trip per depth. """ + nvtx_range_push("mtp-spec-decoding/serial-mtp-init") context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_slice = slice(context.paused_request_count, context.total_request_count) - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + unwrapped_model = self._unwrapped_model # On non-last pipeline stages, the model won't have decoder hidden states. - has_mtp = is_pipeline_last_stage(self.pp_group) and hasattr( + has_mtp = self._is_last_pp_stage and hasattr( unwrapped_model, '_decoder_hidden_states_cache' ) @@ -864,7 +870,7 @@ def _compute_serial_mtp_and_sample(self): # When SP is active the decoder output is in scattered format # [S/TP, B, H], but _last_accepted_seq_indices are indices into # the full (gathered) sequence. - if self.model_config.sequence_parallel: + if self._sp_enabled: hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.inference_wrapped_model.tp_group ) @@ -879,72 +885,87 @@ def _compute_serial_mtp_and_sample(self): # The next position to predict starts at that cache length. adjusted_offsets = context.request_kv_length_offsets[active_slice] processed_tokens = context.request_query_lengths[active_slice] - base_position = adjusted_offsets + processed_tokens + # Cast to int64 to match CUDA graph capture dtype expectations. + base_position = (adjusted_offsets + processed_tokens).to(torch.int64) # Start with the freshly sampled base token. next_token_ids = self._sampled_tokens_cuda[:active_request_count].clone() current_hidden = last_accepted_hidden if has_mtp else None - # Compute padding needed to make batch a multiple of tp_size for SP compatibility. - tp_size = get_pg_size(self.inference_wrapped_model.tp_group) - sp_enabled = self.model_config.sequence_parallel and tp_size > 1 - if sp_enabled: - pad_count = (tp_size - active_request_count % tp_size) % tp_size - padded_count = active_request_count + pad_count + # Compute padding needed to make batch compatible with SP and CUDA graphs. + if getattr(self, '_mtp_resolved_padded_count', None) is not None: + # CUDA-graph path: use the EP-synced padded count. + padded_count = self._mtp_resolved_padded_count + assert not self._sp_enabled or padded_count % self._tp_size == 0 + elif has_mtp: + # Eager path: pad only for SP alignment. + padded_count = active_request_count + if self._sp_enabled: + padded_count += (self._tp_size - padded_count % self._tp_size) % self._tp_size else: - pad_count = 0 + padded_count = active_request_count + pad_count = padded_count - active_request_count - # Pad hidden states to align with the tensor parallel size. - if has_mtp and sp_enabled: - if pad_count > 0: - current_hidden = F.pad(current_hidden, (0, 0, 0, 0, 0, pad_count)) + # Pad hidden states and scatter for sequence parallelism. + if has_mtp: + current_hidden = F.pad(current_hidden, (0, 0, 0, 0, 0, pad_count)) + if self._sp_enabled: + current_hidden = scatter_to_sequence_parallel_region( + current_hidden, group=self.inference_wrapped_model.tp_group + ) - current_hidden = scatter_to_sequence_parallel_region( - current_hidden, group=self.inference_wrapped_model.tp_group - ) + token_ids_buf = self._mtp_token_ids_buf[:, :padded_count] + position_ids_buf = self._mtp_position_ids_buf[:, :padded_count] + + # Zero-fill padding slots so the embedding layer never sees out-of-range IDs. + token_ids_buf[0, active_request_count:] = 0 + position_ids_buf[0, active_request_count:] = 0 - num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) - for depth in range(num_depths): - position_ids = (base_position + depth).unsqueeze(0) # [1, active_request_count] - token_ids = next_token_ids.unsqueeze(0) # [1, active_request_count] + nvtx_range_pop("mtp-spec-decoding/serial-mtp-init") + for depth in range(self._num_mtp_depths): + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}") + + token_ids_buf[0, :active_request_count] = next_token_ids + position_ids_buf[0, :active_request_count] = base_position + depth mtp_logits_2d = None if has_mtp: - # Pad token_ids and position_ids each iteration (they change per depth). - if pad_count > 0: - token_ids = F.pad(token_ids, (0, pad_count)) - position_ids = F.pad(position_ids, (0, pad_count)) - + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/forward") current_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( hidden_states=current_hidden, - next_token_ids=token_ids, - position_ids=position_ids, + next_token_ids=token_ids_buf, + position_ids=position_ids_buf, depth=depth, ) + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/forward") - # Strip padding from logits only. Hidden states stay padded+SP + # Strip padding from logits only. Hidden states stay padded+SP # between depths to avoid redundant gather/scatter round-trips. - if pad_count > 0: - mtp_logits = mtp_logits[:active_request_count] + mtp_logits = mtp_logits[:active_request_count] # mtp_logits: [active_request_count, 1, vocab_size] mtp_logits_2d = mtp_logits.squeeze(1) # [active_request_count, vocab_size] # Broadcast MTP logits across pipeline stages. if self.model_is_pipeline_parallel: + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/pp-broadcast") mtp_logits_2d = broadcast_from_last_pipeline_stage( [active_request_count, self.vocab_size], dtype=self.model_config.params_dtype, tensor=mtp_logits_2d, pp_group=self.pp_group, ) + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/pp-broadcast") # Sample speculative token using the same sampling parameters. + nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/sample") spec_tokens = self._sample_from_logits_2d(mtp_logits_2d) self._sampled_mtp_tokens_cuda[depth, :active_request_count] = spec_tokens + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/sample") # Use sampled token as input for the next depth. next_token_ids = spec_tokens + nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}") # Clean up cached hidden states. if has_mtp: @@ -983,13 +1004,10 @@ def _sample_speculative_logits( output_tokens_jumbled_list = [] token_order_list = [] - for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: - request_indices_tensor = torch.tensor( - request_indices, device=token_to_request_index.device - ) - required_indices = torch.where( - torch.isin(token_to_request_index, request_indices_tensor) - )[0] + for idx_tensor, (_, temp, top_k, top_p) in zip( + self._torch_sampling_bucket_index_tensors, self._torch_sampling_buckets + ): + required_indices = torch.where(torch.isin(token_to_request_index, idx_tensor))[0] output_tokens_jumbled_list.append( self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) ) @@ -1011,86 +1029,19 @@ def _verify_speculative_tokens( self, output_tokens: Tensor, input_tokens_required: Tensor, - request_in_prefill_status_tensor: Tensor, - repeats: Tensor, num_decode_requests: int, num_prefill_requests: int, active_request_count: int, ) -> tuple: - """Verify speculative tokens against input tokens and compute acceptance. - - Creates an accepted tokens mask where: - - For prefill requests, the token is always accepted. - - For decode requests, the first token (base token) is always accepted, then we compare - sampled tokens with input tokens and accept consecutive matches. - Then finds the index of the last accepted token per request. - - Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): - input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 - Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] - Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] - Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] - Last one indices [ 1 | 5 | 6 | 9 | 10 ] - - Returns: - tuple: (last_one_indices, accepted_tokens_mask, input_tokens_required) where - last_one_indices contains the index of the last accepted token per request. - """ - if input_tokens_required.ndim == 2: - assert ( - input_tokens_required.shape[0] == 1 - ), f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" - input_tokens_required = input_tokens_required.squeeze(0) - - # Initialize mask with False to prevent boundary bleed - accepted_tokens_mask = torch.zeros_like(input_tokens_required, dtype=torch.bool) - - # Make all prefill tokens accepted - token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) - accepted_tokens_mask[token_to_prefill_idx == 1] = True - - # Safe decode token verification without cross-batch boundary contamination - decode_mask_2d = None - if num_decode_requests > 0: - decode_len = num_decode_requests * (self.num_speculative_tokens + 1) - - decode_inputs = input_tokens_required[:decode_len].reshape( - num_decode_requests, self.num_speculative_tokens + 1 - ) - decode_outputs = output_tokens[:decode_len].reshape( - num_decode_requests, self.num_speculative_tokens + 1 - ) - - # Shift outputs right by 1 *within* each request to align sampled tokens with input targets - decode_outputs_shifted = decode_outputs.roll(1, dims=1) - decode_mask_2d = decode_inputs == decode_outputs_shifted - # The first token (base token) is always accepted - decode_mask_2d[:, 0] = True - # Enforce consecutive acceptance: cummin propagates False to the right - decode_mask_2d = decode_mask_2d.cummin(dim=1).values - accepted_tokens_mask[:decode_len] = decode_mask_2d.flatten() - - last_one_indices = torch.full( - (active_request_count,), -1, device=input_tokens_required.device + """Verify speculative tokens against input tokens (Triton kernel).""" + return verify_speculative_tokens( + input_tokens=input_tokens_required, + output_tokens=output_tokens, + num_decode_requests=num_decode_requests, + num_prefill_requests=num_prefill_requests, + num_speculative_tokens=self.num_speculative_tokens, ) - if num_decode_requests > 0: - # Summing the consecutive mask gives the count; subtract 1 for the local index - local_last_indices = decode_mask_2d.sum(dim=1) - 1 - row_offsets = torch.arange(num_decode_requests, device=last_one_indices.device) * ( - self.num_speculative_tokens + 1 - ) - last_one_indices[:num_decode_requests] = row_offsets + local_last_indices - - if num_prefill_requests > 0: - decode_len = num_decode_requests * (self.num_speculative_tokens + 1) - prefill_valid = ( - torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len - ) - last_one_indices[num_decode_requests:] = prefill_valid - - return last_one_indices, accepted_tokens_mask, input_tokens_required - def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_ids: Tensor): """ Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. @@ -1101,16 +1052,11 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ context.paused_request_count : context.total_request_count ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count - ] - - num_prefill_requests = request_in_prefill_status_tensor.sum().item() - num_decode_requests = active_request_count - num_prefill_requests # Get the logit indices for tokens that need sampling. # These indices are always needed for input_ids slicing and tracking # accepted sequence positions, even when logits are pre-sliced. + nvtx_range_push("mtp-spec-decoding/verify/logit-indices") required_logit_indices = context.speculative_required_logit_indices(logits.device) if context.config.materialize_only_last_token_logits: @@ -1120,57 +1066,76 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id required_logits = logits.squeeze(0)[ required_logit_indices, : ] # Shape [num_required, vocab_size] + nvtx_range_pop("mtp-spec-decoding/verify/logit-indices") # Sample tokens from logits + nvtx_range_push("mtp-spec-decoding/verify/sample") output_tokens, repeats = self._sample_speculative_logits( required_logits, request_in_prefill_status_tensor ) + nvtx_range_pop("mtp-spec-decoding/verify/sample") + + num_prefill_requests = context.num_prefill_requests + num_decode_requests = active_request_count - num_prefill_requests # Verify speculative tokens against input tokens. + nvtx_range_push("mtp-spec-decoding/verify/verify-tokens") input_tokens_required = input_ids[0, required_logit_indices] last_one_indices, accepted_tokens_mask, input_tokens_required = ( self._verify_speculative_tokens( output_tokens, input_tokens_required, - request_in_prefill_status_tensor, - repeats, num_decode_requests, num_prefill_requests, active_request_count, ) ) + nvtx_range_pop("mtp-spec-decoding/verify/verify-tokens") + + nvtx_range_push("mtp-spec-decoding/verify/prepare-next") + self._prepare_speculative_tokens_for_next_forward_pass( + num_decode_requests, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_tokens_mask, + input_tokens_required, + ) + nvtx_range_pop("mtp-spec-decoding/verify/prepare-next") - # Store the final sampled tokens for the next forward pass. - final_sampled_tokens = output_tokens[last_one_indices] - self._sampled_tokens_cuda[: len(final_sampled_tokens)] = final_sampled_tokens - - # Store the last accepted positions in the packed sequence for serial - # MTP computation after verification. - self._last_accepted_seq_indices = required_logit_indices[last_one_indices] - - # Extract accepted tokens and counts for decode requests. - # For prefill it is always set to 1. For decode, the first token is always accepted, - # then we compare with input tokens and accept the next tokens if its a match. - # - # Example (continuing from above): - # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] - # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] - # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only decode requests (prefill defaults to -1) - # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 - input_tokens_required[accepted_tokens_mask == 0] = -1 # Mask out non-accepted tokens - input_tokens_decode_mode = input_tokens_required[ - : num_decode_requests * (self.num_speculative_tokens + 1) - ] - input_tokens_reshaped = input_tokens_decode_mode.reshape( - -1, self.num_speculative_tokens + 1 - ) # shape: [num_decode_requests, num_speculative_tokens + 1] - - # Skip the first token of every decode request (i.e a5, b3, c6) - accepted_tokens = input_tokens_reshaped[:, 1:] - self._accepted_tokens_per_request[: accepted_tokens.shape[0], :] = accepted_tokens - self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum( - dim=1 + def _prepare_speculative_tokens_for_next_forward_pass( + self, + num_decode_requests: int, + output_tokens: torch.Tensor, + required_logit_indices: torch.Tensor, + last_one_indices: torch.Tensor, + accepted_tokens_mask: torch.Tensor, + input_tokens_required: torch.Tensor, + ): + """Prepare accepted speculative tokens for the next forward pass (Triton kernel). + + Example: + input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] + Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] (decode only; prefill → -1) + Accepted token counts [ 1 | 2 | 0 ] (prefill defaults to 0) + """ + active_request_count = last_one_indices.shape[0] + prepare_next_forward_pass( + num_decode_requests=num_decode_requests, + output_tokens=output_tokens, + required_logit_indices=required_logit_indices, + last_one_indices=last_one_indices, + accepted_tokens_mask=accepted_tokens_mask, + input_tokens=input_tokens_required, + sampled_tokens_buf=self._sampled_tokens_cuda, + last_accepted_seq_buf=self._last_accepted_seq_indices_buf, + accepted_tokens_per_request=self._accepted_tokens_per_request, + accepted_token_counts=self._accepted_token_counts_per_request, + num_speculative_tokens=self.num_speculative_tokens, ) + # Expose the active slice so downstream code sees the right length. + self._last_accepted_seq_indices = self._last_accepted_seq_indices_buf[:active_request_count] def _dynamic_step_sample_logits(self, logits: Tensor): """Sample tokens from logits for dynamic batching. @@ -1628,7 +1593,8 @@ def dummy_forward(self): if not context.cuda_graph_batch_dimensions_list: self.inference_wrapped_model.dummy_forward() - # Disable MoE padding for MTP computation + # Disable MoE padding for MTP computation. + # No CUDA graphs in this path (cuda_graph_batch_dimensions_list is empty). if self.model_config.moe_pad_experts_for_cuda_graph_inference: unwrapped_model = unwrap_model(self.inference_wrapped_model.model) set_decode_expert_padding(unwrapped_model, False) @@ -1651,10 +1617,12 @@ def dummy_forward(self): # fallback to eager dummy forward self.inference_wrapped_model.dummy_forward() - # Disable MoE padding for MTP computation + # Disable MoE padding for MTP computation, unless CUDA graphs + # are active (the graphs were captured with padding enabled). if self.model_config.moe_pad_experts_for_cuda_graph_inference: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - set_decode_expert_padding(unwrapped_model, False) + if not context.using_cuda_graph_this_step(): + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + set_decode_expert_padding(unwrapped_model, False) # When speculative decoding is active, the real EP ranks perform serial # MTP forward passes after the main forward pass. MTP layers may contain @@ -1686,10 +1654,11 @@ def _dummy_serial_mtp_forward(self): if self.model_config.expert_model_parallel_size <= 1: return - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + unwrapped_model = self._unwrapped_model - is_last_stage = is_pipeline_last_stage(self.pp_group) - has_mtp = is_last_stage and hasattr(unwrapped_model, '_decoder_hidden_states_cache') + has_mtp = self._is_last_pp_stage and hasattr( + unwrapped_model, '_decoder_hidden_states_cache' + ) if not has_mtp and not self.model_is_pipeline_parallel: # No MTP on this rank and no PP broadcast to participate in. return @@ -1697,24 +1666,30 @@ def _dummy_serial_mtp_forward(self): device = torch.cuda.current_device() dtype = self.model_config.params_dtype hidden_size = self.model_config.hidden_size - num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) - # Pad token_ids/position_ids to nearest multiple of tp_size so that the - # embedding can reduce-scatter evenly across TP ranks. - tp_size = get_pg_size(self.inference_wrapped_model.tp_group) - sp_enabled = self.model_config.sequence_parallel and tp_size > 1 - padded_count = tp_size if sp_enabled else 1 + # Use precomputed MTP CUDA graph batch size when available; + # otherwise use minimal SP-compatible size. + if getattr(self, '_mtp_resolved_padded_count', None) is not None: + padded_count = self._mtp_resolved_padded_count + assert not self._sp_enabled or padded_count % self._tp_size == 0 + elif has_mtp: + # Eager path: use TP-aligned minimum size for dummy tensors. + padded_count = self._tp_size if self._sp_enabled else 1 dummy_hidden = None if has_mtp: - # Minimal dummy tensors — just enough to drive the MTP layer forward + # Minimal dummy tensors to drive the MTP layer forward # so that the MoE all-to-all collectives are issued. - # Depth 0 uses full-format hidden; subsequent depths use SP format. - dummy_hidden = torch.zeros((1, 1, hidden_size), device=device, dtype=dtype) + dummy_hidden = torch.zeros((padded_count, 1, hidden_size), device=device, dtype=dtype) + if self._sp_enabled: + dummy_hidden = scatter_to_sequence_parallel_region( + dummy_hidden, group=self.inference_wrapped_model.tp_group + ) dummy_token_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long) dummy_position_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long) - for depth in range(num_depths): + for depth in range(self._num_mtp_depths): + nvtx_range_push(f"mtp-spec-decoding/dummy-depth-{depth}") mtp_logits_2d = None if has_mtp: dummy_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( @@ -1733,6 +1708,7 @@ def _dummy_serial_mtp_forward(self): tensor=mtp_logits_2d, pp_group=self.pp_group, ) + nvtx_range_pop(f"mtp-spec-decoding/dummy-depth-{depth}") def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: """Update the dynamic inference context after sampling. @@ -1873,17 +1849,28 @@ async def async_generate_output_tokens_dynamic_batch( if self.num_speculative_tokens > 0: # Phase 1: Verify speculative tokens using base logits only. + nvtx_range_push("mtp-spec-decoding/verify") self._dynamic_step_sample_logits_and_verify_tokens(logits, input_ids) + nvtx_range_pop("mtp-spec-decoding/verify") # Phase 2: Rewind KV cache for rejected tokens. - self._rewind_kv_cache() + nvtx_range_push("mtp-spec-decoding/rewind-kv-cache") + blocks_to_release, remove_mask = self._rewind_kv_cache() + nvtx_range_pop("mtp-spec-decoding/rewind-kv-cache") - # Disable MoE padding for MTP computation + # Disable MoE padding for MTP computation, unless CUDA graphs + # are active (the graphs were captured with padding enabled). if self.model_config.moe_pad_experts_for_cuda_graph_inference: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - set_decode_expert_padding(unwrapped_model, False) + if not context.using_cuda_graph_this_step(): + set_decode_expert_padding(self._unwrapped_model, False) # Phase 3: Compute MTP serially with correct (verified) inputs. + nvtx_range_push("mtp-spec-decoding/serial-mtp") self._compute_serial_mtp_and_sample() + nvtx_range_pop("mtp-spec-decoding/serial-mtp") + + # Phase 4: Release freed blocks. Deferred from Phase 2 so the + # data-dependent boolean-mask sync overlaps with MTP GPU work. + context.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) else: self._dynamic_step_sample_logits(logits) diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index e8bb564e759..75ff640b1b9 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -30,6 +30,8 @@ get_tensor_model_parallel_group_if_none, is_te_min_version, make_tp_sharded_tensor_for_checkpoint, + nvtx_range_pop, + nvtx_range_push, ) @@ -343,19 +345,27 @@ def compute_mtp_single_step( """ layer_idx = 0 if self.mtp.mtp_use_repeated_layer else depth + nvtx_range_push(f"mtp-single-step/depth-{depth}/mtp-layer") mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( hidden_states=hidden_states, next_token_ids=next_token_ids, position_ids=position_ids, embedding=self.embedding, ) + # CudaGraphManager.replay_graph_capture always wraps outputs in a + # tuple. Unwrap when forward_single_position is CUDA-graphed. + if isinstance(mtp_hidden, tuple): + mtp_hidden = mtp_hidden[0] + nvtx_range_pop(f"mtp-single-step/depth-{depth}/mtp-layer") output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() + nvtx_range_push(f"mtp-single-step/depth-{depth}/output-layer") logits, _ = self.output_layer(mtp_hidden, weight=output_weight, runtime_gather_output=True) logits = self._scale_logits(logits) + nvtx_range_pop(f"mtp-single-step/depth-{depth}/output-layer") return mtp_hidden, logits diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index c7631519e43..a124bc05af7 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -241,8 +241,8 @@ def _check_supported_type(meta): DynamicInferenceContext, ArgMetadata, } - assert meta.type in _SUPPORTED_TYPES or is_dataclass( - meta.value + assert ( + meta.type in _SUPPORTED_TYPES or is_dataclass(meta.value) or callable(meta.value) ), f"Cudagraphs received an arg of type {meta.type} which is not supported." @@ -258,6 +258,10 @@ def _determine_if_first_last_layer_of_this_vp_chunk(base_module): if not hasattr(base_module, "layer_number"): return True, True + # MTP layers are self-contained; don't chain them with decoder layers. + if getattr(base_module, 'is_mtp_layer', False): + return True, True + # find all first/last layers of this PP stage first_layer_numbers = [] last_layer_numbers = [] @@ -1418,8 +1422,12 @@ def __init__( config: TransformerConfig object containing CUDA graph settings for memory pooling, graph retention, gradient accumulation, FP8/FP4, and warmup steps. """ + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + rng_tracker = get_cuda_rng_tracker() self.need_backward = need_backward + # MTP is only cuda-graphed for inference (forward_single_position). + self.is_mtp = isinstance(base_module, MultiTokenPredictionLayer) if function_name is not None: func = getattr(base_module, function_name) @@ -1494,6 +1502,7 @@ def get_cudagraph_runner(self, megatron_module, args, kwargs, reuse_cudagraphs): over different microbatches by tracking their respective fwd and bwd passes.''' if reuse_cudagraphs: is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] + is_mtp_inference = self.is_mtp if is_inference_mode: is_static_batching = kwargs['inference_context'].is_static_batching() if is_static_batching: @@ -1503,6 +1512,10 @@ def get_cudagraph_runner(self, megatron_module, args, kwargs, reuse_cudagraphs): else: padded_batch_dimensions = kwargs['inference_context'].padded_batch_dimensions runner = self.inference_cudagraphs_lookup_table[padded_batch_dimensions] + elif is_mtp_inference: + # MTP layers have no inference_context; key by hidden_states shape. + mtp_key = ('mtp', kwargs['hidden_states'].shape) + runner = self.inference_cudagraphs_lookup_table.get(mtp_key) else: # Todo: For training, we could also cache runners based on input shape. # If autograd is currently disabled, it doesnt matter if a runner was created @@ -1545,6 +1558,8 @@ def is_valid(r): ) else: self.inference_cudagraphs_lookup_table[padded_batch_dimensions] = runner + elif is_mtp_inference: + self.inference_cudagraphs_lookup_table[mtp_key] = runner else: # Create cudagraphs for every microbatch if _CudagraphGlobalRecord.cudagraph_created: @@ -1574,7 +1589,9 @@ def __call__(self, megatron_module, args, kwargs): kwargs (dict): The keyword args to be passed to the module. """ - is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] + is_inference_mode = ( + 'inference_context' in kwargs.keys() and kwargs['inference_context'] + ) or self.is_mtp is_in_checkpoint_fwd = is_checkpointing() if HAVE_TE_GRAPHS: is_in_checkpoint_fwd = is_in_checkpoint_fwd or is_fp8_activation_recompute_enabled() @@ -1590,9 +1607,22 @@ def __call__(self, megatron_module, args, kwargs): out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs) else: if is_inference_mode: + # MTP must match the main model's eager/graph mode so all EP + # ranks take the same code path. Skip during graph capture. + if ( + self.is_mtp + and not getattr(megatron_module, 'use_mtp_cuda_graphs', False) + and not is_graph_capturing() + ): + return self.func(*args, **kwargs) + # Inference generation mode creates graphs immediately runner = self.get_cudagraph_runner(megatron_module, args, kwargs, True) + if not runner.fwd_graph_recorded and self.is_mtp and not is_graph_capturing(): + # No pre-warmed graph for this batch size — run eagerly. + return self.func(*args, **kwargs) + if not runner.fwd_graph_recorded: # Reuse graph input-output buffers for inference local_args, local_kwargs = args, kwargs diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index e4865d3d89a..9b59ba86b6d 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -843,6 +843,7 @@ def __init__( mamba_submodules: Optional[MambaStackSubmodules] = None, ): super().__init__(config=config) + self.is_mtp_layer = True self.sequence_parallel = config.sequence_parallel self.submodules = submodules self.layer_number = layer_number + get_mtp_layer_offset(self.config, vp_stage) @@ -949,6 +950,19 @@ def __init__( ) self.offload_context = nullcontext() + # Create cuda graph manager for forward_single_position so that + # the full MTP forward (embedding, projection, transformer, layernorm) + # is captured in a single graph. + if config.cuda_graph_impl == "local" and not config.cuda_graph_scope: + from megatron.core.transformer.cuda_graphs import CudaGraphManager + + self.cudagraph_manager = CudaGraphManager( + config, + base_module=self, + function_name="forward_single_position", + need_backward=False, + ) + def _get_embeddings( self, input_ids: torch.Tensor, @@ -1110,6 +1124,14 @@ def _postprocess(self, hidden_states: torch.Tensor): return hidden_states + def _should_call_local_cudagraph(self, *args, **kwargs): + """MTP cuda-graphs forward_single_position, not forward. + + Disable the MegatronModule.__call__ interceptor so the training forward + path is not routed through the cuda graph manager. + """ + return False + def forward_single_position( self, hidden_states: Tensor, @@ -1120,7 +1142,6 @@ def forward_single_position( rotary_pos_emb: Optional[Tensor] = None, rotary_pos_cos: Optional[Tensor] = None, rotary_pos_sin: Optional[Tensor] = None, - inference_params=None, packed_seq_params: Optional[PackedSeqParams] = None, sequence_len_offset: Optional[Tensor] = None, ) -> Tensor: @@ -1151,7 +1172,6 @@ def forward_single_position( rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, - inference_params=inference_params, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, ) diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh index 641019c9750..ed0a5a622c2 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh @@ -3,7 +3,6 @@ set -u # Libraries. -uv pip install simpy uv pip install tiktoken # Environment variables. diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index 9c6564f6989..6bd96d5dae5 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -14,7 +14,7 @@ from transformer_engine.pytorch.fp8 import check_fp8_support from megatron.core import parallel_state -from megatron.core.inference.config import InferenceConfig +from megatron.core.inference.config import InferenceConfig, MambaInferenceStateConfig from megatron.core.inference.contexts import DynamicInferenceContext, StaticInferenceContext from megatron.core.inference.contexts.dynamic_context import MaxSequenceLengthOverflowError from megatron.core.inference.inference_request import ( @@ -34,6 +34,8 @@ get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import Float16Module @@ -64,6 +66,7 @@ def setup_model( sequence_parallel: bool = False, expert_model_parallel_size: int = 1, num_moe_experts: int = None, + hybrid_layer_pattern: str = None, ): Utils.initialize_model_parallel( tensor_model_parallel_size=tensor_model_parallel_size, @@ -98,31 +101,51 @@ def setup_model( expert_model_parallel_size=expert_model_parallel_size, num_moe_experts=num_moe_experts, add_bias_linear=num_moe_experts is None, + **( + dict(is_hybrid_model=True, mamba_num_heads=2, mamba_head_dim=16, mamba_num_groups=2) + if hybrid_layer_pattern + else {} + ), ) if dtype == torch.bfloat16: transformer_config.bf16 = True - layer_spec = get_gpt_layer_local_spec() + mamba_inference_state_config = None + if hybrid_layer_pattern: + model = MambaModel( + config=transformer_config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + hybrid_layer_pattern=hybrid_layer_pattern, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ).cuda() + mamba_inference_state_config = MambaInferenceStateConfig.from_model(model) + else: + layer_spec = get_gpt_layer_local_spec() - mtp_block_spec = None - if mtp_num_layers > 0: - mtp_block_spec = get_gpt_mtp_block_spec( - config=transformer_config, spec=layer_spec, use_transformer_engine=False - ) + mtp_block_spec = None + if mtp_num_layers > 0: + mtp_block_spec = get_gpt_mtp_block_spec( + config=transformer_config, spec=layer_spec, use_transformer_engine=False + ) - gpt_model = GPTModel( - config=transformer_config, - transformer_layer_spec=layer_spec, - vocab_size=self.vocab_size, - max_sequence_length=self.sequence_length, - parallel_output=True, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), - mtp_block_spec=mtp_block_spec, - ).cuda() - gpt_model.eval() + model = GPTModel( + config=transformer_config, + transformer_layer_spec=layer_spec, + vocab_size=self.vocab_size, + max_sequence_length=self.sequence_length, + parallel_output=True, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + mtp_block_spec=mtp_block_spec, + ).cuda() + + model.eval() if dtype == torch.bfloat16: - gpt_model = Float16Module(gpt_model.config, gpt_model) + model = Float16Module(model.config, model) if static: inference_context = StaticInferenceContext( @@ -142,10 +165,11 @@ def setup_model( block_size_tokens=block_size_tokens, enable_prefix_caching=enable_prefix_caching, max_requests=max_requests, + mamba_inference_state_config=mamba_inference_state_config, ), ) - inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_context) + inference_wrapped_model = GPTInferenceWrapper(model, inference_context) inference_wrapped_model.model_is_pipeline_parallel = not ( parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() @@ -1124,6 +1148,9 @@ def mock_sampling_func(logits, *args, **kwargs): # Override sampling to return our predictable mock outputs self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 1, 0.0)] + self.text_generation_controller._torch_sampling_bucket_index_tensors = [ + torch.tensor([0, 1], device='cuda', dtype=torch.long) + ] self.text_generation_controller._torch_sampling_func = mock.MagicMock( side_effect=mock_sampling_func ) @@ -1156,6 +1183,7 @@ def test_rewind_kv_cache(self, is_hybrid_model): num_speculative_tokens=3, block_size_tokens=4, max_requests=16, + hybrid_layer_pattern="***M" if is_hybrid_model else None, ) self.text_generation_controller.num_speculative_tokens = 3 ctx = self.text_generation_controller.inference_wrapped_model.inference_context @@ -1177,13 +1205,13 @@ def test_rewind_kv_cache(self, is_hybrid_model): ) if is_hybrid_model: - ctx.is_hybrid_model = True - ctx.mamba_metadata = mock.MagicMock() - ctx.mamba_metadata.request_to_mamba_state_idx = torch.tensor([0, 1], device='cuda') - ctx.mamba_ssm_states = torch.zeros((1, 2, 16), device='cuda') - ctx.mamba_intermediate_ssm_states = torch.ones((1, 2, 4, 16), device='cuda') * 99 - ctx.mamba_conv_states = torch.zeros((1, 2, 8), device='cuda') - ctx.mamba_intermediate_conv_states = torch.ones((1, 2, 4, 8), device='cuda') * 77 + ctx.mamba_metadata.request_to_mamba_state_idx[:2] = torch.tensor( + [0, 1], dtype=torch.int32, device='cuda' + ) + ctx.mamba_ssm_states.zero_() + ctx.mamba_intermediate_ssm_states.fill_(99) + ctx.mamba_conv_states.zero_() + ctx.mamba_intermediate_conv_states.fill_(77) # Mock accepted token counts: Req 0 accepts 1 (rejects 2), Req 1 accepts 0 (rejects 3) self.text_generation_controller._init_mtp_sampling_tensor() @@ -1191,7 +1219,8 @@ def test_rewind_kv_cache(self, is_hybrid_model): [1, 0], device='cuda' ) - self.text_generation_controller._rewind_kv_cache() + blocks_to_release, remove_mask = self.text_generation_controller._rewind_kv_cache() + ctx.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) # Assert offsets updated assert torch.equal( @@ -1251,6 +1280,9 @@ def test_speculative_multinomial_sampling(self): # Set up a bucket that forces multinomial sampling (top_p = 0.9, top_k = 0) # _torch_sampling_buckets format: (indices, temp, top_k, top_p) self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 0, 0.9)] + self.text_generation_controller._torch_sampling_bucket_index_tensors = [ + torch.tensor([0, 1], device='cuda', dtype=torch.long) + ] # Since we are actually testing the internal math of `_torch_sampling_func` handling the shapes, # we DO NOT mock `_torch_sampling_func` here. We want it to run natively to prove it doesn't crash. @@ -1315,7 +1347,8 @@ def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): [1, 0], device='cuda' ) - self.text_generation_controller._rewind_kv_cache() + blocks_to_release, remove_mask = self.text_generation_controller._rewind_kv_cache() + ctx.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) # Req 1 should have released block 20 (ref count decremented). assert ctx.kv_block_allocator.block_ref_counts[20].item() == 1 @@ -1355,7 +1388,8 @@ def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): [0], device='cuda' ) - self.text_generation_controller._rewind_kv_cache() + blocks_to_release, remove_mask = self.text_generation_controller._rewind_kv_cache() + ctx.kv_block_allocator.release_memory_blocks(blocks_to_release[remove_mask]) # Only block 40 should be released, not blocks 10, 20, or 30. assert ctx.request_kv_block_counts[0].item() == 3 @@ -1488,6 +1522,9 @@ def test_mtp_sp_padding_real_ranks(self, active_request_count): # Greedy sampling: top_k=1 selects the argmax token deterministically. ctrl._torch_sampling_buckets = [(list(range(active_request_count)), 1.0, 1, 0.0)] + ctrl._torch_sampling_bucket_index_tensors = [ + torch.arange(active_request_count, device='cuda', dtype=torch.long) + ] # Run the MTP forward pass ctrl._compute_serial_mtp_and_sample() diff --git a/tests/unit_tests/transformer/experimental_attention_variant/test_attention_variant_dsa.py b/tests/unit_tests/transformer/experimental_attention_variant/test_attention_variant_dsa.py index 96253a4ca10..192ff4ef594 100644 --- a/tests/unit_tests/transformer/experimental_attention_variant/test_attention_variant_dsa.py +++ b/tests/unit_tests/transformer/experimental_attention_variant/test_attention_variant_dsa.py @@ -67,6 +67,7 @@ def setup_method(self): yield Utils.destroy_model_parallel() + @pytest.mark.flaky_in_dev def test_rotate_activation_shape(self): """Test that rotate_activation preserves shape.""" batch_size = 2 @@ -79,6 +80,7 @@ def test_rotate_activation_shape(self): assert output.shape == x.shape assert output.dtype == torch.bfloat16 + @pytest.mark.flaky_in_dev def test_rotate_activation_dtype_check(self): """Test that rotate_activation only accepts bfloat16.""" x = torch.randn(16, 2, 128, dtype=torch.float32).cuda() From 06484e340df4da6fabbe060a75d3152b90921230 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Wed, 15 Apr 2026 09:37:20 -0700 Subject: [PATCH 2/3] Add triton kernels Signed-off-by: Keshav Santhanam --- .../triton_kernels.py | 437 ++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 megatron/core/inference/text_generation_controllers/triton_kernels.py diff --git a/megatron/core/inference/text_generation_controllers/triton_kernels.py b/megatron/core/inference/text_generation_controllers/triton_kernels.py new file mode 100644 index 00000000000..a27995f601b --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/triton_kernels.py @@ -0,0 +1,437 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import math + +import torch + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + from unittest.mock import MagicMock + + from megatron.core.utils import null_decorator + + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + HAVE_TRITON = False + + +# --------------------------------------------------------------------------- +# Kernel 1: KV-cache rewind for speculative decoding +# --------------------------------------------------------------------------- +@triton.jit +def _rewind_kv_cache_kernel( + # Per-request input (read-only) + ACCEPTED_COUNTS_PTR, + PREFILL_STATUS_PTR, + # Per-request state (read-write, updated in-place) + LAST_KV_BLOCK_OFFSET_PTR, + KV_LENGTH_OFFSETS_PTR, + KV_BLOCK_COUNTS_PTR, + LAST_KV_BLOCK_ID_PTR, + # 2-D table [N, max_blocks] (read-write) + KV_BLOCK_IDS_PTR, + # Per-request outputs + BLOCKS_TO_RELEASE_PTR, + REMOVE_MASK_PTR, + # Strides / limits + kv_block_ids_stride, + max_blocks_minus_1, + # Compile-time constants + NUM_SPEC_TOKENS: tl.constexpr, + BLOCK_SIZE_TOKENS: tl.constexpr, +): + """Rewind KV-cache bookkeeping for one request after speculative verification. + + Grid: (active_request_count,) + Each program handles exactly one request. + """ + pid = tl.program_id(0) + + # --- Load per-request scalars --- + accepted = tl.load(ACCEPTED_COUNTS_PTR + pid) + prefill = tl.load(PREFILL_STATUS_PTR + pid) + last_offset = tl.load(LAST_KV_BLOCK_OFFSET_PTR + pid) + kv_length = tl.load(KV_LENGTH_OFFSETS_PTR + pid) + block_count = tl.load(KV_BLOCK_COUNTS_PTR + pid) + last_block_id = tl.load(LAST_KV_BLOCK_ID_PTR + pid) + + # --- Compute rewind (zero for prefill requests) --- + num_to_rewind = tl.where(prefill == 1, 0, NUM_SPEC_TOKENS - accepted) + diff = last_offset - num_to_rewind + remove = diff < 0 + + # Python-style modulo: ((diff % M) + M) % M to handle negative diff + new_offset = ((diff % BLOCK_SIZE_TOKENS) + BLOCK_SIZE_TOKENS) % BLOCK_SIZE_TOKENS + tl.store(LAST_KV_BLOCK_OFFSET_PTR + pid, new_offset) + tl.store(KV_LENGTH_OFFSETS_PTR + pid, kv_length - num_to_rewind) + + # Save current last block id (will be released by caller if remove is True) + tl.store(BLOCKS_TO_RELEASE_PTR + pid, last_block_id) + + # Decrement block count when a block boundary was crossed + new_block_count = tl.where(remove, block_count - 1, block_count) + tl.store(KV_BLOCK_COUNTS_PTR + pid, new_block_count) + + # Gather previous block id from the 2-D table + kv_row_base = pid.to(tl.int64) * kv_block_ids_stride + prev_idx = tl.maximum(new_block_count - 1, 0) + prev_block_id = tl.load(KV_BLOCK_IDS_PTR + kv_row_base + prev_idx) + + # Conditionally update last block id + tl.store(LAST_KV_BLOCK_ID_PTR + pid, tl.where(remove, prev_block_id, last_block_id)) + + # Clear released block entry via scatter + scatter_idx = tl.minimum(new_block_count, max_blocks_minus_1) + current_val = tl.load(KV_BLOCK_IDS_PTR + kv_row_base + scatter_idx) + tl.store(KV_BLOCK_IDS_PTR + kv_row_base + scatter_idx, tl.where(remove, -1, current_val)) + + # Output remove mask for the caller (to release blocks outside this kernel) + tl.store(REMOVE_MASK_PTR + pid, remove) + + +def rewind_kv_cache( + accepted_counts, + prefill_status, + last_kv_block_offset, + kv_length_offsets, + kv_block_counts, + last_kv_block_id, + kv_block_ids, + num_speculative_tokens, + block_size_tokens, +): + """Launch the KV-cache rewind Triton kernel. + + Returns: + (blocks_to_release, remove_mask) — same semantics as the original + torch.compile'd ``_rewind_kv_cache`` (KV-cache portion only; Mamba + state updates are handled separately by the caller). + """ + N = accepted_counts.shape[0] + if N == 0: + return ( + torch.empty(0, device=accepted_counts.device, dtype=last_kv_block_id.dtype), + torch.empty(0, device=accepted_counts.device, dtype=torch.bool), + ) + + blocks_to_release = torch.empty_like(last_kv_block_id) + remove_mask = torch.empty(N, device=accepted_counts.device, dtype=torch.bool) + + _rewind_kv_cache_kernel[(N,)]( + accepted_counts, + prefill_status, + last_kv_block_offset, + kv_length_offsets, + kv_block_counts, + last_kv_block_id, + kv_block_ids, + blocks_to_release, + remove_mask, + kv_block_ids_stride=kv_block_ids.stride(0), + max_blocks_minus_1=kv_block_ids.shape[1] - 1, + NUM_SPEC_TOKENS=num_speculative_tokens, + BLOCK_SIZE_TOKENS=block_size_tokens, + ) + return blocks_to_release, remove_mask + + +# --------------------------------------------------------------------------- +# Kernel 2: Verify speculative tokens +# --------------------------------------------------------------------------- +@triton.jit +def _verify_speculative_tokens_kernel( + INPUT_TOKENS_PTR, + OUTPUT_TOKENS_PTR, + # Outputs + ACCEPTED_MASK_PTR, + LAST_ONE_INDICES_PTR, + # Runtime scalars + num_decode_requests, + decode_len, + # Compile-time constants + STRIDE: tl.constexpr, # num_speculative_tokens + 1 + BLOCK_SIZE: tl.constexpr, # next_power_of_2(STRIDE) +): + """Verify speculative tokens for one request. + + Grid: (active_request_count,) + Programs 0..num_decode_requests-1 handle decode requests. + Programs num_decode_requests..end handle prefill requests. + """ + pid = tl.program_id(0) + + if pid < num_decode_requests: + base = pid * STRIDE + offsets = tl.arange(0, BLOCK_SIZE) + valid = offsets < STRIDE + + input_toks = tl.load(INPUT_TOKENS_PTR + base + offsets, mask=valid, other=0) + + # Build shifted output: shifted[i] = output[i-1]. + # Position 0 uses a dummy load (always accepted regardless). + safe_shifted = tl.where(offsets > 0, offsets - 1, 0) + shifted_output = tl.load(OUTPUT_TOKENS_PTR + base + safe_shifted, mask=valid, other=0) + + # First token is always accepted; rest must match shifted output. + match = tl.where(offsets == 0, 1, (input_toks == shifted_output).to(tl.int32)) + match = tl.where(valid, match, 0) + + # Consecutive acceptance via cumulative-sum trick: + # accepted[i] iff cumsum(match)[i] == i + 1 + cumsum = tl.cumsum(match, axis=0) + accepted = (cumsum == (offsets + 1)) & valid + + tl.store(ACCEPTED_MASK_PTR + base + offsets, accepted, mask=valid) + + accepted_count = tl.sum(accepted.to(tl.int32)) + tl.store(LAST_ONE_INDICES_PTR + pid, (base + accepted_count - 1).to(tl.int64)) + else: + # Prefill request — single token, always accepted + prefill_idx = decode_len + (pid - num_decode_requests) + tl.store(ACCEPTED_MASK_PTR + prefill_idx, 1) + tl.store(LAST_ONE_INDICES_PTR + pid, prefill_idx.to(tl.int64)) + + +def verify_speculative_tokens( + input_tokens, output_tokens, num_decode_requests, num_prefill_requests, num_speculative_tokens +): + """Launch the speculative-token verification Triton kernel. + + Returns: + (last_one_indices, accepted_tokens_mask, input_tokens) + matching the original ``_verify_speculative_tokens`` signature. + """ + if input_tokens.ndim == 2: + input_tokens = input_tokens.squeeze(0) + + device = input_tokens.device + active_request_count = num_decode_requests + num_prefill_requests + stride = num_speculative_tokens + 1 + decode_len = num_decode_requests * stride + + accepted_tokens_mask = torch.zeros_like(input_tokens, dtype=torch.bool) + last_one_indices = torch.full((active_request_count,), -1, device=device, dtype=torch.long) + + if active_request_count > 0: + block_size = triton.next_power_of_2(stride) + _verify_speculative_tokens_kernel[(active_request_count,)]( + input_tokens, + output_tokens, + accepted_tokens_mask, + last_one_indices, + num_decode_requests=num_decode_requests, + decode_len=decode_len, + STRIDE=stride, + BLOCK_SIZE=block_size, + ) + + return last_one_indices, accepted_tokens_mask, input_tokens + + +# --------------------------------------------------------------------------- +# Kernel 3: Prepare speculative tokens for next forward pass +# --------------------------------------------------------------------------- +@triton.jit +def _prepare_next_forward_pass_kernel( + OUTPUT_TOKENS_PTR, + REQUIRED_LOGIT_INDICES_PTR, + LAST_ONE_INDICES_PTR, + INPUT_TOKENS_PTR, + ACCEPTED_MASK_PTR, + # Outputs + SAMPLED_TOKENS_OUT_PTR, + LAST_ACCEPTED_SEQ_OUT_PTR, + ACCEPTED_TOKENS_OUT_PTR, + ACCEPTED_COUNTS_OUT_PTR, + # Strides + accepted_tokens_out_stride, + # Runtime scalars + num_decode_requests, + # Compile-time constants + STRIDE: tl.constexpr, # num_speculative_tokens + 1 + NUM_SPEC_TOKENS: tl.constexpr, + SPEC_BLOCK_SIZE: tl.constexpr, # next_power_of_2(NUM_SPEC_TOKENS) +): + """Gather final tokens and extract accepted speculative tokens per request. + + Grid: (active_request_count,) + """ + pid = tl.program_id(0) + + # --- Gather final sampled token and sequence index for every request --- + idx = tl.load(LAST_ONE_INDICES_PTR + pid) + tl.store(SAMPLED_TOKENS_OUT_PTR + pid, tl.load(OUTPUT_TOKENS_PTR + idx)) + tl.store(LAST_ACCEPTED_SEQ_OUT_PTR + pid, tl.load(REQUIRED_LOGIT_INDICES_PTR + idx)) + + # --- For decode requests: extract accepted tokens and count --- + if pid < num_decode_requests: + base = pid * STRIDE + spec_offsets = tl.arange(0, SPEC_BLOCK_SIZE) + spec_valid = spec_offsets < NUM_SPEC_TOKENS + token_positions = base + 1 + spec_offsets # skip first (base) token + + tokens = tl.load(INPUT_TOKENS_PTR + token_positions, mask=spec_valid, other=0) + mask_val = tl.load(ACCEPTED_MASK_PTR + token_positions, mask=spec_valid, other=0) + accepted = mask_val != 0 + + result = tl.where(accepted & spec_valid, tokens, -1) + + out_base = pid.to(tl.int64) * accepted_tokens_out_stride + tl.store(ACCEPTED_TOKENS_OUT_PTR + out_base + spec_offsets, result, mask=spec_valid) + + count = tl.sum((accepted & spec_valid).to(tl.int64)) + tl.store(ACCEPTED_COUNTS_OUT_PTR + pid, count) + + +def prepare_next_forward_pass( + num_decode_requests, + output_tokens, + required_logit_indices, + last_one_indices, + accepted_tokens_mask, + input_tokens, + sampled_tokens_buf, + last_accepted_seq_buf, + accepted_tokens_per_request, + accepted_token_counts, + num_speculative_tokens, +): + """Launch the prepare-next-forward-pass Triton kernel. + + Writes results into the pre-allocated buffers provided by the caller. + """ + active_request_count = last_one_indices.shape[0] + if active_request_count == 0: + return + + stride = num_speculative_tokens + 1 + spec_block_size = triton.next_power_of_2(num_speculative_tokens) + + _prepare_next_forward_pass_kernel[(active_request_count,)]( + output_tokens, + required_logit_indices, + last_one_indices, + input_tokens, + accepted_tokens_mask, + sampled_tokens_buf, + last_accepted_seq_buf, + accepted_tokens_per_request, + accepted_token_counts, + accepted_tokens_out_stride=accepted_tokens_per_request.stride(0), + num_decode_requests=num_decode_requests, + STRIDE=stride, + NUM_SPEC_TOKENS=num_speculative_tokens, + SPEC_BLOCK_SIZE=spec_block_size, + ) + + +# --------------------------------------------------------------------------- +# Kernel 4: Mamba state selective copy (eliminates temporary allocations) +# --------------------------------------------------------------------------- +@triton.jit +def _mamba_state_selective_copy_kernel( + # Source: intermediate states [L, M, S+1, *state_shape] + SRC_PTR, + # Destination: current states [L, M, *state_shape] + DST_PTR, + # Per-request index arrays + PREFILL_STATUS_PTR, # [N] 0=decode, 1=prefill + STATE_IDX_PTR, # [N] maps request → mamba state slot + ACCEPTED_PTR, # [N] accepted token index per request + # Strides (in elements) + src_stride_layer, + src_stride_slot, + src_stride_spec, + dst_stride_layer, + dst_stride_slot, + # Data size + STATE_SIZE, + # Compile-time + BLOCK_SIZE: tl.constexpr, +): + """Copy intermediate Mamba state to current state for decode requests. + + Grid: (N, L, num_chunks) + - dim 0: active request index + - dim 1: mamba layer index + - dim 2: chunk of the flattened state vector + + No-op for prefill requests. + """ + pid_req = tl.program_id(0) + pid_layer = tl.program_id(1) + pid_chunk = tl.program_id(2) + + # Skip prefill requests immediately. + prefill = tl.load(PREFILL_STATUS_PTR + pid_req) + if prefill == 1: + return + + state_idx = tl.load(STATE_IDX_PTR + pid_req).to(tl.int64) + accepted = tl.load(ACCEPTED_PTR + pid_req).to(tl.int64) + + chunk_start = pid_chunk * BLOCK_SIZE + offsets = tl.arange(0, BLOCK_SIZE) + elem_offsets = chunk_start + offsets + mask = elem_offsets < STATE_SIZE + + src_base = ( + pid_layer.to(tl.int64) * src_stride_layer + + state_idx * src_stride_slot + + accepted * src_stride_spec + ) + dst_base = pid_layer.to(tl.int64) * dst_stride_layer + state_idx * dst_stride_slot + + data = tl.load(SRC_PTR + src_base + elem_offsets, mask=mask) + tl.store(DST_PTR + dst_base + elem_offsets, data, mask=mask) + + +def mamba_state_selective_copy( + intermediate_states, current_states, prefill_status, state_idx, accepted_counts, num_layers +): + """Copy accepted intermediate Mamba states to current states in-place. + + For each decode request, copies + ``intermediate[layer, slot, accepted_count, ...]`` → + ``current[layer, slot, ...]`` for every Mamba layer. + + Args: + intermediate_states: ``(L, M, S+1, *state_shape)`` — intermediate buffer. + current_states: ``(L, M, *state_shape)`` — current state buffer (updated in-place). + prefill_status: ``(N,)`` int tensor — 0 for decode, 1 for prefill. + state_idx: ``(N,)`` int tensor — mamba state slot index per request. + accepted_counts: ``(N,)`` int tensor — accepted token index per request. + num_layers: number of Mamba layers (first dim of the state tensors). + """ + N = prefill_status.shape[0] + if N == 0: + return + + # The state vector to copy per (layer, request) is the product of all + # trailing dimensions after the speculative-token axis. + # intermediate shape: (L, M, S+1, *state_shape) → state_size = prod(state_shape) + state_size = math.prod(intermediate_states.shape[3:]) + + BLOCK_SIZE = 1024 + num_chunks = triton.cdiv(state_size, BLOCK_SIZE) + grid = (N, num_layers, num_chunks) + + _mamba_state_selective_copy_kernel[grid]( + intermediate_states, + current_states, + prefill_status, + state_idx, + accepted_counts, + src_stride_layer=intermediate_states.stride(0), + src_stride_slot=intermediate_states.stride(1), + src_stride_spec=intermediate_states.stride(2), + dst_stride_layer=current_states.stride(0), + dst_stride_slot=current_states.stride(1), + STATE_SIZE=state_size, + BLOCK_SIZE=BLOCK_SIZE, + ) From f8424ad3219edd63dac2de527d1da38d4863b76f Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Wed, 15 Apr 2026 13:29:56 -0700 Subject: [PATCH 3/3] Load MTP and latent MoE layers in bf16 Signed-off-by: Keshav Santhanam --- megatron/core/transformer/moe/moe_layer.py | 53 +++++++++++-------- .../transformer/multi_token_prediction.py | 12 ++++- .../core/transformer/transformer_config.py | 2 +- megatron/training/arguments.py | 12 ++++- 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 35b567679fe..147e97ff58d 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from contextlib import nullcontext from dataclasses import dataclass from typing import Optional, Protocol @@ -249,28 +250,36 @@ def __init__( linear_cls = InferenceLinear else: linear_cls = TELinear - self.fc1_latent_proj = linear_cls( - self.config.hidden_size, - self.config.moe_latent_size, - parallel_mode="duplicated", - config=self.config, - init_method=self.config.init_method, - bias=self.config.add_bias_linear, - skip_bias_add=False, - skip_weight_param_allocation=False, - is_expert=False, - ) - self.fc2_latent_proj = linear_cls( - self.config.moe_latent_size, - self.config.hidden_size, - parallel_mode="duplicated", - config=self.config, - init_method=self.config.output_layer_init_method, - bias=self.config.add_bias_linear, - skip_bias_add=False, - skip_weight_param_allocation=False, - is_expert=False, - ) + # Latent projections remain in bf16 for inference; disable fp8_model_init + if not torch.is_grad_enabled() and self.config.fp8_param: + import transformer_engine.pytorch + + disable_fp8_ctx = transformer_engine.pytorch.fp8_model_init(enabled=False) + else: + disable_fp8_ctx = nullcontext() + with disable_fp8_ctx: + self.fc1_latent_proj = linear_cls( + self.config.hidden_size, + self.config.moe_latent_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) + self.fc2_latent_proj = linear_cls( + self.config.moe_latent_size, + self.config.hidden_size, + parallel_mode="duplicated", + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + skip_bias_add=False, + skip_weight_param_allocation=False, + is_expert=False, + ) # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 9b59ba86b6d..371fbd2334d 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -1468,7 +1468,11 @@ def _build_layers(self, pg_collection): def build_layer_legacy(layer_spec, layer_number): """Build layer using legacy spec-based approach.""" - fp8_init_context = get_fp8_context(self.config, is_init=True) + # MTP layers remain in bf16 for inference; skip fp8_model_init + if not torch.is_grad_enabled(): + fp8_init_context = nullcontext() + else: + fp8_init_context = get_fp8_context(self.config, is_init=True) with fp8_init_context: module = build_module( layer_spec, @@ -1482,7 +1486,11 @@ def build_layer_legacy(layer_spec, layer_number): def build_layer_with_pattern(layer_spec, layer_number, mtp_layer_pattern, mamba_submodules): """Build layer using pattern-based approach (new Mamba path).""" - fp8_init_context = get_fp8_context(self.config, is_init=True) + # MTP layers remain in bf16 for inference; skip fp8_model_init + if not torch.is_grad_enabled(): + fp8_init_context = nullcontext() + else: + fp8_init_context = get_fp8_context(self.config, is_init=True) with fp8_init_context: module = build_module( layer_spec, diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index af69d3fe662..9df95abfe7f 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1204,7 +1204,7 @@ def __post_init__(self): raise ValueError( "fp8_param must be enabled when using " "--transformer-impl='inference_optimized' with --fp8-recipe='mxfp8'. " - "Please set --fp8-param-gather." + "Please set --fp8-param." ) assert self.inference_grouped_gemm_backend in ('auto', 'torch', 'te'), ( diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index e4755971edf..b1683161318 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -925,6 +925,10 @@ def validate_args(args, defaults={}): assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \ '--fp8-param-gather only supported with distributed optimizer, torch fsdp2, megatron fsdp, or inference mode' + if getattr(args, 'fp8_param', False) and not args.fp8_param_gather: + assert not torch.is_grad_enabled(), \ + '--fp8-param (without --fp8-param-gather) is only supported in inference mode' + # FP4 and FP8 are mutually exclusive if args.fp4 and args.fp8: raise ValueError("--fp4-format and --fp8-format cannot be used simultaneously. Please choose one.") @@ -1665,7 +1669,8 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['rotary_interleaved'] = args.rotary_interleaved kw_args['num_layers_in_first_pipeline_stage']= args.decoder_first_pipeline_num_layers kw_args['num_layers_in_last_pipeline_stage']= args.decoder_last_pipeline_num_layers - kw_args['fp8_param'] = args.fp8_param_gather + kw_args['fp8_param'] = getattr(args, 'fp8_param', False) or args.fp8_param_gather + if args.swiglu: kw_args['activation_func'] = F.silu kw_args['gated_linear_unit'] = True @@ -1736,6 +1741,11 @@ def _add_transformer_engine_args(parser): help='Keep the compute param in fp8 (do not use any other intermediate ' 'dtype) and perform the param all-gather in fp8.') + group.add_argument('--fp8-param', action='store_true', + help='Initialize model parameters in fp8 format. ' + 'Use for inference with --fp8-recipe=mxfp8. ' + 'For training, use --fp8-param-gather instead.') + # FP4 related arguments group.add_argument('--te-precision-config-file', default=None, help='Configuration file to select per-module precision overrides. '