Skip to content
15 changes: 7 additions & 8 deletions examples/mimo/model_providers/nemotron_moe_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

import torch

from examples.mimo.utils.hetero import (
debug_rank,
get_grid_dim_size,
get_group_rank_or,
get_group_size_or,
is_process_group_member,
)
from megatron.core.activations import fast_gelu, squared_relu
from megatron.core.hyper_comm_grid import HyperCommGrid
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
Expand All @@ -28,14 +35,6 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import sharded_state_dict_default

from examples.mimo.utils.hetero import (
debug_rank,
get_grid_dim_size,
get_group_rank_or,
get_group_size_or,
is_process_group_member,
)

try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ TRAIN_LAUNCH_ARGS=(
--seed 1234 # Sanjeev's value
--save "${CHECKPOINT_SAVE_PATH}"
--save-interval "${SAVE_INTERVAL}"
# --load-vision-from "${VISION_CKPT}" # TODO: enable once PR_load_vision_from lands
--load-vision-from "${VISION_CKPT}"
)

CONTAINER_MOUNTS="${SCRATCH_ROOT}:${SCRATCH_ROOT},/lustre/fsw/portfolios/llmservice:/lustre/fsw/portfolios/llmservice,/scratch/fsw/portfolios/llmservice:/scratch/fsw/portfolios/llmservice"
Expand Down
7 changes: 7 additions & 0 deletions examples/mimo/train_hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Standalone heterogeneous MIMO training entrypoint."""

import faulthandler
import os
import sys

Expand All @@ -24,6 +25,12 @@

def main() -> None:
"""Program entrypoint."""
# Dump every rank's python stack every 120 s. Hands-off diagnostic for
# hetero MIMO hangs — output goes to each rank's stderr so cog/slurm log
# capture works without code changes.
faulthandler.enable()
faulthandler.dump_traceback_later(120, repeat=True)

args = parse_args()
if args.enable_experimental:
set_experimental_flag(True)
Expand Down
22 changes: 22 additions & 0 deletions examples/mimo/training/hetero/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,28 @@ def parse_args() -> argparse.Namespace:
"skip optimizer + scheduler state regardless of the other flags."
),
)
ckpt.add_argument(
"--no-load-strict",
dest="load_strict",
action="store_false",
default=True,
help=(
"Disable strict checkpoint validation. By default the load path uses "
"StrictHandling.RAISE_ALL so missing or unexpected keys raise immediately, "
"confirming every parameter the model expects came from the checkpoint. "
"Pass --no-load-strict to fall back to ASSUME_OK_UNEXPECTED for schema drift."
),
)
ckpt.add_argument(
"--load-vision-from",
type=str,
default=None,
help=(
"Path to a Megatron-Bridge DCP containing `model.vision_model.*` keys. "
"Loaded only on encoder ranks, only on first run (when --load resolves no "
"checkpoint). Matches Sanjeev's --load-vision-from semantics from pre-vlm-05."
),
)
ckpt.add_argument(
"--dist-ckpt-optim-fully-reshardable",
action=argparse.BooleanOptionalAction,
Expand Down
145 changes: 144 additions & 1 deletion examples/mimo/training/hetero/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@
from megatron.core import dist_checkpointing, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedObject
from megatron.core.dist_checkpointing.utils import _clean_metadata_for_serialization
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.models.mimo.model.base import MimoModel
from megatron.core.models.mimo.optimizer import MimoOptimizer
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler

_TRACKER_FILE = "latest_checkpointed_iteration.txt"
_CHECKPOINT_VERSION = 3.0
# Sanjeev's pre-vlm-05 Bridge DCP wraps the whole VLM under a top-level `model.`
# attribute, so the RADIO encoder's tensors are keyed `model.vision_model.<radio>`.
_VISION_DCP_PREFIX = "model.vision_model."


def _iter_directory(root: str, iteration: int) -> str:
Expand Down Expand Up @@ -312,7 +316,8 @@ def load_checkpoint(
is_loading=True,
)

loaded = dist_checkpointing.load(sharded_state_dict, source_dir)
strict = StrictHandling.RAISE_ALL if args.load_strict else StrictHandling.ASSUME_OK_UNEXPECTED
loaded = dist_checkpointing.load(sharded_state_dict, source_dir, strict=strict)

model.load_state_dict(loaded["model"], strict=True)

Expand All @@ -332,3 +337,141 @@ def load_checkpoint(
resume_iter = 0 if is_finetune else int(loaded.get("iteration", iteration))
print_rank_0(f"resuming hetero training at iteration {resume_iter}")
return resume_iter


def _resolve_vision_dcp_dir(ckpt_dir: str) -> str:
"""Resolve a flat-DCP or `iter_NNNNNNN/` directory under `ckpt_dir`."""
tracker = os.path.join(ckpt_dir, _TRACKER_FILE)
if os.path.isfile(tracker):
with open(tracker) as f:
iteration = int(f.read().strip())
return _iter_directory(ckpt_dir, iteration)
return ckpt_dir


def _radio_model(model: MimoModel, encoder_name: str, radio_encoder_key: str):
"""Return the inner RADIOViTModel for the encoder branch, or None.

Returns None on ranks outside the encoder branch (e.g. LLM-only ranks where
the modality submodule was never instantiated) and on encoder ranks whose
submodule does not contain the expected radio encoder key.
"""
submodules = getattr(model, "modality_submodules", None)
if submodules is None or encoder_name not in submodules:
return None
vision_submodule = submodules[encoder_name]
inner = getattr(vision_submodule, "module", vision_submodule) # unwrap DDP
encoders = getattr(inner, "encoders", None)
if encoders is None or radio_encoder_key not in encoders:
return None
return getattr(encoders[radio_encoder_key], "radio_model", None)


def _tp_slice(tensor: torch.Tensor, param_shape, tp_rank: int, tp_size: int) -> torch.Tensor:
"""Slice a full (TP=1) tensor down to this rank's TP shard.

Handles column-parallel (split on dim 0) and row-parallel (split on dim 1).
Returns the tensor unchanged when it already matches the param shape.
"""
if tp_size == 1 or tuple(tensor.shape) == tuple(param_shape):
return tensor
if tensor.shape[0] != param_shape[0]:
start = tp_rank * param_shape[0]
return tensor[start : start + param_shape[0], ...]
if len(tensor.shape) > 1 and tensor.shape[1] != param_shape[1]:
start = tp_rank * param_shape[1]
return tensor[:, start : start + param_shape[1]]
return tensor


def load_vision_from_checkpoint(
model: MimoModel,
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle it this verbose? just for checkpoint loading? its a distributed checkpoint why do we need to handle tp slice etc. also we dont load projection any way? just the encoder part?

args: argparse.Namespace,
topology: HeteroTopology,
) -> None:
"""Load RADIO encoder weights from a torch-DCP Bridge checkpoint.

The Bridge format (e.g. ``post-c-radio-omni``) is the
``torch.distributed.checkpoint`` layout (``.metadata`` + ``__N_M.distcp``),
NOT Megatron's ``dist_checkpointing`` layout — so we cannot use
``radio_model.sharded_state_dict`` + ``dist_checkpointing.load``.

Strict guarantee: every RADIO parameter the model expects must come from
the checkpoint. Any unmatched key (missing or extra) raises. The old
silent ``skipped`` counter is gone.

Only encoder ranks call ``dcp.load`` — vision weights belong only to the
encoder branch. LLM-only ranks early-return. The init_process_group
timeout is bumped to 1 hour so LLM ranks do not drop their c10d store
while encoder ranks do lustre I/O.
"""
if not args.load_vision_from:
return
if topology.encoder_grid is None or not is_rank_in_grid(topology.encoder_grid):
return

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.metadata import TensorStorageMetadata

radio_encoder_key = getattr(args, "vision_encoder_key", "radio_encoder")
radio_model = _radio_model(model, topology.encoder_name, radio_encoder_key)
if radio_model is None:
return

iter_dir = _resolve_vision_dcp_dir(args.load_vision_from)
print_rank_0(f"[load-vision-from] loading ViT via torch DCP from {iter_dir}")

reader = FileSystemReader(iter_dir)
ckpt_meta = reader.read_metadata().state_dict_metadata
load_sd: Dict[str, torch.Tensor] = {
k: torch.empty(meta.size, dtype=meta.properties.dtype)
for k, meta in ckpt_meta.items()
if k.startswith(_VISION_DCP_PREFIX) and isinstance(meta, TensorStorageMetadata)
}
if not load_sd:
raise RuntimeError(
f"[load-vision-from] no '{_VISION_DCP_PREFIX}*' keys in {iter_dir}"
)

# Encoder-grid collective: only the 4 encoder ranks participate. LLM ranks
# short-circuited above; their c10d store stays alive thanks to the 1-hour
# init_process_group timeout in distributed.initialize_distributed().
dcp.load(load_sd, storage_reader=reader, process_group=topology.vision_pg.tp_dp_cp)

tp_pg = getattr(topology.vision_pg, "tp", None) if topology.vision_pg is not None else None
if is_process_group_member(tp_pg):
tp_rank, tp_size = tp_pg.rank(), tp_pg.size()
else:
tp_rank, tp_size = 0, 1

# Build a regular state_dict keyed by radio_model parameter names, TP-sliced
# for this rank's view, then run it through the canonical load_state_dict
# path so PyTorch handles any module hooks (DDP buckets, TE extra_state, …)
# rather than us mutating `param.data` directly.
cleaned: Dict[str, torch.Tensor] = {}
unexpected: list[str] = []
radio_state = radio_model.state_dict()
for ckpt_key, tensor in load_sd.items():
rel_key = ckpt_key[len(_VISION_DCP_PREFIX) :]
if "extra_state" in rel_key:
continue
target = radio_state.get(rel_key)
if target is None:
unexpected.append(rel_key)
continue
cleaned[rel_key] = _tp_slice(tensor, target.shape, tp_rank, tp_size).to(
dtype=target.dtype
)

incompatible = radio_model.load_state_dict(cleaned, strict=False)
missing = [k for k in incompatible.missing_keys if "extra_state" not in k]
extra = [k for k in incompatible.unexpected_keys if "extra_state" not in k] + unexpected
if missing or extra:
raise RuntimeError(
f"[load-vision-from] strict mismatch under '{_VISION_DCP_PREFIX}'. "
f"Missing (model expects but ckpt lacks): {missing}. "
f"Unexpected (ckpt has but model lacks): {extra}."
)

print_rank_0(f"[load-vision-from] ViT loaded ({len(cleaned)} tensors, strict)")
11 changes: 10 additions & 1 deletion examples/mimo/training/hetero/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import datetime
import sys

import torch
Expand All @@ -19,7 +20,15 @@ def initialize_distributed() -> None:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# 1-hour collective timeout: lustre Bridge-DCP reads on encoder ranks can
# leave LLM ranks idle for several minutes; default 600 s is too short.
# device_id is explicit: pytorch's auto-guess from global rank can cause
# hangs in heterogeneous topologies (encoder/LLM offset != 0).
dist.init_process_group(
backend="nccl",
timeout=datetime.timedelta(hours=1),
device_id=torch.device(f"cuda:{local_rank}"),
)
assert_megatron_parallel_state_uninitialized()
try:
parallel_state.get_global_memory_buffer()
Expand Down
16 changes: 15 additions & 1 deletion examples/mimo/training/hetero/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import torch

from examples.mimo.training.hetero.args import prepare_args
from examples.mimo.training.hetero.checkpointing import load_checkpoint, save_checkpoint
from examples.mimo.training.hetero.checkpointing import (
load_checkpoint,
load_vision_from_checkpoint,
save_checkpoint,
)
from examples.mimo.training.hetero.data import select_data_iterator, validate_data_iterator
from examples.mimo.training.hetero.distributed import print_rank_0
from examples.mimo.training.hetero.grad_sync import configure_grad_sync
Expand Down Expand Up @@ -63,6 +67,16 @@ def run_train_loop(args: argparse.Namespace) -> None:
debug_rank("training setup ready")

start_iteration = load_checkpoint(model, optimizer, opt_param_scheduler, args, topology)
if start_iteration == 0 and args.load_vision_from is not None:
# Vision-only warm-start when `--load` did NOT resolve a full ckpt.
# The load mutates radio encoder params in place (post-DDP-wrap,
# post-optimizer build). We then call optimizer.reload_model_params()
# to refresh the distributed optimizer's fp32 main-param mirror from
# the just-loaded bf16 model — same pattern megatron uses after
# in-place param mutation in upcycling (training.py:1826).
load_vision_from_checkpoint(model, args, topology)
if optimizer is not None and not optimizer.is_stub_optimizer:
optimizer.reload_model_params()
if start_iteration >= args.train_iters:
print_rank_0(
f"Resume iteration ({start_iteration}) >= --train-iters ({args.train_iters}); "
Expand Down
26 changes: 26 additions & 0 deletions examples/mimo/training/hetero/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import os
from typing import Iterator, Optional

import torch
Expand Down Expand Up @@ -77,6 +78,31 @@ def build_mimo_runtime(args: argparse.Namespace, topology: HeteroTopology) -> Mi
return mimo_model


def _full_checkpoint_exists(load_root: Optional[str]) -> bool:
"""Whether a full hetero checkpoint exists at ``load_root``.

Mirrors `examples.mimo.training.hetero.checkpointing._read_tracker` but
returns a bool. Used to gate the vision warm-start: when ``--load``
resolves a real checkpoint, the main `load_checkpoint` path will restore
encoder weights from it, so the vision DCP warm-start must skip.
"""
if not load_root:
return False
tracker = os.path.join(load_root, "latest_checkpointed_iteration.txt")
local_iter = -1
if os.path.isfile(tracker):
try:
with open(tracker) as f:
local_iter = int(f.read().strip())
except (ValueError, OSError):
local_iter = -1
if torch.distributed.is_available() and torch.distributed.is_initialized():
iters = torch.tensor([local_iter], dtype=torch.long, device="cuda")
torch.distributed.all_reduce(iters, op=torch.distributed.ReduceOp.MAX)
return int(iters[0].item()) >= 0
return local_iter >= 0


def _resolve_bucket_size(
args: argparse.Namespace, module: Optional[torch.nn.Module]
) -> Optional[int]:
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/mimo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
Loading