Skip to content

[Fork review] NMFW-17: dest CP>1 support on top of pr-a#11

Draft
yashaswikarnati wants to merge 7 commits into
ykarnati/nmfw-17-colocated-bridge-pp1from
ykarnati/nmfw-17-colocated-cp
Draft

[Fork review] NMFW-17: dest CP>1 support on top of pr-a#11
yashaswikarnati wants to merge 7 commits into
ykarnati/nmfw-17-colocated-bridge-pp1from
ykarnati/nmfw-17-colocated-cp

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Owner

Fork-internal PR for reviewing the pr-c CP>1 diff. Base is pr-a (ykarnati/nmfw-17-colocated-bridge-pp1), so the diff shown here is only the CP-specific changes.

Commits (newest → oldest)

  • 7711ed28f2 — Review follow-ups: dim ordering guard + directed CP backward gradient oracle test
  • c89d14153d — Pass cp_group/tp_group to MimoModel in colocated e2e tests
  • 44cf308696 — Fix fan-out forward slice for CP>1 dest
  • 17d37b7a54 — Thread cp_group through PartitionAdapter sbhd path
  • eb8fbc82be — CP=2 E2E tests (fan-in and fan-out)
  • 503e4062e1 — Dest CP>1 support in ColocatedBridgeCommunicator

Design summary

  • Encoder (src) must stay CP=1; LLM (dest) may have CP>1
  • PartitionAdapter.shard uses index_select → backward is zero-pad. The bridge's backward intra-CP all_reduce(SUM) on grad_output reconstructs the full-sequence grad before the direction-specific op
  • Fan-out gather groups split per (src_dp, dest_tp, cp_idx) so every world rank lands in exactly one subgroup
  • rank_to_dest_pos keeps canonical cp=0 ranks per (dp, tp); full (dp, tp, cp) is in rank_to_dest_coords
  • Reuses dest_grid.get_pg('cp') — no new PG created

Test status (all green on 8x H100)

  • test_mimo_colocated_communicator.py (17 tests, incl. new CP=2 cases)
  • test_mimo_colocated_correctness.py (3 CP=1 cases + 1 directed CP backward oracle)
  • test_mimo_colocated_e2e.py::test_colocated_fan_in_cp2_8gpu
  • test_mimo_colocated_e2e.py::test_colocated_fan_out_cp2_8gpu
  • test_mimo_colocated_e2e.py::test_colocated_fan_in_8gpu (CP=1 regression — still passes)

Review history

Went through three independent reviews. The remaining follow-ups were:

  1. Dim ordering validation (cp must appear before dp in dim_names) — added in 7711ed28f2
  2. Directed gradient oracle for intra-CP all_reduce — added in 7711ed28f2 (test_cp_backward_reduces_partial_seq_grads)

🤖 Generated with Claude Code

yashaswikarnati and others added 7 commits April 21, 2026 15:13
Allows the LLM (dest) side to use context parallelism while the encoder
(src) stays CP=1.

Correctness path:

- PartitionAdapter.shard shards sequence via index_select whose autograd
  adjoint is a scatter (zero-pad). So the grad flowing into the bridge's
  backward is already zero-padded along the sequence dimension — each
  dest CP rank holds only its own sequence chunks with zeros elsewhere.
  To return a full-sequence gradient to the encoder we run an intra-CP
  all_reduce(SUM) on grad_output before the direction-specific op.
  After the reduction every CP sibling holds the same full-sequence
  gradient, and the downstream narrow (fan-in) or all-gather (fan-out)
  proceeds exactly as in the CP=1 case.

- The intra-CP process group is reused from dest_grid.get_pg('cp'); no
  new PG is created by the communicator.

- Fan-out gather groups are now split per (src_dp_idx, dest_tp_idx,
  cp_idx) rather than pooled per (src_dp_idx, dest_tp_idx). This is
  required because new_subgroups_by_enumeration demands every world
  rank land in exactly one subgroup — a single pooled group would
  leave cp>0 ranks orphaned. After the CP reduction each cp-level's
  all-gather produces the same full-batch gradient.

- _build_rank_mappings filters dest tp_groups on cp_coord==0 when
  advancing dp_idx (mirror of the existing pp_coord==0 filter), so
  dp_idx correctly indexes the DP dimension regardless of CP size.
  Full (dp, tp, cp) coords are stored in rank_to_dest_coords for every
  pp=0 dest rank.

Validation:
- Encoder (src) CP must remain 1. Rejected in _validate_grids.
- Dest CP>1 is allowed; dest_cp_pg is populated from dest_grid.

Tests:
- New parameterized rank-mapping cases for CP=2 (fan-in and fan-out).
- New fan-out gather group test verifying per-CP-level subgroups and
  full world-rank coverage.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two new test cases exercise the real PartitionAdapter.shard path end-to-
end — no emulation of the zero-padded sequence backward. The LLM's
TransformerConfig now picks context_parallel_size from the llm_grid's CP
group, so CP=2 drives index_select sharding in the MIMO forward and
zero-pad in backward exactly as in production.

- test_colocated_fan_in_cp2_8gpu: encoder TP2/DP4, LLM TP2/DP2/CP2.
  Covers fan-in forward all-gather and intra-CP all_reduce before the
  batch narrow in backward.
- test_colocated_fan_out_cp2_8gpu: encoder TP4/DP2, LLM TP1/DP4/CP2.
  Covers fan-out narrow and the per-CP-level DP all-gather in backward,
  preceded by intra-CP all_reduce.

Both seq_length values are divisible by 2*cp to satisfy the
PartitionAdapter causal-load-balancing shard factor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PartitionAdapter._apply_context_parallel stored a cp_group on its
config but then called get_batch_on_this_cp_rank without forwarding
it, so the helper fell back to parallel_state.get_context_parallel_
world_size() and asserted 'context parallel group is not initialized'
whenever the model was built with an explicit cp_group (e.g. MIMO
colocated paths that derive CP from a HyperCommGrid instead of the
global parallel state).

Pass self.cfg.cp_group so the sbhd split uses the caller-provided
group and works in configs that never initialize global parallel
state.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
rank_to_dest_pos is the canonical (cp=0) view of dest ranks and is
used by downstream per-slot group construction. _get_fan_out_slice_
info was reading it to compute each rank's batch slice, so CP>0 dest
ranks missed the lookup and fell back to 'return the whole batch' --
their forward activation then had the wrong batch size, and backward
collectives across fan-out and CP siblings deadlocked on mismatched
shapes.

Use rank_to_dest_coords (every cp level) and read dp_idx off the full
coord. Batch slice is a dp-only choice; all CP-siblings of a (dp, tp)
slot must pick the same slot so the intra-CP all_reduce in backward
sees matching shapes. rank_to_dest_pos is left untouched since slot-
indexed group construction still requires one entry per (dp, tp).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The CP=2 e2e tests build the LLM via a custom HyperCommGrid and never
call megatron.core.parallel_state.initialize_model_parallel. MimoModel
defaults cp_group/tp_group to None and then routes through the global
parallel_state, tripping 'context parallel group is not initialized'
during PartitionConfig.from_mp_config.

Forward language_pg.cp and language_pg.tp (both already built by the
test's grid) so the PartitionAdapter binds to the test's CP group
directly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…W-17)

Two review follow-ups:

1. _build_rank_mappings assumed cp varies fastest for fixed dp, which
   holds only when 'cp' appears before 'dp' in dim_names. Add an
   explicit validation in _validate_grids rejecting reversed ordering
   with a descriptive error. All existing callers use the standard
   ['tp','cp','pp','dp'] layout, so this is defensive.

2. Add a directed unit test (test_cp_backward_reduces_partial_seq_grads)
   that feeds a PartitionAdapter-style zero-padded gradient into the
   bridge's backward and asserts the returned input grad is the full
   (summed) sequence. If the intra-CP all_reduce(SUM) in backward were
   a no-op, each CP rank would return only its own sequence-chunk grad
   and this test would fail. The previous E2E CP tests would all pass
   despite such a regression, so this closes the oracle gap.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Closes the main gap from the PR-10 AXIOM audit applied to PR-11: the
existing CP>1 E2E tests only verify training runs (loss decreases),
not that the encoder gradient is mathematically correct under CP>1.

Adds test_mimo_colocated_correctness_cp.py — applies the equal-DP
CP=1 reference oracle from test_mimo_colocated_correctness.py to a
heterogeneous-DP dist model with llm_cp>1. The reference uses cp=1
with the SAME encoder TP/DP layout as dist (bridge is identity,
shards align 1:1), so post-step encoder weights compare directly.
Catches: missing intra-CP all_reduce, dp-only loss reduction
(under-counts tokens by cp_size), and broken fan-out per-CP-level
gather groups — any of which silently scales the encoder grad.

Also lands the supporting fixes uncovered during the audit:
- loss_func reduces (num, den) over dp*cp instead of dp so the
  per-token grad factor stays 1/global_den under cp>1.
- Adds fan-out CP backward unit test (Test 5b) — the existing CP
  backward test only covered fan-in; fan-out exercises a different
  code path (intra-CP reduce + per-CP-level all-gather).
- Adds negative test for dest_grid dim_names ordering — the
  dp_idx-advancement guard was unverified, so a refactor could
  silently delete it.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/nmfw-17-colocated-cp branch from 7711ed2 to c012d6c Compare April 21, 2026 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant