From fc05ed8915112baa98e3ea431b1f055df732787c Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Mon, 2 Feb 2026 07:42:19 -0800 Subject: [PATCH 1/7] init commit Signed-off-by: Jimmy Zhang --- megatron/core/models/mamba/mamba_model.py | 26 ++++++ .../fine_grained_activation_offload.py | 71 ++++++++++++--- megatron/core/transformer/cuda_graphs.py | 86 ++++++++++++++++--- 3 files changed, 159 insertions(+), 24 deletions(-) diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 8d45e1d0147..405bae3688d 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -11,6 +11,9 @@ from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region @@ -157,6 +160,8 @@ def __init__( quant_config = get_quant_config_or_none(name, self.config.quant_recipe) module.finish_init(quant_config) + self.disable_param_offloading = True + def set_input_tensor(self, input_tensor: Tensor) -> None: """Sets input tensor to the model. @@ -173,6 +178,24 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' self.decoder.set_input_tensor(input_tensor[0]) + def preprocess_for_fine_grained_offloading(self): + """Preprocess for fine-grained activation offloading.""" + off_interface.init_chunk_handler( + vp_size=self.config.virtual_pipeline_model_parallel_size, + vp_stage=self.vp_stage, + min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, + ) + if self.disable_param_offloading: + for param in self.decoder.parameters(): + off_interface.mark_not_offloadable(param) + if self.pre_process: + for param in self.embedding.parameters(): + off_interface.mark_not_offloadable(param) + if self.post_process: + for param in self.output_layer.parameters(): + off_interface.mark_not_offloadable(param) + self.disable_param_offloading = False + def forward( self, input_ids: Tensor, @@ -196,6 +219,9 @@ def forward( # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + if self.config.fine_grained_activation_offloading: + self.preprocess_for_fine_grained_offloading() + inference_context = deprecate_inference_params(inference_context, inference_params) in_inference_mode = inference_context is not None and not self.training diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 08e46a039e2..7612e5736eb 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -10,7 +10,11 @@ DEBUG = False DEBUG_RANK = 0 -from megatron.core.transformer.cuda_graphs import is_graph_capturing +from megatron.core.transformer.cuda_graphs import ( + is_graph_capturing, + is_graph_warmup, + set_external_join_stream_for_graph_capture, +) def debug_rank(message): @@ -341,6 +345,11 @@ def __init__(self, name): self.offload = True self.total_offload_bytes = 0 self.total_tensor_count = 0 + # Events should be created with `external=True` in case of graph capture, since the record + # and synchronization of the event may occur in different graphs. Create the event lazily + # for back compatibility with older pytorch versions. + self._offload_event_cudagraph = None + self._reload_event_cudagraph = None # Using memory pool is for the compatibility with cuda graph. # Shapes of tensors for expert_fc1 and moe_act are not known in advance, # so we do not use CPU pool for them. @@ -359,19 +368,35 @@ def pop_tensor(self, tag): def record_offload_event(self, stream): """Record the offload event.""" - self._offload_event.record(stream) + if is_graph_capturing(): + if self._offload_event_cudagraph is None: + self._offload_event_cudagraph = torch.cuda.Event(external=True) + self._offload_event_cudagraph.record(stream) + else: + self._offload_event.record(stream) def wait_offload_event(self, stream): """Wait for the offload event.""" - stream.wait_event(self._offload_event) + if is_graph_capturing(): + stream.wait_event(self._offload_event_cudagraph) + else: + stream.wait_event(self._offload_event) def record_reload_event(self, stream): """Record the reload event.""" - self._reload_event.record(stream) + if is_graph_capturing(): + if self._reload_event_cudagraph is None: + self._reload_event_cudagraph = torch.cuda.Event(external=True) + self._reload_event_cudagraph.record(stream) + else: + self._reload_event.record(stream) def wait_reload_event(self, stream): """Wait for the reload event.""" - stream.wait_event(self._reload_event) + if is_graph_capturing(): + stream.wait_event(self._reload_event_cudagraph) + else: + stream.wait_event(self._reload_event) def update_offload_info(self, tensor): """Update the offload information.""" @@ -903,8 +928,7 @@ def bulk_reload_group(self): torch.cuda.nvtx.range_push("activation reloading " + group_to_reload._name) with torch.cuda.stream(self.h2d_stream): # Wait for offload to complete before reloading - if not is_graph_capturing(): - group_to_reload.wait_offload_event(self.h2d_stream) + group_to_reload.wait_offload_event(self.h2d_stream) for tensor_tag, state in group_to_reload._tensors.items(): # Only reload if tensor was offloaded (stored as tuple) if isinstance(state, tuple): @@ -969,8 +993,15 @@ def on_group_commit_forward(self, forced_released_tensors): if not self.do_offload: return debug_rank("--on_group_commit_forward") - # Wait for compute to finish before starting offload - self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + if is_graph_capturing(): + # Mark that d2h_stream is used so it gets joined before capture ends + set_external_join_stream_for_graph_capture(self.d2h_stream) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + self.d2h_stream.wait_event(event) + else: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload(forced_released_tensors) def bulk_reload(self): @@ -1005,7 +1036,7 @@ def on_group_commit_backward(self, name): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}" # Wait for reload to complete before using tensors - if not is_graph_capturing() and len(self._reloading_group) > 0: + if len(self._reloading_group) > 0: for reloading_group in self._reloading_group: if reloading_group._name == name: reloading_group.wait_reload_event(torch.cuda.current_stream()) @@ -1042,8 +1073,14 @@ def on_group_start_backward(self): if not self.do_offload: return debug_rank(f"--on_group_start_backward {self}") - # Wait for compute to finish before starting reload - self.h2d_stream.wait_stream(torch.cuda.current_stream()) + + if is_graph_capturing(): + set_external_join_stream_for_graph_capture(self.h2d_stream) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + self.h2d_stream.wait_event(event) + else: + self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() @@ -1215,6 +1252,9 @@ def __init__(self, offload: bool, tensor: torch.Tensor, name: str): def __enter__(self): """Enter context manager to enable activation offloading hooks.""" + if is_graph_warmup(): + return self.tensor + if self.offload: self.tensor = fine_grained_offloading_group_start(self.tensor, self.name) PipelineOffloadManager.get_instance().__enter__() @@ -1222,6 +1262,9 @@ def __enter__(self): def __exit__(self, *args: Any): """Exit context manager to disable activation offloading hooks.""" + if is_graph_warmup(): + return self.tensor + if self.offload: PipelineOffloadManager.get_instance().__exit__() @@ -1240,6 +1283,10 @@ def get_context(flag): @staticmethod def group_commit(tensor, name, forced_released_tensors=None, delay_offload=False): """Group commit the tensors.""" + + if is_graph_warmup(): + return tensor + return fine_grained_offloading_group_commit( tensor, name, forced_released_tensors, delay_offload ) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 3643c42c3ce..6c70af84bd5 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -66,6 +66,7 @@ _IS_GRAPH_CAPTURING = False _IS_GRAPH_WARMUP = False +_JOIN_STREAMS = [] logger = logging.getLogger(__name__) # Freeze GC during capture. @@ -98,6 +99,20 @@ def _set_capture_end(): _IS_GRAPH_CAPTURING = False +def set_external_join_stream_for_graph_capture(stream: torch.cuda.Stream): + """Add a stream that needs to be joined for CUDA graph capture.""" + _JOIN_STREAMS.append(stream) + + +def join_external_streams(): + """Join external streams back to current stream if it was used during capture.""" + global _JOIN_STREAMS + + for stream in _JOIN_STREAMS: + torch.cuda.current_stream().wait_stream(stream) + _JOIN_STREAMS = [] + + def is_graph_warmup(): """Query if currently warming up for graph capture.""" return _IS_GRAPH_WARMUP @@ -112,6 +127,7 @@ def _set_warmup_start(): def _set_warmup_end(): """Set graph warmup has ended.""" global _IS_GRAPH_WARMUP + _IS_GRAPH_WARMUP = False @dataclass @@ -576,15 +592,20 @@ def forward(ctx, runner, is_first_microbatch, *inputs): # Copy new data into fwd graph input buffer need_copy_inputs = [] + for user_input, cudagraph_input in zip(inputs, runner.fwd_graph_input_surface): - if ( - hasattr(cudagraph_input, "can_skip_replay_copy") - and cudagraph_input.can_skip_replay_copy - ): - need_copy_inputs.append(user_input) - assert user_input.data_ptr() == cudagraph_input.data_ptr() + if hasattr(cudagraph_input, "can_skip_replay_copy"): + if not cudagraph_input.can_skip_replay_copy: + cudagraph_input.copy_(user_input) + need_copy_inputs.append(user_input) else: - cudagraph_input.copy_(user_input) + assert user_input.data_ptr() == cudagraph_input.data_ptr(), ( + f"Static CUDA graph input tensor changed memory address between graph capture and replay. " + f"Expected data_ptr={cudagraph_input.data_ptr()}, got data_ptr={user_input.data_ptr()}. " + f"Tensor shape={user_input.shape}, dtype={user_input.dtype}. " + f"Static inputs must maintain the same memory address across replays. " + f"Consider marking this input as copyable during graph capture." + ) ctx.runner = runner ctx.save_for_backward(*need_copy_inputs) @@ -609,7 +630,14 @@ def forward(ctx, runner, is_first_microbatch, *inputs): FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch) runner.fp8_param_cache_updated = is_first_microbatch - runner.fwd_graph.replay() + if runner.stream is not None: + runner.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(runner.stream): + runner.fwd_graph.replay() + torch.cuda.current_stream().wait_event(runner.fwd_completion_event) + else: + runner.fwd_graph.replay() + return runner.fwd_graph_output_surface @staticmethod @@ -626,12 +654,11 @@ def backward(ctx, *grads): assert len(grads) == len( runner.static_grad_outputs ), "Bwd cudagraph received a different number of tensors than what it was graphed with!" - need_copy_inputs = list(ctx.saved_tensors) for cudagraph_input in runner.fwd_graph_input_surface: if ( hasattr(cudagraph_input, "can_skip_replay_copy") - and cudagraph_input.can_skip_replay_copy + and not cudagraph_input.can_skip_replay_copy ): cudagraph_input.copy_(need_copy_inputs.pop(0)) @@ -642,7 +669,14 @@ def backward(ctx, *grads): if user_output_grad.data_ptr() != cudagraph_output_grad.data_ptr(): cudagraph_output_grad.copy_(user_output_grad) - runner.bwd_graph.replay() + if runner.stream is not None: + runner.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(runner.stream): + runner.bwd_graph.replay() + torch.cuda.current_stream().wait_event(runner.bwd_completion_event) + else: + runner.bwd_graph.replay() + runner.status = _GraphStatus.FWD_READY # Update FP8 scale factors if needed @@ -678,6 +712,7 @@ def __init__( fwd_graph_input_kwargs: Dict[str, Any], func, need_backward, + stream: torch.cuda.Stream = None, ): """Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into @@ -687,6 +722,7 @@ def __init__( self.base_module = base_module self.mempool = mempool + self.stream = None self.fwd_graph_input_arg_metas = [ArgMetadata(a) for a in fwd_graph_input_args] self.fwd_graph_input_kwarg_metas = { @@ -731,6 +767,20 @@ def __init__( self.fp8_runtime_enabled = None self.fp4_runtime_enabled = None + if self.base_module.config.fine_grained_activation_offloading: + # Dedicated stream for this runner's graph replays + self.stream = stream + self.fwd_completion_event = ( + torch.cuda.Event(external=True, interprocess=True) + if stream is not None + else None + ) + self.bwd_completion_event = ( + torch.cuda.Event(external=True, interprocess=True) + if stream is not None + else None + ) + if self.fp8_enabled: self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) @@ -878,7 +928,7 @@ def _resolve_input_buffer(ten): if clone_inputs: # if a buffer is used for multiple inputs, create it now - for ten in self.get_tensors(args, kwargs): + for ten in self.get_arg_metas(args, kwargs): if ( hasattr(ten, 'cg_buffer_metadata') and ten.cg_buffer_metadata.input_use_count > 1 @@ -951,6 +1001,11 @@ def clone_ten(ten): fwd_graph_outputs = self.func( *self.fwd_graph_input_args, **self.fwd_graph_input_kwargs ) + # Record completion event inside the graph so the current stream can + # proceed as soon as module compute finishes, before async comms are joined. + if self.fwd_completion_event is not None: + self.fwd_completion_event.record() + join_external_streams() # Unfreeze GC. if FREEZE_GC: @@ -1071,6 +1126,10 @@ def create_bwd_graph(self): allow_unused=True, ) + if self.bwd_completion_event is not None: + self.bwd_completion_event.record() + join_external_streams() + # Unfreeze GC. if FREEZE_GC: gc.unfreeze() @@ -1422,6 +1481,7 @@ def wrapped_func(*args, **kwargs): self.cudagraph_runners: list[_CudaGraphRunner] = [] self.inference_cudagraphs_lookup_table: dict = defaultdict(lambda: None) self.is_first_microbatch = False + self.stream = torch.cuda.Stream() # Without pipeline parallelism, microbatches execute one at a time. # Therefore modules will always execute in the same order, so cudagraphs @@ -1498,6 +1558,7 @@ def is_valid(r): kwargs, self.func, self.need_backward, + stream=self.stream, ) self.cudagraph_runners.append(runner) if is_inference_mode: @@ -1522,6 +1583,7 @@ def is_valid(r): kwargs, self.func, self.need_backward, + stream=self.stream, ) self.cudagraph_runners.append(runner) From 036331aff59cd4d7ff866f5f8a9ea539ca42f7f7 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Mon, 2 Feb 2026 13:03:04 -0800 Subject: [PATCH 2/7] support mamba offloading Signed-off-by: Jimmy Zhang --- megatron/core/ssm/mamba_mixer.py | 31 +++++++++++++++++-- megatron/core/transformer/cuda_graphs.py | 13 ++++++-- .../core/transformer/transformer_config.py | 2 ++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index cc71cdc32f6..9ad89b2fac7 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -24,6 +24,9 @@ tensor_masked_update, tensor_merge, ) +from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, +) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import get_cuda_rng_tracker @@ -397,6 +400,17 @@ def __init__( ) self.tp_group = pg_collection.tp + self.offload_in_proj = ( + self.config.fine_grained_activation_offloading + and "mamba_in_proj" in self.config.offload_modules + ) + + self.offload_out_proj = ( + self.config.fine_grained_activation_offloading + and "mamba_out_proj" in self.config.offload_modules + ) + + def forward( self, hidden_states, @@ -429,7 +443,13 @@ def forward( out, out_bias = self._decode(hidden_states, conv_state, ssm_state) return out, out_bias - zxBCdt, _ = self.in_proj(hidden_states) + with off_interface(self.offload_in_proj, hidden_states, "mamba_in_proj") as hidden_states: + zxBCdt, _ = self.in_proj(hidden_states) + + if self.offload_in_proj: + zxBCdt = off_interface.group_commit( + zxBCdt, name="mamba_in_proj", forced_released_tensors=[] + ) zxBCdt = self.cp.pre_conv_ssm(zxBCdt, packed_seq_params) @@ -444,7 +464,14 @@ def forward( assert ssm_state is None y = self._ssm_training(zxBCdt, packed_seq_params) - out, out_bias = self.out_proj(y) + + with off_interface(self.offload_out_proj, y, "mamba_out_proj") as y: + out, out_bias = self.out_proj(y) + + if self.offload_out_proj: + out = off_interface.group_commit( + out, name="mamba_out_proj", forced_released_tensors=[] + ) return out, out_bias diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 6c70af84bd5..f0e4ea6050c 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -384,6 +384,12 @@ def create_cudagraphs(cls): [isinstance(m, TransformerEngineBaseModule) for m in base_module.modules()] ) + # Graph captures requires offloading to be from a blank state. + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + off_interface.reset() + progress_bar = enumerate(cls.cudagraph_record) time_start = time.time() mem_stats_start = torch.cuda.memory_stats() @@ -479,6 +485,9 @@ def format_mem_bytes(mem_bytes): cls.cudagraph_created = True cls.cudagraph_record = [] + # Reset offloading data structures, which may have been advanced during capture + off_interface.reset() + # Finished capturing. _set_capture_end() if has_te_modules: @@ -1003,7 +1012,7 @@ def clone_ten(ten): ) # Record completion event inside the graph so the current stream can # proceed as soon as module compute finishes, before async comms are joined. - if self.fwd_completion_event is not None: + if self.stream is not None: self.fwd_completion_event.record() join_external_streams() @@ -1126,7 +1135,7 @@ def create_bwd_graph(self): allow_unused=True, ) - if self.bwd_completion_event is not None: + if self.stream is not None: self.bwd_completion_event.record() join_external_streams() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index eaae585905e..5462d331ff8 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1278,6 +1278,8 @@ def __post_init__(self): "attn_norm", "mlp_norm", "qkv_linear", + "mamba_in_proj", + "mamba_out_proj", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( From 4fd8b869cdaafe0a0b9014c65ad6cefbc6eed425 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Tue, 3 Feb 2026 08:09:40 -0800 Subject: [PATCH 3/7] cleanup Signed-off-by: Jimmy Zhang --- .../fine_grained_activation_offload.py | 2 +- megatron/core/ssm/mamba_mixer.py | 8 ++------ megatron/core/transformer/cuda_graphs.py | 13 +++---------- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 7612e5736eb..e140317e42b 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1263,7 +1263,7 @@ def __enter__(self): def __exit__(self, *args: Any): """Exit context manager to disable activation offloading hooks.""" if is_graph_warmup(): - return self.tensor + return if self.offload: PipelineOffloadManager.get_instance().__exit__() diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 9ad89b2fac7..d92e2821e47 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -24,10 +24,10 @@ tensor_masked_update, tensor_merge, ) +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer import TransformerConfig @@ -410,7 +410,6 @@ def __init__( and "mamba_out_proj" in self.config.offload_modules ) - def forward( self, hidden_states, @@ -464,14 +463,11 @@ def forward( assert ssm_state is None y = self._ssm_training(zxBCdt, packed_seq_params) - with off_interface(self.offload_out_proj, y, "mamba_out_proj") as y: out, out_bias = self.out_proj(y) if self.offload_out_proj: - out = off_interface.group_commit( - out, name="mamba_out_proj", forced_released_tensors=[] - ) + out = off_interface.group_commit(out, name="mamba_out_proj", forced_released_tensors=[]) return out, out_bias diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index f0e4ea6050c..8a5bb65c0c0 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -388,6 +388,7 @@ def create_cudagraphs(cls): from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) + off_interface.reset() progress_bar = enumerate(cls.cudagraph_record) @@ -779,16 +780,8 @@ def __init__( if self.base_module.config.fine_grained_activation_offloading: # Dedicated stream for this runner's graph replays self.stream = stream - self.fwd_completion_event = ( - torch.cuda.Event(external=True, interprocess=True) - if stream is not None - else None - ) - self.bwd_completion_event = ( - torch.cuda.Event(external=True, interprocess=True) - if stream is not None - else None - ) + self.fwd_completion_event = torch.cuda.Event(external=True, interprocess=True) + self.bwd_completion_event = torch.cuda.Event(external=True, interprocess=True) if self.fp8_enabled: self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() From 5c9d48c60c68255be3c0f04e4f95bc6dd5c40460 Mon Sep 17 00:00:00 2001 From: Jieming Zhang Date: Thu, 5 Feb 2026 13:03:41 -0800 Subject: [PATCH 4/7] cleanup, handle race Signed-off-by: Jieming Zhang --- .../fine_grained_activation_offload.py | 31 ++--- megatron/core/transformer/cuda_graphs.py | 130 ++++++++++++++---- 2 files changed, 110 insertions(+), 51 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index e140317e42b..b6961bc4532 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1,4 +1,5 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + from collections import deque from contextlib import nullcontext @@ -12,7 +13,6 @@ from megatron.core.transformer.cuda_graphs import ( is_graph_capturing, - is_graph_warmup, set_external_join_stream_for_graph_capture, ) @@ -892,6 +892,9 @@ def tensor_need_offloading_checker(self, tensor): # Respect tensor's offload preference if specified if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: return False + if hasattr(tensor, "_TE_do_not_offload") and tensor._TE_do_not_offload: + return False + return True def bulk_offload_group(self): @@ -997,11 +1000,8 @@ def on_group_commit_forward(self, forced_released_tensors): if is_graph_capturing(): # Mark that d2h_stream is used so it gets joined before capture ends set_external_join_stream_for_graph_capture(self.d2h_stream) - event = torch.cuda.Event() - event.record(torch.cuda.current_stream()) - self.d2h_stream.wait_event(event) - else: - self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload(forced_released_tensors) def bulk_reload(self): @@ -1076,11 +1076,8 @@ def on_group_start_backward(self): if is_graph_capturing(): set_external_join_stream_for_graph_capture(self.h2d_stream) - event = torch.cuda.Event() - event.record(torch.cuda.current_stream()) - self.h2d_stream.wait_event(event) - else: - self.h2d_stream.wait_stream(torch.cuda.current_stream()) + + self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() @@ -1252,9 +1249,6 @@ def __init__(self, offload: bool, tensor: torch.Tensor, name: str): def __enter__(self): """Enter context manager to enable activation offloading hooks.""" - if is_graph_warmup(): - return self.tensor - if self.offload: self.tensor = fine_grained_offloading_group_start(self.tensor, self.name) PipelineOffloadManager.get_instance().__enter__() @@ -1262,9 +1256,6 @@ def __enter__(self): def __exit__(self, *args: Any): """Exit context manager to disable activation offloading hooks.""" - if is_graph_warmup(): - return - if self.offload: PipelineOffloadManager.get_instance().__exit__() @@ -1283,10 +1274,6 @@ def get_context(flag): @staticmethod def group_commit(tensor, name, forced_released_tensors=None, delay_offload=False): """Group commit the tensors.""" - - if is_graph_warmup(): - return tensor - return fine_grained_offloading_group_commit( tensor, name, forced_released_tensors, delay_offload ) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 8a5bb65c0c0..154219eda85 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -113,6 +113,56 @@ def join_external_streams(): _JOIN_STREAMS = [] +class RecordStreamTracker: + """Tracks tensor references and alternates CUDA streams across consecutive cudagraph replays. + + Double-buffers both tensor references (intercepted via record_stream) and CUDA streams so + that consecutive cudagraph replays use different streams, allowing overlap with async comms. + This is needed as any tensors that have been attached to a stream with record_stream + may be deallocated when the stream is joined. Because we may wish to overlap the subsequent + cudagraph with the previous cudagraph's async comms (ie. offloads), the subsequent cudagraph + may invalidate the memory of async comm. So we intercept torch.Tensor.record_stream calls + and manually store references to such tensors. + + """ + + # Double-buffered tensor references and CUDA streams, alternated via _idx + _buffers = [[], []] + _streams = [None, None] + _idx = 0 + + _original_record_stream = torch.Tensor.record_stream + + def _patched_record_stream(*args, **kwargs): + RecordStreamTracker._original_record_stream(*args, **kwargs) + RecordStreamTracker._buffers[RecordStreamTracker._idx].append(args[0]) + + def __init__(self): + RecordStreamTracker._idx = 0 if RecordStreamTracker._idx == 1 else 1 + RecordStreamTracker._buffers[RecordStreamTracker._idx] = [] + + @classmethod + def clear(cls): + """Clear all record_stream tensor references.""" + cls._buffers = [[], []] + + @classmethod + def get_next_stream(cls): + """Return the next alternating stream for cudagraph replay.""" + if cls._streams[0] is None: + cls._streams = [torch.cuda.Stream(), torch.cuda.Stream()] + idx = cls._idx + cls._idx = 0 if cls._idx == 1 else 1 + return cls._streams[idx] + + def __enter__(self): + torch.Tensor.record_stream = RecordStreamTracker._patched_record_stream + return self + + def __exit__(self, *args): + torch.Tensor.record_stream = RecordStreamTracker._original_record_stream + + def is_graph_warmup(): """Query if currently warming up for graph capture.""" return _IS_GRAPH_WARMUP @@ -123,12 +173,24 @@ def _set_warmup_start(): global _IS_GRAPH_WARMUP _IS_GRAPH_WARMUP = True + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_disable_offload, + ) + + fine_grained_offloading_disable_offload() + def _set_warmup_end(): """Set graph warmup has ended.""" global _IS_GRAPH_WARMUP _IS_GRAPH_WARMUP = False + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + fine_grained_offloading_enable_offload, + ) + + fine_grained_offloading_enable_offload() + @dataclass class CudagraphBufferMetadata: @@ -391,6 +453,7 @@ def create_cudagraphs(cls): off_interface.reset() + last_fwd_runner = [r for r in cls.cudagraph_record if r[1] == "fwd"][-1] progress_bar = enumerate(cls.cudagraph_record) time_start = time.time() mem_stats_start = torch.cuda.memory_stats() @@ -445,6 +508,16 @@ def format_mem_bytes(mem_bytes): if graph_type == 'fwd': args, kwargs, out = g[2:] runner.create_fwd_graph(args, kwargs, out, clone_inputs=True) + + if runner is last_fwd_runner: + RecordStreamTracker.clear() + if FREEZE_GC: + # gc.collect() drops references to unreachable tensors created during + # capture, returning their storage to the allocator to avoid a slowdown + # during replay. However, it forces expensive global garbage collection, + # so must be done only on the last layer per-device to avoid slowing + # down graph creation. + gc.collect() else: assert fwd_buffer_reuse_ref_count == 0 runner.create_bwd_graph() @@ -610,11 +683,11 @@ def forward(ctx, runner, is_first_microbatch, *inputs): need_copy_inputs.append(user_input) else: assert user_input.data_ptr() == cudagraph_input.data_ptr(), ( - f"Static CUDA graph input tensor changed memory address between graph capture and replay. " - f"Expected data_ptr={cudagraph_input.data_ptr()}, got data_ptr={user_input.data_ptr()}. " - f"Tensor shape={user_input.shape}, dtype={user_input.dtype}. " - f"Static inputs must maintain the same memory address across replays. " - f"Consider marking this input as copyable during graph capture." + f"Static CUDA graph input tensor changed memory address between graph", + f"capture and replay. Expected data_ptr={cudagraph_input.data_ptr()}, " + f"got data_ptr={user_input.data_ptr()}. Static inputs must maintain ", + f"the same memory address across replays. Consider marking this input as ", + f"copyable during graph capture.", ) ctx.runner = runner @@ -640,9 +713,10 @@ def forward(ctx, runner, is_first_microbatch, *inputs): FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch) runner.fp8_param_cache_updated = is_first_microbatch - if runner.stream is not None: - runner.stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(runner.stream): + if runner.use_stream: + stream = RecordStreamTracker.get_next_stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): runner.fwd_graph.replay() torch.cuda.current_stream().wait_event(runner.fwd_completion_event) else: @@ -679,9 +753,10 @@ def backward(ctx, *grads): if user_output_grad.data_ptr() != cudagraph_output_grad.data_ptr(): cudagraph_output_grad.copy_(user_output_grad) - if runner.stream is not None: - runner.stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(runner.stream): + if runner.use_stream: + stream = RecordStreamTracker.get_next_stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): runner.bwd_graph.replay() torch.cuda.current_stream().wait_event(runner.bwd_completion_event) else: @@ -722,7 +797,6 @@ def __init__( fwd_graph_input_kwargs: Dict[str, Any], func, need_backward, - stream: torch.cuda.Stream = None, ): """Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which are not created until this runner records its graph creation into @@ -732,7 +806,7 @@ def __init__( self.base_module = base_module self.mempool = mempool - self.stream = None + self.use_stream = False self.fwd_graph_input_arg_metas = [ArgMetadata(a) for a in fwd_graph_input_args] self.fwd_graph_input_kwarg_metas = { @@ -778,8 +852,8 @@ def __init__( self.fp4_runtime_enabled = None if self.base_module.config.fine_grained_activation_offloading: - # Dedicated stream for this runner's graph replays - self.stream = stream + # Use alternating streams from RecordStreamTracker for graph replays + self.use_stream = True self.fwd_completion_event = torch.cuda.Event(external=True, interprocess=True) self.bwd_completion_event = torch.cuda.Event(external=True, interprocess=True) @@ -997,15 +1071,23 @@ def clone_ten(ten): if FREEZE_GC: gc.freeze() - with torch.cuda.graph( - self.fwd_graph, pool=self.mempool, capture_error_mode="thread_local" + if self.use_stream: + record_stream_tracker = RecordStreamTracker() + else: + record_stream_tracker = nullcontext() + + with ( + torch.cuda.graph( + self.fwd_graph, pool=self.mempool, capture_error_mode="thread_local" + ), + record_stream_tracker, ): fwd_graph_outputs = self.func( *self.fwd_graph_input_args, **self.fwd_graph_input_kwargs ) # Record completion event inside the graph so the current stream can # proceed as soon as module compute finishes, before async comms are joined. - if self.stream is not None: + if self.use_stream: self.fwd_completion_event.record() join_external_streams() @@ -1013,13 +1095,6 @@ def clone_ten(ten): if FREEZE_GC: gc.unfreeze() - # gc.collect() drops references to unreachable tensors created during capture, - # returning their storage to the allocator to avoid a slowdown during replay. - # However, it forces expensive global garbage collection, so must be done - # only on the last layer per-device to avoid slowing down graph creation. - if self.is_last_layer: - gc.collect() - # save cudagraph output buffer self.fwd_graph_outputs = fwd_graph_outputs self.fwd_graph_output_surface = self.get_tensors(fwd_graph_outputs) @@ -1128,7 +1203,7 @@ def create_bwd_graph(self): allow_unused=True, ) - if self.stream is not None: + if self.use_stream: self.bwd_completion_event.record() join_external_streams() @@ -1483,7 +1558,6 @@ def wrapped_func(*args, **kwargs): self.cudagraph_runners: list[_CudaGraphRunner] = [] self.inference_cudagraphs_lookup_table: dict = defaultdict(lambda: None) self.is_first_microbatch = False - self.stream = torch.cuda.Stream() # Without pipeline parallelism, microbatches execute one at a time. # Therefore modules will always execute in the same order, so cudagraphs @@ -1560,7 +1634,6 @@ def is_valid(r): kwargs, self.func, self.need_backward, - stream=self.stream, ) self.cudagraph_runners.append(runner) if is_inference_mode: @@ -1585,7 +1658,6 @@ def is_valid(r): kwargs, self.func, self.need_backward, - stream=self.stream, ) self.cudagraph_runners.append(runner) From ab82a5c0f9da42367c6bf4ca80c62d77415e9a11 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Thu, 5 Feb 2026 15:01:41 -0800 Subject: [PATCH 5/7] offload ssm Signed-off-by: Jimmy Zhang --- megatron/core/ssm/mamba_mixer.py | 48 ++++++++++++------- .../core/transformer/transformer_config.py | 1 + 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index d92e2821e47..8780f8ec1a2 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -410,6 +410,11 @@ def __init__( and "mamba_out_proj" in self.config.offload_modules ) + self.offload_ssm = ( + self.config.fine_grained_activation_offloading + and "mamba_ssm" in self.config.offload_modules + ) + def forward( self, hidden_states, @@ -679,24 +684,31 @@ def _ssm_training( assert sequence_packing_available, reason_for_no_sequence_packing seq_idx = self._create_packed_seq_idx(packed_seq_params, zxBCdt.shape[1]) - y = mamba_split_conv1d_scan_combined( - zxBCdt, - rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), - self.cp.get_conv1d_bias(), - self.cp.get_dt_bias().float(), - A, - D=( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ), - chunk_size=self.chunk_size, - activation=self.activation, - headdim=None if self.D_has_hdim else self.headdim, - ngroups=self.cp.ngroups_local_tpcp, - norm_before_gate=self.norm_before_gate, - seq_idx=seq_idx, - ) + with off_interface(self.offload_ssm, zxBCdt, "mamba_ssm") as zxBCdt: + + y = mamba_split_conv1d_scan_combined( + zxBCdt, + rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), + self.cp.get_conv1d_bias(), + self.cp.get_dt_bias().float(), + A, + D=( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ), + chunk_size=self.chunk_size, + activation=self.activation, + headdim=None if self.D_has_hdim else self.headdim, + ngroups=self.cp.ngroups_local_tpcp, + norm_before_gate=self.norm_before_gate, + seq_idx=seq_idx, + ) + + if self.offload_ssm: + y = off_interface.group_commit( + y, name="mamba_ssm", forced_released_tensors=[] + ) y = rearrange(y, "b l d -> l b d").contiguous() y = self.cp.post_conv_ssm(y, packed_seq_params) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 5462d331ff8..7618a2fc5b7 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1280,6 +1280,7 @@ def __post_init__(self): "qkv_linear", "mamba_in_proj", "mamba_out_proj", + "mamba_ssm", } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( From c4370858410d4311cde60f7a9866e376b70ac3f7 Mon Sep 17 00:00:00 2001 From: Jieming Zhang Date: Thu, 5 Feb 2026 15:35:09 -0800 Subject: [PATCH 6/7] cleanup Signed-off-by: Jieming Zhang --- .../fine_grained_activation_offload.py | 2 + megatron/core/transformer/cuda_graphs.py | 47 ++++++++++--------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index b6961bc4532..4995de973db 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -1001,6 +1001,7 @@ def on_group_commit_forward(self, forced_released_tensors): # Mark that d2h_stream is used so it gets joined before capture ends set_external_join_stream_for_graph_capture(self.d2h_stream) + # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload(forced_released_tensors) @@ -1077,6 +1078,7 @@ def on_group_start_backward(self): if is_graph_capturing(): set_external_join_stream_for_graph_capture(self.h2d_stream) + # Wait for compute to finish before starting reload self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 154219eda85..ee1bef7391f 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -113,54 +113,55 @@ def join_external_streams(): _JOIN_STREAMS = [] -class RecordStreamTracker: +class StreamTracker: """Tracks tensor references and alternates CUDA streams across consecutive cudagraph replays. - Double-buffers both tensor references (intercepted via record_stream) and CUDA streams so + Multi-buffers both tensor references (intercepted via record_stream) and CUDA streams so that consecutive cudagraph replays use different streams, allowing overlap with async comms. - This is needed as any tensors that have been attached to a stream with record_stream - may be deallocated when the stream is joined. Because we may wish to overlap the subsequent + This is needed as tensors attached to a stream with 'torch.Tensor.record_stream' + may be deallocated when the stream is joined. Because we may overlap the subsequent cudagraph with the previous cudagraph's async comms (ie. offloads), the subsequent cudagraph - may invalidate the memory of async comm. So we intercept torch.Tensor.record_stream calls - and manually store references to such tensors. + may invalidate the memory of attached tensors. So we intercept torch.Tensor.record_stream calls + and manually store references to such tensors until the graph is guarenteed to have finished. """ - # Double-buffered tensor references and CUDA streams, alternated via _idx - _buffers = [[], []] - _streams = [None, None] + # Multi-buffered tensor references and CUDA streams, alternated via _idx + _num_buffers = 4 + _buffers = [[] for _ in range(_num_buffers)] + _streams = [None] * _num_buffers _idx = 0 _original_record_stream = torch.Tensor.record_stream def _patched_record_stream(*args, **kwargs): - RecordStreamTracker._original_record_stream(*args, **kwargs) - RecordStreamTracker._buffers[RecordStreamTracker._idx].append(args[0]) + StreamTracker._original_record_stream(*args, **kwargs) + StreamTracker._buffers[StreamTracker._idx].append(args[0]) def __init__(self): - RecordStreamTracker._idx = 0 if RecordStreamTracker._idx == 1 else 1 - RecordStreamTracker._buffers[RecordStreamTracker._idx] = [] + StreamTracker._idx = (StreamTracker._idx + 1) % StreamTracker._num_buffers + StreamTracker._buffers[StreamTracker._idx] = [] @classmethod def clear(cls): """Clear all record_stream tensor references.""" - cls._buffers = [[], []] + cls._buffers = [[] for _ in range(cls._num_buffers)] @classmethod def get_next_stream(cls): """Return the next alternating stream for cudagraph replay.""" if cls._streams[0] is None: - cls._streams = [torch.cuda.Stream(), torch.cuda.Stream()] + cls._streams = [torch.cuda.Stream() for _ in range(cls._num_buffers)] idx = cls._idx - cls._idx = 0 if cls._idx == 1 else 1 + cls._idx = (cls._idx + 1) % cls._num_buffers return cls._streams[idx] def __enter__(self): - torch.Tensor.record_stream = RecordStreamTracker._patched_record_stream + torch.Tensor.record_stream = StreamTracker._patched_record_stream return self def __exit__(self, *args): - torch.Tensor.record_stream = RecordStreamTracker._original_record_stream + torch.Tensor.record_stream = StreamTracker._original_record_stream def is_graph_warmup(): @@ -510,7 +511,7 @@ def format_mem_bytes(mem_bytes): runner.create_fwd_graph(args, kwargs, out, clone_inputs=True) if runner is last_fwd_runner: - RecordStreamTracker.clear() + StreamTracker.clear() if FREEZE_GC: # gc.collect() drops references to unreachable tensors created during # capture, returning their storage to the allocator to avoid a slowdown @@ -714,7 +715,7 @@ def forward(ctx, runner, is_first_microbatch, *inputs): runner.fp8_param_cache_updated = is_first_microbatch if runner.use_stream: - stream = RecordStreamTracker.get_next_stream() + stream = StreamTracker.get_next_stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): runner.fwd_graph.replay() @@ -754,7 +755,7 @@ def backward(ctx, *grads): cudagraph_output_grad.copy_(user_output_grad) if runner.use_stream: - stream = RecordStreamTracker.get_next_stream() + stream = StreamTracker.get_next_stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): runner.bwd_graph.replay() @@ -852,7 +853,7 @@ def __init__( self.fp4_runtime_enabled = None if self.base_module.config.fine_grained_activation_offloading: - # Use alternating streams from RecordStreamTracker for graph replays + # Use alternating streams from StreamTracker for graph replays self.use_stream = True self.fwd_completion_event = torch.cuda.Event(external=True, interprocess=True) self.bwd_completion_event = torch.cuda.Event(external=True, interprocess=True) @@ -1072,7 +1073,7 @@ def clone_ten(ten): gc.freeze() if self.use_stream: - record_stream_tracker = RecordStreamTracker() + record_stream_tracker = StreamTracker() else: record_stream_tracker = nullcontext() From 494e675f15c3fd5d09a01e9278791633178b3297 Mon Sep 17 00:00:00 2001 From: Jimmy Zhang Date: Fri, 6 Feb 2026 08:03:38 -0800 Subject: [PATCH 7/7] fix tests Signed-off-by: Jimmy Zhang --- megatron/core/ssm/mamba_mixer.py | 4 +--- megatron/core/transformer/cuda_graphs.py | 11 +++-------- .../bert/bert_mcore_tp2_pp2/model_config.yaml | 1 + 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 8780f8ec1a2..dfc606516e2 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -706,9 +706,7 @@ def _ssm_training( ) if self.offload_ssm: - y = off_interface.group_commit( - y, name="mamba_ssm", forced_released_tensors=[] - ) + y = off_interface.group_commit(y, name="mamba_ssm", forced_released_tensors=[]) y = rearrange(y, "b l d -> l b d").contiguous() y = self.cp.post_conv_ssm(y, packed_seq_params) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 8ce3e3e6e80..0971e5532f6 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -682,14 +682,9 @@ def forward(ctx, runner, is_first_microbatch, *inputs): if not cudagraph_input.can_skip_replay_copy: cudagraph_input.copy_(user_input) need_copy_inputs.append(user_input) - else: - assert user_input.data_ptr() == cudagraph_input.data_ptr(), ( - f"Static CUDA graph input tensor changed memory address between graph", - f"capture and replay. Expected data_ptr={cudagraph_input.data_ptr()}, " - f"got data_ptr={user_input.data_ptr()}. Static inputs must maintain ", - f"the same memory address across replays. Consider marking this input as ", - f"copyable during graph capture.", - ) + elif user_input.data_ptr() != cudagraph_input.data_ptr(): + cudagraph_input.copy_(user_input) + need_copy_inputs.append(user_input) ctx.runner = runner ctx.save_for_backward(*need_copy_inputs) diff --git a/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml b/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml index f965ee1d9ef..316a623324d 100644 --- a/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml +++ b/tests/functional_tests/test_cases/bert/bert_mcore_tp2_pp2/model_config.yaml @@ -41,4 +41,5 @@ MODEL_ARGS: --bf16: true --ckpt-format: torch --attention-backend: unfused + --exit-interval: 50 TEST_TYPE: regular