diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py new file mode 100644 index 00000000000..e1fe93aec0a --- /dev/null +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -0,0 +1,278 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Three-phase schedule for colocated MIMO training with LLM PP>1. + +Phase 1: Encoder forward + communicate for the full batch (all ranks synchronized). +Phase 2: LLM 1F1B pipeline with detached encoder embeddings sliced per microbatch. +Phase 3: Encoder backward for the full batch (all ranks synchronized). + +Encoder runs on all ranks (PP=1) and its TP/DP collectives require all ranks +to participate simultaneously. The 1F1B pipeline staggers ranks across PP stages, +so encoder collectives cannot run inside the pipeline. The three-phase design +separates encoder (synchronized) from LLM (pipelined) by detaching the autograd +graph at the encoder-LLM boundary. +""" + +from contextlib import contextmanager +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.pipeline_parallel import schedules + + +def colocated_forward_backward_with_pp( + mimo_model, + data_iterator, + num_microbatches: int, + encoder_grid: Optional[HyperCommGrid] = None, + llm_grid: Optional[HyperCommGrid] = None, + encoder_name: str = "images", + forward_only: bool = False, + **schedule_kwargs, +): + """Three-phase colocated training: encoder batch -> LLM pipeline -> encoder backward. + + Args: + mimo_model: MimoModel with colocated communicators and lm_has_pp=True. + data_iterator: Yields dicts with input_ids, labels, etc. + num_microbatches: Number of microbatches for the LLM pipeline. + encoder_grid: Encoder HyperCommGrid (for DP fan-in slicing). + llm_grid: LLM HyperCommGrid (for PP group). + encoder_name: Modality name for the encoder (e.g., "images"). + forward_only: Skip backward passes if True. + **schedule_kwargs: Passed to forward_backward_pipelining_without_interleaving. + Must include p2p_communicator, pg_collection, seq_length, micro_batch_size. + """ + pp_group = llm_grid.get_pg("pp") if llm_grid and 'pp' in llm_grid.dim_names else None + is_pp_first = pp_group is None or pp_group.rank() == 0 + + # ── Phase 1: Encoder forward on full batch (one pass) ──────────────── + # All ranks participate (encoder is PP=1, communicate is collective). + all_batches = [next(data_iterator) for _ in range(num_microbatches)] + full_encoder_input = _concat_encoder_inputs(all_batches, encoder_name) + _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid) + + enc_out = mimo_model.encode_and_communicate({encoder_name: full_encoder_input}) + + # Detach so Phase 2 runs no encoder collectives; microbatch views accumulate + # .grad into detached_full.grad automatically. + detached_full = {k: v.detach().requires_grad_(True) for k, v in enc_out.items()} + lm_data = _build_lm_microbatches(detached_full, all_batches, num_microbatches) + + # ── Phase 2: LLM 1F1B pipeline ────────────────────────────────────── + # Only LLM P2P communication (within PP group). No encoder collectives. + cache_iter = iter(lm_data) + + def _lm_forward_step(data_iterator_unused, model, *args): + cached = next(cache_iter) + output_tensor, loss_mask = model( + input_ids=cached['input_ids'], + labels=cached['labels'], + loss_mask=cached['loss_mask'], + position_ids=cached['position_ids'], + encoder_embeddings=cached['encoder_embeddings'], + ) + return output_tensor, partial(_loss_func, cached['loss_mask']) + + # Swap in a capturing finalize so the inner PP schedule does not run DDP + # grad sync before Phase 3 has produced encoder grads. The capture also + # records ``num_tokens`` that the inner schedule would have passed — we + # forward it to the original finalize after Phase 3 so per-token-loss + # configs see the correct global divisor. + with _deferred_finalize(mimo_model.config) as (original_finalize, capture): + losses = schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=_lm_forward_step, + data_iterator=cache_iter, + model=[mimo_model], + num_microbatches=num_microbatches, + forward_only=forward_only, + **schedule_kwargs, + ) + + # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── + # detached_full.grad was populated by Phase 2's per-microbatch LLM backward + # (accumulated across microbatch view slices on PP stage 0). + # Broadcast to PP stage 1+ then run one encoder backward for the full batch. + if not forward_only and enc_out: + _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first) + for key in enc_out: + grad = detached_full[key].grad + if grad is not None: + torch.autograd.backward(enc_out[key], grad_tensors=grad) + + # Single post-Phase-3 finalize: reduces LLM grads (from Phase 2) and + # encoder grads (from Phase 3) together. Without this call, encoder + # grads remain local to each rank and Adam steps on un-reduced grads, + # causing silent divergence from the equal-DP reference. + if not forward_only and original_finalize is not None: + original_finalize( + [mimo_model], + capture.num_tokens, + pg_collection=schedule_kwargs.get('pg_collection'), + force_all_reduce=False, + ) + + return losses + + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _concat_encoder_inputs(all_batches, encoder_name): + """Concatenate encoder inputs from all microbatches along batch dim (dim 1).""" + first = all_batches[0] + result = {} + if not (first.get('modality_inputs') and encoder_name in first['modality_inputs']): + return result + for enc_name in first['modality_inputs'][encoder_name]: + result[enc_name] = {} + for key in first['modality_inputs'][encoder_name][enc_name]: + vals = [ + b['modality_inputs'][encoder_name][enc_name][key] + for b in all_batches + if b.get('modality_inputs') and encoder_name in b['modality_inputs'] + ] + tensors = [v for v in vals if isinstance(v, torch.Tensor)] + result[enc_name][key] = torch.cat(tensors, dim=1) if tensors else vals[0] + return result + + +def _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid): + """Slice concatenated encoder input for fan-in (enc_dp > llm_dp).""" + if encoder_grid is None or llm_grid is None: + return + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if enc_dp <= llm_dp: + return + scale = enc_dp // llm_dp + slot = encoder_grid.get_pg("dp").rank() % scale + for enc_name in full_encoder_input: + for key, tensor in full_encoder_input[enc_name].items(): + if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2: + bs = tensor.shape[1] + ss = bs // scale + if ss == 0: + raise ValueError( + f"Encoder fan-in produces zero-sized batch: " + f"total_batch={bs}, scale={scale}. Increase micro_batch_size." + ) + full_encoder_input[enc_name][key] = tensor[ + :, slot * ss : (slot + 1) * ss, : + ].contiguous() + + +def _build_lm_microbatches(detached_full, all_batches, num_microbatches): + """Slice detached encoder output into per-microbatch views for the LLM pipeline.""" + if not detached_full: + # Text-only batch: no encoder embeddings to slice + return [ + { + 'encoder_embeddings': {}, + 'input_ids': all_batches[mb].get('input_ids'), + 'labels': all_batches[mb].get('labels'), + 'loss_mask': all_batches[mb].get('loss_mask'), + 'position_ids': all_batches[mb].get('position_ids'), + } + for mb in range(num_microbatches) + ] + + sample = next(iter(detached_full.values())) + batch_dim = 1 if sample.ndim == 3 else 0 + total_batch = sample.shape[batch_dim] + assert total_batch % num_microbatches == 0, ( + f"Encoder output batch ({total_batch}) must be divisible " + f"by num_microbatches ({num_microbatches})" + ) + mb_size = total_batch // num_microbatches + + lm_data = [] + for mb in range(num_microbatches): + s, e = mb * mb_size, (mb + 1) * mb_size + mb_enc = {} + for k, v in detached_full.items(): + mb_enc[k] = v[:, s:e, :] if v.ndim == 3 else v[s:e, :] + lm_data.append( + { + 'encoder_embeddings': mb_enc, + 'input_ids': all_batches[mb].get('input_ids'), + 'labels': all_batches[mb].get('labels'), + 'loss_mask': all_batches[mb].get('loss_mask'), + 'position_ids': all_batches[mb].get('position_ids'), + } + ) + return lm_data + + +def _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first): + """Broadcast encoder gradient from PP stage 0 to stage 1+ ranks.""" + if pp_group is None or pp_group.size() <= 1: + return + src = dist.get_global_rank(pp_group, 0) + for key in enc_out: + if is_pp_first: + assert ( + detached_full[key].grad is not None + ), f"No encoder gradient on PP stage 0 for '{key}'" + dist.broadcast(detached_full[key].grad, src=src, group=pp_group) + else: + grad = torch.zeros_like(detached_full[key]) + dist.broadcast(grad, src=src, group=pp_group) + detached_full[key].grad = grad + + +def _loss_func(loss_mask, output_tensor): + """Default loss function for the LLM pipeline. + + Returns the 3-tuple ``(local_sum, local_num_tokens, log_dict)`` contract + expected when ``calculate_per_token_loss=True`` is set on the + TransformerConfig. When it is not set, the schedule divides + ``local_sum`` by ``local_num_tokens`` (clamped to 1), so the 3-tuple + form is also safe for standard per-microbatch-mean configs. + """ + if output_tensor is None: + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} + + +class _CapturingFinalize: + """Capture the ``num_tokens`` the inner PP schedule would have passed. + + The three-phase schedule defers grad finalization until after Phase 3 + runs encoder backward. Replacing the config's ``finalize_model_grads_func`` + with this object absorbs the inner schedule's invocation and stores + ``num_tokens`` so the post-Phase-3 call to the original finalize can + forward it — required for ``calculate_per_token_loss=True`` configs + whose finalize hook divides by the global valid-token count. + """ + + def __init__(self): + self.num_tokens = None + + def __call__(self, model_list, num_tokens, *args, **kwargs): + self.num_tokens = num_tokens + return None + + +@contextmanager +def _deferred_finalize(config): + """Suppress the PP schedule's end-of-run DDP grad sync; yield the + original finalize and a capture object so callers can invoke the + original (with the captured ``num_tokens``) once after Phase 3. + """ + original = config.finalize_model_grads_func + capture = _CapturingFinalize() + config.finalize_model_grads_func = capture + try: + yield original, capture + finally: + config.finalize_model_grads_func = original diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py index dd0241d8f80..b501d911bbb 100644 --- a/megatron/core/models/mimo/comm/colocated_communicator.py +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -128,8 +128,9 @@ def _validate_grids(self): f"src={self.src_grid.rank_offset}, dest={self.dest_grid.rank_offset}" ) - # Per-grid dim checks: tp/dp required; pp and cp (if present) must be 1. - # CP>1 also corrupts dp_idx when iterating get_rank_enum(['tp']) groups. + # Per-grid dim checks: tp/dp required; cp (if present) must be 1. + # Src PP must be 1; dest PP>1 is allowed. CP>1 corrupts dp_idx when + # iterating get_rank_enum(['tp']) groups. for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: for required in ('tp', 'dp'): if required not in grid.dim_names: @@ -137,14 +138,18 @@ def _validate_grids(self): f"{name} grid must have '{required}' dimension, " f"got dim_names={grid.dim_names}" ) - for singleton in ('pp', 'cp'): - if singleton in grid.dim_names: - size = grid.shape[grid.dim_names.index(singleton)] - if size != 1: - raise ValueError( - f"{name} {singleton.upper()} must be 1 for " - f"ColocatedBridgeCommunicator, got {size}" - ) + if 'cp' in grid.dim_names: + cp_size = grid.shape[grid.dim_names.index('cp')] + if cp_size != 1: + raise ValueError( + f"{name} CP must be 1 for ColocatedBridgeCommunicator, got {cp_size}" + ) + if 'pp' in self.src_grid.dim_names: + src_pp = self.src_grid.shape[self.src_grid.dim_names.index('pp')] + if src_pp != 1: + raise ValueError( + f"src PP must be 1 for ColocatedBridgeCommunicator, got {src_pp}" + ) src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 7732ad85253..ed0d7aa8ca4 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -79,7 +79,7 @@ def build( Grids differ → NON_COLOCATED with PP-stage info per module. """ if module_to_grid_map is None or cls._all_grids_colocated(module_to_grid_map): - return cls._colocated(modality_module_names) + return cls._colocated(modality_module_names, module_to_grid_map) return cls._from_grid_map(module_to_grid_map) @staticmethod @@ -91,16 +91,31 @@ def _all_grids_colocated(module_to_grid_map: Dict[str, 'HyperCommGrid']) -> bool ) @classmethod - def _colocated(cls, modality_module_names: List[str]) -> 'RankRole': - """Colocated layout: every module on every rank, PP=1.""" + def _colocated( + cls, + modality_module_names: List[str], + module_to_grid_map: Optional[Dict[str, 'HyperCommGrid']] = None, + ) -> 'RankRole': + """Colocated layout: every module on every rank. + + When a grid map is supplied, per-module stage info is derived from + each grid's pp group (LLM PP>1 is allowed). With no grid map, every + module is both first and last stage. + """ all_module_names = list(modality_module_names) + [MIMO_LANGUAGE_MODULE_KEY] - return cls( - modules={ - name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) - for name in all_module_names - }, - mode=ModuleLayout.COLOCATED, - ) + modules = {} + for name in all_module_names: + grid = module_to_grid_map.get(name) if module_to_grid_map else None + if grid is not None and 'pp' in grid.dim_names: + pp_group = grid.get_pg('pp') + pp_rank, pp_size = pp_group.rank(), pp_group.size() + modules[name] = ModuleStageInfo( + is_first_stage=(pp_rank == 0), + is_last_stage=(pp_rank == pp_size - 1), + ) + else: + modules[name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + return cls(modules=modules, mode=ModuleLayout.COLOCATED) @classmethod def _from_grid_map(cls, module_to_grid_map: Dict[str, HyperCommGrid]) -> 'RankRole': diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 703bc9d9950..0af467276b3 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -67,6 +67,11 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # in TP/DP within those ranks. self._build_colocated_communicators() + lang_info = self.role.modules.get(MIMO_LANGUAGE_MODULE_KEY) + self.lm_has_pp = lang_info is not None and not ( + lang_info.is_first_stage and lang_info.is_last_stage + ) + # Use special token IDs from the config self.special_token_ids = ( mimo_config.special_token_ids.copy() if mimo_config.special_token_ids else {} @@ -318,6 +323,7 @@ def forward( labels: Optional[torch.Tensor] = None, modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None, packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass through the multimodal model. @@ -362,6 +368,20 @@ def forward( input_tensors = getattr(self, 'input_tensors', None) if self.role.mode == ModuleLayout.COLOCATED: + if self.lm_has_pp and input_tensors is not None: + # PP>1 non-first stage: hidden states from P2P + lm_result = self._forward_language_module( + input_ids, + position_ids, + attention_mask, + labels, + {MIMO_LANGUAGE_MODULE_KEY: input_tensors}, + ) + # Unwrap dict for P2P (schedule uses plain tensors, not dicts) + if isinstance(lm_result, dict): + lm_result = lm_result[MIMO_LANGUAGE_MODULE_KEY] + return lm_result, loss_mask + return self._forward_all_modules( input_ids, position_ids, @@ -370,6 +390,7 @@ def forward( labels, modality_inputs, packing_kwargs, + encoder_embeddings=encoder_embeddings, ) if self.role.mode == ModuleLayout.NON_COLOCATED: @@ -531,6 +552,22 @@ def _apply_colocated_comms(self, modality_embeddings): ) return modality_embeddings + def encode_and_communicate(self, modality_inputs): + """Run encoder forward + colocated TP/DP transform (collective).""" + modality_embeddings = {} + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) + return modality_embeddings + def _forward_all_modules( self, input_ids: torch.Tensor, @@ -540,6 +577,7 @@ def _forward_all_modules( labels: Optional[torch.Tensor], modality_inputs: Optional[Dict[str, Dict[str, Any]]], packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass when all modules are on all ranks (no multi-module PP). @@ -556,26 +594,12 @@ def _forward_all_modules( packed_seq_params.qkv_format = 'thd' logger.debug(f"Packed sequence parameters: {packed_seq_params}") - # 1. Process each modality to get embeddings - modality_embeddings = {} - - for modality_name, submodule in self.modality_submodules.items(): - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - modality_embeddings[modality_name] = embeddings - logger.debug( - f"Generated embeddings for {modality_name} with shape {embeddings.shape}" - ) - - # Apply colocated communication if configured (no-op when colocated_comms is empty) - if self.colocated_comms: - modality_embeddings = self._apply_colocated_comms(modality_embeddings) + if encoder_embeddings is not None: + # PP>1 path: encoder forward + communicate already ran in Phase 1; + # reuse the precomputed embeddings for every LLM microbatch. + modality_embeddings = encoder_embeddings + else: + modality_embeddings = self.encode_and_communicate(modality_inputs) # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py index 546a1648fd5..236e5295dde 100644 --- a/tests/unit_tests/models/test_mimo_colocated_communicator.py +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -236,8 +236,8 @@ def test_rank_offset_mismatch(self): "side,dim,expected", [ ("src", "pp", "src PP must be 1"), - ("dest", "pp", "dest PP must be 1"), ("src", "cp", "CP must be 1"), + ("dest", "cp", "CP must be 1"), ], ) def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): @@ -252,6 +252,13 @@ def test_pp_or_cp_gt_one_rejected(self, side, dim, expected): with pytest.raises(ValueError, match=expected): make_comm(src_grid, dest_grid) + def test_dest_pp_gt_one_accepted(self): + # Dest PP>1 is valid: the three-phase colocated schedule handles + # the LLM pipeline orchestration. The bridge only needs src PP=1. + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, pp=2, dp=2) + make_comm(src_grid, dest_grid) + def test_dp_not_divisible(self): # 6-rank grids with DP sizes (3 vs 2) that neither divides the other. # Fits inside an 8-rank world (HyperCommGrid enforces size <= world - offset). diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py new file mode 100644 index 00000000000..87280c45911 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -0,0 +1,443 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Tests for colocated MIMO training with LLM PP>1. + +Run individually (8 GPUs): + uv run python -m torch.distributed.run --nproc_per_node=8 \ + -m pytest tests/unit_tests/models/test_mimo_colocated_pp.py -v +""" + +import re +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator +from megatron.core.transformer.enums import ModelType +from tests.unit_tests.models.test_mimo_1f1b_schedule import ( + build_no_sync_func, + create_all_embedding_groups, + create_hypercomm_grid, + destroy_all_grids, + get_mimo_model, +) +from tests.unit_tests.models.test_mimo_colocated_correctness import ( + _assert_encoder_weights_match, + _BatchIterator, + _copy_ref_params_to_dist, + _generate_and_broadcast_global_batches, + _slice_batch, + _wire_training_hooks, +) +from tests.unit_tests.test_utilities import Utils + + +def _wire_pp_training_hooks(mimo_model, language_pg, vision_pg, llm_grid): + """PP-aware variant of ``_wire_training_hooks`` for LLM PP>1. + + ``calculate_per_token_loss=True`` on both sub-model configs pins DDP's + gradient_scaling_factor to 1.0 (pure SUM across DP). But with LLM PP>1, + the inner schedule only populates ``num_tokens`` on the last PP stage; + non-last stages get 0. Before the DP all-reduce, this helper broadcasts + ``num_tokens`` from the last LLM PP rank to earlier ones so every rank + arrives at the same ``N_global`` and the per-token divisor lands + uniformly on encoder + LLM grads. + """ + + no_sync_func = build_no_sync_func(mimo_model) + pp_group = llm_grid.get_pg("pp") + + def finalize_grads_func(model_list, num_tokens, force_all_reduce=False, **kwargs): + assert num_tokens is not None, ( + "finalize_grads_func expects calculate_per_token_loss=True on the " + "TransformerConfig so the schedule forwards total_num_tokens; got None." + ) + + if pp_group.size() > 1: + last_rank = dist.get_global_rank(pp_group, pp_group.size() - 1) + dist.broadcast(num_tokens, src=last_rank, group=pp_group) + + llm_dp_pg = language_pg.dp_cp if language_pg.dp_cp is not None else language_pg.dp + dist.all_reduce(num_tokens, group=llm_dp_pg, op=dist.ReduceOp.SUM) + n_global = num_tokens.item() + + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], + num_tokens=None, + pg_collection=language_pg, + force_all_reduce=force_all_reduce, + ) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + finalize_model_grads( + [submodule], + num_tokens=None, + pg_collection=vision_pg, + force_all_reduce=force_all_reduce, + ) + + if n_global > 0: + inv = 1.0 / n_global + if mimo_model.language_model is not None: + mimo_model.language_model.scale_gradients(inv) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + submodule.scale_gradients(inv) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + +def _assert_llm_weights_match_pp_aware( + ref_module, dist_module, pp_rank, pp_size, num_layers, rtol=1e-2, atol=1e-2 +): + """Assert dist LLM shards match ref (PP=1) via the same layer-index remap + ``_copy_llm_params_pp_aware`` uses. Non-layer params (embedding, + final_layernorm, output_layer) only exist on stages that own them. + """ + layers_per_stage = num_layers // pp_size + layer_rx = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') + ref_params = dict(ref_module.named_parameters()) + + mismatches = [] + for name, dist_param in dist_module.named_parameters(): + m = layer_rx.match(name) + if m: + prefix, local_idx_s, suffix = m.groups() + global_idx = pp_rank * layers_per_stage + int(local_idx_s) + ref_name = f"{prefix}{global_idx}{suffix}" + else: + ref_name = name + assert ref_name in ref_params, ( + f"LLM param '{name}' maps to ref '{ref_name}' which does not exist " + f"(ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)}." + ) + try: + torch.testing.assert_close( + dist_param.data, ref_param.data, rtol=rtol, atol=atol + ) + except AssertionError as e: + mismatches.append((name, ref_name, str(e))) + + if mismatches: + rank = dist.get_rank() + details = "\n".join(f" {n} -> {rn}: {msg}" for n, rn, msg in mismatches) + raise AssertionError( + f"Rank {rank}: {len(mismatches)} LLM param(s) diverged between " + f"PP>1 dist model and PP=1 reference:\n{details}" + ) + + +def _copy_llm_params_pp_aware(ref_module, dist_module, pp_rank, pp_size, num_layers): + """Copy LLM params ref (PP=1) → dist (PP>=1) with layer-index remapping. + + Dist's ``decoder.layers.{local_idx}`` on PP stage ``s`` corresponds to + ref's global layer ``s * layers_per_stage + local_idx``. Non-layer + params (embedding, final_layernorm, output_layer) are only present on + stages that own them and their names match exactly between ref and + dist. Assumes ``dist_llm_tp == ref_llm_tp`` so shards line up 1:1. + """ + assert num_layers % pp_size == 0, ( + f"num_layers={num_layers} not divisible by pp_size={pp_size}; " + f"oracle requires even PP split." + ) + layers_per_stage = num_layers // pp_size + layer_rx = re.compile(r'^(.*decoder\.layers\.)(\d+)(\..*)$') + ref_params = dict(ref_module.named_parameters()) + + with torch.no_grad(): + for name, dist_param in dist_module.named_parameters(): + m = layer_rx.match(name) + if m: + prefix, local_idx_s, suffix = m.groups() + global_idx = pp_rank * layers_per_stage + int(local_idx_s) + ref_name = f"{prefix}{global_idx}{suffix}" + else: + ref_name = name + assert ref_name in ref_params, ( + f"LLM param '{name}' on PP stage {pp_rank} maps to ref name " + f"'{ref_name}' which does not exist in ref (ref has llm_pp=1)." + ) + ref_param = ref_params[ref_name] + assert ref_param.shape == dist_param.shape, ( + f"LLM param '{name}': ref.shape={tuple(ref_param.shape)} != " + f"dist.shape={tuple(dist_param.shape)} — oracle requires " + f"dist_llm_tp == ref_llm_tp." + ) + dist_param.data.copy_(ref_param.data.to(dist_param.dtype)) + + +def _run_pp_weight_oracle( + dist_enc_tp, + dist_enc_dp, + dist_llm_tp, + dist_llm_pp, + dist_llm_dp, + num_microbatches, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size_llm=2, +): + """Drive the dist (PP>1) vs ref (PP=1, equal-DP) weight oracle.""" + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + encoder_name = "images" + + # Equal-DP reference: same encoder TP/DP; LLM matches encoder TP/DP and + # uses PP=1 (the only PP value compatible with equal-DP on a fixed rank + # count when enc_tp == llm_tp). + ref_enc_tp, ref_enc_dp = dist_enc_tp, dist_enc_dp + ref_llm_tp, ref_llm_pp, ref_llm_dp = dist_enc_tp, 1, dist_enc_dp + + global_batch_size = micro_batch_size_llm * dist_llm_dp + ref_per_rank_mbs = global_batch_size // ref_llm_dp + + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, + bucket_size=10000, + use_distributed_optimizer=True, + ) + + dist_enc_grid = create_hypercomm_grid( + offset=0, tp=dist_enc_tp, cp=1, pp=1, dp=dist_enc_dp + ) + dist_llm_grid = create_hypercomm_grid( + offset=0, tp=dist_llm_tp, cp=1, pp=dist_llm_pp, dp=dist_llm_dp + ) + ref_enc_grid = create_hypercomm_grid( + offset=0, tp=ref_enc_tp, cp=1, pp=1, dp=ref_enc_dp + ) + ref_llm_grid = create_hypercomm_grid( + offset=0, tp=ref_llm_tp, cp=1, pp=ref_llm_pp, dp=ref_llm_dp + ) + create_all_embedding_groups([dist_enc_grid, dist_llm_grid, ref_enc_grid, ref_llm_grid]) + + torch.manual_seed(12345) + dist_model, _, _, dist_lang_pg, dist_vis_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_len=seq_length, + ddp_config=ddp_config, + bf16=False, + bias=False, + dropout=False, + per_token_loss=True, + ) + dist_model.model_type = ModelType.encoder_or_decoder + + torch.manual_seed(12345) + ref_model, _, _, ref_lang_pg, ref_vis_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=ref_enc_grid, + llm_grid=ref_llm_grid, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_len=seq_length, + ddp_config=ddp_config, + bf16=False, + bias=False, + dropout=False, + per_token_loss=True, + ) + ref_model.model_type = ModelType.encoder_or_decoder + + _copy_ref_params_to_dist( + ref_model.modality_submodules[encoder_name].module, + dist_model.modality_submodules[encoder_name].module, + ref_enc_grid.get_pg("tp"), + dist_enc_grid.get_pg("tp"), + ) + _copy_llm_params_pp_aware( + ref_model.language_model.module, + dist_model.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=dist_llm_pp, + num_layers=num_layers, + ) + + _wire_pp_training_hooks(dist_model, dist_lang_pg, dist_vis_pg, dist_llm_grid) + _wire_training_hooks(ref_model, ref_lang_pg, ref_vis_pg) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=False, + use_distributed_optimizer=True, + ) + dist_optimizer = get_mimo_optimizer(dist_model, opt_config) + ref_optimizer = get_mimo_optimizer(ref_model, opt_config) + + torch.manual_seed(99999) + global_batches = _generate_and_broadcast_global_batches( + global_mbs=global_batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + vocab_size=vocab_size, + encoder_name=encoder_name, + num_batches=num_microbatches, + ) + dist_batches = [ + _slice_batch(b, dist_llm_dp, dist_llm_grid.get_pg("dp").rank()) + for b in global_batches + ] + ref_batches = [ + _slice_batch(b, ref_enc_dp, ref_enc_grid.get_pg("dp").rank()) + for b in global_batches + ] + + dist_optimizer.zero_grad() + colocated_forward_backward_with_pp( + mimo_model=dist_model, + data_iterator=_BatchIterator(dist_batches), + num_microbatches=num_microbatches, + encoder_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + encoder_name=encoder_name, + seq_length=seq_length, + micro_batch_size=micro_batch_size_llm, + p2p_communicator=P2PCommunicator( + pp_group=dist_llm_grid.get_pg("pp"), config=dist_model.config + ), + pg_collection=dist_lang_pg, + ) + dist_ok, dist_gn, _ = dist_optimizer.step() + assert dist_ok, "Dist optimizer step failed" + assert dist_gn is not None and dist_gn > 0, ( + f"Dist grad_norm={dist_gn} — three-phase schedule produced zero grads." + ) + + def _sum_loss(loss_mask, output_tensor): + """Per-token-loss 3-tuple matching ``_wire_training_hooks`` contract.""" + if output_tensor is None: + zero_loss = torch.tensor(0.0, device='cuda', requires_grad=True) + zero_count = torch.tensor(0, device='cuda', dtype=torch.int) + return zero_loss, zero_count, {'loss_reduced': 0.0} + masked = output_tensor.float() * loss_mask.float() + local_sum = masked.sum() + local_num_tokens = loss_mask.float().sum().to(torch.int) + return local_sum, local_num_tokens, {'loss_reduced': local_sum.detach().item()} + + def _ref_forward_step(data_iterator, model, *args): + batch = next(data_iterator) + output_tensor, loss_mask = model( + input_ids=batch['input_ids'], + labels=batch['labels'], + loss_mask=batch['loss_mask'], + position_ids=batch['position_ids'], + modality_inputs=batch['modality_inputs'], + ) + return output_tensor, partial(_sum_loss, loss_mask) + + ref_optimizer.zero_grad() + schedule.forward_backward_no_pipelining( + forward_step_func=_ref_forward_step, + data_iterator=_BatchIterator(ref_batches), + model=[ref_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=ref_per_rank_mbs, + forward_only=False, + pg_collection=ref_lang_pg, + ) + ref_ok, ref_gn, _ = ref_optimizer.step() + assert ref_ok, "Ref optimizer step failed" + assert ref_gn is not None and ref_gn > 0, f"Ref grad_norm={ref_gn}" + + # LLM forward differs between 1F1B (dist, PP>1) and no-pipelining (ref), + # and TP shards may accumulate in a different order; keep tolerances + # loose enough to absorb that drift even in fp32. + _assert_encoder_weights_match( + ref_model.modality_submodules[encoder_name].module, + dist_model.modality_submodules[encoder_name].module, + rtol=1e-2, + atol=1e-2, + ) + _assert_llm_weights_match_pp_aware( + ref_model.language_model.module, + dist_model.language_model.module, + pp_rank=dist_llm_grid.get_pg("pp").rank(), + pp_size=dist_llm_pp, + num_layers=num_layers, + rtol=1e-2, + atol=1e-2, + ) + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimoColocatedPP: + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + @pytest.mark.parametrize( + "num_microbatches", + [2, 4], + ids=["num_mb_eq_pp", "num_mb_gt_pp_grad_acc"], + ) + def test_pp_matches_pp1_equal_dp_reference(self, num_microbatches): + """Post-step encoder weights under PP>1 match equal-DP PP=1 reference. + + Dist runs ``colocated_forward_backward_with_pp`` (three-phase + schedule with PP=2 on the LLM); ref runs + ``forward_backward_no_pipelining`` with matching encoder TP/DP and + LLM PP=1. Under correct PP>1 encoder grad accumulation + broadcast, + one Adam step yields shard-wise equal post-step encoder weights + (modulo bf16 drift). + + The ``num_mb_gt_pp_grad_acc`` case runs more microbatches than PP + stages so encoder embedding views for every microbatch must + accumulate into the same ``detached_full.grad`` via PyTorch + view-gradient semantics — a regression there surfaces here. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + _run_pp_weight_oracle( + dist_enc_tp=2, + dist_enc_dp=4, + dist_llm_tp=2, + dist_llm_pp=2, + dist_llm_dp=2, + num_microbatches=num_microbatches, + )