Skip to content

NMFW-478: Add --load-vision-from for hetero MIMO encoder DCP loading#28

Draft
yashaswikarnati wants to merge 7 commits into
ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallelfrom
ykarnati/nmfw-478-load-vision-from
Draft

NMFW-478: Add --load-vision-from for hetero MIMO encoder DCP loading#28
yashaswikarnati wants to merge 7 commits into
ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallelfrom
ykarnati/nmfw-478-load-vision-from

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Owner

Summary

Adds --load-vision-from to the hetero MIMO loop so we can boot from Sanjeev's pre-trained C-RADIO Omni encoder DCP (matches pre-vlm-05 behavior). Without this, the encoder loads randomly initialized weights and the loss curve diverges from Sanjeev's reference run.

Mirrors Sanjeev's _load_vision_from_checkpoint in examples/multimodal/v3/pretrain_vlm_energon.py (sasatheesh/megatron-lm!45), adapted to our hetero topology where vision lives at vision_submodule.encoders["radio_encoder"].radio_model.* (vs his model.vision_model.*).

CLI surface (3 new flags)

  • --load-vision-from PATH — Megatron-Bridge DCP path; loads only encoder weights on encoder ranks, only when --load resolves no full checkpoint.
  • --allow-missing-vision-projection-checkpoint — tolerate missing projector keys.
  • --radio-force-eval-mode / --no-radio-force-eval-mode — explicit override (default derives from --freeze-vit).

Behavioral contract

  • LLM-only ranks and ranks outside the encoder grid: early-return no-op.
  • Active only when start_iteration == 0 (no full-checkpoint resume). --load is authoritative when both are set.
  • Resolves iter_NNNNNNN/ from latest_checkpointed_iteration.txt, or treats the path as a flat DCP.
  • TP-slices encoder weights via the encoder grid's TP group; column-parallel (dim-0) and row-parallel (dim-1) handled.
  • Best-effort projector load; with --allow-missing-vision-projection-checkpoint, missing projector keys are skipped silently.
  • Rank-0 emits [load-vision-from] ViT loaded (X/Y tensors, Z skipped).

Code locations

  • CLI: examples/mimo/training/hetero/args.py:281-306
  • Loader + helpers: examples/mimo/training/hetero/checkpointing.py:337-491
  • Wire-in: examples/mimo/training/hetero/loop.py:69-73
  • Provider hook: examples/mimo/model_providers/nemotron_moe_vlm.py:599-603 (forwards --radio-force-eval-mode into RADIOEncoderWrapper).
  • Parity sbatch: examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh:220 (uncommented --load-vision-from "${VISION_CKPT}").

Tests

tests/unit_tests/mimo/test_load_vision_from.py (single-process, no GPU/dist):

  • _tp_slice — TP=1 passthrough, matching-shape passthrough, column-parallel (dim-0 split), row-parallel (dim-1 split).
  • _resolve_vision_dcp_dir — flat-DCP and tracker-file (iter_0000042/) layouts.
  • DCP round-trip — writes a Bridge-shaped DCP with two model.vision_model.* keys plus a decoy model.language_model.* key, asserts the prefix filter selects only vision keys and that dcp.load returns the saved tensor values.

Test plan (real cluster)

  • Submit examples/mimo/scripts/sbatch_hetero_nemotron_54l_hel_9n_parity.sh (uses the staged post-c-radio-omni DCP at /lustre/.../agents-scratch/encoders/post-c-radio-omni/).
  • Verify rank-0 log line [load-vision-from] ViT loaded (388/388 tensors) (matches Sanjeev's export-info tensor count).
  • Confirm encoder param sample hash matches the source DCP.
  • Confirm iter-1 lm_loss is in the expected range (~12.19 ± noise per Sanjeev's job 202967 — pending the correct_encoder_grad_for_partial_participation PR for full parity).

Notes

  • This PR is part of the NMFW-478 VLM training parity series. Two related PRs to follow: --dataloader-save (energon resumable state) and --correct-encoder-grad-for-partial-participation (projector grad scaling).
  • The branch base is ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallel (which already contains the parity-recipe commit 2f6b2c05a).

🤖 Generated with Claude Code

Mirror Sanjeev's `--load-vision-from` from pre-vlm-05 (`pretrain_vlm_energon.py`)
inside the hetero MIMO training loop so encoder ranks can warm-start RADIO
weights from a Megatron-Bridge DCP (e.g. `post-c-radio-omni`) on the first run.
Without this the encoder is randomly initialized and parity with Sanjeev's
recipe is unreachable: iter-1 lm-loss only matches when the projector sees
trained RADIO features.

Behavior:
* New `--load-vision-from PATH` (plus `--allow-missing-vision-projection-checkpoint`
  and `--radio-force-eval-mode` for args-dump parity) in the `ckpt` group of
  the standalone hetero parser.
* `load_vision_from_checkpoint` in `examples/mimo/training/hetero/checkpointing.py`
  resolves either a flat DCP or a `latest_checkpointed_iteration.txt` +
  `iter_NNNNNNN/` layout, filters keys to `model.vision_model.*`, dcp-loads,
  TP-slices via `topology.vision_pg.tp`, and copies into
  `vision_submodule.encoders["radio_encoder"].radio_model.<rel>` (and
  best-effort into `input_projections[0]`).
* No-op on LLM-only ranks and on ranks outside the encoder grid.
* Called from `loop.py` only when `--load` resolved no checkpoint
  (`--load` stays authoritative on resume).
* `nemotron_moe_vlm.py` honors the explicit `--radio-force-eval-mode` knob
  and falls back to `args.freeze_vit` when unset.
* The 9n parity sbatch un-comments `--load-vision-from "${VISION_CKPT}"`.

Tests (`tests/unit_tests/mimo/test_load_vision_from.py`) cover the
deterministic helpers single-process: `_tp_slice` (column/row parallel +
passthrough), `_resolve_vision_dcp_dir` (flat vs tracker), and a real
mini-DCP round-trip exercising the `model.vision_model.*` prefix filter
plus `dcp.load`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
"force_eval_mode": args.freeze_vit,
"force_eval_mode": (
args.radio_force_eval_mode
if getattr(args, "radio_force_eval_mode", None) is not None
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need both the args? are they doing the same thing ?

Comment thread examples/mimo/training/hetero/args.py Outdated
),
)
ckpt.add_argument(
"--allow-missing-vision-projection-checkpoint",
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need --allow-missing-vision-projection-checkpoint ?

Comment thread examples/mimo/training/hetero/args.py Outdated
default=False,
help="Tolerate missing projector keys when loading the vision DCP.",
)
ckpt.add_argument(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this radio force eval mode arg? is freeze vit enough ?



def load_vision_from_checkpoint(
model: MimoModel,
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle it this verbose? just for checkpoint loading? its a distributed checkpoint why do we need to handle tp slice etc. also we dont load projection any way? just the encoder part?

yashaswikarnati and others added 6 commits May 16, 2026 19:20
…o-load-strict

Addresses review feedback on PR 28:
- Replace bespoke per-tensor _tp_slice + DCP plumbing with the
  sharded_state_dict pattern from examples/mimo/utils/model_helpers.
  Each ShardedTensor carries TP-sharding metadata, so
  dist_checkpointing.load handles per-rank slicing automatically;
  no manual column/row-parallel branching.
- Strict-validate that every parameter the model expects landed in
  the checkpoint (extra_state buffers excluded) via post-load
  load_state_dict + incompatible-keys check. The old loop silently
  skipped tensor mismatches into a 'skipped' counter.
- Drop --allow-missing-vision-projection-checkpoint: vision-from is
  encoder-only, never touches the projector.
- Drop --radio-force-eval-mode: --freeze-vit is sufficient. Revert
  force_eval_mode wiring in nemotron_moe_vlm.py.
- Add --no-load-strict to the main checkpoint load path. Default
  enables StrictHandling.RAISE_ALL on dist_checkpointing.load so any
  missing or unexpected key raises immediately, confirming all params
  reloaded. Falls back to ASSUME_OK_UNEXPECTED when disabled (schema
  drift / partial loads).

Validated on cw-dfw 8-GPU 20L mock:
- Stage1 regression smoke: exit 0, grad norm 0.025 -> 0.013, language
  branch has no LR field (frozen-LM optimizer skipped from PR 27).
- Save + strict-reload round-trip: save iter 3, reload with default
  --load-strict (RAISE_ALL), resume iter 4-5 with cosine LR
  continuation (1.32e-4 -> 1.01e-4). No strict failures: every
  sharded key matched the checkpoint.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…der I/O

Vision Bridge DCPs (e.g. C-RADIO Omni `post-c-radio-omni`) are the
`torch.distributed.checkpoint` format (`.metadata` + `__N_M.distcp`), not
Megatron's `dist_checkpointing`. The previous loader walked the DCP keys
per-tensor, TP-sliced, copied — and silently absorbed any unmatched key
into a "skipped" counter. We retain that per-tensor + TP-slice approach
(verified on nb-hel: the Bridge DCP keys all land under
`model.vision_model.*` and map 387/387 to `RADIOViTModel.named_parameters()`)
but the loader now strict-validates: any key the model expects but the
checkpoint lacks (or vice versa) raises with the offending lists.

Other changes vs the previous PR-28 attempt:
- Drop projector loading. `--load-vision-from` is encoder-only; the
  projector trains from scratch in stage1/stage2.
- Drop `--allow-missing-vision-projection-checkpoint`. Unused with
  encoder-only loading.
- Drop `--radio-force-eval-mode`. `--freeze-vit` is sufficient;
  `force_eval_mode` derives from it.
- Scope `dcp.load` to `topology.vision_pg.tp_dp_cp` (the 4-rank encoder
  grid). LLM-only ranks short-circuit out of the function.
- Bump `init_process_group` timeout to 1 hour. Lustre reads on encoder
  ranks can stall LLM ranks for minutes; the default 600 s c10d socket
  timeout caused TCPStore drops in earlier runs.
- Add `--no-load-strict` for the main `load_checkpoint` path. Default
  enables `StrictHandling.RAISE_ALL` so every key the model expects must
  come from the checkpoint, confirming a complete reload.

Validated on cw-dfw 8-GPU 20L mock:
- Stage1 regression smoke (no `--load-vision-from`): exit 0, grad norm
  0.025 -> 0.013, language branch correctly has no LR field.
- Strict save+load round-trip: save iter 3, reload with default
  `--load-strict`, resume iter 4-5 with cosine LR continuation. No
  strict failures: every sharded key matched the checkpoint.

Validated on nb-hel 8-GPU 20L mock:
- Strict load against real C-RADIO Omni Bridge DCP completes cleanly:
  `[load-vision-from] ViT loaded (387 tensors, strict)`, all keys
  matched against `model.vision_model.*` prefix.
- Iter-1 training hang observed downstream of the loaded weights is a
  separate issue (reproduces independent of `dcp.load` -- only requires
  the `param.data.copy_` into vision params) and is filed for follow-up;
  the loader itself behaves correctly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the manual `param.data.copy_(...)` loop with
`radio_model.load_state_dict(cleaned, strict=False)` + the standard
`incompatible` validation. Matches the pattern in
`examples/mimo/utils/model_helpers.load_submodule_ckpt`. We still build
the cleaned state_dict ourselves so we can TP-slice each ckpt tensor
to the local rank's parameter shape before handing it to PyTorch.

Diagnostic note: nb-hel testing showed iter-1 training hangs after this
load completes successfully (387/387 strict). The hang reproduces
regardless of whether the load uses `dcp.load`, the manual `copy_`
loop, or this `load_state_dict` path. Bisection isolated the trigger
to writing any value (real or uninitialized) into the radio params
post-DDP-wrap. Filed for separate follow-up; the loader itself is
correct.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Architectural fix attempting to resolve the iter-1 hang observed in
earlier revisions of this PR:

- Move `load_vision_from_checkpoint` from the loop-runtime into
  `build_mimo_runtime` so it runs AFTER `wrap_active_modules_with_ddp`
  but BEFORE the loop's `build_optimizer` call. The distributed
  optimizer's fp32 main-param mirror is then built from the loaded
  bf16 weights, which mirrors the standard megatron `load_checkpoint`
  path (training/training.py:1837) where the optimizer's main_params
  come from the checkpoint, not from random init.
- Add `_full_checkpoint_exists` helper so the warm-start skips when
  `--load` resolves a real full checkpoint (the main load handles
  encoder weights in that case).
- Pin `device_id` in `init_process_group` to silence pytorch's
  "Guessing device ID … can cause a hang if rank to GPU mapping is
  heterogeneous" warning. Our encoder grid (offset 0) and LLM grid
  (offset 4) make this exactly the case pytorch warns about.

Note: nb-hel testing still shows an init-time hang on the very first
post-load NCCL collective. The hang reproduces regardless of where the
load runs in the lifecycle (pre-DDP-wrap, between DDP-wrap and optim,
or post-optim) and regardless of which mechanism writes the weights
(direct `param.data.copy_`, `radio_model.load_state_dict`, or even
uninitialized memory). The trigger is bisected to writing any value
into encoder params on encoder ranks — the symptom moves from "iter 1
hangs" to "init hangs" as we load earlier in the lifecycle. This is a
real bug worth a separate investigation; the loader code itself is
correct and strict-validates 387/387 keys from the C-RADIO Omni DCP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Match megatron's `training.py:1826` post-modification path: after the
vision DCP load mutates encoder params in place, call
`optimizer.reload_model_params()` so the distributed optimizer's fp32
main-param mirror is rebuilt from the just-loaded bf16 weights.

Without this, the fp32 mirror stays at random init while the bf16 view
holds the loaded weights — the first optimizer step would silently
overwrite our load with the stale fp32, and (on nb-hel) the mismatch
deadlocked the very first post-init NCCL collective. With
`reload_model_params()` the run gets past init and into the training
loop. (A separate iter-1 hang remains and is being tracked.)

Revert the load-before-DDP-wrap relocation from the previous commit:
in-place mutation after DDP-wrap is the supported pattern (it is what
megatron's `load_checkpoint` does at `training.py:1837`); the prior
revision added complexity without fixing anything.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Enable `faulthandler.dump_traceback_later(120, repeat=True)` in
`train_hetero.py` so every rank dumps its python stack every 2 minutes.
Output goes to per-rank stderr, picked up by cog/slurm log capture
without any additional plumbing.

This was load-bearing for diagnosing the `--load-vision-from` iter-1
hang on nb-hel: dumps showed the encoder is stuck in
`radio.py:213 self.embedder(x)` → TP `all_gather_into_tensor`, and the
LLM is stuck in `BridgeCommunicator.recv_forward`. The hang is at the
first column-parallel `all_gather` inside the radio encoder, not in
the cross-grid pipeline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant