Skip to content

Add PP>1 support for LLM in colocated MIMO training (NMFW-19)#9

Draft
yashaswikarnati wants to merge 11 commits into
ykarnati/nmfw-17-colocated-bridge-pp1from
ykarnati/nmfw-19-colocated-pp-support
Draft

Add PP>1 support for LLM in colocated MIMO training (NMFW-19)#9
yashaswikarnati wants to merge 11 commits into
ykarnati/nmfw-17-colocated-bridge-pp1from
ykarnati/nmfw-19-colocated-pp-support

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Owner

Summary

Stacked on top of NMFW-17 / PR A. This PR adds PP>1 support for the language model in colocated MIMO training. Encoder stays PP=1 on all ranks; LLM runs 1F1B pipeline over microbatches of the precomputed encoder embeddings.

Base: `ykarnati/nmfw-17-colocated-bridge-pp1` (PR A) — review only the delta introduced here.

Three-phase execution

  1. Encoder forward + communicate — full batch across all ranks via `encode_and_communicate()`; collective TP/DP transform happens once.
  2. LLM 1F1B pipeline — encoder embeddings are detached and sliced into microbatches; the stock `forward_backward_pipelining_without_interleaving` runs.
  3. Encoder backward — encoder grad is broadcast from PP rank 0 to the other PP ranks so all ranks participate in the encoder backward collective; encoder backward runs on the full batch.

Code changes

  • `MimoModel.init` detects LLM PP>1 from `module_to_grid_map` and overrides language `ModuleStageInfo` so stage flags reflect PP position.
  • `MimoModel.forward` routes non-first PP stages through `_forward_language_module` (receives hidden states via P2P, returns plain tensor instead of dict).
  • `_forward_all_modules` takes an optional `encoder_embeddings` dict to skip encoder forward per microbatch.
  • `encode_and_communicate()` encapsulates encoder forward + `_apply_colocated_comms`, used by Phase 1.
  • `colocated_schedule.py` implements `colocated_forward_backward_with_pp` with the three-phase driver, encoder-grad broadcast, and microbatch construction helpers.

Test plan

  • `tests/unit_tests/models/test_mimo_colocated_pp.py` — fan-in, equal-DP, and grad-accumulation cases at LLM PP=2 / encoder PP=1 on 8 GPUs.
  • All PR A tests continue to pass on this branch (no regressions in PP=1 path).

```bash
uv run python -m torch.distributed.run --nproc_per_node=8 -m pytest tests/unit_tests/models/test_mimo_colocated_pp.py -v
```

Review notes

This PR layers on top of PR A — please review the bridge primitive there first. All changes here are PP>1 specific: new schedule file, PP-aware hunks in `base.py`, and the PP test.

🤖 Generated with Claude Code

@yashaswikarnati yashaswikarnati force-pushed the ykarnati/nmfw-19-colocated-pp-support branch from dc3b45e to 32e2cab Compare April 21, 2026 17:42
raise ValueError(
f"{name} PP must be 1 for ColocatedBridgeCommunicator, got {pp_size}"
)
# Src (encoder) must be PP=1: encode_and_communicate runs on every
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove verbose comment

Comment thread megatron/core/models/mimo/model/base.py Outdated
# in TP/DP within those ranks.
self._build_colocated_communicators()

# Detect LLM PP>1 for three-phase colocated execution
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be concise and part of rankrole build ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
# Apply colocated communication if configured (no-op when colocated_comms is empty)
if self.colocated_comms:
modality_embeddings = self._apply_colocated_comms(modality_embeddings)
for modality_name, submodule in self.modality_submodules.items():
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse encode_and_communicate above?

return output_tensor, partial(_loss_func, cached['loss_mask'])

# Defer finalize until AFTER Phase 3. The inner PP schedule would call
# ``config.finalize_model_grads_func`` at end-of-schedule, which runs
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be concise, this is useful info. but can be conclisely said why we do this

# then invoke the original finalize once after Phase 3 so the single
# DP reduction covers both the LLM grads from Phase 2 and the encoder
# grads from Phase 3.
original_finalize = mimo_model.config.finalize_model_grads_func
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this cleanly managed by context manager?

@yashaswikarnati yashaswikarnati force-pushed the ykarnati/nmfw-19-colocated-pp-support branch from 29739f8 to 5af7002 Compare April 21, 2026 22:36
self._pp_size = pp_size
self._pp_group = MockProcessGroup(pp_rank, pp_size)

@property
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this change ? and shape property ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
self.lm_has_pp = lang_info is not None and not (
lang_info.is_first_stage and lang_info.is_last_stage
)
self.lm_is_first_pp_stage = lang_info is None or lang_info.is_first_stage
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is self.lm_is_first_pp_stage used ?

yashaswikarnati and others added 11 commits April 22, 2026 03:06
Three-phase execution for colocated encoder PP=1 + LLM PP>1:
  - Phase 1: Encoder forward + bridge communicate on the full batch, with
    all ranks participating in the collective.
  - Phase 2: 1F1B LLM pipeline over microbatch slices of the detached
    encoder embeddings.
  - Phase 3: Encoder backward on the full batch, with the encoder gradient
    broadcast from PP rank 0 to the other PP ranks first.

Changes:
  - MimoModel detects LLM PP>1 from module_to_grid_map and overrides the
    language ModuleStageInfo so is_first_stage / is_last_stage reflect PP
    position.
  - MimoModel.forward routes non-first PP stages through
    _forward_language_module using P2P hidden states and unwraps the dict
    return for the schedule.
  - _forward_all_modules accepts a precomputed encoder_embeddings dict to
    skip encoder forward inside each LLM microbatch iteration.
  - New encode_and_communicate() helper runs encoder forward + bridge
    transform; used by Phase 1 and reused by the 3-phase schedule.
  - colocated_schedule.py implements colocated_forward_backward_with_pp
    which drives the three phases and broadcasts encoder gradients.

Tests:
  - test_mimo_colocated_pp.py: fan-in, equal-DP, and grad-accumulation
    cases at LLM PP=2 / encoder PP=1 on 8 GPUs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The three-phase colocated schedule used to let the inner 1F1B PP schedule
invoke ``config.finalize_model_grads_func`` at end-of-schedule. At that
point the encoder has zero grads (Phase 3 has not run), so its DDP
``finish_grad_sync`` all-reduces zeros and the subsequent Phase 3 encoder
grads stay local to each rank — Adam then steps on un-reduced encoder
grads and diverges from an equal-DP reference. ``finish_grad_sync`` is
not idempotent, so a second post-Phase-3 call is unsafe; instead, swap
in a no-op during the PP schedule and invoke the user-provided finalize
once after Phase 3 so a single DP reduction covers both LLM (Phase 2)
and encoder (Phase 3) grads.

Tests: extend ``test_mimo_colocated_pp.py`` with a real post-step weight
oracle that compares the PP>1 dist run against an equal-DP PP=1 reference
(identity bridge). Adds PP-aware LLM weight reshaping so both models
start from identical state, runs one Adam step on each via their
respective schedules, then asserts shard-wise encoder equality within
bf16 tolerance. Parametrized for ``num_mb == pp`` and ``num_mb > pp``
to cover single-microbatch-per-stage and grad-accumulation-across-views
cases. Existing smoke tests also gain ``grad_norm > 0`` assertions and
params-changed snapshots to catch silently-zeroed encoder grads.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The PP>1 schedule orchestrates the LLM pipeline; the bridge only needs
src (encoder) PP=1 since encode_and_communicate runs on every rank
synchronously. For fan-in, gather groups are keyed by src position so
each rank lands in exactly one group regardless of its llm_pp index; the
EQUAL path does no collective at all.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The ValueError message is self-documenting; schedule rationale belongs
in colocated_schedule.py, not the communicator.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
RankRole.colocated now accepts the grid map and derives per-module
PP stage info from each grid's pp group, removing the post-build
mutation of self.role.modules from MimoModel.__init__.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collapse the inlined modality-forward block (duplicating
encode_and_communicate) into a single call, dropping redundant
per-modality debug logs.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collapse the finish_grad_sync exposition into three lines that keep
the why (encoder grads don't exist yet) without the DDP-internal
detail.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Introduce _deferred_finalize contextmanager that yields the original
callable, so the post-Phase-3 invocation keeps access to it while the
swap/restore logic is encapsulated.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…PR-10 oracle infra

Delete the smoke test (run_colocated_pp_test) and four smoke test
methods: the PP weight oracle already subsumes them. Delete the
DataIterator, _build_pp_oracle_model, grid/spec helpers, param-copy
helpers, shard-match helper, and the shared-batch generator — they are
1:1 duplicates of infra already exported from
test_mimo_1f1b_schedule.py and test_mimo_colocated_correctness.py.

Keep only the new piece specific to PP>1: _copy_llm_params_pp_aware,
with the unused same-TP fallback all-gather branch removed and the
docstring tightened. _run_pp_weight_oracle is rewritten as a short
driver built on the imported helpers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Drop unused lm_is_first_pp_stage on MimoModel and tighten the PP-flag
  derivation; lm_has_pp is the only flag actually consumed by the
  three-phase schedule.
- Drop MockGrid.shape + _pp_rank/_pp_size in test_mimo_model: unused
  after _colocated switched to grid.get_pg('pp') for PP-stage derivation.
- Tighten the Phase-1 detach comment in colocated_schedule to a single
  "why".
- Add PP-aware LLM weight parity check
  (_assert_llm_weights_match_pp_aware) alongside the existing encoder
  check in test_mimo_colocated_pp; dist PP>1 LLM shards must match the
  PP=1 ref via the same layer-index remap used for init.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
PR #10 replaced the ad-hoc gradient_reduce_div_factor DDP knob with
calculate_per_token_loss=True on both sub-model configs plus a custom
finalize hook that divides grads by the global valid-token count. The
three-phase PP schedule now has to forward the schedule's total_num_tokens
to the deferred finalize call, otherwise the hook's assertion fails and
per-token normalization never happens on the encoder/LLM grads.

* _loss_func now returns the 3-tuple (local_sum, local_num_tokens,
  log_dict) contract the schedule expects when per-token loss is on.
* _deferred_finalize swaps the finalize hook with a capturing stub that
  records the num_tokens the inner schedule would have passed; after
  Phase 3, we invoke the original finalize with the captured value.

test_mimo_colocated_pp: adopt per-token-loss wiring, add PP broadcast

_wire_training_hooks from the PR #10 correctness test only all-reduces
num_tokens over the LLM DP group. With LLM PP>1, non-last PP stages see
num_tokens=0 from the inner schedule (loss runs only on the last stage),
so the DP sum would land at N_last_stage instead of N_global and
encoder/LLM grads would end up scaled differently per PP stage.
_wire_pp_training_hooks broadcasts num_tokens from the last LLM PP rank
first, then all-reduces across DP — every rank arrives at the same
N_global. The PP test also drops the removed gradient_reduce_div_factor
kwarg, switches both models to fp32 / no-bias / no-dropout for exact
comparison, and uses the 3-tuple loss shape on the ref forward path.
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/nmfw-19-colocated-pp-support branch from 5af7002 to ed63e63 Compare April 22, 2026 04:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant