diff --git a/examples/mimo/training/hetero/args.py b/examples/mimo/training/hetero/args.py index 12b7f9f041d..5fc7b19d77c 100644 --- a/examples/mimo/training/hetero/args.py +++ b/examples/mimo/training/hetero/args.py @@ -74,6 +74,18 @@ def parse_args() -> argparse.Namespace: default=0, help="Dense data-parallel replica to trace when --timeline-ranks=dp-replica.", ) + runtime.add_argument( + "--timeline-iter-start", + type=int, + default=None, + help="First training iteration to record in timeline traces.", + ) + runtime.add_argument( + "--timeline-iter-end", + type=int, + default=None, + help="Last training iteration to record in timeline traces.", + ) runtime.add_argument( "--timeline-cuda-events", action="store_true", @@ -152,6 +164,16 @@ def validate_args(args: argparse.Namespace, world_size: int) -> tuple[int, int]: raise ValueError("--log-interval must be >= 1") if args.timeline_dp_replica < 0: raise ValueError("--timeline-dp-replica must be >= 0") + if args.timeline_iter_start is not None and args.timeline_iter_start < 1: + raise ValueError("--timeline-iter-start must be >= 1") + if args.timeline_iter_end is not None and args.timeline_iter_end < 1: + raise ValueError("--timeline-iter-end must be >= 1") + if ( + args.timeline_iter_start is not None + and args.timeline_iter_end is not None + and args.timeline_iter_end < args.timeline_iter_start + ): + raise ValueError("--timeline-iter-end must be >= --timeline-iter-start") validate_model_provider_args(args) if args.dataset_provider == "mock": diff --git a/examples/mimo/training/hetero/grad_sync.py b/examples/mimo/training/hetero/grad_sync.py index be5ee7b0c78..86592a26bf0 100644 --- a/examples/mimo/training/hetero/grad_sync.py +++ b/examples/mimo/training/hetero/grad_sync.py @@ -14,6 +14,7 @@ from examples.mimo.utils.hetero import debug_rank, is_process_group_member from megatron.core.distributed.finalize_model_grads import finalize_model_grads from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.pipeline_parallel.timeline import timeline_event from megatron.core.pipeline_parallel.utils import is_pp_last_stage @@ -34,48 +35,52 @@ def finalize_grads_func(_model_list, num_tokens, force_all_reduce=False, **_kwar if num_tokens is None: raise RuntimeError("train_hetero.py expects calculate_per_token_loss=True") - global_num_tokens = torch.zeros(1, dtype=torch.float32, device="cuda") - if is_token_source_rank(): - # MCore has already summed loss-mask token counts across microbatches - # for this gradient-accumulation step. Match Megatron's normalization - # domain by reducing the language last-stage count over DP and CP. - token_count = num_tokens.to(device="cuda", dtype=torch.float32).sum().view(1) - dist.all_reduce(token_count, op=dist.ReduceOp.SUM, group=language_pg.dp_cp) - if dist.get_rank(language_pg.dp_cp) == 0: - global_num_tokens.copy_(token_count) - # Publish the already DP/CP-reduced language token count to encoder ranks too. - dist.all_reduce(global_num_tokens, op=dist.ReduceOp.MAX) - global_num_tokens_value = global_num_tokens.item() + with timeline_event("grad_finalize.token_count_reduce"): + global_num_tokens = torch.zeros(1, dtype=torch.float32, device="cuda") + if is_token_source_rank(): + # MCore has already summed loss-mask token counts across microbatches + # for this gradient-accumulation step. Match Megatron's normalization + # domain by reducing the language last-stage count over DP and CP. + token_count = num_tokens.to(device="cuda", dtype=torch.float32).sum().view(1) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM, group=language_pg.dp_cp) + if dist.get_rank(language_pg.dp_cp) == 0: + global_num_tokens.copy_(token_count) + # Publish the already DP/CP-reduced language token count to encoder ranks too. + dist.all_reduce(global_num_tokens, op=dist.ReduceOp.MAX) + global_num_tokens_value = global_num_tokens.item() if mimo_model.language_model is not None: debug_rank("finalizing language grads") - finalize_model_grads( - [mimo_model.language_model], - num_tokens=None, - pg_collection=language_pg, - force_all_reduce=force_all_reduce, - ) - debug_rank("language grads finalized") - for submodule in mimo_model.modality_submodules.values(): - if submodule is not None: - debug_rank("finalizing vision grads") + with timeline_event("grad_finalize.language"): finalize_model_grads( - [submodule], + [mimo_model.language_model], num_tokens=None, - pg_collection=vision_pg, + pg_collection=language_pg, force_all_reduce=force_all_reduce, ) + debug_rank("language grads finalized") + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + debug_rank("finalizing vision grads") + with timeline_event("grad_finalize.vision"): + finalize_model_grads( + [submodule], + num_tokens=None, + pg_collection=vision_pg, + force_all_reduce=force_all_reduce, + ) debug_rank("vision grads finalized") if global_num_tokens_value > 0: scale = 1.0 / global_num_tokens_value - if mimo_model.language_model is not None: - debug_rank("scaling language grads") - mimo_model.language_model.scale_gradients(scale) - for submodule in mimo_model.modality_submodules.values(): - if submodule is not None: - debug_rank("scaling vision grads") - submodule.scale_gradients(scale) + with timeline_event("grad_finalize.scale"): + if mimo_model.language_model is not None: + debug_rank("scaling language grads") + mimo_model.language_model.scale_gradients(scale) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + debug_rank("scaling vision grads") + submodule.scale_gradients(scale) mimo_model.config.no_sync_func = build_no_sync_func(mimo_model) mimo_model.config.finalize_model_grads_func = finalize_grads_func diff --git a/examples/mimo/training/hetero/loop.py b/examples/mimo/training/hetero/loop.py index 5e9574ccfe9..f0233626b7d 100644 --- a/examples/mimo/training/hetero/loop.py +++ b/examples/mimo/training/hetero/loop.py @@ -25,7 +25,9 @@ from megatron.core.pipeline_parallel.timeline import ( close_pipeline_timeline, flush_pipeline_timeline, + is_pipeline_timeline_active, set_pipeline_timeline_iteration, + timeline_instant, ) @@ -74,6 +76,7 @@ def run_train_loop(args: argparse.Namespace) -> None: result = train_step( args, model, topology, optimizer, opt_param_scheduler, communicator, data_iterator ) + record_cuda_memory_snapshot() flush_pipeline_timeline() logger.record_step(result) logger.maybe_log(iteration, optimizer, result) @@ -97,3 +100,16 @@ def build_pipeline_communicator( dim_mapping={"s": 0, "h": 2, "b": 1}, module_output_ndim={topology.encoder_name: 2}, ) + + +def record_cuda_memory_snapshot() -> None: + """Record CUDA memory usage in the active timeline without cross-rank synchronization.""" + if not torch.cuda.is_available() or not is_pipeline_timeline_active(): + return + timeline_instant( + "cuda.memory", + memory_allocated_bytes=torch.cuda.memory_allocated(), + max_memory_allocated_bytes=torch.cuda.max_memory_allocated(), + memory_reserved_bytes=torch.cuda.memory_reserved(), + max_memory_reserved_bytes=torch.cuda.max_memory_reserved(), + ) diff --git a/examples/mimo/training/hetero/step.py b/examples/mimo/training/hetero/step.py index f136de45873..97fcaa09261 100644 --- a/examples/mimo/training/hetero/step.py +++ b/examples/mimo/training/hetero/step.py @@ -98,8 +98,10 @@ def train_step( data_iterator, ) -> TrainStepResult: """Run one Megatron-shaped hetero training step.""" - zero_active_grad_buffers(model) - optimizer.zero_grad() + with timeline_event("grad_buffers.zero"): + zero_active_grad_buffers(model) + with timeline_event("optimizer.zero_grad"): + optimizer.zero_grad() debug_rank("starting forward/backward schedule") losses = schedule.forward_backward_pipelining_without_interleaving( @@ -116,12 +118,15 @@ def train_step( debug_rank("schedule complete") debug_rank("optimizer step starting") - update_successful, grad_norm, num_zeros_in_grad = optimizer.step() - update_successful = reduce_update_success(update_successful) + with timeline_event("optimizer.step"): + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + with timeline_event("update_success.reduce"): + update_successful = reduce_update_success(update_successful) debug_rank("optimizer step complete") if update_successful: - opt_param_scheduler.step(increment=get_global_batch_size(args)) + with timeline_event("optimizer_param_scheduler.step"): + opt_param_scheduler.step(increment=get_global_batch_size(args)) skipped_iter = 0 else: # Match Megatron train_step semantics: failed updates skip LR advancement but diff --git a/examples/mimo/training/hetero/timeline.py b/examples/mimo/training/hetero/timeline.py index 1fff25782bb..99d697af04b 100644 --- a/examples/mimo/training/hetero/timeline.py +++ b/examples/mimo/training/hetero/timeline.py @@ -33,6 +33,18 @@ def configure_hetero_timeline(args: argparse.Namespace, topology: HeteroTopology scope = os.environ.get("MIMO_TIMELINE_RANKS", args.timeline_ranks) dp_replica = int(os.environ.get("MIMO_TIMELINE_DP_REPLICA", args.timeline_dp_replica)) output_dir = args.timeline_dir or os.environ.get("MIMO_TIMELINE_DIR", "mimo_timeline") + iteration_start = _optional_env_int("MIMO_TIMELINE_ITER_START", args.timeline_iter_start) + iteration_end = _optional_env_int("MIMO_TIMELINE_ITER_END", args.timeline_iter_end) + if iteration_start is not None and iteration_start < 1: + raise ValueError("timeline iteration start must be >= 1") + if iteration_end is not None and iteration_end < 1: + raise ValueError("timeline iteration end must be >= 1") + if ( + iteration_start is not None + and iteration_end is not None + and iteration_end < iteration_start + ): + raise ValueError("timeline iteration end must be >= timeline iteration start") selected_ranks = select_timeline_ranks(scope, dp_replica, topology, world_size) role, coords = rank_role_and_coords(rank, topology) @@ -49,13 +61,16 @@ def configure_hetero_timeline(args: argparse.Namespace, topology: HeteroTopology }, cuda_events=args.timeline_cuda_events or env_flag_enabled("MIMO_TIMELINE_CUDA_EVENTS"), nvtx=args.timeline_nvtx or env_flag_enabled("MIMO_TIMELINE_NVTX"), + iteration_start=iteration_start, + iteration_end=iteration_end, ) if rank != 0: return None return ( "Pipeline timeline enabled: " - f"dir={output_dir}, scope={scope}, selected_ranks={len(selected_ranks)}" + f"dir={output_dir}, scope={scope}, selected_ranks={len(selected_ranks)}, " + f"iter_start={iteration_start}, iter_end={iteration_end}" ) @@ -108,3 +123,10 @@ def grid_coords(grid: HyperCommGrid, rank: int) -> dict[str, int]: def env_flag_enabled(name: str) -> bool: """Return whether an environment flag is set to a truthy value.""" return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"} + + +def _optional_env_int(name: str, default: Optional[int]) -> Optional[int]: + value = os.environ.get(name) + if value is None or value.strip() == "": + return default + return int(value) diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 9bbaf61e9ce..f1b629b6f6a 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -12,6 +12,7 @@ from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.timeline import timeline_event from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.utils import sharded_state_dict_default @@ -133,6 +134,19 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): ) return sharded_sd + @staticmethod + def _timeline_tensor_meta( + tensor: Optional[torch.Tensor], prefix: str = "tensor" + ) -> Dict[str, Any]: + """Return shape-only tensor metadata without forcing device synchronization.""" + if tensor is None: + return {f"{prefix}_is_none": True} + return { + f"{prefix}_shape": tuple(tensor.shape), + f"{prefix}_numel": int(tensor.numel()), + f"{prefix}_dtype": str(tensor.dtype), + } + def align_embeddings_by_token_positions( self, modality_embeddings: Dict[str, torch.Tensor], # [num_embeddings, hidden_dim] @@ -430,16 +444,31 @@ def _forward_encoders( submodule = self.modality_submodules[encoder_name] encoder_inputs = modality_inputs.get(encoder_name) if modality_inputs else None hidden_states = input_tensors.get(encoder_name) if input_tensors else None - output = submodule.forward(encoder_inputs=encoder_inputs, hidden_states=hidden_states) + with timeline_event( + "mimo.encoder_forward", + cuda=True, + encoder_name=encoder_name, + has_encoder_inputs=encoder_inputs is not None, + has_hidden_states=hidden_states is not None, + ): + output = submodule.forward( + encoder_inputs=encoder_inputs, hidden_states=hidden_states + ) if output is None and encoder_inputs is None and hidden_states is None: if self._has_encoder_tokens(input_ids, encoder_name): raise RuntimeError( f"{encoder_name} inputs are missing, but matching special tokens exist" ) - output = self._empty_encoder_output(submodule, input_ids) + with timeline_event("mimo.encoder_empty_output", encoder_name=encoder_name): + output = self._empty_encoder_output(submodule, input_ids) if output is not None: - self._attach_modality_split_sizes(output, input_ids, encoder_name) + with timeline_event( + "mimo.encoder_attach_split_sizes", + encoder_name=encoder_name, + **self._timeline_tensor_meta(output, "output"), + ): + self._attach_modality_split_sizes(output, input_ids, encoder_name) outputs[encoder_name] = output return outputs @@ -506,74 +535,125 @@ def _forward_language_module( Tuple of language model output and the matching, possibly sharded loss mask. """ lang_name = MIMO_LANGUAGE_MODULE_KEY - packed_seq_params = self._build_packed_seq_params(packing_kwargs) + with timeline_event( + "mimo.build_packed_seq_params", has_packing_kwargs=packing_kwargs is not None + ): + packed_seq_params = self._build_packed_seq_params(packing_kwargs) if self.role.is_first_stage(lang_name): # First stage: receive encoder embeddings, combine with text, pass to LM # Build modality embeddings dict from encoder outputs modality_embeddings = {} - if input_tensors: - for name, tensor in input_tensors.items(): - if name != lang_name: - modality_embeddings[name] = tensor + with timeline_event( + "mimo.collect_modality_embeddings", + input_tensor_names=list(input_tensors.keys()) if input_tensors else [], + ): + if input_tensors: + for name, tensor in input_tensors.items(): + if name != lang_name: + modality_embeddings[name] = tensor # Get text embeddings - text_embeddings = self.get_text_embeddings( - input_ids, position_ids, self.special_token_ids - ) + with timeline_event( + "mimo.text_embeddings", + cuda=True, + **self._timeline_tensor_meta(input_ids, "input_ids"), + ): + text_embeddings = self.get_text_embeddings( + input_ids, position_ids, self.special_token_ids + ) modality_embeddings["text"] = text_embeddings # Combine all embeddings - combined_embeddings = self.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, - input_ids=input_ids, - special_token_ids=self.special_token_ids, - ) + with timeline_event( + "mimo.align_embeddings", + cuda=True, + modality_names=list(modality_embeddings.keys()), + modality_shapes={ + name: tuple(tensor.shape) for name, tensor in modality_embeddings.items() + }, + **self._timeline_tensor_meta(input_ids, "input_ids"), + ): + combined_embeddings = self.align_embeddings_by_token_positions( + modality_embeddings=modality_embeddings, + input_ids=input_ids, + special_token_ids=self.special_token_ids, + ) + + with timeline_event( + "mimo.prepare_language_inputs", + cuda=True, + shard_loss_inputs=self.role.is_last_stage(lang_name), + **self._timeline_tensor_meta(combined_embeddings, "embeddings"), + ): + combined_embeddings, labels, loss_mask, packed_seq_params = ( + self._prepare_language_inputs( + embeddings=combined_embeddings, + labels=labels, + loss_mask=loss_mask, + packed_seq_params=packed_seq_params, + shard_loss_inputs=self.role.is_last_stage(lang_name), + ) + ) - combined_embeddings, labels, loss_mask, packed_seq_params = ( - self._prepare_language_inputs( - embeddings=combined_embeddings, + with timeline_event( + "mimo.language_model", + cuda=True, + is_first_stage=True, + is_last_stage=self.role.is_last_stage(lang_name), + **self._timeline_tensor_meta(combined_embeddings, "decoder_input"), + ): + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=combined_embeddings, labels=labels, - loss_mask=loss_mask, + attention_mask=attention_mask, packed_seq_params=packed_seq_params, - shard_loss_inputs=self.role.is_last_stage(lang_name), ) - ) - - lm_output = self.language_model( - input_ids=None, - position_ids=None, - decoder_input=combined_embeddings, - labels=labels, - attention_mask=attention_mask, - packed_seq_params=packed_seq_params, - ) else: # Non-first stage: receive hidden states from previous LM stage hidden_states = input_tensors.get(lang_name) if input_tensors else None - _, labels, loss_mask, packed_seq_params = self._prepare_language_inputs( - embeddings=None, - labels=labels, - loss_mask=loss_mask, - packed_seq_params=packed_seq_params, + with timeline_event( + "mimo.prepare_language_inputs", + cuda=True, shard_loss_inputs=self.role.is_last_stage(lang_name), - ) + **self._timeline_tensor_meta(hidden_states, "hidden_states"), + ): + _, labels, loss_mask, packed_seq_params = self._prepare_language_inputs( + embeddings=None, + labels=labels, + loss_mask=loss_mask, + packed_seq_params=packed_seq_params, + shard_loss_inputs=self.role.is_last_stage(lang_name), + ) # Set input tensor on language model for PP (unwrap DDP to reach GPTModel) if hidden_states is not None: - underlying_lm = unwrap_model(self.language_model) - if hasattr(underlying_lm, 'set_input_tensor'): - underlying_lm.set_input_tensor(hidden_states) - - lm_output = self.language_model( - input_ids=None, - position_ids=None, - decoder_input=None, - labels=labels, - attention_mask=attention_mask, - packed_seq_params=packed_seq_params, - ) + with timeline_event( + "mimo.set_language_input_tensor", + **self._timeline_tensor_meta(hidden_states, "hidden_states"), + ): + underlying_lm = unwrap_model(self.language_model) + if hasattr(underlying_lm, 'set_input_tensor'): + underlying_lm.set_input_tensor(hidden_states) + + with timeline_event( + "mimo.language_model", + cuda=True, + is_first_stage=False, + is_last_stage=self.role.is_last_stage(lang_name), + **self._timeline_tensor_meta(hidden_states, "hidden_states"), + ): + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=None, + labels=labels, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + ) # Key output for non-last stages so schedule can route to next LM stage if not self.role.is_last_stage(lang_name): diff --git a/megatron/core/pipeline_parallel/bridge_communicator.py b/megatron/core/pipeline_parallel/bridge_communicator.py index bc028970ff4..ed179c3ddbc 100644 --- a/megatron/core/pipeline_parallel/bridge_communicator.py +++ b/megatron/core/pipeline_parallel/bridge_communicator.py @@ -9,6 +9,7 @@ import torch.distributed as dist from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.pipeline_parallel.timeline import timeline_event class CommRole(Enum): @@ -330,6 +331,31 @@ def build_comm_map(self, src_tp_leaders: List[int], dest_tp_leaders: List[int]): role=CommRole.RECEIVER, recv_from_ranks=[src_rank] ) + def _timeline_meta(self, **metadata): + """Return bridge metadata shared by detailed timeline events.""" + return { + "src_module": self.src_module_name, + "dest_module": self.dest_module_name, + **metadata, + } + + @staticmethod + def _tensor_timeline_meta(tensor: torch.Tensor, prefix: str = "tensor") -> Dict[str, object]: + """Return shape-only tensor metadata without forcing device synchronization.""" + return { + f"{prefix}_shape": tuple(tensor.shape), + f"{prefix}_numel": int(tensor.numel()), + f"{prefix}_dtype": str(tensor.dtype), + } + + @staticmethod + def _shapes_timeline_meta(shapes: List[Tuple[int, ...]], prefix: str) -> Dict[str, object]: + """Return compact metadata for a list of communicated tensor shapes.""" + return { + f"{prefix}_shapes": [tuple(shape) for shape in shapes], + f"{prefix}_numel": [int(torch.Size(shape).numel()) for shape in shapes], + } + def send_forward(self, tensor_to_send: torch.Tensor): """Send forward activation tensor. @@ -349,14 +375,28 @@ def send_forward(self, tensor_to_send: torch.Tensor): # Send splits to destination ranks num_sends = len(rank_info.send_to_ranks) if num_sends > 0: - tensor_splits = self._split_tensor_at_batch_dim(tensor_to_send, num_sends) + with timeline_event( + "bridge.split.forward", + **self._timeline_meta( + peers=rank_info.send_to_ranks, + num_peers=num_sends, + **self._tensor_timeline_meta(tensor_to_send, "input"), + ), + ): + tensor_splits = self._split_tensor_at_batch_dim(tensor_to_send, num_sends) self._communicate_shapes(tensor_to_send_next=tensor_splits) for dest_rank, tensor_split in zip(rank_info.send_to_ranks, tensor_splits): logging.debug( f"[Bridge Comunicator] [send_forward] Rank {self.current_rank} " f"send to rank {dest_rank}" ) - dist.send(tensor_split, dst=dest_rank) + with timeline_event( + "bridge.tensor_send.forward.peer", + **self._timeline_meta( + peer_rank=dest_rank, **self._tensor_timeline_meta(tensor_split) + ), + ): + dist.send(tensor_split, dst=dest_rank) def recv_forward(self) -> torch.Tensor: """Receive forward activation tensor. @@ -399,14 +439,26 @@ def recv_forward(self) -> torch.Tensor: dtype=self.comm_dtype, requires_grad=True, ) - dist.recv(tensor_to_recv, src=src_rank) + with timeline_event( + "bridge.tensor_recv.forward.peer", + **self._timeline_meta(peer_rank=src_rank, tensor_shape=tuple(shape)), + ): + dist.recv(tensor_to_recv, src=src_rank) logging.debug( f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " f"received tensor from src rank {src_rank} " f"shape {tensor_to_recv.shape} sum {tensor_to_recv.sum()}" ) received_tensors_list.append(tensor_to_recv) - aggregated_tensor = torch.cat(received_tensors_list, dim=self._batch_dim) + with timeline_event( + "bridge.cat.forward", + **self._timeline_meta( + num_tensors=len(received_tensors_list), + batch_dim=self._batch_dim, + **self._shapes_timeline_meta(recv_forward_shapes, "input"), + ), + ): + aggregated_tensor = torch.cat(received_tensors_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " f"broadcasting tensor {aggregated_tensor.shape} sum {aggregated_tensor.sum()}" @@ -416,12 +468,30 @@ def recv_forward(self) -> torch.Tensor: shape_tensor = torch.tensor( aggregated_tensor.shape, device=aggregated_tensor.device, dtype=torch.int64 ) - dist.broadcast(shape_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg) + with timeline_event( + "bridge.broadcast.forward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_tensor), + ), + ): + dist.broadcast( + shape_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg + ) # Step 2: broadcast the actual tensor - dist.broadcast( - aggregated_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.forward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_tensor), + ), + ): + dist.broadcast( + aggregated_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg + ) return aggregated_tensor @@ -433,9 +503,15 @@ def recv_forward(self) -> torch.Tensor: shape_tensor = torch.empty( (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast( - shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.forward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, src_rank=self.dest_local_leader_rank + ), + ): + dist.broadcast( + shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg + ) received_shape = tuple(shape_tensor.tolist()) received_tensor = torch.empty( @@ -446,9 +522,19 @@ def recv_forward(self) -> torch.Tensor: ) # Receive the full tensor via broadcast - dist.broadcast( - received_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.forward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.dest_local_leader_rank, + tensor_shape=received_shape, + ), + ): + dist.broadcast( + received_tensor, + src=self.dest_local_leader_rank, + group=self.dest_grid_broadcast_pg, + ) logging.debug( f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " @@ -479,7 +565,15 @@ def send_backward(self, grad_tensor: torch.Tensor): ), f"Rank {self.current_rank} is not the leader rank" # Send gradients back to source ranks num_receives = len(rank_info.recv_from_ranks) - tensor_splits = self._split_tensor_at_batch_dim(grad_tensor, num_receives) + with timeline_event( + "bridge.split.backward", + **self._timeline_meta( + peers=rank_info.recv_from_ranks, + num_peers=num_receives, + **self._tensor_timeline_meta(grad_tensor, "input"), + ), + ): + tensor_splits = self._split_tensor_at_batch_dim(grad_tensor, num_receives) self._communicate_shapes(tensor_to_send_prev=tensor_splits) if num_receives > 0: for src_rank, tensor_split in zip(rank_info.recv_from_ranks, tensor_splits): @@ -489,7 +583,13 @@ def send_backward(self, grad_tensor: torch.Tensor): f"sending gradient to src rank {src_rank} " f"shape {tensor_split.shape} sum {tensor_split.sum()}" ) - dist.send(tensor_split, dst=src_rank) + with timeline_event( + "bridge.tensor_send.backward.peer", + **self._timeline_meta( + peer_rank=src_rank, **self._tensor_timeline_meta(tensor_split) + ), + ): + dist.send(tensor_split, dst=src_rank) def recv_backward(self) -> torch.Tensor: """Receive backward gradient tensor. @@ -528,7 +628,11 @@ def recv_backward(self) -> torch.Tensor: grad_tensor = torch.empty( grad_shape, device=torch.cuda.current_device(), dtype=self.comm_dtype ) - dist.recv(grad_tensor, src=dest_rank) + with timeline_event( + "bridge.tensor_recv.backward.peer", + **self._timeline_meta(peer_rank=dest_rank, tensor_shape=tuple(grad_shape)), + ): + dist.recv(grad_tensor, src=dest_rank) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " f"received gradient from dest rank {dest_rank} " @@ -537,7 +641,15 @@ def recv_backward(self) -> torch.Tensor: received_gradients_list.append(grad_tensor) # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim) + with timeline_event( + "bridge.cat.backward", + **self._timeline_meta( + num_tensors=len(received_gradients_list), + batch_dim=self._batch_dim, + **self._shapes_timeline_meta(recv_grad_shapes, "input"), + ), + ): + aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -546,12 +658,30 @@ def recv_backward(self) -> torch.Tensor: shape_tensor = torch.tensor( aggregated_gradient.shape, device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast(shape_tensor, src=self.current_rank, group=self.src_grid_broadcast_pg) + with timeline_event( + "bridge.broadcast.backward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_gradient), + ), + ): + dist.broadcast( + shape_tensor, src=self.current_rank, group=self.src_grid_broadcast_pg + ) # Scatter the tensors to all ranks in the group - dist.broadcast( - aggregated_gradient, src=self.current_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_gradient), + ), + ): + dist.broadcast( + aggregated_gradient, src=self.current_rank, group=self.src_grid_broadcast_pg + ) return aggregated_gradient elif ( @@ -562,9 +692,15 @@ def recv_backward(self) -> torch.Tensor: shape_tensor = torch.empty( (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast( - shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, src_rank=self.src_local_leader_rank + ), + ): + dist.broadcast( + shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg + ) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " @@ -575,9 +711,19 @@ def recv_backward(self) -> torch.Tensor: received_shape, device=torch.cuda.current_device(), dtype=self.comm_dtype ) - dist.broadcast( - received_gradient, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.src_local_leader_rank, + tensor_shape=received_shape, + ), + ): + dist.broadcast( + received_gradient, + src=self.src_local_leader_rank, + group=self.src_grid_broadcast_pg, + ) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " f"received gradient from scatter operation, shape {received_gradient.shape}" @@ -615,7 +761,16 @@ def send_forward_recv_backward( ), f"Rank {self.current_rank} is not the leader rank" num_sends = len(rank_info.send_to_ranks) - activation_splits = self._split_tensor_at_batch_dim(input_tensor, num_sends) + with timeline_event( + "bridge.split.forward", + **self._timeline_meta( + op="send_forward_recv_backward", + peers=rank_info.send_to_ranks, + num_peers=num_sends, + **self._tensor_timeline_meta(input_tensor, "input"), + ), + ): + activation_splits = self._split_tensor_at_batch_dim(input_tensor, num_sends) # Communicate shapes for both directions (send forward, receive backward) recv_forward_shapes, recv_grad_shapes = self._communicate_shapes( tensor_to_send_next=activation_splits, recv_next=True @@ -655,12 +810,31 @@ def send_forward_recv_backward( f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " f"executing {len(ops)} simultaneous P2P operations" ) - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() + with timeline_event( + "bridge.p2p.forward_backward", + **self._timeline_meta( + peers=rank_info.send_to_ranks, + num_ops=len(ops), + **self._shapes_timeline_meta( + [tuple(tensor.shape) for tensor in activation_splits], "send" + ), + **self._shapes_timeline_meta(recv_grad_shapes, "recv"), + ), + ): + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim) + with timeline_event( + "bridge.cat.backward", + **self._timeline_meta( + num_tensors=len(received_gradients_list), + batch_dim=self._batch_dim, + **self._shapes_timeline_meta(recv_grad_shapes, "input"), + ), + ): + aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -670,14 +844,30 @@ def send_forward_recv_backward( shape_tensor = torch.tensor( tensor_shape_to_broadcast, device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast( - shape_tensor, src=self.current_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_gradient), + ), + ): + dist.broadcast( + shape_tensor, src=self.current_rank, group=self.src_grid_broadcast_pg + ) # Broadcast the tensors to all ranks in the group - dist.broadcast( - aggregated_gradient, src=self.current_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_gradient), + ), + ): + dist.broadcast( + aggregated_gradient, src=self.current_rank, group=self.src_grid_broadcast_pg + ) return aggregated_gradient @@ -689,18 +879,34 @@ def send_forward_recv_backward( shape_tensor = torch.empty( (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast( - shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, src_rank=self.src_local_leader_rank + ), + ): + dist.broadcast( + shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg + ) # Use the received shape to create tensor for broadcast received_shape = tuple(shape_tensor.tolist()) received_gradient = torch.empty( received_shape, device=torch.cuda.current_device(), dtype=self.comm_dtype ) - dist.broadcast( - received_gradient, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.backward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.src_local_leader_rank, + tensor_shape=received_shape, + ), + ): + dist.broadcast( + received_gradient, + src=self.src_local_leader_rank, + group=self.src_grid_broadcast_pg, + ) logging.debug( f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " f"received gradient from broadcast, shape {received_gradient.shape}" @@ -734,7 +940,16 @@ def send_backward_recv_forward( ), f"Rank {self.current_rank} is not the leader rank" num_receives = len(rank_info.recv_from_ranks) - gradient_splits = self._split_tensor_at_batch_dim(grad_tensor, num_receives) + with timeline_event( + "bridge.split.backward", + **self._timeline_meta( + op="send_backward_recv_forward", + peers=rank_info.recv_from_ranks, + num_peers=num_receives, + **self._tensor_timeline_meta(grad_tensor, "input"), + ), + ): + gradient_splits = self._split_tensor_at_batch_dim(grad_tensor, num_receives) # Communicate shapes for both directions (send backward, receive forward) recv_forward_shapes, recv_grad_shapes = self._communicate_shapes( tensor_to_send_prev=gradient_splits, recv_prev=True @@ -779,12 +994,33 @@ def send_backward_recv_forward( f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " f"executing {len(ops)} simultaneous P2P operations" ) - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() + with timeline_event( + "bridge.p2p.backward_forward", + **self._timeline_meta( + peers=rank_info.recv_from_ranks, + num_ops=len(ops), + **self._shapes_timeline_meta( + [tuple(tensor.shape) for tensor in gradient_splits], "send" + ), + **self._shapes_timeline_meta(recv_forward_shapes, "recv"), + ), + ): + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() # Concatenate received activations - aggregated_activation = torch.cat(received_activations_list, dim=self._batch_dim) + with timeline_event( + "bridge.cat.forward", + **self._timeline_meta( + num_tensors=len(received_activations_list), + batch_dim=self._batch_dim, + **self._shapes_timeline_meta(recv_forward_shapes, "input"), + ), + ): + aggregated_activation = torch.cat( + received_activations_list, dim=self._batch_dim + ) logging.debug( f"[Bridge Communicator] [send_backward_recv_forward] Rank {self.current_rank} " f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}" @@ -795,14 +1031,32 @@ def send_backward_recv_forward( shape_tensor = torch.tensor( tensor_shape_to_scatter, device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast( - shape_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.forward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_activation), + ), + ): + dist.broadcast( + shape_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg + ) # Scatter the tensors to all ranks in the group - dist.broadcast( - aggregated_activation, src=self.current_rank, group=self.dest_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.forward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.current_rank, + **self._tensor_timeline_meta(aggregated_activation), + ), + ): + dist.broadcast( + aggregated_activation, + src=self.current_rank, + group=self.dest_grid_broadcast_pg, + ) return aggregated_activation elif ( @@ -812,9 +1066,15 @@ def send_backward_recv_forward( shape_tensor = torch.empty( (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 ) - dist.broadcast( - shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg - ) + with timeline_event( + "bridge.broadcast.forward_shape", + **self._timeline_meta( + bridge_role=rank_info.role.value, src_rank=self.dest_local_leader_rank + ), + ): + dist.broadcast( + shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg + ) # Use the received shape to create tensor for scatter operation received_shape = tuple(shape_tensor.tolist()) @@ -824,11 +1084,19 @@ def send_backward_recv_forward( dtype=self.comm_dtype, requires_grad=True, ) - dist.broadcast( - received_activation, - src=self.dest_local_leader_rank, - group=self.dest_grid_broadcast_pg, - ) + with timeline_event( + "bridge.broadcast.forward_tensor", + **self._timeline_meta( + bridge_role=rank_info.role.value, + src_rank=self.dest_local_leader_rank, + tensor_shape=received_shape, + ), + ): + dist.broadcast( + received_activation, + src=self.dest_local_leader_rank, + group=self.dest_grid_broadcast_pg, + ) logging.debug( f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " f"received activation from scatter operation, shape {received_activation.shape}" @@ -874,6 +1142,8 @@ def _communicate_shapes( ops = [] recv_forward_shape_tensors = [] recv_grad_shape_tensors = [] + send_forward_shapes = [] + send_grad_shapes = [] if rank_info.role == CommRole.SENDER: # Prepare send operations for forward shapes @@ -883,6 +1153,7 @@ def _communicate_shapes( ) # Add send operations for each destination for dest_rank, tensor in zip(rank_info.send_to_ranks, tensors_to_send): + send_forward_shapes.append(tuple(tensor.shape)) send_shape_tensor = torch.tensor( tensor.shape, device=torch.cuda.current_device(), dtype=torch.int64 ) @@ -927,6 +1198,7 @@ def _communicate_shapes( ) for src_rank, tensor in zip(rank_info.recv_from_ranks, tensors_to_send): + send_grad_shapes.append(tuple(tensor.shape)) grad_shape_tensor = torch.tensor( tensor.shape, device=torch.cuda.current_device(), dtype=torch.int64 ) @@ -938,9 +1210,40 @@ def _communicate_shapes( # Execute all operations in a single batch if ops: - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() + meta = self._timeline_meta( + bridge_role=rank_info.role.value, + num_ops=len(ops), + send_next=bool(tensor_to_send_next is not None), + recv_next=bool(recv_next), + recv_prev=bool(recv_prev), + send_prev=bool(tensor_to_send_prev is not None), + send_next_ranks=( + rank_info.send_to_ranks + if rank_info.role == CommRole.SENDER and tensor_to_send_next is not None + else [] + ), + recv_next_ranks=( + rank_info.send_to_ranks + if rank_info.role == CommRole.SENDER and recv_next + else [] + ), + recv_prev_ranks=( + rank_info.recv_from_ranks + if rank_info.role == CommRole.RECEIVER and recv_prev + else [] + ), + send_prev_ranks=( + rank_info.recv_from_ranks + if rank_info.role == CommRole.RECEIVER and tensor_to_send_prev is not None + else [] + ), + **self._shapes_timeline_meta(send_forward_shapes, "send_forward"), + **self._shapes_timeline_meta(send_grad_shapes, "send_grad"), + ) + with timeline_event("bridge.shape_exchange", **meta): + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() # Extract shapes from received tensors for forward_shape_tensor in recv_forward_shape_tensors: diff --git a/megatron/core/pipeline_parallel/timeline.py b/megatron/core/pipeline_parallel/timeline.py index 051408c257c..ae2c23cff81 100644 --- a/megatron/core/pipeline_parallel/timeline.py +++ b/megatron/core/pipeline_parallel/timeline.py @@ -25,6 +25,8 @@ class PipelineTimelineRecorder: metadata: dict[str, Any] = field(default_factory=dict) cuda_events: bool = False nvtx: bool = False + iteration_start: Optional[int] = None + iteration_end: Optional[int] = None iteration: Optional[int] = None _records: list[dict[str, Any]] = field(default_factory=list) _context_stack: list[dict[str, Any]] = field(default_factory=list) @@ -73,6 +75,7 @@ def record(self, event: str, cuda: bool = False, **metadata) -> Iterator[None]: "world_size": self.world_size, "role": self.role, "start_time_ns": start_time_ns, + "start_perf_ns": start_perf_ns, "duration_us": (end_perf_ns - start_perf_ns) / 1000.0, "ok": ok, } @@ -109,6 +112,18 @@ def close(self) -> None: self._file.close() self._file = None + def is_active(self) -> bool: + """Return whether the current iteration should be recorded.""" + if self.iteration_start is None and self.iteration_end is None: + return True + if self.iteration is None: + return False + if self.iteration_start is not None and self.iteration < self.iteration_start: + return False + if self.iteration_end is not None and self.iteration > self.iteration_end: + return False + return True + def _format_nvtx(self, event: str, metadata: dict[str, Any]) -> str: microbatch = metadata.get("microbatch") if microbatch is None: @@ -129,6 +144,8 @@ def configure_pipeline_timeline( metadata: Optional[dict[str, Any]] = None, cuda_events: bool = False, nvtx: bool = False, + iteration_start: Optional[int] = None, + iteration_end: Optional[int] = None, ) -> None: """Configure the process-local pipeline timeline recorder.""" global _RECORDER @@ -144,6 +161,8 @@ def configure_pipeline_timeline( metadata=metadata or {}, cuda_events=cuda_events, nvtx=nvtx, + iteration_start=iteration_start, + iteration_end=iteration_end, ) @@ -169,11 +188,24 @@ def close_pipeline_timeline() -> None: def timeline_event(event: str, cuda: bool = False, **metadata): """Return a no-op or recording context manager for one timeline event.""" - if _RECORDER is None: + if _RECORDER is None or not _RECORDER.is_active(): return contextlib.nullcontext() return _RECORDER.record(event, cuda=cuda, **metadata) +def is_pipeline_timeline_active() -> bool: + """Return whether the current rank/iteration is writing pipeline timeline events.""" + return _RECORDER is not None and _RECORDER.is_active() + + +def timeline_instant(event: str, **metadata) -> None: + """Write a zero-work timeline event with metadata for the current rank/iteration.""" + if _RECORDER is None or not _RECORDER.is_active(): + return + with _RECORDER.record(event, **metadata): + pass + + def _jsonable(value): """Convert common non-JSON values used in trace metadata.""" if isinstance(value, dict):