diff --git a/megatron/core/models/mimo/colocated_schedule.py b/megatron/core/models/mimo/colocated_schedule.py new file mode 100644 index 00000000000..8cf96ded6c6 --- /dev/null +++ b/megatron/core/models/mimo/colocated_schedule.py @@ -0,0 +1,215 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Three-phase schedule for colocated MIMO training with LLM PP>1. + +Phase 1: Encoder forward + communicate for the full batch (all ranks synchronized). +Phase 2: LLM 1F1B pipeline with detached encoder embeddings sliced per microbatch. +Phase 3: Encoder backward for the full batch (all ranks synchronized). + +Encoder runs on all ranks (PP=1) and its TP/DP collectives require all ranks +to participate simultaneously. The 1F1B pipeline staggers ranks across PP stages, +so encoder collectives cannot run inside the pipeline. The three-phase design +separates encoder (synchronized) from LLM (pipelined) by detaching the autograd +graph at the encoder-LLM boundary. +""" + +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.pipeline_parallel import schedules + + +def colocated_forward_backward_with_pp( + mimo_model, + data_iterator, + num_microbatches: int, + encoder_grid: Optional[HyperCommGrid] = None, + llm_grid: Optional[HyperCommGrid] = None, + encoder_name: str = "images", + forward_only: bool = False, + **schedule_kwargs, +): + """Three-phase colocated training: encoder batch -> LLM pipeline -> encoder backward. + + Args: + mimo_model: MimoModel with colocated communicators and lm_has_pp=True. + data_iterator: Yields dicts with input_ids, labels, etc. + num_microbatches: Number of microbatches for the LLM pipeline. + encoder_grid: Encoder HyperCommGrid (for DP fan-in slicing). + llm_grid: LLM HyperCommGrid (for PP group). + encoder_name: Modality name for the encoder (e.g., "images"). + forward_only: Skip backward passes if True. + **schedule_kwargs: Passed to forward_backward_pipelining_without_interleaving. + Must include p2p_communicator, pg_collection, seq_length, micro_batch_size. + """ + pp_group = llm_grid.get_pg("pp") if llm_grid and 'pp' in llm_grid.dim_names else None + is_pp_first = pp_group is None or pp_group.rank() == 0 + + # ── Phase 1: Encoder forward on full batch (one pass) ──────────────── + # All ranks participate (encoder is PP=1, communicate is collective). + all_batches = [next(data_iterator) for _ in range(num_microbatches)] + full_encoder_input = _concat_encoder_inputs(all_batches, encoder_name) + _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid) + + enc_out = mimo_model.encode_and_communicate({encoder_name: full_encoder_input}) + + # Detach: sever autograd link to encoder so Phase 2 has no encoder collectives. + # Microbatch slices are views into detached_full — their .grad accumulates + # into detached_full.grad automatically via PyTorch's view gradient semantics. + detached_full = {k: v.detach().requires_grad_(True) for k, v in enc_out.items()} + lm_data = _build_lm_microbatches(detached_full, all_batches, num_microbatches) + + # ── Phase 2: LLM 1F1B pipeline ────────────────────────────────────── + # Only LLM P2P communication (within PP group). No encoder collectives. + cache_iter = iter(lm_data) + + def _lm_forward_step(data_iterator_unused, model, *args): + cached = next(cache_iter) + output_tensor, loss_mask = model( + input_ids=cached['input_ids'], + labels=cached['labels'], + loss_mask=cached['loss_mask'], + position_ids=cached['position_ids'], + encoder_embeddings=cached['encoder_embeddings'], + ) + return output_tensor, partial(_loss_func, cached['loss_mask']) + + losses = schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=_lm_forward_step, + data_iterator=cache_iter, + model=[mimo_model], + num_microbatches=num_microbatches, + forward_only=forward_only, + **schedule_kwargs, + ) + + # ── Phase 3: Encoder backward (one pass, all ranks sync) ──────────── + # detached_full.grad was populated by Phase 2's per-microbatch LLM backward + # (accumulated across microbatch view slices on PP stage 0). + # Broadcast to PP stage 1+ then run one encoder backward for the full batch. + if not forward_only and enc_out: + _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first) + for key in enc_out: + grad = detached_full[key].grad + if grad is not None: + torch.autograd.backward(enc_out[key], grad_tensors=grad) + + return losses + + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _concat_encoder_inputs(all_batches, encoder_name): + """Concatenate encoder inputs from all microbatches along batch dim (dim 1).""" + first = all_batches[0] + result = {} + if not (first.get('modality_inputs') and encoder_name in first['modality_inputs']): + return result + for enc_name in first['modality_inputs'][encoder_name]: + result[enc_name] = {} + for key in first['modality_inputs'][encoder_name][enc_name]: + vals = [ + b['modality_inputs'][encoder_name][enc_name][key] + for b in all_batches + if b.get('modality_inputs') and encoder_name in b['modality_inputs'] + ] + tensors = [v for v in vals if isinstance(v, torch.Tensor)] + result[enc_name][key] = torch.cat(tensors, dim=1) if tensors else vals[0] + return result + + +def _slice_for_encoder_dp(full_encoder_input, encoder_grid, llm_grid): + """Slice concatenated encoder input for fan-in (enc_dp > llm_dp).""" + if encoder_grid is None or llm_grid is None: + return + enc_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + if enc_dp <= llm_dp: + return + scale = enc_dp // llm_dp + slot = encoder_grid.get_pg("dp").rank() % scale + for enc_name in full_encoder_input: + for key, tensor in full_encoder_input[enc_name].items(): + if isinstance(tensor, torch.Tensor) and tensor.ndim >= 2: + bs = tensor.shape[1] + ss = bs // scale + if ss == 0: + raise ValueError( + f"Encoder fan-in produces zero-sized batch: " + f"total_batch={bs}, scale={scale}. Increase micro_batch_size." + ) + full_encoder_input[enc_name][key] = tensor[ + :, slot * ss : (slot + 1) * ss, : + ].contiguous() + + +def _build_lm_microbatches(detached_full, all_batches, num_microbatches): + """Slice detached encoder output into per-microbatch views for the LLM pipeline.""" + if not detached_full: + # Text-only batch: no encoder embeddings to slice + return [ + { + 'encoder_embeddings': {}, + 'input_ids': all_batches[mb].get('input_ids'), + 'labels': all_batches[mb].get('labels'), + 'loss_mask': all_batches[mb].get('loss_mask'), + 'position_ids': all_batches[mb].get('position_ids'), + } + for mb in range(num_microbatches) + ] + + sample = next(iter(detached_full.values())) + batch_dim = 1 if sample.ndim == 3 else 0 + total_batch = sample.shape[batch_dim] + assert total_batch % num_microbatches == 0, ( + f"Encoder output batch ({total_batch}) must be divisible " + f"by num_microbatches ({num_microbatches})" + ) + mb_size = total_batch // num_microbatches + + lm_data = [] + for mb in range(num_microbatches): + s, e = mb * mb_size, (mb + 1) * mb_size + mb_enc = {} + for k, v in detached_full.items(): + mb_enc[k] = v[:, s:e, :] if v.ndim == 3 else v[s:e, :] + lm_data.append( + { + 'encoder_embeddings': mb_enc, + 'input_ids': all_batches[mb].get('input_ids'), + 'labels': all_batches[mb].get('labels'), + 'loss_mask': all_batches[mb].get('loss_mask'), + 'position_ids': all_batches[mb].get('position_ids'), + } + ) + return lm_data + + +def _broadcast_encoder_grad(detached_full, enc_out, pp_group, is_pp_first): + """Broadcast encoder gradient from PP stage 0 to stage 1+ ranks.""" + if pp_group is None or pp_group.size() <= 1: + return + src = dist.get_global_rank(pp_group, 0) + for key in enc_out: + if is_pp_first: + assert ( + detached_full[key].grad is not None + ), f"No encoder gradient on PP stage 0 for '{key}'" + dist.broadcast(detached_full[key].grad, src=src, group=pp_group) + else: + grad = torch.zeros_like(detached_full[key]) + dist.broadcast(grad, src=src, group=pp_group) + detached_full[key].grad = grad + + +def _loss_func(loss_mask, output_tensor): + """Default loss function for the LLM pipeline.""" + if output_tensor is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + loss = output_tensor.float().sum() + return loss, {'loss_reduced': loss.detach().item()} diff --git a/megatron/core/models/mimo/comm/__init__.py b/megatron/core/models/mimo/comm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py new file mode 100644 index 00000000000..9c79fa914a7 --- /dev/null +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid + + +@dataclass +class SliceInfo: + """Batch dimension slice information for a rank's data partition.""" + + start: int + size: int + + +class ColocatedBridgeCommunicator: + """Handles tensor communication between colocated modules with different TP/DP layouts.""" + + def __init__( + self, + src_grid: HyperCommGrid, + dest_grid: HyperCommGrid, + src_module_name: str = "src", + dest_module_name: str = "dest", + dim_mapping: Optional[Dict[str, int]] = None, + ): + self.src_grid = src_grid + self.dest_grid = dest_grid + self.src_module_name = src_module_name + self.dest_module_name = dest_module_name + self.dim_mapping = dim_mapping or {'b': 0, 's': 1, 'h': 2} + self.current_rank = dist.get_rank() + + self._validate_grids() + self._extract_parallelism_info() + self._build_rank_mappings() + + self.all_gather_pg: Optional[dist.ProcessGroup] = None + self.all_gather_group_ranks: List[List[int]] = [] + + if self.dp_scale_factor > 1: + self._build_all_gather_groups() + + logging.info( + f"[Rank {self.current_rank}] ColocatedBridgeCommunicator: " + f"{src_module_name}({self.src_tp_size}TP/{self.src_dp_size}DP) -> " + f"{dest_module_name}({self.dest_tp_size}TP/{self.dest_dp_size}DP), " + f"scale_factor={self.dp_scale_factor}" + ) + + def _validate_grids(self): + if self.src_grid.size != self.dest_grid.size: + raise ValueError( + f"Grids must span same number of ranks: " + f"src={self.src_grid.size}, dest={self.dest_grid.size}" + ) + + if self.src_grid.rank_offset != self.dest_grid.rank_offset: + raise ValueError( + f"Grids must have same rank offset: " + f"src={self.src_grid.rank_offset}, dest={self.dest_grid.rank_offset}" + ) + + # Source (encoder) must have PP=1. Dest (LLM) may have PP>1 — + # the communicator only maps to dest's first PP stage. + if 'pp' in self.src_grid.dim_names: + pp_size = self.src_grid.shape[self.src_grid.dim_names.index('pp')] + if pp_size != 1: + raise ValueError(f"Source PP must be 1 for colocated, got {pp_size}") + + src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] + dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + if src_dp % dest_dp != 0 and dest_dp % src_dp != 0: + raise ValueError( + f"DP sizes must be evenly divisible: src_dp={src_dp}, dest_dp={dest_dp}" + ) + + def _extract_parallelism_info(self): + self.src_tp_size = self.src_grid.shape[self.src_grid.dim_names.index('tp')] + self.src_dp_size = self.src_grid.shape[self.src_grid.dim_names.index('dp')] + self.dest_tp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('tp')] + self.dest_dp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + self.dp_scale_factor = self.src_dp_size / self.dest_dp_size + + @staticmethod + def _get_rank_dim_coord(rank, grid, dim_name): + """Extract a rank's coordinate for a specific grid dimension.""" + dim_idx = grid.dim_names.index(dim_name) + temp = rank - grid.rank_offset + for i in range(dim_idx): + temp //= grid.shape[i] + return temp % grid.shape[dim_idx] + + def _build_rank_mappings(self): + self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {} + self.rank_to_dest_pos: Dict[int, Tuple[int, int]] = {} + + src_tp_groups = self.src_grid._gen_rank_enum(['tp']) + for dp_idx, tp_group in enumerate(src_tp_groups): + for tp_idx, rank in enumerate(tp_group): + self.rank_to_src_pos[rank] = (dp_idx, tp_idx) + + # For dest, only map ranks at PP stage 0 (first pipeline stage). + # When dest has PP>1, _gen_rank_enum(['tp']) returns dp*pp groups. + # We filter to PP=0 so dp_idx correctly indexes the DP dimension only. + dest_has_pp = 'pp' in self.dest_grid.dim_names + dest_tp_groups = self.dest_grid._gen_rank_enum(['tp']) + dp_idx = 0 + for tp_group in dest_tp_groups: + if dest_has_pp: + pp_coord = self._get_rank_dim_coord(tp_group[0], self.dest_grid, 'pp') + if pp_coord != 0: + continue + for tp_idx, rank in enumerate(tp_group): + self.rank_to_dest_pos[rank] = (dp_idx, tp_idx) + dp_idx += 1 + + def _build_all_gather_groups(self): + scale = int(self.dp_scale_factor) + all_groups: List[List[int]] = [] + + for dest_dp_idx in range(self.dest_dp_size): + src_dp_start = dest_dp_idx * scale + src_dp_indices = range(src_dp_start, src_dp_start + scale) + + for src_tp_idx in range(self.src_tp_size): + group_ranks = [] + for src_dp_idx in src_dp_indices: + for rank, (dp, tp) in self.rank_to_src_pos.items(): + if dp == src_dp_idx and tp == src_tp_idx: + group_ranks.append(rank) + break + all_groups.append(sorted(group_ranks)) + + self.all_gather_group_ranks = all_groups + self.all_gather_pg, _ = dist.new_subgroups_by_enumeration(all_groups, backend='nccl') + + self._my_all_gather_group_idx = None + for idx, group_ranks in enumerate(all_groups): + if self.current_rank in group_ranks: + self._my_all_gather_group_idx = idx + break + + logging.debug( + f"[Rank {self.current_rank}] All-gather groups: {all_groups}, " + f"my_group_idx={self._my_all_gather_group_idx}" + ) + + def get_all_gather_group(self) -> Optional[dist.ProcessGroup]: + """Return the all-gather process group for fan-in communication.""" + return self.all_gather_pg + + def get_all_gather_world_size(self) -> int: + """Return the world size of the all-gather group.""" + if self.all_gather_pg is None: + return 1 + return dist.get_world_size(self.all_gather_pg) + + def get_slice_info(self, batch_size: int) -> SliceInfo: + """Compute batch slice info for the current rank given the full batch size.""" + if self.dp_scale_factor < 1: + return self._get_fan_out_slice_info(batch_size) + elif self.dp_scale_factor > 1: + return self._get_fan_in_slice_info(batch_size) + else: + return SliceInfo(start=0, size=batch_size) + + def _get_fan_out_slice_info(self, batch_size: int) -> SliceInfo: + # For PP>1 dest, only PP stage 0 ranks are in rank_to_dest_pos. + # PP stage 1+ ranks still call communicate() but the result is unused. + if self.current_rank not in self.rank_to_dest_pos: + return SliceInfo(start=0, size=batch_size) + dest_dp_idx = self.rank_to_dest_pos[self.current_rank][0] + scale = int(1 / self.dp_scale_factor) + slot = dest_dp_idx % scale + slice_size = batch_size // scale + return SliceInfo(start=slot * slice_size, size=slice_size) + + def _get_fan_in_slice_info(self, batch_size: int) -> SliceInfo: + src_dp_idx = self.rank_to_src_pos[self.current_rank][0] + scale = int(self.dp_scale_factor) + slot = src_dp_idx % scale + slice_size = batch_size // scale + return SliceInfo(start=slot * slice_size, size=slice_size) + + def is_fan_out(self) -> bool: + """Return True if src DP < dest DP (encoder has fewer replicas).""" + return self.src_dp_size < self.dest_dp_size + + def is_fan_in(self) -> bool: + """Return True if src DP > dest DP (encoder has more replicas).""" + return self.src_dp_size > self.dest_dp_size + + def is_equal_dp(self) -> bool: + """Return True if src and dest have same DP size.""" + return self.src_dp_size == self.dest_dp_size + + def communicate(self, tensor: torch.Tensor) -> torch.Tensor: + """Transform tensor from src TP/DP layout to dest TP/DP layout.""" + return _ColocatedCommunicate.apply(tensor, self) + + +class _ColocatedCommunicate(torch.autograd.Function): + """Autograd function for colocated communication with correct backward pass.""" + + @staticmethod + def forward(ctx, tensor: torch.Tensor, comm: ColocatedBridgeCommunicator) -> torch.Tensor: + """Forward: fan-out slices, fan-in all-gathers, equal copies.""" + ctx.comm = comm + ctx.batch_dim = comm.dim_mapping['b'] + batch_size = tensor.shape[ctx.batch_dim] + + if comm.is_fan_out(): + ctx.input_batch_size = batch_size + slice_info = comm.get_slice_info(batch_size) + return tensor.narrow(ctx.batch_dim, slice_info.start, slice_info.size).contiguous() + + elif comm.is_fan_in(): + group = comm.get_all_gather_group() + world_size = comm.get_all_gather_world_size() + gathered_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(gathered_list, tensor.contiguous(), group=group) + return torch.cat(gathered_list, dim=ctx.batch_dim) + + else: + return tensor.contiguous() + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Backward: adjoint of forward (zero-pad for fan-out, slice for fan-in).""" + comm = ctx.comm + batch_dim = ctx.batch_dim + + if comm.is_fan_out(): + grad_input = torch.zeros( + *grad_output.shape[:batch_dim], + ctx.input_batch_size, + *grad_output.shape[batch_dim + 1 :], + dtype=grad_output.dtype, + device=grad_output.device, + ) + slice_info = comm.get_slice_info(ctx.input_batch_size) + grad_input.narrow(batch_dim, slice_info.start, slice_info.size).copy_(grad_output) + return grad_input, None + + elif comm.is_fan_in(): + output_batch_size = grad_output.shape[batch_dim] + slice_info = comm.get_slice_info(output_batch_size) + return ( + grad_output.narrow(batch_dim, slice_info.start, slice_info.size).contiguous(), + None, + ) + + else: + return grad_output.contiguous(), None diff --git a/megatron/core/models/mimo/config/base_configs.py b/megatron/core/models/mimo/config/base_configs.py index a92484a5a48..91a26ac5bde 100644 --- a/megatron/core/models/mimo/config/base_configs.py +++ b/megatron/core/models/mimo/config/base_configs.py @@ -23,9 +23,11 @@ class MimoModelConfig: in the input_ids to insert the modality embeddings at the correct positions. module_to_grid_map (Optional[Dict[str, HyperCommGrid]]): Dictionary mapping module keys (e.g., "vision", "language") to their - corresponding HyperCommGrid configurations for non-colocated pipeline - parallelism. The language model must use the key MIMO_LANGUAGE_MODULE_KEY. - When None, all modules are assumed to be colocated on the same ranks. + corresponding HyperCommGrid configurations. The language model must use + the key MIMO_LANGUAGE_MODULE_KEY. + When grids span the same ranks → colocated (same or different TP/DP). + When grids span disjoint ranks → non-colocated (pipeline parallel). + When None → colocated with legacy global parallel_state. kv_format (str): Key-value format for attention: "sbhd" (seq-batch-head-dim) or "thd" (total-head-dim). Default is "sbhd". diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 77c2512e8e6..8ef7a4be3c1 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -24,22 +24,17 @@ class ModuleLayout(Enum): Determines how modules are distributed across ranks and which forward path is used. - UNIFIED: No module_to_grid_map. All modules share same ranks and - parallelism. Uses the unified forward path (_forward_all_modules). + COLOCATED: All modules share the same ranks. Covers both legacy + (no grid map, global parallel_state) and heterogeneous TP/DP + (grid map with overlapping ranks). Uses _forward_all_modules. NON_COLOCATED: module_to_grid_map is set with non-overlapping rank ranges. Each rank runs EITHER encoder(s) OR the language model. Uses role-based dispatch with separate forward paths. - - COLOCATED: (future) module_to_grid_map is set with overlapping rank - ranges. Encoder(s) and language model share ranks but have - different parallelism configs. Uses role-based dispatch but - allows both module types on the same rank. """ - UNIFIED = "unified" - NON_COLOCATED = "non_colocated" COLOCATED = "colocated" + NON_COLOCATED = "non_colocated" @dataclass @@ -70,17 +65,17 @@ class RankRole: """ modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) - mode: ModuleLayout = ModuleLayout.UNIFIED + mode: ModuleLayout = ModuleLayout.COLOCATED @classmethod - def unified(cls, module_names: List[str]) -> 'RankRole': - """Create a role for the unified case: every module, first+last stage.""" + def colocated(cls, module_names: List[str]) -> 'RankRole': + """Create a role for colocated: every module on every rank, PP=1.""" return cls( modules={ name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) for name in module_names }, - mode=ModuleLayout.UNIFIED, + mode=ModuleLayout.COLOCATED, ) @classmethod diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index b1c12f521c3..12b08dea2a0 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -7,6 +7,7 @@ import torch from megatron.core.distributed import DistributedDataParallel +from megatron.core.models.mimo.comm.colocated_communicator import ColocatedBridgeCommunicator from megatron.core.models.mimo.config import MimoModelConfig from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig @@ -59,10 +60,35 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - self.mimo_config = mimo_config modality_names = list(mimo_config.modality_submodules_spec.keys()) + self.colocated_comms = {} if mimo_config.module_to_grid_map: - self.role = RankRole.from_grid_map(mimo_config.module_to_grid_map, modality_names) + if self._is_colocated(mimo_config.module_to_grid_map): + self.role = RankRole.colocated(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) + self._build_colocated_communicators() + else: + self.role = RankRole.from_grid_map(mimo_config.module_to_grid_map, modality_names) else: - self.role = RankRole.unified(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) + self.role = RankRole.colocated(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) + + # Detect LLM PP>1 for two-phase colocated execution + self.lm_has_pp = False + self.lm_is_first_pp_stage = True + if mimo_config.module_to_grid_map: + lang_grid = mimo_config.module_to_grid_map.get(MIMO_LANGUAGE_MODULE_KEY) + if lang_grid and 'pp' in lang_grid.dim_names: + pp_idx = lang_grid.dim_names.index('pp') + if lang_grid.shape[pp_idx] > 1: + self.lm_has_pp = True + pp_group = lang_grid.get_pg('pp') + pp_rank = pp_group.rank() + pp_size = pp_group.size() + self.lm_is_first_pp_stage = pp_rank == 0 + # Update language module stage info for PP>1 + from megatron.core.models.mimo.config.role import ModuleStageInfo + + self.role.modules[MIMO_LANGUAGE_MODULE_KEY] = ModuleStageInfo( + is_first_stage=(pp_rank == 0), is_last_stage=(pp_rank == pp_size - 1) + ) # Use special token IDs from the config self.special_token_ids = ( @@ -315,6 +341,7 @@ def forward( labels: Optional[torch.Tensor] = None, modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None, packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass through the multimodal model. @@ -355,10 +382,23 @@ def forward( - Language module ranks: language model output (logits or loss) - No role (all modules colocated): language model output """ - # Get any tensors passed via set_input_tensor - input_tensors = getattr(self, 'input_tensors', None) + if self.role.mode == ModuleLayout.COLOCATED: + input_tensors = getattr(self, 'input_tensors', None) + + if self.lm_has_pp and input_tensors is not None: + # PP>1 non-first stage: hidden states from P2P + lm_result = self._forward_language_module( + input_ids, + position_ids, + attention_mask, + labels, + {MIMO_LANGUAGE_MODULE_KEY: input_tensors}, + ) + # Unwrap dict for P2P (schedule uses plain tensors, not dicts) + if isinstance(lm_result, dict): + lm_result = lm_result[MIMO_LANGUAGE_MODULE_KEY] + return lm_result, loss_mask - if self.role.mode == ModuleLayout.UNIFIED: return self._forward_all_modules( input_ids, position_ids, @@ -367,8 +407,12 @@ def forward( labels, modality_inputs, packing_kwargs, + encoder_embeddings=encoder_embeddings, ) + # Get any tensors passed via set_input_tensor + input_tensors = getattr(self, 'input_tensors', None) + if self.role.mode == ModuleLayout.NON_COLOCATED: if self.role.has_modality_modules: return self._forward_encoders(modality_inputs, input_tensors), loss_mask @@ -491,6 +535,56 @@ def _forward_language_module( return lm_output + @staticmethod + def _is_colocated(module_to_grid_map): + """Check if all grids span the same ranks (colocated).""" + grids = list(module_to_grid_map.values()) + first = grids[0] + return all(g.rank_offset == first.rank_offset and g.size == first.size for g in grids[1:]) + + def _build_colocated_communicators(self): + """Build communicators for each encoder → language edge.""" + grid_map = self.mimo_config.module_to_grid_map + lang_key = MIMO_LANGUAGE_MODULE_KEY + lang_grid = grid_map[lang_key] + for mod_name in self.mimo_config.modality_submodules_spec: + if mod_name in grid_map and mod_name != lang_key: + src_grid = grid_map[mod_name] + if src_grid.size == lang_grid.size: + self.colocated_comms[(mod_name, lang_key)] = ColocatedBridgeCommunicator( + src_grid=src_grid, + dest_grid=lang_grid, + src_module_name=mod_name, + dest_module_name=lang_key, + ) + + def _apply_colocated_comms(self, modality_embeddings): + """Transform encoder embeddings from encoder TP/DP to LLM TP/DP layout.""" + lang_key = MIMO_LANGUAGE_MODULE_KEY + for modality_name in list(modality_embeddings.keys()): + comm = self.colocated_comms.get((modality_name, lang_key)) + if comm is not None: + modality_embeddings[modality_name] = comm.communicate( + modality_embeddings[modality_name] + ) + return modality_embeddings + + def encode_and_communicate(self, modality_inputs): + """Run encoder forward + colocated TP/DP transform (collective).""" + modality_embeddings = {} + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) + return modality_embeddings + def _forward_all_modules( self, input_ids: torch.Tensor, @@ -500,6 +594,7 @@ def _forward_all_modules( labels: Optional[torch.Tensor], modality_inputs: Optional[Dict[str, Dict[str, Any]]], packing_kwargs: Optional[dict] = None, + encoder_embeddings: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass when all modules are on all ranks (no multi-module PP). @@ -516,22 +611,27 @@ def _forward_all_modules( packed_seq_params.qkv_format = 'thd' logger.debug(f"Packed sequence parameters: {packed_seq_params}") - # 1. Process each modality to get embeddings - modality_embeddings = {} + if encoder_embeddings is not None: + modality_embeddings = encoder_embeddings + else: + # 1. Process each modality to get embeddings + modality_embeddings = {} - for modality_name, submodule in self.modality_submodules.items(): - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - modality_embeddings[modality_name] = embeddings - logger.debug( - f"Generated embeddings for {modality_name} with shape {embeddings.shape}" - ) + for modality_name, submodule in self.modality_submodules.items(): + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + logger.debug(f"Processing {modality_name} modality") + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + logger.debug(f"{modality_name} embeddings: {embeddings.shape}") + + # Apply colocated communication if configured (no-op when colocated_comms is empty) + if self.colocated_comms: + modality_embeddings = self._apply_colocated_comms(modality_embeddings) # Get text embeddings text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py new file mode 100644 index 00000000000..d62b2f2c07c --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -0,0 +1,432 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import logging +import os +import sys + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.mimo.comm.colocated_communicator import ( + ColocatedBridgeCommunicator, + SliceInfo, +) +from tests.unit_tests.test_utilities import Utils + +logging.basicConfig(level=logging.DEBUG, stream=sys.stderr) + +_active_grids: list = [] + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + grid = HyperCommGrid( + shape=[tp, cp, pp, dp], + dim_names=["tp", "cp", "pp", "dp"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + _active_grids.append(grid) + return grid + + +def destroy_all_grids(): + for grid in _active_grids: + grid.destroy() + _active_grids.clear() + + +# ── Test 1: Rank mappings ────────────────────────────────────────────────────── + + +class TestRankMappings: + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + @pytest.mark.parametrize( + "src_tp, src_dp, dest_tp, dest_dp, expected_src_pos, expected_dest_pos", + [ + # Fan-in: TP2/DP4 → TP4/DP2 + ( + 2, + 4, + 4, + 2, + { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + 4: (2, 0), + 5: (2, 1), + 6: (3, 0), + 7: (3, 1), + }, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (1, 0), + 5: (1, 1), + 6: (1, 2), + 7: (1, 3), + }, + ), + # Fan-out: TP4/DP2 → TP2/DP4 + ( + 4, + 2, + 2, + 4, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (1, 0), + 5: (1, 1), + 6: (1, 2), + 7: (1, 3), + }, + { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + 4: (2, 0), + 5: (2, 1), + 6: (3, 0), + 7: (3, 1), + }, + ), + # Equal: TP4/DP2 → TP4/DP2 + ( + 4, + 2, + 4, + 2, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (1, 0), + 5: (1, 1), + 6: (1, 2), + 7: (1, 3), + }, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (1, 0), + 5: (1, 1), + 6: (1, 2), + 7: (1, 3), + }, + ), + # Extreme: TP1/DP8 → TP8/DP1 + ( + 1, + 8, + 8, + 1, + { + 0: (0, 0), + 1: (1, 0), + 2: (2, 0), + 3: (3, 0), + 4: (4, 0), + 5: (5, 0), + 6: (6, 0), + 7: (7, 0), + }, + { + 0: (0, 0), + 1: (0, 1), + 2: (0, 2), + 3: (0, 3), + 4: (0, 4), + 5: (0, 5), + 6: (0, 6), + 7: (0, 7), + }, + ), + ], + ids=["fan_in", "fan_out", "equal", "extreme"], + ) + def test_rank_mappings( + self, src_tp, src_dp, dest_tp, dest_dp, expected_src_pos, expected_dest_pos + ): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + + assert comm.rank_to_src_pos == expected_src_pos + assert comm.rank_to_dest_pos == expected_dest_pos + + +# ── Test 2: All-gather groups ────────────────────────────────────────────────── + + +class TestAllGatherGroups: + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + @pytest.mark.parametrize( + "src_tp, src_dp, dest_tp, dest_dp, expected_groups", + [ + # Fan-in: TP2/DP4 → TP4/DP2 + (2, 4, 4, 2, [[0, 2], [1, 3], [4, 6], [5, 7]]), + # Extreme: TP1/DP8 → TP8/DP1 + (1, 8, 8, 1, [[0, 1, 2, 3, 4, 5, 6, 7]]), + ], + ids=["fan_in_2x", "extreme_8x"], + ) + def test_fan_in_all_gather_groups(self, src_tp, src_dp, dest_tp, dest_dp, expected_groups): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + + assert comm.all_gather_group_ranks == expected_groups + assert comm.all_gather_pg is not None + + def test_fan_out_no_all_gather(self): + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=2, dp=4) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + + assert comm.all_gather_group_ranks == [] + assert comm.all_gather_pg is None + + +# ── Test 3: Slice info ───────────────────────────────────────────────────────── + + +class TestSliceInfo: + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + def teardown_method(self): + destroy_all_grids() + + @pytest.mark.parametrize( + "src_tp, src_dp, dest_tp, dest_dp, batch_size, expected_slices", + [ + # Fan-out: TP4/DP2 → TP2/DP4, batch=8 + ( + 4, + 2, + 2, + 4, + 8, + { + 0: SliceInfo(start=0, size=4), + 1: SliceInfo(start=0, size=4), + 2: SliceInfo(start=4, size=4), + 3: SliceInfo(start=4, size=4), + 4: SliceInfo(start=0, size=4), + 5: SliceInfo(start=0, size=4), + 6: SliceInfo(start=4, size=4), + 7: SliceInfo(start=4, size=4), + }, + ), + # Fan-in: TP2/DP4 → TP4/DP2, batch=8 + ( + 2, + 4, + 4, + 2, + 8, + { + 0: SliceInfo(start=0, size=4), + 1: SliceInfo(start=0, size=4), + 2: SliceInfo(start=4, size=4), + 3: SliceInfo(start=4, size=4), + 4: SliceInfo(start=0, size=4), + 5: SliceInfo(start=0, size=4), + 6: SliceInfo(start=4, size=4), + 7: SliceInfo(start=4, size=4), + }, + ), + ], + ids=["fan_out", "fan_in"], + ) + def test_slice_info(self, src_tp, src_dp, dest_tp, dest_dp, batch_size, expected_slices): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, dp=dest_dp) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + + rank = dist.get_rank() + if rank not in expected_slices: + pytest.skip(f"rank {rank} not in expected_slices") + + info = comm.get_slice_info(batch_size) + expected = expected_slices[rank] + assert info.start == expected.start, f"rank {rank}: start {info.start} != {expected.start}" + assert info.size == expected.size, f"rank {rank}: size {info.size} != {expected.size}" + + def test_equal_dp_slice(self): + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=4, dp=2) + comm = ColocatedBridgeCommunicator(src_grid, dest_grid) + + info = comm.get_slice_info(batch_size=8) + assert info == SliceInfo(start=0, size=8) + + +# ── Test 4: Forward / backward golden test ───────────────────────────────────── + + +class TestGolden: + + @classmethod + def setup_class(cls): + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + @pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("2.3.0"), reason="Requires PyTorch 2.3+" + ) + def test_forward_backward_golden(self): + from tests.unit_tests.pipeline_parallel.test_bridge_communicator import ( + _avg_params, + _create_transformer_block, + _get_pg_collection_from_grid, + _shard_and_copy_, + ) + + hidden_size = 1024 + seq_len = 16 + micro_batch = 8 + dtype = torch.float32 + rank = dist.get_rank() + + # Encoder TP2/DP4, LLM TP4/DP2 + enc_tp, enc_dp = 2, 4 + llm_tp, llm_dp = 4, 2 + + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, create_gloo_process_groups=False + ) + + # Reference TP1 blocks + ref_grid = create_hypercomm_grid(tp=1, dp=8) + ref_pg = _get_pg_collection_from_grid(ref_grid) + ref_enc = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=ref_pg + ) + _avg_params(ref_enc, ref_grid.get_pg("dp")) + ref_llm = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=ref_pg + ) + _avg_params(ref_llm, ref_grid.get_pg("dp")) + + # Sharded encoder block (TP2/DP4) + enc_grid = create_hypercomm_grid(tp=enc_tp, dp=enc_dp) + enc_pg = _get_pg_collection_from_grid(enc_grid) + enc_block = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=enc_pg + ) + _shard_and_copy_(ref_enc, enc_block, enc_tp, enc_pg.tp.rank()) + + # Sharded LLM block (TP4/DP2) + llm_grid = create_hypercomm_grid(tp=llm_tp, dp=llm_dp) + llm_pg = _get_pg_collection_from_grid(llm_grid) + llm_block = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=llm_pg + ) + _shard_and_copy_(ref_llm, llm_block, llm_tp, llm_pg.tp.rank()) + + dist.barrier() + + # Communicator + comm = ColocatedBridgeCommunicator( + enc_grid, + llm_grid, + src_module_name="encoder", + dest_module_name="llm", + dim_mapping={"s": 0, "b": 1, "h": 2}, + ) + + # ── Reference forward (full batch, TP1) ─────────────────────────── + torch.manual_seed(42) + full_input = torch.randn(seq_len, micro_batch, hidden_size, device="cuda", dtype=dtype) + full_input_ref = full_input.clone().detach().requires_grad_(True) + ref_enc_out = ref_enc(hidden_states=full_input_ref, attention_mask=None) + ref_llm_out = ref_llm(hidden_states=ref_enc_out, attention_mask=None) + + # ── Colocated forward ────────────────────────────────────────────── + # Each rank gets its encoder DP slice + enc_dp_idx = comm.rank_to_src_pos[rank][0] + enc_slice_size = micro_batch // enc_dp + enc_input_slice = ( + full_input[:, enc_dp_idx * enc_slice_size : (enc_dp_idx + 1) * enc_slice_size, :] + .clone() + .detach() + .requires_grad_(True) + ) + + enc_out = enc_block(hidden_states=enc_input_slice, attention_mask=None) + bridged = comm.communicate(enc_out) + llm_out = llm_block(hidden_states=bridged, attention_mask=None) + + # ── Compare forward outputs ──────────────────────────────────────── + llm_dp_idx = comm.rank_to_dest_pos[rank][0] + llm_slice_size = micro_batch // llm_dp + ref_slice = ref_llm_out[ + :, llm_dp_idx * llm_slice_size : (llm_dp_idx + 1) * llm_slice_size, : + ].detach() + + torch.testing.assert_close(llm_out.detach(), ref_slice, rtol=1e-3, atol=1e-3) + + # ── Backward ────────────────────────────────────────────────────── + llm_out.sum().backward() + ref_llm_out.sum().backward() + + ref_input_grad_slice = full_input_ref.grad[ + :, enc_dp_idx * enc_slice_size : (enc_dp_idx + 1) * enc_slice_size, : + ] + torch.testing.assert_close(enc_input_slice.grad, ref_input_grad_slice, rtol=1e-5, atol=1e-5) + + Utils.destroy_model_parallel() diff --git a/tests/unit_tests/models/test_mimo_colocated_e2e.py b/tests/unit_tests/models/test_mimo_colocated_e2e.py new file mode 100644 index 00000000000..94bafd9bf23 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_e2e.py @@ -0,0 +1,675 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""End-to-end integration test for MIMO model with colocated modules (no pipeline parallelism). + +Both encoder and LLM share the same ranks (offset=0) but use different TP/DP +configurations. Communication between heterogeneous TP/DP layouts is handled by +ColocatedBridgeCommunicator. + +Run with: + uv run python -m torch.distributed.run --nproc_per_node=8 -m pytest tests/unit_tests/models/test_mimo_colocated_e2e.py -v +""" + +import logging +from contextlib import ExitStack, contextmanager +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TERowParallelLinear = None + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Helper Functions (copied from test_mimo_1f1b_schedule.py to avoid +# cross-test process group conflicts) +# ============================================================================ + +_active_grids: list = [] +_embedding_pg_cache: dict = {} + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + """Create a HyperCommGrid with specified parallelism.""" + grid = HyperCommGrid( + shape=[tp, cp, pp, dp, 1, 1], # [tp, cp, pp, dp, ep, expt_dp] + dim_names=["tp", "cp", "pp", "dp", "ep", "expt_dp"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["ep"]) + grid.create_pg(["expt_dp"]) + # Required by _get_pg_collection_for_optimizer + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) + _active_grids.append(grid) + return grid + + +def destroy_all_grids(): + """Destroy all tracked grids and bridge communicator PGs.""" + for grid in _active_grids: + grid.destroy() + _active_grids.clear() + _embedding_pg_cache.clear() + BridgeCommunicator.destroy_broadcast_pgs() + + +def get_pg_collection(grid): + """Get ProcessGroupCollection from grid.""" + pg_collection = ProcessGroupCollection() + pg_collection.tp = grid.get_pg("tp") + pg_collection.cp = grid.get_pg("cp") + pg_collection.pp = grid.get_pg("pp") + pg_collection.ep = grid.get_pg("ep") + pg_collection.dp = grid.get_pg("dp") + pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + pg_collection.expt_dp = grid.get_pg("expt_dp") + return pg_collection + + +def create_all_embedding_groups(grids): + """Create embedding PGs for all grids upfront. + + dist.new_group is a collective -- ALL ranks must call it, even non-members. + We create all embedding groups in a consistent order across all ranks to + avoid hangs from asymmetric new_group calls. + """ + for grid in grids: + pp_group = grid.get_pg("pp") + if not pp_group: + continue + + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + cache_key = tuple(pp_ranks) + + if cache_key not in _embedding_pg_cache: + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + _embedding_pg_cache[cache_key] = ( + dist.new_group(ranks=pos_embd_ranks), + dist.new_group(ranks=embd_ranks), + ) + + +def add_embedding_groups(pg_collection, is_language_model=False): + """Add cached embedding groups to a process group collection.""" + if not pg_collection.pp: + return pg_collection + + pp_ranks = sorted(dist.get_process_group_ranks(pg_collection.pp)) + cache_key = tuple(pp_ranks) + pos_embd_pg, embd_pg = _embedding_pg_cache[cache_key] + + pg_collection.pos_embd = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None + + if is_language_model: + pg_collection.embd = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + else: + pg_collection.embd = None + + return pg_collection + + +def get_pg_collection_with_embedding_groups(grid, is_language_model=False): + """Get ProcessGroupCollection with embedding groups (PGs must be pre-created).""" + return add_embedding_groups(get_pg_collection(grid), is_language_model=is_language_model) + + +# ============================================================================ +# Model Spec Helpers +# ============================================================================ + + +def get_language_model_spec( + num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection +): + """Get the language model spec.""" + pp_rank = dist.get_rank(pg_collection.pp) + pp_size = dist.get_world_size(pg_collection.pp) + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + + lm_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl='te', + ) + return ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), + "pg_collection": pg_collection, + }, + ) + + +def get_projection_config(hidden_size): + """Return a TransformerConfig for the vision projection MLP.""" + cfg = TransformerConfig(num_layers=1, hidden_size=hidden_size, num_attention_heads=1) + cfg.ffn_hidden_size = hidden_size + cfg.bias_activation_fusion = True + cfg.add_bias_linear = True + cfg.activation_func = torch.nn.functional.gelu + return cfg + + +def get_projection_layer_spec(): + """Layer spec for the vision-projection MLP.""" + if TEColumnParallelLinear is None or TERowParallelLinear is None: + raise RuntimeError("TEColumnParallelLinear and TERowParallelLinear are required") + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear), + ) + + +def get_vision_submodules_spec( + num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection +): + """Get the submodule spec for the vision modality.""" + from megatron.core.transformer.transformer_block import TransformerBlock + + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 + pp_rank = dist.get_rank(pg_collection.pp) + + vision_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + vision_encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": pg_collection, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), + }, + ) + + vision_projection_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": get_projection_config(hidden_size=language_hidden_size), + "submodules": get_projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": vision_config.hidden_size, + "tp_group": pg_collection.tp, + }, + ) + + return ModuleSpec( + module=VisionModalitySubmodules, + submodules={ + "encoders": {"clip_encoder": vision_encoder_spec}, + "input_projections": [vision_projection_spec], + }, + ) + + +# ============================================================================ +# Data Iterator +# ============================================================================ + + +class DataIterator: + """Simple data iterator returning VLM-like batches.""" + + def __init__( + self, + hidden_size, + seq_length, + micro_batch_size, + vocab_size, + encoder_name, + image_token_id=50257, + image_seq_length=None, + ): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.vocab_size = vocab_size + self.encoder_name = encoder_name + self.image_token_id = image_token_id + self.image_seq_length = image_seq_length or (seq_length // 2) + + def __iter__(self): + return self + + def __next__(self): + encoder_hidden_states = torch.randn( + self.image_seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + + image_tokens = torch.full( + (self.micro_batch_size, self.image_seq_length), + self.image_token_id, + dtype=torch.long, + device='cuda', + ) + text_tokens = torch.randint( + 1, + self.vocab_size, + (self.micro_batch_size, self.seq_length - self.image_seq_length), + device='cuda', + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + + labels = input_ids.clone() + labels[input_ids == self.image_token_id] = -100 + + loss_mask = torch.ones( + self.micro_batch_size, self.seq_length, device='cuda', dtype=torch.float32 + ) + loss_mask[input_ids == self.image_token_id] = 0.0 + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": torch.arange(self.seq_length, device='cuda') + .unsqueeze(0) + .expand(self.micro_batch_size, -1) + .clone(), + "modality_inputs": { + self.encoder_name: { + "clip_encoder": {'hidden_states': encoder_hidden_states, 'attention_mask': None} + } + }, + } + + +# ============================================================================ +# Model Creation for Colocated Config +# ============================================================================ + + +def get_mimo_model_colocated( + encoder_name, encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seq_len +): + """Create MIMO model with colocated grids for same-rank heterogeneous TP/DP.""" + language_pg = get_pg_collection_with_embedding_groups(llm_grid, is_language_model=True) + vision_pg = get_pg_collection_with_embedding_groups(encoder_grid, is_language_model=False) + + language_model_spec = get_language_model_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + vocab_size=vocab_size, + seq_len=seq_len, + pg_collection=language_pg, + ) + vision_submodule_spec = get_vision_submodules_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + language_hidden_size=hidden_size, + pg_collection=vision_pg, + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={encoder_name: vision_submodule_spec}, + special_token_ids={encoder_name: 50257}, + module_to_grid_map={encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid}, + ) + + mimo_model = MimoModel(mimo_config) + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + + # Set model_type so forward_backward_no_pipelining's get_model_type() works + mimo_model.model_type = ModelType.encoder_or_decoder + + # Wrap with DDP + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True + ) + + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg, + ) + + if encoder_name in mimo_model.modality_submodules: + submodule = mimo_model.modality_submodules[encoder_name] + if submodule is not None: + submodule = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg, + ) + mimo_model.modality_submodules[encoder_name] = submodule + + return mimo_model, language_pg, vision_pg + + +# ============================================================================ +# Test Runner +# ============================================================================ + + +def loss_func(loss_mask, output_tensor): + """Compute loss from model output.""" + if output_tensor is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + + loss = output_tensor.float().sum() + return loss, {'loss_reduced': loss.detach().item()} + + +def forward_step(data_iterator, model, encoder_grid, llm_grid, encoder_name): + """Forward step with data slicing for heterogeneous DP.""" + batch = next(data_iterator) if data_iterator is not None else {'input_ids': None} + + if batch.get('input_ids') is None: + output_tensor, loss_mask = model(**batch) + return output_tensor, partial(loss_func, loss_mask) + + encoder_dp = encoder_grid.get_pg("dp").size() + llm_dp = llm_grid.get_pg("dp").size() + + if encoder_dp > llm_dp: + # Fan-in: data loaded with LLM DP (larger batch per rank) + # Slice modality_inputs for encoder's smaller batch + scale = encoder_dp // llm_dp + encoder_dp_idx = encoder_grid.get_pg("dp").rank() + slot = encoder_dp_idx % scale + + if 'modality_inputs' in batch and batch['modality_inputs'] is not None: + for mod_name, mod_data in batch['modality_inputs'].items(): + for enc_name, enc_data in mod_data.items(): + for key, tensor in enc_data.items(): + if tensor is not None and isinstance(tensor, torch.Tensor): + # Encoder inputs are [seq, batch, hidden] -- slice batch dim + batch_size = tensor.shape[1] # batch is dim 1 + slice_size = batch_size // scale + start = slot * slice_size + enc_data[key] = tensor[:, start : start + slice_size, :].contiguous() + + elif llm_dp > encoder_dp: + # Fan-out: slice LLM inputs for LLM's smaller batch + scale = llm_dp // encoder_dp + llm_dp_idx = llm_grid.get_pg("dp").rank() + slot = llm_dp_idx % scale + + batch_size = batch['input_ids'].shape[0] + slice_size = batch_size // scale + start = slot * slice_size + + for key in ['input_ids', 'labels', 'loss_mask', 'position_ids']: + if key in batch and batch[key] is not None: + batch[key] = batch[key][start : start + slice_size].contiguous() + + output_tensor, loss_mask = model(**batch) + return output_tensor, partial(loss_func, loss_mask) + + +def run_colocated_test( + encoder_tp, + encoder_dp, + llm_tp, + llm_dp, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=2, +): + """Run MIMO model through forward_backward_no_pipelining with colocated modules.""" + # Clear NVTE env vars that the conftest set_env fixture sets to '0'. + # GPTModel (LanguageModule) asserts these are unset or match the attention backend. + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + encoder_name = "images" + + # Both grids at offset=0 (colocated on same ranks) + encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) + llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=1, dp=llm_dp) + + # Create all embedding PGs upfront -- dist.new_group is a collective + create_all_embedding_groups([encoder_grid, llm_grid]) + + torch.manual_seed(12345) + + mimo_model, language_pg, vision_pg = get_mimo_model_colocated( + encoder_name=encoder_name, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_len=seq_length, + ) + + # Create MIMO optimizer (handles per-module DP groups, global grad norm) + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + optimizer = get_mimo_optimizer(mimo_model, opt_config) + + # Build schedule functions + @contextmanager + def no_sync_func(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + stack.enter_context(submodule.no_sync()) + yield + + def finalize_grads_func(*args, **kwargs): + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], num_tokens=None, pg_collection=language_pg + ) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + finalize_model_grads([submodule], num_tokens=None, pg_collection=vision_pg) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + # Create data iterator -- all ranks need data since PP=1 and all are colocated + data_iterator = DataIterator( + hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name + ) + + # Run multiple iterations of forward_backward + optimizer step + all_losses = [] + num_iterations = 3 + rank = dist.get_rank() + optimizer.zero_grad() + + for iteration in range(num_iterations): + losses = schedule.forward_backward_no_pipelining( + forward_step_func=partial( + forward_step, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + encoder_name=encoder_name, + ), + data_iterator=data_iterator, + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=language_pg, + ) + + # MIMO optimizer step (handles per-module DP groups + global grad norm) + success, grad_norm, num_zeros = optimizer.step() + assert success, f"Rank {rank}: Optimizer step failed at iteration {iteration}" + optimizer.zero_grad() + + all_losses.extend(losses) + logger.info(f"Rank {rank}: iteration {iteration} completed with {len(losses)} microbatches") + + # Verify losses from all iterations + assert len(all_losses) > 0, f"Rank {rank}: Expected non-empty losses list" + + for i, loss_dict in enumerate(all_losses): + assert 'loss_reduced' in loss_dict, f"Rank {rank}: Missing 'loss_reduced' in microbatch {i}" + loss_val = loss_dict['loss_reduced'] + if isinstance(loss_val, torch.Tensor): + loss_val = loss_val.item() + assert loss_val == loss_val, f"Rank {rank}: Loss is NaN at microbatch {i}" # NaN check + assert abs(loss_val) != float('inf'), f"Rank {rank}: Loss is inf at microbatch {i}" + logger.info(f"Rank {rank}: microbatch {i} loss = {loss_val}") + + # At least one microbatch should have non-zero loss + any_nonzero = any( + ( + loss_dict['loss_reduced'].item() + if isinstance(loss_dict['loss_reduced'], torch.Tensor) + else loss_dict['loss_reduced'] + ) + != 0.0 + for loss_dict in all_losses + ) + assert any_nonzero, f"Rank {rank}: All losses are zero -- model did not compute anything" + + # Verify we got losses from all iterations (num_iterations * num_microbatches) + expected_total = num_iterations * num_microbatches + assert len(all_losses) == expected_total, ( + f"Rank {rank}: Expected {expected_total} loss entries " + f"({num_iterations} iterations x {num_microbatches} microbatches), " + f"got {len(all_losses)}" + ) + + return all_losses + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimoColocatedE2E: + """Test MIMO model with colocated modules and forward_backward_no_pipelining.""" + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + def test_colocated_fan_in_8gpu(self): + """Encoder TP2/DP4, LLM TP4/DP2 -- fan-in case.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_test( + encoder_tp=2, + encoder_dp=4, + llm_tp=4, + llm_dp=2, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=2, + ) diff --git a/tests/unit_tests/models/test_mimo_colocated_pp.py b/tests/unit_tests/models/test_mimo_colocated_pp.py new file mode 100644 index 00000000000..3ce09a97968 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_pp.py @@ -0,0 +1,499 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Tests for colocated MIMO training with LLM PP>1. + +Uses two-phase execution: encoder pre-compute + 1F1B LLM pipeline. + +Run individually (8 GPUs): + uv run python -m torch.distributed.run --nproc_per_node=8 \ + -m pytest tests/unit_tests/models/test_mimo_colocated_pp.py -v +""" + +import logging +from contextlib import ExitStack, contextmanager +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.core.models.mimo.colocated_schedule import colocated_forward_backward_with_pp +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator +from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TERowParallelLinear = None + +logger = logging.getLogger(__name__) + +_active_grids: list = [] +_embedding_pg_cache: dict = {} + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + grid = HyperCommGrid( + shape=[tp, cp, pp, dp, 1, 1], + dim_names=["tp", "cp", "pp", "dp", "ep", "expt_dp"], + rank_offset=offset, + backend="nccl", + ) + for dim in ["tp", "cp", "pp", "dp", "ep", "expt_dp"]: + grid.create_pg([dim]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) + _active_grids.append(grid) + return grid + + +def destroy_all_grids(): + for g in _active_grids: + g.destroy() + _active_grids.clear() + _embedding_pg_cache.clear() + BridgeCommunicator.destroy_broadcast_pgs() + + +def create_all_embedding_groups(grids): + for grid in grids: + pp_group = grid.get_pg("pp") + if not pp_group: + continue + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + key = tuple(pp_ranks) + if key not in _embedding_pg_cache: + pos = [pp_ranks[0]] + embd = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd.append(pp_ranks[-1]) + _embedding_pg_cache[key] = (dist.new_group(ranks=pos), dist.new_group(ranks=embd)) + + +def get_pg_collection(grid, is_language_model=False): + pg = ProcessGroupCollection() + pg.tp = grid.get_pg("tp") + pg.cp = grid.get_pg("cp") + pg.pp = grid.get_pg("pp") + pg.ep = grid.get_pg("ep") + pg.dp = grid.get_pg("dp") + pg.dp_cp = grid.get_pg(["dp", "cp"]) + pg.expt_dp = grid.get_pg("expt_dp") + pp_ranks = sorted(dist.get_process_group_ranks(pg.pp)) + key = tuple(pp_ranks) + if key in _embedding_pg_cache: + pos_pg, embd_pg = _embedding_pg_cache[key] + pg.pos_embd = pos_pg if is_pp_first_stage(pg.pp) else None + pg.embd = ( + embd_pg + if is_language_model and (is_pp_last_stage(pg.pp) or is_pp_first_stage(pg.pp)) + else None + ) + return pg + + +def get_language_model_spec( + num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection +): + pp_rank = dist.get_rank(pg_collection.pp) + pp_size = dist.get_world_size(pg_collection.pp) + tp_size = pg_collection.tp.size() if pg_collection.tp else 1 + lm_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl='te', + ) + return ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), + "pg_collection": pg_collection, + }, + ) + + +def get_vision_submodules_spec( + num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection +): + from megatron.core.transformer.transformer_block import TransformerBlock + + tp_size = pg_collection.tp.size() if pg_collection.tp else 1 + vision_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=1, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + proj_cfg = TransformerConfig( + num_layers=1, hidden_size=language_hidden_size, num_attention_heads=1 + ) + proj_cfg.ffn_hidden_size = language_hidden_size + proj_cfg.bias_activation_fusion = True + proj_cfg.add_bias_linear = True + proj_cfg.activation_func = torch.nn.functional.gelu + + return ModuleSpec( + module=VisionModalitySubmodules, + submodules={ + "encoders": { + "clip_encoder": ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": pg_collection, + "pre_process": True, + "post_process": True, + }, + ) + }, + "input_projections": [ + ModuleSpec( + module=MultimodalProjector, + params={ + "config": proj_cfg, + "submodules": MLPSubmodules( + linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + "projector_type": "mlp", + "input_size": vision_config.hidden_size, + "tp_group": pg_collection.tp, + }, + ) + ], + }, + ) + + +class DataIterator: + def __init__( + self, + hidden_size, + seq_length, + micro_batch_size, + vocab_size, + encoder_name, + image_token_id=50257, + image_seq_length=None, + ): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.vocab_size = vocab_size + self.encoder_name = encoder_name + self.image_token_id = image_token_id + self.image_seq_length = image_seq_length or (seq_length // 2) + + def __iter__(self): + return self + + def __next__(self): + encoder_hidden_states = torch.randn( + self.image_seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + image_tokens = torch.full( + (self.micro_batch_size, self.image_seq_length), + self.image_token_id, + dtype=torch.long, + device='cuda', + ) + text_tokens = torch.randint( + 1, + self.vocab_size, + (self.micro_batch_size, self.seq_length - self.image_seq_length), + device='cuda', + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + labels = input_ids.clone() + labels[input_ids == self.image_token_id] = -100 + loss_mask = (input_ids != self.image_token_id).float() + position_ids = ( + torch.arange(self.seq_length, device='cuda') + .unsqueeze(0) + .expand(self.micro_batch_size, -1) + .clone() + ) + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "modality_inputs": { + self.encoder_name: { + "clip_encoder": {'hidden_states': encoder_hidden_states, 'attention_mask': None} + } + }, + } + + +def run_colocated_pp_test( + encoder_tp, + encoder_dp, + llm_tp, + llm_pp, + llm_dp, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, +): + """Run colocated MIMO with encoder PP=1 + LLM PP>1.""" + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + + encoder_name = "images" + + encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) + llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) + create_all_embedding_groups([encoder_grid, llm_grid]) + torch.manual_seed(12345) + + vision_pg = get_pg_collection(encoder_grid, is_language_model=False) + language_pg = get_pg_collection(llm_grid, is_language_model=True) + + language_model_spec = get_language_model_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + vocab_size=vocab_size, + seq_len=seq_length, + pg_collection=language_pg, + ) + vision_submodule_spec = get_vision_submodules_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + language_hidden_size=hidden_size, + pg_collection=vision_pg, + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={encoder_name: vision_submodule_spec}, + special_token_ids={encoder_name: 50257}, + module_to_grid_map={encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid}, + ) + + mimo_model = MimoModel(mimo_config) + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + mimo_model.model_type = ModelType.encoder_or_decoder + + # Wrap with DDP (per-module process groups) + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=False, bucket_size=10000, use_distributed_optimizer=True + ) + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg, + ) + if encoder_name in mimo_model.modality_submodules: + submodule = mimo_model.modality_submodules[encoder_name] + if submodule is not None: + mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg, + ) + + @contextmanager + def no_sync_func(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for sub in mimo_model.modality_submodules.values(): + if sub is not None: + stack.enter_context(sub.no_sync()) + yield + + def finalize_grads_func(*args, **kwargs): + if mimo_model.language_model is not None: + finalize_model_grads( + [mimo_model.language_model], num_tokens=None, pg_collection=language_pg + ) + for sub in mimo_model.modality_submodules.values(): + if sub is not None: + finalize_model_grads([sub], num_tokens=None, pg_collection=vision_pg) + + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) + else loss + ) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + optimizer = get_mimo_optimizer(mimo_model, opt_config) + + data_iterator = DataIterator( + hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name + ) + lm_pp_group = llm_grid.get_pg("pp") + + rank = dist.get_rank() + num_iterations = 2 + all_losses = [] + optimizer.zero_grad() + + for iteration in range(num_iterations): + losses = colocated_forward_backward_with_pp( + mimo_model=mimo_model, + data_iterator=data_iterator, + num_microbatches=num_microbatches, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + encoder_name=encoder_name, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + p2p_communicator=P2PCommunicator(pp_group=lm_pp_group, config=mimo_model.config), + pg_collection=language_pg, + ) + + success, grad_norm, _ = optimizer.step() + assert success, f"Rank {rank}: Optimizer step failed at iteration {iteration}" + optimizer.zero_grad() + + all_losses.extend(losses or []) + logger.info(f"Rank {rank}: iteration {iteration} done, losses={len(losses or [])}") + + # Verify on last PP stage + if is_pp_last_stage(lm_pp_group): + assert len(all_losses) > 0, f"Rank {rank}: No losses on last stage" + for i, loss_dict in enumerate(all_losses): + loss_val = loss_dict.get('loss_reduced', 0) + if isinstance(loss_val, torch.Tensor): + loss_val = loss_val.item() + assert loss_val == loss_val, f"Rank {rank}: NaN loss at mb {i}" + assert abs(loss_val) != float('inf'), f"Rank {rank}: Inf loss at mb {i}" + + return all_losses + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimoColocatedPP: + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + def test_fan_in_enc_tp2_dp4_llm_tp2_dp2_pp2(self): + """Fan-in: encoder TP2/DP4 → LLM TP2/DP2/PP2.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_pp_test( + encoder_tp=2, encoder_dp=4, llm_tp=2, llm_pp=2, llm_dp=2, num_microbatches=4 + ) + + def test_equal_dp_enc_tp4_dp2_llm_tp2_dp2_pp2(self): + """Equal DP: encoder TP4/DP2 → LLM TP2/DP2/PP2 (enc_dp == llm_dp).""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_pp_test( + encoder_tp=4, encoder_dp=2, llm_tp=2, llm_pp=2, llm_dp=2, num_microbatches=4 + ) + + def test_fan_in_with_grad_acc(self): + """Fan-in with gradient accumulation (num_microbatches > pp_size).""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_pp_test( + encoder_tp=2, + encoder_dp=4, + llm_tp=2, + llm_pp=2, + llm_dp=2, + num_microbatches=6, # > pp_size=2, tests grad accumulation + ) + + def test_fan_in_enc_tp1_dp8_llm_tp4_dp1_pp2(self): + """Fan-in extreme: encoder TP1/DP8 → LLM TP4/DP1/PP2.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + # micro_batch_size must be >= fan-in scale (enc_dp/llm_dp = 8/1 = 8) + # to avoid zero-sized slices in _slice_for_encoder_dp. + run_colocated_pp_test( + encoder_tp=1, + encoder_dp=8, + llm_tp=4, + llm_pp=2, + llm_dp=1, + micro_batch_size=8, + num_microbatches=4, + ) diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index e1c4b6e89bf..8beeb8a6ce5 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -550,7 +550,7 @@ def test_role_determination(self): self.patch_dim, {"images": 50257}, ) - assert model_no_grid.role.mode == ModuleLayout.UNIFIED + assert model_no_grid.role.mode == ModuleLayout.COLOCATED assert model_no_grid.role.has_language_module is True assert model_no_grid.role.has_modality_modules is True