Add PP>1 support for LLM in colocated MIMO training (NMFW-19)#9
Draft
yashaswikarnati wants to merge 11 commits into
Draft
Add PP>1 support for LLM in colocated MIMO training (NMFW-19)#9yashaswikarnati wants to merge 11 commits into
yashaswikarnati wants to merge 11 commits into
Conversation
dc3b45e to
32e2cab
Compare
yashaswikarnati
commented
Apr 21, 2026
| raise ValueError( | ||
| f"{name} PP must be 1 for ColocatedBridgeCommunicator, got {pp_size}" | ||
| ) | ||
| # Src (encoder) must be PP=1: encode_and_communicate runs on every |
Owner
Author
There was a problem hiding this comment.
remove verbose comment
yashaswikarnati
commented
Apr 21, 2026
| # in TP/DP within those ranks. | ||
| self._build_colocated_communicators() | ||
|
|
||
| # Detect LLM PP>1 for three-phase colocated execution |
Owner
Author
There was a problem hiding this comment.
can this be concise and part of rankrole build ?
yashaswikarnati
commented
Apr 21, 2026
| # Apply colocated communication if configured (no-op when colocated_comms is empty) | ||
| if self.colocated_comms: | ||
| modality_embeddings = self._apply_colocated_comms(modality_embeddings) | ||
| for modality_name, submodule in self.modality_submodules.items(): |
Owner
Author
There was a problem hiding this comment.
can we reuse encode_and_communicate above?
yashaswikarnati
commented
Apr 21, 2026
| return output_tensor, partial(_loss_func, cached['loss_mask']) | ||
|
|
||
| # Defer finalize until AFTER Phase 3. The inner PP schedule would call | ||
| # ``config.finalize_model_grads_func`` at end-of-schedule, which runs |
Owner
Author
There was a problem hiding this comment.
this comment can be concise, this is useful info. but can be conclisely said why we do this
yashaswikarnati
commented
Apr 21, 2026
| # then invoke the original finalize once after Phase 3 so the single | ||
| # DP reduction covers both the LLM grads from Phase 2 and the encoder | ||
| # grads from Phase 3. | ||
| original_finalize = mimo_model.config.finalize_model_grads_func |
Owner
Author
There was a problem hiding this comment.
can this cleanly managed by context manager?
29739f8 to
5af7002
Compare
yashaswikarnati
commented
Apr 21, 2026
| self._pp_size = pp_size | ||
| self._pp_group = MockProcessGroup(pp_rank, pp_size) | ||
|
|
||
| @property |
Owner
Author
There was a problem hiding this comment.
why we need this change ? and shape property ?
yashaswikarnati
commented
Apr 21, 2026
| self.lm_has_pp = lang_info is not None and not ( | ||
| lang_info.is_first_stage and lang_info.is_last_stage | ||
| ) | ||
| self.lm_is_first_pp_stage = lang_info is None or lang_info.is_first_stage |
Owner
Author
There was a problem hiding this comment.
where is self.lm_is_first_pp_stage used ?
Three-phase execution for colocated encoder PP=1 + LLM PP>1:
- Phase 1: Encoder forward + bridge communicate on the full batch, with
all ranks participating in the collective.
- Phase 2: 1F1B LLM pipeline over microbatch slices of the detached
encoder embeddings.
- Phase 3: Encoder backward on the full batch, with the encoder gradient
broadcast from PP rank 0 to the other PP ranks first.
Changes:
- MimoModel detects LLM PP>1 from module_to_grid_map and overrides the
language ModuleStageInfo so is_first_stage / is_last_stage reflect PP
position.
- MimoModel.forward routes non-first PP stages through
_forward_language_module using P2P hidden states and unwraps the dict
return for the schedule.
- _forward_all_modules accepts a precomputed encoder_embeddings dict to
skip encoder forward inside each LLM microbatch iteration.
- New encode_and_communicate() helper runs encoder forward + bridge
transform; used by Phase 1 and reused by the 3-phase schedule.
- colocated_schedule.py implements colocated_forward_backward_with_pp
which drives the three phases and broadcasts encoder gradients.
Tests:
- test_mimo_colocated_pp.py: fan-in, equal-DP, and grad-accumulation
cases at LLM PP=2 / encoder PP=1 on 8 GPUs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The three-phase colocated schedule used to let the inner 1F1B PP schedule invoke ``config.finalize_model_grads_func`` at end-of-schedule. At that point the encoder has zero grads (Phase 3 has not run), so its DDP ``finish_grad_sync`` all-reduces zeros and the subsequent Phase 3 encoder grads stay local to each rank — Adam then steps on un-reduced encoder grads and diverges from an equal-DP reference. ``finish_grad_sync`` is not idempotent, so a second post-Phase-3 call is unsafe; instead, swap in a no-op during the PP schedule and invoke the user-provided finalize once after Phase 3 so a single DP reduction covers both LLM (Phase 2) and encoder (Phase 3) grads. Tests: extend ``test_mimo_colocated_pp.py`` with a real post-step weight oracle that compares the PP>1 dist run against an equal-DP PP=1 reference (identity bridge). Adds PP-aware LLM weight reshaping so both models start from identical state, runs one Adam step on each via their respective schedules, then asserts shard-wise encoder equality within bf16 tolerance. Parametrized for ``num_mb == pp`` and ``num_mb > pp`` to cover single-microbatch-per-stage and grad-accumulation-across-views cases. Existing smoke tests also gain ``grad_norm > 0`` assertions and params-changed snapshots to catch silently-zeroed encoder grads. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The PP>1 schedule orchestrates the LLM pipeline; the bridge only needs src (encoder) PP=1 since encode_and_communicate runs on every rank synchronously. For fan-in, gather groups are keyed by src position so each rank lands in exactly one group regardless of its llm_pp index; the EQUAL path does no collective at all. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The ValueError message is self-documenting; schedule rationale belongs in colocated_schedule.py, not the communicator. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
RankRole.colocated now accepts the grid map and derives per-module PP stage info from each grid's pp group, removing the post-build mutation of self.role.modules from MimoModel.__init__. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collapse the inlined modality-forward block (duplicating encode_and_communicate) into a single call, dropping redundant per-modality debug logs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collapse the finish_grad_sync exposition into three lines that keep the why (encoder grads don't exist yet) without the DDP-internal detail. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Introduce _deferred_finalize contextmanager that yields the original callable, so the post-Phase-3 invocation keeps access to it while the swap/restore logic is encapsulated. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…PR-10 oracle infra Delete the smoke test (run_colocated_pp_test) and four smoke test methods: the PP weight oracle already subsumes them. Delete the DataIterator, _build_pp_oracle_model, grid/spec helpers, param-copy helpers, shard-match helper, and the shared-batch generator — they are 1:1 duplicates of infra already exported from test_mimo_1f1b_schedule.py and test_mimo_colocated_correctness.py. Keep only the new piece specific to PP>1: _copy_llm_params_pp_aware, with the unused same-TP fallback all-gather branch removed and the docstring tightened. _run_pp_weight_oracle is rewritten as a short driver built on the imported helpers. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Drop unused lm_is_first_pp_stage on MimoModel and tighten the PP-flag
derivation; lm_has_pp is the only flag actually consumed by the
three-phase schedule.
- Drop MockGrid.shape + _pp_rank/_pp_size in test_mimo_model: unused
after _colocated switched to grid.get_pg('pp') for PP-stage derivation.
- Tighten the Phase-1 detach comment in colocated_schedule to a single
"why".
- Add PP-aware LLM weight parity check
(_assert_llm_weights_match_pp_aware) alongside the existing encoder
check in test_mimo_colocated_pp; dist PP>1 LLM shards must match the
PP=1 ref via the same layer-index remap used for init.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
PR #10 replaced the ad-hoc gradient_reduce_div_factor DDP knob with calculate_per_token_loss=True on both sub-model configs plus a custom finalize hook that divides grads by the global valid-token count. The three-phase PP schedule now has to forward the schedule's total_num_tokens to the deferred finalize call, otherwise the hook's assertion fails and per-token normalization never happens on the encoder/LLM grads. * _loss_func now returns the 3-tuple (local_sum, local_num_tokens, log_dict) contract the schedule expects when per-token loss is on. * _deferred_finalize swaps the finalize hook with a capturing stub that records the num_tokens the inner schedule would have passed; after Phase 3, we invoke the original finalize with the captured value. test_mimo_colocated_pp: adopt per-token-loss wiring, add PP broadcast _wire_training_hooks from the PR #10 correctness test only all-reduces num_tokens over the LLM DP group. With LLM PP>1, non-last PP stages see num_tokens=0 from the inner schedule (loss runs only on the last stage), so the DP sum would land at N_last_stage instead of N_global and encoder/LLM grads would end up scaled differently per PP stage. _wire_pp_training_hooks broadcasts num_tokens from the last LLM PP rank first, then all-reduces across DP — every rank arrives at the same N_global. The PP test also drops the removed gradient_reduce_div_factor kwarg, switches both models to fp32 / no-bias / no-dropout for exact comparison, and uses the 3-tuple loss shape on the ref forward path.
5af7002 to
ed63e63
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on top of NMFW-17 / PR A. This PR adds PP>1 support for the language model in colocated MIMO training. Encoder stays PP=1 on all ranks; LLM runs 1F1B pipeline over microbatches of the precomputed encoder embeddings.
Base: `ykarnati/nmfw-17-colocated-bridge-pp1` (PR A) — review only the delta introduced here.
Three-phase execution
Code changes
Test plan
```bash
uv run python -m torch.distributed.run --nproc_per_node=8 -m pytest tests/unit_tests/models/test_mimo_colocated_pp.py -v
```
Review notes
This PR layers on top of PR A — please review the bridge primitive there first. All changes here are PP>1 specific: new schedule file, PP-aware hunks in `base.py`, and the PP test.
🤖 Generated with Claude Code