diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 8dd614fdaaa..5421dd92d7b 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -11,6 +11,9 @@ from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region @@ -201,6 +204,8 @@ def __init__( quant_config = get_quant_config_or_none(name, self.config.quant_recipe) module.finish_init(quant_config) + self.disable_param_offloading = True + def set_input_tensor(self, input_tensor: Tensor) -> None: """Sets input tensor to the model. @@ -217,6 +222,24 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' self.decoder.set_input_tensor(input_tensor[0]) + def preprocess_for_fine_grained_offloading(self): + """Preprocess for fine-grained activation offloading.""" + off_interface.init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, + vp_stage=self.vp_stage, + min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, + ) + if self.disable_param_offloading: + for param in self.decoder.parameters(): + off_interface.mark_not_offloadable(param) + if self.pre_process: + for param in self.embedding.parameters(): + off_interface.mark_not_offloadable(param) + if self.post_process: + for param in self.output_layer.parameters(): + off_interface.mark_not_offloadable(param) + self.disable_param_offloading = False + def forward( self, input_ids: Tensor, @@ -241,6 +264,9 @@ def forward( # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + if self.config.fine_grained_activation_offloading: + self.preprocess_for_fine_grained_offloading() + inference_context = deprecate_inference_params(inference_context, inference_params) in_inference_mode = inference_context is not None and not self.training diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 08e46a039e2..4995de973db 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1,4 +1,5 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + from collections import deque from contextlib import nullcontext @@ -10,7 +11,10 @@ DEBUG = False DEBUG_RANK = 0 -from megatron.core.transformer.cuda_graphs import is_graph_capturing +from megatron.core.transformer.cuda_graphs import ( + is_graph_capturing, + set_external_join_stream_for_graph_capture, +) def debug_rank(message): @@ -341,6 +345,11 @@ def __init__(self, name): self.offload = True self.total_offload_bytes = 0 self.total_tensor_count = 0 + # Events should be created with `external=True` in case of graph capture, since the record + # and synchronization of the event may occur in different graphs. Create the event lazily + # for back compatibility with older pytorch versions. + self._offload_event_cudagraph = None + self._reload_event_cudagraph = None # Using memory pool is for the compatibility with cuda graph. # Shapes of tensors for expert_fc1 and moe_act are not known in advance, # so we do not use CPU pool for them. @@ -359,19 +368,35 @@ def pop_tensor(self, tag): def record_offload_event(self, stream): """Record the offload event.""" - self._offload_event.record(stream) + if is_graph_capturing(): + if self._offload_event_cudagraph is None: + self._offload_event_cudagraph = torch.cuda.Event(external=True) + self._offload_event_cudagraph.record(stream) + else: + self._offload_event.record(stream) def wait_offload_event(self, stream): """Wait for the offload event.""" - stream.wait_event(self._offload_event) + if is_graph_capturing(): + stream.wait_event(self._offload_event_cudagraph) + else: + stream.wait_event(self._offload_event) def record_reload_event(self, stream): """Record the reload event.""" - self._reload_event.record(stream) + if is_graph_capturing(): + if self._reload_event_cudagraph is None: + self._reload_event_cudagraph = torch.cuda.Event(external=True) + self._reload_event_cudagraph.record(stream) + else: + self._reload_event.record(stream) def wait_reload_event(self, stream): """Wait for the reload event.""" - stream.wait_event(self._reload_event) + if is_graph_capturing(): + stream.wait_event(self._reload_event_cudagraph) + else: + stream.wait_event(self._reload_event) def update_offload_info(self, tensor): """Update the offload information.""" @@ -867,6 +892,9 @@ def tensor_need_offloading_checker(self, tensor): # Respect tensor's offload preference if specified if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: return False + if hasattr(tensor, "_TE_do_not_offload") and tensor._TE_do_not_offload: + return False + return True def bulk_offload_group(self): @@ -903,8 +931,7 @@ def bulk_reload_group(self): torch.cuda.nvtx.range_push("activation reloading " + group_to_reload._name) with torch.cuda.stream(self.h2d_stream): # Wait for offload to complete before reloading - if not is_graph_capturing(): - group_to_reload.wait_offload_event(self.h2d_stream) + group_to_reload.wait_offload_event(self.h2d_stream) for tensor_tag, state in group_to_reload._tensors.items(): # Only reload if tensor was offloaded (stored as tuple) if isinstance(state, tuple): @@ -969,6 +996,11 @@ def on_group_commit_forward(self, forced_released_tensors): if not self.do_offload: return debug_rank("--on_group_commit_forward") + + if is_graph_capturing(): + # Mark that d2h_stream is used so it gets joined before capture ends + set_external_join_stream_for_graph_capture(self.d2h_stream) + # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload(forced_released_tensors) @@ -1005,7 +1037,7 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors - if not is_graph_capturing() and len(self._reloading_group) > 0: + if len(self._reloading_group) > 0: for reloading_group in self._reloading_group: if reloading_group._name == name: reloading_group.wait_reload_event(torch.cuda.current_stream()) @@ -1042,6 +1074,10 @@ def on_group_start_backward(self): if not self.do_offload: return debug_rank(f"--on_group_start_backward {self}") + + if is_graph_capturing(): + set_external_join_stream_for_graph_capture(self.h2d_stream) + # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index cc71cdc32f6..dfc606516e2 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -25,6 +25,9 @@ tensor_merge, ) from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer import TransformerConfig @@ -397,6 +400,21 @@ def __init__( ) self.tp_group = pg_collection.tp + self.offload_in_proj = ( + self.config.fine_grained_activation_offloading + and "mamba_in_proj" in self.config.offload_modules + ) + + self.offload_out_proj = ( + self.config.fine_grained_activation_offloading + and "mamba_out_proj" in self.config.offload_modules + ) + + self.offload_ssm = ( + self.config.fine_grained_activation_offloading + and "mamba_ssm" in self.config.offload_modules + ) + def forward( self, hidden_states, @@ -429,7 +447,13 @@ def forward( out, out_bias = self._decode(hidden_states, conv_state, ssm_state) return out, out_bias - zxBCdt, _ = self.in_proj(hidden_states) + with off_interface(self.offload_in_proj, hidden_states, "mamba_in_proj") as hidden_states: + zxBCdt, _ = self.in_proj(hidden_states) + + if self.offload_in_proj: + zxBCdt = off_interface.group_commit( + zxBCdt, name="mamba_in_proj", forced_released_tensors=[] + ) zxBCdt = self.cp.pre_conv_ssm(zxBCdt, packed_seq_params) @@ -444,7 +468,11 @@ def forward( assert ssm_state is None y = self._ssm_training(zxBCdt, packed_seq_params) - out, out_bias = self.out_proj(y) + with off_interface(self.offload_out_proj, y, "mamba_out_proj") as y: + out, out_bias = self.out_proj(y) + + if self.offload_out_proj: + out = off_interface.group_commit(out, name="mamba_out_proj", forced_released_tensors=[]) return out, out_bias @@ -656,24 +684,29 @@ def _ssm_training( assert sequence_packing_available, reason_for_no_sequence_packing seq_idx = self._create_packed_seq_idx(packed_seq_params, zxBCdt.shape[1]) - y = mamba_split_conv1d_scan_combined( - zxBCdt, - rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), - self.cp.get_conv1d_bias(), - self.cp.get_dt_bias().float(), - A, - D=( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ), - chunk_size=self.chunk_size, - activation=self.activation, - headdim=None if self.D_has_hdim else self.headdim, - ngroups=self.cp.ngroups_local_tpcp, - norm_before_gate=self.norm_before_gate, - seq_idx=seq_idx, - ) + with off_interface(self.offload_ssm, zxBCdt, "mamba_ssm") as zxBCdt: + + y = mamba_split_conv1d_scan_combined( + zxBCdt, + rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), + self.cp.get_conv1d_bias(), + self.cp.get_dt_bias().float(), + A, + D=( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ), + chunk_size=self.chunk_size, + activation=self.activation, + headdim=None if self.D_has_hdim else self.headdim, + ngroups=self.cp.ngroups_local_tpcp, + norm_before_gate=self.norm_before_gate, + seq_idx=seq_idx, + ) + + if self.offload_ssm: + y = off_interface.group_commit(y, name="mamba_ssm", forced_released_tensors=[]) y = rearrange(y, "b l d -> l b d").contiguous() y = self.cp.post_conv_ssm(y, packed_seq_params) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index dd0dad8eba5..0971e5532f6 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -66,6 +66,7 @@ _IS_GRAPH_CAPTURING = False _IS_GRAPH_WARMUP = False +_JOIN_STREAMS = [] logger = logging.getLogger(__name__) # Freeze GC during capture. @@ -98,6 +99,71 @@ def _set_capture_end(): _IS_GRAPH_CAPTURING = False +def set_external_join_stream_for_graph_capture(stream: torch.cuda.Stream): + """Add a stream that needs to be joined for CUDA graph capture.""" + _JOIN_STREAMS.append(stream) + + +def join_external_streams(): + """Join external streams back to current stream if it was used during capture.""" + global _JOIN_STREAMS + + for stream in _JOIN_STREAMS: + torch.cuda.current_stream().wait_stream(stream) + _JOIN_STREAMS = [] + + +class StreamTracker: + """Tracks tensor references and alternates CUDA streams across consecutive cudagraph replays. + + Multi-buffers both tensor references (intercepted via record_stream) and CUDA streams so + that consecutive cudagraph replays use different streams, allowing overlap with async comms. + This is needed as tensors attached to a stream with 'torch.Tensor.record_stream' + may be deallocated when the stream is joined. Because we may overlap the subsequent + cudagraph with the previous cudagraph's async comms (ie. offloads), the subsequent cudagraph + may invalidate the memory of attached tensors. So we intercept torch.Tensor.record_stream calls + and manually store references to such tensors until the graph is guarenteed to have finished. + + """ + + # Multi-buffered tensor references and CUDA streams, alternated via _idx + _num_buffers = 4 + _buffers = [[] for _ in range(_num_buffers)] + _streams = [None] * _num_buffers + _idx = 0 + + _original_record_stream = torch.Tensor.record_stream + + def _patched_record_stream(*args, **kwargs): + StreamTracker._original_record_stream(*args, **kwargs) + StreamTracker._buffers[StreamTracker._idx].append(args[0]) + + def __init__(self): + StreamTracker._idx = (StreamTracker._idx + 1) % StreamTracker._num_buffers + StreamTracker._buffers[StreamTracker._idx] = [] + + @classmethod + def clear(cls): + """Clear all record_stream tensor references.""" + cls._buffers = [[] for _ in range(cls._num_buffers)] + + @classmethod + def get_next_stream(cls): + """Return the next alternating stream for cudagraph replay.""" + if cls._streams[0] is None: + cls._streams = [torch.cuda.Stream() for _ in range(cls._num_buffers)] + idx = cls._idx + cls._idx = (cls._idx + 1) % cls._num_buffers + return cls._streams[idx] + + def __enter__(self): + torch.Tensor.record_stream = StreamTracker._patched_record_stream + return self + + def __exit__(self, *args): + torch.Tensor.record_stream = StreamTracker._original_record_stream + + def is_graph_warmup(): """Query if currently warming up for graph capture.""" return _IS_GRAPH_WARMUP @@ -108,10 +174,23 @@ def _set_warmup_start(): global _IS_GRAPH_WARMUP _IS_GRAPH_WARMUP = True + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_disable_offload, + ) + + fine_grained_offloading_disable_offload() + def _set_warmup_end(): """Set graph warmup has ended.""" global _IS_GRAPH_WARMUP + _IS_GRAPH_WARMUP = False + + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_enable_offload, + ) + + fine_grained_offloading_enable_offload() @dataclass @@ -368,6 +447,14 @@ def create_cudagraphs(cls): [isinstance(m, TransformerEngineBaseModule) for m in base_module.modules()] ) + # Graph captures requires offloading to be from a blank state. + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + off_interface.reset() + + last_fwd_runner = [r for r in cls.cudagraph_record if r[1] == "fwd"][-1] progress_bar = enumerate(cls.cudagraph_record) time_start = time.time() mem_stats_start = torch.cuda.memory_stats() @@ -422,6 +509,16 @@ def format_mem_bytes(mem_bytes): if graph_type == 'fwd': args, kwargs, out = g[2:] runner.create_fwd_graph(args, kwargs, out, clone_inputs=True) + + if runner is last_fwd_runner: + StreamTracker.clear() + if FREEZE_GC: + # gc.collect() drops references to unreachable tensors created during + # capture, returning their storage to the allocator to avoid a slowdown + # during replay. However, it forces expensive global garbage collection, + # so must be done only on the last layer per-device to avoid slowing + # down graph creation. + gc.collect() else: assert fwd_buffer_reuse_ref_count == 0 runner.create_bwd_graph() @@ -463,6 +560,9 @@ def format_mem_bytes(mem_bytes): cls.cudagraph_created = True cls.cudagraph_record = [] + # Reset offloading data structures, which may have been advanced during capture + off_interface.reset() + # Finished capturing. _set_capture_end() if has_te_modules: @@ -576,15 +676,15 @@ def forward(ctx, runner, is_first_microbatch, *inputs): # Copy new data into fwd graph input buffer need_copy_inputs = [] + for user_input, cudagraph_input in zip(inputs, runner.fwd_graph_input_surface): - if ( - hasattr(cudagraph_input, "can_skip_replay_copy") - and cudagraph_input.can_skip_replay_copy - ): - need_copy_inputs.append(user_input) - assert user_input.data_ptr() == cudagraph_input.data_ptr() - else: + if hasattr(cudagraph_input, "can_skip_replay_copy"): + if not cudagraph_input.can_skip_replay_copy: + cudagraph_input.copy_(user_input) + need_copy_inputs.append(user_input) + elif user_input.data_ptr() != cudagraph_input.data_ptr(): cudagraph_input.copy_(user_input) + need_copy_inputs.append(user_input) ctx.runner = runner ctx.save_for_backward(*need_copy_inputs) @@ -609,7 +709,15 @@ def forward(ctx, runner, is_first_microbatch, *inputs): FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch) runner.fp8_param_cache_updated = is_first_microbatch - runner.fwd_graph.replay() + if runner.use_stream: + stream = StreamTracker.get_next_stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + runner.fwd_graph.replay() + torch.cuda.current_stream().wait_event(runner.fwd_completion_event) + else: + runner.fwd_graph.replay() + return runner.fwd_graph_output_surface @staticmethod @@ -626,12 +734,11 @@ def backward(ctx, *grads): assert len(grads) == len( runner.static_grad_outputs ), "Bwd cudagraph received a different number of tensors than what it was graphed with!" - need_copy_inputs = list(ctx.saved_tensors) for cudagraph_input in runner.fwd_graph_input_surface: if ( hasattr(cudagraph_input, "can_skip_replay_copy") - and cudagraph_input.can_skip_replay_copy + and not cudagraph_input.can_skip_replay_copy ): cudagraph_input.copy_(need_copy_inputs.pop(0)) @@ -642,7 +749,15 @@ def backward(ctx, *grads): if user_output_grad.data_ptr() != cudagraph_output_grad.data_ptr(): cudagraph_output_grad.copy_(user_output_grad) - runner.bwd_graph.replay() + if runner.use_stream: + stream = StreamTracker.get_next_stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + runner.bwd_graph.replay() + torch.cuda.current_stream().wait_event(runner.bwd_completion_event) + else: + runner.bwd_graph.replay() + runner.status = _GraphStatus.FWD_READY # Update FP8 scale factors if needed @@ -687,6 +802,7 @@ def __init__( self.base_module = base_module self.mempool = mempool + self.use_stream = False self.fwd_graph_input_arg_metas = [ArgMetadata(a) for a in fwd_graph_input_args] self.fwd_graph_input_kwarg_metas = { @@ -731,6 +847,12 @@ def __init__( self.fp8_runtime_enabled = None self.fp4_runtime_enabled = None + if self.base_module.config.fine_grained_activation_offloading: + # Use alternating streams from StreamTracker for graph replays + self.use_stream = True + self.fwd_completion_event = torch.cuda.Event(external=True, interprocess=True) + self.bwd_completion_event = torch.cuda.Event(external=True, interprocess=True) + if self.fp8_enabled: self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) @@ -878,7 +1000,7 @@ def _resolve_input_buffer(ten): if clone_inputs: # if a buffer is used for multiple inputs, create it now - for ten in self.get_tensors(args, kwargs): + for ten in self.get_arg_metas(args, kwargs): if ( hasattr(ten, 'cg_buffer_metadata') and ten.cg_buffer_metadata.input_use_count > 1 @@ -945,24 +1067,30 @@ def clone_ten(ten): if FREEZE_GC: gc.freeze() - with torch.cuda.graph( - self.fwd_graph, pool=self.mempool, capture_error_mode="thread_local" + if self.use_stream: + record_stream_tracker = StreamTracker() + else: + record_stream_tracker = nullcontext() + + with ( + torch.cuda.graph( + self.fwd_graph, pool=self.mempool, capture_error_mode="thread_local" + ), + record_stream_tracker, ): fwd_graph_outputs = self.func( *self.fwd_graph_input_args, **self.fwd_graph_input_kwargs ) + # Record completion event inside the graph so the current stream can + # proceed as soon as module compute finishes, before async comms are joined. + if self.use_stream: + self.fwd_completion_event.record() + join_external_streams() # Unfreeze GC. if FREEZE_GC: gc.unfreeze() - # gc.collect() drops references to unreachable tensors created during capture, - # returning their storage to the allocator to avoid a slowdown during replay. - # However, it forces expensive global garbage collection, so must be done - # only on the last layer per-device to avoid slowing down graph creation. - if self.is_last_layer: - gc.collect() - # save cudagraph output buffer self.fwd_graph_outputs = fwd_graph_outputs self.fwd_graph_output_surface = self.get_tensors(fwd_graph_outputs) @@ -1071,6 +1199,10 @@ def create_bwd_graph(self): allow_unused=True, ) + if self.use_stream: + self.bwd_completion_event.record() + join_external_streams() + # Unfreeze GC. if FREEZE_GC: gc.unfreeze() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 5648657d466..b5ea5c03004 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1296,6 +1296,9 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", + "mamba_in_proj", + "mamba_out_proj", + "mamba_ssm", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( diff --git a/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml b/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml index f965ee1d9ef..316a623324d 100644 --- a/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml +++ b/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml @@ -41,4 +41,5 @@ MODEL_ARGS: --bf16: true --ckpt-format: torch --attention-backend: unfused + --exit-interval: 50 TEST_TYPE: regular