Skip to content

NMFW-464 Nemotron VLM hetero mock training#16

Draft
yashaswikarnati wants to merge 44 commits into
mainfrom
ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallel
Draft

NMFW-464 Nemotron VLM hetero mock training#16
yashaswikarnati wants to merge 44 commits into
mainfrom
ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallel

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Owner

@yashaswikarnati yashaswikarnati commented May 10, 2026

Scope

Draft PR for NMFW-464 phase 1 and phase 2 review.

  • Extends the existing hyper communication grid/process-group plumbing to include expert parallel, expert tensor parallel, expert data parallel, and expert all-rank groups for hetero MIMO MoE workflows.
  • Adds hetero mock training in a separate examples/mimo/train_hetero.py; train.py is intentionally untouched.
  • Refactors the hetero loop into examples/mimo/training/hetero/ with separate topology, runtime, step, scheduler, logging, args, and distributed setup modules.
  • Moves model/data construction responsibilities to examples/mimo/model_providers/hetero_vlm.py and examples/mimo/data/hetero_mock.py.
  • Adds Megatron-like training-step behavior: optimizer-step success gates LR scheduler advancement, skipped iteration accounting is logged, consumed samples are tracked, and LR/grad stats are reduced before language-rank logging.
  • Adds Nemotron 20L VLM mock workflow script matching the reference 20-layer setup, with calculate_per_token_loss=True enforced.
  • Wires explicit process groups through the hetero path for MoE/shared experts, loss, grad finalization, and optimizer behavior.
  • Updates MIMO partition tests for MoE EP/ETP/EDP coverage and the current sequence-parallel label/loss-mask behavior.

Verification

Local:

  • python3 -m py_compile examples/mimo/train_hetero.py examples/mimo/training/hetero/*.py examples/mimo/model_providers/hetero_vlm.py examples/mimo/data/hetero_mock.py examples/mimo/utils/hetero.py
  • uv run black --check examples/mimo/train_hetero.py examples/mimo/training/hetero examples/mimo/model_providers/hetero_vlm.py examples/mimo/data/hetero_mock.py examples/mimo/utils/hetero.py
  • uv run pylint examples/mimo/train_hetero.py examples/mimo/training/hetero/*.py examples/mimo/model_providers/hetero_vlm.py examples/mimo/data/hetero_mock.py examples/mimo/utils/hetero.py
  • bash -n examples/mimo/scripts/run_hetero_mock_train.sh examples/mimo/scripts/run_hetero_nemotron_20l_mock_train.sh
  • git diff --check

Cog batch:

  • Run: nmfw464-hetero-refactor-0510b
  • Slurm log: 11680219
  • Result: completed, exit 0
  • Unit tests: tests/unit_tests/models/test_mimo_partition.py -> 25 passed
  • Generic hetero mock: lm loss=6.287042, grad_norm=2.558
  • Nemotron 20L mock: lm loss=12.19014, grad_norm=3.389

Comment thread megatron/core/tokenizers/vision/libraries/multimodal_tokenizer.py
has_bos=True,
has_system_role=True,
)
elif prompt_format == "nemotron6-moe":
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we need this ? was this there in the other branch ?

Comment thread megatron/core/tensor_parallel/cross_entropy.py
Comment thread megatron/core/process_groups_config.py Outdated

return cls(**init_dict)

@classmethod
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we need this function now? we just added expert groups ? how was this working without expert groups before?

Comment thread examples/mimo/model_providers/nemotron_moe_vlm.py
Comment thread examples/mimo/training/hetero/args.py Outdated
default=MOCK_MODEL_PRESET,
help="Model config preset. The Nemotron preset matches the 20L reference script.",
)
model.add_argument("--hidden-size", type=int, default=128)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need all these hidden size, num attention heads etc as part of args? the model provider covers this anyway? and we can just specify the moel provider or model preset, i'm not sure if model preset is a right word, should we call it model provider ?

Comment thread examples/mimo/training/hetero/args.py Outdated
model.add_argument(
"--freeze-projection", action="store_true", help="Freeze vision projection params"
)
model.add_argument(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also these are specific to nemotron vlm. i'm not sure if training loop args should be seperate from model specific ones and if we can group model specific ones separately so the the training loop is generic across models

Comment thread examples/mimo/training/hetero/args.py Outdated
return parser.parse_args()


def apply_model_preset(args: argparse.Namespace) -> None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this function ?

Comment thread examples/mimo/training/hetero/args.py Outdated
args.image_seq_length = NEMOTRON_20L_IMAGE_SEQ_PER_TILE * args.num_image_tiles


def apply_training_stage(args: argparse.Namespace) -> None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should belong to nemtotron vlm provider ?

Comment thread examples/mimo/training/hetero/args.py Outdated
args.training_stage = stage


def resolve_image_token_id(args: argparse.Namespace) -> None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also nemtron vlm specific? also the 20l ones is one specifc one, we will add multiple variants later with different num layers. also i'm not sure if we need is_nemtron_20l all over the place ?

Comment thread examples/mimo/training/hetero/args.py
Comment thread examples/mimo/training/hetero/args.py
from megatron.core import parallel_state


def clear_transformer_engine_env() -> None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont need this

def initialize_distributed() -> None:
"""Initialize torch.distributed for torchrun."""
clear_transformer_engine_env()
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1") is this required ?

elif num_tokens is not None:
loss_acc[1] += float(num_tokens)

dist.all_reduce(loss_acc, op=dist.ReduceOp.SUM, group=language_pg.tp_dp_cp)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this tp_dp_cp? does this match existing Megatron lm train loop?

)

debug_rank("creating MIMO optimizer stats group")
optimizer_stats_group = dist.new_group(ranks=list(range(world_size)), backend="nccl")
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need, optimizer stats group? also this seems to be whole world ? doesnt mimo optimizer handle this already?

Comment thread examples/mimo/utils/hetero.py Outdated
NEMOTRON_VISION_ENCODER_KEY = "radio_encoder"


def is_nemotron_20l(args) -> bool:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this ?

loss = fused_vocab_parallel_cross_entropy(logits, labels, self.pg_collection.tp)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
loss = tensor_parallel.vocab_parallel_cross_entropy(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not use unfused loss, so lets not make this change to decrease the surface area of changes

Comment thread megatron/core/models/hybrid/hybrid_layer_allocation.py
Comment thread megatron/core/models/mimo/model/base.py Outdated
if packing_kwargs is not None:
for key in packing_kwargs:
if 'cu_seqlens' in key and packing_kwargs[key] is not None:
packing_kwargs[key] = packing_kwargs[key].to(dtype=torch.int32)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not include convert to int 32 or setting format here, this should be concern of data loader

# Key output for non-last stages so schedule can route to next LM stage
if not self.role.is_last_stage(lang_name):
return {lang_name: lm_output}
return {lang_name: lm_output}, loss_mask
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to return loss mask here ?


if self.cfg.seq_parallel and embeddings is not None:
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# GPT/Hybrid output layers gather sequence-parallel hidden states
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think this is correct, do we need this ? this was working before ?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets double check

Comment thread megatron/core/models/mimo/optimizer.py Outdated
opt.zero_grad(set_to_none)

def get_loss_scale(self) -> torch.Tensor:
"""Return the loss scale from the first active optimizer, or one for stubs."""
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont add new doc strings unless asked

Comment thread megatron/core/models/mimo/optimizer.py Outdated
self,
module_infos: Dict[str, ModuleOptimizerInfo],
config: OptimizerConfig,
stats_group: Optional[torch.distributed.ProcessGroup] = None,
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this stats group?

Comment thread megatron/core/models/mimo/optimizer.py Outdated
grid.create_pg(["tp", "ep", "pp"])
grid.create_pg(["dp", "ep"])
grid.create_pg(["tp", "cp", "ep", "pp", "dp"])
grid.create_pg("tp_ep_pp")
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why tp_ep_pp insteaf of tp, ep, pp ?

Comment thread megatron/core/optimizer/__init__.py Outdated
gathered_params_key = [
None for _ in range(torch.distributed.get_world_size(group=param_group_sync_group))
]
torch.distributed.all_gather_object(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need param group sync group ?

shard_factor = None
seq_dim = None # which dimension holds the token sequence

# MimoModel.forward() passes embeddings in batch-first layout
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment correct ? at which part of forward is it B,S,H ?

Comment thread megatron/core/models/mimo/partition/utils.py
Comment thread megatron/core/models/mimo/model/base.py Outdated
"""Annotate flat modality outputs with per-sample split sizes for bridge fan-out."""
if (
not isinstance(output, torch.Tensor)
or output.ndim != 2
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need so many checks here ? with or ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
Comment thread megatron/core/models/mimo/model/base.py Outdated
# Non-first stage: receive hidden states from previous LM stage
hidden_states = input_tensors.get(lang_name) if input_tensors else None

if self.partition_adapter is not None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.partition_adapter is not None looks like this logic is replicated at multiple places? what are these multiple places?

Comment thread megatron/core/models/mimo/model/base.py Outdated
input_ids: Input token IDs. Shape: (B, S)
position_ids: Position IDs. Shape: (B, S)
attention_mask: Attention mask. Shape: (B, S)
attention_mask: Accepted for API compatibility. This path currently relies on
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont have to change this doc string

Comment thread megatron/core/models/mimo/model/base.py Outdated
"""Return an empty projected activation for text-only non-colocated batches."""
hidden_size = self.config.hidden_size
param = next(submodule.parameters(), None)
if param is not None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have so much if else here, do we need all these ? also when exactly is _empty_modality_output called? in non colocated what if submodule is on differen rank?

Comment thread megatron/core/models/mimo/model/base.py Outdated
raise RuntimeError(
f"{encoder_name} inputs are missing, but matching special tokens exist"
)
output = self._empty_modality_output(submodule, input_ids)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this empty output works through bridge communicator ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
decoder_input=combined_embeddings,
labels=labels,
attention_mask=attention_mask,
attention_mask=None,
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont have to change this ? just pass attention_mask=attention_mask ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
decoder_input=None,
labels=labels,
attention_mask=attention_mask,
attention_mask=None,
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask=attention_mask ?

yashaswikarnati and others added 17 commits May 12, 2026 19:23
* Add 54L Nemotron MoE VLM provider

* Address 54L provider review comments

* Inline static Nemotron provider config

* Hardcode Nemotron language architecture config
…ng loop (#26)

* Add distributed-checkpoint save/load to the hetero MIMO training loop

Adds the standalone `examples/mimo/training/hetero/checkpointing.py` module
plus the CLI surface and loop wiring needed to round-trip MimoModel,
MimoOptimizer (ChainedOptimizer-of-DistributedOptimizers in the MoE recipe)
and the LR/WD scheduler through `megatron.core.dist_checkpointing` without
depending on the `parallel_state` singleton.

Layout stays compatible with `megatron/training/checkpointing.py` output:
`<save>/latest_checkpointed_iteration.txt` plus per-iteration directories
containing `common.pt`, `metadata.json`, `.metadata`, and torch_dist shards.
Common state now carries `args`, `checkpoint_version=3.0`, the LR scheduler
state, and a per-branch `mimo.{branch}.rng_state` ShardedObject; the tracker
read uses a cross-rank MAX reduce to mirror megatron's `read_metadata`.

Fixes three pre-existing dist-ckpt bugs that hetero usage uncovered:
- `megatron/core/ssm/mamba_mixer.py` was calling
  `make_sharded_tensors_for_checkpoint` without passing `tp_group` and
  `dp_cp_group`, which fell back to the parallel_state singleton and
  asserted in hetero mode (gated_delta_net was already correct).
- `MimoOptimizer.sharded_state_dict` now applies
  `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` to each per-branch
  optimizer sub-dict so two modules' identical internal ShardedObject keys
  (e.g. `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide.
- `_get_replica_id` now folds in `tp_rank` so two TP ranks within DP=0
  don't both claim primary writer for the same shard.

Also routes DistributedOptimizer's per-module `param_state_sharding_type`
config string through a new ShardedObject (`_extract_*` helpers) so the
non-rank-0 module owner doesn't lose it when only rank 0's common.pt is
authoritative.

A `_propagate_tp_groups_for_checkpoint` walker stamps `self.tp_group` on
descendants that omit it (e.g. `ExtendedRMSNorm`, RADIO submodules) so the
default `MegatronModule.sharded_state_dict` path doesn't fall through to
`parallel_state.get_tensor_model_parallel_group`.

Validated end-to-end on cw-dfw 8-GPU 20L mock (stage2):
- Save iter 3 (DistributedOptimizer + EP=4 + TP=2 + 2-module Chained)
- Reload iter 3 → resume at iter 4 with cosine LR continuation
  (1.59e-4 → 1.32e-4 → 1.01e-4), losses match prior trajectory.

New flags: `--save`, `--load`, `--save-interval`, `--no-save-optim`,
`--no-load-optim`, `--no-load-scheduler`, `--no-save-rng`, `--no-load-rng`,
`--finetune`, `--dist-ckpt-optim-fully-reshardable`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Adopt MimoOptimizer checkpoint patterns from NVIDIA#4801

Three convergent simplifications to MimoOptimizer's distributed-checkpoint
path, matching Kamran's MimoOptimizer fixes PR:

1. Replace the ShardedObject-based round-trip for
   `param_state_sharding_type` with a metadata stash. The sharding type is
   not per-rank state — it's a load-time interpretation hint that the caller
   supplies via the `metadata` kwarg on `sharded_state_dict()`. We stash that
   metadata in `self._last_sharded_metadata` at save and re-inject the
   sharding type into each per-module sub state-dict during
   `load_state_dict()` for ranks that lost it via dist_checkpointing's
   common-state path (i.e. non-rank-0 module owners in non-colocated
   layouts). Drops `_extract_param_state_sharding_type` /
   `_restore_param_state_sharding_type` along with their ShardedObject keys.

2. `_restore_param_groups` now uses `setdefault('optimizer', {})` before
   writing back `param_groups`. After `_extract_param_groups` deletes
   `param_groups` at save time, the leftover empty `'optimizer'` dict can be
   dropped by the common-state round-trip on ranks whose active module
   wasn't on rank 0 at save. The setdefault makes the restore path tolerant
   of that drop.

3. `_get_replica_id` reorders to `(tp_rank, pp_rank, dp_rank)` to match the
   convention used by `make_sharded_object_for_checkpoint` in
   `megatron/core/transformer/utils.py:168-172`. Dedup math is unchanged —
   `(0, 0, 0)` is still the primary replica — but the order is now
   consistent with the rest of the codebase.

Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer +
ChainedOptimizer + EP=4 + TP=2): save iter 3, reload, resume iter 4 with
cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4) and matching loss
trajectory. Save exit 0, load exit 0.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Document _propagate_tp_groups_for_checkpoint escapees from no-walker run

Disabled `_propagate_tp_groups_for_checkpoint` and re-ran the 20L mock
to enumerate exactly which modules fall through to
`parallel_state.get_tensor_model_parallel_group()` and assert. Confirmed
both branches escape:

- RADIO encoder internals (first failure, reached via
  `nemotron_moe_vlm.RadioEncoder.sharded_state_dict` → HF radio_model
  leaves with no tp_group + no own sharded_state_dict).
- `MambaLayer.__init__` in `megatron/core/ssm/mamba_layer.py` plumbs
  pg_collection to the mixer but never sets `self.tp_group`.
- `ExtendedRMSNorm` at `megatron/core/ssm/mamba_mixer.py:93` never sees
  pg_collection at all.

Fixing each at the source would mean patches across core (Mamba) plus a
partial walk of RADIO's HF wrapper, validated against all existing
non-hetero users of those modules. The walker is the smaller intervention:
one place, hasattr-guarded, applied per branch with the correct pg.

Re-enables the walker (it was already in PR1; this commit only updates
the docstring to record the experiment's findings).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Switch back to ShardedObject for param_state_sharding_type (NVIDIA#4791)

Kamran reverted the metadata-stash approach in NVIDIA#4801 (discussion
r3250847203) and adopted Li Ding's PR NVIDIA#4791 pattern, which is the same
ShardedObject round-trip we had originally. Align our MimoOptimizer with
that final shape:

- Restore `_extract_param_state_sharding_type` / `_restore_param_state_sharding_type`
  helpers. Hooks back into the existing `_iter_optimizer_sub_dicts` loop.
- Add `if not opt_sub: del sub_sd['optimizer']` to `_extract_param_groups`
  (from NVIDIA#4791) so the now-empty `'optimizer'` wrapper doesn't round-trip
  through common-state with undefined behavior on the load side.
- Drop `self._last_sharded_metadata` and the metadata-stash recover path
  from `load_state_dict` / `sharded_state_dict`. The ShardedObject route
  is self-contained and doesn't need caller-state coupling.

Kept (not in NVIDIA#4791, specific to our non-colocated hetero layout):
- `add_prefix_for_sharding(module_sd, f'mimo.{name}.')` so the two
  branches' identical inner ShardedObject keys (e.g.
  `chained_0.optimizer.distributed.dp_group_idx_0.*`) don't collide.
- `_get_replica_id` returning `(tp_rank, pp_rank, dp_rank)` (from NVIDIA#4801).

Validated on cw-dfw 1-node 8-GPU 20L mock (stage2, DistributedOptimizer
+ ChainedOptimizer + EP=4 + TP=2): save iter 3 exit 0, reload + resume
iter 4 with cosine LR continuation (1.59e-4 → 1.32e-4 → 1.01e-4),
matching loss trajectory across the boundary.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Drop _stamp_tp_group walker; fix the three constructors at the source

Three modules under our hetero save path don't store `self.tp_group` in
their constructors and therefore trip `MegatronModule.sharded_state_dict`'s
parallel_state fallback (`megatron/core/transformer/module.py:85`) in
heterogeneous-parallelism layouts where parallel_state is intentionally
not initialized. Fix them at the source instead of papering over with the
hasattr-guarded walker:

- `megatron/core/models/vision/radio.py:RADIOViTModel.__init__` — already
  extracts `tp_group` at line 129 for the embedder; now also stamps
  `self.tp_group = tp_group`.
- `megatron/core/ssm/mamba_layer.py:MambaLayer.__init__` — takes
  pg_collection and plumbs it into the mixer; now also stores
  `self.tp_group = pg_collection.tp` on the layer itself.
- `megatron/core/ssm/mamba_mixer.py:ExtendedRMSNorm` — adds an
  `__init__(*args, tp_group=None, **kwargs)` override that stores
  `self.tp_group` eagerly, and updates the single call site at line ~369
  to pass `tp_group=self.pg_collection.tp`. The lazy `hasattr` fallback
  inside `sharded_state_dict` is preserved for callers that don't pass
  tp_group.

With these three constructor fixes in place, the
`_propagate_tp_groups_for_checkpoint` walker (and `_stamp_tp_group`
helper) in `examples/mimo/training/hetero/runtime.py` is no longer needed.
Removed entirely.

Validated on cw-dfw 1-node 8-GPU 20L mock with the walker disabled:
- save iter 3 exit 0 (DistributedOptimizer + ChainedOptimizer + EP=4 + TP=2)
- reload iter 3 → resume iter 4-5 with cosine LR continuation
  (1.59e-4 → 1.32e-4 → 1.01e-4), exit 0
- losses match prior runs (iter 1: 12.187, iter 2: 12.190, iter 3: 12.177,
  resume iter 4: 11.817, iter 5: 11.264)

The downstream check `if not hasattr(self, 'tp_group')` in subsequent
descendants (TransformerBlock, TransformerLayer, Attention, MLP,
ColumnParallelLinear) was already satisfied by their own constructors;
verified by reading those files.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t_param_groups all-gather (#27)

Two convergent fixes for MIMO + frozen modules under non-colocated
parallelism, cherry-picking the pattern from NVIDIA#4790
(Li Ding, "Fix MIMO optimizer setup for frozen modules") with comments
updated to cite the source and link to NMFW-464 context.

## Problem 1 — placeholder optimizer for all-frozen modules

`get_mimo_optimizer` previously called `get_megatron_optimizer` on every
non-None module, including modules whose params are all frozen on every
rank in the module's group. The most common trigger is
`--training-stage stage1` (the LLaVA projector-only recipe), which sets
`--freeze-vit --freeze-lm` and leaves the language model with zero
trainable parameters on the LLM ranks. The resulting placeholder
DistributedOptimizer either crashes in downstream setup or behaves
silently incorrectly (e.g., LR scheduler advancing an empty group, save
of a degenerate optimizer state).

## Problem 2 — `_get_param_groups` all-gathers over WORLD

`_get_param_groups` reconciles param-group keys across ranks of the same
model via `all_gather_object(params_key, …)` over the global default
group. In non-colocated MIMO, encoder ranks and LLM ranks are disjoint
and own different params (RADIO vs Mamba), so a WORLD-group all-gather
pollutes both branches with the other branch's keys.

## Fixes

- `megatron/core/models/mimo/optimizer.py`: add
  `_module_has_any_trainable_parameters(module, pg_collection)` — an
  all-reduce-MAX over `pg_collection.intra_dist_opt` of the local
  trainable-param count. Gate the `get_megatron_optimizer` call on it.
  When false, leave `info.optimizer = None` so
  `MimoOptimizer.is_stub_optimizer` handles the branch (that path was
  already designed for this; the gating never triggered it before).
- `megatron/core/optimizer/__init__.py`: add an optional
  `process_group=None` kwarg to `_get_param_groups` /
  `_get_param_groups_and_buffers`, plumbed through
  `_get_megatron_emerging_optimizer` and `get_megatron_optimizer`, so the
  cross-rank `all_gather_object` can target a specific group. MIMO passes
  `pg_collection.intra_dist_opt`. Default `None` preserves the current
  WORLD-group behavior for every existing non-MIMO caller.

## Validation

cw-dfw 1-node 8-GPU 20L mock, stage1 (vit + lm frozen, projector-only):
- 3 iters, exit 0
- grad norm 0.025 / 0.020 / 0.016 (consistent with only projector params
  having live grads)
- The `learning rate: ...` field in the iteration log is now absent —
  the previous-PR-era smoke printed it, because the placeholder language
  optimizer was being queried for an LR; with this fix the language
  optimizer is correctly `None` and the logger has nothing to query.
  That's the visible signature of the fix landing.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reproduces Sanjeev's `examples/multimodal/v3/pretrain_3b_nano_vlm_sota_90t_10v.sh`
(`sasatheesh/megatron-lm!45`, job 202967) on 9 HEL nodes with our hetero MIMO
training loop, modulo three items deferred to follow-up PRs (load_vision_from,
correct_encoder_grad_for_partial_participation, dataloader_save) and MTP.

Layout
- LLM grid: TP=2, EP=16, DP=32 (encoder grid unchanged at TP=1, DP=8, EP=1).
- 9n sbatch wrapper bumps `LLM_TP 4 -> 2`, `LLM_DP 16 -> 32`, `NUM_WORKERS 0 -> 2`.

Model provider (`nemotron_moe_vlm.py`)
- `moe_aux_loss_coeff: 1e-9 -> 1e-4` for non-trivial router pressure.
- `bias_dropout_fusion: False -> True` for args-dump parity (inert under
  add_bias_linear=False but matches Sanjeev).
- Make `mamba_num_groups=8`, `mamba_state_dim=128`, `linear_conv_kernel_dim=4`
  explicit in `nemotron_language_config` (mcore defaults already match; we
  declare to be defensive).
- Pass `share_embeddings_and_output_weights=False` to `MambaModel` for
  explicit `untie_embeddings_and_output_weights=True` parity.

Vision encoder / data path
- Add `--dynamic-resolution / --no-dynamic-resolution`,
  `--dynamic-resolution-min-patches`, `--dynamic-resolution-max-patches` CLI
  flags. Default `dynamic_resolution=True` for `nemotron-moe-vlm-*` providers.
- Pin `args.use_thumbnail=False` and `args.use_tiling=False` under dynamic
  resolution so `DynamicResolutionImageTilingStrategy` does not emit an extra
  thumbnail tile (Sanjeev's run has both False).
- Thread `dynamic_resolution_min/max_patches` and `dynamic_resolution_min/
  max_side` through `VisionConfig` in `energon_multimodal_provider.py`.

Optimizer + WSD scheduler
- Add `--train-samples`, `--lr-warmup-samples`, `--lr-decay-samples`,
  `--lr-wsd-decay-samples`, `--lr-wsd-decay-style`. Extend `--lr-decay-style`
  choices to include `WSD`.
- `validate_args` derives `train_iters = ceil(train_samples / gbs)` and
  enforces WSD requires both wsd_decay_samples and wsd_decay_style.
- `build_optimizer_param_scheduler` honors the sample-based knobs (taking
  precedence over iter-based) and passes `wsd_decay_steps` /
  `lr_wsd_decay_style` to `OptimizerParamScheduler`.
- `run_hetero` lifts `LR/MIN_LR/WEIGHT_DECAY/LR_DECAY_STYLE` to env-var
  defaults and threads optional sample-based knobs through.

DDP
- Add `--overlap-param-gather`, `--ddp-num-buckets`, and
  `--ddp-pad-buckets-for-high-nccl-busbw` CLI flags.
- `runtime._resolve_bucket_size` derives DDP bucket_size from
  num_parameters // num_buckets when `--ddp-num-buckets` is set, else honors
  `--ddp-bucket-size`, else returns None for mcore auto-default.
- Vision DDP keeps both overlap knobs forced OFF (partial-participation
  safety: text-only batches leave some encoder DP ranks with zero grads).

Self-contained parity sbatch
- `sbatch_hetero_nemotron_54l_hel_9n_parity.sh`: every training value pinned
  inline (no `${VAR:-default}` fallbacks). GBS is canonical and
  NUM_MICROBATCHES is derived as `GBS / (MBS * LLM_DP) = 24`.
- Uses `megatron-venv-baked-206674.sqsh`; staged tokenizer + post-c-radio-omni
  encoder under `agents-scratch`.
- Sanjeev-faithful values: lr=1.2e-3, min_lr=1.2e-5, wd=0.1, WSD with
  minus_sqrt tail, warmup 1.024M samples, decay 35.6M samples, wsd-decay
  5.49M samples, train_samples 36.62M (~300B tokens, 47684 iters at gbs=768),
  log_interval=100, save_interval=1000, seed=1234, class_token_len=10,
  image_tag_type=internvl, max_num_tiles=1, overlap-grad-reduce,
  overlap-param-gather, ddp-num-buckets=8, ddp-pad-buckets.
- `--load-vision-from` commented out pending PR_load_vision_from.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add Nemotron-format ckpt loader for hetero MIMO

- Add load_nemotron_vlm_ckpt_hetero and --load-nemotron-checkpoint flag for
  loading pre-vlm-05 Nemotron-format VLM dist-ckpts from the hetero pipeline.
- After the custom load, refresh the DistributedOptimizer's FP32 main-param
  shards via optimizer.reload_model_params(); the standard load_checkpoint
  path does this automatically, but the custom loader bypasses it. Without
  this the optimizer steps with the model-provider init weights instead of
  the loaded ckpt weights.
- Seed python random + numpy + torch in the hetero entry to match
  Megatron's _set_random_seed (energon's text_packing shuffle uses the
  global random module).
- Add --correct-encoder-grad-for-partial-participation flag (consumed in a
  follow-up commit by grad_sync.py).
- Add --train-samples flag (samples-based budget; --train-iters is derived).
- Fix RADIO pos-embedding bilinear interpolation to align_corners=False to
  match upstream RADIO.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Hetero correctness fixes: fp32 grad reduce + encoder grad participation

- runtime.py: set grad_reduce_in_fp32=True on both language and vision DDP
  configs (mirrors --accumulate-allreduce-grads-in-fp32). The default False
  produces bf16 main_grad, which drifts step-2 weights after Adam.
- grad_sync.py: when only some encoder DP ranks process images in a step,
  scale vision grads post-DP-reduce by encoder_dp_size / participation_count.
  Without this the vision encoder learns at a diluted rate.
- nemotron_moe_vlm.py: set moe_router_fusion=False to match the
  TransformerConfig default. The fused softmax/topk kernel takes a different
  bf16 reduction path, slightly perturbing router probs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Hetero data pipeline parity

- energon_multimodal_provider.py: forward HF tokenizer's chat_template and
  apply_chat_template through TokenizerAdapter so energon's
  tokenize_and_prepare can find them. Add _supported_kwargs filter on
  VisionConfig so the same recipe args work across energon versions whose
  VisionConfig accepts different kwargs.
- hetero_energon.py: use get_savable_loader (SavableDatasetWrapper sets a
  worker_id_offset and per-worker init that affects step-0 sample order).
  Use the unsalted seed in the single-lane iterator; energon's
  WorkerConfig(rank=lane, world_size=llm_dp) already salts per-rank, so
  adding a +lane offset over-salts and de-aligns sample ordering.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Parity sbatch scripts + broaden modelopt import guard

- Add sbatch_hetero_parity_100step.sh and sbatch_sanjeev_parity_100step.sh
  as paired 150-step train-loss parity drivers (hetero MIMO vs reference
  recipe). Settings match the Sanjeev-202967 hetero arg-parity table.
- training.py: broaden the modelopt distill-plugin import guard from
  ImportError to Exception. Some container builds ship modelopt against a
  transformers version that removed transformers.modeling_utils.Conv1D, so
  importing the distill plugin raises AttributeError during module init.
  Distill is optional, so skip safely.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Address PR review comments

- Revert training.py modelopt import-guard change (not needed for hetero
  correctness; container-specific compat issue tracked elsewhere).
- Trim verbose docstrings in model_helpers.py (_load_submodule_from_ckpt,
  load_nemotron_vlm_ckpt_hetero) and args.py --load-nemotron-checkpoint help.
- hetero_energon.py: drop the WorkerConfig-salting comment and the over-
  defensive try/except wrapping get_savable_loader. Match pre-vlm-05 exactly:
  one direct call with cache_pool=NoCachePool() and the watchdog kwargs.
- grad_sync.py: replace expensive (buffer.grad_data != 0).any() scans
  with a one-bool participation flag set by forward_step from batch.images.
  Combine the per-token normalization and the partial-participation
  correction into a single scale_gradients call per submodule (was two
  separate kernel launches before).
- step.py: call mark_modality_participation in forward_step and
  reset_modality_participation at the top of each train_step.
- loop.py: extract the --load-nemotron-checkpoint branch into
  load_and_refresh_nemotron_checkpoint helper in model_helpers.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Point sanj parity sbatch at clean pre-vlm-05 clone

SANJEEV_REPO now defaults to ${SCRATCH_ROOT}/sanjeev-repos/megatron-lm-clean,
a fresh checkout of sasatheesh/pre-vlm-05 with only the two correctness
changes needed for the parity baseline (model.py calculate_per_token_loss
honors --calculate-per-token-loss; recipe sh passes the flag). All NMFW
debug instrumentation that lived on the old sanjeev-repos/megatron-lm
checkout is dropped from this baseline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* RADIO: set mtp_num_layers=0 so final layernorm applies

With post_process=False on the RADIO TransformerBlock and
mtp_num_layers defaulting to None, has_final_layernorm_in_this_stage
short-circuits to False and the final layernorm is dropped. The sanj
recipe passes --mtp-num-layers 0, which takes the alternate branch and
keeps the layernorm on the stage that holds layer.num_layers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* RADIO: pass ln_post_impl=TENorm to RADIOViTModel

The RADIO post-encoder layernorm lives on RADIOViTModel as self.ln_post,
applied after the decoder in radio.py:239-240. The ckpt converter writes
RADIO's inner.norm.{weight,bias} into ln_post.{weight,bias}; without
ln_post_impl=TENorm the wrapper leaves self.ln_post=None and the loader
silently drops the ckpt entries (StrictHandling.LOG_UNEXPECTED), so the
vision tower output is the raw decoder hidden state instead of its
post-norm. Sanj-side llava_model.py:309-311 sets the same impl.

Revert previous mtp_num_layers=0 attempt -- that was targeting
TransformerBlock.final_layernorm, which is a different module and not
where the ckpt weight lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Revert RADIO layernorm experiments

Both attempts targeted layernorms that aren't actually missing for this
RADIO variant:
- mtp_num_layers=0 (0b5996c) targeted TransformerBlock.final_layernorm,
  but that gating returns False here for unrelated reasons.
- ln_post_impl=TENorm (82ce396) builds RADIOViTModel.ln_post, but the
  sanj iter_1000 ckpt has no ln_post.* keys -- llava_model.py only sets
  ln_post_impl=TENorm for vision_model_type=='radio-g', not for the
  cradio variant we run.

Confirmed by 266840's load failure: model requested vision_model.ln_post.*
but dist_checkpointing flagged them as unexpected (not found in ckpt).

The actual sanj parity baseline (266771, num_layers=54) still sits at
iter1=2.898, hetero 266780 at iter1=2.757 -- same pattern as prior runs,
no missing-norm regression.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…der rank (#31)

Replace the per-LLM-lane Energon iterators owned by each encoder DP rank with
a single multiplexed iterator whose worker pool is sized
``args.num_workers * lanes_per_encoder``. Samples are routed back to their
owning LLM lane using the producing worker's ``WorkerConfig.global_worker_id``,
which the MIMO multimodal encoder now stamps onto every batch when
``attach_provenance=True``.

At scale (encoder_dp small relative to llm_dp), the per-lane construction path
issues ``lanes_per_encoder × num_workers`` shard-open events at iterator
creation; collapsing to one iterator per encoder rank cuts the open-burst by
``lanes_per_encoder``× and avoids the previous workaround scripts that
staggered loader construction across encoder ranks.

The reshape preserves bit-wise sample parity with the per-lane path:

- ``global_workers = world_size * num_workers`` is invariant under
  ``(world_size, num_workers) → (world_size/k, num_workers*k)``.
  ``WebdatasetSharder.split_samples_to_workers`` partitions shards by global
  worker index over ``global_workers``, so equal global_worker_ids ⇒ equal
  shards in equal order.
- ``WorkerConfig.worker_seed`` hashes only ``(global_worker_id, seed_offset)``
  (see ``megatron/energon/worker.py``); ``seed_offset`` is unchanged.
- The routed-iterator's worker W on encoder rank E has
  ``global_worker_id = E * (num_workers * lanes_per_encoder) + W``, which
  equals the per-lane worker w on lane L=E*lanes_per_encoder + W//num_workers,
  w = W%num_workers.

``test_hetero_energon.py`` adds unit tests for ``_route_samples_to_lanes``
(round-robin fill, surplus FIFO, lane-offset shift, pull-budget overflow,
out-of-range worker id, missing provenance) plus an algebraic parity test
asserting the global_worker_id equivalence above and a global_workers-invariant
sweep across (encoder_dp, llm_dp, num_workers) shapes.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…o MIMO VLM (#32)

1) Dynamic-resolution RADIO ViT
   Patchify each image at its native aspect ratio with a token budget,
   interleave per-tile class tokens, run THD-packed attention through the
   transformer block. Line-for-line equivalent to sanj's pre-vlm-05
   RADIOViTModel.forward dynres branch.

   - megatron/core/models/vision/radio.py: add dynamic_resolution kwarg,
     imgs_sizes + packed_seq_params forward kwargs, per-tile apply_pos_enc,
     interleaved CLS-token concat, packed_seq_params.cu_seqlens shift.
   - examples/mimo/data/energon_multimodal_provider.py: per-image n_tokens
     under dynamic_resolution; surface imgs_sizes + PackedSeqParams.
   - examples/mimo/model_providers/nemotron_moe_vlm.py: RADIOEncoderWrapper
     accepts dynres kwargs; interleaved-CLS removal mask; per-tile
     pixel-shuffle.
   - examples/mimo/training/hetero/step.py: PackedSeqParams cu_seqlens /
     max_seqlen tensors moved to CUDA before the encoder forward (TE THD
     attention hangs on H2D-sync of these tensors otherwise).

2) RADIO encoder final_layernorm parity with sanj
   Sanj's --mtp-num-layers 0 propagates to vision_config.mtp_num_layers via
   core_transformer_config_from_args. In TransformerBlock.has_final_layernorm_in_this_stage,
   the else-branch (mtp_num_layers is not None) allocates final_layernorm
   when the last decoder layer is in this stage — independent of
   post_process. Without this, hetero's frozen RADIO output magnitude was
   ~150x larger than sanj's (5344 vs 35.75) because the final LN never
   ran. The downstream projection + LLM were trained against the
   LN-normalized magnitude (gamma=1 / bias=0 from ckpt).

   - examples/mimo/model_providers/nemotron_moe_vlm.py:radio_vision_config
     sets config.mtp_num_layers = 0.
   - examples/mimo/utils/model_helpers.py:load_nemotron_vlm_ckpt_hetero
     switches dist_checkpointing.load to StrictHandling.RETURN_ALL and
     raises on any non-extra_state key the model requests but the ckpt
     does not have (mcore's "unexpected" set, opposite of PyTorch's
     "missing"). Weaker modes silently keep random-init values and
     masked this bug.

Verified at iter 1, GBS=8 (3-node hetero + 2-node sanj parity pair):
  vision_proj  hetero absmax=584  sanj absmax=584  cos=1.000010
  combined_embeddings  cos=1.000011
  iter-1 lm_loss  hetero=2.756  sanj=2.750  Δ=+0.22%
  iter-2/3 lm_loss  rel Δ < 1.1%

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- args.py: new --tensorboard-dir flag.
- logging.py: HeteroTrainingLogger creates a torch.utils.tensorboard
  SummaryWriter on the language logging rank when --tensorboard-dir is set.
  Emits the same scalar keys Megatron's standard training_log uses (lm loss,
  learning-rate, grad-norm, batch-size, loss-scale, num-zeros) plus
  iteration-time-ms, both per-iter and "vs samples", so TB plots overlay
  cleanly against the reference run's logs.
- sbatch_hetero_parity_100step.sh: pass --tensorboard-dir "${RUN_DIR}/tensorboard"
  so parity runs auto-log to a per-run TB dir.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #33's MoE-tracker block gated on args.num_experts +
args.moe_router_load_balancing_type, but hetero stores those as
args.num_moe_experts and on the TransformerConfig (not args). The gate
never fired, so seq_load_balancing_loss never reached TB.

Fix: gate on args.num_moe_experts; hardcode track_names to
["seq_load_balancing_loss"] (the only LB type Nemotron6-MoE uses);
compute num_moe_layers from args.hybrid_layer_pattern's E-count so the
per-iter average over MoE layers matches sanj's training_log output
(54L pattern has 27 'E' layers, not 54).

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
#35)

When encoder_dp < llm_dp, the routed encoder iterator merges N per-lane
batches into one encoder forward via _combine_encoder_batches ->
_concat_nested_tensors. After PR #32 (dynres) added a PackedSeqParams
dataclass into modality_inputs.images.<encoder>.packed_seq_params, two
bugs in this merge path surface together on multi-lane configurations
(e.g. 9n GBS=192: ENCODER_DP=8, LLM_DP=32, lanes_per_encoder=4):

1) TypeError: cannot concatenate encoder batch value of type
   PackedSeqParams. PR #31 only handled torch.Tensor and dict; the
   dataclass fell through to the final raise.

2) RuntimeError: Sizes of tensors must match except in dimension 0.
   Per-lane packed image buffers have shape (1, T_lane, C) -- dim 0 is
   the constant lane batch and T_lane varies. PR #31's blanket
   torch.cat(present, dim=0) requires every other dim to match and
   fails on the T axis.

The combination of (1) + (2) produced the 0-15-silent / 16-71-NCCL-
timeout cascade observed on the 9n parity sbatch: encoder rank 0's
loader worker died in torch.cat, lanes 0-3 stalled at the bridge, the
remaining LLM ranks reached the DDP allgather and timed out after 10
minutes.

Fix:
  - _concat_packed_seq_params merges N per-lane PackedSeqParams into
    one set covering the merged flat buffer. cu_seqlens_{q,kv}[_padded]
    concatenate with running offset = sum of prior lanes'
    cu_seqlens_q[-1] (the same offset-shift rule used in
    megatron.energon.task_encoder.multimodal.encoder.py); max_seqlen
    takes element-wise max; total_tokens sums. qkv_format,
    local_cp_size, cp_group are asserted equal across lanes. seq_idx
    is left to PackedSeqParams.__post_init__.
  - _concat_first_varying_dim concatenates plain tensors along the
    first dimension whose size differs across lanes, defaulting to
    dim 0 when all shapes agree. This handles (1, T_lane, C) packed
    image buffers (dim 1) and (N_images_lane, 2) imgs_sizes (dim 0)
    uniformly without sibling-key context, and preserves the prior
    behavior on non-dynres batches.

Verified by running the existing 9n GBS=192 parity sbatch end-to-end:
30-iter smoke reaches steady-state iter time around 4.5-6s with a
clean lm-loss trajectory; previously the same sbatch crashed inside
_combine_encoder_batches before iteration 1.
…ge (#36)

PR #35 introduced ``_concat_first_varying_dim`` that selected the concat
dim at runtime by looking for the first dim whose size differed across
per-lane tensors, falling back to dim 0 when no dim varied. That worked in
the common case but fails when two participating lanes happen to produce
identically-shaped packed image buffers in the same step: the (1, T, C)
buffer falls back to dim-0 cat and becomes (2, T, C) instead of
(1, 2T, C). RADIO then splits the buffer using imgs_sizes (correctly
cat'd to (2N, 2)) and asserts ``sum(seq_lens) != x.shape[1]`` with an
exact 2x ratio — visible in the 1-node standalone smoke at
lanes_per_encoder=16 as

  AssertionError: 15984 != 7992 at radio.py:235

In production hetero training, the rank dies silently in encoder forward;
its served LLM lanes stall on bridge recv; the cluster cascades into a
600 s NCCL watchdog timeout that looked like a different bug. The
probability of two lanes producing identical (1, T, C) per step grows
with ``lanes_per_encoder``: rare at 4 (9n), common at 16 (33n). All
"33n hangs" we chased were instances of this.

Replace the runtime inference with a schema-aware merger that knows the
fixed structure of ``modality_inputs``:

  packed image buffer (1, T_lane, C)  -> torch.cat dim 1
  imgs_sizes          (N_images,  2)  -> torch.cat dim 0
  packed_seq_params   PackedSeqParams -> _concat_packed_seq_params

Anything unrecognized raises a loud ``TypeError`` so a future schema
change has to be handled in ``_merge_encoder_inputs`` rather than
silently miscompiled by a heuristic.

Validated end-to-end:

  * 1-node standalone (lanes_per_encoder=16, 200 steps): all 8 ranks
    complete with no AssertionError, no hang. Previously failed at
    step 27 on rank 5.
  * 9n GBS=192 production smoke: 25/25 iters, ~5.5 s/iter steady state.
  * 17n GBS=384 production smoke: 25/25 iters, ~6.2 s/iter steady state.
  * 33n GBS=768 production smoke: 25/25 iters, ~7.0 s/iter steady state
    (matches the scaling-study report's 33n target of 7.11 s/iter).
    Previously hung at iter 5 across multiple attempts.
… Sanjeev LR schedule) (#37)

* NMFW-464: production hetero scaling sbatches at 33n / 68n / 100n

Adds three production-grade hetero MIMO Nemotron6-MoE VLM training sbatches
that mirror Sanjeev's pretrain_3b_nano_vlm_sota_90t_10v.sh schedule:

  * sbatch_hetero_prod_gbs768_33n_ep8.sh — 33 nodes (1 enc + 32 LLM)
  * sbatch_hetero_prod_gbs768_68n_ep8.sh — 68 nodes (4 enc + 64 LLM)
  * sbatch_hetero_prod_gbs768_100n.sh    — 100 nodes (4 enc + 96 LLM)

Pinned from Sanjeev's baseline:
  - TRAIN_SAMPLES=122070313
  - LR_WARMUP_SAMPLES=1024000
  - LR_DECAY_SAMPLES = TRAIN_SAMPLES - LR_WARMUP_SAMPLES = 121046313
  - LR_WSD_DECAY_SAMPLES=18310547
  - LR_WSD_DECAY_STYLE=minus_sqrt
  - PACKING_BUFFER_SIZE=128
  - NUM_WORKERS=1
  - LOG_INTERVAL=100, SAVE_INTERVAL=1000
  - --load-nemotron-checkpoint pointing at sasatheesh iter_1000

Deviations from Sanjeev's baseline:
  - LLM_EP=8 (vs Sanjeev's EP=16)
  - Hetero topology TP=2 (vs Sanjeev's TP=4); explicit encoder grid
  - MOE_ROUTER_FORCE_LOAD_BALANCING=0 (natural seq_aux_loss)
  - No MTP layers
  - Wall time 4h (Sanjeev's; restartable from save-interval-1000 checkpoints)

These sbatches are derived from the scaling-study templates
(sbatch_hetero_parity_gbs768_{33n_ep8,68n_ep8,100n}.sh on the
ykarnati/nmfw-464-encoder-stall-profiling branch). The timeline-profile
instrumentation and the num-distributed-optimizer-instances flag from
that branch are dropped here — they are debug-only and not present in
the production codebase on ykarnati/nmfw-464-nemotron-vlm-with-hetero-parallel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* NMFW-464: drop obsolete early-development 9n/8n smoke sbatches

Removes five sbatch scripts that were only used during early hetero MIMO
development and have no external references in the repo (no docs, no
Python code, no CI). They are superseded by the new production sbatches
(33n/68n/100n) added in this PR, plus the existing parity templates
sbatch_hetero_parity_gbs192.sh and sbatch_hetero_parity_gbs32.sh.

Removed:

  * sbatch_hetero_nemotron_54l_hel_9n.sh
      30-iter 9n smoke test. Smoke-only; no production use.
  * sbatch_hetero_nemotron_54l_hel_9n_text_only.sh
      text-only data-blend smoke test. exec'd _9n.sh — orphaned by its
      removal.
  * sbatch_hetero_nemotron_54l_hel_9n_text_vision.sh
      90/10 text-vision blend smoke test. exec'd _9n.sh — orphaned by
      its removal.
  * sbatch_hetero_nemotron_54l_hel_9n_parity.sh
      Standalone 9n Sanjeev-parity reproduction. Functionally superseded
      by sbatch_hetero_parity_gbs192.sh (also 9n + Sanjeev recipe, paired
      with sbatch_sanjeev_parity_gbs192.sh).
  * sbatch_mimo_nemotron_54l_hel_8n_text_only_llm.sh
      8n LLM-only text-only variant. exec'd _9n.sh — orphaned by its
      removal.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* NMFW-464: add concise README for the hetero sbatch directory

One table covering the 9n parity sbatch and the three production sbatches
(33n / 68n / 100n) with their topology, GBS, and purpose.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant