diff --git a/examples/mimo/model_providers/nemotron_moe_vlm.py b/examples/mimo/model_providers/nemotron_moe_vlm.py index 2129d943045..a93e6242116 100644 --- a/examples/mimo/model_providers/nemotron_moe_vlm.py +++ b/examples/mimo/model_providers/nemotron_moe_vlm.py @@ -10,6 +10,13 @@ import torch +from examples.mimo.utils.hetero import ( + debug_rank, + get_grid_dim_size, + get_group_rank_or, + get_group_size_or, + is_process_group_member, +) from megatron.core.activations import fast_gelu, squared_relu from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec @@ -28,14 +35,6 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import sharded_state_dict_default -from examples.mimo.utils.hetero import ( - debug_rank, - get_grid_dim_size, - get_group_rank_or, - get_group_size_or, - is_process_group_member, -) - try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, diff --git a/examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh b/examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh index 8f6ff13d024..ee738f69223 100755 --- a/examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh +++ b/examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh @@ -217,7 +217,7 @@ TRAIN_LAUNCH_ARGS=( --seed 1234 # Sanjeev's value --save "${CHECKPOINT_SAVE_PATH}" --save-interval "${SAVE_INTERVAL}" - # --load-vision-from "${VISION_CKPT}" # TODO: enable once PR_load_vision_from lands + --load-vision-from "${VISION_CKPT}" ) CONTAINER_MOUNTS="${SCRATCH_ROOT}:${SCRATCH_ROOT},/lustre/fsw/portfolios/llmservice:/lustre/fsw/portfolios/llmservice,/scratch/fsw/portfolios/llmservice:/scratch/fsw/portfolios/llmservice" diff --git a/examples/mimo/train_hetero.py b/examples/mimo/train_hetero.py index af46e3235ea..629c5732b31 100644 --- a/examples/mimo/train_hetero.py +++ b/examples/mimo/train_hetero.py @@ -2,6 +2,7 @@ """Standalone heterogeneous MIMO training entrypoint.""" +import faulthandler import os import sys @@ -24,6 +25,12 @@ def main() -> None: """Program entrypoint.""" + # Dump every rank's python stack every 120 s. Hands-off diagnostic for + # hetero MIMO hangs — output goes to each rank's stderr so cog/slurm log + # capture works without code changes. + faulthandler.enable() + faulthandler.dump_traceback_later(120, repeat=True) + args = parse_args() if args.enable_experimental: set_experimental_flag(True) diff --git a/examples/mimo/training/hetero/args.py b/examples/mimo/training/hetero/args.py index efd4f87136c..7d82050d5c7 100644 --- a/examples/mimo/training/hetero/args.py +++ b/examples/mimo/training/hetero/args.py @@ -264,6 +264,28 @@ def parse_args() -> argparse.Namespace: "skip optimizer + scheduler state regardless of the other flags." ), ) + ckpt.add_argument( + "--no-load-strict", + dest="load_strict", + action="store_false", + default=True, + help=( + "Disable strict checkpoint validation. By default the load path uses " + "StrictHandling.RAISE_ALL so missing or unexpected keys raise immediately, " + "confirming every parameter the model expects came from the checkpoint. " + "Pass --no-load-strict to fall back to ASSUME_OK_UNEXPECTED for schema drift." + ), + ) + ckpt.add_argument( + "--load-vision-from", + type=str, + default=None, + help=( + "Path to a Megatron-Bridge DCP containing `model.vision_model.*` keys. " + "Loaded only on encoder ranks, only on first run (when --load resolves no " + "checkpoint). Matches Sanjeev's --load-vision-from semantics from pre-vlm-05." + ), + ) ckpt.add_argument( "--dist-ckpt-optim-fully-reshardable", action=argparse.BooleanOptionalAction, diff --git a/examples/mimo/training/hetero/checkpointing.py b/examples/mimo/training/hetero/checkpointing.py index 90a40c42059..a13bab24e70 100644 --- a/examples/mimo/training/hetero/checkpointing.py +++ b/examples/mimo/training/hetero/checkpointing.py @@ -35,12 +35,16 @@ from megatron.core import dist_checkpointing, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject from megatron.core.dist_checkpointing.utils import _clean_metadata_for_serialization +from megatron.core.dist_checkpointing.validation import StrictHandling from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.optimizer import MimoOptimizer from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler _TRACKER_FILE = "latest_checkpointed_iteration.txt" _CHECKPOINT_VERSION = 3.0 +# Sanjeev's pre-vlm-05 Bridge DCP wraps the whole VLM under a top-level `model.` +# attribute, so the RADIO encoder's tensors are keyed `model.vision_model.`. +_VISION_DCP_PREFIX = "model.vision_model." def _iter_directory(root: str, iteration: int) -> str: @@ -312,7 +316,8 @@ def load_checkpoint( is_loading=True, ) - loaded = dist_checkpointing.load(sharded_state_dict, source_dir) + strict = StrictHandling.RAISE_ALL if args.load_strict else StrictHandling.ASSUME_OK_UNEXPECTED + loaded = dist_checkpointing.load(sharded_state_dict, source_dir, strict=strict) model.load_state_dict(loaded["model"], strict=True) @@ -332,3 +337,141 @@ def load_checkpoint( resume_iter = 0 if is_finetune else int(loaded.get("iteration", iteration)) print_rank_0(f"resuming hetero training at iteration {resume_iter}") return resume_iter + + +def _resolve_vision_dcp_dir(ckpt_dir: str) -> str: + """Resolve a flat-DCP or `iter_NNNNNNN/` directory under `ckpt_dir`.""" + tracker = os.path.join(ckpt_dir, _TRACKER_FILE) + if os.path.isfile(tracker): + with open(tracker) as f: + iteration = int(f.read().strip()) + return _iter_directory(ckpt_dir, iteration) + return ckpt_dir + + +def _radio_model(model: MimoModel, encoder_name: str, radio_encoder_key: str): + """Return the inner RADIOViTModel for the encoder branch, or None. + + Returns None on ranks outside the encoder branch (e.g. LLM-only ranks where + the modality submodule was never instantiated) and on encoder ranks whose + submodule does not contain the expected radio encoder key. + """ + submodules = getattr(model, "modality_submodules", None) + if submodules is None or encoder_name not in submodules: + return None + vision_submodule = submodules[encoder_name] + inner = getattr(vision_submodule, "module", vision_submodule) # unwrap DDP + encoders = getattr(inner, "encoders", None) + if encoders is None or radio_encoder_key not in encoders: + return None + return getattr(encoders[radio_encoder_key], "radio_model", None) + + +def _tp_slice(tensor: torch.Tensor, param_shape, tp_rank: int, tp_size: int) -> torch.Tensor: + """Slice a full (TP=1) tensor down to this rank's TP shard. + + Handles column-parallel (split on dim 0) and row-parallel (split on dim 1). + Returns the tensor unchanged when it already matches the param shape. + """ + if tp_size == 1 or tuple(tensor.shape) == tuple(param_shape): + return tensor + if tensor.shape[0] != param_shape[0]: + start = tp_rank * param_shape[0] + return tensor[start : start + param_shape[0], ...] + if len(tensor.shape) > 1 and tensor.shape[1] != param_shape[1]: + start = tp_rank * param_shape[1] + return tensor[:, start : start + param_shape[1]] + return tensor + + +def load_vision_from_checkpoint( + model: MimoModel, + args: argparse.Namespace, + topology: HeteroTopology, +) -> None: + """Load RADIO encoder weights from a torch-DCP Bridge checkpoint. + + The Bridge format (e.g. ``post-c-radio-omni``) is the + ``torch.distributed.checkpoint`` layout (``.metadata`` + ``__N_M.distcp``), + NOT Megatron's ``dist_checkpointing`` layout — so we cannot use + ``radio_model.sharded_state_dict`` + ``dist_checkpointing.load``. + + Strict guarantee: every RADIO parameter the model expects must come from + the checkpoint. Any unmatched key (missing or extra) raises. The old + silent ``skipped`` counter is gone. + + Only encoder ranks call ``dcp.load`` — vision weights belong only to the + encoder branch. LLM-only ranks early-return. The init_process_group + timeout is bumped to 1 hour so LLM ranks do not drop their c10d store + while encoder ranks do lustre I/O. + """ + if not args.load_vision_from: + return + if topology.encoder_grid is None or not is_rank_in_grid(topology.encoder_grid): + return + + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint.metadata import TensorStorageMetadata + + radio_encoder_key = getattr(args, "vision_encoder_key", "radio_encoder") + radio_model = _radio_model(model, topology.encoder_name, radio_encoder_key) + if radio_model is None: + return + + iter_dir = _resolve_vision_dcp_dir(args.load_vision_from) + print_rank_0(f"[load-vision-from] loading ViT via torch DCP from {iter_dir}") + + reader = FileSystemReader(iter_dir) + ckpt_meta = reader.read_metadata().state_dict_metadata + load_sd: Dict[str, torch.Tensor] = { + k: torch.empty(meta.size, dtype=meta.properties.dtype) + for k, meta in ckpt_meta.items() + if k.startswith(_VISION_DCP_PREFIX) and isinstance(meta, TensorStorageMetadata) + } + if not load_sd: + raise RuntimeError( + f"[load-vision-from] no '{_VISION_DCP_PREFIX}*' keys in {iter_dir}" + ) + + # Encoder-grid collective: only the 4 encoder ranks participate. LLM ranks + # short-circuited above; their c10d store stays alive thanks to the 1-hour + # init_process_group timeout in distributed.initialize_distributed(). + dcp.load(load_sd, storage_reader=reader, process_group=topology.vision_pg.tp_dp_cp) + + tp_pg = getattr(topology.vision_pg, "tp", None) if topology.vision_pg is not None else None + if is_process_group_member(tp_pg): + tp_rank, tp_size = tp_pg.rank(), tp_pg.size() + else: + tp_rank, tp_size = 0, 1 + + # Build a regular state_dict keyed by radio_model parameter names, TP-sliced + # for this rank's view, then run it through the canonical load_state_dict + # path so PyTorch handles any module hooks (DDP buckets, TE extra_state, …) + # rather than us mutating `param.data` directly. + cleaned: Dict[str, torch.Tensor] = {} + unexpected: list[str] = [] + radio_state = radio_model.state_dict() + for ckpt_key, tensor in load_sd.items(): + rel_key = ckpt_key[len(_VISION_DCP_PREFIX) :] + if "extra_state" in rel_key: + continue + target = radio_state.get(rel_key) + if target is None: + unexpected.append(rel_key) + continue + cleaned[rel_key] = _tp_slice(tensor, target.shape, tp_rank, tp_size).to( + dtype=target.dtype + ) + + incompatible = radio_model.load_state_dict(cleaned, strict=False) + missing = [k for k in incompatible.missing_keys if "extra_state" not in k] + extra = [k for k in incompatible.unexpected_keys if "extra_state" not in k] + unexpected + if missing or extra: + raise RuntimeError( + f"[load-vision-from] strict mismatch under '{_VISION_DCP_PREFIX}'. " + f"Missing (model expects but ckpt lacks): {missing}. " + f"Unexpected (ckpt has but model lacks): {extra}." + ) + + print_rank_0(f"[load-vision-from] ViT loaded ({len(cleaned)} tensors, strict)") diff --git a/examples/mimo/training/hetero/distributed.py b/examples/mimo/training/hetero/distributed.py index ad617c0f472..8bbf994d47a 100644 --- a/examples/mimo/training/hetero/distributed.py +++ b/examples/mimo/training/hetero/distributed.py @@ -4,6 +4,7 @@ from __future__ import annotations +import datetime import sys import torch @@ -19,7 +20,15 @@ def initialize_distributed() -> None: local_rank = int(os.environ.get("LOCAL_RANK", "0")) torch.cuda.set_device(local_rank) if not dist.is_initialized(): - dist.init_process_group(backend="nccl") + # 1-hour collective timeout: lustre Bridge-DCP reads on encoder ranks can + # leave LLM ranks idle for several minutes; default 600 s is too short. + # device_id is explicit: pytorch's auto-guess from global rank can cause + # hangs in heterogeneous topologies (encoder/LLM offset != 0). + dist.init_process_group( + backend="nccl", + timeout=datetime.timedelta(hours=1), + device_id=torch.device(f"cuda:{local_rank}"), + ) assert_megatron_parallel_state_uninitialized() try: parallel_state.get_global_memory_buffer() diff --git a/examples/mimo/training/hetero/loop.py b/examples/mimo/training/hetero/loop.py index ac3d723eda3..8641ab014b0 100644 --- a/examples/mimo/training/hetero/loop.py +++ b/examples/mimo/training/hetero/loop.py @@ -10,7 +10,11 @@ import torch from examples.mimo.training.hetero.args import prepare_args -from examples.mimo.training.hetero.checkpointing import load_checkpoint, save_checkpoint +from examples.mimo.training.hetero.checkpointing import ( + load_checkpoint, + load_vision_from_checkpoint, + save_checkpoint, +) from examples.mimo.training.hetero.data import select_data_iterator, validate_data_iterator from examples.mimo.training.hetero.distributed import print_rank_0 from examples.mimo.training.hetero.grad_sync import configure_grad_sync @@ -63,6 +67,16 @@ def run_train_loop(args: argparse.Namespace) -> None: debug_rank("training setup ready") start_iteration = load_checkpoint(model, optimizer, opt_param_scheduler, args, topology) + if start_iteration == 0 and args.load_vision_from is not None: + # Vision-only warm-start when `--load` did NOT resolve a full ckpt. + # The load mutates radio encoder params in place (post-DDP-wrap, + # post-optimizer build). We then call optimizer.reload_model_params() + # to refresh the distributed optimizer's fp32 main-param mirror from + # the just-loaded bf16 model — same pattern megatron uses after + # in-place param mutation in upcycling (training.py:1826). + load_vision_from_checkpoint(model, args, topology) + if optimizer is not None and not optimizer.is_stub_optimizer: + optimizer.reload_model_params() if start_iteration >= args.train_iters: print_rank_0( f"Resume iteration ({start_iteration}) >= --train-iters ({args.train_iters}); " diff --git a/examples/mimo/training/hetero/runtime.py b/examples/mimo/training/hetero/runtime.py index 0e340ef6473..92d130b3a0d 100644 --- a/examples/mimo/training/hetero/runtime.py +++ b/examples/mimo/training/hetero/runtime.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import os from typing import Iterator, Optional import torch @@ -77,6 +78,31 @@ def build_mimo_runtime(args: argparse.Namespace, topology: HeteroTopology) -> Mi return mimo_model +def _full_checkpoint_exists(load_root: Optional[str]) -> bool: + """Whether a full hetero checkpoint exists at ``load_root``. + + Mirrors `examples.mimo.training.hetero.checkpointing._read_tracker` but + returns a bool. Used to gate the vision warm-start: when ``--load`` + resolves a real checkpoint, the main `load_checkpoint` path will restore + encoder weights from it, so the vision DCP warm-start must skip. + """ + if not load_root: + return False + tracker = os.path.join(load_root, "latest_checkpointed_iteration.txt") + local_iter = -1 + if os.path.isfile(tracker): + try: + with open(tracker) as f: + local_iter = int(f.read().strip()) + except (ValueError, OSError): + local_iter = -1 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + iters = torch.tensor([local_iter], dtype=torch.long, device="cuda") + torch.distributed.all_reduce(iters, op=torch.distributed.ReduceOp.MAX) + return int(iters[0].item()) >= 0 + return local_iter >= 0 + + def _resolve_bucket_size( args: argparse.Namespace, module: Optional[torch.nn.Module] ) -> Optional[int]: diff --git a/tests/unit_tests/mimo/__init__.py b/tests/unit_tests/mimo/__init__.py new file mode 100644 index 00000000000..26496bfed70 --- /dev/null +++ b/tests/unit_tests/mimo/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/tests/unit_tests/mimo/test_load_vision_from.py b/tests/unit_tests/mimo/test_load_vision_from.py new file mode 100644 index 00000000000..afe164071b1 --- /dev/null +++ b/tests/unit_tests/mimo/test_load_vision_from.py @@ -0,0 +1,158 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for the hetero MIMO `--load-vision-from` loader. + +These tests target the deterministic helpers in +``examples.mimo.training.hetero.checkpointing``: + +* ``_tp_slice`` — pure-tensor TP slicing logic. +* DCP prefix filter + read — write a tiny Megatron-Bridge-shaped DCP with + ``model.vision_model.*`` keys and verify that the filter selects the right + subset and that ``dcp.load`` rehydrates the tensors. + +The full ``load_vision_from_checkpoint`` end-to-end path requires a real +``torch.distributed`` world plus a built ``MimoModel`` (see +``tests/unit_tests/models/test_mimo_checkpoint.py`` for that level of test). +We deliberately keep these tests single-process so they can run inside the +unit-test buckets without GPUs. +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest +import torch + +# Module under test — import here so a collection-time ImportError surfaces as +# a clear test failure rather than a cryptic skip. +from examples.mimo.training.hetero.checkpointing import ( + _VISION_DCP_PREFIX, + _resolve_vision_dcp_dir, + _tp_slice, +) + +# --------------------------------------------------------------------------- +# _tp_slice — column- and row-parallel sharding behavior. +# --------------------------------------------------------------------------- + + +def test_tp_slice_tp1_is_passthrough(): + """TP=1 must return the input tensor unchanged.""" + t = torch.arange(8).view(4, 2) + out = _tp_slice(t, t.shape, tp_rank=0, tp_size=1) + assert out is t + + +def test_tp_slice_matching_shape_is_passthrough(): + """Tensor already at the per-rank shape must be returned as-is.""" + t = torch.arange(4).view(2, 2) + out = _tp_slice(t, (2, 2), tp_rank=1, tp_size=2) + assert out is t + + +def test_tp_slice_column_parallel_splits_first_dim(): + """Column-parallel weight: full[out_size, in] split along dim 0.""" + full = torch.arange(16, dtype=torch.float32).view(8, 2) + param_shape = (4, 2) + shard0 = _tp_slice(full, param_shape, tp_rank=0, tp_size=2) + shard1 = _tp_slice(full, param_shape, tp_rank=1, tp_size=2) + assert tuple(shard0.shape) == param_shape + assert tuple(shard1.shape) == param_shape + torch.testing.assert_close(shard0, full[:4]) + torch.testing.assert_close(shard1, full[4:]) + + +def test_tp_slice_row_parallel_splits_second_dim(): + """Row-parallel weight: full[out, in_size] split along dim 1.""" + full = torch.arange(16, dtype=torch.float32).view(2, 8) + param_shape = (2, 4) + shard0 = _tp_slice(full, param_shape, tp_rank=0, tp_size=2) + shard1 = _tp_slice(full, param_shape, tp_rank=1, tp_size=2) + assert tuple(shard0.shape) == param_shape + assert tuple(shard1.shape) == param_shape + torch.testing.assert_close(shard0, full[:, :4]) + torch.testing.assert_close(shard1, full[:, 4:]) + + +# --------------------------------------------------------------------------- +# _resolve_vision_dcp_dir — tracker vs flat layout. +# --------------------------------------------------------------------------- + + +def test_resolve_vision_dcp_dir_flat(tmp_path: Path): + """Without a tracker file, the loader treats the path as a flat DCP.""" + assert _resolve_vision_dcp_dir(str(tmp_path)) == str(tmp_path) + + +def test_resolve_vision_dcp_dir_with_tracker(tmp_path: Path): + """A tracker file makes the loader descend into iter_NNNNNNN/.""" + (tmp_path / "latest_checkpointed_iteration.txt").write_text("42\n") + expected = os.path.join(str(tmp_path), "iter_0000042") + assert _resolve_vision_dcp_dir(str(tmp_path)) == expected + + +# --------------------------------------------------------------------------- +# DCP prefix filter — write a tiny Bridge-shaped DCP and round-trip it. +# --------------------------------------------------------------------------- + + +def _write_mock_vision_dcp(dcp_dir: str) -> dict[str, torch.Tensor]: + """Write a tiny DCP with two `model.vision_model.*` keys and one decoy key. + + Returns the saved state-dict so callers can assert exact tensor equality. + """ + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint import FileSystemWriter + + sd = { + "model.vision_model.embedder.weight": torch.arange(12, dtype=torch.float32).view(3, 4), + "model.vision_model.embedder.bias": torch.arange(3, dtype=torch.float32), + # Decoy: not under the vision prefix; loader must skip it. + "model.language_model.embed_tokens.weight": torch.zeros(2, 4, dtype=torch.float32), + } + writer = FileSystemWriter(dcp_dir, single_file_per_rank=True) + dcp.save(sd, storage_writer=writer, no_dist=True) + return sd + + +def test_dcp_prefix_filter_and_read(tmp_path: Path): + """The loader's prefix filter selects only `model.vision_model.*` keys + and `dcp.load` rehydrates them with the saved values. + + This test is process-local (no distributed init) — it exercises the same + metadata-read + filter + load path the real loader uses, isolated from + the MimoModel build. + """ + pytest.importorskip("torch.distributed.checkpoint") + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint.metadata import TensorStorageMetadata + + dcp_dir = tmp_path / "post-c-radio-omni" + dcp_dir.mkdir() + saved = _write_mock_vision_dcp(str(dcp_dir)) + + reader = FileSystemReader(str(dcp_dir)) + metadata = reader.read_metadata().state_dict_metadata + + # 1. Prefix filter: same one-liner the loader uses. + vision_keys = { + k + for k, meta in metadata.items() + if k.startswith(_VISION_DCP_PREFIX) and isinstance(meta, TensorStorageMetadata) + } + assert vision_keys == { + "model.vision_model.embedder.weight", + "model.vision_model.embedder.bias", + }, "decoy `model.language_model.*` key leaked through the prefix filter" + + # 2. Round-trip: build the empty-tensor request dict and dcp.load it. + load_sd = { + k: torch.empty(metadata[k].size, dtype=metadata[k].properties.dtype) for k in vision_keys + } + dcp.load(load_sd, storage_reader=reader) + + for k in vision_keys: + torch.testing.assert_close(load_sd[k], saved[k])