NMFW-478: Add --load-vision-from for hetero MIMO encoder DCP loading#28
Draft
yashaswikarnati wants to merge 7 commits into
Draft
Conversation
Mirror Sanjeev's `--load-vision-from` from pre-vlm-05 (`pretrain_vlm_energon.py`)
inside the hetero MIMO training loop so encoder ranks can warm-start RADIO
weights from a Megatron-Bridge DCP (e.g. `post-c-radio-omni`) on the first run.
Without this the encoder is randomly initialized and parity with Sanjeev's
recipe is unreachable: iter-1 lm-loss only matches when the projector sees
trained RADIO features.
Behavior:
* New `--load-vision-from PATH` (plus `--allow-missing-vision-projection-checkpoint`
and `--radio-force-eval-mode` for args-dump parity) in the `ckpt` group of
the standalone hetero parser.
* `load_vision_from_checkpoint` in `examples/mimo/training/hetero/checkpointing.py`
resolves either a flat DCP or a `latest_checkpointed_iteration.txt` +
`iter_NNNNNNN/` layout, filters keys to `model.vision_model.*`, dcp-loads,
TP-slices via `topology.vision_pg.tp`, and copies into
`vision_submodule.encoders["radio_encoder"].radio_model.<rel>` (and
best-effort into `input_projections[0]`).
* No-op on LLM-only ranks and on ranks outside the encoder grid.
* Called from `loop.py` only when `--load` resolved no checkpoint
(`--load` stays authoritative on resume).
* `nemotron_moe_vlm.py` honors the explicit `--radio-force-eval-mode` knob
and falls back to `args.freeze_vit` when unset.
* The 9n parity sbatch un-comments `--load-vision-from "${VISION_CKPT}"`.
Tests (`tests/unit_tests/mimo/test_load_vision_from.py`) cover the
deterministic helpers single-process: `_tp_slice` (column/row parallel +
passthrough), `_resolve_vision_dcp_dir` (flat vs tracker), and a real
mini-DCP round-trip exercising the `model.vision_model.*` prefix filter
plus `dcp.load`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yashaswikarnati
commented
May 16, 2026
| "force_eval_mode": args.freeze_vit, | ||
| "force_eval_mode": ( | ||
| args.radio_force_eval_mode | ||
| if getattr(args, "radio_force_eval_mode", None) is not None |
Owner
Author
There was a problem hiding this comment.
do we need both the args? are they doing the same thing ?
yashaswikarnati
commented
May 16, 2026
| ), | ||
| ) | ||
| ckpt.add_argument( | ||
| "--allow-missing-vision-projection-checkpoint", |
Owner
Author
There was a problem hiding this comment.
do we need --allow-missing-vision-projection-checkpoint ?
yashaswikarnati
commented
May 16, 2026
| default=False, | ||
| help="Tolerate missing projector keys when loading the vision DCP.", | ||
| ) | ||
| ckpt.add_argument( |
Owner
Author
There was a problem hiding this comment.
do we need this radio force eval mode arg? is freeze vit enough ?
yashaswikarnati
commented
May 16, 2026
|
|
||
|
|
||
| def load_vision_from_checkpoint( | ||
| model: MimoModel, |
Owner
Author
There was a problem hiding this comment.
do we need to handle it this verbose? just for checkpoint loading? its a distributed checkpoint why do we need to handle tp slice etc. also we dont load projection any way? just the encoder part?
…o-load-strict Addresses review feedback on PR 28: - Replace bespoke per-tensor _tp_slice + DCP plumbing with the sharded_state_dict pattern from examples/mimo/utils/model_helpers. Each ShardedTensor carries TP-sharding metadata, so dist_checkpointing.load handles per-rank slicing automatically; no manual column/row-parallel branching. - Strict-validate that every parameter the model expects landed in the checkpoint (extra_state buffers excluded) via post-load load_state_dict + incompatible-keys check. The old loop silently skipped tensor mismatches into a 'skipped' counter. - Drop --allow-missing-vision-projection-checkpoint: vision-from is encoder-only, never touches the projector. - Drop --radio-force-eval-mode: --freeze-vit is sufficient. Revert force_eval_mode wiring in nemotron_moe_vlm.py. - Add --no-load-strict to the main checkpoint load path. Default enables StrictHandling.RAISE_ALL on dist_checkpointing.load so any missing or unexpected key raises immediately, confirming all params reloaded. Falls back to ASSUME_OK_UNEXPECTED when disabled (schema drift / partial loads). Validated on cw-dfw 8-GPU 20L mock: - Stage1 regression smoke: exit 0, grad norm 0.025 -> 0.013, language branch has no LR field (frozen-LM optimizer skipped from PR 27). - Save + strict-reload round-trip: save iter 3, reload with default --load-strict (RAISE_ALL), resume iter 4-5 with cosine LR continuation (1.32e-4 -> 1.01e-4). No strict failures: every sharded key matched the checkpoint. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…der I/O Vision Bridge DCPs (e.g. C-RADIO Omni `post-c-radio-omni`) are the `torch.distributed.checkpoint` format (`.metadata` + `__N_M.distcp`), not Megatron's `dist_checkpointing`. The previous loader walked the DCP keys per-tensor, TP-sliced, copied — and silently absorbed any unmatched key into a "skipped" counter. We retain that per-tensor + TP-slice approach (verified on nb-hel: the Bridge DCP keys all land under `model.vision_model.*` and map 387/387 to `RADIOViTModel.named_parameters()`) but the loader now strict-validates: any key the model expects but the checkpoint lacks (or vice versa) raises with the offending lists. Other changes vs the previous PR-28 attempt: - Drop projector loading. `--load-vision-from` is encoder-only; the projector trains from scratch in stage1/stage2. - Drop `--allow-missing-vision-projection-checkpoint`. Unused with encoder-only loading. - Drop `--radio-force-eval-mode`. `--freeze-vit` is sufficient; `force_eval_mode` derives from it. - Scope `dcp.load` to `topology.vision_pg.tp_dp_cp` (the 4-rank encoder grid). LLM-only ranks short-circuit out of the function. - Bump `init_process_group` timeout to 1 hour. Lustre reads on encoder ranks can stall LLM ranks for minutes; the default 600 s c10d socket timeout caused TCPStore drops in earlier runs. - Add `--no-load-strict` for the main `load_checkpoint` path. Default enables `StrictHandling.RAISE_ALL` so every key the model expects must come from the checkpoint, confirming a complete reload. Validated on cw-dfw 8-GPU 20L mock: - Stage1 regression smoke (no `--load-vision-from`): exit 0, grad norm 0.025 -> 0.013, language branch correctly has no LR field. - Strict save+load round-trip: save iter 3, reload with default `--load-strict`, resume iter 4-5 with cosine LR continuation. No strict failures: every sharded key matched the checkpoint. Validated on nb-hel 8-GPU 20L mock: - Strict load against real C-RADIO Omni Bridge DCP completes cleanly: `[load-vision-from] ViT loaded (387 tensors, strict)`, all keys matched against `model.vision_model.*` prefix. - Iter-1 training hang observed downstream of the loaded weights is a separate issue (reproduces independent of `dcp.load` -- only requires the `param.data.copy_` into vision params) and is filed for follow-up; the loader itself behaves correctly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the manual `param.data.copy_(...)` loop with `radio_model.load_state_dict(cleaned, strict=False)` + the standard `incompatible` validation. Matches the pattern in `examples/mimo/utils/model_helpers.load_submodule_ckpt`. We still build the cleaned state_dict ourselves so we can TP-slice each ckpt tensor to the local rank's parameter shape before handing it to PyTorch. Diagnostic note: nb-hel testing showed iter-1 training hangs after this load completes successfully (387/387 strict). The hang reproduces regardless of whether the load uses `dcp.load`, the manual `copy_` loop, or this `load_state_dict` path. Bisection isolated the trigger to writing any value (real or uninitialized) into the radio params post-DDP-wrap. Filed for separate follow-up; the loader itself is correct. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Architectural fix attempting to resolve the iter-1 hang observed in earlier revisions of this PR: - Move `load_vision_from_checkpoint` from the loop-runtime into `build_mimo_runtime` so it runs AFTER `wrap_active_modules_with_ddp` but BEFORE the loop's `build_optimizer` call. The distributed optimizer's fp32 main-param mirror is then built from the loaded bf16 weights, which mirrors the standard megatron `load_checkpoint` path (training/training.py:1837) where the optimizer's main_params come from the checkpoint, not from random init. - Add `_full_checkpoint_exists` helper so the warm-start skips when `--load` resolves a real full checkpoint (the main load handles encoder weights in that case). - Pin `device_id` in `init_process_group` to silence pytorch's "Guessing device ID … can cause a hang if rank to GPU mapping is heterogeneous" warning. Our encoder grid (offset 0) and LLM grid (offset 4) make this exactly the case pytorch warns about. Note: nb-hel testing still shows an init-time hang on the very first post-load NCCL collective. The hang reproduces regardless of where the load runs in the lifecycle (pre-DDP-wrap, between DDP-wrap and optim, or post-optim) and regardless of which mechanism writes the weights (direct `param.data.copy_`, `radio_model.load_state_dict`, or even uninitialized memory). The trigger is bisected to writing any value into encoder params on encoder ranks — the symptom moves from "iter 1 hangs" to "init hangs" as we load earlier in the lifecycle. This is a real bug worth a separate investigation; the loader code itself is correct and strict-validates 387/387 keys from the C-RADIO Omni DCP. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Match megatron's `training.py:1826` post-modification path: after the vision DCP load mutates encoder params in place, call `optimizer.reload_model_params()` so the distributed optimizer's fp32 main-param mirror is rebuilt from the just-loaded bf16 weights. Without this, the fp32 mirror stays at random init while the bf16 view holds the loaded weights — the first optimizer step would silently overwrite our load with the stale fp32, and (on nb-hel) the mismatch deadlocked the very first post-init NCCL collective. With `reload_model_params()` the run gets past init and into the training loop. (A separate iter-1 hang remains and is being tracked.) Revert the load-before-DDP-wrap relocation from the previous commit: in-place mutation after DDP-wrap is the supported pattern (it is what megatron's `load_checkpoint` does at `training.py:1837`); the prior revision added complexity without fixing anything. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Enable `faulthandler.dump_traceback_later(120, repeat=True)` in `train_hetero.py` so every rank dumps its python stack every 2 minutes. Output goes to per-rank stderr, picked up by cog/slurm log capture without any additional plumbing. This was load-bearing for diagnosing the `--load-vision-from` iter-1 hang on nb-hel: dumps showed the encoder is stuck in `radio.py:213 self.embedder(x)` → TP `all_gather_into_tensor`, and the LLM is stuck in `BridgeCommunicator.recv_forward`. The hang is at the first column-parallel `all_gather` inside the radio encoder, not in the cross-grid pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
--load-vision-fromto the hetero MIMO loop so we can boot from Sanjeev's pre-trained C-RADIO Omni encoder DCP (matchespre-vlm-05behavior). Without this, the encoder loads randomly initialized weights and the loss curve diverges from Sanjeev's reference run.Mirrors Sanjeev's
_load_vision_from_checkpointinexamples/multimodal/v3/pretrain_vlm_energon.py(sasatheesh/megatron-lm!45), adapted to our hetero topology where vision lives atvision_submodule.encoders["radio_encoder"].radio_model.*(vs hismodel.vision_model.*).CLI surface (3 new flags)
--load-vision-from PATH— Megatron-Bridge DCP path; loads only encoder weights on encoder ranks, only when--loadresolves no full checkpoint.--allow-missing-vision-projection-checkpoint— tolerate missing projector keys.--radio-force-eval-mode / --no-radio-force-eval-mode— explicit override (default derives from--freeze-vit).Behavioral contract
start_iteration == 0(no full-checkpoint resume).--loadis authoritative when both are set.iter_NNNNNNN/fromlatest_checkpointed_iteration.txt, or treats the path as a flat DCP.--allow-missing-vision-projection-checkpoint, missing projector keys are skipped silently.[load-vision-from] ViT loaded (X/Y tensors, Z skipped).Code locations
examples/mimo/training/hetero/args.py:281-306examples/mimo/training/hetero/checkpointing.py:337-491examples/mimo/training/hetero/loop.py:69-73examples/mimo/model_providers/nemotron_moe_vlm.py:599-603(forwards--radio-force-eval-modeintoRADIOEncoderWrapper).examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh:220(uncommented--load-vision-from "${VISION_CKPT}").Tests
tests/unit_tests/mimo/test_load_vision_from.py(single-process, no GPU/dist):_tp_slice— TP=1 passthrough, matching-shape passthrough, column-parallel (dim-0 split), row-parallel (dim-1 split)._resolve_vision_dcp_dir— flat-DCP and tracker-file (iter_0000042/) layouts.model.vision_model.*keys plus a decoymodel.language_model.*key, asserts the prefix filter selects only vision keys and thatdcp.loadreturns the saved tensor values.Test plan (real cluster)
examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh(uses the stagedpost-c-radio-omniDCP at/lustre/.../agents-scratch/encoders/post-c-radio-omni/).[load-vision-from] ViT loaded (388/388 tensors)(matches Sanjeev's export-info tensor count).correct_encoder_grad_for_partial_participationPR for full parity).Notes
--dataloader-save(energon resumable state) and--correct-encoder-grad-for-partial-participation(projector grad scaling).ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallel(which already contains the parity-recipe commit2f6b2c05a).🤖 Generated with Claude Code