diff --git a/examples/mimo/configs/nemotron_moe_vlm.py b/examples/mimo/configs/nemotron_moe_vlm.py new file mode 100644 index 00000000000..62cf84fd461 --- /dev/null +++ b/examples/mimo/configs/nemotron_moe_vlm.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Configuration utilities for the Nemotron6-MoE VLM with RADIO vision encoder. + +Provides TransformerConfig builders for: +- Nemotron6-MoE (3B hybrid Mamba-MoE) language model +- RADIO ViT vision encoder +- Vision-to-language MLP projector +- Layer specs for each component +""" + +import dataclasses +from typing import Optional + +import torch + +from megatron.core.activations import fast_gelu, squared_relu +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.vision.vit_layer_specs import ( + get_vit_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig + + +# --------------------------------------------------------------------------- +# Language model config: Nemotron6-MoE (3B hybrid Mamba-MoE) +# --------------------------------------------------------------------------- +def get_nemotron_moe_language_model_config(args=None) -> TransformerConfig: + """Return a TransformerConfig for **Nemotron6-MoE** (3B hybrid Mamba-MoE). + + Builds the config in two layers: + + 1. **Nemotron defaults** — model-specific values (MoE routing, Mamba heads, + hidden sizes, etc.) that define the 3B architecture. + 2. **CLI overrides** — if ``args`` (the Megatron args namespace) is provided, + any field whose value differs from TransformerConfig's dataclass default + is applied on top. This covers parallelism, precision, and any arch + field the training script explicitly sets. + + When ``args`` is ``None`` the pure Nemotron defaults are returned (useful + for unit tests or standalone usage). + """ + num_layers = getattr(args, "num_layers", 52) if args else 52 + + cfg = TransformerConfig( + num_layers=num_layers, + hidden_size=2688, + num_attention_heads=32, + ) + + # GQA + cfg.num_query_groups = 2 + cfg.kv_channels = 128 + + # FFN + cfg.ffn_hidden_size = 1856 + cfg.activation_func = squared_relu + cfg.gated_linear_unit = False + + # Normalisation + cfg.normalization = "RMSNorm" + + # No bias + cfg.add_bias_linear = False + + # Mamba SSM + cfg.mamba_num_heads = 64 + cfg.mamba_head_dim = 64 + + # MoE + cfg.num_moe_experts = 128 + cfg.moe_ffn_hidden_size = 1856 + cfg.moe_router_topk = 6 + cfg.moe_grouped_gemm = True + cfg.moe_router_score_function = "sigmoid" + cfg.moe_router_topk_scaling_factor = 2.5 + cfg.moe_router_enable_expert_bias = True + cfg.moe_router_dtype = "fp32" + cfg.moe_router_load_balancing_type = "seq_aux_loss" + cfg.moe_aux_loss_coeff = 0.0001 + cfg.moe_shared_expert_intermediate_size = 3712 + cfg.moe_shared_expert_overlap = True + cfg.moe_token_dispatcher_type = "alltoall" + + # Positional embeddings (Mamba handles position internally) + cfg.position_embedding_type = "none" + + # Sequence length + cfg.seq_length = 4096 + cfg.max_position_embeddings = 4096 + + # Dropout + cfg.attention_dropout = 0.0 + cfg.hidden_dropout = 0.0 + + # TE / kernel fusions + cfg.bias_activation_fusion = False + cfg.masked_softmax_fusion = True + cfg.persist_layer_norm = True + cfg.bias_dropout_fusion = False + + # ── CLI overrides ─────────────────────────────────────────────────── + # Override Nemotron defaults with values from CLI args. A CLI value + # is considered "explicitly set" when it differs from the + # TransformerConfig dataclass default — this preserves Nemotron- + # specific values (e.g. moe_router_topk=6) when the script doesn't + # pass them (CLI default also equals the dataclass default of 2). + if args is not None: + for field in dataclasses.fields(TransformerConfig): + if field.default is dataclasses.MISSING: + continue + arg_val = getattr(args, field.name, None) + if arg_val is None: + continue + if arg_val == field.default and getattr(cfg, field.name) != field.default: + continue + setattr(cfg, field.name, arg_val) + + return cfg + + +# --------------------------------------------------------------------------- +# Vision encoder config: RADIO ViT +# --------------------------------------------------------------------------- +def get_radio_vision_config( + config: Optional[TransformerConfig] = None, +) -> TransformerConfig: + """Return a TransformerConfig for the **RADIO** vision encoder. + + Parameters match ``examples/multimodal/config.py`` for ``vision_model_type == "radio"``. + """ + cfg = TransformerConfig( + num_layers=32, + hidden_size=1280, + num_attention_heads=16, + ) + + cfg.kv_channels = 80 + cfg.num_query_groups = 16 + cfg.ffn_hidden_size = 5120 + cfg.gated_linear_unit = False + cfg.activation_func = fast_gelu + + cfg.add_bias_linear = True + cfg.add_qkv_bias = True + + cfg.normalization = "LayerNorm" + cfg.layernorm_epsilon = 1e-6 + cfg.layernorm_zero_centered_gamma = False + + cfg.apply_rope_fusion = False + cfg.qk_layernorm = False + cfg.bias_activation_fusion = False + cfg.bias_dropout_fusion = False + cfg.attention_softmax_in_fp32 = True + + cfg.attention_dropout = 0.0 + cfg.hidden_dropout = 0.0 + + # Apply user overrides last. + if config is not None: + for field, value in vars(config).items(): + setattr(cfg, field, value) + + return cfg + + +# --------------------------------------------------------------------------- +# Vision → language projection MLP +# --------------------------------------------------------------------------- +def get_vlm_projection_config( + hidden_size: int = 2688, + config: Optional[TransformerConfig] = None, +) -> TransformerConfig: + """Return a TransformerConfig for the vision→language projection MLP. + + ``hidden_size`` should match the language model's hidden size. + + Must match the original pretrain_vlm.py architecture: + - activation_func = squared_relu (inherited from language model base config) + - normalization = "RMSNorm" (inherited from language model base config) + - bias_activation_fusion = False + - bias_dropout_fusion = False + """ + cfg = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=1, + ) + cfg.ffn_hidden_size = 4 * 5120 + cfg.bias_activation_fusion = False + cfg.bias_dropout_fusion = False + cfg.add_bias_linear = False + cfg.activation_func = squared_relu + cfg.normalization = "RMSNorm" + + if config is not None: + for field, value in vars(config).items(): + setattr(cfg, field, value) + + return cfg + + +# --------------------------------------------------------------------------- +# Layer specs +# --------------------------------------------------------------------------- +def get_radio_vision_layer_spec() -> ModuleSpec: + """Layer spec for the RADIO ViT encoder (Transformer-Engine).""" + return get_vit_layer_with_transformer_engine_spec() + + +def get_nemotron_moe_language_layer_spec() -> ModuleSpec: + """Layer spec for the Nemotron6-MoE hybrid Mamba stack. + + Returns the ``mamba_stack_spec`` from ``mamba_layer_specs`` which + supports Mamba / attention / MLP / MoE layers via + ``hybrid_override_pattern``. + """ + return mamba_stack_spec + + +def get_vlm_projection_layer_spec() -> ModuleSpec: + """Layer spec for the vision→language projection MLP. + + Uses TELayerNormColumnParallelLinear for fc1 to match the original + pretrain_vlm.py architecture (examples/multimodal/model.py). The fused + layer norm normalizes the vision encoder output before the first linear + layer, which is critical for training stability. + """ + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ) diff --git a/examples/mimo/model_providers/nemotron_moe_vlm.py b/examples/mimo/model_providers/nemotron_moe_vlm.py new file mode 100644 index 00000000000..f151f0653b1 --- /dev/null +++ b/examples/mimo/model_providers/nemotron_moe_vlm.py @@ -0,0 +1,277 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Model provider for Nemotron6-MoE VLM with RADIO vision encoder. + +Assembles a MIMO model consisting of: +- RADIO ViT vision encoder (1280 hidden, pixel shuffle → 5120) +- MLP projector (5120 → language hidden size) +- Nemotron6-MoE hybrid Mamba language model +""" + +import torch +from configs.nemotron_moe_vlm import ( + get_nemotron_moe_language_layer_spec, + get_nemotron_moe_language_model_config, + get_radio_vision_config, + get_radio_vision_layer_spec, + get_vlm_projection_config, + get_vlm_projection_layer_spec, +) + +from examples.mimo.model_providers.radio_encoder import RADIOEncoderWrapper +from examples.mimo.utils.logging import print_mimo_structure +from examples.mimo.utils.model_helpers import load_nemotron_vlm_ckpt, load_submodule_ckpt +from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.models.mimo import MimoModel, MimoModelConfig +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.transformer.spec_utils import ModuleSpec + + +def add_nemotron_moe_vlm_args(parser): + """Add Nemotron-MoE VLM specific arguments. + + These are args specific to the RADIO + Mamba-MoE model provider. + Note: --img-h, --img-w, --patch-dim are already in Megatron core args. + """ + group = parser.add_argument_group( + 'Nemotron-MoE VLM', 'Nemotron-MoE VLM model provider arguments' + ) + group.add_argument('--pixel-shuffle', action='store_true', default=False, + help='Apply pixel shuffle post-processing to vision encoder output') + group.add_argument('--max-num-tiles', type=int, default=1, + help='Max number of image tiles') + group.add_argument('--use-tiling', action='store_true', default=False, + help='Enable image tiling') + group.add_argument('--use-thumbnail', action='store_true', default=False, + help='Enable thumbnail generation') + group.add_argument('--disable-vision-class-token', action='store_true', default=False, + help='Do not drop class tokens from vision encoder') + group.add_argument('--freeze-lm', action='store_true', default=False, + help='Freeze language model parameters') + group.add_argument('--freeze-vit', action='store_true', default=False, + help='Freeze vision encoder parameters') + group.add_argument('--freeze-projection', action='store_true', default=False, + help='Freeze vision-to-language projection MLP parameters') + group.add_argument('--vision-model-type', type=str, default='radio', + help='Vision model type (e.g. radio)') + group.add_argument('--class-token-len', type=int, default=None, + help='Number of class tokens in vision encoder') + group.add_argument('--nemotron-checkpoint', type=str, default=None, + help='Path to a non-MIMO Nemotron VLM checkpoint directory. ' + 'Loads vision_model/vision_projection/language_model with key remapping.') + group.add_argument('--skip-projection-checkpoint', action='store_true', default=False, + help='When loading --nemotron-checkpoint, skip vision_projection weights ' + '(projection stays randomly initialized). Use for adapter-only training.') + + # MultimodalTokenizer args (required by megatron/training/tokenizer/tokenizer.py) + group.add_argument('--special-tokens', nargs='*', default=[''], + help='Special tokens for the multimodal tokenizer') + group.add_argument('--tokenizer-prompt-format', type=str, default='nemotron6-moe', + help='Prompt format for MultimodalTokenizer') + group.add_argument('--image-tag-type', type=str, default='', + help='Image tag type (e.g. nvlm, internvl, or empty)') + group.add_argument('--force-system-message', action='store_true', default=False, + help='Force a specific system message in the tokenizer') + return parser + + +def model_provider_nemotron_moe_vlm( + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + **kwargs, +): + """Build a Nemotron6-MoE VLM MIMO model. + + Composes RADIO vision encoder + MLP projector + MambaModel (hybrid + Mamba-MoE) into a single :class:`MimoModel`. + + Args: + pre_process: Include embedding layer (pipeline parallelism). + post_process: Include output layer (pipeline parallelism). + add_encoder: Unused (PP not yet supported in MIMO). + add_decoder: Unused (PP not yet supported in MIMO). + + Reads ``--image-token-id`` directly from CLI args (same source as the + data provider) to keep both sides in sync. + """ + from megatron.training import get_args + from megatron.training.global_vars import get_tokenizer + + args = get_args() + + # Derive image_token_id from the tokenizer (matches remote reference). + tokenizer = get_tokenizer() + image_special_token_id = tokenizer.convert_tokens_to_ids("") + + # ── Configs ────────────────────────────────────────────────────────── + # Language config: Nemotron defaults + CLI overrides (parallelism, + # precision, and any explicitly-set arch fields) — all handled inside. + language_config = get_nemotron_moe_language_model_config(args) + + # Vision / projection: fixed architectures, sync precision and parallelism. + vision_config = get_radio_vision_config() + projection_config = get_vlm_projection_config( + hidden_size=language_config.hidden_size, + ) + for cfg in (vision_config, projection_config): + cfg.params_dtype = language_config.params_dtype + cfg.bf16 = language_config.bf16 + cfg.fp16 = language_config.fp16 + cfg.use_cpu_initialization = language_config.use_cpu_initialization + cfg.perform_initialization = language_config.perform_initialization + # TP group is global — vision/projection layers already use the global + # TP=2 process group at init time. Sync the config so that + # sharded_state_dict() computes matching global shapes. + cfg.tensor_model_parallel_size = language_config.tensor_model_parallel_size + + # ── Vision encoder (RADIO) ─────────────────────────────────────────── + # Image args from CLI + img_h = getattr(args, "img_h", 512) + img_w = getattr(args, "img_w", 512) + patch_dim = getattr(args, "patch_dim", 16) + apply_pixel_shuffle = getattr(args, "pixel_shuffle", False) + class_token_len = getattr(args, "class_token_len", 8) or 8 + disable_vision_class_token = getattr(args, "disable_vision_class_token", False) + + # After pixel shuffle: hidden * 4 = 1280 * 4 = 5120 + vision_input_size = vision_config.hidden_size * 4 if apply_pixel_shuffle else vision_config.hidden_size + + vision_encoder = ModuleSpec( + module=RADIOEncoderWrapper, + params={ + "transformer_config": vision_config, + "transformer_layer_spec": get_radio_vision_layer_spec(), + "img_h": img_h, + "img_w": img_w, + "patch_dim": patch_dim, + "class_token_len": class_token_len, + "drop_class_token": disable_vision_class_token, + "apply_pixel_shuffle": apply_pixel_shuffle, + }, + ) + + # ── Vision → language projection ───────────────────────────────────── + vision_projection = ModuleSpec( + module=MultimodalProjector, + params={ + "config": projection_config, + "submodules": get_vlm_projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": vision_input_size, + }, + ) + + # ── Vision submodule spec ──────────────────────────────────────────── + vision_submodule_spec = ModuleSpec( + module=VisionModalitySubmodules, + params={}, + submodules={ + "encoders": {"radio_encoder": vision_encoder}, + "input_projections": [vision_projection], + }, + ) + + # ── Language model (MambaModel for hybrid Mamba-MoE) ───────────────── + hybrid_override_pattern = getattr(args, "hybrid_override_pattern", None) + + language_model_spec = ModuleSpec( + module=MambaModel, + params={ + "config": language_config, + "mamba_stack_spec": get_nemotron_moe_language_layer_spec(), + "vocab_size": args.padded_vocab_size, + "max_sequence_length": args.max_position_embeddings, + "pre_process": pre_process, + "post_process": post_process, + "hybrid_override_pattern": hybrid_override_pattern, + "position_embedding_type": "none", + # Disable scatter in embedding — MIMO combines modality embeddings + # at full sequence length, then scatters to SP in forward(). + "scatter_embedding_sequence_parallel": False, + }, + ) + + # ── Assemble MIMO model ────────────────────────────────────────────── + mimo_model_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec}, + special_token_ids={"images": image_special_token_id}, + ) + + mimo_model = MimoModel(mimo_model_config) + print("*" * 100) + print_mimo_structure(mimo_model) + print("*" * 100) + + # ── Load pre-trained checkpoints ────────────────────────────────────── + if getattr(args, "nemotron_checkpoint", None) is not None: + if getattr(args, "load", None) is not None: + raise ValueError( + "--nemotron-checkpoint and --load cannot both be set. " + "Use --nemotron-checkpoint for initial loading of a non-MIMO checkpoint, " + "or --load for resuming from a MIMO-native checkpoint." + ) + load_nemotron_vlm_ckpt( + mimo_model, + args.nemotron_checkpoint, + skip_projection=getattr(args, "skip_projection_checkpoint", False), + ) + print(f"Successfully loaded nemotron checkpoint from {args.nemotron_checkpoint}") + + if getattr(args, "language_model_checkpoint", None) is not None: + load_submodule_ckpt(mimo_model.language_model, args.language_model_checkpoint) + print(f"Successfully loaded language model from {args.language_model_checkpoint}") + + if getattr(args, "vision_encoder_checkpoint", None) is not None: + load_submodule_ckpt( + mimo_model.modality_submodules.images.encoders.radio_encoder.radio_model, + args.vision_encoder_checkpoint, + ignore_missing_keys=("class_token",), + ) + print(f"Successfully loaded vision encoder from {args.vision_encoder_checkpoint}") + + # ── Freeze / unfreeze based on CLI flags ───────────────────────────── + if getattr(args, "freeze_vit", False): + for p in mimo_model.modality_submodules.images.encoders.radio_encoder.parameters(): + p.requires_grad = False + + if getattr(args, "freeze_lm", False): + for p in mimo_model.language_model.parameters(): + p.requires_grad = False + + if getattr(args, "freeze_projection", False): + for proj in mimo_model.modality_submodules.images.input_projections: + for p in proj.parameters(): + p.requires_grad = False + + # Log trainable vs frozen parameter counts. + _log_freeze_summary(mimo_model) + + return mimo_model + + +def _log_freeze_summary(model: MimoModel): + """Print trainable/frozen parameter counts per component.""" + components = { + "vision_encoder": model.modality_submodules.images.encoders.radio_encoder, + "projection": model.modality_submodules.images.input_projections, + "language_model": model.language_model, + } + total_trainable = 0 + total_frozen = 0 + print("=" * 60) + print("Freeze summary:") + for name, module in components.items(): + trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) + frozen = sum(p.numel() for p in module.parameters() if not p.requires_grad) + total_trainable += trainable + total_frozen += frozen + status = "FROZEN" if trainable == 0 else ("TRAINABLE" if frozen == 0 else "PARTIAL") + print(f" {name:20s}: {status:10s} " + f"(trainable={trainable:>12,}, frozen={frozen:>12,})") + print(f" {'TOTAL':20s}: " + f"(trainable={total_trainable:>12,}, frozen={total_frozen:>12,})") + print("=" * 60) diff --git a/examples/mimo/model_providers/radio_encoder.py b/examples/mimo/model_providers/radio_encoder.py new file mode 100644 index 00000000000..d88f24f2cbe --- /dev/null +++ b/examples/mimo/model_providers/radio_encoder.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""RADIO ViT encoder wrapper for MIMO. + +Wraps RADIOViTModel to conform to the MIMO encoder forward(**kwargs) interface, +with optional class-token dropping and pixel-shuffle post-processing. +""" + +from typing import Optional + +import torch +import torch.nn as nn + +from megatron.core.models.multimodal.llava_model import pixel_shuffle +from megatron.core.models.vision.radio import RADIOViTModel +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import sharded_state_dict_default + + +class RADIOEncoderWrapper(nn.Module): + """RADIO encoder wrapper for MIMO's encoder interface. + + Instantiates a ``RADIOViTModel`` and adds optional class-token dropping + and pixel-shuffle post-processing so that the output is ready for the + multimodal projector. + + After pixel shuffle the hidden dimension is ``hidden_size * 4`` and the + sequence length is reduced by 4×. + """ + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + img_h: int = 512, + img_w: int = 512, + patch_dim: int = 16, + class_token_len: int = 8, + drop_class_token: bool = True, + apply_pixel_shuffle: bool = False, + max_img_h: int = 2048, + max_img_w: int = 2048, + has_cpe: bool = True, + embedder_bias: bool = False, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__() + self.drop_class_token = drop_class_token + self.class_token_len = class_token_len + self._apply_pixel_shuffle = apply_pixel_shuffle + + self.radio_model = RADIOViTModel( + transformer_config=transformer_config, + transformer_layer_spec=transformer_layer_spec, + patch_dim=patch_dim, + img_h=img_h, + img_w=img_w, + class_token_len=class_token_len, + add_class_token=True, + max_img_h=max_img_h, + max_img_w=max_img_w, + has_cpe=has_cpe, + embedder_bias=embedder_bias, + pg_collection=pg_collection, + ) + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Delegate to radio_model for distributed checkpoint TP resharding support.""" + sharded_sd = {} + for name, child in self.named_children(): + sharded_sd.update( + sharded_state_dict_default(child, f'{prefix}{name}.', sharded_offsets, metadata) + ) + return sharded_sd + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run RADIO encoder. + + Args: + x: Input images of shape ``[num_tiles, 3, img_h, img_w]``. + + Returns: + Encoded embeddings. With default settings the shape is + ``[num_tiles, reduced_seq, hidden_size * 4]``. + """ + x = x.to(dtype=self.radio_model.embedder.weight.dtype) + embeddings = self.radio_model(x) # [num_tiles, seq_len, hidden] + + if self.drop_class_token: + embeddings = embeddings[:, self.class_token_len :, :] + + if self._apply_pixel_shuffle: + embeddings = pixel_shuffle(embeddings, scale_factor=0.5) + + return embeddings diff --git a/megatron/core/hyper_comm_grid.py b/megatron/core/hyper_comm_grid.py index 4b860396c4e..e8ea34bd9a8 100644 --- a/megatron/core/hyper_comm_grid.py +++ b/megatron/core/hyper_comm_grid.py @@ -77,6 +77,37 @@ class HyperCommGrid: rank_offset: Starting rank when the grid doesn't span the entire communication world. Default 0. backend: Backend for creating process group. Default None and will use default backend. + alt_factorizations: Optional alternate factorizations of a contiguous block of the primary + dim_names. Each entry re-expresses the same rank slab under different axis names with + a different shape. Used to overlap expert parallelism (EP / ETP / EDP) onto the same + ranks that carry TP / CP / DP without inflating the world size. The mapping has the + shape ``{name: {"shape": [...], "dim_names": [...], "replaces": [...]}}``. + + Constraints (enforced at construction): + + * ``replaces`` must be a contiguous slice of the primary ``dim_names``. + * The product of the alt ``shape`` must equal the product of the primary shape values + at the covered positions. + * Alt ``dim_names`` must not collide with primary ``dim_names`` or with names from + any other alt factorization. + + Example for NMFW-464 expert overlap (8 ranks, ``tp=cp=dp=2``, ``ep=etp=edp=2``):: + + HyperCommGrid( + [2, 2, 2, 1], ["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + }, + }, + ) + + ``create_pg("ep")`` then enumerates the same rank slab as ``create_pg("cp")`` would + for the primary, but under the expert factorization. Mixing covered primary dims + (e.g. ``tp``) with alt dims (e.g. ``ep``) in a single ``create_pg`` call is rejected + because the two views share ranks and the combined group is ambiguous. """ def __init__( @@ -85,6 +116,7 @@ def __init__( dim_names: list[str], rank_offset: int = 0, backend: Optional[str] = None, + alt_factorizations: Optional[dict[str, dict[str, Any]]] = None, ) -> None: if len(shape) != len(dim_names): raise ValueError(f"len(shape) {shape} != len(dim_names) {dim_names}") @@ -117,6 +149,134 @@ def __init__( self.backend = backend self._pgs: dict[str, dist.ProcessGroup] = {} + # Alt factorizations: each builds a "shadow" (dim_names, shape) that expresses the same + # flat rank range under a different naming, by replacing the contiguous slice of primary + # dim_names listed in ``replaces`` with the alt's dim_names and shape. The shadow drives + # einops enumeration when the caller asks for groups along alt axes. + self._alt_shadows: dict[str, Tuple[list[str], list[int]]] = {} + # Map from primary dim name → alt name that replaces it. Used to detect ambiguous + # mixed-factorization requests in ``_resolve_dims``. + self._replaced_to_alt: dict[str, str] = {} + # Map from alt-axis dim name → alt name that owns it. + self._dim_to_alt: dict[str, str] = {} + if alt_factorizations: + for alt_name, alt_spec in alt_factorizations.items(): + shadow = self._validate_and_build_alt(alt_name, alt_spec) + self._alt_shadows[alt_name] = shadow + for d in alt_spec["dim_names"]: + self._dim_to_alt[d] = alt_name + for d in alt_spec["replaces"]: + if d in self._replaced_to_alt: + other = self._replaced_to_alt[d] + raise ValueError( + f"alt_factorization {alt_name!r}: primary dim {d!r} is already " + f"replaced by alt factorization {other!r}; alt factorizations must " + f"replace disjoint slices of the primary" + ) + self._replaced_to_alt[d] = alt_name + + def _validate_and_build_alt( + self, alt_name: str, alt_spec: dict[str, Any] + ) -> Tuple[list[str], list[int]]: + r"""Validate one alt factorization and return its ``(shadow_dim_names, shadow_shape)``.""" + for required in ("shape", "dim_names", "replaces"): + if required not in alt_spec: + raise ValueError( + f"alt_factorization {alt_name!r} is missing required key {required!r}" + ) + alt_shape = list(alt_spec["shape"]) + alt_dim_names = list(alt_spec["dim_names"]) + replaces = list(alt_spec["replaces"]) + if len(alt_shape) != len(alt_dim_names): + raise ValueError( + f"alt_factorization {alt_name!r}: len(shape) {alt_shape} != " + f"len(dim_names) {alt_dim_names}" + ) + if not replaces: + raise ValueError(f"alt_factorization {alt_name!r}: replaces must be non-empty") + + # replaces must be a contiguous slice of primary dim_names + for c in replaces: + if c not in self.dim_names: + raise ValueError( + f"alt_factorization {alt_name!r}: replaces entry {c!r} is not a primary dim" + ) + first_idx = self.dim_names.index(replaces[0]) + expected = self.dim_names[first_idx : first_idx + len(replaces)] + if expected != replaces: + raise ValueError( + f"alt_factorization {alt_name!r}: replaces {replaces} must be a contiguous slice " + f"of primary dim_names {self.dim_names}" + ) + + # product(alt.shape) == product(primary.shape over replaced positions) + primary_replaced_prod = int(np.prod(self.shape[first_idx : first_idx + len(replaces)])) + alt_prod = int(np.prod(alt_shape)) + if alt_prod != primary_replaced_prod: + raise ValueError( + f"alt_factorization {alt_name!r}: product(shape) {alt_prod} != product of " + f"primary replaced dims {primary_replaced_prod}" + ) + + # alt dim_names must not collide with primary or other alt names + for d in alt_dim_names: + if d in self.dim_names: + raise ValueError( + f"alt_factorization {alt_name!r}: dim {d!r} collides with primary dim_names" + ) + if d in self._dim_to_alt: + other = self._dim_to_alt[d] + raise ValueError( + f"alt_factorization {alt_name!r}: dim {d!r} collides with alt " + f"factorization {other!r}" + ) + + # Build shadow by replacing the contiguous slice with alt + shadow_dim_names = ( + self.dim_names[:first_idx] + alt_dim_names + self.dim_names[first_idx + len(replaces) :] + ) + shadow_shape = self.shape[:first_idx] + alt_shape + self.shape[first_idx + len(replaces) :] + return shadow_dim_names, shadow_shape + + def _resolve_dims(self, dims_list: list[str]) -> Tuple[list[str], list[int]]: + r"""Pick the layout (primary or alt-shadow) that should handle ``dims_list``. + + Returns: + ``(dim_names, shape)`` — the layout to use for rank-enumeration. + + Raises: + KeyError: if a requested dim is unknown. + ValueError: if the request mixes replaced primary dims with alt dims, or alt dims + from different factorizations. + """ + alts_used: set[str] = set() + has_replaced_primary = False + for d in dims_list: + if d in self._dim_to_alt: + alts_used.add(self._dim_to_alt[d]) + elif d in self.dim_names: + if d in self._replaced_to_alt: + has_replaced_primary = True + else: + raise KeyError(f"Dimension {d!r} is not a primary or alt dim of this grid") + + if len(alts_used) > 1: + raise ValueError( + f"create_pg/get_pg cannot mix dims from multiple alt factorizations: " + f"{sorted(alts_used)}" + ) + + if alts_used: + alt_name = next(iter(alts_used)) + if has_replaced_primary: + raise ValueError( + f"Cannot combine replaced primary dims with dims from alt factorization " + f"{alt_name!r}; the views share ranks and the combined group is ambiguous" + ) + return self._alt_shadows[alt_name] + + return self.dim_names, self.shape + def create_pg(self, dims: Union[str, list[str]], **kwargs: Any) -> dist.ProcessGroup | None: r"""Create a process group based on a list of dimension names @@ -145,8 +305,10 @@ def create_pg(self, dims: Union[str, list[str]], **kwargs: Any) -> dist.ProcessG Raises: KeyError: If attempting to recreate a process group with an existing key. """ - # ordered_dims and unique_group_key will follow the reversed order of self.dim_names - ordered_dims, unique_group_key = self._order_dims(dims) + dims_list = [dims] if isinstance(dims, str) else list(dims) + layout_names, layout_shape = self._resolve_dims(dims_list) + # ordered_dims and unique_group_key follow the reversed order of layout_names + ordered_dims, unique_group_key = self._order_dims(dims, dim_names=layout_names) if unique_group_key in self._pgs: raise KeyError( @@ -155,10 +317,10 @@ def create_pg(self, dims: Union[str, list[str]], **kwargs: Any) -> dist.ProcessG f"of returning the process group that has already been created before." ) - rank_enum = self._gen_rank_enum(ordered_dims) + rank_enum = self._gen_rank_enum(ordered_dims, dim_names=layout_names, shape=layout_shape) pg, _ = dist.new_subgroups_by_enumeration(rank_enum, backend=self.backend, **kwargs) - if dist.get_rank() == 0: + if dist.is_initialized() and dist.get_rank() == 0: logging.info( f"Generated process group for {unique_group_key} with enumeration {rank_enum}" ) @@ -178,7 +340,9 @@ def get_pg(self, dims: Union[str, list[str]]) -> dist.ProcessGroup: Args: dims: Name of leading dimensions to create process group """ - _, unique_group_key = self._order_dims(dims) + dims_list = [dims] if isinstance(dims, str) else list(dims) + layout_names, _ = self._resolve_dims(dims_list) + _, unique_group_key = self._order_dims(dims, dim_names=layout_names) if unique_group_key not in self._pgs: raise KeyError( @@ -200,10 +364,17 @@ def get_rank_enum(self, dims: Union[str, list[str]]) -> list[list[int]]: Returns: List of rank lists (one per subgroup). """ - ordered_dims, _ = self._order_dims(dims) - return self._gen_rank_enum(ordered_dims) + dims_list = [dims] if isinstance(dims, str) else list(dims) + layout_names, layout_shape = self._resolve_dims(dims_list) + ordered_dims, _ = self._order_dims(dims, dim_names=layout_names) + return self._gen_rank_enum(ordered_dims, dim_names=layout_names, shape=layout_shape) - def _gen_rank_enum(self, dims: list[str]) -> list[list[int]]: + def _gen_rank_enum( + self, + dims: list[str], + dim_names: Optional[list[str]] = None, + shape: Optional[list[int]] = None, + ) -> list[list[int]]: r"""Generate rank enumeration before calling new_subgroups_by_enumeration This function returns ranks grouped by the specified dimensions, but in REVERSE order @@ -220,6 +391,11 @@ def _gen_rank_enum(self, dims: list[str]) -> list[list[int]]: Args: dims: Name of leading dimensions to create process group + dim_names: Layout dim_names to use; defaults to ``self.dim_names``. When the caller + requests groups along an alt factorization, this is overridden by the alt's + shadow dim_names. + shape: Layout shape to use; defaults to ``self.shape``. Like ``dim_names``, this is + overridden when generating groups along an alt factorization. Although the function is lightweight enough to be inlined, a standalone one makes it easier to test against MCore's RankGenerator @@ -230,8 +406,11 @@ def _gen_rank_enum(self, dims: list[str]) -> list[list[int]]: "einops is not installed. Please install it with `pip install einops`." ) + layout_names = self.dim_names if dim_names is None else dim_names + layout_shape = self.shape if shape is None else shape + # Need to reverse order of dim_names to match MCore convention - dim_names_reverse = self.dim_names[::-1] + dim_names_reverse = layout_names[::-1] remaining_dims = [] for v in dim_names_reverse: @@ -243,17 +422,20 @@ def _gen_rank_enum(self, dims: list[str]) -> list[list[int]]: ) logging.debug(rearrange_str) - shape_dict = {d: s for d, s in zip(self.dim_names, self.shape)} + shape_dict = {d: s for d, s in zip(layout_names, layout_shape)} return einops.rearrange( np.arange(self.rank_offset, self.rank_offset + self.size), rearrange_str, **shape_dict ).tolist() - def _order_dims(self, dims: Union[str, list[str]]) -> Tuple[list[str], str]: - r"""Reorder dims based on the order of self.dim_names""" + def _order_dims( + self, dims: Union[str, list[str]], dim_names: Optional[list[str]] = None + ) -> Tuple[list[str], str]: + r"""Reorder dims based on the order of ``dim_names`` (defaults to ``self.dim_names``).""" + layout_names = self.dim_names if dim_names is None else dim_names if not isinstance(dims, list): ordered_dims = [dims] else: - dim_names_reverse = self.dim_names[::-1] + dim_names_reverse = layout_names[::-1] indices = sorted([dim_names_reverse.index(d) for d in dims]) if len(indices) == 1: ordered_dims = [dim_names_reverse[indices[0]]] diff --git a/megatron/core/models/hybrid/hybrid_layer_allocation.py b/megatron/core/models/hybrid/hybrid_layer_allocation.py index f1ba94ef7fa..e30869ee32b 100644 --- a/megatron/core/models/hybrid/hybrid_layer_allocation.py +++ b/megatron/core/models/hybrid/hybrid_layer_allocation.py @@ -333,6 +333,8 @@ def select_pipeline_segment( vp_stage: Optional[int], first_stage_layers: Optional[int] = None, last_stage_layers: Optional[int] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + dp_cp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Tuple[List[str], int]: """Select and validate the pipeline segment for the given PP rank and VP stage. @@ -445,6 +447,8 @@ def select_pipeline_segment( f"HybridModel: pp_rank={pp_rank}/{pp_size}, vp_stage={vp_stage}, " f"layers='{''.join(selected)}' ({len(selected)} layers), " f"layer_offset={offset} (auto-split)", + tp_group=tp_group, + dp_cp_group=dp_cp_group, ) return selected, offset @@ -479,6 +483,8 @@ def select_pipeline_segment( f"segment_index={segment_index}/{len(segments)}, " f"layers='{my_segment}' ({len(layer_type_list)} layers), " f"layer_offset={layer_offset}", + tp_group=tp_group, + dp_cp_group=dp_cp_group, ) return layer_type_list, layer_offset diff --git a/megatron/core/models/hybrid/hybrid_model.py b/megatron/core/models/hybrid/hybrid_model.py index 4399c6984a7..d648bd96c78 100644 --- a/megatron/core/models/hybrid/hybrid_model.py +++ b/megatron/core/models/hybrid/hybrid_model.py @@ -192,6 +192,8 @@ def __init__( vp_stage, first_stage_layers=self.config.num_layers_in_first_pipeline_stage, last_stage_layers=self.config.num_layers_in_last_pipeline_stage, + tp_group=getattr(self.pg_collection, "tp", None), + dp_cp_group=getattr(self.pg_collection, "dp_cp", None), ) # Determine if MTP is needed (based on pattern parsing) diff --git a/megatron/core/models/vision/radio.py b/megatron/core/models/vision/radio.py index 5e9525adfee..e7c84266d6d 100644 --- a/megatron/core/models/vision/radio.py +++ b/megatron/core/models/vision/radio.py @@ -125,6 +125,8 @@ def __init__( self.pos_dropout = pos_dropout self.has_cpe = has_cpe + self.pg_collection = pg_collection + # Using non-TE version so we can force gather_output self.embedder = ColumnParallelLinear( input_size=3 * self.patch_dim * self.patch_dim, @@ -133,13 +135,13 @@ def __init__( config=transformer_config, gather_output=True, init_method=lambda tensor: torch.nn.init.normal_(tensor, mean=0.0, std=1.0), + tp_group=pg_collection.tp if pg_collection is not None else None, ) self.model_type = ModelType.encoder_or_decoder self.ln_pre = None self.ln_post = None - self.pg_collection = pg_collection self.vp_stage = vp_stage if ln_pre_impl is not None: self.ln_pre = build_module( diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index 6c1e3651387..48460578520 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -165,6 +165,95 @@ def __repr__(self): else "ProcessGroupCollection(empty)" ) + @classmethod + def from_hyper_comm_grid(cls, grid) -> "ProcessGroupCollection": + """Build a ``ProcessGroupCollection`` directly from a ``HyperCommGrid``. + + Uses the grid's primary axes ``tp``, ``cp``, ``dp``, ``pp`` (each optional) to + populate the standard non-expert fields, and the grid's alt-factorization axes + ``etp``/``ep``/``edp`` to populate the expert fields when an alt factorization + named ``"expert"`` (or any alt that names those axes) is present. No global + ``parallel_state`` initialization is required. + + Args: + grid: A ``HyperCommGrid`` instance. May carry an alt factorization that names + expert axes (``etp``/``ep``/``edp``) over the same per-PP-stage rank slab as + the primary ``tp``/``cp``/``dp`` axes (see NMFW-464 expert overlap). + + Returns: + A populated ``ProcessGroupCollection``. Primary fields (``tp``/``cp``/``dp``/``pp`` + and the standard combined groups ``dp_cp``/``tp_cp``/``tp_dp_cp``/``mp``) are only + set when the corresponding axes exist in the grid; on ranks that aren't members of + the grid, the value is ``None`` rather than the field being absent. Expert fields + (``ep``/``expt_tp``/``expt_dp``/``tp_ep``/``tp_ep_pp``) are *always* populated — + either with the alt-factorization group, or with ``None`` when no alt is present — + so DDP / optimizer / MoE call sites can use a uniform ``hasattr`` + ``is not None`` + probe regardless of whether the grid has an expert alt. + """ + from megatron.core.hyper_comm_grid import HyperCommGrid # local import to avoid cycle + + if not isinstance(grid, HyperCommGrid): + raise TypeError(f"grid must be a HyperCommGrid, got {type(grid)}") + + primary = set(grid.dim_names) + alts = set(grid._dim_to_alt) if hasattr(grid, "_dim_to_alt") else set() + + def _make(dims): + """Create the requested group if all named axes are present in the grid. + + Returns the PG (which may be ``None`` on ranks that aren't members of this + grid — ``dist.new_subgroups_by_enumeration`` returns ``None`` for non-members). + Returns the sentinel ``"absent"`` if the axis isn't present in the grid at + all (so the field is skipped entirely rather than set to ``None``). + """ + names = [dims] if isinstance(dims, str) else list(dims) + if not all((d in primary) or (d in alts) for d in names): + return "absent" + return grid.create_pg(dims) + + kwargs = {} + for field_name, axis_spec in ( + ("tp", "tp"), + ("cp", "cp"), + ("dp", "dp"), + ("pp", "pp"), + ("dp_cp", ["dp", "cp"]), + ("tp_cp", ["tp", "cp"]), + ("tp_dp_cp", ["tp", "dp", "cp"]), + ("mp", ["tp", "pp"]), + ): + pg = _make(axis_spec) + if pg != "absent": + # Set the field even when ``pg is None`` so non-member ranks see + # ``collection.tp is None`` rather than ``AttributeError``. + kwargs[field_name] = pg + + # Expert axes via alt factorization. As above, ``pg`` may legitimately be + # ``None`` on ranks outside this grid — set the field so callers can probe. + # When the grid carries no alt factorization at all (typical for non-MoE + # encoder grids), populate the standard expert fields with ``None`` so that + # callers like ``setup_process_groups_for_ddp`` (which uses ``hasattr``) can + # uniformly probe for them without distinguishing MoE-grid from non-MoE-grid. + kwargs["ep"] = grid.create_pg("ep") if "ep" in alts else None + kwargs["expt_tp"] = grid.create_pg("etp") if "etp" in alts else None + kwargs["expt_dp"] = grid.create_pg("edp") if "edp" in alts else None + kwargs["tp_ep"] = grid.create_pg(["etp", "ep"]) if {"etp", "ep"}.issubset(alts) else None + kwargs["tp_ep_pp"] = ( + grid.create_pg(["etp", "ep", "pp"]) + if {"etp", "ep"}.issubset(alts) and "pp" in primary + else None + ) + + # Aliases used by DDP / optimizer when there's a single distributed-optimizer instance. + # Falling back to the non-partial groups keeps ``hasattr`` / fallback paths happy without + # needing to allocate separate NCCL groups. + if "dp_cp" in kwargs and "intra_dp_cp" not in kwargs: + kwargs["intra_dp_cp"] = kwargs["dp_cp"] + if "expt_dp" in kwargs and "intra_expt_dp" not in kwargs: + kwargs["intra_expt_dp"] = kwargs["expt_dp"] + + return cls(**kwargs) + @classmethod def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None): """ diff --git a/scripts/nmfw464_e2e_batch.sh b/scripts/nmfw464_e2e_batch.sh new file mode 100755 index 00000000000..2d09568c6a7 --- /dev/null +++ b/scripts/nmfw464_e2e_batch.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash +# NMFW-464 Phase 1 E2E batch test runner. +# +# Each test runs in its own ``torch.distributed.run`` invocation so global +# singletons (``parallel_state``, NCCL groups, RNG tracker) cannot leak +# between tests. Fails fast on the first non-zero pytest exit so the cog +# batch job returns a meaningful status to the caller. +set -euo pipefail +cd "$(dirname "$0")/.." + +PYTEST_TESTS=( + # Unit tests (mock, fast). + 'tests/unit_tests/test_hyper_comm_grid.py::TestHyperCommGrid' + 'tests/unit_tests/test_hyper_comm_grid.py::TestHyperCommGridAltFactorization' + + # Distributed integration (8 GPUs). + 'tests/unit_tests/test_hyper_comm_grid.py::TestHyperCommGridIntegration' + 'tests/unit_tests/test_process_groups_config.py::TestPGConfigFromHyperCommGrid' + + # MoE-on-schedule. + 'tests/unit_tests/transformer/moe/test_moe_with_hcg_pg.py' + + # MIMO + MoE end-to-end matrix. Each test gets its own torchrun. + 'tests/unit_tests/models/test_mimo_moe_e2e.py::TestMimoMoEColocated::test_mimo_moe_colocated_8gpu[False]' + 'tests/unit_tests/models/test_mimo_moe_e2e.py::TestMimoMoEColocated::test_mimo_moe_colocated_8gpu[True]' + 'tests/unit_tests/models/test_mimo_moe_e2e.py::TestMimoMoEColocated::test_mimo_nemotron_mamba_moe_colocated_8gpu' + 'tests/unit_tests/models/test_mimo_moe_e2e.py::TestMimoMoEColocated::test_mimo_nemotron_radio_mamba_moe_colocated_8gpu' + 'tests/unit_tests/models/test_mimo_moe_e2e.py::TestMimoMoEColocated::test_mimo_mamba_moe_non_colocated_8gpu' + 'tests/unit_tests/models/test_mimo_moe_e2e.py::TestMimoMoEColocated::test_mimo_nemotron_radio_mamba_moe_non_colocated_8gpu' +) + +results=() +for t in "${PYTEST_TESTS[@]}"; do + echo + echo "===================================================================" + echo "RUN: $t" + echo "===================================================================" + # Hash the test id so log directories never collide on long parametrized names. + log_id="$(printf '%s' "$t" | sha1sum | cut -c1-12)" + log_dir="${TORCHRUN_LOG_DIR:-/tmp}/${log_id}" + mkdir -p "$log_dir" + + # Authoritative success/failure comes from pytest's exit code (PIPESTATUS[0]), + # not from grep matching a "passed" line. The grep is purely for one-line + # summaries in the batch stdout and is allowed to print nothing. + set +e + uv run --no-sync python -m torch.distributed.run \ + --nproc-per-node 8 \ + --redirects=3 --tee=3 \ + --log-dir "$log_dir" \ + -m pytest "$t" -q --tb=line 2>&1 | tee "$log_dir/run.log" + pytest_rc=${PIPESTATUS[0]} + set -e + + grep -E '^\[default0\]:.*passed|^\[default0\]:.*failed' "$log_dir/run.log" | tail -3 || true + + if [ "$pytest_rc" -eq 0 ]; then + results+=("PASS $t") + else + results+=("FAIL $t (rc=$pytest_rc)") + echo "===================================================================" + echo "FAILED: $t (pytest exit code $pytest_rc)" + echo "Last 80 lines of run.log:" + tail -80 "$log_dir/run.log" + echo "===================================================================" + printf '%s\n' "${results[@]}" + exit "$pytest_rc" + fi +done + +echo +echo "===================================================================" +echo "ALL TESTS PASSED" +echo "===================================================================" +printf '%s\n' "${results[@]}" diff --git a/tests/unit_tests/models/test_mimo_moe_e2e.py b/tests/unit_tests/models/test_mimo_moe_e2e.py new file mode 100644 index 00000000000..8decbf8cad7 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_moe_e2e.py @@ -0,0 +1,1152 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""End-to-end MIMO + MoE smoke test (NMFW-464 Phase 1). + +Builds a small MimoModel (TransformerBlock encoder + MoE GPT LLM) with process +groups built from HyperCommGrid alt-factorization, wraps in DDP, and drives one +``forward_backward_no_pipelining`` step on 8 GPUs with mock data. No global +``parallel_state`` is initialized for model-parallel groups. + +This is the colocated half of the Phase-1 smoke matrix; encoder and LLM share +the same 8 ranks but the LLM also carries an EP/ETP/EDP alt-factorization +overlapping its TP/CP/DP axes (the NMFW-464 expert-overlap fix). +""" + +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist + +from megatron.core import parallel_state, pipeline_parallel +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_stack_spec +from megatron.core.models.hybrid.hybrid_model import HybridModel +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +from megatron.core.process_groups_config import ( + MultiModuleProcessGroupCollection, + ProcessGroupCollection, +) +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.enums import AttnBackend, ModelType +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TERowParallelLinear = None + + +def _build_grids_and_pgs(): + """Build a single HCG that hosts both encoder and LLM (colocated) on 8 ranks. + + The LLM grid carries an expert alt-factorization (etp=1, ep=4, edp=2) over + the same tp=2 cp=1 dp=4 slab. Encoder uses the primary axes only. + """ + encoder_grid = HyperCommGrid( + shape=[1, 1, 8, 1], dim_names=["tp", "cp", "dp", "pp"], backend="nccl" + ) + llm_grid = HyperCommGrid( + shape=[1, 1, 8, 1], + dim_names=["tp", "cp", "dp", "pp"], + backend="nccl", + alt_factorizations={ + "expert": { + # Re-factor the 8-rank slab as ep=4, edp=2 (no expert TP) overlaying dp=8. + "shape": [1, 4, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + + encoder_pg = ProcessGroupCollection.from_hyper_comm_grid(encoder_grid) + llm_pg = ProcessGroupCollection.from_hyper_comm_grid(llm_grid) + + # PP=1 singleton groups stand in for embd / pos_embd. + encoder_pg.embd = encoder_pg.pp + encoder_pg.pos_embd = encoder_pg.pp + llm_pg.embd = llm_pg.pp + llm_pg.pos_embd = llm_pg.pp + + return encoder_grid, llm_grid, encoder_pg, llm_pg + + +def _build_mamba_moe_language_spec(pg, num_layers, hidden, num_experts, vocab_size, seq_len): + """Build a MambaModel/HybridModel language spec with Nemotron-MoE shape (Mamba + MoE).""" + from megatron.core.activations import squared_relu + + # ``pg.tp`` is None on ranks that aren't members of this grid (non-colocated layout). + tp_size = pg.tp.size() if pg.tp is not None else 1 + config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden, + num_attention_heads=4, + num_query_groups=2, + kv_channels=64, + ffn_hidden_size=hidden * 2, + num_moe_experts=num_experts, + moe_router_topk=min(num_experts, 4), + moe_token_dispatcher_type="alltoall", + moe_grouped_gemm=False, + moe_ffn_hidden_size=hidden, + add_bias_linear=False, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + context_parallel_size=1, + pipeline_model_parallel_size=1, + sequence_parallel=False, + bf16=False, + params_dtype=torch.float32, + attention_backend=AttnBackend.unfused, + variable_seq_lengths=True, + ) + config.activation_func = squared_relu + config.gated_linear_unit = False + config.normalization = "RMSNorm" + config.position_embedding_type = "none" + # Mamba SSM dims (smallest values that pass Mamba's divisibility checks for hidden=128). + config.mamba_state_dim = 64 + config.mamba_head_dim = 64 + config.mamba_num_heads = 2 + config.mamba_num_groups = 1 + return ModuleSpec( + module=HybridModel, + params={ + "config": config, + "hybrid_stack_spec": hybrid_stack_spec, + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + # 4-layer Nemotron-style mini pattern: Mamba, attention, Mamba, Expert. + "hybrid_layer_pattern": "M*ME"[:num_layers], + "pre_process": True, + "post_process": True, + "pg_collection": pg, + }, + ) + + +def _build_language_spec( + pg, num_layers, hidden, num_experts, vocab_size, seq_len, nemotron_flavor=False +): + """Build a GPT-MoE language spec. Tolerates ``pg.tp is None`` for non-colocated ranks. + + When ``nemotron_flavor=True`` the config matches Nemotron6-MoE in every dimension + that doesn't require the Mamba SSM kernels: squared_relu activation, RMSNorm, + GQA with num_query_groups=2, sigmoid router with topk=6 (capped to num_experts), + shared experts, alltoall dispatcher, no bias. Mamba SSM layers themselves are + out of scope (they need ``mamba_ssm``); the rest of the architecture reachable + through ``TransformerConfig`` is exercised here. + """ + from megatron.core.activations import squared_relu + + tp_size = pg.tp.size() if pg.tp is not None else 1 + config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden, + num_attention_heads=4, + ffn_hidden_size=4 * hidden, + num_moe_experts=num_experts, + moe_router_topk=min(num_experts, 6) if nemotron_flavor else 2, + moe_token_dispatcher_type="alltoall", + moe_grouped_gemm=False, + moe_ffn_hidden_size=2 * hidden, + add_bias_linear=False, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + context_parallel_size=1, + pipeline_model_parallel_size=1, + sequence_parallel=False, + bf16=False, + params_dtype=torch.float32, + attention_backend=AttnBackend.unfused, + variable_seq_lengths=True, + ) + if nemotron_flavor: + # Nemotron6-MoE knobs that are reachable through TransformerConfig (everything + # below comes from configs/nemotron_moe_vlm.py), minus the Mamba-specific fields. + config.activation_func = squared_relu + config.gated_linear_unit = False + config.normalization = "RMSNorm" + config.num_query_groups = 2 + config.kv_channels = 128 + config.moe_router_score_function = "sigmoid" + config.moe_router_topk_scaling_factor = 2.5 + config.moe_router_enable_expert_bias = True + config.moe_router_dtype = "fp32" + config.moe_router_load_balancing_type = "seq_aux_loss" + config.moe_aux_loss_coeff = 0.0001 + config.moe_shared_expert_intermediate_size = 2 * hidden + # Note: shared_expert_overlap reaches for an internal MP group not set in our + # HCG-only setup; turn off for the smoke test. + config.moe_shared_expert_overlap = False + config.position_embedding_type = "none" + config.attention_dropout = 0.0 + config.hidden_dropout = 0.0 + config.bias_activation_fusion = False + config.masked_softmax_fusion = True + config.persist_layer_norm = True + config.bias_dropout_fusion = False + return ModuleSpec( + module=GPTModel, + params={ + "config": config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, moe_grouped_gemm=False + ), + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": True, + "post_process": True, + "pg_collection": pg, + }, + ) + + +def _build_radio_submodules_spec( + pg, num_layers, hidden, language_hidden, img_h, img_w, patch_dim, class_token_len +): + """Build a vision-modality submodules spec using the literal RADIOEncoderWrapper.""" + from examples.mimo.model_providers.radio_encoder import RADIOEncoderWrapper + + if TEColumnParallelLinear is None or TERowParallelLinear is None: + pytest.skip("TE column/row parallel linear not available") + tp_size = pg.tp.size() if pg.tp is not None else 1 + radio_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden, + num_attention_heads=4, + ffn_hidden_size=4 * hidden, + kv_channels=hidden // 4, + num_query_groups=4, + gated_linear_unit=False, + add_bias_linear=True, + add_qkv_bias=True, + normalization="LayerNorm", + layernorm_epsilon=1e-6, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + context_parallel_size=1, + pipeline_model_parallel_size=1, + sequence_parallel=False, + bf16=False, + params_dtype=torch.float32, + attention_backend=AttnBackend.unfused, + moe_token_dispatcher_type="alltoall", + variable_seq_lengths=True, + ) + radio_layer_spec = get_vit_layer_with_transformer_engine_spec() + encoder_spec = ModuleSpec( + module=RADIOEncoderWrapper, + params={ + "transformer_config": radio_config, + "transformer_layer_spec": radio_layer_spec, + "img_h": img_h, + "img_w": img_w, + "patch_dim": patch_dim, + "class_token_len": class_token_len, + "drop_class_token": True, + "apply_pixel_shuffle": False, + "max_img_h": img_h, + "max_img_w": img_w, + "has_cpe": True, + "embedder_bias": False, + # Threading the encoder pg_collection into RADIOViTModel so it doesn't + # fall back to ``parallel_state`` for dp / dp_cp / intra_dp_cp groups. + "pg_collection": pg, + }, + ) + proj_config = TransformerConfig( + num_layers=1, + hidden_size=language_hidden, + num_attention_heads=tp_size, + ffn_hidden_size=language_hidden, + add_bias_linear=False, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + bf16=False, + params_dtype=torch.float32, + ) + proj_config.activation_func = torch.nn.functional.gelu + proj_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": proj_config, + "submodules": MLPSubmodules( + linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + "projector_type": "mlp", + "input_size": hidden, # RADIO hidden -> projection -> language hidden + "tp_group": pg.tp, + }, + ) + return ModuleSpec( + module=VisionModalitySubmodules, + params={"pg_collection": pg}, + submodules={"encoders": {"radio_encoder": encoder_spec}, "input_projections": [proj_spec]}, + ) + + +def _build_vision_submodules_spec(pg, num_layers, hidden, language_hidden): + if TEColumnParallelLinear is None or TERowParallelLinear is None: + pytest.skip("TE column/row parallel linear not available") + tp_size = pg.tp.size() if pg.tp is not None else 1 + vision_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden, + num_attention_heads=4, + ffn_hidden_size=4 * hidden, + add_bias_linear=False, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + context_parallel_size=1, + pipeline_model_parallel_size=1, + sequence_parallel=False, + bf16=False, + params_dtype=torch.float32, + attention_backend=AttnBackend.unfused, + moe_token_dispatcher_type="alltoall", + variable_seq_lengths=True, + ) + encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": pg, + "pre_process": True, + "post_process": True, + }, + ) + proj_config = TransformerConfig( + num_layers=1, + hidden_size=hidden, + num_attention_heads=tp_size, + ffn_hidden_size=hidden, + add_bias_linear=False, + use_cpu_initialization=True, + tensor_model_parallel_size=tp_size, + bf16=False, + params_dtype=torch.float32, + ) + proj_config.activation_func = torch.nn.functional.gelu + proj_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": proj_config, + "submodules": MLPSubmodules( + linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + "projector_type": "mlp", + "input_size": hidden, + "tp_group": pg.tp, + }, + ) + return ModuleSpec( + module=VisionModalitySubmodules, + params={"pg_collection": pg}, + submodules={"encoders": {"clip_encoder": encoder_spec}, "input_projections": [proj_spec]}, + ) + + +def _mock_radio_data_iterator( + seq_length, micro_batch_size, vocab_size, num_image_tokens, img_h, img_w +): + """Mock data iterator that emits raw images (for RADIO) + token IDs with placeholders.""" + while True: + input_ids = torch.randint(1, vocab_size, (micro_batch_size, seq_length), device="cuda") + # Place exactly num_image_tokens placeholder tokens (id 0) per batch row. + input_ids[:, :num_image_tokens] = 0 + loss_mask = torch.ones((micro_batch_size, seq_length), device="cuda") + position_ids = ( + torch.arange(seq_length, device="cuda").unsqueeze(0).expand(micro_batch_size, -1) + ) + # RADIO takes [num_tiles, 3, H, W]. With one tile per batch row, num_tiles == B. + images = torch.randn(micro_batch_size, 3, img_h, img_w, device="cuda") + yield { + "tokens": input_ids, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": None, + "modality_inputs": {"radio_encoder": {"radio_encoder": {"x": images}}}, + } + + +def _mock_mimo_data_iterator(seq_length, micro_batch_size, vocab_size, image_seq_len, hidden): + while True: + # Reserve token id 0 as the image-placeholder; sample remaining tokens from [1, V). + input_ids = torch.randint(1, vocab_size, (micro_batch_size, seq_length), device="cuda") + # Place exactly image_seq_len placeholder tokens (id 0) per batch row. + input_ids[:, :image_seq_len] = 0 + loss_mask = torch.ones((micro_batch_size, seq_length), device="cuda") + position_ids = ( + torch.arange(seq_length, device="cuda").unsqueeze(0).expand(micro_batch_size, -1) + ) + # Vision input: image features [seq=image_seq_len, batch, hidden]. + image_features = torch.randn(image_seq_len, micro_batch_size, hidden, device="cuda") + yield { + "tokens": input_ids, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": None, + "modality_inputs": { + "clip_encoder": { + "clip_encoder": {"hidden_states": image_features, "attention_mask": None} + } + }, + } + + +def _loss_func(loss_mask, output_tensor): + if isinstance(output_tensor, (tuple, list)): + output_tensor = output_tensor[0] + loss = output_tensor.float().sum() + return loss, {"lm_loss": loss.detach()} + + +def _forward_step(data_iterator, model): + batch = next(data_iterator) + output_tensor = model( + input_ids=batch["tokens"], + position_ids=batch["position_ids"], + attention_mask=batch["attention_mask"], + modality_inputs=batch["modality_inputs"], + ) + return output_tensor, partial(_loss_func, batch["loss_mask"]) + + +class TestMimoMoEColocated: + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + try: + dist.init_process_group(backend="nccl") + except Exception as e: + pytest.skip(f"Cannot initialize distributed: {e}") + os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") + + @staticmethod + def _reset_global_singletons(): + """Drop any leftover ``parallel_state`` groups so this test's HCG-built groups + are the only model-parallel topology the process knows about. + + The MIMO/hetero path threads process groups through ``HyperCommGrid`` and + ``ProcessGroupCollection``; we deliberately do *not* call + ``parallel_state.initialize_model_parallel`` so any code that silently reaches + for it via the global singletons surfaces immediately rather than picking up + a wrong-topology group. + + ``_GLOBAL_MEMORY_BUFFER`` and ``_CUDA_RNG_STATE_TRACKER`` are intentionally + left alone — touching them at test-method scope can race with live CUDA state + and crash pytest fixture teardown. The batch runner already gives each test + its own ``torch.distributed.run`` invocation, so a fresh process is the real + isolation boundary. + """ + if parallel_state.model_parallel_is_initialized(): + parallel_state.destroy_model_parallel() + + def _run_mimo_step( + self, + language_spec, + hidden, + seq_length, + image_seq_len, + micro_batch_size, + vocab_size, + encoder_pg, + llm_pg, + ): + """Build MimoModel from the given language spec and drive one fwd/bwd step.""" + vision_spec = _build_vision_submodules_spec( + pg=encoder_pg, num_layers=1, hidden=hidden, language_hidden=hidden + ) + mimo_config = MimoModelConfig( + language_model_spec=language_spec, + modality_submodules_spec={"clip_encoder": vision_spec}, + special_token_ids={"clip_encoder": 0}, + ) + # Pass tp_group from llm_pg so PartitionConfig doesn't fall back to parallel_state. + mimo_model = MimoModel(mimo_config, tp_group=llm_pg.tp, cp_group=llm_pg.cp).cuda() + mimo_model.model_type = ModelType.encoder_or_decoder + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=False, + overlap_grad_reduce=False, + use_distributed_optimizer=False, + check_for_nan_in_grad=False, + bucket_size=None, + average_in_collective=False, + ) + ddp_model = DistributedDataParallel( + config=language_spec.params["config"], + ddp_config=ddp_config, + module=mimo_model, + pg_collection=llm_pg, + ) + + data_iter = _mock_mimo_data_iterator( + seq_length, micro_batch_size, vocab_size, image_seq_len, hidden + ) + losses = pipeline_parallel.schedules.forward_backward_no_pipelining( + forward_step_func=_forward_step, + data_iterator=data_iter, + model=ddp_model, + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=llm_pg, + ) + assert isinstance(losses, list) and len(losses) == 1 + loss_dict = losses[0] + assert "lm_loss" in loss_dict + assert torch.isfinite( + loss_dict["lm_loss"] + ).item(), f"loss not finite: {loss_dict['lm_loss']}" + + any_grad = False + for p in mimo_model.parameters(): + if p.grad is not None and p.grad.abs().sum() > 0: + any_grad = True + break + if hasattr(p, "main_grad") and p.main_grad is not None and p.main_grad.abs().sum() > 0: + any_grad = True + break + assert any_grad, "no parameter received a non-zero gradient" + return mimo_model + + def _setup_pgs_and_rng(self): + torch.manual_seed(12345) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + encoder_grid, llm_grid, encoder_pg, llm_pg = _build_grids_and_pgs() + assert sorted(dist.get_process_group_ranks(llm_pg.ep)) != sorted( + dist.get_process_group_ranks(llm_pg.dp) + ), "ep / dp must not alias under alt factorization with these shapes" + tp_rank = dist.get_rank(group=llm_pg.tp) + ep_rank = dist.get_rank(group=llm_pg.ep) + etp_rank = dist.get_rank(group=llm_pg.expt_tp) + model_parallel_cuda_manual_seed( + seed=12345, tp_rank=tp_rank, ep_rank=ep_rank, etp_rank=etp_rank + ) + if parallel_state._GLOBAL_MEMORY_BUFFER is None: + parallel_state._set_global_memory_buffer() + return encoder_grid, llm_grid, encoder_pg, llm_pg + + @pytest.mark.parametrize("nemotron_flavor", [False, True]) + def test_mimo_moe_colocated_8gpu(self, nemotron_flavor): + """Smoke: GPT-style MoE LLM, basic + Nemotron-flavor configs.""" + if not dist.is_initialized() or dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + # GPT path doesn't strictly need ``parallel_state``, but resetting it + # explicitly makes the test order-independent — matches the other tests + # in the class. + self._reset_global_singletons() + encoder_grid, llm_grid, encoder_pg, llm_pg = self._setup_pgs_and_rng() + hidden = 64 + language_spec = _build_language_spec( + pg=llm_pg, + num_layers=2, + hidden=hidden, + num_experts=4, + vocab_size=128, + seq_len=32, + nemotron_flavor=nemotron_flavor, + ) + self._run_mimo_step( + language_spec, + hidden=hidden, + seq_length=32, + image_seq_len=8, + micro_batch_size=2, + vocab_size=128, + encoder_pg=encoder_pg, + llm_pg=llm_pg, + ) + encoder_grid.destroy() + llm_grid.destroy() + + def test_mimo_nemotron_radio_mamba_moe_colocated_8gpu(self): + """E2E: RADIO ViT encoder + MLP projection + literal Mamba-MoE LLM. + + This is the literal Nemotron VLM assembly (small variant): RADIOEncoderWrapper + as the vision encoder, MultimodalProjector for vision→language, and + ``MambaModel`` (HybridModel with Mamba + MoE) as the language model. Wired + through HyperCommGrid alt-factorization, ``forward_backward_no_pipelining``, + and DDP. Mock images + token IDs with placeholders. Colocated mode. + """ + if not dist.is_initialized() or dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + self._reset_global_singletons() + encoder_grid, llm_grid, encoder_pg, llm_pg = self._setup_pgs_and_rng() + + hidden = 128 + num_experts = 4 + seq_length = 64 + # 64x64 image / patch=16 -> 16 patches; +8 class tokens; drop class -> 16 image tokens. + img_h = img_w = 64 + patch_dim = 16 + class_token_len = 8 + num_image_tokens_per_image = (img_h // patch_dim) * (img_w // patch_dim) + micro_batch_size = 2 + # Total image tokens placed in input_ids: one tile per batch row + # contributes num_image_tokens_per_image tokens; with B rows, the per-row + # placeholder count must equal num_image_tokens_per_image. + vocab_size = 128 + + language_spec = _build_mamba_moe_language_spec( + pg=llm_pg, + num_layers=4, + hidden=hidden, + num_experts=num_experts, + vocab_size=vocab_size, + seq_len=seq_length, + ) + vision_spec = _build_radio_submodules_spec( + pg=encoder_pg, + num_layers=2, + hidden=hidden, + language_hidden=hidden, + img_h=img_h, + img_w=img_w, + patch_dim=patch_dim, + class_token_len=class_token_len, + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_spec, + modality_submodules_spec={"radio_encoder": vision_spec}, + special_token_ids={"radio_encoder": 0}, + ) + mimo_model = MimoModel(mimo_config, tp_group=llm_pg.tp, cp_group=llm_pg.cp).cuda() + mimo_model.model_type = ModelType.encoder_or_decoder + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=False, + overlap_grad_reduce=False, + use_distributed_optimizer=False, + check_for_nan_in_grad=False, + bucket_size=None, + average_in_collective=False, + ) + ddp_model = DistributedDataParallel( + config=language_spec.params["config"], + ddp_config=ddp_config, + module=mimo_model, + pg_collection=llm_pg, + ) + + data_iter = _mock_radio_data_iterator( + seq_length=seq_length, + micro_batch_size=micro_batch_size, + vocab_size=vocab_size, + num_image_tokens=num_image_tokens_per_image, + img_h=img_h, + img_w=img_w, + ) + + losses = pipeline_parallel.schedules.forward_backward_no_pipelining( + forward_step_func=_forward_step, + data_iterator=data_iter, + model=ddp_model, + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=llm_pg, + ) + assert isinstance(losses, list) and len(losses) == 1 + assert torch.isfinite(losses[0]["lm_loss"]).item() + any_grad = any( + (p.grad is not None and p.grad.abs().sum() > 0) + or (hasattr(p, "main_grad") and p.main_grad is not None and p.main_grad.abs().sum() > 0) + for p in mimo_model.parameters() + ) + assert any_grad, "no parameter received a non-zero gradient" + encoder_grid.destroy() + llm_grid.destroy() + + def test_mimo_nemotron_radio_mamba_moe_non_colocated_8gpu(self): + """E2E non-colocated literal Nemotron VLM: RADIO encoder (ranks 0-3) + Mamba-MoE + LLM (ranks 4-7) bridged by ``MultiModulePipelineCommunicator``.""" + if not dist.is_initialized() or dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + self._reset_global_singletons() + torch.manual_seed(12345) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + + encoder_grid = HyperCommGrid( + shape=[1, 1, 4, 1], dim_names=["tp", "cp", "dp", "pp"], rank_offset=0, backend="nccl" + ) + llm_grid = HyperCommGrid( + shape=[1, 1, 4, 1], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=4, + backend="nccl", + alt_factorizations={ + "expert": { + "shape": [1, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + encoder_pg = ProcessGroupCollection.from_hyper_comm_grid(encoder_grid) + llm_pg = ProcessGroupCollection.from_hyper_comm_grid(llm_grid) + + rank = dist.get_rank() + in_encoder = 0 <= rank < 4 + in_llm = 4 <= rank < 8 + if in_encoder: + encoder_pg.embd = encoder_pg.pp + encoder_pg.pos_embd = encoder_pg.pp + if in_llm: + llm_pg.embd = llm_pg.pp + llm_pg.pos_embd = llm_pg.pp + + # This test assumes symmetric DP (no MBS scaling across the bridge). + if in_encoder: + assert encoder_pg.dp.size() == 4 + if in_llm: + assert llm_pg.dp.size() == 4 + + ep_rank = dist.get_rank(group=llm_pg.ep) if in_llm else 0 + model_parallel_cuda_manual_seed(seed=12345, tp_rank=0, ep_rank=ep_rank, etp_rank=0) + if parallel_state._GLOBAL_MEMORY_BUFFER is None: + parallel_state._set_global_memory_buffer() + + hidden = 128 + num_experts = 4 + seq_length = 64 + img_h = img_w = 64 + patch_dim = 16 + class_token_len = 8 + num_image_tokens_per_image = (img_h // patch_dim) * (img_w // patch_dim) + micro_batch_size = 2 + vocab_size = 128 + encoder_name = "radio_encoder" + + language_spec = _build_mamba_moe_language_spec( + pg=llm_pg, + num_layers=4, + hidden=hidden, + num_experts=num_experts, + vocab_size=vocab_size, + seq_len=seq_length, + ) + vision_spec = _build_radio_submodules_spec( + pg=encoder_pg, + num_layers=2, + hidden=hidden, + language_hidden=hidden, + img_h=img_h, + img_w=img_w, + patch_dim=patch_dim, + class_token_len=class_token_len, + ) + + module_to_grid_map = {encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid} + topology = {encoder_name: [MIMO_LANGUAGE_MODULE_KEY], MIMO_LANGUAGE_MODULE_KEY: []} + mimo_config = MimoModelConfig( + language_model_spec=language_spec, + modality_submodules_spec={encoder_name: vision_spec}, + special_token_ids={encoder_name: 0}, + module_to_grid_map=module_to_grid_map, + ) + tp_group = llm_pg.tp if in_llm else encoder_pg.tp + cp_group = llm_pg.cp if in_llm else encoder_pg.cp + mimo_model = MimoModel(mimo_config, tp_group=tp_group, cp_group=cp_group).cuda() + mimo_model.model_type = ModelType.encoder_or_decoder + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=False, + overlap_grad_reduce=False, + use_distributed_optimizer=False, + check_for_nan_in_grad=False, + bucket_size=None, + average_in_collective=False, + ) + if in_llm and mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=llm_pg, + ) + if in_encoder and encoder_name in mimo_model.modality_submodules: + sub = mimo_model.modality_submodules[encoder_name] + if sub is not None: + # The encoder-spec submodule key is "radio_encoder" here. + ddp_cfg_src = sub.encoders["radio_encoder"].radio_model.config + mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( + config=ddp_cfg_src, ddp_config=ddp_config, module=sub, pg_collection=encoder_pg + ) + + communicator = MultiModulePipelineCommunicator( + module_to_grid_map, + topology, + mimo_model.config, + dim_mapping={"s": 0, "h": 2, "b": 1}, + module_output_ndim={encoder_name: 2}, + ) + + def _data_iter(): + while True: + images = torch.randn(micro_batch_size, 3, img_h, img_w, device="cuda") + input_ids = torch.randint( + 1, vocab_size, (micro_batch_size, seq_length), device="cuda" + ) + input_ids[:, :num_image_tokens_per_image] = 0 + position_ids = ( + torch.arange(seq_length, device="cuda") + .unsqueeze(0) + .expand(micro_batch_size, -1) + ) + loss_mask = torch.ones((micro_batch_size, seq_length), device="cuda") + yield { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": None, + "loss_mask": loss_mask, + "modality_inputs": {encoder_name: {"radio_encoder": {"x": images}}}, + } + + module_pgs = {} + language_model_module_name = None + if in_encoder: + module_pgs[encoder_name] = encoder_pg + if in_llm: + module_pgs[MIMO_LANGUAGE_MODULE_KEY] = llm_pg + language_model_module_name = MIMO_LANGUAGE_MODULE_KEY + sched_pg = MultiModuleProcessGroupCollection( + module_pgs=module_pgs, language_model_module_name=language_model_module_name + ) + + def step_func(data_iterator, model): + def loss_func(loss_mask, output_tensor): + if output_tensor is None: + return torch.tensor(0.0, device="cuda", requires_grad=True), { + "loss_reduced": 0.0 + } + if isinstance(output_tensor, dict): + out = output_tensor.get( + MIMO_LANGUAGE_MODULE_KEY, next(iter(output_tensor.values()), None) + ) + else: + out = output_tensor + if out is None: + return torch.tensor(0.0, device="cuda", requires_grad=True), { + "loss_reduced": 0.0 + } + if isinstance(out, (tuple, list)): + out = out[0] + loss = out.float().sum() + return loss, {"loss_reduced": loss.detach()} + + batch = next(data_iterator) if data_iterator is not None else {} + output_tensor = model(**batch) + if isinstance(output_tensor, tuple): + output_tensor = output_tensor[0] + return output_tensor, partial(loss_func, batch.get("loss_mask")) + + losses = pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=step_func, + data_iterator=_data_iter(), + model=[mimo_model], + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=communicator, + pg_collection=sched_pg, + ) + if in_llm: + assert isinstance(losses, list) + for ld in losses: + assert "loss_reduced" in ld + encoder_grid.destroy() + llm_grid.destroy() + + def test_mimo_mamba_moe_non_colocated_8gpu(self): + """E2E non-colocated: encoder on ranks 0-3, Mamba-MoE LLM on ranks 4-7. + + Bridge between the two rank slabs is ``MultiModulePipelineCommunicator``; + the schedule is ``forward_backward_pipelining_without_interleaving``. + Encoder DP=4, LLM DP=4 (symmetric, no MBS scaling needed). LLM grid carries + an alt-factorization ep=2 edp=2 over its dp=4 axis. + """ + if not dist.is_initialized() or dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + # parallel_state needs to match the LLM topology because HybridModel still + # reaches into log_on_each_pipeline_stage / get_tensor_model_parallel_*. We + # initialize it once with TP=1 EP=2 PP=1 (matching the LLM slab on ranks 4-7). + self._reset_global_singletons() + + torch.manual_seed(12345) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + + # Disjoint grids: encoder on ranks [0..3], LLM on ranks [4..7]. Each tp=1 dp=4 pp=1. + encoder_grid = HyperCommGrid( + shape=[1, 1, 4, 1], dim_names=["tp", "cp", "dp", "pp"], rank_offset=0, backend="nccl" + ) + llm_grid = HyperCommGrid( + shape=[1, 1, 4, 1], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=4, + backend="nccl", + alt_factorizations={ + "expert": { + "shape": [1, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + encoder_pg = ProcessGroupCollection.from_hyper_comm_grid(encoder_grid) + llm_pg = ProcessGroupCollection.from_hyper_comm_grid(llm_grid) + + rank = dist.get_rank() + in_encoder = 0 <= rank < 4 + in_llm = 4 <= rank < 8 + + # PP=1 stand-ins for embd / pos_embd, only on ranks that are in the relevant grid + # (``from_hyper_comm_grid`` only populates fields for ranks that are members). + if in_encoder: + encoder_pg.embd = encoder_pg.pp + encoder_pg.pos_embd = encoder_pg.pp + if in_llm: + llm_pg.embd = llm_pg.pp + llm_pg.pos_embd = llm_pg.pp + + # This test assumes symmetric DP (no MBS scaling across the bridge). + if in_encoder: + assert encoder_pg.dp.size() == 4 + if in_llm: + assert llm_pg.dp.size() == 4 + + # RNG init on whichever grid this rank belongs to (TP rank for both is 0 + # since tp=1; ep_rank only meaningful on LLM ranks). + ep_rank = dist.get_rank(group=llm_pg.ep) if in_llm else 0 + model_parallel_cuda_manual_seed(seed=12345, tp_rank=0, ep_rank=ep_rank, etp_rank=0) + if parallel_state._GLOBAL_MEMORY_BUFFER is None: + parallel_state._set_global_memory_buffer() + + hidden = 128 + num_experts = 4 + seq_length = 32 + image_seq_len = 8 + micro_batch_size = 2 + vocab_size = 128 + + encoder_name = "clip_encoder" + # Build language + vision specs with their own pg_collections. + language_spec = _build_mamba_moe_language_spec( + pg=llm_pg, + num_layers=4, + hidden=hidden, + num_experts=num_experts, + vocab_size=vocab_size, + seq_len=seq_length, + ) + # Use the simpler TransformerBlock encoder for this smoke test (RADIO would + # add image-tile shape constraints across the bridge — orthogonal to what + # this test verifies). + vision_spec = _build_vision_submodules_spec( + pg=encoder_pg, num_layers=1, hidden=hidden, language_hidden=hidden + ) + + module_to_grid_map = {encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid} + topology = {encoder_name: [MIMO_LANGUAGE_MODULE_KEY], MIMO_LANGUAGE_MODULE_KEY: []} + + mimo_config = MimoModelConfig( + language_model_spec=language_spec, + modality_submodules_spec={encoder_name: vision_spec}, + special_token_ids={encoder_name: 0}, + module_to_grid_map=module_to_grid_map, + ) + # tp_group / cp_group come from whichever side this rank lives on. + tp_group = llm_pg.tp if in_llm else encoder_pg.tp + cp_group = llm_pg.cp if in_llm else encoder_pg.cp + mimo_model = MimoModel(mimo_config, tp_group=tp_group, cp_group=cp_group).cuda() + mimo_model.model_type = ModelType.encoder_or_decoder + + # DDP-wrap each side independently with its own pg_collection. + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=False, + overlap_grad_reduce=False, + use_distributed_optimizer=False, + check_for_nan_in_grad=False, + bucket_size=None, + average_in_collective=False, + ) + if in_llm and mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=llm_pg, + ) + if in_encoder and encoder_name in mimo_model.modality_submodules: + sub = mimo_model.modality_submodules[encoder_name] + if sub is not None: + mimo_model.modality_submodules[encoder_name] = DistributedDataParallel( + config=sub.encoders["clip_encoder"].config, + ddp_config=ddp_config, + module=sub, + pg_collection=encoder_pg, + ) + + # Multi-module bridge: encoder hidden flows from encoder ranks to LLM ranks. + communicator = MultiModulePipelineCommunicator( + module_to_grid_map, + topology, + mimo_model.config, + dim_mapping={"s": 0, "h": 2, "b": 1}, + module_output_ndim={encoder_name: 2}, + ) + + # Per-rank data iterator: only encoder ranks (which feed images) and LLM ranks + # (which read text + run loss) need data. With encoder_dp == llm_dp the MBS is + # the same on both sides. + def _data_iter(): + while True: + # Mock encoder hidden states (the "image features" the bridge will ship). + encoder_hidden_states = torch.randn( + image_seq_len, micro_batch_size, hidden, device="cuda" + ) + input_ids = torch.randint( + 1, vocab_size, (micro_batch_size, seq_length), device="cuda" + ) + input_ids[:, :image_seq_len] = 0 + position_ids = ( + torch.arange(seq_length, device="cuda") + .unsqueeze(0) + .expand(micro_batch_size, -1) + ) + loss_mask = torch.ones((micro_batch_size, seq_length), device="cuda") + yield { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": None, + "loss_mask": loss_mask, + "modality_inputs": { + encoder_name: { + "clip_encoder": { + "hidden_states": encoder_hidden_states, + "attention_mask": None, + } + } + }, + } + + data_iterator = _data_iter() + + # Schedule's MultiModuleProcessGroupCollection: only this rank's pg. + module_pgs = {} + language_model_module_name = None + if in_encoder: + module_pgs[encoder_name] = encoder_pg + if in_llm: + module_pgs[MIMO_LANGUAGE_MODULE_KEY] = llm_pg + language_model_module_name = MIMO_LANGUAGE_MODULE_KEY + sched_pg = MultiModuleProcessGroupCollection( + module_pgs=module_pgs, language_model_module_name=language_model_module_name + ) + + def step_func(data_iterator, model): + def loss_func(loss_mask, output_tensor): + if output_tensor is None: + return torch.tensor(0.0, device="cuda", requires_grad=True), { + "loss_reduced": 0.0 + } + if isinstance(output_tensor, dict): + out = output_tensor.get( + MIMO_LANGUAGE_MODULE_KEY, next(iter(output_tensor.values()), None) + ) + else: + out = output_tensor + if out is None: + return torch.tensor(0.0, device="cuda", requires_grad=True), { + "loss_reduced": 0.0 + } + if isinstance(out, (tuple, list)): + out = out[0] + loss = out.float().sum() + return loss, {"loss_reduced": loss.detach()} + + batch = next(data_iterator) if data_iterator is not None else {} + output_tensor = model(**batch) + # MoE LLMs return ``(logits, extras)``; the schedule expects a single tensor. + if isinstance(output_tensor, tuple): + output_tensor = output_tensor[0] + return output_tensor, partial(loss_func, batch.get("loss_mask")) + + losses = pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving( + forward_step_func=step_func, + data_iterator=data_iterator, + model=[mimo_model], + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=communicator, + pg_collection=sched_pg, + ) + + # On LLM-last-stage ranks the schedule returns loss dicts; elsewhere []. + if in_llm: + assert isinstance(losses, list) + for ld in losses: + assert "loss_reduced" in ld + # loss_reduced may be 0.0 sentinel on intermediate stages; + # at least one rank in the LLM dp group should have a real tensor. + encoder_grid.destroy() + llm_grid.destroy() + + def test_mimo_nemotron_mamba_moe_colocated_8gpu(self): + """Smoke: literal Nemotron-shape Mamba-MoE LLM via HybridModel. + + HybridModel still has a few hard ``parallel_state`` touchpoints + (notably ``log_on_each_pipeline_stage``); for this smoke test we + minimally initialize parallel_state with shapes matching the + HyperCommGrid so those calls succeed. Real training would route + these through pg_collection — that's a follow-up cleanup beyond + the NMFW-464 Phase-1 substrate. + """ + if not dist.is_initialized() or dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + # Match LLM grid topology: tp=1 cp=1 dp=8 pp=1 with ep=4 etp=1 edp=2. + self._reset_global_singletons() + encoder_grid, llm_grid, encoder_pg, llm_pg = self._setup_pgs_and_rng() + hidden = 128 + language_spec = _build_mamba_moe_language_spec( + pg=llm_pg, num_layers=4, hidden=hidden, num_experts=4, vocab_size=128, seq_len=32 + ) + self._run_mimo_step( + language_spec, + hidden=hidden, + seq_length=32, + image_seq_len=8, + micro_batch_size=2, + vocab_size=128, + encoder_pg=encoder_pg, + llm_pg=llm_pg, + ) + encoder_grid.destroy() + llm_grid.destroy() diff --git a/tests/unit_tests/test_hyper_comm_grid.py b/tests/unit_tests/test_hyper_comm_grid.py index dd27f84f60d..7e6e43cdce5 100644 --- a/tests/unit_tests/test_hyper_comm_grid.py +++ b/tests/unit_tests/test_hyper_comm_grid.py @@ -310,6 +310,356 @@ def test_rank_enumeration_correctness(self): assert rank_enum_ab == expected_ab +class TestHyperCommGridAltFactorization: + """Tests for the alt-factorization feature (NMFW-464 expert overlap).""" + + def _expert_grid(self): + """Standard 8-rank LLM grid with expert overlap: tp=cp=dp=2 / ep=etp=edp=2.""" + return HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + + def test_world_size_unchanged_with_alt(self): + """The alt factorization must not inflate world_size.""" + grid = self._expert_grid() + # 8 ranks, not 8 * 8 = 64. + assert grid.size == 8 + + def test_alt_dim_names_registered(self): + """Alt dim names should be discoverable for dispatch.""" + grid = self._expert_grid() + assert grid._dim_to_alt == {"etp": "expert", "ep": "expert", "edp": "expert"} + + def test_alt_only_rank_enum_matches_primary(self): + """When alt shape == covered primary shape, alt axes enumerate the same rank groups + as the corresponding primary axes (just under different names). + """ + grid = self._expert_grid() + # In primary, tp / cp / dp produce these enumerations: + primary_tp = grid._gen_rank_enum(["tp"]) + primary_cp = grid._gen_rank_enum(["cp"]) + primary_dp = grid._gen_rank_enum(["dp"]) + + # In alt, etp / ep / edp sit at the same positions in the shadow layout. + etp_enum = grid.get_rank_enum("etp") + ep_enum = grid.get_rank_enum("ep") + edp_enum = grid.get_rank_enum("edp") + + assert etp_enum == primary_tp + assert ep_enum == primary_cp + assert edp_enum == primary_dp + + def test_alt_multi_dim_rank_enum(self): + """Multi-dim alt groups (e.g. tp_ep semantically: ['ep', 'etp']) match primary.""" + grid = self._expert_grid() + # Combined alt group [ep, etp] matches combined primary [cp, tp]. + assert grid.get_rank_enum(["ep", "etp"]) == grid._gen_rank_enum(["cp", "tp"]) + + def test_alt_with_shared_dim(self): + """Combining alt dims with a primary shared (uncovered) dim should work.""" + grid = HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + # ep + pp should resolve via alt shadow layout. + # With pp=1, this collapses to the same as ep alone. + assert grid.get_rank_enum(["ep", "pp"]) == grid.get_rank_enum("ep") + + def test_alt_with_nontrivial_pp(self): + """Multi-stage PP keeps the alt factorization confined to per-stage rank slabs.""" + os.environ["WORLD_SIZE"] = "16" + try: + grid = HyperCommGrid( + shape=[2, 2, 2, 2], # tp=2 cp=2 dp=2 pp=2 -> 16 ranks + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + ep_enum = grid.get_rank_enum("ep") + # ep groups must only contain ranks within the same PP stage. + for group in ep_enum: + stages = {r // 8 for r in group} + assert len(stages) == 1, ( + f"ep group {group} crosses PP stages {stages}; expert axes must be confined " + f"to a single PP stage's rank slab" + ) + finally: + os.environ["WORLD_SIZE"] = "8" + + def test_alt_unequal_shape(self): + """Alt factorization may differ in shape from primary covers as long as products match.""" + # Primary tp*cp*dp = 2*2*2 = 8. Alt re-factor as ep=4, etp=2, edp=1 (product 8). + grid = HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 4, 1], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + # ep has size 4, so 2 groups of 4 ranks each (one per edp value, with edp=1 it's one + # outer group; but let's check structure). + ep_enum = grid.get_rank_enum("ep") + # 8 / 4 = 2 groups, each of size 4. + assert len(ep_enum) == 2 + for group in ep_enum: + assert len(group) == 4 + # Together they must cover all 8 ranks exactly once. + flat = [r for grp in ep_enum for r in grp] + assert sorted(flat) == list(range(8)) + + def test_alt_constraint_violated_product_mismatch(self): + """Mismatched product must raise.""" + with pytest.raises(ValueError, match="product"): + HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 4], # product 16 != 8 + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + + def test_alt_constraint_violated_non_contiguous_replaces(self): + """``replaces`` that isn't a contiguous slice of primary dim_names must raise.""" + with pytest.raises(ValueError, match="contiguous"): + HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2], + "dim_names": ["etp", "ep"], + "replaces": ["tp", "dp"], # tp and dp skip over cp -> non-contiguous + } + }, + ) + + def test_alt_constraint_violated_unknown_cover(self): + """Cover entries must be primary dim names.""" + with pytest.raises(ValueError, match="not a primary dim"): + HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": {"shape": [2], "dim_names": ["ep"], "replaces": ["xx"]} + }, + ) + + def test_alt_dim_name_collision_with_primary(self): + """Alt dim_names must not collide with primary dim_names.""" + with pytest.raises(ValueError, match="collides with primary"): + HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["tp", "ep", "edp"], # tp collides + "replaces": ["tp", "cp", "dp"], + } + }, + ) + + def test_alt_dim_name_collision_across_alts(self): + """Two alt factorizations must not share dim names.""" + with pytest.raises(ValueError, match="collides with alt"): + HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + }, + "second": { + "shape": [2, 2, 2], + "dim_names": ["etp", "x", "y"], # etp re-used + "replaces": ["tp", "cp", "dp"], + }, + }, + ) + + def test_create_pg_rejects_mixing_replaced_primary_and_alt(self): + """tp + ep together is ambiguous; reject at create_pg time.""" + grid = self._expert_grid() + with pytest.raises(ValueError, match="combine replaced primary"): + grid.create_pg(["tp", "ep"]) + + def test_get_rank_enum_rejects_unknown_dim(self): + grid = self._expert_grid() + with pytest.raises(KeyError, match="not a primary or alt dim"): + grid.get_rank_enum("zz") + + @patch('torch.distributed.new_subgroups_by_enumeration') + def test_create_pg_alt_then_get_pg_alt(self, mock_new_subgroups): + """create_pg for an alt dim, get_pg returns the same group.""" + mock_pg = MagicMock(spec=dist.ProcessGroup) + mock_new_subgroups.return_value = (mock_pg, None) + + grid = self._expert_grid() + ep_pg = grid.create_pg("ep") + assert ep_pg == mock_pg + assert grid.get_pg("ep") == mock_pg + + # Verify the rank enumeration that was passed to new_subgroups_by_enumeration matches + # what the primary "cp" would produce, since etp/ep/edp share shapes with tp/cp/dp. + args, _ = mock_new_subgroups.call_args + assert args[0] == grid._gen_rank_enum(["cp"]) + + def test_alt_full_primary_no_shared_dims(self): + """Alt covering the full primary (no shared dims) is allowed.""" + grid = HyperCommGrid( + shape=[2, 2, 2], # tp cp dp; no pp + dim_names=["tp", "cp", "dp"], + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + assert grid.get_rank_enum("ep") == grid._gen_rank_enum(["cp"]) + assert grid.get_rank_enum(["ep", "etp"]) == grid._gen_rank_enum(["cp", "tp"]) + + def test_disjoint_alts_positive(self): + """Two alt factorizations may replace disjoint slices of the primary.""" + os.environ["WORLD_SIZE"] = "16" + try: + grid = HyperCommGrid( + shape=[2, 2, 2, 2], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2], + "dim_names": ["etp", "ep"], + "replaces": ["tp", "cp"], + }, + "video": {"shape": [2], "dim_names": ["vdp"], "replaces": ["dp"]}, + }, + ) + # vdp aliases dp; ep aliases cp. + assert grid.get_rank_enum("vdp") == grid._gen_rank_enum(["dp"]) + assert grid.get_rank_enum("ep") == grid._gen_rank_enum(["cp"]) + finally: + os.environ["WORLD_SIZE"] = "8" + + def test_disjoint_alts_overlapping_replaces_rejected(self): + """Two alt factorizations replacing overlapping primary dims must raise.""" + with pytest.raises(ValueError, match="already replaced"): + HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2], + "dim_names": ["etp", "ep"], + "replaces": ["tp", "cp"], + }, + "second": {"shape": [2], "dim_names": ["x"], "replaces": ["cp"]}, + }, + ) + + def test_resolve_rejects_mixing_two_alt_factorizations(self): + """Mixing dims from two different alt factorizations at request time must raise.""" + os.environ["WORLD_SIZE"] = "16" + try: + grid = HyperCommGrid( + shape=[2, 2, 2, 2], + dim_names=["tp", "cp", "dp", "pp"], + alt_factorizations={ + "expert": { + "shape": [2, 2], + "dim_names": ["etp", "ep"], + "replaces": ["tp", "cp"], + }, + "video": {"shape": [2], "dim_names": ["vdp"], "replaces": ["dp"]}, + }, + ) + with pytest.raises(ValueError, match="multiple alt factorizations"): + grid.get_rank_enum(["ep", "vdp"]) + finally: + os.environ["WORLD_SIZE"] = "8" + + def test_alt_with_rank_offset(self): + """rank_offset shifts the alt enumeration by the same amount as the primary.""" + os.environ["WORLD_SIZE"] = "16" + try: + grid = HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=8, # grid lives on ranks 8..15 + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + ep_enum = grid.get_rank_enum("ep") + cp_enum = grid._gen_rank_enum(["cp"]) + assert ep_enum == cp_enum + # Every rank in the enumeration must be in [8, 16). + for group in ep_enum: + for r in group: + assert 8 <= r < 16 + finally: + os.environ["WORLD_SIZE"] = "8" + + def test_get_rank_enum_multi_axis_alt_via_public_api(self): + """Multi-axis alt groups must be reachable through the public API.""" + grid = self._expert_grid() + assert grid.get_rank_enum(["ep", "etp"]) == grid._gen_rank_enum(["cp", "tp"]) + assert grid.get_rank_enum(["edp", "etp"]) == grid._gen_rank_enum(["dp", "tp"]) + + @patch('torch.distributed.new_subgroups_by_enumeration') + def test_create_pg_primary_and_alt_have_distinct_keys(self, mock_new_subgroups): + """Primary 'cp' and alt 'ep' both create groups; the keys must not collide.""" + mock_pg_cp = MagicMock(spec=dist.ProcessGroup) + mock_pg_ep = MagicMock(spec=dist.ProcessGroup) + mock_new_subgroups.side_effect = [(mock_pg_cp, None), (mock_pg_ep, None)] + + grid = self._expert_grid() + cp_pg = grid.create_pg("cp") + ep_pg = grid.create_pg("ep") + + assert cp_pg == mock_pg_cp + assert ep_pg == mock_pg_ep + assert grid.get_pg("cp") == mock_pg_cp + assert grid.get_pg("ep") == mock_pg_ep + assert "cp" in grid._pgs + assert "ep" in grid._pgs + + class TestHyperCommGridIntegration: """Integration tests for HyperCommGrid with real distributed initialization.""" @@ -513,6 +863,84 @@ def test_real_distributed_error_handling(self): with pytest.raises(KeyError, match="Process group.*has already been created"): grid.create_pg("tp") + def test_real_distributed_alt_factorization_overlap(self): + """Verify that alt-factorization expert groups live on the same ranks as the + primary tp/cp/dp groups (NMFW-464 overlap).""" + if not dist.is_initialized(): + pytest.skip("Distributed not initialized") + + world_size = dist.get_world_size() + if world_size != 8: + pytest.skip("This test specifically requires 8 GPUs") + + grid = HyperCommGrid( + shape=[2, 2, 2, 1], + dim_names=["tp", "cp", "dp", "pp"], + backend="nccl", + alt_factorizations={ + "expert": { + "shape": [2, 2, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + + # Build matching groups under each factorization. + tp_pg = grid.create_pg("tp") + cp_pg = grid.create_pg("cp") + dp_pg = grid.create_pg("dp") + etp_pg = grid.create_pg("etp") + ep_pg = grid.create_pg("ep") + edp_pg = grid.create_pg("edp") + + # Each pair must enumerate the SAME physical ranks for the current process. + assert dist.get_process_group_ranks(tp_pg) == dist.get_process_group_ranks( + etp_pg + ), "etp must alias tp ranks under expert overlap" + assert dist.get_process_group_ranks(cp_pg) == dist.get_process_group_ranks( + ep_pg + ), "ep must alias cp ranks under expert overlap" + assert dist.get_process_group_ranks(dp_pg) == dist.get_process_group_ranks( + edp_pg + ), "edp must alias dp ranks under expert overlap" + + # Sanity: communication actually works on the alt group (all-reduce within ep). + rank = dist.get_rank() + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + tensor = torch.tensor([rank], dtype=torch.float, device=device) + dist.all_reduce(tensor, group=ep_pg) + ep_ranks = dist.get_process_group_ranks(ep_pg) + assert tensor.item() == sum(ep_ranks) + + def test_real_distributed_alt_factorization_pp_confined(self): + """With PP>1, alt-factorization expert groups must stay within a PP stage's rank slab.""" + if not dist.is_initialized(): + pytest.skip("Distributed not initialized") + + world_size = dist.get_world_size() + if world_size != 8: + pytest.skip("This test specifically requires 8 GPUs") + + # tp=2 cp=1 dp=2 pp=2 -> 8 ranks; per-PP-stage slab has 4 ranks. + grid = HyperCommGrid( + shape=[2, 1, 2, 2], + dim_names=["tp", "cp", "dp", "pp"], + backend="nccl", + alt_factorizations={ + "expert": { + "shape": [2, 1, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + edp_pg = grid.create_pg("edp") + edp_ranks = dist.get_process_group_ranks(edp_pg) + # All ranks in the edp group must share a PP stage (ranks 0-3 or 4-7). + stages = {r // 4 for r in edp_ranks} + assert len(stages) == 1, f"edp group {edp_ranks} crosses PP stages {stages}" + def test_real_distributed_rank_enumeration_verification(self): """Verify rank enumeration produces correct communication patterns.""" if not dist.is_initialized(): diff --git a/tests/unit_tests/test_process_groups_config.py b/tests/unit_tests/test_process_groups_config.py index b49962b1a5a..a3eb8655365 100644 --- a/tests/unit_tests/test_process_groups_config.py +++ b/tests/unit_tests/test_process_groups_config.py @@ -134,3 +134,75 @@ def test_default_initialization(self): # Test that an error is raised if an invalid process group is requested with pytest.raises(ValueError, match=r"Invalid process groups requested"): model_pgs = ProcessGroupCollection.use_mpu_process_groups(['tp', 'pp', 'foo']) + + +class TestPGConfigFromHyperCommGrid: + """Build ProcessGroupCollection from a HyperCommGrid (no parallel_state init). + + Uses real distributed groups via NCCL on whatever world the runner provides. + """ + + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + try: + dist.init_process_group(backend="nccl") + except Exception as e: + pytest.skip(f"Cannot initialize distributed: {e}") + + def test_from_hyper_comm_grid_dense_8gpu(self): + if not dist.is_initialized(): + pytest.skip("Distributed not initialized") + if dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + + from megatron.core.hyper_comm_grid import HyperCommGrid + + grid = HyperCommGrid(shape=[2, 1, 4, 1], dim_names=["tp", "cp", "dp", "pp"], backend="nccl") + pg = ProcessGroupCollection.from_hyper_comm_grid(grid) + assert hasattr(pg, "tp") and pg.tp.size() == 2 + assert hasattr(pg, "cp") and pg.cp.size() == 1 + assert hasattr(pg, "dp") and pg.dp.size() == 4 + assert hasattr(pg, "pp") and pg.pp.size() == 1 + assert hasattr(pg, "dp_cp") and pg.dp_cp.size() == 4 + # No alt factorization -> expert fields are set to ``None`` (so callers like + # DDP's ``hasattr`` check pass uniformly without distinguishing MoE-grid + # from non-MoE-grid). + assert pg.ep is None + assert pg.expt_tp is None + assert pg.expt_dp is None + + def test_from_hyper_comm_grid_with_expert_alt_8gpu(self): + if not dist.is_initialized(): + pytest.skip("Distributed not initialized") + if dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + + from megatron.core.hyper_comm_grid import HyperCommGrid + + grid = HyperCommGrid( + shape=[2, 1, 4, 1], + dim_names=["tp", "cp", "dp", "pp"], + backend="nccl", + alt_factorizations={ + "expert": { + "shape": [1, 4, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + pg = ProcessGroupCollection.from_hyper_comm_grid(grid) + # Standard fields populated. + assert pg.tp.size() == 2 + assert pg.dp.size() == 4 + # Expert fields populated from alt factorization. + assert pg.ep.size() == 4 + assert pg.expt_tp.size() == 1 + assert pg.expt_dp.size() == 2 + # Combined expert groups. + assert pg.tp_ep.size() == 4 # etp(1) * ep(4) + # Structural invariant: ep ranks must live entirely within this rank's + # primary tp*cp*dp slab (i.e. within the same PP stage). With pp=1, that's + # the full world; the meaningful check is that the size matches. + assert len(dist.get_process_group_ranks(pg.ep)) == 4 diff --git a/tests/unit_tests/transformer/moe/test_moe_with_hcg_pg.py b/tests/unit_tests/transformer/moe/test_moe_with_hcg_pg.py new file mode 100644 index 00000000000..0e7ada04085 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_moe_with_hcg_pg.py @@ -0,0 +1,224 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Integration test: small MoE GPT model driven by ``forward_backward_no_pipelining`` +with process groups built from a HyperCommGrid (alt-factorization for EP overlap), +and NO global ``parallel_state`` initialization. + +Verifies the NMFW-464 wiring: HyperCommGrid → ProcessGroupCollection → MoE layers + +DDP + schedule, all without touching ``parallel_state``. +""" + +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist + +from megatron.core import parallel_state, pipeline_parallel +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.enums import AttnBackend +from megatron.core.transformer.transformer_config import TransformerConfig + + +def _build_grid_and_pg_collection(): + """Build an 8-rank HyperCommGrid with expert overlap and a pg_collection.""" + grid = HyperCommGrid( + shape=[2, 1, 4, 1], # tp=2 cp=1 dp=4 pp=1 + dim_names=["tp", "cp", "dp", "pp"], + backend="nccl", + alt_factorizations={ + "expert": { + # ep=4 carved out of the dp=4 axis; etp=1, edp=1. + # Product: 1 * 4 * 1 == 2 * 1 * 4 == 8. + "shape": [1, 4, 2], + "dim_names": ["etp", "ep", "edp"], + "replaces": ["tp", "cp", "dp"], + } + }, + ) + pg = ProcessGroupCollection.from_hyper_comm_grid(grid) + # forward_backward_no_pipelining reads pg.embd / pg.pos_embd if present; + # with pp=1 a singleton group per rank is fine. Reuse pp group. + pg.embd = pg.pp + pg.pos_embd = pg.pp + return grid, pg + + +def _mock_data_iterator(seq_length: int, micro_batch_size: int, vocab_size: int): + while True: + input_ids = torch.randint(0, vocab_size, (micro_batch_size, seq_length), device="cuda") + labels = torch.randint(0, vocab_size, (micro_batch_size, seq_length), device="cuda") + loss_mask = torch.ones((micro_batch_size, seq_length), device="cuda") + position_ids = ( + torch.arange(seq_length, device="cuda").unsqueeze(0).expand(micro_batch_size, -1) + ) + attention_mask = None + yield { + "tokens": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + + +def _loss_func(loss_mask, output_tensor): + """Sum-of-logits surrogate loss, matching test_mimo_1f1b_schedule. + + We avoid VocabParallelCrossEntropy because it calls ``parallel_state``; this test verifies + that the schedule path itself works without ``parallel_state`` initialization. + """ + loss = output_tensor.float().sum() + return loss, {"lm_loss": loss.detach()} + + +def _forward_step(data_iterator, model): + batch = next(data_iterator) + # Pass labels=None so the model returns logits and we can apply a no-parallel-state loss. + output_tensor = model( + input_ids=batch["tokens"], + position_ids=batch["position_ids"], + attention_mask=batch["attention_mask"], + ) + return output_tensor, partial(_loss_func, batch["loss_mask"]) + + +class TestMoEWithHCGProcessGroups: + @classmethod + def setup_class(cls): + if not dist.is_initialized(): + try: + dist.init_process_group(backend="nccl") + except Exception as e: + pytest.skip(f"Cannot initialize distributed: {e}") + os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") + + def test_moe_gpt_forward_backward_no_pipelining_8gpu(self): + if not dist.is_initialized() or dist.get_world_size() != 8: + pytest.skip("Requires exactly 8 GPUs") + + torch.manual_seed(12345) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + + grid, pg = _build_grid_and_pg_collection() + + # Tightening assertion (per code review): prove the alt-factorization path was taken + # by showing ep group ranks differ from dp group ranks. With primary tp=2 cp=1 dp=4 + # and alt etp=1 ep=4 edp=2, dp groups are [[0,2,4,6],[1,3,5,7]] but ep groups are + # [[0,1,2,3],[4,5,6,7]] — distinct. A silent orthogonal fallback would alias them. + ep_ranks = sorted(dist.get_process_group_ranks(pg.ep)) + dp_ranks = sorted(dist.get_process_group_ranks(pg.dp)) + assert ep_ranks != dp_ranks, ( + f"ep ranks {ep_ranks} match dp ranks {dp_ranks}; alt-factorization is likely " + f"silently aliasing dp instead of carving its own slab" + ) + + # Initialize model-parallel CUDA RNG tracker without parallel_state, by passing + # explicit ranks derived from the HyperCommGrid groups. + tp_rank = dist.get_rank(group=pg.tp) + ep_rank = dist.get_rank(group=pg.ep) + etp_rank = dist.get_rank(group=pg.expt_tp) + model_parallel_cuda_manual_seed( + seed=12345, tp_rank=tp_rank, ep_rank=ep_rank, etp_rank=etp_rank + ) + + # Sequence-parallel allgather buffer lives in parallel_state but is a side-effect-free + # global cache; we initialize it without touching the model-parallel groups. + if parallel_state._GLOBAL_MEMORY_BUFFER is None: + parallel_state._set_global_memory_buffer() + + num_experts = 4 + hidden = 64 + seq_length = 32 + micro_batch_size = 2 + vocab_size = 128 + + config = TransformerConfig( + num_layers=2, + hidden_size=hidden, + num_attention_heads=4, + ffn_hidden_size=4 * hidden, + num_moe_experts=num_experts, + moe_router_topk=2, + moe_token_dispatcher_type="alltoall", + moe_grouped_gemm=False, + moe_ffn_hidden_size=2 * hidden, + add_bias_linear=False, + use_cpu_initialization=True, + tensor_model_parallel_size=2, + context_parallel_size=1, + pipeline_model_parallel_size=1, + sequence_parallel=True, + calculate_per_token_loss=False, + bf16=False, + params_dtype=torch.float32, + attention_backend=AttnBackend.unfused, + ) + + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, moe_grouped_gemm=False + ) + + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=vocab_size, + max_sequence_length=seq_length, + pre_process=True, + post_process=True, + pg_collection=pg, + ).cuda() + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=False, + overlap_grad_reduce=False, + use_distributed_optimizer=False, + check_for_nan_in_grad=False, + bucket_size=None, + average_in_collective=False, + ) + ddp_model = DistributedDataParallel( + config=config, ddp_config=ddp_config, module=model, pg_collection=pg + ) + + data_iter = _mock_data_iterator(seq_length, micro_batch_size, vocab_size) + + losses = pipeline_parallel.schedules.forward_backward_no_pipelining( + forward_step_func=_forward_step, + data_iterator=data_iter, + model=ddp_model, + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + pg_collection=pg, + ) + + # Loss should be finite scalar. + assert isinstance(losses, list) and len(losses) == 1 + loss_dict = losses[0] + assert "lm_loss" in loss_dict + loss_val = loss_dict["lm_loss"] + assert torch.isfinite(loss_val).item(), f"loss not finite: {loss_val}" + + # At least one parameter must have a non-zero gradient. + any_grad = False + for p in model.parameters(): + if p.grad is not None and p.grad.abs().sum() > 0: + any_grad = True + break + if hasattr(p, "main_grad") and p.main_grad is not None and p.main_grad.abs().sum() > 0: + any_grad = True + break + assert any_grad, "no parameter received a non-zero gradient" + + # Cleanup: keep things tidy for the next test run inside the same process. + ddp_model.zero_grad_buffer() + del ddp_model, model + torch.cuda.empty_cache()