[Fork review] NMFW-17: dest CP>1 support on top of pr-a#11
Draft
yashaswikarnati wants to merge 7 commits into
Draft
[Fork review] NMFW-17: dest CP>1 support on top of pr-a#11yashaswikarnati wants to merge 7 commits into
yashaswikarnati wants to merge 7 commits into
Conversation
Allows the LLM (dest) side to use context parallelism while the encoder
(src) stays CP=1.
Correctness path:
- PartitionAdapter.shard shards sequence via index_select whose autograd
adjoint is a scatter (zero-pad). So the grad flowing into the bridge's
backward is already zero-padded along the sequence dimension — each
dest CP rank holds only its own sequence chunks with zeros elsewhere.
To return a full-sequence gradient to the encoder we run an intra-CP
all_reduce(SUM) on grad_output before the direction-specific op.
After the reduction every CP sibling holds the same full-sequence
gradient, and the downstream narrow (fan-in) or all-gather (fan-out)
proceeds exactly as in the CP=1 case.
- The intra-CP process group is reused from dest_grid.get_pg('cp'); no
new PG is created by the communicator.
- Fan-out gather groups are now split per (src_dp_idx, dest_tp_idx,
cp_idx) rather than pooled per (src_dp_idx, dest_tp_idx). This is
required because new_subgroups_by_enumeration demands every world
rank land in exactly one subgroup — a single pooled group would
leave cp>0 ranks orphaned. After the CP reduction each cp-level's
all-gather produces the same full-batch gradient.
- _build_rank_mappings filters dest tp_groups on cp_coord==0 when
advancing dp_idx (mirror of the existing pp_coord==0 filter), so
dp_idx correctly indexes the DP dimension regardless of CP size.
Full (dp, tp, cp) coords are stored in rank_to_dest_coords for every
pp=0 dest rank.
Validation:
- Encoder (src) CP must remain 1. Rejected in _validate_grids.
- Dest CP>1 is allowed; dest_cp_pg is populated from dest_grid.
Tests:
- New parameterized rank-mapping cases for CP=2 (fan-in and fan-out).
- New fan-out gather group test verifying per-CP-level subgroups and
full world-rank coverage.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two new test cases exercise the real PartitionAdapter.shard path end-to- end — no emulation of the zero-padded sequence backward. The LLM's TransformerConfig now picks context_parallel_size from the llm_grid's CP group, so CP=2 drives index_select sharding in the MIMO forward and zero-pad in backward exactly as in production. - test_colocated_fan_in_cp2_8gpu: encoder TP2/DP4, LLM TP2/DP2/CP2. Covers fan-in forward all-gather and intra-CP all_reduce before the batch narrow in backward. - test_colocated_fan_out_cp2_8gpu: encoder TP4/DP2, LLM TP1/DP4/CP2. Covers fan-out narrow and the per-CP-level DP all-gather in backward, preceded by intra-CP all_reduce. Both seq_length values are divisible by 2*cp to satisfy the PartitionAdapter causal-load-balancing shard factor. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PartitionAdapter._apply_context_parallel stored a cp_group on its config but then called get_batch_on_this_cp_rank without forwarding it, so the helper fell back to parallel_state.get_context_parallel_ world_size() and asserted 'context parallel group is not initialized' whenever the model was built with an explicit cp_group (e.g. MIMO colocated paths that derive CP from a HyperCommGrid instead of the global parallel state). Pass self.cfg.cp_group so the sbhd split uses the caller-provided group and works in configs that never initialize global parallel state. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
rank_to_dest_pos is the canonical (cp=0) view of dest ranks and is used by downstream per-slot group construction. _get_fan_out_slice_ info was reading it to compute each rank's batch slice, so CP>0 dest ranks missed the lookup and fell back to 'return the whole batch' -- their forward activation then had the wrong batch size, and backward collectives across fan-out and CP siblings deadlocked on mismatched shapes. Use rank_to_dest_coords (every cp level) and read dp_idx off the full coord. Batch slice is a dp-only choice; all CP-siblings of a (dp, tp) slot must pick the same slot so the intra-CP all_reduce in backward sees matching shapes. rank_to_dest_pos is left untouched since slot- indexed group construction still requires one entry per (dp, tp). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The CP=2 e2e tests build the LLM via a custom HyperCommGrid and never call megatron.core.parallel_state.initialize_model_parallel. MimoModel defaults cp_group/tp_group to None and then routes through the global parallel_state, tripping 'context parallel group is not initialized' during PartitionConfig.from_mp_config. Forward language_pg.cp and language_pg.tp (both already built by the test's grid) so the PartitionAdapter binds to the test's CP group directly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…W-17) Two review follow-ups: 1. _build_rank_mappings assumed cp varies fastest for fixed dp, which holds only when 'cp' appears before 'dp' in dim_names. Add an explicit validation in _validate_grids rejecting reversed ordering with a descriptive error. All existing callers use the standard ['tp','cp','pp','dp'] layout, so this is defensive. 2. Add a directed unit test (test_cp_backward_reduces_partial_seq_grads) that feeds a PartitionAdapter-style zero-padded gradient into the bridge's backward and asserts the returned input grad is the full (summed) sequence. If the intra-CP all_reduce(SUM) in backward were a no-op, each CP rank would return only its own sequence-chunk grad and this test would fail. The previous E2E CP tests would all pass despite such a regression, so this closes the oracle gap. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Closes the main gap from the PR-10 AXIOM audit applied to PR-11: the existing CP>1 E2E tests only verify training runs (loss decreases), not that the encoder gradient is mathematically correct under CP>1. Adds test_mimo_colocated_correctness_cp.py — applies the equal-DP CP=1 reference oracle from test_mimo_colocated_correctness.py to a heterogeneous-DP dist model with llm_cp>1. The reference uses cp=1 with the SAME encoder TP/DP layout as dist (bridge is identity, shards align 1:1), so post-step encoder weights compare directly. Catches: missing intra-CP all_reduce, dp-only loss reduction (under-counts tokens by cp_size), and broken fan-out per-CP-level gather groups — any of which silently scales the encoder grad. Also lands the supporting fixes uncovered during the audit: - loss_func reduces (num, den) over dp*cp instead of dp so the per-token grad factor stays 1/global_den under cp>1. - Adds fan-out CP backward unit test (Test 5b) — the existing CP backward test only covered fan-in; fan-out exercises a different code path (intra-CP reduce + per-CP-level all-gather). - Adds negative test for dest_grid dim_names ordering — the dp_idx-advancement guard was unverified, so a refactor could silently delete it. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
7711ed2 to
c012d6c
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fork-internal PR for reviewing the pr-c CP>1 diff. Base is pr-a (
ykarnati/nmfw-17-colocated-bridge-pp1), so the diff shown here is only the CP-specific changes.Commits (newest → oldest)
7711ed28f2— Review follow-ups: dim ordering guard + directed CP backward gradient oracle testc89d14153d— Pass cp_group/tp_group to MimoModel in colocated e2e tests44cf308696— Fix fan-out forward slice for CP>1 dest17d37b7a54— Thread cp_group through PartitionAdapter sbhd patheb8fbc82be— CP=2 E2E tests (fan-in and fan-out)503e4062e1— Dest CP>1 support in ColocatedBridgeCommunicatorDesign summary
PartitionAdapter.shardusesindex_select→ backward is zero-pad. The bridge's backward intra-CPall_reduce(SUM)ongrad_outputreconstructs the full-sequence grad before the direction-specific op(src_dp, dest_tp, cp_idx)so every world rank lands in exactly one subgrouprank_to_dest_poskeeps canonical cp=0 ranks per(dp, tp); full(dp, tp, cp)is inrank_to_dest_coordsdest_grid.get_pg('cp')— no new PG createdTest status (all green on 8x H100)
test_mimo_colocated_communicator.py(17 tests, incl. new CP=2 cases)test_mimo_colocated_correctness.py(3 CP=1 cases + 1 directed CP backward oracle)test_mimo_colocated_e2e.py::test_colocated_fan_in_cp2_8gputest_mimo_colocated_e2e.py::test_colocated_fan_out_cp2_8gputest_mimo_colocated_e2e.py::test_colocated_fan_in_8gpu(CP=1 regression — still passes)Review history
Went through three independent reviews. The remaining follow-ups were:
cpmust appear beforedpindim_names) — added in7711ed28f27711ed28f2(test_cp_backward_reduces_partial_seq_grads)🤖 Generated with Claude Code