NMFW-464 Nemotron VLM hetero mock training#16
Conversation
| has_bos=True, | ||
| has_system_role=True, | ||
| ) | ||
| elif prompt_format == "nemotron6-moe": |
There was a problem hiding this comment.
did we need this ? was this there in the other branch ?
|
|
||
| return cls(**init_dict) | ||
|
|
||
| @classmethod |
There was a problem hiding this comment.
why did we need this function now? we just added expert groups ? how was this working without expert groups before?
| default=MOCK_MODEL_PRESET, | ||
| help="Model config preset. The Nemotron preset matches the 20L reference script.", | ||
| ) | ||
| model.add_argument("--hidden-size", type=int, default=128) |
There was a problem hiding this comment.
do we need all these hidden size, num attention heads etc as part of args? the model provider covers this anyway? and we can just specify the moel provider or model preset, i'm not sure if model preset is a right word, should we call it model provider ?
| model.add_argument( | ||
| "--freeze-projection", action="store_true", help="Freeze vision projection params" | ||
| ) | ||
| model.add_argument( |
There was a problem hiding this comment.
also these are specific to nemotron vlm. i'm not sure if training loop args should be seperate from model specific ones and if we can group model specific ones separately so the the training loop is generic across models
| return parser.parse_args() | ||
|
|
||
|
|
||
| def apply_model_preset(args: argparse.Namespace) -> None: |
There was a problem hiding this comment.
do we need this function ?
| args.image_seq_length = NEMOTRON_20L_IMAGE_SEQ_PER_TILE * args.num_image_tiles | ||
|
|
||
|
|
||
| def apply_training_stage(args: argparse.Namespace) -> None: |
There was a problem hiding this comment.
this should belong to nemtotron vlm provider ?
| args.training_stage = stage | ||
|
|
||
|
|
||
| def resolve_image_token_id(args: argparse.Namespace) -> None: |
There was a problem hiding this comment.
this is also nemtron vlm specific? also the 20l ones is one specifc one, we will add multiple variants later with different num layers. also i'm not sure if we need is_nemtron_20l all over the place ?
| from megatron.core import parallel_state | ||
|
|
||
|
|
||
| def clear_transformer_engine_env() -> None: |
There was a problem hiding this comment.
we dont need this
| def initialize_distributed() -> None: | ||
| """Initialize torch.distributed for torchrun.""" | ||
| clear_transformer_engine_env() | ||
| os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") |
There was a problem hiding this comment.
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") is this required ?
| elif num_tokens is not None: | ||
| loss_acc[1] += float(num_tokens) | ||
|
|
||
| dist.all_reduce(loss_acc, op=dist.ReduceOp.SUM, group=language_pg.tp_dp_cp) |
There was a problem hiding this comment.
why is this tp_dp_cp? does this match existing Megatron lm train loop?
| ) | ||
|
|
||
| debug_rank("creating MIMO optimizer stats group") | ||
| optimizer_stats_group = dist.new_group(ranks=list(range(world_size)), backend="nccl") |
There was a problem hiding this comment.
why do we need, optimizer stats group? also this seems to be whole world ? doesnt mimo optimizer handle this already?
| NEMOTRON_VISION_ENCODER_KEY = "radio_encoder" | ||
|
|
||
|
|
||
| def is_nemotron_20l(args) -> bool: |
There was a problem hiding this comment.
do we need this ?
| loss = fused_vocab_parallel_cross_entropy(logits, labels, self.pg_collection.tp) | ||
| else: | ||
| loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels) | ||
| loss = tensor_parallel.vocab_parallel_cross_entropy( |
There was a problem hiding this comment.
lets not use unfused loss, so lets not make this change to decrease the surface area of changes
| if packing_kwargs is not None: | ||
| for key in packing_kwargs: | ||
| if 'cu_seqlens' in key and packing_kwargs[key] is not None: | ||
| packing_kwargs[key] = packing_kwargs[key].to(dtype=torch.int32) |
There was a problem hiding this comment.
lets not include convert to int 32 or setting format here, this should be concern of data loader
| # Key output for non-last stages so schedule can route to next LM stage | ||
| if not self.role.is_last_stage(lang_name): | ||
| return {lang_name: lm_output} | ||
| return {lang_name: lm_output}, loss_mask |
There was a problem hiding this comment.
do we really need to return loss mask here ?
|
|
||
| if self.cfg.seq_parallel and embeddings is not None: | ||
| embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) | ||
| # GPT/Hybrid output layers gather sequence-parallel hidden states |
There was a problem hiding this comment.
i dont think this is correct, do we need this ? this was working before ?
There was a problem hiding this comment.
lets double check
| opt.zero_grad(set_to_none) | ||
|
|
||
| def get_loss_scale(self) -> torch.Tensor: | ||
| """Return the loss scale from the first active optimizer, or one for stubs.""" |
There was a problem hiding this comment.
dont add new doc strings unless asked
| self, | ||
| module_infos: Dict[str, ModuleOptimizerInfo], | ||
| config: OptimizerConfig, | ||
| stats_group: Optional[torch.distributed.ProcessGroup] = None, |
There was a problem hiding this comment.
do we really need this stats group?
| grid.create_pg(["tp", "ep", "pp"]) | ||
| grid.create_pg(["dp", "ep"]) | ||
| grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) | ||
| grid.create_pg("tp_ep_pp") |
There was a problem hiding this comment.
why tp_ep_pp insteaf of tp, ep, pp ?
| gathered_params_key = [ | ||
| None for _ in range(torch.distributed.get_world_size(group=param_group_sync_group)) | ||
| ] | ||
| torch.distributed.all_gather_object( |
There was a problem hiding this comment.
why do we need param group sync group ?
| shard_factor = None | ||
| seq_dim = None # which dimension holds the token sequence | ||
|
|
||
| # MimoModel.forward() passes embeddings in batch-first layout |
There was a problem hiding this comment.
is this comment correct ? at which part of forward is it B,S,H ?
| """Annotate flat modality outputs with per-sample split sizes for bridge fan-out.""" | ||
| if ( | ||
| not isinstance(output, torch.Tensor) | ||
| or output.ndim != 2 |
There was a problem hiding this comment.
why do we need so many checks here ? with or ?
| # Non-first stage: receive hidden states from previous LM stage | ||
| hidden_states = input_tensors.get(lang_name) if input_tensors else None | ||
|
|
||
| if self.partition_adapter is not None: |
There was a problem hiding this comment.
if self.partition_adapter is not None looks like this logic is replicated at multiple places? what are these multiple places?
| input_ids: Input token IDs. Shape: (B, S) | ||
| position_ids: Position IDs. Shape: (B, S) | ||
| attention_mask: Attention mask. Shape: (B, S) | ||
| attention_mask: Accepted for API compatibility. This path currently relies on |
There was a problem hiding this comment.
we dont have to change this doc string
| """Return an empty projected activation for text-only non-colocated batches.""" | ||
| hidden_size = self.config.hidden_size | ||
| param = next(submodule.parameters(), None) | ||
| if param is not None: |
There was a problem hiding this comment.
why do we have so much if else here, do we need all these ? also when exactly is _empty_modality_output called? in non colocated what if submodule is on differen rank?
| raise RuntimeError( | ||
| f"{encoder_name} inputs are missing, but matching special tokens exist" | ||
| ) | ||
| output = self._empty_modality_output(submodule, input_ids) |
There was a problem hiding this comment.
does this empty output works through bridge communicator ?
| decoder_input=combined_embeddings, | ||
| labels=labels, | ||
| attention_mask=attention_mask, | ||
| attention_mask=None, |
There was a problem hiding this comment.
we dont have to change this ? just pass attention_mask=attention_mask ?
| decoder_input=None, | ||
| labels=labels, | ||
| attention_mask=attention_mask, | ||
| attention_mask=None, |
There was a problem hiding this comment.
attention_mask=attention_mask ?
* Add 54L Nemotron MoE VLM provider * Address 54L provider review comments * Inline static Nemotron provider config * Hardcode Nemotron language architecture config
…ng loop (#26) * Add distributed-checkpoint save/load to the hetero MIMO training loop Adds the standalone `examples/mimo/training/hetero/checkpointing.py` module plus the CLI surface and loop wiring needed to round-trip MimoModel, MimoOptimizer (ChainedOptimizer-of-DistributedOptimizers in the MoE recipe) and the LR/WD scheduler through `megatron.core.dist_checkpointing` without depending on the `parallel_state` singleton. Layout stays compatible with `megatron/training/checkpointing.py` output: `<save>/latest_checkpointed_iteration.txt` plus per-iteration directories containing `common.pt`, `metadata.json`, `.metadata`, and torch_dist shards. Common state now carries `args`, `checkpoint_version=3.0`, the LR scheduler state, and a per-branch `mimo.{branch}.rng_state` ShardedObject; the tracker read uses a cross-rank MAX reduce to mirror megatron's `read_metadata`. Fixes three pre-existing dist-ckpt bugs that hetero usage uncovered: - `megatron/core/ssm/mamba_mixer.py` was calling `make_sharded_tensors_for_checkpoint` without passing `tp_group` and `dp_cp_group`, which fell back to the parallel_state singleton and asserted in hetero mode (gated_delta_net was already correct). - `MimoOptimizer.sharded_state_dict` now applies `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` to each per-branch optimizer sub-dict so two modules' identical internal ShardedObject keys (e.g. `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide. - `_get_replica_id` now folds in `tp_rank` so two TP ranks within DP=0 don't both claim primary writer for the same shard. Also routes DistributedOptimizer's per-module `param_state_sharding_type` config string through a new ShardedObject (`_extract_*` helpers) so the non-rank-0 module owner doesn't lose it when only rank 0's common.pt is authoritative. A `_propagate_tp_groups_for_checkpoint` walker stamps `self.tp_group` on descendants that omit it (e.g. `ExtendedRMSNorm`, RADIO submodules) so the default `MegatronModule.sharded_state_dict` path doesn't fall through to `parallel_state.get_tensor_model_parallel_group`. Validated end-to-end on cw-dfw 8-GPU 20L mock (stage2): - Save iter 3 (DistributedOptimizer + EP=4 + TP=2 + 2-module Chained) - Reload iter 3 → resume at iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4), losses match prior trajectory. New flags: `--save`, `--load`, `--save-interval`, `--no-save-optim`, `--no-load-optim`, `--no-load-scheduler`, `--no-save-rng`, `--no-load-rng`, `--finetune`, `--dist-ckpt-optim-fully-reshardable`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Adopt MimoOptimizer checkpoint patterns from NVIDIA#4801 Three convergent simplifications to MimoOptimizer's distributed-checkpoint path, matching Kamran's MimoOptimizer fixes PR: 1. Replace the ShardedObject-based round-trip for `param_state_sharding_type` with a metadata stash. The sharding type is not per-rank state — it's a load-time interpretation hint that the caller supplies via the `metadata` kwarg on `sharded_state_dict()`. We stash that metadata in `self._last_sharded_metadata` at save and re-inject the sharding type into each per-module sub state-dict during `load_state_dict()` for ranks that lost it via dist_checkpointing's common-state path (i.e. non-rank-0 module owners in non-colocated layouts). Drops `_extract_param_state_sharding_type` / `_restore_param_state_sharding_type` along with their ShardedObject keys. 2. `_restore_param_groups` now uses `setdefault('optimizer', {})` before writing back `param_groups`. After `_extract_param_groups` deletes `param_groups` at save time, the leftover empty `'optimizer'` dict can be dropped by the common-state round-trip on ranks whose active module wasn't on rank 0 at save. The setdefault makes the restore path tolerant of that drop. 3. `_get_replica_id` reorders to `(tp_rank, pp_rank, dp_rank)` to match the convention used by `make_sharded_object_for_checkpoint` in `megatron/core/transformer/utils.py:168-172`. Dedup math is unchanged — `(0, 0, 0)` is still the primary replica — but the order is now consistent with the rest of the codebase. Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2): save iter 3, reload, resume iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4) and matching loss trajectory. Save exit 0, load exit 0. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Document _propagate_tp_groups_for_checkpoint escapees from no-walker run Disabled `_propagate_tp_groups_for_checkpoint` and re-ran the 20L mock to enumerate exactly which modules fall through to `parallel_state.get_tensor_model_parallel_group()` and assert. Confirmed both branches escape: - RADIO encoder internals (first failure, reached via `nemotron_moe_vlm.RadioEncoder.sharded_state_dict` → HF radio_model leaves with no tp_group + no own sharded_state_dict). - `MambaLayer.__init__` in `megatron/core/ssm/mamba_layer.py` plumbs pg_collection to the mixer but never sets `self.tp_group`. - `ExtendedRMSNorm` at `megatron/core/ssm/mamba_mixer.py:93` never sees pg_collection at all. Fixing each at the source would mean patches across core (Mamba) plus a partial walk of RADIO's HF wrapper, validated against all existing non-hetero users of those modules. The walker is the smaller intervention: one place, hasattr-guarded, applied per branch with the correct pg. Re-enables the walker (it was already in PR1; this commit only updates the docstring to record the experiment's findings). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Switch back to ShardedObject for param_state_sharding_type (NVIDIA#4791) Kamran reverted the metadata-stash approach in NVIDIA#4801 (discussion r3250847203) and adopted Li Ding's PR NVIDIA#4791 pattern, which is the same ShardedObject round-trip we had originally. Align our MimoOptimizer with that final shape: - Restore `_extract_param_state_sharding_type` / `_restore_param_state_sharding_type` helpers. Hooks back into the existing `_iter_optimizer_sub_dicts` loop. - Add `if not opt_sub: del sub_sd['optimizer']` to `_extract_param_groups` (from NVIDIA#4791) so the now-empty `'optimizer'` wrapper doesn't round-trip through common-state with undefined behavior on the load side. - Drop `self._last_sharded_metadata` and the metadata-stash recover path from `load_state_dict` / `sharded_state_dict`. The ShardedObject route is self-contained and doesn't need caller-state coupling. Kept (not in NVIDIA#4791, specific to our non-colocated hetero layout): - `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` so the two branches' identical inner ShardedObject keys (e.g. `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide. - `_get_replica_id` returning `(tp_rank, pp_rank, dp_rank)` (from NVIDIA#4801). Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2): save iter 3 exit 0, reload + resume iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4), matching loss trajectory across the boundary. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Drop _stamp_tp_group walker; fix the three constructors at the source Three modules under our hetero save path don't store `self.tp_group` in their constructors and therefore trip `MegatronModule.sharded_state_dict`'s parallel_state fallback (`megatron/core/transformer/module.py:85`) in heterogeneous-parallelism layouts where parallel_state is intentionally not initialized. Fix them at the source instead of papering over with the hasattr-guarded walker: - `megatron/core/models/vision/radio.py:RADIOViTModel.__init__` — already extracts `tp_group` at line 129 for the embedder; now also stamps `self.tp_group = tp_group`. - `megatron/core/ssm/mamba_layer.py:MambaLayer.__init__` — takes pg_collection and plumbs it into the mixer; now also stores `self.tp_group = pg_collection.tp` on the layer itself. - `megatron/core/ssm/mamba_mixer.py:ExtendedRMSNorm` — adds an `__init__(*args, tp_group=None, **kwargs)` override that stores `self.tp_group` eagerly, and updates the single call site at line ~369 to pass `tp_group=self.pg_collection.tp`. The lazy `hasattr` fallback inside `sharded_state_dict` is preserved for callers that don't pass tp_group. With these three constructor fixes in place, the `_propagate_tp_groups_for_checkpoint` walker (and `_stamp_tp_group` helper) in `examples/mimo/training/hetero/runtime.py` is no longer needed. Removed entirely. Validated on cw-dfw 1-node 8-GPU 20L mock with the walker disabled: - save iter 3 exit 0 (DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2) - reload iter 3 → resume iter 4-5 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4), exit 0 - losses match prior runs (iter 1: 12.187, iter 2: 12.190, iter 3: 12.177, resume iter 4: 11.817, iter 5: 11.264) The downstream check `if not hasattr(self, 'tp_group')` in subsequent descendants (TransformerBlock, TransformerLayer, Attention, MLP, ColumnParallelLinear) was already satisfied by their own constructors; verified by reading those files. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t_param_groups all-gather (#27) Two convergent fixes for MIMO + frozen modules under non-colocated parallelism, cherry-picking the pattern from NVIDIA#4790 (Li Ding, "Fix MIMO optimizer setup for frozen modules") with comments updated to cite the source and link to NMFW-464 context. ## Problem 1 — placeholder optimizer for all-frozen modules `get_mimo_optimizer` previously called `get_megatron_optimizer` on every non-None module, including modules whose params are all frozen on every rank in the module's group. The most common trigger is `--training-stage stage1` (the LLaVA projector-only recipe), which sets `--freeze-vit --freeze-lm` and leaves the language model with zero trainable parameters on the LLM ranks. The resulting placeholder DistributedOptimizer either crashes in downstream setup or behaves silently incorrectly (e.g., LR scheduler advancing an empty group, save of a degenerate optimizer state). ## Problem 2 — `_get_param_groups` all-gathers over WORLD `_get_param_groups` reconciles param-group keys across ranks of the same model via `all_gather_object(params_key, …)` over the global default group. In non-colocated MIMO, encoder ranks and LLM ranks are disjoint and own different params (RADIO vs Mamba), so a WORLD-group all-gather pollutes both branches with the other branch's keys. ## Fixes - `megatron/core/models/mimo/optimizer.py`: add `_module_has_any_trainable_parameters(module, pg_collection)` — an all-reduce-MAX over `pg_collection.intra_dist_opt` of the local trainable-param count. Gate the `get_megatron_optimizer` call on it. When false, leave `info.optimizer = None` so `MimoOptimizer.is_stub_optimizer` handles the branch (that path was already designed for this; the gating never triggered it before). - `megatron/core/optimizer/__init__.py`: add an optional `process_group=None` kwarg to `_get_param_groups` / `_get_param_groups_and_buffers`, plumbed through `_get_megatron_emerging_optimizer` and `get_megatron_optimizer`, so the cross-rank `all_gather_object` can target a specific group. MIMO passes `pg_collection.intra_dist_opt`. Default `None` preserves the current WORLD-group behavior for every existing non-MIMO caller. ## Validation cw-dfw 1-node 8-GPU 20L mock, stage1 (vit + lm frozen, projector-only): - 3 iters, exit 0 - grad norm 0.025 / 0.020 / 0.016 (consistent with only projector params having live grads) - The `learning rate: ...` field in the iteration log is now absent — the previous-PR-era smoke printed it, because the placeholder language optimizer was being queried for an LR; with this fix the language optimizer is correctly `None` and the logger has nothing to query. That's the visible signature of the fix landing. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reproduces Sanjeev's `examples/multimodal/v3/pretrain_3b_nano_vlm_sota_90t_10v.sh`
(`sasatheesh/megatron-lm!45`, job 202967) on 9 HEL nodes with our hetero MIMO
training loop, modulo three items deferred to follow-up PRs (load_vision_from,
correct_encoder_grad_for_partial_participation, dataloader_save) and MTP.
Layout
- LLM grid: TP=2, EP=16, DP=32 (encoder grid unchanged at TP=1, DP=8, EP=1).
- 9n sbatch wrapper bumps `LLM_TP 4 -> 2`, `LLM_DP 16 -> 32`, `NUM_WORKERS 0 -> 2`.
Model provider (`nemotron_moe_vlm.py`)
- `moe_aux_loss_coeff: 1e-9 -> 1e-4` for non-trivial router pressure.
- `bias_dropout_fusion: False -> True` for args-dump parity (inert under
add_bias_linear=False but matches Sanjeev).
- Make `mamba_num_groups=8`, `mamba_state_dim=128`, `linear_conv_kernel_dim=4`
explicit in `nemotron_language_config` (mcore defaults already match; we
declare to be defensive).
- Pass `share_embeddings_and_output_weights=False` to `MambaModel` for
explicit `untie_embeddings_and_output_weights=True` parity.
Vision encoder / data path
- Add `--dynamic-resolution / --no-dynamic-resolution`,
`--dynamic-resolution-min-patches`, `--dynamic-resolution-max-patches` CLI
flags. Default `dynamic_resolution=True` for `nemotron-moe-vlm-*` providers.
- Pin `args.use_thumbnail=False` and `args.use_tiling=False` under dynamic
resolution so `DynamicResolutionImageTilingStrategy` does not emit an extra
thumbnail tile (Sanjeev's run has both False).
- Thread `dynamic_resolution_min/max_patches` and `dynamic_resolution_min/
max_side` through `VisionConfig` in `energon_multimodal_provider.py`.
Optimizer + WSD scheduler
- Add `--train-samples`, `--lr-warmup-samples`, `--lr-decay-samples`,
`--lr-wsd-decay-samples`, `--lr-wsd-decay-style`. Extend `--lr-decay-style`
choices to include `WSD`.
- `validate_args` derives `train_iters = ceil(train_samples / gbs)` and
enforces WSD requires both wsd_decay_samples and wsd_decay_style.
- `build_optimizer_param_scheduler` honors the sample-based knobs (taking
precedence over iter-based) and passes `wsd_decay_steps` /
`lr_wsd_decay_style` to `OptimizerParamScheduler`.
- `run_hetero` lifts `LR/MIN_LR/WEIGHT_DECAY/LR_DECAY_STYLE` to env-var
defaults and threads optional sample-based knobs through.
DDP
- Add `--overlap-param-gather`, `--ddp-num-buckets`, and
`--ddp-pad-buckets-for-high-nccl-busbw` CLI flags.
- `runtime._resolve_bucket_size` derives DDP bucket_size from
num_parameters // num_buckets when `--ddp-num-buckets` is set, else honors
`--ddp-bucket-size`, else returns None for mcore auto-default.
- Vision DDP keeps both overlap knobs forced OFF (partial-participation
safety: text-only batches leave some encoder DP ranks with zero grads).
Self-contained parity sbatch
- `sbatch_hetero_nemotron_54l_hel_9n_parity.sh`: every training value pinned
inline (no `${VAR:-default}` fallbacks). GBS is canonical and
NUM_MICROBATCHES is derived as `GBS / (MBS * LLM_DP) = 24`.
- Uses `megatron-venv-baked-206674.sqsh`; staged tokenizer + post-c-radio-omni
encoder under `agents-scratch`.
- Sanjeev-faithful values: lr=1.2e-3, min_lr=1.2e-5, wd=0.1, WSD with
minus_sqrt tail, warmup 1.024M samples, decay 35.6M samples, wsd-decay
5.49M samples, train_samples 36.62M (~300B tokens, 47684 iters at gbs=768),
log_interval=100, save_interval=1000, seed=1234, class_token_len=10,
image_tag_type=internvl, max_num_tiles=1, overlap-grad-reduce,
overlap-param-gather, ddp-num-buckets=8, ddp-pad-buckets.
- `--load-vision-from` commented out pending PR_load_vision_from.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add Nemotron-format ckpt loader for hetero MIMO
- Add load_nemotron_vlm_ckpt_hetero and --load-nemotron-checkpoint flag for
loading pre-vlm-05 Nemotron-format VLM dist-ckpts from the hetero pipeline.
- After the custom load, refresh the DistributedOptimizer's FP32 main-param
shards via optimizer.reload_model_params(); the standard load_checkpoint
path does this automatically, but the custom loader bypasses it. Without
this the optimizer steps with the model-provider init weights instead of
the loaded ckpt weights.
- Seed python random + numpy + torch in the hetero entry to match
Megatron's _set_random_seed (energon's text_packing shuffle uses the
global random module).
- Add --correct-encoder-grad-for-partial-participation flag (consumed in a
follow-up commit by grad_sync.py).
- Add --train-samples flag (samples-based budget; --train-iters is derived).
- Fix RADIO pos-embedding bilinear interpolation to align_corners=False to
match upstream RADIO.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Hetero correctness fixes: fp32 grad reduce + encoder grad participation
- runtime.py: set grad_reduce_in_fp32=True on both language and vision DDP
configs (mirrors --accumulate-allreduce-grads-in-fp32). The default False
produces bf16 main_grad, which drifts step-2 weights after Adam.
- grad_sync.py: when only some encoder DP ranks process images in a step,
scale vision grads post-DP-reduce by encoder_dp_size / participation_count.
Without this the vision encoder learns at a diluted rate.
- nemotron_moe_vlm.py: set moe_router_fusion=False to match the
TransformerConfig default. The fused softmax/topk kernel takes a different
bf16 reduction path, slightly perturbing router probs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Hetero data pipeline parity
- energon_multimodal_provider.py: forward HF tokenizer's chat_template and
apply_chat_template through TokenizerAdapter so energon's
tokenize_and_prepare can find them. Add _supported_kwargs filter on
VisionConfig so the same recipe args work across energon versions whose
VisionConfig accepts different kwargs.
- hetero_energon.py: use get_savable_loader (SavableDatasetWrapper sets a
worker_id_offset and per-worker init that affects step-0 sample order).
Use the unsalted seed in the single-lane iterator; energon's
WorkerConfig(rank=lane, world_size=llm_dp) already salts per-rank, so
adding a +lane offset over-salts and de-aligns sample ordering.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Parity sbatch scripts + broaden modelopt import guard
- Add sbatch_hetero_parity_100step.sh and sbatch_sanjeev_parity_100step.sh
as paired 150-step train-loss parity drivers (hetero MIMO vs reference
recipe). Settings match the Sanjeev-202967 hetero arg-parity table.
- training.py: broaden the modelopt distill-plugin import guard from
ImportError to Exception. Some container builds ship modelopt against a
transformers version that removed transformers.modeling_utils.Conv1D, so
importing the distill plugin raises AttributeError during module init.
Distill is optional, so skip safely.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Address PR review comments
- Revert training.py modelopt import-guard change (not needed for hetero
correctness; container-specific compat issue tracked elsewhere).
- Trim verbose docstrings in model_helpers.py (_load_submodule_from_ckpt,
load_nemotron_vlm_ckpt_hetero) and args.py --load-nemotron-checkpoint help.
- hetero_energon.py: drop the WorkerConfig-salting comment and the over-
defensive try/except wrapping get_savable_loader. Match pre-vlm-05 exactly:
one direct call with cache_pool=NoCachePool() and the watchdog kwargs.
- grad_sync.py: replace expensive (buffer.grad_data != 0).any() scans
with a one-bool participation flag set by forward_step from batch.images.
Combine the per-token normalization and the partial-participation
correction into a single scale_gradients call per submodule (was two
separate kernel launches before).
- step.py: call mark_modality_participation in forward_step and
reset_modality_participation at the top of each train_step.
- loop.py: extract the --load-nemotron-checkpoint branch into
load_and_refresh_nemotron_checkpoint helper in model_helpers.py.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Point sanj parity sbatch at clean pre-vlm-05 clone
SANJEEV_REPO now defaults to ${SCRATCH_ROOT}/sanjeev-repos/megatron-lm-clean,
a fresh checkout of sasatheesh/pre-vlm-05 with only the two correctness
changes needed for the parity baseline (model.py calculate_per_token_loss
honors --calculate-per-token-loss; recipe sh passes the flag). All NMFW
debug instrumentation that lived on the old sanjeev-repos/megatron-lm
checkout is dropped from this baseline.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* RADIO: set mtp_num_layers=0 so final layernorm applies
With post_process=False on the RADIO TransformerBlock and
mtp_num_layers defaulting to None, has_final_layernorm_in_this_stage
short-circuits to False and the final layernorm is dropped. The sanj
recipe passes --mtp-num-layers 0, which takes the alternate branch and
keeps the layernorm on the stage that holds layer.num_layers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* RADIO: pass ln_post_impl=TENorm to RADIOViTModel
The RADIO post-encoder layernorm lives on RADIOViTModel as self.ln_post,
applied after the decoder in radio.py:239-240. The ckpt converter writes
RADIO's inner.norm.{weight,bias} into ln_post.{weight,bias}; without
ln_post_impl=TENorm the wrapper leaves self.ln_post=None and the loader
silently drops the ckpt entries (StrictHandling.LOG_UNEXPECTED), so the
vision tower output is the raw decoder hidden state instead of its
post-norm. Sanj-side llava_model.py:309-311 sets the same impl.
Revert previous mtp_num_layers=0 attempt -- that was targeting
TransformerBlock.final_layernorm, which is a different module and not
where the ckpt weight lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Revert RADIO layernorm experiments
Both attempts targeted layernorms that aren't actually missing for this
RADIO variant:
- mtp_num_layers=0 (0b5996c) targeted TransformerBlock.final_layernorm,
but that gating returns False here for unrelated reasons.
- ln_post_impl=TENorm (82ce396) builds RADIOViTModel.ln_post, but the
sanj iter_1000 ckpt has no ln_post.* keys -- llava_model.py only sets
ln_post_impl=TENorm for vision_model_type=='radio-g', not for the
cradio variant we run.
Confirmed by 266840's load failure: model requested vision_model.ln_post.*
but dist_checkpointing flagged them as unexpected (not found in ckpt).
The actual sanj parity baseline (266771, num_layers=54) still sits at
iter1=2.898, hetero 266780 at iter1=2.757 -- same pattern as prior runs,
no missing-norm regression.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…der rank (#31) Replace the per-LLM-lane Energon iterators owned by each encoder DP rank with a single multiplexed iterator whose worker pool is sized ``args.num_workers * lanes_per_encoder``. Samples are routed back to their owning LLM lane using the producing worker's ``WorkerConfig.global_worker_id``, which the MIMO multimodal encoder now stamps onto every batch when ``attach_provenance=True``. At scale (encoder_dp small relative to llm_dp), the per-lane construction path issues ``lanes_per_encoder × num_workers`` shard-open events at iterator creation; collapsing to one iterator per encoder rank cuts the open-burst by ``lanes_per_encoder``× and avoids the previous workaround scripts that staggered loader construction across encoder ranks. The reshape preserves bit-wise sample parity with the per-lane path: - ``global_workers = world_size * num_workers`` is invariant under ``(world_size, num_workers) → (world_size/k, num_workers*k)``. ``WebdatasetSharder.split_samples_to_workers`` partitions shards by global worker index over ``global_workers``, so equal global_worker_ids ⇒ equal shards in equal order. - ``WorkerConfig.worker_seed`` hashes only ``(global_worker_id, seed_offset)`` (see ``megatron/energon/worker.py``); ``seed_offset`` is unchanged. - The routed-iterator's worker W on encoder rank E has ``global_worker_id = E * (num_workers * lanes_per_encoder) + W``, which equals the per-lane worker w on lane L=E*lanes_per_encoder + W//num_workers, w = W%num_workers. ``test_hetero_energon.py`` adds unit tests for ``_route_samples_to_lanes`` (round-robin fill, surplus FIFO, lane-offset shift, pull-budget overflow, out-of-range worker id, missing provenance) plus an algebraic parity test asserting the global_worker_id equivalence above and a global_workers-invariant sweep across (encoder_dp, llm_dp, num_workers) shapes. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…o MIMO VLM (#32) 1) Dynamic-resolution RADIO ViT Patchify each image at its native aspect ratio with a token budget, interleave per-tile class tokens, run THD-packed attention through the transformer block. Line-for-line equivalent to sanj's pre-vlm-05 RADIOViTModel.forward dynres branch. - megatron/core/models/vision/radio.py: add dynamic_resolution kwarg, imgs_sizes + packed_seq_params forward kwargs, per-tile apply_pos_enc, interleaved CLS-token concat, packed_seq_params.cu_seqlens shift. - examples/mimo/data/energon_multimodal_provider.py: per-image n_tokens under dynamic_resolution; surface imgs_sizes + PackedSeqParams. - examples/mimo/model_providers/nemotron_moe_vlm.py: RADIOEncoderWrapper accepts dynres kwargs; interleaved-CLS removal mask; per-tile pixel-shuffle. - examples/mimo/training/hetero/step.py: PackedSeqParams cu_seqlens / max_seqlen tensors moved to CUDA before the encoder forward (TE THD attention hangs on H2D-sync of these tensors otherwise). 2) RADIO encoder final_layernorm parity with sanj Sanj's --mtp-num-layers 0 propagates to vision_config.mtp_num_layers via core_transformer_config_from_args. In TransformerBlock.has_final_layernorm_in_this_stage, the else-branch (mtp_num_layers is not None) allocates final_layernorm when the last decoder layer is in this stage — independent of post_process. Without this, hetero's frozen RADIO output magnitude was ~150x larger than sanj's (5344 vs 35.75) because the final LN never ran. The downstream projection + LLM were trained against the LN-normalized magnitude (gamma=1 / bias=0 from ckpt). - examples/mimo/model_providers/nemotron_moe_vlm.py:radio_vision_config sets config.mtp_num_layers = 0. - examples/mimo/utils/model_helpers.py:load_nemotron_vlm_ckpt_hetero switches dist_checkpointing.load to StrictHandling.RETURN_ALL and raises on any non-extra_state key the model requests but the ckpt does not have (mcore's "unexpected" set, opposite of PyTorch's "missing"). Weaker modes silently keep random-init values and masked this bug. Verified at iter 1, GBS=8 (3-node hetero + 2-node sanj parity pair): vision_proj hetero absmax=584 sanj absmax=584 cos=1.000010 combined_embeddings cos=1.000011 iter-1 lm_loss hetero=2.756 sanj=2.750 Δ=+0.22% iter-2/3 lm_loss rel Δ < 1.1% Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- args.py: new --tensorboard-dir flag.
- logging.py: HeteroTrainingLogger creates a torch.utils.tensorboard
SummaryWriter on the language logging rank when --tensorboard-dir is set.
Emits the same scalar keys Megatron's standard training_log uses (lm loss,
learning-rate, grad-norm, batch-size, loss-scale, num-zeros) plus
iteration-time-ms, both per-iter and "vs samples", so TB plots overlay
cleanly against the reference run's logs.
- sbatch_hetero_parity_100step.sh: pass --tensorboard-dir "${RUN_DIR}/tensorboard"
so parity runs auto-log to a per-run TB dir.
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #33's MoE-tracker block gated on args.num_experts + args.moe_router_load_balancing_type, but hetero stores those as args.num_moe_experts and on the TransformerConfig (not args). The gate never fired, so seq_load_balancing_loss never reached TB. Fix: gate on args.num_moe_experts; hardcode track_names to ["seq_load_balancing_loss"] (the only LB type Nemotron6-MoE uses); compute num_moe_layers from args.hybrid_layer_pattern's E-count so the per-iter average over MoE layers matches sanj's training_log output (54L pattern has 27 'E' layers, not 54). Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
#35) When encoder_dp < llm_dp, the routed encoder iterator merges N per-lane batches into one encoder forward via _combine_encoder_batches -> _concat_nested_tensors. After PR #32 (dynres) added a PackedSeqParams dataclass into modality_inputs.images.<encoder>.packed_seq_params, two bugs in this merge path surface together on multi-lane configurations (e.g. 9n GBS=192: ENCODER_DP=8, LLM_DP=32, lanes_per_encoder=4): 1) TypeError: cannot concatenate encoder batch value of type PackedSeqParams. PR #31 only handled torch.Tensor and dict; the dataclass fell through to the final raise. 2) RuntimeError: Sizes of tensors must match except in dimension 0. Per-lane packed image buffers have shape (1, T_lane, C) -- dim 0 is the constant lane batch and T_lane varies. PR #31's blanket torch.cat(present, dim=0) requires every other dim to match and fails on the T axis. The combination of (1) + (2) produced the 0-15-silent / 16-71-NCCL- timeout cascade observed on the 9n parity sbatch: encoder rank 0's loader worker died in torch.cat, lanes 0-3 stalled at the bridge, the remaining LLM ranks reached the DDP allgather and timed out after 10 minutes. Fix: - _concat_packed_seq_params merges N per-lane PackedSeqParams into one set covering the merged flat buffer. cu_seqlens_{q,kv}[_padded] concatenate with running offset = sum of prior lanes' cu_seqlens_q[-1] (the same offset-shift rule used in megatron.energon.task_encoder.multimodal.encoder.py); max_seqlen takes element-wise max; total_tokens sums. qkv_format, local_cp_size, cp_group are asserted equal across lanes. seq_idx is left to PackedSeqParams.__post_init__. - _concat_first_varying_dim concatenates plain tensors along the first dimension whose size differs across lanes, defaulting to dim 0 when all shapes agree. This handles (1, T_lane, C) packed image buffers (dim 1) and (N_images_lane, 2) imgs_sizes (dim 0) uniformly without sibling-key context, and preserves the prior behavior on non-dynres batches. Verified by running the existing 9n GBS=192 parity sbatch end-to-end: 30-iter smoke reaches steady-state iter time around 4.5-6s with a clean lm-loss trajectory; previously the same sbatch crashed inside _combine_encoder_batches before iteration 1.
…ge (#36) PR #35 introduced ``_concat_first_varying_dim`` that selected the concat dim at runtime by looking for the first dim whose size differed across per-lane tensors, falling back to dim 0 when no dim varied. That worked in the common case but fails when two participating lanes happen to produce identically-shaped packed image buffers in the same step: the (1, T, C) buffer falls back to dim-0 cat and becomes (2, T, C) instead of (1, 2T, C). RADIO then splits the buffer using imgs_sizes (correctly cat'd to (2N, 2)) and asserts ``sum(seq_lens) != x.shape[1]`` with an exact 2x ratio — visible in the 1-node standalone smoke at lanes_per_encoder=16 as AssertionError: 15984 != 7992 at radio.py:235 In production hetero training, the rank dies silently in encoder forward; its served LLM lanes stall on bridge recv; the cluster cascades into a 600 s NCCL watchdog timeout that looked like a different bug. The probability of two lanes producing identical (1, T, C) per step grows with ``lanes_per_encoder``: rare at 4 (9n), common at 16 (33n). All "33n hangs" we chased were instances of this. Replace the runtime inference with a schema-aware merger that knows the fixed structure of ``modality_inputs``: packed image buffer (1, T_lane, C) -> torch.cat dim 1 imgs_sizes (N_images, 2) -> torch.cat dim 0 packed_seq_params PackedSeqParams -> _concat_packed_seq_params Anything unrecognized raises a loud ``TypeError`` so a future schema change has to be handled in ``_merge_encoder_inputs`` rather than silently miscompiled by a heuristic. Validated end-to-end: * 1-node standalone (lanes_per_encoder=16, 200 steps): all 8 ranks complete with no AssertionError, no hang. Previously failed at step 27 on rank 5. * 9n GBS=192 production smoke: 25/25 iters, ~5.5 s/iter steady state. * 17n GBS=384 production smoke: 25/25 iters, ~6.2 s/iter steady state. * 33n GBS=768 production smoke: 25/25 iters, ~7.0 s/iter steady state (matches the scaling-study report's 33n target of 7.11 s/iter). Previously hung at iter 5 across multiple attempts.
… Sanjeev LR schedule) (#37) * NMFW-464: production hetero scaling sbatches at 33n / 68n / 100n Adds three production-grade hetero MIMO Nemotron6-MoE VLM training sbatches that mirror Sanjeev's pretrain_3b_nano_vlm_sota_90t_10v.sh schedule: * sbatch_hetero_prod_gbs768_33n_ep8.sh — 33 nodes (1 enc + 32 LLM) * sbatch_hetero_prod_gbs768_68n_ep8.sh — 68 nodes (4 enc + 64 LLM) * sbatch_hetero_prod_gbs768_100n.sh — 100 nodes (4 enc + 96 LLM) Pinned from Sanjeev's baseline: - TRAIN_SAMPLES=122070313 - LR_WARMUP_SAMPLES=1024000 - LR_DECAY_SAMPLES = TRAIN_SAMPLES - LR_WARMUP_SAMPLES = 121046313 - LR_WSD_DECAY_SAMPLES=18310547 - LR_WSD_DECAY_STYLE=minus_sqrt - PACKING_BUFFER_SIZE=128 - NUM_WORKERS=1 - LOG_INTERVAL=100, SAVE_INTERVAL=1000 - --load-nemotron-checkpoint pointing at sasatheesh iter_1000 Deviations from Sanjeev's baseline: - LLM_EP=8 (vs Sanjeev's EP=16) - Hetero topology TP=2 (vs Sanjeev's TP=4); explicit encoder grid - MOE_ROUTER_FORCE_LOAD_BALANCING=0 (natural seq_aux_loss) - No MTP layers - Wall time 4h (Sanjeev's; restartable from save-interval-1000 checkpoints) These sbatches are derived from the scaling-study templates (sbatch_hetero_parity_gbs768_{33n_ep8,68n_ep8,100n}.sh on the ykarnati/nmfw-464-encoder-stall-profiling branch). The timeline-profile instrumentation and the num-distributed-optimizer-instances flag from that branch are dropped here — they are debug-only and not present in the production codebase on ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * NMFW-464: drop obsolete early-development 9n/8n smoke sbatches Removes five sbatch scripts that were only used during early hetero MIMO development and have no external references in the repo (no docs, no Python code, no CI). They are superseded by the new production sbatches (33n/68n/100n) added in this PR, plus the existing parity templates sbatch_hetero_parity_gbs192.sh and sbatch_hetero_parity_gbs32.sh. Removed: * sbatch_hetero_nemotron_54l_hel_9n.sh 30-iter 9n smoke test. Smoke-only; no production use. * sbatch_hetero_nemotron_54l_hel_9n_text_only.sh text-only data-blend smoke test. exec'd _9n.sh — orphaned by its removal. * sbatch_hetero_nemotron_54l_hel_9n_text_vision.sh 90/10 text-vision blend smoke test. exec'd _9n.sh — orphaned by its removal. * sbatch_hetero_nemotron_54l_hel_9n_parity.sh Standalone 9n Sanjeev-parity reproduction. Functionally superseded by sbatch_hetero_parity_gbs192.sh (also 9n + Sanjeev recipe, paired with sbatch_sanjeev_parity_gbs192.sh). * sbatch_mimo_nemotron_54l_hel_8n_text_only_llm.sh 8n LLM-only text-only variant. exec'd _9n.sh — orphaned by its removal. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * NMFW-464: add concise README for the hetero sbatch directory One table covering the 9n parity sbatch and the three production sbatches (33n / 68n / 100n) with their topology, GBS, and purpose. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Scope
Draft PR for NMFW-464 phase 1 and phase 2 review.
examples/mimo/train_hetero.py;train.pyis intentionally untouched.examples/mimo/training/hetero/with separate topology, runtime, step, scheduler, logging, args, and distributed setup modules.examples/mimo/model_providers/hetero_vlm.pyandexamples/mimo/data/hetero_mock.py.calculate_per_token_loss=Trueenforced.Verification
Local:
python3 -m py_compile examples/mimo/train_hetero.py examples/mimo/training/hetero/*.py examples/mimo/model_providers/hetero_vlm.py examples/mimo/data/hetero_mock.py examples/mimo/utils/hetero.pyuv run black --check examples/mimo/train_hetero.py examples/mimo/training/hetero examples/mimo/model_providers/hetero_vlm.py examples/mimo/data/hetero_mock.py examples/mimo/utils/hetero.pyuv run pylint examples/mimo/train_hetero.py examples/mimo/training/hetero/*.py examples/mimo/model_providers/hetero_vlm.py examples/mimo/data/hetero_mock.py examples/mimo/utils/hetero.pybash -n examples/mimo/scripts/run_hetero_mock_train.sh examples/mimo/scripts/run_hetero_nemotron_20l_mock_train.shgit diff --checkCog batch:
nmfw464-hetero-refactor-0510b116802190tests/unit_tests/models/test_mimo_partition.py->25 passedlm loss=6.287042,grad_norm=2.558lm loss=12.19014,grad_norm=3.389