From 214c08bb441759cc46e9f5b8d9603b54a5be145b Mon Sep 17 00:00:00 2001 From: ykarnati Date: Fri, 17 Apr 2026 12:17:38 -0700 Subject: [PATCH 01/11] Add PP>1 support for LLM in colocated MIMO training (NMFW-19) Three-phase execution for colocated encoder PP=1 + LLM PP>1: - Phase 1: Encoder forward + bridge communicate on the full batch, with all ranks participating in the collective. - Phase 2: 1F1B LLM pipeline over microbatch slices of the detached encoder embeddings. - Phase 3: Encoder backward on the full batch, with the encoder gradient broadcast from PP rank 0 to the other PP ranks first. Changes: - MimoModel detects LLM PP>1 from module_to_grid_map and overrides the language ModuleStageInfo so is_first_stage / is_last_stage reflect PP position. - MimoModel.forward routes non-first PP stages through _forward_language_module using P2P hidden states and unwraps the dict return for the schedule. - _forward_all_modules accepts a precomputed encoder_embeddings dict to skip encoder forward inside each LLM microbatch iteration. - New encode_and_communicate() helper runs encoder forward + bridge transform; used by Phase 1 and reused by the 3-phase schedule. - colocated_schedule.py implements colocated_forward_backward_with_pp which drives the three phases and broadcasts encoder gradients. Tests: - test_mimo_colocated_pp.py: fan-in, equal-DP, and grad-accumulation cases at LLM PP=2 / encoder PP=1 on 8 GPUs. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/models/mimo/colocated_schedule.py | 215 ++++++++ megatron/core/models/mimo/model/base.py | 97 +++- .../models/test_mimo_colocated_pp.py | 499 ++++++++++++++++++ tests/unit_tests/models/test_mimo_model.py | 6 + 4 files changed, 798 insertions(+), 19 deletions(-) create mode 100644 megatron/core/models/mimo/colocated_schedule.py create mode 100644 tests/unit_tests/models/test_mimo_colocated_pp.py diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py new file mode 100644 index 00000000000..8cf96ded6c6 --- /dev/null +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -0,0 +1,215 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Three-phase schedule for colocated MIMO training with LLM PP>1. + +Phase 1: Encoder forward + communicate for the full batch (all ranks synchronized). +Phase 2: LLM 1F1B pipeline with detached encoder embeddings sliced per microbatch. +Phase 3: Encoder backward for the full batch (all ranks synchronized). + +Encoder runs on all ranks (PP=1) and its TP/DP collectives require all ranks +to participate simultaneously. The 1F1B pipeline staggers ranks across PP stages, +so encoder collectives cannot run inside the pipeline. The three-phase design +separates encoder (synchronized) from LLM (pipelined) by detaching the autograd +graph at the encoder-LLM boundary. +""" + +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.pipeline_parallel import schedules + + +def colocated_forward_backward_with_pp( + mimo_model, + data_iterator, + num_microbatches: int, + encoder_grid: Optional[HyperCommGrid] = None, + llm_grid: Optional[HyperCommGrid] = None, + encoder_name: str = "images", + forward_only: bool = False, + **schedule_kwargs, +): + """Three-phase colocated training: encoder batch -> LLM pipeline -> encoder backward. + + Args: + mimo_model: MimoModel with colocated communicators and lm_has_pp=True. + data_iterator: Yields dicts with input_ids, labels, etc. + num_microbatches: Number of microbatches for the LLM pipeline. + encoder_grid: Encoder HyperCommGrid (for DP fan-in slicing). + llm_grid: LLM HyperCommGrid (for PP group). + encoder_name: Modality name for the encoder (e.g., "images"). + forward_only: Skip backward passes if True. + **schedule_kwargs: Passed to forward_backward_pipelining_without_interleaving. + Must include p2p_communicator, pg_collection, seq_length, micro_batch_size. + """ + pp_group = llm_grid.get_pg("pp") if llm_grid and 'pp' in llm_grid.dim_names else None + is_pp_first = pp_group is None or pp_group.rank() == 0 + + # ── Phase 1: Encoder forward on full batch (one pass) ──────────────── + # All ranks participate (encoder is PP=1, communicate is collective). + all_batches = [next(data_iterator) for _ in range(num_microbatches)] + full_encoder_input = _concat_encoder_inputs(all_batches, encoder_name) + _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid) + + enc_out = mimo_model.encode_and_communicate({encoder_name: full_encoder_input}) + + # Detach: sever autograd link to encoder so Phase 2 has no encoder collectives. + # Microbatch slices are views into detached_full — their .grad accumulates + # into detached_full.grad automatically via PyTorch's view gradient semantics. + detached_full = {k: v.detach().requires_grad_(True) for k, v in enc_out.items()} + lm_data = _build_lm_microbatches(detached_full, all_batches, num_microbatches) + + # ── Phase 2: LLM 1F1B pipeline ────────────────────────────────────── + # Only LLM P2P communication (within PP group). No encoder collectives. + cache_iter = iter(lm_data) + + def _lm_forward_step(data_iterator_unused, model, *args): + cached = next(cache_iter) + output_tensor, loss_mask = model( + input_ids=cached['input_ids'], + labels=cached['labels'], + loss_mask=cached['loss_mask'], + position_ids=cached['position_ids'], + encoder_embeddings=cached['encoder_embeddings'], + ) + return output_tensor, partial(_loss_func, cached['loss_mask']) + + losses = schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=_lm_forward_step, + data_iterator=cache_iter, + model=[mimo_model], + num_microbatches=num_microbatches, + forward_only=forward_only, + **schedule_kwargs, + ) + + # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── + # detached_full.grad was populated by Phase 2's per-microbatch LLM backward + # (accumulated across microbatch view slices on PP stage 0). + # Broadcast to PP stage 1+ then run one encoder backward for the full batch. + if not forward_only and enc_out: + _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first) + for key in enc_out: + grad = detached_full[key].grad + if grad is not None: + torch.autograd.backward(enc_out[key], grad_tensors=grad) + + return losses + + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _concat_encoder_inputs(all_batches, encoder_name): + """Concatenate encoder inputs from all microbatches along batch dim (dim 1).""" + first = all_batches[0] + result = {} + if not (first.get('modality_inputs') and encoder_name in first['modality_inputs']): + return result + for enc_name in first['modality_inputs'][encoder_name]: + result[enc_name] = {} + for key in first['modality_inputs'][encoder_name][enc_name]: + vals = [ + b['modality_inputs'][encoder_name][enc_name][key] + for b in all_batches + if b.get('modality_inputs') and encoder_name in b['modality_inputs'] + ] + tensors = [v for v in vals if isinstance(v, torch.Tensor)] + result[enc_name][key] = torch.cat(tensors, dim=1) if tensors else vals[0] + return result + + +def _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid): + """Slice concatenated encoder input for fan-in (enc_dp > llm_dp).""" + if encoder_grid is None or llm_grid is None: + return + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if enc_dp <= llm_dp: + return + scale = enc_dp // llm_dp + slot = encoder_grid.get_pg("dp").rank() % scale + for enc_name in full_encoder_input: + for key, tensor in full_encoder_input[enc_name].items(): + if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2: + bs = tensor.shape[1] + ss = bs // scale + if ss == 0: + raise ValueError( + f"Encoder fan-in produces zero-sized batch: " + f"total_batch={bs}, scale={scale}. Increase micro_batch_size." + ) + full_encoder_input[enc_name][key] = tensor[ + :, slot * ss : (slot + 1) * ss, : + ].contiguous() + + +def _build_lm_microbatches(detached_full, all_batches, num_microbatches): + """Slice detached encoder output into per-microbatch views for the LLM pipeline.""" + if not detached_full: + # Text-only batch: no encoder embeddings to slice + return [ + { + 'encoder_embeddings': {}, + 'input_ids': all_batches[mb].get('input_ids'), + 'labels': all_batches[mb].get('labels'), + 'loss_mask': all_batches[mb].get('loss_mask'), + 'position_ids': all_batches[mb].get('position_ids'), + } + for mb in range(num_microbatches) + ] + + sample = next(iter(detached_full.values())) + batch_dim = 1 if sample.ndim == 3 else 0 + total_batch = sample.shape[batch_dim] + assert total_batch % num_microbatches == 0, ( + f"Encoder output batch ({total_batch}) must be divisible " + f"by num_microbatches ({num_microbatches})" + ) + mb_size = total_batch // num_microbatches + + lm_data = [] + for mb in range(num_microbatches): + s, e = mb * mb_size, (mb + 1) * mb_size + mb_enc = {} + for k, v in detached_full.items(): + mb_enc[k] = v[:, s:e, :] if v.ndim == 3 else v[s:e, :] + lm_data.append( + { + 'encoder_embeddings': mb_enc, + 'input_ids': all_batches[mb].get('input_ids'), + 'labels': all_batches[mb].get('labels'), + 'loss_mask': all_batches[mb].get('loss_mask'), + 'position_ids': all_batches[mb].get('position_ids'), + } + ) + return lm_data + + +def _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first): + """Broadcast encoder gradient from PP stage 0 to stage 1+ ranks.""" + if pp_group is None or pp_group.size() <= 1: + return + src = dist.get_global_rank(pp_group, 0) + for key in enc_out: + if is_pp_first: + assert ( + detached_full[key].grad is not None + ), f"No encoder gradient on PP stage 0 for '{key}'" + dist.broadcast(detached_full[key].grad, src=src, group=pp_group) + else: + grad = torch.zeros_like(detached_full[key]) + dist.broadcast(grad, src=src, group=pp_group) + detached_full[key].grad = grad + + +def _loss_func(loss_mask, output_tensor): + """Default loss function for the LLM pipeline.""" + if output_tensor is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + loss = output_tensor.float().sum() + return loss, {'loss_reduced': loss.detach().item()} diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 703bc9d9950..95a1651ef5c 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -67,6 +67,26 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # in TP/DP within those ranks. self._build_colocated_communicators() + # Detect LLM PP>1 for three-phase colocated execution + self.lm_has_pp = False + self.lm_is_first_pp_stage = True + if mimo_config.module_to_grid_map: + lang_grid = mimo_config.module_to_grid_map.get(MIMO_LANGUAGE_MODULE_KEY) + if lang_grid and 'pp' in lang_grid.dim_names: + pp_idx = lang_grid.dim_names.index('pp') + if lang_grid.shape[pp_idx] > 1: + self.lm_has_pp = True + pp_group = lang_grid.get_pg('pp') + pp_rank = pp_group.rank() + pp_size = pp_group.size() + self.lm_is_first_pp_stage = pp_rank == 0 + # Update language module stage info for PP>1 + from megatron.core.models.mimo.config.role import ModuleStageInfo + + self.role.modules[MIMO_LANGUAGE_MODULE_KEY] = ModuleStageInfo( + is_first_stage=(pp_rank == 0), is_last_stage=(pp_rank == pp_size - 1) + ) + # Use special token IDs from the config self.special_token_ids = ( mimo_config.special_token_ids.copy() if mimo_config.special_token_ids else {} @@ -318,6 +338,7 @@ def forward( labels: Optional[torch.Tensor] = None, modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None, packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass through the multimodal model. @@ -362,6 +383,20 @@ def forward( input_tensors = getattr(self, 'input_tensors', None) if self.role.mode == ModuleLayout.COLOCATED: + if self.lm_has_pp and input_tensors is not None: + # PP>1 non-first stage: hidden states from P2P + lm_result = self._forward_language_module( + input_ids, + position_ids, + attention_mask, + labels, + {MIMO_LANGUAGE_MODULE_KEY: input_tensors}, + ) + # Unwrap dict for P2P (schedule uses plain tensors, not dicts) + if isinstance(lm_result, dict): + lm_result = lm_result[MIMO_LANGUAGE_MODULE_KEY] + return lm_result, loss_mask + return self._forward_all_modules( input_ids, position_ids, @@ -370,6 +405,7 @@ def forward( labels, modality_inputs, packing_kwargs, + encoder_embeddings=encoder_embeddings, ) if self.role.mode == ModuleLayout.NON_COLOCATED: @@ -531,6 +567,22 @@ def _apply_colocated_comms(self, modality_embeddings): ) return modality_embeddings + def encode_and_communicate(self, modality_inputs): + """Run encoder forward + colocated TP/DP transform (collective).""" + modality_embeddings = {} + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) + return modality_embeddings + def _forward_all_modules( self, input_ids: torch.Tensor, @@ -540,6 +592,7 @@ def _forward_all_modules( labels: Optional[torch.Tensor], modality_inputs: Optional[Dict[str, Dict[str, Any]]], packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass when all modules are on all ranks (no multi-module PP). @@ -556,26 +609,32 @@ def _forward_all_modules( packed_seq_params.qkv_format = 'thd' logger.debug(f"Packed sequence parameters: {packed_seq_params}") - # 1. Process each modality to get embeddings - modality_embeddings = {} - - for modality_name, submodule in self.modality_submodules.items(): - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - modality_embeddings[modality_name] = embeddings - logger.debug( - f"Generated embeddings for {modality_name} with shape {embeddings.shape}" - ) + if encoder_embeddings is not None: + # PP>1 path: encoder forward + communicate already ran in Phase 1; + # reuse the precomputed embeddings for every LLM microbatch. + modality_embeddings = encoder_embeddings + else: + # 1. Process each modality to get embeddings + modality_embeddings = {} - # Apply colocated communication if configured (no-op when colocated_comms is empty) - if self.colocated_comms: - modality_embeddings = self._apply_colocated_comms(modality_embeddings) + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + logger.debug(f"Processing {modality_name} modality") + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + logger.debug( + f"Generated embeddings for {modality_name} with shape " + f"{embeddings.shape}" + ) + + # Apply colocated communication if configured (no-op when colocated_comms is empty) + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py new file mode 100644 index 00000000000..5aacdcad0d9 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -0,0 +1,499 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Tests for colocated MIMO training with LLM PP>1. + +Uses two-phase execution: encoder pre-compute + 1F1B LLM pipeline. + +Run individually (8 GPUs): + uv run python -m torch.distributed.run --nproc_per_node=8 \ + -m pytest tests/unit_tests/models/test_mimo_colocated_pp.py -v +""" + +import logging +from contextlib import ExitStack, contextmanager +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TERowParallelLinear = None + +logger = logging.getLogger(__name__) + +_active_grids: list = [] +_embedding_pg_cache: dict = {} + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + grid = HyperCommGrid( + shape=[tp, cp, pp, dp, 1, 1], + dim_names=["tp", "cp", "pp", "dp", "ep", "expt_dp"], + rank_offset=offset, + backend="nccl", + ) + for dim in ["tp", "cp", "pp", "dp", "ep", "expt_dp"]: + grid.create_pg([dim]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) + _active_grids.append(grid) + return grid + + +def destroy_all_grids(): + for g in _active_grids: + g.destroy() + _active_grids.clear() + _embedding_pg_cache.clear() + BridgeCommunicator.destroy_broadcast_pgs() + + +def create_all_embedding_groups(grids): + for grid in grids: + pp_group = grid.get_pg("pp") + if not pp_group: + continue + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + key = tuple(pp_ranks) + if key not in _embedding_pg_cache: + pos = [pp_ranks[0]] + embd = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd.append(pp_ranks[-1]) + _embedding_pg_cache[key] = (dist.new_group(ranks=pos), dist.new_group(ranks=embd)) + + +def get_pg_collection(grid, is_language_model=False): + pg = ProcessGroupCollection() + pg.tp = grid.get_pg("tp") + pg.cp = grid.get_pg("cp") + pg.pp = grid.get_pg("pp") + pg.ep = grid.get_pg("ep") + pg.dp = grid.get_pg("dp") + pg.dp_cp = grid.get_pg(["dp", "cp"]) + pg.expt_dp = grid.get_pg("expt_dp") + pp_ranks = sorted(dist.get_process_group_ranks(pg.pp)) + key = tuple(pp_ranks) + if key in _embedding_pg_cache: + pos_pg, embd_pg = _embedding_pg_cache[key] + pg.pos_embd = pos_pg if is_pp_first_stage(pg.pp) else None + pg.embd = ( + embd_pg + if is_language_model and (is_pp_last_stage(pg.pp) or is_pp_first_stage(pg.pp)) + else None + ) + return pg + + +def get_language_model_spec( + num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection +): + pp_rank = dist.get_rank(pg_collection.pp) + pp_size = dist.get_world_size(pg_collection.pp) + tp_size = pg_collection.tp.size() if pg_collection.tp else 1 + lm_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl='te', + ) + return ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), + "pg_collection": pg_collection, + }, + ) + + +def get_vision_submodules_spec( + num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection +): + from megatron.core.transformer.transformer_block import TransformerBlock + + tp_size = pg_collection.tp.size() if pg_collection.tp else 1 + vision_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=1, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + proj_cfg = TransformerConfig( + num_layers=1, hidden_size=language_hidden_size, num_attention_heads=1 + ) + proj_cfg.ffn_hidden_size = language_hidden_size + proj_cfg.bias_activation_fusion = True + proj_cfg.add_bias_linear = True + proj_cfg.activation_func = torch.nn.functional.gelu + + return ModuleSpec( + module=VisionModalitySubmodules, + submodules={ + "encoders": { + "clip_encoder": ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": pg_collection, + "pre_process": True, + "post_process": True, + }, + ) + }, + "input_projections": [ + ModuleSpec( + module=MultimodalProjector, + params={ + "config": proj_cfg, + "submodules": MLPSubmodules( + linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + "projector_type": "mlp", + "input_size": vision_config.hidden_size, + "tp_group": pg_collection.tp, + }, + ) + ], + }, + ) + + +class DataIterator: + def __init__( + self, + hidden_size, + seq_length, + micro_batch_size, + vocab_size, + encoder_name, + image_token_id=50257, + image_seq_length=None, + ): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.vocab_size = vocab_size + self.encoder_name = encoder_name + self.image_token_id = image_token_id + self.image_seq_length = image_seq_length or (seq_length // 2) + + def __iter__(self): + return self + + def __next__(self): + encoder_hidden_states = torch.randn( + self.image_seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + image_tokens = torch.full( + (self.micro_batch_size, self.image_seq_length), + self.image_token_id, + dtype=torch.long, + device='cuda', + ) + text_tokens = torch.randint( + 1, + self.vocab_size, + (self.micro_batch_size, self.seq_length - self.image_seq_length), + device='cuda', + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + labels = input_ids.clone() + labels[input_ids == self.image_token_id] = -100 + loss_mask = (input_ids != self.image_token_id).float() + position_ids = ( + torch.arange(self.seq_length, device='cuda') + .unsqueeze(0) + .expand(self.micro_batch_size, -1) + .clone() + ) + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "modality_inputs": { + self.encoder_name: { + "clip_encoder": {'hidden_states': encoder_hidden_states, 'attention_mask': None} + } + }, + } + + +def run_colocated_pp_test( + encoder_tp, + encoder_dp, + llm_tp, + llm_pp, + llm_dp, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, +): + """Run colocated MIMO with encoder PP=1 + LLM PP>1.""" + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + encoder_name = "images" + + encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) + llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) + create_all_embedding_groups([encoder_grid, llm_grid]) + torch.manual_seed(12345) + + vision_pg = get_pg_collection(encoder_grid, is_language_model=False) + language_pg = get_pg_collection(llm_grid, is_language_model=True) + + language_model_spec = get_language_model_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + vocab_size=vocab_size, + seq_len=seq_length, + pg_collection=language_pg, + ) + vision_submodule_spec = get_vision_submodules_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + language_hidden_size=hidden_size, + pg_collection=vision_pg, + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={encoder_name: vision_submodule_spec}, + special_token_ids={encoder_name: 50257}, + module_to_grid_map={encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid}, + ) + + mimo_model = MimoModel(mimo_config) + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + mimo_model.model_type = ModelType.encoder_or_decoder + + # Wrap with DDP (per-module process groups) + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=False, bucket_size=10000, use_distributed_optimizer=True + ) + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg, + ) + if encoder_name in mimo_model.modality_submodules: + submodule = mimo_model.modality_submodules[encoder_name] + if submodule is not None: + mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg, + ) + + @contextmanager + def no_sync_func(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for sub in mimo_model.modality_submodules.values(): + if sub is not None: + stack.enter_context(sub.no_sync()) + yield + + def finalize_grads_func(*args, **kwargs): + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], num_tokens=None, pg_collection=language_pg + ) + for sub in mimo_model.modality_submodules.values(): + if sub is not None: + finalize_model_grads([sub], num_tokens=None, pg_collection=vision_pg) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + optimizer = get_mimo_optimizer(mimo_model, opt_config) + + data_iterator = DataIterator( + hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name + ) + lm_pp_group = llm_grid.get_pg("pp") + + rank = dist.get_rank() + num_iterations = 2 + all_losses = [] + optimizer.zero_grad() + + for iteration in range(num_iterations): + losses = colocated_forward_backward_with_pp( + mimo_model=mimo_model, + data_iterator=data_iterator, + num_microbatches=num_microbatches, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + encoder_name=encoder_name, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + p2p_communicator=P2PCommunicator(pp_group=lm_pp_group, config=mimo_model.config), + pg_collection=language_pg, + ) + + success, grad_norm, _ = optimizer.step() + assert success, f"Rank {rank}: Optimizer step failed at iteration {iteration}" + optimizer.zero_grad() + + all_losses.extend(losses or []) + logger.info(f"Rank {rank}: iteration {iteration} done, losses={len(losses or [])}") + + # Verify on last PP stage + if is_pp_last_stage(lm_pp_group): + assert len(all_losses) > 0, f"Rank {rank}: No losses on last stage" + for i, loss_dict in enumerate(all_losses): + loss_val = loss_dict.get('loss_reduced', 0) + if isinstance(loss_val, torch.Tensor): + loss_val = loss_val.item() + assert loss_val == loss_val, f"Rank {rank}: NaN loss at mb {i}" + assert abs(loss_val) != float('inf'), f"Rank {rank}: Inf loss at mb {i}" + + return all_losses + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimoColocatedPP: + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + def test_fan_in_enc_tp2_dp4_llm_tp2_dp2_pp2(self): + """Fan-in: encoder TP2/DP4 → LLM TP2/DP2/PP2.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_pp_test( + encoder_tp=2, encoder_dp=4, llm_tp=2, llm_pp=2, llm_dp=2, num_microbatches=4 + ) + + def test_equal_dp_enc_tp4_dp2_llm_tp2_dp2_pp2(self): + """Equal DP: encoder TP4/DP2 → LLM TP2/DP2/PP2 (enc_dp == llm_dp).""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_pp_test( + encoder_tp=4, encoder_dp=2, llm_tp=2, llm_pp=2, llm_dp=2, num_microbatches=4 + ) + + def test_fan_in_with_grad_acc(self): + """Fan-in with gradient accumulation (num_microbatches > pp_size).""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_pp_test( + encoder_tp=2, + encoder_dp=4, + llm_tp=2, + llm_pp=2, + llm_dp=2, + num_microbatches=6, # > pp_size=2, tests grad accumulation + ) + + def test_fan_in_enc_tp1_dp8_llm_tp4_dp1_pp2(self): + """Fan-in extreme: encoder TP1/DP8 → LLM TP4/DP1/PP2.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + # micro_batch_size must be >= fan-in scale (enc_dp/llm_dp = 8/1 = 8) + # to avoid zero-sized slices in _slice_for_encoder_dp. + run_colocated_pp_test( + encoder_tp=1, + encoder_dp=8, + llm_tp=4, + llm_pp=2, + llm_dp=1, + micro_batch_size=8, + num_microbatches=4, + ) diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index 3ab65a3616e..325f8cb305e 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -453,8 +453,14 @@ def __init__(self, rank_offset=0, size=1, dim_names=None, pp_rank=0, pp_size=1): self.rank_offset = rank_offset self.size = size self.dim_names = dim_names or [] + self._pp_rank = pp_rank + self._pp_size = pp_size self._pp_group = MockProcessGroup(pp_rank, pp_size) + @property + def shape(self): + return tuple(self._pp_size if d == "pp" else 1 for d in self.dim_names) + def get_pg(self, dims): if dims == "pp": return self._pp_group From faa946754707ea301655c5b2455437b23b200a29 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 16:07:27 +0000 Subject: [PATCH 02/11] Fix encoder DP grad sync and add PP weight oracle (NMFW-19) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The three-phase colocated schedule used to let the inner 1F1B PP schedule invoke ``config.finalize_model_grads_func`` at end-of-schedule. At that point the encoder has zero grads (Phase 3 has not run), so its DDP ``finish_grad_sync`` all-reduces zeros and the subsequent Phase 3 encoder grads stay local to each rank — Adam then steps on un-reduced encoder grads and diverges from an equal-DP reference. ``finish_grad_sync`` is not idempotent, so a second post-Phase-3 call is unsafe; instead, swap in a no-op during the PP schedule and invoke the user-provided finalize once after Phase 3 so a single DP reduction covers both LLM (Phase 2) and encoder (Phase 3) grads. Tests: extend ``test_mimo_colocated_pp.py`` with a real post-step weight oracle that compares the PP>1 dist run against an equal-DP PP=1 reference (identity bridge). Adds PP-aware LLM weight reshaping so both models start from identical state, runs one Adam step on each via their respective schedules, then asserts shard-wise encoder equality within bf16 tolerance. Parametrized for ``num_mb == pp`` and ``num_mb > pp`` to cover single-microbatch-per-stage and grad-accumulation-across-views cases. Existing smoke tests also gain ``grad_norm > 0`` assertions and params-changed snapshots to catch silently-zeroed encoder grads. Co-Authored-By: Claude Opus 4.7 --- .../core/models/mimo/colocated_schedule.py | 53 +- .../models/test_mimo_colocated_pp.py | 660 +++++++++++++++++- 2 files changed, 704 insertions(+), 9 deletions(-) diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py index 8cf96ded6c6..3e24be10097 100644 --- a/megatron/core/models/mimo/colocated_schedule.py +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -78,14 +78,30 @@ def _lm_forward_step(data_iterator_unused, model, *args): ) return output_tensor, partial(_loss_func, cached['loss_mask']) - losses = schedules.forward_backward_pipelining_without_interleaving( - forward_step_func=_lm_forward_step, - data_iterator=cache_iter, - model=[mimo_model], - num_microbatches=num_microbatches, - forward_only=forward_only, - **schedule_kwargs, - ) + # Defer finalize until AFTER Phase 3. The inner PP schedule would call + # ``config.finalize_model_grads_func`` at end-of-schedule, which runs + # DDP ``finish_grad_sync`` on both the LLM and the encoder. At that + # point the encoder has zero grads (Phase 3 has not run yet), so its + # DP all-reduce operates on zeros and the Phase 3 grads that follow + # are never synced — ``finish_grad_sync`` is not safe to call twice + # (it asserts on the outstanding async handle and has no idempotency + # guarantee). We swap in a no-op so the schedule proceeds normally, + # then invoke the original finalize once after Phase 3 so the single + # DP reduction covers both the LLM grads from Phase 2 and the encoder + # grads from Phase 3. + original_finalize = mimo_model.config.finalize_model_grads_func + mimo_model.config.finalize_model_grads_func = _noop_finalize + try: + losses = schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=_lm_forward_step, + data_iterator=cache_iter, + model=[mimo_model], + num_microbatches=num_microbatches, + forward_only=forward_only, + **schedule_kwargs, + ) + finally: + mimo_model.config.finalize_model_grads_func = original_finalize # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── # detached_full.grad was populated by Phase 2's per-microbatch LLM backward @@ -98,6 +114,18 @@ def _lm_forward_step(data_iterator_unused, model, *args): if grad is not None: torch.autograd.backward(enc_out[key], grad_tensors=grad) + # Single post-Phase-3 finalize: reduces LLM grads (from Phase 2) and + # encoder grads (from Phase 3) together. Without this call, encoder + # grads remain local to each rank and Adam steps on un-reduced grads, + # causing silent divergence from the equal-DP reference. + if not forward_only and original_finalize is not None: + original_finalize( + [mimo_model], + None, + pg_collection=schedule_kwargs.get('pg_collection'), + force_all_reduce=False, + ) + return losses @@ -213,3 +241,12 @@ def _loss_func(loss_mask, output_tensor): return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} loss = output_tensor.float().sum() return loss, {'loss_reduced': loss.detach().item()} + + +def _noop_finalize(*args, **kwargs): + """Placeholder used to suppress the inner PP schedule's finalize call. + + The three-phase schedule needs to defer grad finalization until after + Phase 3 runs the encoder backward. See ``colocated_forward_backward_with_pp``. + """ + return None diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py index 5aacdcad0d9..f74b930826b 100644 --- a/tests/unit_tests/models/test_mimo_colocated_pp.py +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -290,7 +290,18 @@ def run_colocated_pp_test( micro_batch_size=2, num_microbatches=4, ): - """Run colocated MIMO with encoder PP=1 + LLM PP>1.""" + """Run colocated MIMO with encoder PP=1 + LLM PP>1. + + Beyond "loss is finite", this helper verifies: + * ``optimizer.step`` returns grad_norm > 0 — catches silently-zeroed + encoder grads (e.g. broadcast never populating detached_full.grad + on non-first PP stages). + * Encoder params' data changed after the step — catches the case + where grads flow but the update is a no-op (wrong PG, wrong + device, clipping to zero). + * LLM params' data changed on every PP stage — catches the case + where the pipeline runs but a PP stage's grads never backprop. + """ import os os.environ.pop('NVTE_FLASH_ATTN', None) @@ -402,6 +413,27 @@ def finalize_grads_func(*args, **kwargs): all_losses = [] optimizer.zero_grad() + # Snapshot initial params to verify the step actually moves them. + # A silently-zeroed encoder grad (e.g. PP>1 grad broadcast missing) would + # leave these unchanged despite grad_norm appearing nonzero. + encoder_module = ( + mimo_model.modality_submodules[encoder_name].module + if encoder_name in mimo_model.modality_submodules + and mimo_model.modality_submodules[encoder_name] is not None + else None + ) + llm_module = mimo_model.language_model.module if mimo_model.language_model is not None else None + initial_encoder_params = ( + {n: p.detach().clone() for n, p in encoder_module.named_parameters()} + if encoder_module is not None + else {} + ) + initial_llm_params = ( + {n: p.detach().clone() for n, p in llm_module.named_parameters()} + if llm_module is not None + else {} + ) + for iteration in range(num_iterations): losses = colocated_forward_backward_with_pp( mimo_model=mimo_model, @@ -418,6 +450,13 @@ def finalize_grads_func(*args, **kwargs): success, grad_norm, _ = optimizer.step() assert success, f"Rank {rank}: Optimizer step failed at iteration {iteration}" + # grad_norm must be strictly positive: zero means every tracked param + # had zero grad, which indicates the schedule never wired a usable + # gradient into the param.grad buffers. + assert grad_norm is not None and grad_norm > 0, ( + f"Rank {rank}: grad_norm={grad_norm} at iter {iteration} — encoder or " + f"LLM grads were silently zeroed (did Phase 3 broadcast/backward run?)" + ) optimizer.zero_grad() all_losses.extend(losses or []) @@ -433,9 +472,565 @@ def finalize_grads_func(*args, **kwargs): assert loss_val == loss_val, f"Rank {rank}: NaN loss at mb {i}" assert abs(loss_val) != float('inf'), f"Rank {rank}: Inf loss at mb {i}" + # Oracle: at least one param of each module's shard must have changed. + # Under correct three-phase execution, every encoder rank accumulates a + # nonzero DP=1 gradient (via Phase 3 backward from the broadcast grad), + # and every LLM PP stage accumulates nonzero grads from the 1F1B pass. + # A blanket "all shards unchanged" outcome means the optimizer step was + # effectively a no-op for that module on this rank. + if encoder_module is not None: + changed = any( + not torch.equal(p.detach(), initial_encoder_params[n]) + for n, p in encoder_module.named_parameters() + if n in initial_encoder_params + ) + assert changed, ( + f"Rank {rank}: no encoder params changed after {num_iterations} steps — " + f"Phase 3 encoder backward likely did not populate grads on this rank" + ) + if llm_module is not None: + changed = any( + not torch.equal(p.detach(), initial_llm_params[n]) + for n, p in llm_module.named_parameters() + if n in initial_llm_params + ) + assert changed, ( + f"Rank {rank}: no LLM params changed after {num_iterations} steps — " + f"PP stage {dist.get_rank(lm_pp_group)} may have received no gradient" + ) + return all_losses +# --------------------------------------------------------------------------- +# Weight-oracle helpers: dist (PP>1, heterogeneous) vs ref (PP=1, equal-DP). +# --------------------------------------------------------------------------- + + +def _build_pp_oracle_model( + encoder_tp, + encoder_dp, + llm_tp, + llm_pp, + llm_dp, + hidden_size, + num_layers, + vocab_size, + seq_length, + ddp_config, + encoder_name="images", +): + """Build a MimoModel + DDP wrap for the weight-oracle test. Returns the + model plus its encoder_grid/llm_grid and pg_collections. Mirrors the + setup in ``run_colocated_pp_test`` but accepts an explicit ``ddp_config`` + so both dist and ref can share ``gradient_reduce_div_factor=1``. + """ + encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) + llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) + create_all_embedding_groups([encoder_grid, llm_grid]) + + vision_pg = get_pg_collection(encoder_grid, is_language_model=False) + language_pg = get_pg_collection(llm_grid, is_language_model=True) + + language_model_spec = get_language_model_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + vocab_size=vocab_size, + seq_len=seq_length, + pg_collection=language_pg, + ) + vision_submodule_spec = get_vision_submodules_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + language_hidden_size=hidden_size, + pg_collection=vision_pg, + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={encoder_name: vision_submodule_spec}, + special_token_ids={encoder_name: 50257}, + module_to_grid_map={encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid}, + ) + + mimo_model = MimoModel(mimo_config) + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + mimo_model.model_type = ModelType.encoder_or_decoder + + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg, + ) + if encoder_name in mimo_model.modality_submodules: + submodule = mimo_model.modality_submodules[encoder_name] + if submodule is not None: + mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg, + ) + + @contextmanager + def no_sync_func(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for sub in mimo_model.modality_submodules.values(): + if sub is not None: + stack.enter_context(sub.no_sync()) + yield + + def finalize_grads_func(*args, **kwargs): + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], num_tokens=None, pg_collection=language_pg + ) + for sub in mimo_model.modality_submodules.values(): + if sub is not None: + finalize_model_grads([sub], num_tokens=None, pg_collection=vision_pg) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + return mimo_model, encoder_grid, llm_grid, language_pg, vision_pg + + +def _generate_shared_global_batches( + num_batches, + global_batch_size, + seq_length, + hidden_size, + vocab_size, + encoder_name, + image_token_id=50257, +): + """Generate global batches on rank 0 and broadcast so every rank sees + identical data. Encoder input shape is [seq, batch, hidden] (sbh), + matching ``DataIterator`` above. + """ + rank = dist.get_rank() + image_seq_length = seq_length // 2 + batches = [] + for _ in range(num_batches): + if rank == 0: + encoder_hidden_states = torch.randn( + image_seq_length, + global_batch_size, + hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + image_tokens = torch.full( + (global_batch_size, image_seq_length), + image_token_id, + dtype=torch.long, + device='cuda', + ) + text_tokens = torch.randint( + 1, + vocab_size, + (global_batch_size, seq_length - image_seq_length), + device='cuda', + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + else: + encoder_hidden_states = torch.empty( + image_seq_length, + global_batch_size, + hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + input_ids = torch.empty( + global_batch_size, seq_length, dtype=torch.long, device='cuda' + ) + dist.broadcast(encoder_hidden_states, src=0) + dist.broadcast(input_ids, src=0) + + labels = input_ids.clone() + labels[input_ids == image_token_id] = -100 + loss_mask = (input_ids != image_token_id).float() + position_ids = ( + torch.arange(seq_length, device='cuda') + .unsqueeze(0) + .expand(global_batch_size, -1) + .clone() + ) + batches.append( + { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "modality_inputs": { + encoder_name: { + "clip_encoder": { + 'hidden_states': encoder_hidden_states, + 'attention_mask': None, + } + } + }, + } + ) + return batches + + +def _slice_batch_along_dim0(batch, split, idx): + """Return ``idx``-th of ``split`` equal slices along the batch dim.""" + b = batch['input_ids'].shape[0] + size = b // split + s, e = idx * size, (idx + 1) * size + out = {k: batch[k][s:e].contiguous() for k in ['input_ids', 'labels', 'loss_mask', 'position_ids']} + mod_new = {} + for m, md in batch['modality_inputs'].items(): + mod_new[m] = {} + for enc, ed in md.items(): + mod_new[m][enc] = {} + for k, t in ed.items(): + if isinstance(t, torch.Tensor): + # modality hidden_states shape [seq, batch, hidden] — dim 1 + mod_new[m][enc][k] = t[:, s:e, :].contiguous() + else: + mod_new[m][enc][k] = t + out['modality_inputs'] = mod_new + return out + + +def _copy_encoder_params(ref_module, dist_module): + """Copy encoder params ref → dist. Encoder layouts match by construction + (same enc_tp and enc_dp in both models), so shards line up 1:1. + """ + ref_params = dict(ref_module.named_parameters()) + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + assert name in ref_params, f"Encoder param '{name}' missing in ref" + ref_param = ref_params[name] + assert ref_param.shape == dist_param.shape, ( + f"Encoder param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)} — enc_tp/enc_dp must match " + f"between ref and dist for shard-wise comparison." + ) + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) + + +def _copy_llm_params_pp_aware( + ref_module, dist_module, pp_rank, pp_size, num_layers, dist_tp_group, ref_tp_group +): + """Copy LLM params ref (PP=1) → dist (PP>=1) with layer-index remapping. + + In Megatron's ``TransformerBlock``, ``self.layers`` is a ``ModuleList`` + indexed 0..N-1 *locally* per PP stage. The global layer number is + ``local_idx + pp_rank * layers_per_stage``. For a PP=1 reference, all + N layers live at local indices 0..N-1 on each rank; for a PP>1 dist + model, PP stage ``s``'s local layer ``i`` corresponds to ref's global + layer ``s*layers_per_stage + i``. + + Non-layer params (embedding, final_layernorm, output_layer) are only + present on stages with the relevant ``pre_process``/``post_process`` + flag, and their names match exactly between ref (which has them all) + and whichever dist stage owns them. + + If dist_tp != ref_tp the helper falls back to the PR-10 pattern of + gather-across-ref-tp + slice-for-dist-tp. Same-TP is the normal path + (this helper is designed for tests where ``dist_llm_tp == ref_llm_tp``, + so the gather path is a no-op fallback). + """ + import re + + assert num_layers % pp_size == 0, ( + f"num_layers={num_layers} not divisible by pp_size={pp_size}; " + f"oracle requires even PP split." + ) + layers_per_stage = num_layers // pp_size + layer_rx = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') + + ref_params = dict(ref_module.named_parameters()) + ref_tp_size = dist.get_world_size(ref_tp_group) + dist_tp_rank = dist.get_rank(dist_tp_group) + dist_tp_size = dist.get_world_size(dist_tp_group) + + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + m = layer_rx.match(name) + if m: + prefix, local_idx_s, suffix = m.groups() + global_idx = pp_rank * layers_per_stage + int(local_idx_s) + ref_name = f"{prefix}{global_idx}{suffix}" + else: + ref_name = name + + assert ref_name in ref_params, ( + f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " + f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + partition_dim = getattr(dist_param, 'partition_dim', -1) + + if ref_param.shape == dist_param.shape: + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) + continue + + assert partition_dim >= 0, ( + f"LLM param '{name}': shapes differ (ref={tuple(ref_param.shape)}, " + f"dist={tuple(dist_param.shape)}) but partition_dim<0 — cannot reshape " + f"a replicated param." + ) + shards = [torch.empty_like(ref_param.data) for _ in range(ref_tp_size)] + dist.all_gather(shards, ref_param.data.contiguous(), group=ref_tp_group) + full = torch.cat(shards, dim=partition_dim) + sliced = torch.tensor_split(full, dist_tp_size, dim=partition_dim)[dist_tp_rank] + assert sliced.shape == dist_param.shape + dist_param.data.copy_(sliced.to(dist_param.dtype)) + + +def _sum_loss_func(loss_mask_unused, output_tensor): + """Match the ``.sum()`` loss used by ``colocated_schedule._loss_func`` so + the reference's forward_backward_no_pipelining path produces comparable + gradient magnitudes. + """ + if output_tensor is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + loss = output_tensor.float().sum() + return loss, {'loss_reduced': loss.detach().item()} + + +def _assert_encoder_shards_match(ref_module, dist_module, rtol=1e-2, atol=1e-2): + """Assert every dist encoder shard matches the ref encoder shard. + + Tolerance accounts for bf16 accumulation-order drift between the ref's + LLM-flat (pp=1) gradient path and the dist's PP>1 1F1B path. Both paths + yield the same DP=1 encoder gradient in exact arithmetic; bf16 rounding + bounds the drift within the tolerance below. + """ + ref_params = dict(ref_module.named_parameters()) + mismatches = [] + for name, dist_param in dist_module.named_parameters(): + ref_param = ref_params[name] + assert ref_param.shape == dist_param.shape + try: + torch.testing.assert_close(dist_param.data, ref_param.data, rtol=rtol, atol=atol) + except AssertionError as e: + mismatches.append((name, str(e))) + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n}: {msg}" for n, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} encoder param(s) diverged between " + f"PP>1 dist and equal-DP PP=1 ref:\n{details}" + ) + + +def _run_pp_weight_oracle( + dist_enc_tp, + dist_enc_dp, + dist_llm_tp, + dist_llm_pp, + dist_llm_dp, + num_microbatches, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size_llm=2, +): + """Drive the dist-vs-ref weight oracle described in + ``test_pp_matches_pp1_equal_dp_reference``. + """ + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + encoder_name = "images" + + # Equal-DP reference: enc_tp=dist_enc_tp, enc_dp=dist_enc_dp, + # llm_tp=dist_enc_tp (→ same encoder & LLM TP layout), llm_dp=dist_enc_dp, + # llm_pp=1 (identity bridge, only PP value compatible with equal-DP on + # a fixed rank count). + ref_enc_tp, ref_enc_dp = dist_enc_tp, dist_enc_dp + ref_llm_tp, ref_llm_pp, ref_llm_dp = dist_enc_tp, 1, dist_enc_dp + + # Global batch size spans the larger DP side. Dist's LLM DP is smaller + # (fan-in), so each LLM rank holds micro_batch_size_llm samples. + global_batch_size = micro_batch_size_llm * dist_llm_dp + # For ref (equal-DP, llm_dp == enc_dp): per-rank batch = global_batch / enc_dp. + ref_per_rank_mbs = global_batch_size // ref_llm_dp + + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=False, + bucket_size=10000, + use_distributed_optimizer=True, + gradient_reduce_div_factor=1, + ) + + # Build dist first (heterogeneous TP/DP + PP>1). + torch.manual_seed(12345) + dist_model, dist_enc_grid, dist_llm_grid, dist_lang_pg, dist_vis_pg = _build_pp_oracle_model( + encoder_tp=dist_enc_tp, + encoder_dp=dist_enc_dp, + llm_tp=dist_llm_tp, + llm_pp=dist_llm_pp, + llm_dp=dist_llm_dp, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_length=seq_length, + ddp_config=ddp_config, + ) + # Build ref (equal-DP, PP=1). + torch.manual_seed(12345) + ref_model, ref_enc_grid, ref_llm_grid, ref_lang_pg, ref_vis_pg = _build_pp_oracle_model( + encoder_tp=ref_enc_tp, + encoder_dp=ref_enc_dp, + llm_tp=ref_llm_tp, + llm_pp=ref_llm_pp, + llm_dp=ref_llm_dp, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_length=seq_length, + ddp_config=ddp_config, + ) + + # Force identical initial state. Encoder: same TP/DP → shard-wise copy. + # LLM: ref has pp=1 (all layers), dist has pp>=1 (layers split); remap. + _copy_encoder_params( + ref_model.modality_submodules[encoder_name].module, + dist_model.modality_submodules[encoder_name].module, + ) + dist_pp_rank = dist_llm_grid.get_pg("pp").rank() + _copy_llm_params_pp_aware( + ref_model.language_model.module, + dist_model.language_model.module, + pp_rank=dist_pp_rank, + pp_size=dist_llm_pp, + num_layers=num_layers, + dist_tp_group=dist_llm_grid.get_pg("tp"), + ref_tp_group=ref_llm_grid.get_pg("tp"), + ) + + # Build optimizers AFTER weight copy (distributed optimizer snapshots + # fp32 master weights at __init__). + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + dist_optimizer = get_mimo_optimizer(dist_model, opt_config) + ref_optimizer = get_mimo_optimizer(ref_model, opt_config) + + # Deterministic shared global data. Both models consume the same global + # batches but slice differently: + # - Dist's data_iterator yields per-LLM-rank micro_batch_size samples + # (schedule then fan-in-slices on the encoder side). + # - Ref's data_iterator yields per-rank ref_per_rank_mbs samples. + torch.manual_seed(99999) + global_batches = _generate_shared_global_batches( + num_batches=num_microbatches, + global_batch_size=global_batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + vocab_size=vocab_size, + encoder_name=encoder_name, + ) + dist_llm_dp_pg = dist_llm_grid.get_pg("dp") + ref_enc_dp_pg = ref_enc_grid.get_pg("dp") + dist_per_rank_batches = [ + _slice_batch_along_dim0(b, dist_llm_dp, dist_llm_dp_pg.rank()) + for b in global_batches + ] + ref_per_rank_batches = [ + _slice_batch_along_dim0(b, ref_enc_dp, ref_enc_dp_pg.rank()) + for b in global_batches + ] + + # ── Dist forward/backward: three-phase colocated schedule ──────────── + class _ListIter: + def __init__(self, items): + self._items = items + self._i = 0 + + def __iter__(self): + return self + + def __next__(self): + if self._i >= len(self._items): + raise StopIteration + v = self._items[self._i] + self._i += 1 + return v + + dist_optimizer.zero_grad() + colocated_forward_backward_with_pp( + mimo_model=dist_model, + data_iterator=_ListIter(dist_per_rank_batches), + num_microbatches=num_microbatches, + encoder_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + encoder_name=encoder_name, + seq_length=seq_length, + micro_batch_size=micro_batch_size_llm, + p2p_communicator=P2PCommunicator( + pp_group=dist_llm_grid.get_pg("pp"), config=dist_model.config + ), + pg_collection=dist_lang_pg, + ) + dist_ok, dist_gn, _ = dist_optimizer.step() + assert dist_ok, "Dist optimizer step failed" + assert dist_gn is not None and dist_gn > 0, ( + f"Dist grad_norm={dist_gn} — three-phase schedule produced zero encoder/LLM grads." + ) + + # ── Ref forward/backward: plain no-pipelining schedule ─────────────── + def _ref_forward_step(data_iterator, model, *args): + batch = next(data_iterator) + output_tensor, loss_mask = model( + input_ids=batch['input_ids'], + labels=batch['labels'], + loss_mask=batch['loss_mask'], + position_ids=batch['position_ids'], + modality_inputs=batch['modality_inputs'], + ) + return output_tensor, partial(_sum_loss_func, loss_mask) + + ref_optimizer.zero_grad() + schedule.forward_backward_no_pipelining( + forward_step_func=_ref_forward_step, + data_iterator=_ListIter(ref_per_rank_batches), + model=[ref_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=ref_per_rank_mbs, + forward_only=False, + pg_collection=ref_lang_pg, + ) + ref_ok, ref_gn, _ = ref_optimizer.step() + assert ref_ok, "Ref optimizer step failed" + assert ref_gn is not None and ref_gn > 0, f"Ref grad_norm={ref_gn}" + + # Main oracle: post-step encoder shards match 1:1 (same enc_tp, enc_dp). + _assert_encoder_shards_match( + ref_model.modality_submodules[encoder_name].module, + dist_model.modality_submodules[encoder_name].module, + rtol=1e-2, + atol=1e-2, + ) + + @pytest.mark.skipif( version.parse(torch.__version__) < version.parse('2.3.0'), reason="Device mesh requires PyTorch 2.3+", @@ -497,3 +1092,66 @@ def test_fan_in_enc_tp1_dp8_llm_tp4_dp1_pp2(self): micro_batch_size=8, num_microbatches=4, ) + + @pytest.mark.parametrize( + "num_microbatches", + [2, 4], + ids=["num_mb_eq_pp", "num_mb_gt_pp_grad_acc"], + ) + def test_pp_matches_pp1_equal_dp_reference(self, num_microbatches): + """Post-step encoder weights under PP>1 match equal-DP PP=1 reference. + + This is the real correctness oracle for PR-9's three-phase schedule. + Parallels the PR-10 oracle (``test_mimo_colocated_correctness.py``), + extended with PP-aware LLM weight reshaping so the reference has + ``llm_pp=1`` (the only config compatible with equal-DP on a fixed + rank count: ``enc_tp * enc_dp == llm_tp * llm_pp * llm_dp`` and + ``enc_dp == llm_dp`` force ``llm_pp = enc_tp / llm_tp``; with + ``enc_tp == llm_tp`` this means ``llm_pp = 1``). + + * Dist: fan-in + PP>1 (the config under test). Runs through + ``colocated_forward_backward_with_pp`` (three-phase schedule). + * Ref: ``enc_tp=dist_enc_tp``, ``enc_dp=dist_enc_dp``, + ``llm_tp=dist_enc_tp``, ``llm_dp=dist_enc_dp``, ``llm_pp=1``. + Identity bridge (``BridgeDirection.EQUAL``); runs through + ``forward_backward_no_pipelining``. + + Both use ``gradient_reduce_div_factor=1`` with an identical + ``.sum()`` loss (matching the loss in ``colocated_schedule.py``), + so the DDP reduction yields the DP=1 aggregate gradient on every + encoder shard regardless of LLM layout. Encoder TP matches across + the two models, so shards line up 1:1. LLM TP matches too; LLM + weights differ only in PP partitioning, which the copy helper + below reshapes. Under correct PP>1 encoder grad accumulation + + broadcast, one Adam step yields shard-wise equal post-step encoder + weights modulo bf16 accumulation drift. + + If the three-phase schedule mishandles any of: encoder grad + accumulation across microbatches, PP-stage-0→stage-N broadcast, + or the detach/reattach boundary, encoder shards diverge and this + test fails. + + Two parametrized cases: + * ``num_mb_eq_pp`` (num_microbatches=2, pp=2): the minimal + pipeline with one microbatch per PP stage. No grad + accumulation across microbatches. + * ``num_mb_gt_pp_grad_acc`` (num_microbatches=4, pp=2): 1F1B + pipeline runs with 2 microbatches per stage, so encoder + embedding views for 4 microbatches all accumulate into the + same ``detached_full.grad`` via PyTorch view-gradient + semantics. If the microbatch slicing in + ``_build_lm_microbatches`` does not produce proper views of + ``detached_full`` (e.g., accidentally cloning), grad + accumulation across microbatches is silently dropped and the + encoder shards diverge. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + _run_pp_weight_oracle( + dist_enc_tp=2, + dist_enc_dp=4, + dist_llm_tp=2, + dist_llm_pp=2, + dist_llm_dp=2, + num_microbatches=num_microbatches, + ) From b5a75e99906157ff856e52c3bdb16a69b51d3a5c Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 17:32:46 +0000 Subject: [PATCH 03/11] Allow dest PP>1 in ColocatedBridgeCommunicator (NMFW-19) The PP>1 schedule orchestrates the LLM pipeline; the bridge only needs src (encoder) PP=1 since encode_and_communicate runs on every rank synchronously. For fan-in, gather groups are keyed by src position so each rank lands in exactly one group regardless of its llm_pp index; the EQUAL path does no collective at all. Co-Authored-By: Claude Opus 4.7 --- .../mimo/comm/colocated_communicator.py | 27 ++++++++++++------- .../test_mimo_colocated_communicator.py | 9 ++++++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py index dd0241d8f80..157114b1003 100644 --- a/megatron/core/models/mimo/comm/colocated_communicator.py +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -128,8 +128,11 @@ def _validate_grids(self): f"src={self.src_grid.rank_offset}, dest={self.dest_grid.rank_offset}" ) - # Per-grid dim checks: tp/dp required; pp and cp (if present) must be 1. - # CP>1 also corrupts dp_idx when iterating get_rank_enum(['tp']) groups. + # Per-grid dim checks: tp/dp required; cp (if present) must be 1. + # Src PP must be 1; dest PP>1 is allowed — the three-phase colocated + # schedule orchestrates LLM PP, and fan-in groups are keyed by src + # (PP=1) position so each rank lands in exactly one group. + # CP>1 corrupts dp_idx when iterating get_rank_enum(['tp']) groups. for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: for required in ('tp', 'dp'): if required not in grid.dim_names: @@ -137,14 +140,18 @@ def _validate_grids(self): f"{name} grid must have '{required}' dimension, " f"got dim_names={grid.dim_names}" ) - for singleton in ('pp', 'cp'): - if singleton in grid.dim_names: - size = grid.shape[grid.dim_names.index(singleton)] - if size != 1: - raise ValueError( - f"{name} {singleton.upper()} must be 1 for " - f"ColocatedBridgeCommunicator, got {size}" - ) + if 'cp' in grid.dim_names: + cp_size = grid.shape[grid.dim_names.index('cp')] + if cp_size != 1: + raise ValueError( + f"{name} CP must be 1 for ColocatedBridgeCommunicator, got {cp_size}" + ) + if 'pp' in self.src_grid.dim_names: + src_pp = self.src_grid.shape[self.src_grid.dim_names.index('pp')] + if src_pp != 1: + raise ValueError( + f"src PP must be 1 for ColocatedBridgeCommunicator, got {src_pp}" + ) src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py index 546a1648fd5..236e5295dde 100644 --- a/tests/unit_tests/models/test_mimo_colocated_communicator.py +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -236,8 +236,8 @@ def test_rank_offset_mismatch(self): "side,dim,expected", [ ("src", "pp", "src PP must be 1"), - ("dest", "pp", "dest PP must be 1"), ("src", "cp", "CP must be 1"), + ("dest", "cp", "CP must be 1"), ], ) def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): @@ -252,6 +252,13 @@ def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): with pytest.raises(ValueError, match=expected): make_comm(src_grid, dest_grid) + def test_dest_pp_gt_one_accepted(self): + # Dest PP>1 is valid: the three-phase colocated schedule handles + # the LLM pipeline orchestration. The bridge only needs src PP=1. + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, pp=2, dp=2) + make_comm(src_grid, dest_grid) + def test_dp_not_divisible(self): # 6-rank grids with DP sizes (3 vs 2) that neither divides the other. # Fits inside an 8-rank world (HyperCommGrid enforces size <= world - offset). From f199f9533e26f480d016e5f9de773ff17e6bceba Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 20:05:36 +0000 Subject: [PATCH 04/11] colocated_comm: drop verbose PP=1 rationale comment The ValueError message is self-documenting; schedule rationale belongs in colocated_schedule.py, not the communicator. Co-Authored-By: Claude Opus 4.7 --- megatron/core/models/mimo/comm/colocated_communicator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py index 157114b1003..b501d911bbb 100644 --- a/megatron/core/models/mimo/comm/colocated_communicator.py +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -129,10 +129,8 @@ def _validate_grids(self): ) # Per-grid dim checks: tp/dp required; cp (if present) must be 1. - # Src PP must be 1; dest PP>1 is allowed — the three-phase colocated - # schedule orchestrates LLM PP, and fan-in groups are keyed by src - # (PP=1) position so each rank lands in exactly one group. - # CP>1 corrupts dp_idx when iterating get_rank_enum(['tp']) groups. + # Src PP must be 1; dest PP>1 is allowed. CP>1 corrupts dp_idx when + # iterating get_rank_enum(['tp']) groups. for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: for required in ('tp', 'dp'): if required not in grid.dim_names: From c5376e634e8da319d55d177fffd7fffd72af4ae0 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 20:07:21 +0000 Subject: [PATCH 05/11] mimo: fold LLM PP stage detection into RankRole.colocated RankRole.colocated now accepts the grid map and derives per-module PP stage info from each grid's pp group, removing the post-build mutation of self.role.modules from MimoModel.__init__. Co-Authored-By: Claude Opus 4.7 --- megatron/core/models/mimo/config/role.py | 35 +++++++++++++++++------- megatron/core/models/mimo/model/base.py | 26 +++++------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 7732ad85253..ed0d7aa8ca4 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -79,7 +79,7 @@ def build( Grids differ → NON_COLOCATED with PP-stage info per module. """ if module_to_grid_map is None or cls._all_grids_colocated(module_to_grid_map): - return cls._colocated(modality_module_names) + return cls._colocated(modality_module_names, module_to_grid_map) return cls._from_grid_map(module_to_grid_map) @staticmethod @@ -91,16 +91,31 @@ def _all_grids_colocated(module_to_grid_map: Dict[str, 'HyperCommGrid']) -> bool ) @classmethod - def _colocated(cls, modality_module_names: List[str]) -> 'RankRole': - """Colocated layout: every module on every rank, PP=1.""" + def _colocated( + cls, + modality_module_names: List[str], + module_to_grid_map: Optional[Dict[str, 'HyperCommGrid']] = None, + ) -> 'RankRole': + """Colocated layout: every module on every rank. + + When a grid map is supplied, per-module stage info is derived from + each grid's pp group (LLM PP>1 is allowed). With no grid map, every + module is both first and last stage. + """ all_module_names = list(modality_module_names) + [MIMO_LANGUAGE_MODULE_KEY] - return cls( - modules={ - name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) - for name in all_module_names - }, - mode=ModuleLayout.COLOCATED, - ) + modules = {} + for name in all_module_names: + grid = module_to_grid_map.get(name) if module_to_grid_map else None + if grid is not None and 'pp' in grid.dim_names: + pp_group = grid.get_pg('pp') + pp_rank, pp_size = pp_group.rank(), pp_group.size() + modules[name] = ModuleStageInfo( + is_first_stage=(pp_rank == 0), + is_last_stage=(pp_rank == pp_size - 1), + ) + else: + modules[name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + return cls(modules=modules, mode=ModuleLayout.COLOCATED) @classmethod def _from_grid_map(cls, module_to_grid_map: Dict[str, HyperCommGrid]) -> 'RankRole': diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 95a1651ef5c..dd9bb860894 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -67,25 +67,13 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # in TP/DP within those ranks. self._build_colocated_communicators() - # Detect LLM PP>1 for three-phase colocated execution - self.lm_has_pp = False - self.lm_is_first_pp_stage = True - if mimo_config.module_to_grid_map: - lang_grid = mimo_config.module_to_grid_map.get(MIMO_LANGUAGE_MODULE_KEY) - if lang_grid and 'pp' in lang_grid.dim_names: - pp_idx = lang_grid.dim_names.index('pp') - if lang_grid.shape[pp_idx] > 1: - self.lm_has_pp = True - pp_group = lang_grid.get_pg('pp') - pp_rank = pp_group.rank() - pp_size = pp_group.size() - self.lm_is_first_pp_stage = pp_rank == 0 - # Update language module stage info for PP>1 - from megatron.core.models.mimo.config.role import ModuleStageInfo - - self.role.modules[MIMO_LANGUAGE_MODULE_KEY] = ModuleStageInfo( - is_first_stage=(pp_rank == 0), is_last_stage=(pp_rank == pp_size - 1) - ) + # LLM PP>1 is already reflected in self.role; expose convenience flags + # for the three-phase colocated schedule. + lang_info = self.role.modules.get(MIMO_LANGUAGE_MODULE_KEY) + self.lm_has_pp = lang_info is not None and not ( + lang_info.is_first_stage and lang_info.is_last_stage + ) + self.lm_is_first_pp_stage = lang_info is None or lang_info.is_first_stage # Use special token IDs from the config self.special_token_ids = ( From 72fb43d012b6baf83cab8ed07124a5423b9168bb Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 20:07:42 +0000 Subject: [PATCH 06/11] mimo: reuse encode_and_communicate in _forward_all_modules Collapse the inlined modality-forward block (duplicating encode_and_communicate) into a single call, dropping redundant per-modality debug logs. Co-Authored-By: Claude Opus 4.7 --- megatron/core/models/mimo/model/base.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index dd9bb860894..83f3c50e6c4 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -602,27 +602,7 @@ def _forward_all_modules( # reuse the precomputed embeddings for every LLM microbatch. modality_embeddings = encoder_embeddings else: - # 1. Process each modality to get embeddings - modality_embeddings = {} - - for modality_name, submodule in self.modality_submodules.items(): - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - modality_embeddings[modality_name] = embeddings - logger.debug( - f"Generated embeddings for {modality_name} with shape " - f"{embeddings.shape}" - ) - - # Apply colocated communication if configured (no-op when colocated_comms is empty) - if self.colocated_comms: - modality_embeddings = self._apply_colocated_comms(modality_embeddings) + modality_embeddings = self.encode_and_communicate(modality_inputs) # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) From 228ed24c090faefc367ab1bb0551a914a6a67b9d Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 20:08:22 +0000 Subject: [PATCH 07/11] colocated_schedule: tighten finalize-defer comment Collapse the finish_grad_sync exposition into three lines that keep the why (encoder grads don't exist yet) without the DDP-internal detail. Co-Authored-By: Claude Opus 4.7 --- megatron/core/models/mimo/colocated_schedule.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py index 3e24be10097..73873e34619 100644 --- a/megatron/core/models/mimo/colocated_schedule.py +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -78,17 +78,9 @@ def _lm_forward_step(data_iterator_unused, model, *args): ) return output_tensor, partial(_loss_func, cached['loss_mask']) - # Defer finalize until AFTER Phase 3. The inner PP schedule would call - # ``config.finalize_model_grads_func`` at end-of-schedule, which runs - # DDP ``finish_grad_sync`` on both the LLM and the encoder. At that - # point the encoder has zero grads (Phase 3 has not run yet), so its - # DP all-reduce operates on zeros and the Phase 3 grads that follow - # are never synced — ``finish_grad_sync`` is not safe to call twice - # (it asserts on the outstanding async handle and has no idempotency - # guarantee). We swap in a no-op so the schedule proceeds normally, - # then invoke the original finalize once after Phase 3 so the single - # DP reduction covers both the LLM grads from Phase 2 and the encoder - # grads from Phase 3. + # Swap in a no-op finalize so the inner PP schedule does not run DDP + # grad sync before Phase 3 has produced encoder grads. We invoke the + # original finalize once after Phase 3 (see below). original_finalize = mimo_model.config.finalize_model_grads_func mimo_model.config.finalize_model_grads_func = _noop_finalize try: From 06eaae843fed193c2516d405220381043afa2eb4 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 20:09:03 +0000 Subject: [PATCH 08/11] colocated_schedule: wrap finalize swap in a context manager Introduce _deferred_finalize contextmanager that yields the original callable, so the post-Phase-3 invocation keeps access to it while the swap/restore logic is encapsulated. Co-Authored-By: Claude Opus 4.7 --- .../core/models/mimo/colocated_schedule.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py index 73873e34619..e7e9598c002 100644 --- a/megatron/core/models/mimo/colocated_schedule.py +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -13,6 +13,7 @@ graph at the encoder-LLM boundary. """ +from contextlib import contextmanager from functools import partial from typing import Optional @@ -81,9 +82,7 @@ def _lm_forward_step(data_iterator_unused, model, *args): # Swap in a no-op finalize so the inner PP schedule does not run DDP # grad sync before Phase 3 has produced encoder grads. We invoke the # original finalize once after Phase 3 (see below). - original_finalize = mimo_model.config.finalize_model_grads_func - mimo_model.config.finalize_model_grads_func = _noop_finalize - try: + with _deferred_finalize(mimo_model.config) as original_finalize: losses = schedules.forward_backward_pipelining_without_interleaving( forward_step_func=_lm_forward_step, data_iterator=cache_iter, @@ -92,8 +91,6 @@ def _lm_forward_step(data_iterator_unused, model, *args): forward_only=forward_only, **schedule_kwargs, ) - finally: - mimo_model.config.finalize_model_grads_func = original_finalize # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── # detached_full.grad was populated by Phase 2's per-microbatch LLM backward @@ -242,3 +239,16 @@ def _noop_finalize(*args, **kwargs): Phase 3 runs the encoder backward. See ``colocated_forward_backward_with_pp``. """ return None + + +@contextmanager +def _deferred_finalize(config): + """Suppress the PP schedule's end-of-run DDP grad sync; yield the + original so callers can invoke it once after Phase 3. + """ + original = config.finalize_model_grads_func + config.finalize_model_grads_func = _noop_finalize + try: + yield original + finally: + config.finalize_model_grads_func = original From aa4568ff830dbe0d3b3503fea7c179dc113ff678 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Tue, 21 Apr 2026 20:11:48 +0000 Subject: [PATCH 09/11] test_mimo_colocated_pp: drop duplicated helpers + smoke tests, reuse PR-10 oracle infra MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete the smoke test (run_colocated_pp_test) and four smoke test methods: the PP weight oracle already subsumes them. Delete the DataIterator, _build_pp_oracle_model, grid/spec helpers, param-copy helpers, shard-match helper, and the shared-batch generator — they are 1:1 duplicates of infra already exported from test_mimo_1f1b_schedule.py and test_mimo_colocated_correctness.py. Keep only the new piece specific to PP>1: _copy_llm_params_pp_aware, with the unused same-TP fallback all-gather branch removed and the docstring tightened. _run_pp_weight_oracle is rewritten as a short driver built on the imported helpers. Co-Authored-By: Claude Opus 4.7 --- .../models/test_mimo_colocated_pp.py | 1040 ++--------------- 1 file changed, 99 insertions(+), 941 deletions(-) diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py index f74b930826b..7d6143b2b2f 100644 --- a/tests/unit_tests/models/test_mimo_colocated_pp.py +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -1,15 +1,12 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """Tests for colocated MIMO training with LLM PP>1. -Uses two-phase execution: encoder pre-compute + 1F1B LLM pipeline. - Run individually (8 GPUs): uv run python -m torch.distributed.run --nproc_per_node=8 \ -m pytest tests/unit_tests/models/test_mimo_colocated_pp.py -v """ -import logging -from contextlib import ExitStack, contextmanager +import re from functools import partial import pytest @@ -18,747 +15,45 @@ from packaging import version import megatron.core.pipeline_parallel.schedules as schedule -from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig -from megatron.core.distributed.finalize_model_grads import finalize_model_grads -from megatron.core.hyper_comm_grid import HyperCommGrid -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp -from megatron.core.models.mimo.config.base_configs import MimoModelConfig -from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY -from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.optimizer import get_mimo_optimizer -from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules -from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.optimizer.optimizer_config import OptimizerConfig -from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator -from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage -from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.models.test_mimo_1f1b_schedule import ( + create_all_embedding_groups, + create_hypercomm_grid, + destroy_all_grids, + get_mimo_model, +) +from tests.unit_tests.models.test_mimo_colocated_correctness import ( + _assert_encoder_weights_match, + _BatchIterator, + _copy_ref_params_to_dist, + _generate_and_broadcast_global_batches, + _slice_batch, + _wire_training_hooks, +) from tests.unit_tests.test_utilities import Utils -try: - from megatron.core.extensions.transformer_engine import ( - TEColumnParallelLinear, - TERowParallelLinear, - ) -except ImportError: - TEColumnParallelLinear = None - TERowParallelLinear = None - -logger = logging.getLogger(__name__) - -_active_grids: list = [] -_embedding_pg_cache: dict = {} - - -def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): - grid = HyperCommGrid( - shape=[tp, cp, pp, dp, 1, 1], - dim_names=["tp", "cp", "pp", "dp", "ep", "expt_dp"], - rank_offset=offset, - backend="nccl", - ) - for dim in ["tp", "cp", "pp", "dp", "ep", "expt_dp"]: - grid.create_pg([dim]) - grid.create_pg(["dp", "cp"]) - grid.create_pg(["tp", "pp"]) - grid.create_pg(["tp", "ep", "pp"]) - grid.create_pg(["dp", "ep"]) - grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) - _active_grids.append(grid) - return grid - - -def destroy_all_grids(): - for g in _active_grids: - g.destroy() - _active_grids.clear() - _embedding_pg_cache.clear() - BridgeCommunicator.destroy_broadcast_pgs() - - -def create_all_embedding_groups(grids): - for grid in grids: - pp_group = grid.get_pg("pp") - if not pp_group: - continue - pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) - key = tuple(pp_ranks) - if key not in _embedding_pg_cache: - pos = [pp_ranks[0]] - embd = [pp_ranks[0]] - if pp_ranks[-1] != pp_ranks[0]: - embd.append(pp_ranks[-1]) - _embedding_pg_cache[key] = (dist.new_group(ranks=pos), dist.new_group(ranks=embd)) - - -def get_pg_collection(grid, is_language_model=False): - pg = ProcessGroupCollection() - pg.tp = grid.get_pg("tp") - pg.cp = grid.get_pg("cp") - pg.pp = grid.get_pg("pp") - pg.ep = grid.get_pg("ep") - pg.dp = grid.get_pg("dp") - pg.dp_cp = grid.get_pg(["dp", "cp"]) - pg.expt_dp = grid.get_pg("expt_dp") - pp_ranks = sorted(dist.get_process_group_ranks(pg.pp)) - key = tuple(pp_ranks) - if key in _embedding_pg_cache: - pos_pg, embd_pg = _embedding_pg_cache[key] - pg.pos_embd = pos_pg if is_pp_first_stage(pg.pp) else None - pg.embd = ( - embd_pg - if is_language_model and (is_pp_last_stage(pg.pp) or is_pp_first_stage(pg.pp)) - else None - ) - return pg - - -def get_language_model_spec( - num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection -): - pp_rank = dist.get_rank(pg_collection.pp) - pp_size = dist.get_world_size(pg_collection.pp) - tp_size = pg_collection.tp.size() if pg_collection.tp else 1 - lm_config = TransformerConfig( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - use_cpu_initialization=True, - variable_seq_lengths=True, - moe_token_dispatcher_type='alltoall', - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=pp_size, - pipeline_dtype=torch.bfloat16, - bf16=True, - cross_entropy_loss_fusion=True, - cross_entropy_fusion_impl='te', - ) - return ModuleSpec( - module=GPTModel, - params={ - "config": lm_config, - "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), - "vocab_size": vocab_size, - "max_sequence_length": seq_len, - "pre_process": (pp_rank == 0), - "post_process": (pp_rank == pp_size - 1), - "pg_collection": pg_collection, - }, - ) - - -def get_vision_submodules_spec( - num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection -): - from megatron.core.transformer.transformer_block import TransformerBlock - - tp_size = pg_collection.tp.size() if pg_collection.tp else 1 - vision_config = TransformerConfig( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - use_cpu_initialization=True, - variable_seq_lengths=True, - moe_token_dispatcher_type='alltoall', - tensor_model_parallel_size=tp_size, - pipeline_model_parallel_size=1, - pipeline_dtype=torch.bfloat16, - bf16=True, - ) - proj_cfg = TransformerConfig( - num_layers=1, hidden_size=language_hidden_size, num_attention_heads=1 - ) - proj_cfg.ffn_hidden_size = language_hidden_size - proj_cfg.bias_activation_fusion = True - proj_cfg.add_bias_linear = True - proj_cfg.activation_func = torch.nn.functional.gelu - - return ModuleSpec( - module=VisionModalitySubmodules, - submodules={ - "encoders": { - "clip_encoder": ModuleSpec( - module=TransformerBlock, - params={ - "config": vision_config, - "spec": get_gpt_layer_with_transformer_engine_spec(), - "pg_collection": pg_collection, - "pre_process": True, - "post_process": True, - }, - ) - }, - "input_projections": [ - ModuleSpec( - module=MultimodalProjector, - params={ - "config": proj_cfg, - "submodules": MLPSubmodules( - linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - "projector_type": "mlp", - "input_size": vision_config.hidden_size, - "tp_group": pg_collection.tp, - }, - ) - ], - }, - ) - - -class DataIterator: - def __init__( - self, - hidden_size, - seq_length, - micro_batch_size, - vocab_size, - encoder_name, - image_token_id=50257, - image_seq_length=None, - ): - self.hidden_size = hidden_size - self.seq_length = seq_length - self.micro_batch_size = micro_batch_size - self.vocab_size = vocab_size - self.encoder_name = encoder_name - self.image_token_id = image_token_id - self.image_seq_length = image_seq_length or (seq_length // 2) - - def __iter__(self): - return self - - def __next__(self): - encoder_hidden_states = torch.randn( - self.image_seq_length, - self.micro_batch_size, - self.hidden_size, - device='cuda', - dtype=torch.bfloat16, - ) - image_tokens = torch.full( - (self.micro_batch_size, self.image_seq_length), - self.image_token_id, - dtype=torch.long, - device='cuda', - ) - text_tokens = torch.randint( - 1, - self.vocab_size, - (self.micro_batch_size, self.seq_length - self.image_seq_length), - device='cuda', - ) - input_ids = torch.cat([image_tokens, text_tokens], dim=1) - labels = input_ids.clone() - labels[input_ids == self.image_token_id] = -100 - loss_mask = (input_ids != self.image_token_id).float() - position_ids = ( - torch.arange(self.seq_length, device='cuda') - .unsqueeze(0) - .expand(self.micro_batch_size, -1) - .clone() - ) - return { - "input_ids": input_ids, - "labels": labels, - "loss_mask": loss_mask, - "position_ids": position_ids, - "modality_inputs": { - self.encoder_name: { - "clip_encoder": {'hidden_states': encoder_hidden_states, 'attention_mask': None} - } - }, - } - - -def run_colocated_pp_test( - encoder_tp, - encoder_dp, - llm_tp, - llm_pp, - llm_dp, - hidden_size=256, - num_layers=2, - vocab_size=1000, - seq_length=64, - micro_batch_size=2, - num_microbatches=4, -): - """Run colocated MIMO with encoder PP=1 + LLM PP>1. - - Beyond "loss is finite", this helper verifies: - * ``optimizer.step`` returns grad_norm > 0 — catches silently-zeroed - encoder grads (e.g. broadcast never populating detached_full.grad - on non-first PP stages). - * Encoder params' data changed after the step — catches the case - where grads flow but the update is a no-op (wrong PG, wrong - device, clipping to zero). - * LLM params' data changed on every PP stage — catches the case - where the pipeline runs but a PP stage's grads never backprop. - """ - import os - - os.environ.pop('NVTE_FLASH_ATTN', None) - os.environ.pop('NVTE_FUSED_ATTN', None) - os.environ.pop('NVTE_UNFUSED_ATTN', None) - - encoder_name = "images" - - encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) - llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) - create_all_embedding_groups([encoder_grid, llm_grid]) - torch.manual_seed(12345) - - vision_pg = get_pg_collection(encoder_grid, is_language_model=False) - language_pg = get_pg_collection(llm_grid, is_language_model=True) - - language_model_spec = get_language_model_spec( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=8, - vocab_size=vocab_size, - seq_len=seq_length, - pg_collection=language_pg, - ) - vision_submodule_spec = get_vision_submodules_spec( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=8, - language_hidden_size=hidden_size, - pg_collection=vision_pg, - ) - - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={encoder_name: vision_submodule_spec}, - special_token_ids={encoder_name: 50257}, - module_to_grid_map={encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid}, - ) - - mimo_model = MimoModel(mimo_config) - mimo_model.to(torch.device("cuda")).to(torch.bfloat16) - mimo_model.model_type = ModelType.encoder_or_decoder - - # Wrap with DDP (per-module process groups) - ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=False, bucket_size=10000, use_distributed_optimizer=True - ) - if mimo_model.language_model is not None: - mimo_model.language_model = DistributedDataParallel( - config=mimo_model.language_model.config, - ddp_config=ddp_config, - module=mimo_model.language_model, - pg_collection=language_pg, - ) - if encoder_name in mimo_model.modality_submodules: - submodule = mimo_model.modality_submodules[encoder_name] - if submodule is not None: - mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( - config=submodule.encoders['clip_encoder'].config, - ddp_config=ddp_config, - module=submodule, - pg_collection=vision_pg, - ) - - @contextmanager - def no_sync_func(): - with ExitStack() as stack: - if mimo_model.language_model is not None: - stack.enter_context(mimo_model.language_model.no_sync()) - for sub in mimo_model.modality_submodules.values(): - if sub is not None: - stack.enter_context(sub.no_sync()) - yield - - def finalize_grads_func(*args, **kwargs): - if mimo_model.language_model is not None: - finalize_model_grads( - [mimo_model.language_model], num_tokens=None, pg_collection=language_pg - ) - for sub in mimo_model.modality_submodules.values(): - if sub is not None: - finalize_model_grads([sub], num_tokens=None, pg_collection=vision_pg) - - mimo_model.config.no_sync_func = no_sync_func - mimo_model.config.finalize_model_grads_func = finalize_grads_func - mimo_model.config.grad_scale_func = lambda loss: ( - torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) - if isinstance(loss, (int, float)) - else loss - ) - - opt_config = OptimizerConfig( - optimizer='adam', - lr=1e-4, - weight_decay=0.01, - clip_grad=1.0, - bf16=True, - use_distributed_optimizer=True, - ) - optimizer = get_mimo_optimizer(mimo_model, opt_config) - - data_iterator = DataIterator( - hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name - ) - lm_pp_group = llm_grid.get_pg("pp") - - rank = dist.get_rank() - num_iterations = 2 - all_losses = [] - optimizer.zero_grad() - - # Snapshot initial params to verify the step actually moves them. - # A silently-zeroed encoder grad (e.g. PP>1 grad broadcast missing) would - # leave these unchanged despite grad_norm appearing nonzero. - encoder_module = ( - mimo_model.modality_submodules[encoder_name].module - if encoder_name in mimo_model.modality_submodules - and mimo_model.modality_submodules[encoder_name] is not None - else None - ) - llm_module = mimo_model.language_model.module if mimo_model.language_model is not None else None - initial_encoder_params = ( - {n: p.detach().clone() for n, p in encoder_module.named_parameters()} - if encoder_module is not None - else {} - ) - initial_llm_params = ( - {n: p.detach().clone() for n, p in llm_module.named_parameters()} - if llm_module is not None - else {} - ) - - for iteration in range(num_iterations): - losses = colocated_forward_backward_with_pp( - mimo_model=mimo_model, - data_iterator=data_iterator, - num_microbatches=num_microbatches, - encoder_grid=encoder_grid, - llm_grid=llm_grid, - encoder_name=encoder_name, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - p2p_communicator=P2PCommunicator(pp_group=lm_pp_group, config=mimo_model.config), - pg_collection=language_pg, - ) - - success, grad_norm, _ = optimizer.step() - assert success, f"Rank {rank}: Optimizer step failed at iteration {iteration}" - # grad_norm must be strictly positive: zero means every tracked param - # had zero grad, which indicates the schedule never wired a usable - # gradient into the param.grad buffers. - assert grad_norm is not None and grad_norm > 0, ( - f"Rank {rank}: grad_norm={grad_norm} at iter {iteration} — encoder or " - f"LLM grads were silently zeroed (did Phase 3 broadcast/backward run?)" - ) - optimizer.zero_grad() - - all_losses.extend(losses or []) - logger.info(f"Rank {rank}: iteration {iteration} done, losses={len(losses or [])}") - - # Verify on last PP stage - if is_pp_last_stage(lm_pp_group): - assert len(all_losses) > 0, f"Rank {rank}: No losses on last stage" - for i, loss_dict in enumerate(all_losses): - loss_val = loss_dict.get('loss_reduced', 0) - if isinstance(loss_val, torch.Tensor): - loss_val = loss_val.item() - assert loss_val == loss_val, f"Rank {rank}: NaN loss at mb {i}" - assert abs(loss_val) != float('inf'), f"Rank {rank}: Inf loss at mb {i}" - - # Oracle: at least one param of each module's shard must have changed. - # Under correct three-phase execution, every encoder rank accumulates a - # nonzero DP=1 gradient (via Phase 3 backward from the broadcast grad), - # and every LLM PP stage accumulates nonzero grads from the 1F1B pass. - # A blanket "all shards unchanged" outcome means the optimizer step was - # effectively a no-op for that module on this rank. - if encoder_module is not None: - changed = any( - not torch.equal(p.detach(), initial_encoder_params[n]) - for n, p in encoder_module.named_parameters() - if n in initial_encoder_params - ) - assert changed, ( - f"Rank {rank}: no encoder params changed after {num_iterations} steps — " - f"Phase 3 encoder backward likely did not populate grads on this rank" - ) - if llm_module is not None: - changed = any( - not torch.equal(p.detach(), initial_llm_params[n]) - for n, p in llm_module.named_parameters() - if n in initial_llm_params - ) - assert changed, ( - f"Rank {rank}: no LLM params changed after {num_iterations} steps — " - f"PP stage {dist.get_rank(lm_pp_group)} may have received no gradient" - ) - - return all_losses - -# --------------------------------------------------------------------------- -# Weight-oracle helpers: dist (PP>1, heterogeneous) vs ref (PP=1, equal-DP). -# --------------------------------------------------------------------------- - - -def _build_pp_oracle_model( - encoder_tp, - encoder_dp, - llm_tp, - llm_pp, - llm_dp, - hidden_size, - num_layers, - vocab_size, - seq_length, - ddp_config, - encoder_name="images", -): - """Build a MimoModel + DDP wrap for the weight-oracle test. Returns the - model plus its encoder_grid/llm_grid and pg_collections. Mirrors the - setup in ``run_colocated_pp_test`` but accepts an explicit ``ddp_config`` - so both dist and ref can share ``gradient_reduce_div_factor=1``. - """ - encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) - llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) - create_all_embedding_groups([encoder_grid, llm_grid]) - - vision_pg = get_pg_collection(encoder_grid, is_language_model=False) - language_pg = get_pg_collection(llm_grid, is_language_model=True) - - language_model_spec = get_language_model_spec( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=8, - vocab_size=vocab_size, - seq_len=seq_length, - pg_collection=language_pg, - ) - vision_submodule_spec = get_vision_submodules_spec( - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=8, - language_hidden_size=hidden_size, - pg_collection=vision_pg, - ) - - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={encoder_name: vision_submodule_spec}, - special_token_ids={encoder_name: 50257}, - module_to_grid_map={encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid}, - ) - - mimo_model = MimoModel(mimo_config) - mimo_model.to(torch.device("cuda")).to(torch.bfloat16) - mimo_model.model_type = ModelType.encoder_or_decoder - - if mimo_model.language_model is not None: - mimo_model.language_model = DistributedDataParallel( - config=mimo_model.language_model.config, - ddp_config=ddp_config, - module=mimo_model.language_model, - pg_collection=language_pg, - ) - if encoder_name in mimo_model.modality_submodules: - submodule = mimo_model.modality_submodules[encoder_name] - if submodule is not None: - mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( - config=submodule.encoders['clip_encoder'].config, - ddp_config=ddp_config, - module=submodule, - pg_collection=vision_pg, - ) - - @contextmanager - def no_sync_func(): - with ExitStack() as stack: - if mimo_model.language_model is not None: - stack.enter_context(mimo_model.language_model.no_sync()) - for sub in mimo_model.modality_submodules.values(): - if sub is not None: - stack.enter_context(sub.no_sync()) - yield - - def finalize_grads_func(*args, **kwargs): - if mimo_model.language_model is not None: - finalize_model_grads( - [mimo_model.language_model], num_tokens=None, pg_collection=language_pg - ) - for sub in mimo_model.modality_submodules.values(): - if sub is not None: - finalize_model_grads([sub], num_tokens=None, pg_collection=vision_pg) - - mimo_model.config.no_sync_func = no_sync_func - mimo_model.config.finalize_model_grads_func = finalize_grads_func - mimo_model.config.grad_scale_func = lambda loss: ( - torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) - if isinstance(loss, (int, float)) - else loss - ) - - return mimo_model, encoder_grid, llm_grid, language_pg, vision_pg - - -def _generate_shared_global_batches( - num_batches, - global_batch_size, - seq_length, - hidden_size, - vocab_size, - encoder_name, - image_token_id=50257, -): - """Generate global batches on rank 0 and broadcast so every rank sees - identical data. Encoder input shape is [seq, batch, hidden] (sbh), - matching ``DataIterator`` above. - """ - rank = dist.get_rank() - image_seq_length = seq_length // 2 - batches = [] - for _ in range(num_batches): - if rank == 0: - encoder_hidden_states = torch.randn( - image_seq_length, - global_batch_size, - hidden_size, - device='cuda', - dtype=torch.bfloat16, - ) - image_tokens = torch.full( - (global_batch_size, image_seq_length), - image_token_id, - dtype=torch.long, - device='cuda', - ) - text_tokens = torch.randint( - 1, - vocab_size, - (global_batch_size, seq_length - image_seq_length), - device='cuda', - ) - input_ids = torch.cat([image_tokens, text_tokens], dim=1) - else: - encoder_hidden_states = torch.empty( - image_seq_length, - global_batch_size, - hidden_size, - device='cuda', - dtype=torch.bfloat16, - ) - input_ids = torch.empty( - global_batch_size, seq_length, dtype=torch.long, device='cuda' - ) - dist.broadcast(encoder_hidden_states, src=0) - dist.broadcast(input_ids, src=0) - - labels = input_ids.clone() - labels[input_ids == image_token_id] = -100 - loss_mask = (input_ids != image_token_id).float() - position_ids = ( - torch.arange(seq_length, device='cuda') - .unsqueeze(0) - .expand(global_batch_size, -1) - .clone() - ) - batches.append( - { - "input_ids": input_ids, - "labels": labels, - "loss_mask": loss_mask, - "position_ids": position_ids, - "modality_inputs": { - encoder_name: { - "clip_encoder": { - 'hidden_states': encoder_hidden_states, - 'attention_mask': None, - } - } - }, - } - ) - return batches - - -def _slice_batch_along_dim0(batch, split, idx): - """Return ``idx``-th of ``split`` equal slices along the batch dim.""" - b = batch['input_ids'].shape[0] - size = b // split - s, e = idx * size, (idx + 1) * size - out = {k: batch[k][s:e].contiguous() for k in ['input_ids', 'labels', 'loss_mask', 'position_ids']} - mod_new = {} - for m, md in batch['modality_inputs'].items(): - mod_new[m] = {} - for enc, ed in md.items(): - mod_new[m][enc] = {} - for k, t in ed.items(): - if isinstance(t, torch.Tensor): - # modality hidden_states shape [seq, batch, hidden] — dim 1 - mod_new[m][enc][k] = t[:, s:e, :].contiguous() - else: - mod_new[m][enc][k] = t - out['modality_inputs'] = mod_new - return out - - -def _copy_encoder_params(ref_module, dist_module): - """Copy encoder params ref → dist. Encoder layouts match by construction - (same enc_tp and enc_dp in both models), so shards line up 1:1. - """ - ref_params = dict(ref_module.named_parameters()) - with torch.no_grad(): - for name, dist_param in dist_module.named_parameters(): - assert name in ref_params, f"Encoder param '{name}' missing in ref" - ref_param = ref_params[name] - assert ref_param.shape == dist_param.shape, ( - f"Encoder param '{name}': ref.shape={tuple(ref_param.shape)} != " - f"dist.shape={tuple(dist_param.shape)} — enc_tp/enc_dp must match " - f"between ref and dist for shard-wise comparison." - ) - dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) - - -def _copy_llm_params_pp_aware( - ref_module, dist_module, pp_rank, pp_size, num_layers, dist_tp_group, ref_tp_group -): +def _copy_llm_params_pp_aware(ref_module, dist_module, pp_rank, pp_size, num_layers): """Copy LLM params ref (PP=1) → dist (PP>=1) with layer-index remapping. - In Megatron's ``TransformerBlock``, ``self.layers`` is a ``ModuleList`` - indexed 0..N-1 *locally* per PP stage. The global layer number is - ``local_idx + pp_rank * layers_per_stage``. For a PP=1 reference, all - N layers live at local indices 0..N-1 on each rank; for a PP>1 dist - model, PP stage ``s``'s local layer ``i`` corresponds to ref's global - layer ``s*layers_per_stage + i``. - - Non-layer params (embedding, final_layernorm, output_layer) are only - present on stages with the relevant ``pre_process``/``post_process`` - flag, and their names match exactly between ref (which has them all) - and whichever dist stage owns them. - - If dist_tp != ref_tp the helper falls back to the PR-10 pattern of - gather-across-ref-tp + slice-for-dist-tp. Same-TP is the normal path - (this helper is designed for tests where ``dist_llm_tp == ref_llm_tp``, - so the gather path is a no-op fallback). + Dist's ``decoder.layers.{local_idx}`` on PP stage ``s`` corresponds to + ref's global layer ``s * layers_per_stage + local_idx``. Non-layer + params (embedding, final_layernorm, output_layer) are only present on + stages that own them and their names match exactly between ref and + dist. Assumes ``dist_llm_tp == ref_llm_tp`` so shards line up 1:1. """ - import re - assert num_layers % pp_size == 0, ( f"num_layers={num_layers} not divisible by pp_size={pp_size}; " f"oracle requires even PP split." ) layers_per_stage = num_layers // pp_size layer_rx = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') - ref_params = dict(ref_module.named_parameters()) - ref_tp_size = dist.get_world_size(ref_tp_group) - dist_tp_rank = dist.get_rank(dist_tp_group) - dist_tp_size = dist.get_world_size(dist_tp_group) with torch.no_grad(): for name, dist_param in dist_module.named_parameters(): @@ -769,66 +64,17 @@ def _copy_llm_params_pp_aware( ref_name = f"{prefix}{global_idx}{suffix}" else: ref_name = name - assert ref_name in ref_params, ( f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." ) ref_param = ref_params[ref_name] - partition_dim = getattr(dist_param, 'partition_dim', -1) - - if ref_param.shape == dist_param.shape: - dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) - continue - - assert partition_dim >= 0, ( - f"LLM param '{name}': shapes differ (ref={tuple(ref_param.shape)}, " - f"dist={tuple(dist_param.shape)}) but partition_dim<0 — cannot reshape " - f"a replicated param." + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)} — oracle requires " + f"dist_llm_tp == ref_llm_tp." ) - shards = [torch.empty_like(ref_param.data) for _ in range(ref_tp_size)] - dist.all_gather(shards, ref_param.data.contiguous(), group=ref_tp_group) - full = torch.cat(shards, dim=partition_dim) - sliced = torch.tensor_split(full, dist_tp_size, dim=partition_dim)[dist_tp_rank] - assert sliced.shape == dist_param.shape - dist_param.data.copy_(sliced.to(dist_param.dtype)) - - -def _sum_loss_func(loss_mask_unused, output_tensor): - """Match the ``.sum()`` loss used by ``colocated_schedule._loss_func`` so - the reference's forward_backward_no_pipelining path produces comparable - gradient magnitudes. - """ - if output_tensor is None: - return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} - loss = output_tensor.float().sum() - return loss, {'loss_reduced': loss.detach().item()} - - -def _assert_encoder_shards_match(ref_module, dist_module, rtol=1e-2, atol=1e-2): - """Assert every dist encoder shard matches the ref encoder shard. - - Tolerance accounts for bf16 accumulation-order drift between the ref's - LLM-flat (pp=1) gradient path and the dist's PP>1 1F1B path. Both paths - yield the same DP=1 encoder gradient in exact arithmetic; bf16 rounding - bounds the drift within the tolerance below. - """ - ref_params = dict(ref_module.named_parameters()) - mismatches = [] - for name, dist_param in dist_module.named_parameters(): - ref_param = ref_params[name] - assert ref_param.shape == dist_param.shape - try: - torch.testing.assert_close(dist_param.data, ref_param.data, rtol=rtol, atol=atol) - except AssertionError as e: - mismatches.append((name, str(e))) - if mismatches: - rank = dist.get_rank() - details = "\n".join(f" {n}: {msg}" for n, msg in mismatches) - raise AssertionError( - f"Rank {rank}: {len(mismatches)} encoder param(s) diverged between " - f"PP>1 dist and equal-DP PP=1 ref:\n{details}" - ) + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) def _run_pp_weight_oracle( @@ -844,9 +90,7 @@ def _run_pp_weight_oracle( seq_length=64, micro_batch_size_llm=2, ): - """Drive the dist-vs-ref weight oracle described in - ``test_pp_matches_pp1_equal_dp_reference``. - """ + """Drive the dist (PP>1) vs ref (PP=1, equal-DP) weight oracle.""" import os os.environ.pop('NVTE_FLASH_ATTN', None) @@ -854,17 +98,13 @@ def _run_pp_weight_oracle( os.environ.pop('NVTE_UNFUSED_ATTN', None) encoder_name = "images" - # Equal-DP reference: enc_tp=dist_enc_tp, enc_dp=dist_enc_dp, - # llm_tp=dist_enc_tp (→ same encoder & LLM TP layout), llm_dp=dist_enc_dp, - # llm_pp=1 (identity bridge, only PP value compatible with equal-DP on - # a fixed rank count). + # Equal-DP reference: same encoder TP/DP; LLM matches encoder TP/DP and + # uses PP=1 (the only PP value compatible with equal-DP on a fixed rank + # count when enc_tp == llm_tp). ref_enc_tp, ref_enc_dp = dist_enc_tp, dist_enc_dp ref_llm_tp, ref_llm_pp, ref_llm_dp = dist_enc_tp, 1, dist_enc_dp - # Global batch size spans the larger DP side. Dist's LLM DP is smaller - # (fan-in), so each LLM rank holds micro_batch_size_llm samples. global_batch_size = micro_batch_size_llm * dist_llm_dp - # For ref (equal-DP, llm_dp == enc_dp): per-rank batch = global_batch / enc_dp. ref_per_rank_mbs = global_batch_size // ref_llm_dp ddp_config = DistributedDataParallelConfig( @@ -874,54 +114,63 @@ def _run_pp_weight_oracle( gradient_reduce_div_factor=1, ) - # Build dist first (heterogeneous TP/DP + PP>1). + dist_enc_grid = create_hypercomm_grid( + offset=0, tp=dist_enc_tp, cp=1, pp=1, dp=dist_enc_dp + ) + dist_llm_grid = create_hypercomm_grid( + offset=0, tp=dist_llm_tp, cp=1, pp=dist_llm_pp, dp=dist_llm_dp + ) + ref_enc_grid = create_hypercomm_grid( + offset=0, tp=ref_enc_tp, cp=1, pp=1, dp=ref_enc_dp + ) + ref_llm_grid = create_hypercomm_grid( + offset=0, tp=ref_llm_tp, cp=1, pp=ref_llm_pp, dp=ref_llm_dp + ) + create_all_embedding_groups([dist_enc_grid, dist_llm_grid, ref_enc_grid, ref_llm_grid]) + torch.manual_seed(12345) - dist_model, dist_enc_grid, dist_llm_grid, dist_lang_pg, dist_vis_pg = _build_pp_oracle_model( - encoder_tp=dist_enc_tp, - encoder_dp=dist_enc_dp, - llm_tp=dist_llm_tp, - llm_pp=dist_llm_pp, - llm_dp=dist_llm_dp, + dist_model, _, _, dist_lang_pg, dist_vis_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=dist_enc_grid, + llm_grid=dist_llm_grid, hidden_size=hidden_size, num_layers=num_layers, vocab_size=vocab_size, - seq_length=seq_length, + seq_len=seq_length, ddp_config=ddp_config, ) - # Build ref (equal-DP, PP=1). + dist_model.model_type = ModelType.encoder_or_decoder + torch.manual_seed(12345) - ref_model, ref_enc_grid, ref_llm_grid, ref_lang_pg, ref_vis_pg = _build_pp_oracle_model( - encoder_tp=ref_enc_tp, - encoder_dp=ref_enc_dp, - llm_tp=ref_llm_tp, - llm_pp=ref_llm_pp, - llm_dp=ref_llm_dp, + ref_model, _, _, ref_lang_pg, ref_vis_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=ref_enc_grid, + llm_grid=ref_llm_grid, hidden_size=hidden_size, num_layers=num_layers, vocab_size=vocab_size, - seq_length=seq_length, + seq_len=seq_length, ddp_config=ddp_config, ) + ref_model.model_type = ModelType.encoder_or_decoder - # Force identical initial state. Encoder: same TP/DP → shard-wise copy. - # LLM: ref has pp=1 (all layers), dist has pp>=1 (layers split); remap. - _copy_encoder_params( + _copy_ref_params_to_dist( ref_model.modality_submodules[encoder_name].module, dist_model.modality_submodules[encoder_name].module, + ref_enc_grid.get_pg("tp"), + dist_enc_grid.get_pg("tp"), ) - dist_pp_rank = dist_llm_grid.get_pg("pp").rank() _copy_llm_params_pp_aware( ref_model.language_model.module, dist_model.language_model.module, - pp_rank=dist_pp_rank, + pp_rank=dist_llm_grid.get_pg("pp").rank(), pp_size=dist_llm_pp, num_layers=num_layers, - dist_tp_group=dist_llm_grid.get_pg("tp"), - ref_tp_group=ref_llm_grid.get_pg("tp"), ) - # Build optimizers AFTER weight copy (distributed optimizer snapshots - # fp32 master weights at __init__). + _wire_training_hooks(dist_model, dist_lang_pg, dist_vis_pg) + _wire_training_hooks(ref_model, ref_lang_pg, ref_vis_pg) + opt_config = OptimizerConfig( optimizer='adam', lr=1e-4, @@ -933,51 +182,28 @@ def _run_pp_weight_oracle( dist_optimizer = get_mimo_optimizer(dist_model, opt_config) ref_optimizer = get_mimo_optimizer(ref_model, opt_config) - # Deterministic shared global data. Both models consume the same global - # batches but slice differently: - # - Dist's data_iterator yields per-LLM-rank micro_batch_size samples - # (schedule then fan-in-slices on the encoder side). - # - Ref's data_iterator yields per-rank ref_per_rank_mbs samples. torch.manual_seed(99999) - global_batches = _generate_shared_global_batches( - num_batches=num_microbatches, - global_batch_size=global_batch_size, + global_batches = _generate_and_broadcast_global_batches( + global_mbs=global_batch_size, seq_length=seq_length, hidden_size=hidden_size, vocab_size=vocab_size, encoder_name=encoder_name, + num_batches=num_microbatches, ) - dist_llm_dp_pg = dist_llm_grid.get_pg("dp") - ref_enc_dp_pg = ref_enc_grid.get_pg("dp") - dist_per_rank_batches = [ - _slice_batch_along_dim0(b, dist_llm_dp, dist_llm_dp_pg.rank()) + dist_batches = [ + _slice_batch(b, dist_llm_dp, dist_llm_grid.get_pg("dp").rank()) for b in global_batches ] - ref_per_rank_batches = [ - _slice_batch_along_dim0(b, ref_enc_dp, ref_enc_dp_pg.rank()) + ref_batches = [ + _slice_batch(b, ref_enc_dp, ref_enc_grid.get_pg("dp").rank()) for b in global_batches ] - # ── Dist forward/backward: three-phase colocated schedule ──────────── - class _ListIter: - def __init__(self, items): - self._items = items - self._i = 0 - - def __iter__(self): - return self - - def __next__(self): - if self._i >= len(self._items): - raise StopIteration - v = self._items[self._i] - self._i += 1 - return v - dist_optimizer.zero_grad() colocated_forward_backward_with_pp( mimo_model=dist_model, - data_iterator=_ListIter(dist_per_rank_batches), + data_iterator=_BatchIterator(dist_batches), num_microbatches=num_microbatches, encoder_grid=dist_enc_grid, llm_grid=dist_llm_grid, @@ -992,10 +218,18 @@ def __next__(self): dist_ok, dist_gn, _ = dist_optimizer.step() assert dist_ok, "Dist optimizer step failed" assert dist_gn is not None and dist_gn > 0, ( - f"Dist grad_norm={dist_gn} — three-phase schedule produced zero encoder/LLM grads." + f"Dist grad_norm={dist_gn} — three-phase schedule produced zero grads." ) - # ── Ref forward/backward: plain no-pipelining schedule ─────────────── + def _sum_loss(loss_mask_unused, output_tensor): + if output_tensor is None: + return ( + torch.tensor(0.0, device='cuda', requires_grad=True), + {'loss_reduced': 0.0}, + ) + loss = output_tensor.float().sum() + return loss, {'loss_reduced': loss.detach().item()} + def _ref_forward_step(data_iterator, model, *args): batch = next(data_iterator) output_tensor, loss_mask = model( @@ -1005,12 +239,12 @@ def _ref_forward_step(data_iterator, model, *args): position_ids=batch['position_ids'], modality_inputs=batch['modality_inputs'], ) - return output_tensor, partial(_sum_loss_func, loss_mask) + return output_tensor, partial(_sum_loss, loss_mask) ref_optimizer.zero_grad() schedule.forward_backward_no_pipelining( forward_step_func=_ref_forward_step, - data_iterator=_ListIter(ref_per_rank_batches), + data_iterator=_BatchIterator(ref_batches), model=[ref_model], num_microbatches=num_microbatches, seq_length=seq_length, @@ -1022,8 +256,9 @@ def _ref_forward_step(data_iterator, model, *args): assert ref_ok, "Ref optimizer step failed" assert ref_gn is not None and ref_gn > 0, f"Ref grad_norm={ref_gn}" - # Main oracle: post-step encoder shards match 1:1 (same enc_tp, enc_dp). - _assert_encoder_shards_match( + # bf16 accumulation drift from the differing LLM paths (1F1B vs. + # no-pipelining) requires slightly looser tolerances than bf16 rounding. + _assert_encoder_weights_match( ref_model.modality_submodules[encoder_name].module, dist_model.modality_submodules[encoder_name].module, rtol=1e-2, @@ -1048,51 +283,6 @@ def teardown_class(cls): def teardown_method(self): destroy_all_grids() - def test_fan_in_enc_tp2_dp4_llm_tp2_dp2_pp2(self): - """Fan-in: encoder TP2/DP4 → LLM TP2/DP2/PP2.""" - if self.world_size != 8: - pytest.skip(f"Requires 8 GPUs, got {self.world_size}") - run_colocated_pp_test( - encoder_tp=2, encoder_dp=4, llm_tp=2, llm_pp=2, llm_dp=2, num_microbatches=4 - ) - - def test_equal_dp_enc_tp4_dp2_llm_tp2_dp2_pp2(self): - """Equal DP: encoder TP4/DP2 → LLM TP2/DP2/PP2 (enc_dp == llm_dp).""" - if self.world_size != 8: - pytest.skip(f"Requires 8 GPUs, got {self.world_size}") - run_colocated_pp_test( - encoder_tp=4, encoder_dp=2, llm_tp=2, llm_pp=2, llm_dp=2, num_microbatches=4 - ) - - def test_fan_in_with_grad_acc(self): - """Fan-in with gradient accumulation (num_microbatches > pp_size).""" - if self.world_size != 8: - pytest.skip(f"Requires 8 GPUs, got {self.world_size}") - run_colocated_pp_test( - encoder_tp=2, - encoder_dp=4, - llm_tp=2, - llm_pp=2, - llm_dp=2, - num_microbatches=6, # > pp_size=2, tests grad accumulation - ) - - def test_fan_in_enc_tp1_dp8_llm_tp4_dp1_pp2(self): - """Fan-in extreme: encoder TP1/DP8 → LLM TP4/DP1/PP2.""" - if self.world_size != 8: - pytest.skip(f"Requires 8 GPUs, got {self.world_size}") - # micro_batch_size must be >= fan-in scale (enc_dp/llm_dp = 8/1 = 8) - # to avoid zero-sized slices in _slice_for_encoder_dp. - run_colocated_pp_test( - encoder_tp=1, - encoder_dp=8, - llm_tp=4, - llm_pp=2, - llm_dp=1, - micro_batch_size=8, - num_microbatches=4, - ) - @pytest.mark.parametrize( "num_microbatches", [2, 4], @@ -1101,49 +291,17 @@ def test_fan_in_enc_tp1_dp8_llm_tp4_dp1_pp2(self): def test_pp_matches_pp1_equal_dp_reference(self, num_microbatches): """Post-step encoder weights under PP>1 match equal-DP PP=1 reference. - This is the real correctness oracle for PR-9's three-phase schedule. - Parallels the PR-10 oracle (``test_mimo_colocated_correctness.py``), - extended with PP-aware LLM weight reshaping so the reference has - ``llm_pp=1`` (the only config compatible with equal-DP on a fixed - rank count: ``enc_tp * enc_dp == llm_tp * llm_pp * llm_dp`` and - ``enc_dp == llm_dp`` force ``llm_pp = enc_tp / llm_tp``; with - ``enc_tp == llm_tp`` this means ``llm_pp = 1``). - - * Dist: fan-in + PP>1 (the config under test). Runs through - ``colocated_forward_backward_with_pp`` (three-phase schedule). - * Ref: ``enc_tp=dist_enc_tp``, ``enc_dp=dist_enc_dp``, - ``llm_tp=dist_enc_tp``, ``llm_dp=dist_enc_dp``, ``llm_pp=1``. - Identity bridge (``BridgeDirection.EQUAL``); runs through - ``forward_backward_no_pipelining``. - - Both use ``gradient_reduce_div_factor=1`` with an identical - ``.sum()`` loss (matching the loss in ``colocated_schedule.py``), - so the DDP reduction yields the DP=1 aggregate gradient on every - encoder shard regardless of LLM layout. Encoder TP matches across - the two models, so shards line up 1:1. LLM TP matches too; LLM - weights differ only in PP partitioning, which the copy helper - below reshapes. Under correct PP>1 encoder grad accumulation + - broadcast, one Adam step yields shard-wise equal post-step encoder - weights modulo bf16 accumulation drift. - - If the three-phase schedule mishandles any of: encoder grad - accumulation across microbatches, PP-stage-0→stage-N broadcast, - or the detach/reattach boundary, encoder shards diverge and this - test fails. - - Two parametrized cases: - * ``num_mb_eq_pp`` (num_microbatches=2, pp=2): the minimal - pipeline with one microbatch per PP stage. No grad - accumulation across microbatches. - * ``num_mb_gt_pp_grad_acc`` (num_microbatches=4, pp=2): 1F1B - pipeline runs with 2 microbatches per stage, so encoder - embedding views for 4 microbatches all accumulate into the - same ``detached_full.grad`` via PyTorch view-gradient - semantics. If the microbatch slicing in - ``_build_lm_microbatches`` does not produce proper views of - ``detached_full`` (e.g., accidentally cloning), grad - accumulation across microbatches is silently dropped and the - encoder shards diverge. + Dist runs ``colocated_forward_backward_with_pp`` (three-phase + schedule with PP=2 on the LLM); ref runs + ``forward_backward_no_pipelining`` with matching encoder TP/DP and + LLM PP=1. Under correct PP>1 encoder grad accumulation + broadcast, + one Adam step yields shard-wise equal post-step encoder weights + (modulo bf16 drift). + + The ``num_mb_gt_pp_grad_acc`` case runs more microbatches than PP + stages so encoder embedding views for every microbatch must + accumulate into the same ``detached_full.grad`` via PyTorch + view-gradient semantics — a regression there surfaces here. """ if self.world_size != 8: pytest.skip(f"Requires 8 GPUs, got {self.world_size}") From 2b6005abf7e34eaad1f27c141776b3d97d1e3ba1 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Wed, 22 Apr 2026 03:11:44 +0000 Subject: [PATCH 10/11] Address PR #9 review comments (NMFW-19) - Drop unused lm_is_first_pp_stage on MimoModel and tighten the PP-flag derivation; lm_has_pp is the only flag actually consumed by the three-phase schedule. - Drop MockGrid.shape + _pp_rank/_pp_size in test_mimo_model: unused after _colocated switched to grid.get_pg('pp') for PP-stage derivation. - Tighten the Phase-1 detach comment in colocated_schedule to a single "why". - Add PP-aware LLM weight parity check (_assert_llm_weights_match_pp_aware) alongside the existing encoder check in test_mimo_colocated_pp; dist PP>1 LLM shards must match the PP=1 ref via the same layer-index remap used for init. Co-Authored-By: Claude Opus 4.7 --- .../core/models/mimo/colocated_schedule.py | 5 +- megatron/core/models/mimo/model/base.py | 3 -- .../models/test_mimo_colocated_pp.py | 54 +++++++++++++++++++ tests/unit_tests/models/test_mimo_model.py | 6 --- 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py index e7e9598c002..d53cf9c95b1 100644 --- a/megatron/core/models/mimo/colocated_schedule.py +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -58,9 +58,8 @@ def colocated_forward_backward_with_pp( enc_out = mimo_model.encode_and_communicate({encoder_name: full_encoder_input}) - # Detach: sever autograd link to encoder so Phase 2 has no encoder collectives. - # Microbatch slices are views into detached_full — their .grad accumulates - # into detached_full.grad automatically via PyTorch's view gradient semantics. + # Detach so Phase 2 runs no encoder collectives; microbatch views accumulate + # .grad into detached_full.grad automatically. detached_full = {k: v.detach().requires_grad_(True) for k, v in enc_out.items()} lm_data = _build_lm_microbatches(detached_full, all_batches, num_microbatches) diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 83f3c50e6c4..0af467276b3 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -67,13 +67,10 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # in TP/DP within those ranks. self._build_colocated_communicators() - # LLM PP>1 is already reflected in self.role; expose convenience flags - # for the three-phase colocated schedule. lang_info = self.role.modules.get(MIMO_LANGUAGE_MODULE_KEY) self.lm_has_pp = lang_info is not None and not ( lang_info.is_first_stage and lang_info.is_last_stage ) - self.lm_is_first_pp_stage = lang_info is None or lang_info.is_first_stage # Use special token IDs from the config self.special_token_ids = ( diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py index 7d6143b2b2f..0cb73b4d8f0 100644 --- a/tests/unit_tests/models/test_mimo_colocated_pp.py +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -38,6 +38,51 @@ from tests.unit_tests.test_utilities import Utils +def _assert_llm_weights_match_pp_aware( + ref_module, dist_module, pp_rank, pp_size, num_layers, rtol=1e-2, atol=1e-2 +): + """Assert dist LLM shards match ref (PP=1) via the same layer-index remap + ``_copy_llm_params_pp_aware`` uses. Non-layer params (embedding, + final_layernorm, output_layer) only exist on stages that own them. + """ + layers_per_stage = num_layers // pp_size + layer_rx = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') + ref_params = dict(ref_module.named_parameters()) + + mismatches = [] + for name, dist_param in dist_module.named_parameters(): + m = layer_rx.match(name) + if m: + prefix, local_idx_s, suffix = m.groups() + global_idx = pp_rank * layers_per_stage + int(local_idx_s) + ref_name = f"{prefix}{global_idx}{suffix}" + else: + ref_name = name + assert ref_name in ref_params, ( + f"LLM param '{name}' maps to ref '{ref_name}' which does not exist " + f"(ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)}." + ) + try: + torch.testing.assert_close( + dist_param.data, ref_param.data, rtol=rtol, atol=atol + ) + except AssertionError as e: + mismatches.append((name, ref_name, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n} -> {rn}: {msg}" for n, rn, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} LLM param(s) diverged between " + f"PP>1 dist model and PP=1 reference:\n{details}" + ) + + def _copy_llm_params_pp_aware(ref_module, dist_module, pp_rank, pp_size, num_layers): """Copy LLM params ref (PP=1) → dist (PP>=1) with layer-index remapping. @@ -264,6 +309,15 @@ def _ref_forward_step(data_iterator, model, *args): rtol=1e-2, atol=1e-2, ) + _assert_llm_weights_match_pp_aware( + ref_model.language_model.module, + dist_model.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=dist_llm_pp, + num_layers=num_layers, + rtol=1e-2, + atol=1e-2, + ) @pytest.mark.skipif( diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index 325f8cb305e..3ab65a3616e 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -453,14 +453,8 @@ def __init__(self, rank_offset=0, size=1, dim_names=None, pp_rank=0, pp_size=1): self.rank_offset = rank_offset self.size = size self.dim_names = dim_names or [] - self._pp_rank = pp_rank - self._pp_size = pp_size self._pp_group = MockProcessGroup(pp_rank, pp_size) - @property - def shape(self): - return tuple(self._pp_size if d == "pp" else 1 for d in self.dim_names) - def get_pg(self, dims): if dims == "pp": return self._pp_group From ed63e63532568cece0c56efc4352b03b291c4aa7 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Wed, 22 Apr 2026 03:31:12 +0000 Subject: [PATCH 11/11] colocated_schedule: thread num_tokens for calculate_per_token_loss=True MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #10 replaced the ad-hoc gradient_reduce_div_factor DDP knob with calculate_per_token_loss=True on both sub-model configs plus a custom finalize hook that divides grads by the global valid-token count. The three-phase PP schedule now has to forward the schedule's total_num_tokens to the deferred finalize call, otherwise the hook's assertion fails and per-token normalization never happens on the encoder/LLM grads. * _loss_func now returns the 3-tuple (local_sum, local_num_tokens, log_dict) contract the schedule expects when per-token loss is on. * _deferred_finalize swaps the finalize hook with a capturing stub that records the num_tokens the inner schedule would have passed; after Phase 3, we invoke the original finalize with the captured value. test_mimo_colocated_pp: adopt per-token-loss wiring, add PP broadcast _wire_training_hooks from the PR #10 correctness test only all-reduces num_tokens over the LLM DP group. With LLM PP>1, non-last PP stages see num_tokens=0 from the inner schedule (loss runs only on the last stage), so the DP sum would land at N_last_stage instead of N_global and encoder/LLM grads would end up scaled differently per PP stage. _wire_pp_training_hooks broadcasts num_tokens from the last LLM PP rank first, then all-reduces across DP — every rank arrives at the same N_global. The PP test also drops the removed gradient_reduce_div_factor kwarg, switches both models to fp32 / no-bias / no-dropout for exact comparison, and uses the 3-tuple loss shape on the ref forward path. --- .../core/models/mimo/colocated_schedule.py | 63 +++++++---- .../models/test_mimo_colocated_pp.py | 100 +++++++++++++++--- 2 files changed, 131 insertions(+), 32 deletions(-) diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py index d53cf9c95b1..e1fe93aec0a 100644 --- a/megatron/core/models/mimo/colocated_schedule.py +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -78,10 +78,12 @@ def _lm_forward_step(data_iterator_unused, model, *args): ) return output_tensor, partial(_loss_func, cached['loss_mask']) - # Swap in a no-op finalize so the inner PP schedule does not run DDP - # grad sync before Phase 3 has produced encoder grads. We invoke the - # original finalize once after Phase 3 (see below). - with _deferred_finalize(mimo_model.config) as original_finalize: + # Swap in a capturing finalize so the inner PP schedule does not run DDP + # grad sync before Phase 3 has produced encoder grads. The capture also + # records ``num_tokens`` that the inner schedule would have passed — we + # forward it to the original finalize after Phase 3 so per-token-loss + # configs see the correct global divisor. + with _deferred_finalize(mimo_model.config) as (original_finalize, capture): losses = schedules.forward_backward_pipelining_without_interleaving( forward_step_func=_lm_forward_step, data_iterator=cache_iter, @@ -109,7 +111,7 @@ def _lm_forward_step(data_iterator_unused, model, *args): if not forward_only and original_finalize is not None: original_finalize( [mimo_model], - None, + capture.num_tokens, pg_collection=schedule_kwargs.get('pg_collection'), force_all_reduce=False, ) @@ -224,30 +226,53 @@ def _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first): def _loss_func(loss_mask, output_tensor): - """Default loss function for the LLM pipeline.""" - if output_tensor is None: - return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} - loss = output_tensor.float().sum() - return loss, {'loss_reduced': loss.detach().item()} + """Default loss function for the LLM pipeline. + Returns the 3-tuple ``(local_sum, local_num_tokens, log_dict)`` contract + expected when ``calculate_per_token_loss=True`` is set on the + TransformerConfig. When it is not set, the schedule divides + ``local_sum`` by ``local_num_tokens`` (clamped to 1), so the 3-tuple + form is also safe for standard per-microbatch-mean configs. + """ + if output_tensor is None: + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} + + +class _CapturingFinalize: + """Capture the ``num_tokens`` the inner PP schedule would have passed. + + The three-phase schedule defers grad finalization until after Phase 3 + runs encoder backward. Replacing the config's ``finalize_model_grads_func`` + with this object absorbs the inner schedule's invocation and stores + ``num_tokens`` so the post-Phase-3 call to the original finalize can + forward it — required for ``calculate_per_token_loss=True`` configs + whose finalize hook divides by the global valid-token count. + """ -def _noop_finalize(*args, **kwargs): - """Placeholder used to suppress the inner PP schedule's finalize call. + def __init__(self): + self.num_tokens = None - The three-phase schedule needs to defer grad finalization until after - Phase 3 runs the encoder backward. See ``colocated_forward_backward_with_pp``. - """ - return None + def __call__(self, model_list, num_tokens, *args, **kwargs): + self.num_tokens = num_tokens + return None @contextmanager def _deferred_finalize(config): """Suppress the PP schedule's end-of-run DDP grad sync; yield the - original so callers can invoke it once after Phase 3. + original finalize and a capture object so callers can invoke the + original (with the captured ``num_tokens``) once after Phase 3. """ original = config.finalize_model_grads_func - config.finalize_model_grads_func = _noop_finalize + capture = _CapturingFinalize() + config.finalize_model_grads_func = capture try: - yield original + yield original, capture finally: config.finalize_model_grads_func = original diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py index 0cb73b4d8f0..87280c45911 100644 --- a/tests/unit_tests/models/test_mimo_colocated_pp.py +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -16,12 +16,14 @@ import megatron.core.pipeline_parallel.schedules as schedule from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp from megatron.core.models.mimo.optimizer import get_mimo_optimizer from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.transformer.enums import ModelType from tests.unit_tests.models.test_mimo_1f1b_schedule import ( + build_no_sync_func, create_all_embedding_groups, create_hypercomm_grid, destroy_all_grids, @@ -38,6 +40,68 @@ from tests.unit_tests.test_utilities import Utils +def _wire_pp_training_hooks(mimo_model, language_pg, vision_pg, llm_grid): + """PP-aware variant of ``_wire_training_hooks`` for LLM PP>1. + + ``calculate_per_token_loss=True`` on both sub-model configs pins DDP's + gradient_scaling_factor to 1.0 (pure SUM across DP). But with LLM PP>1, + the inner schedule only populates ``num_tokens`` on the last PP stage; + non-last stages get 0. Before the DP all-reduce, this helper broadcasts + ``num_tokens`` from the last LLM PP rank to earlier ones so every rank + arrives at the same ``N_global`` and the per-token divisor lands + uniformly on encoder + LLM grads. + """ + + no_sync_func = build_no_sync_func(mimo_model) + pp_group = llm_grid.get_pg("pp") + + def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs): + assert num_tokens is not None, ( + "finalize_grads_func expects calculate_per_token_loss=True on the " + "TransformerConfig so the schedule forwards total_num_tokens; got None." + ) + + if pp_group.size() > 1: + last_rank = dist.get_global_rank(pp_group, pp_group.size() - 1) + dist.broadcast(num_tokens, src=last_rank, group=pp_group) + + llm_dp_pg = language_pg.dp_cp if language_pg.dp_cp is not None else language_pg.dp + dist.all_reduce(num_tokens, group=llm_dp_pg, op=dist.ReduceOp.SUM) + n_global = num_tokens.item() + + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], + num_tokens=None, + pg_collection=language_pg, + force_all_reduce=force_all_reduce, + ) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + finalize_model_grads( + [submodule], + num_tokens=None, + pg_collection=vision_pg, + force_all_reduce=force_all_reduce, + ) + + if n_global > 0: + inv = 1.0 / n_global + if mimo_model.language_model is not None: + mimo_model.language_model.scale_gradients(inv) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + submodule.scale_gradients(inv) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + def _assert_llm_weights_match_pp_aware( ref_module, dist_module, pp_rank, pp_size, num_layers, rtol=1e-2, atol=1e-2 ): @@ -153,10 +217,9 @@ def _run_pp_weight_oracle( ref_per_rank_mbs = global_batch_size // ref_llm_dp ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=False, + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True, - gradient_reduce_div_factor=1, ) dist_enc_grid = create_hypercomm_grid( @@ -183,6 +246,10 @@ def _run_pp_weight_oracle( vocab_size=vocab_size, seq_len=seq_length, ddp_config=ddp_config, + bf16=False, + bias=False, + dropout=False, + per_token_loss=True, ) dist_model.model_type = ModelType.encoder_or_decoder @@ -196,6 +263,10 @@ def _run_pp_weight_oracle( vocab_size=vocab_size, seq_len=seq_length, ddp_config=ddp_config, + bf16=False, + bias=False, + dropout=False, + per_token_loss=True, ) ref_model.model_type = ModelType.encoder_or_decoder @@ -213,7 +284,7 @@ def _run_pp_weight_oracle( num_layers=num_layers, ) - _wire_training_hooks(dist_model, dist_lang_pg, dist_vis_pg) + _wire_pp_training_hooks(dist_model, dist_lang_pg, dist_vis_pg, dist_llm_grid) _wire_training_hooks(ref_model, ref_lang_pg, ref_vis_pg) opt_config = OptimizerConfig( @@ -221,7 +292,7 @@ def _run_pp_weight_oracle( lr=1e-4, weight_decay=0.01, clip_grad=1.0, - bf16=True, + bf16=False, use_distributed_optimizer=True, ) dist_optimizer = get_mimo_optimizer(dist_model, opt_config) @@ -266,14 +337,16 @@ def _run_pp_weight_oracle( f"Dist grad_norm={dist_gn} — three-phase schedule produced zero grads." ) - def _sum_loss(loss_mask_unused, output_tensor): + def _sum_loss(loss_mask, output_tensor): + """Per-token-loss 3-tuple matching ``_wire_training_hooks`` contract.""" if output_tensor is None: - return ( - torch.tensor(0.0, device='cuda', requires_grad=True), - {'loss_reduced': 0.0}, - ) - loss = output_tensor.float().sum() - return loss, {'loss_reduced': loss.detach().item()} + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} def _ref_forward_step(data_iterator, model, *args): batch = next(data_iterator) @@ -301,8 +374,9 @@ def _ref_forward_step(data_iterator, model, *args): assert ref_ok, "Ref optimizer step failed" assert ref_gn is not None and ref_gn > 0, f"Ref grad_norm={ref_gn}" - # bf16 accumulation drift from the differing LLM paths (1F1B vs. - # no-pipelining) requires slightly looser tolerances than bf16 rounding. + # LLM forward differs between 1F1B (dist, PP>1) and no-pipelining (ref), + # and TP shards may accumulate in a different order; keep tolerances + # loose enough to absorb that drift even in fp32. _assert_encoder_weights_match( ref_model.modality_submodules[encoder_name].module, dist_model.modality_submodules[encoder_name].module,