diff --git a/megatron/core/models/mimo/comm/colocated_communicator.py b/megatron/core/models/mimo/comm/colocated_communicator.py index f730ea7e223..9e6f326e5c5 100644 --- a/megatron/core/models/mimo/comm/colocated_communicator.py +++ b/megatron/core/models/mimo/comm/colocated_communicator.py @@ -95,11 +95,17 @@ def __init__( elif self.dest_dp_size > self.src_dp_size: self.direction = BridgeDirection.FAN_OUT self.scale = self.dest_dp_size // self.src_dp_size - self.gather_group_ranks = self._build_gather_groups( + # Fan-out gather groups must be split by dest CP level: every + # world rank must land in exactly one subgroup, so a single + # pooled group per (src_dp, dest_tp) would orphan cp>0 ranks. + # With dest_cp_size == 1 this collapses to the original + # (src_dp, dest_tp) product. + self.gather_group_ranks = self._build_fan_out_gather_groups( iter_size=self.src_dp_size, sibling_tp_size=self.dest_tp_size, scale=self.scale, - rank_to_pos=self.rank_to_dest_pos, + cp_size=self.dest_cp_size, + rank_to_coords=self.rank_to_dest_coords, ) self.gather_pg, _ = dist.new_subgroups_by_enumeration( self.gather_group_ranks, backend='nccl' @@ -144,14 +150,30 @@ def _validate_grids(self): f"{name} PP must be 1 for ColocatedBridgeCommunicator, got {pp_size}" ) - # CP>1 corrupts dp_idx when we iterate get_rank_enum(['tp']) groups. - for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]: - if 'cp' in grid.dim_names: - cp_size = grid.shape[grid.dim_names.index('cp')] - if cp_size != 1: - raise ValueError( - f"{name} CP must be 1 for ColocatedBridgeCommunicator, got {cp_size}" - ) + # Source (encoder) must have CP=1. Dest (LLM) may have CP>1 — the LLM + # shards sequence across CP ranks via PartitionAdapter after receiving + # the bridge's full-sequence output. The bridge's backward path + # reduces partial-sequence gradients across dest CP siblings before + # returning the full-sequence gradient to the encoder. + if 'cp' in self.src_grid.dim_names: + src_cp_size = self.src_grid.shape[self.src_grid.dim_names.index('cp')] + if src_cp_size != 1: + raise ValueError( + f"Source CP must be 1 for ColocatedBridgeCommunicator, got {src_cp_size}" + ) + + # _build_rank_mappings assumes that _gen_rank_enum(['tp']) yields cp + # varying fastest for fixed dp — true only when cp appears BEFORE dp + # in dim_names (reversed, cp becomes an inner loop around dp). If the + # caller reverses them, dp_idx advances at the wrong cp level and + # rank_to_dest_coords is silently wrong. Guard explicitly. + if 'cp' in self.dest_grid.dim_names: + dim_names = self.dest_grid.dim_names + if dim_names.index('cp') > dim_names.index('dp'): + raise ValueError( + f"dest_grid dim_names must have 'cp' before 'dp' " + f"(e.g. ['tp','cp','pp','dp']); got {dim_names}" + ) src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')] dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] @@ -165,20 +187,68 @@ def _extract_parallelism_info(self): self.src_dp_size = self.src_grid.shape[self.src_grid.dim_names.index('dp')] self.dest_tp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('tp')] self.dest_dp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')] + if 'cp' in self.dest_grid.dim_names: + self.dest_cp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('cp')] + else: + self.dest_cp_size = 1 + # Reuse the existing CP group from dest_grid (caller creates it via + # grid.create_pg(['cp'])). None when dest CP=1. Used in backward to + # reduce sequence-sharded gradients across CP siblings before returning + # to the encoder — see the backward docstring for the math. + self.dest_cp_pg: Optional[dist.ProcessGroup] = ( + self.dest_grid.get_pg('cp') if self.dest_cp_size > 1 else None + ) + + @staticmethod + def _get_rank_dim_coord(rank: int, grid: HyperCommGrid, dim_name: str) -> int: + """Extract a rank's coordinate for a specific grid dimension.""" + dim_idx = grid.dim_names.index(dim_name) + temp = rank - grid.rank_offset + for i in range(dim_idx): + temp //= grid.shape[i] + return temp % grid.shape[dim_idx] def _build_rank_mappings(self): self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {} + # rank_to_dest_pos: canonical (cp_idx=0, pp_idx=0) rank per (dp, tp) + # slot — preserves the one-entry-per-slot contract downstream group + # construction depends on. self.rank_to_dest_pos: Dict[int, Tuple[int, int]] = {} + # rank_to_dest_coords: every dest rank's full (dp_idx, tp_idx, cp_idx) + # for PP stage 0. Used to build per-CP-level fan-out groups and to + # drive the intra-CP gradient reduction in backward. + self.rank_to_dest_coords: Dict[int, Tuple[int, int, int]] = {} src_tp_groups = self.src_grid.get_rank_enum(['tp']) for dp_idx, tp_group in enumerate(src_tp_groups): for tp_idx, rank in enumerate(tp_group): self.rank_to_src_pos[rank] = (dp_idx, tp_idx) + # Dest iteration: get_rank_enum(['tp']) returns dp*pp*cp tp-groups. + # We advance dp_idx only after the final cp level of each dp; every + # (cp, pp=0) rank still records its coords so fan-out backward can + # locate CP siblings. + dest_has_pp = 'pp' in self.dest_grid.dim_names + dest_has_cp = 'cp' in self.dest_grid.dim_names dest_tp_groups = self.dest_grid.get_rank_enum(['tp']) - for dp_idx, tp_group in enumerate(dest_tp_groups): + dp_idx = 0 + for tp_group in dest_tp_groups: + if dest_has_pp: + pp_coord = self._get_rank_dim_coord(tp_group[0], self.dest_grid, 'pp') + if pp_coord != 0: + continue + cp_coord = ( + self._get_rank_dim_coord(tp_group[0], self.dest_grid, 'cp') if dest_has_cp else 0 + ) + # get_rank_enum yields cp varying fastest for fixed dp (with + # pp=0 filtered). All cp levels of one dp share dp_idx; we + # advance dp_idx only after the final cp level of that dp. for tp_idx, rank in enumerate(tp_group): - self.rank_to_dest_pos[rank] = (dp_idx, tp_idx) + self.rank_to_dest_coords[rank] = (dp_idx, tp_idx, cp_coord) + if cp_coord == 0: + self.rank_to_dest_pos[rank] = (dp_idx, tp_idx) + if cp_coord == self.dest_cp_size - 1: + dp_idx += 1 @staticmethod def _build_gather_groups( @@ -208,6 +278,39 @@ def _build_gather_groups( groups.append(group_ranks) return groups + @staticmethod + def _build_fan_out_gather_groups( + iter_size: int, + sibling_tp_size: int, + scale: int, + cp_size: int, + rank_to_coords: Dict[int, Tuple[int, int, int]], + ) -> List[List[int]]: + """Build fan-out gather groups split per (src_dp, dest_tp, dest_cp). + + Splitting by ``cp_idx`` is required because + ``new_subgroups_by_enumeration`` demands every world rank land in + exactly one subgroup — a single pooled group per (src_dp, dest_tp) + would leave cp>0 ranks orphaned. After the CP reduction in backward, + each cp-level's all-gather produces the same full-batch gradient. + When ``cp_size == 1`` this degenerates to the original + (src_dp, dest_tp) product. + """ + coords_to_rank: Dict[Tuple[int, int, int], int] = { + coords: rank for rank, coords in rank_to_coords.items() + } + groups: List[List[int]] = [] + for iter_idx in range(iter_size): + sibling_dp_indices = range(iter_idx * scale, (iter_idx + 1) * scale) + for sibling_tp_idx in range(sibling_tp_size): + for cp_idx in range(cp_size): + group_ranks = [ + coords_to_rank[(sibling_dp_idx, sibling_tp_idx, cp_idx)] + for sibling_dp_idx in sibling_dp_indices + ] + groups.append(group_ranks) + return groups + def is_fan_in(self) -> bool: """True if src DP > dest DP (forward all-gathers).""" return self.direction is BridgeDirection.FAN_IN @@ -229,7 +332,11 @@ def get_slice_info(self, batch_size: int) -> SliceInfo: return SliceInfo(start=0, size=batch_size) self._check_divisible(batch_size) if self.direction is BridgeDirection.FAN_OUT: - dp_idx = self.rank_to_dest_pos[self.current_rank][0] + # rank_to_dest_pos only tracks cp=0 canonical slots; CP>0 dest + # ranks must slice the same batch slot as their cp=0 sibling so + # the intra-CP all_reduce in backward sees matching shapes. + # rank_to_dest_coords has an entry per (dp, tp, cp). + dp_idx = self.rank_to_dest_coords[self.current_rank][0] else: # FAN_IN dp_idx = self.rank_to_src_pos[self.current_rank][0] slot = dp_idx % self.scale @@ -296,10 +403,27 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: activation. Instead we all-gather across the fan-out sibling group, reconstructing the full src-batch gradient (symmetric with the fan-in forward's all-gather). + + When the dest grid has CP>1, the LLM's PartitionAdapter.shard slices + sequence via ``index_select`` whose autograd adjoint is a scatter / + zero-pad. So the grad flowing into this backward is already zero- + padded along the sequence dimension — each CP rank holds only its + own sequence chunks, zeros elsewhere. We therefore run an intra-CP + ``all_reduce(SUM)`` on the incoming gradient *before* the direction- + specific op. After the reduction every CP sibling holds the same + full-sequence gradient, and the downstream narrow (fan-in) or + all-gather (fan-out) proceeds exactly as in the CP=1 case. """ comm = ctx.comm batch_dim = ctx.batch_dim + # CP sequence-reduction step. Runs for both fan-in and fan-out when + # dest CP>1. See the docstring for the math. Clone first so the in- + # place all_reduce does not mutate the tensor autograd passed in. + if comm.dest_cp_pg is not None: + grad_output = grad_output.contiguous().clone() + dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=comm.dest_cp_pg) + if comm.direction is BridgeDirection.FAN_OUT: return _all_gather_along_batch_dim(grad_output, comm.gather_pg, batch_dim), None diff --git a/megatron/core/models/mimo/partition/utils.py b/megatron/core/models/mimo/partition/utils.py index 0b43e5548ff..f7bb8e578d1 100644 --- a/megatron/core/models/mimo/partition/utils.py +++ b/megatron/core/models/mimo/partition/utils.py @@ -235,7 +235,7 @@ def _apply_context_parallel( batch["attention_mask"] = attention_mask if packed_seq_params is None or getattr(packed_seq_params, 'qkv_format', 'sbhd') == 'sbhd': - batch = get_batch_on_this_cp_rank(batch) + batch = get_batch_on_this_cp_rank(batch, cp_group=self.cfg.cp_group) else: assert _HAVE_TEX and is_te_min_version("1.10.0"), ( "Please update Transformer Engine to >= 1.10 " diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 8a86cc992fd..2175ebef990 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -340,7 +340,11 @@ def get_mimo_model( module_to_grid_map=module_to_grid_map, ) - mimo_model = MimoModel(mimo_config) + # Forward language_pg.cp/tp so PartitionAdapter binds to the test's CP + # group directly. Test grids skip parallel_state.initialize_model_parallel, + # so leaving these None routes through the uninitialised global and trips + # 'context parallel group is not initialized' under CP>1. + mimo_model = MimoModel(mimo_config, cp_group=language_pg.cp, tp_group=language_pg.tp) mimo_model.to(torch.device("cuda")).to(torch.bfloat16) # Wrap with DDP (caller may override e.g. for heterogeneous-DP scaling). diff --git a/tests/unit_tests/models/test_mimo_colocated_communicator.py b/tests/unit_tests/models/test_mimo_colocated_communicator.py index c19afc82e93..7f1b55f99d6 100644 --- a/tests/unit_tests/models/test_mimo_colocated_communicator.py +++ b/tests/unit_tests/models/test_mimo_colocated_communicator.py @@ -146,6 +146,64 @@ def test_rank_mappings_with_rank_offset(self): assert comm.rank_to_src_pos == {4: (0, 0), 5: (0, 1), 6: (1, 0), 7: (1, 1)} assert comm.rank_to_dest_pos == {4: (0, 0), 5: (1, 0), 6: (2, 0), 7: (3, 0)} + @pytest.mark.parametrize( + "src_tp, src_dp, dest_tp, dest_dp, dest_cp, expected_dest_pos, expected_dest_coords", + [ + # Fan-in with dest CP=2: TP2/DP4 → TP2/DP2/CP2. rank_to_dest_pos + # only holds canonical (cp=0) ranks per (dp, tp); full (dp, tp, cp) + # is stored in rank_to_dest_coords. + ( + 2, + 4, + 2, + 2, + 2, + {0: (0, 0), 1: (0, 1), 4: (1, 0), 5: (1, 1)}, + { + 0: (0, 0, 0), + 1: (0, 1, 0), + 2: (0, 0, 1), + 3: (0, 1, 1), + 4: (1, 0, 0), + 5: (1, 1, 0), + 6: (1, 0, 1), + 7: (1, 1, 1), + }, + ), + # Fan-out with dest CP=2: TP4/DP2 → TP1/DP4/CP2. + ( + 4, + 2, + 1, + 4, + 2, + {0: (0, 0), 2: (1, 0), 4: (2, 0), 6: (3, 0)}, + { + 0: (0, 0, 0), + 1: (0, 0, 1), + 2: (1, 0, 0), + 3: (1, 0, 1), + 4: (2, 0, 0), + 5: (2, 0, 1), + 6: (3, 0, 0), + 7: (3, 0, 1), + }, + ), + ], + ids=["fan_in_cp2", "fan_out_cp2"], + ) + def test_rank_mappings_with_cp( + self, src_tp, src_dp, dest_tp, dest_dp, dest_cp, expected_dest_pos, expected_dest_coords + ): + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, cp=dest_cp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid) + + assert comm.rank_to_dest_pos == expected_dest_pos + assert comm.rank_to_dest_coords == expected_dest_coords + assert comm.dest_cp_size == dest_cp + assert comm.dest_cp_pg is not None + # ── Test 2: All-gather groups ────────────────────────────────────────────────── @@ -184,6 +242,27 @@ def test_fan_out_gather_groups(self): assert comm.gather_group_ranks == [[0, 2], [1, 3], [4, 6], [5, 7]] assert comm.gather_pg is not None + def test_fan_out_gather_groups_with_cp(self): + """Fan-out with dest CP=2: each (src_dp, dest_tp) slot splits into + per-cp-level groups so every world rank lands in exactly one subgroup. + + src=(tp=4, dp=2), dest=(tp=1, dp=4, cp=2), scale=2. Expected groups + (one per src_dp × dest_tp × cp_idx, scale=2 ranks each): + src_dp=0, cp=0: dest_dp=[0,1] → [rank 0, rank 2] + src_dp=0, cp=1: dest_dp=[0,1] → [rank 1, rank 3] + src_dp=1, cp=0: dest_dp=[2,3] → [rank 4, rank 6] + src_dp=1, cp=1: dest_dp=[2,3] → [rank 5, rank 7] + """ + src_grid = create_hypercomm_grid(tp=4, dp=2) + dest_grid = create_hypercomm_grid(tp=1, cp=2, dp=4) + comm = make_comm(src_grid, dest_grid) + + assert comm.gather_group_ranks == [[0, 2], [1, 3], [4, 6], [5, 7]] + # Every world rank must appear exactly once across all fan-out groups. + flat = [r for g in comm.gather_group_ranks for r in g] + assert sorted(flat) == list(range(8)) + assert comm.gather_pg is not None + # ── Test 3b: _validate_grids negative tests ─────────────────────────────────── @@ -245,9 +324,37 @@ def test_dest_pp_gt_one_rejected(self): make_comm(src_grid, dest_grid) def test_cp_gt_one_rejected(self): + # Source CP>1 is explicitly disallowed; only dest CP>1 is supported. src_grid = create_hypercomm_grid(tp=2, cp=2, dp=2) dest_grid = create_hypercomm_grid(tp=4, dp=2) - with pytest.raises(ValueError, match="CP must be 1"): + with pytest.raises(ValueError, match="Source CP must be 1"): + make_comm(src_grid, dest_grid) + + def test_dest_cp_after_dp_in_dim_names_rejected(self): + """Dest ``dim_names`` with ``cp`` *after* ``dp`` must be rejected. + + ``_build_rank_mappings`` relies on ``get_rank_enum(['tp'])`` yielding + cp varying fastest for fixed dp. That only holds when ``cp`` appears + before ``dp`` in dim_names. If the ordering is reversed, dp_idx would + advance at the wrong cp level and ``rank_to_dest_coords`` would be + silently wrong — a latent bug hidden behind a guard. This negative + test makes sure the guard actually fires so the guard can't be + refactored away without a test failure. + """ + if dist.get_world_size() < 8: + pytest.skip("requires at least 8 ranks") + src_grid = create_hypercomm_grid(tp=2, dp=4) + # Reversed order: dp before cp. ``create_hypercomm_grid`` hardcodes the + # canonical ordering, so build the broken grid directly. + dest_grid = HyperCommGrid( + shape=[1, 4, 1, 2], + dim_names=["tp", "dp", "pp", "cp"], + backend="nccl", + ) + dest_grid.create_pg(["tp"]) + dest_grid.create_pg(["cp"]) + _active_grids.append(dest_grid) + with pytest.raises(ValueError, match="must have 'cp' before 'dp'"): make_comm(src_grid, dest_grid) def test_dp_not_divisible(self): @@ -536,7 +643,105 @@ def test_fan_out_backward_equals_concat_of_sibling_grads( ) torch.testing.assert_close(input_tensor.grad, expected, rtol=0, atol=0) - # ── Test 5: equal DP is a pure identity forward and backward ──────────── + # ── Test 5: dest CP>1 backward reconstructs full-seq grad via intra-CP reduce ─ + @pytest.mark.parametrize( + "src_tp,src_dp,dest_tp,dest_dp,dest_cp", [(1, 8, 1, 4, 2)], ids=["fan_in_cp2"] + ) + def test_cp_backward_reduces_partial_seq_grads( + self, src_tp, src_dp, dest_tp, dest_dp, dest_cp + ): + """Bridge backward must intra-CP all_reduce(SUM) before the fan op. + + PartitionAdapter.shard uses index_select whose autograd adjoint is + zero-pad: each CP rank's grad at the bridge boundary covers only + its own 2*CP-chunk positions, zeros elsewhere. Without an intra-CP + all_reduce, every CP sibling would return only its own sequence + chunk and upstream gradients would lose information. + """ + dim_mapping = {'b': 0, 's': 1, 'h': 2} + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, cp=dest_cp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + B_local, S, H = self.B_PER_RANK, 2 * dest_cp * 2, self.H + t = torch.full( + (B_local, S, H), float(dist.get_rank()), device='cuda' + ).requires_grad_() + out = comm.communicate(t) + assert out.shape == (B_local * comm.scale, S, H) + + cp_rank = dest_grid.get_pg("cp").rank() + chunk = S // (2 * dest_cp) + mask = torch.zeros(S, device='cuda') + mask[cp_rank * chunk : (cp_rank + 1) * chunk] = 1.0 + mask[(2 * dest_cp - 1 - cp_rank) * chunk : (2 * dest_cp - cp_rank) * chunk] = 1.0 + grad_output = mask.view(1, S, 1).expand(B_local * comm.scale, S, H).contiguous() + + out.backward(grad_output.to(dtype=out.dtype)) + + expected = torch.ones(B_local, S, H, device='cuda', dtype=t.grad.dtype) + torch.testing.assert_close(t.grad, expected, rtol=0, atol=1e-6) + + # ── Test 5b: dest CP>1 fan-out backward reconstructs full-seq grad ────── + @pytest.mark.parametrize( + "src_tp,src_dp,dest_tp,dest_dp,dest_cp", [(4, 2, 1, 4, 2)], ids=["fan_out_cp2"] + ) + def test_cp_fan_out_backward_reduces_partial_seq_grads( + self, src_tp, src_dp, dest_tp, dest_dp, dest_cp + ): + """Fan-out companion to ``test_cp_backward_reduces_partial_seq_grads``. + + Test 5 only covers fan-in (post-CP-reduce op is ``narrow``). Fan-out + takes a different code path: after the intra-CP ``all_reduce`` the + backward runs an ``all_gather`` across the per-CP-level sibling group + built by ``_build_fan_out_gather_groups``. This test feeds the same + PartitionAdapter-style zero-padded gradient pattern but through the + fan-out direction and verifies the returned input grad is the full + (all-ones) gradient across both the sequence AND the gathered batch. + + Four regressions this catches: + * intra-CP ``all_reduce`` degraded to no-op → gradient stays + per-CP-rank sparse (ones only in this rank's chunks). + * fan-out gather groups **not** split per CP level (every world + rank lands in a single pooled group) → the CP ranks end up in + each other's gather group, duplicating values on the batch dim. + * wrong CP group (e.g. accidentally using ``dp_cp``) → the reduce + covers too many ranks and gradients get inflated. + * all-gather ordering wrong → values land in the wrong batch slot, + so the exact ``ones`` oracle fails. + """ + dim_mapping = {'b': 0, 's': 1, 'h': 2} + src_grid = create_hypercomm_grid(tp=src_tp, dp=src_dp) + dest_grid = create_hypercomm_grid(tp=dest_tp, cp=dest_cp, dp=dest_dp) + comm = make_comm(src_grid, dest_grid, dim_mapping=dim_mapping) + + B_full, S, H = self.B_PER_RANK * comm.scale, 2 * dest_cp * 2, self.H + # TP-replicated input (bridge contract): seed identically on every rank. + torch.manual_seed(42) + t = torch.ones(B_full, S, H, device='cuda', requires_grad=True) + out = comm.communicate(t) + assert out.shape == (self.B_PER_RANK, S, H) + + cp_rank = dest_grid.get_pg("cp").rank() + chunk = S // (2 * dest_cp) + # Mask pattern mirroring ``get_batch_on_this_cp_rank``: this CP rank + # owns chunk ``cp_rank`` and chunk ``2*cp_size - 1 - cp_rank``. After + # summing across CP ranks the mask becomes all ones. + mask = torch.zeros(S, device='cuda') + mask[cp_rank * chunk : (cp_rank + 1) * chunk] = 1.0 + mask[(2 * dest_cp - 1 - cp_rank) * chunk : (2 * dest_cp - cp_rank) * chunk] = 1.0 + grad_output = mask.view(1, S, 1).expand(self.B_PER_RANK, S, H).contiguous() + + out.backward(grad_output.to(dtype=out.dtype)) + + # Expected flow: intra-CP all_reduce → full-seq ones on every CP rank, + # then fan-out all-gather across the (src_dp, dest_tp, cp) sibling + # group concatenates scale=2 copies of ones along the batch dim, + # yielding ones(B_full, S, H) on every src rank. + expected = torch.ones(B_full, S, H, device='cuda', dtype=t.grad.dtype) + torch.testing.assert_close(t.grad, expected, rtol=0, atol=1e-6) + + # ── Test 6: equal DP is a pure identity forward and backward ──────────── @pytest.mark.parametrize( "src_tp,src_dp,dest_tp,dest_dp", [(4, 2, 4, 2)], ids=["tp4_dp2"] ) diff --git a/tests/unit_tests/models/test_mimo_colocated_correctness_cp.py b/tests/unit_tests/models/test_mimo_colocated_correctness_cp.py new file mode 100644 index 00000000000..d396d8a1e92 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_colocated_correctness_cp.py @@ -0,0 +1,274 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Gradient-scaling correctness for colocated MimoModel with dest CP>1. + +Extends the equal-DP-reference oracle from +``test_mimo_colocated_correctness.py`` to a heterogeneous-DP dist model +whose LLM has ``cp > 1``. The reference is the same equal-DP CP=1 +config used in the PR-10 oracle (bridge is identity passthrough, +encoder TP layout matches dist's encoder shard-for-shard); the dist +side adds CP>1 on the LLM, so this test specifically exercises: + +* ``loss_func`` reducing ``(num, den)`` over ``dp * cp`` so the per-token + grad factor stays ``1 / global_den`` instead of being scaled by + ``cp_size``. Reduce over plain DP and the encoder grad would shrink + by ``cp_size``. +* ``ColocatedBridgeCommunicator`` backward's intra-CP ``all_reduce``, + which reconstructs the full-sequence gradient from the zero-padded + per-CP-rank grad produced by ``PartitionAdapter.shard``'s + ``index_select`` adjoint. Without it the encoder receives only the + current CP rank's sequence chunk, dropping the rest. +* For fan-out: the per-CP-level gather groups built by + ``_build_fan_out_gather_groups`` (a single pooled group per + ``(src_dp, dest_tp)`` would orphan ``cp>0`` ranks). + +If any of those is wrong the encoder gradient gets a non-unit factor of +``cp_size`` (or worse, drops sequence content), and one Adam step is +enough to make the encoder shards diverge from the CP=1 reference. + +Run with:: + + uv run python -m torch.distributed.run --nproc_per_node=8 \\ + -m pytest tests/unit_tests/models/test_mimo_colocated_correctness_cp.py -v -s +""" + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.models.mimo.optimizer import get_mimo_optimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.transformer.enums import ModelType +from tests.unit_tests.models.test_mimo_1f1b_schedule import ( + create_all_embedding_groups, + create_hypercomm_grid, + destroy_all_grids, + get_mimo_model, +) +from tests.unit_tests.models.test_mimo_colocated_correctness import ( + _assert_encoder_weights_match, + _copy_ref_params_to_dist, + _generate_and_broadcast_global_batches, + _run_forward_backward, + _set_deterministic_env, + _slice_global_batch_by_dp, + _slice_global_batch_for_dist, + _wire_training_hooks, +) +from tests.unit_tests.test_utilities import Utils + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("2.3.0"), + reason="Requires PyTorch 2.3+", +) +class TestColocatedCPCorrectness: + """Equal-DP CP=1 reference oracle for heterogeneous-DP dist with LLM CP>1.""" + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + torch.use_deterministic_algorithms(False) + destroy_all_grids() + + @pytest.mark.parametrize( + "enc_tp,enc_dp,llm_tp,llm_dp,llm_cp", + [(2, 4, 2, 2, 2), (4, 2, 1, 4, 2)], + ids=["fan_in_cp2", "fan_out_cp2"], + ) + def test_cp_dist_matches_cp1_reference_post_step_weights( + self, enc_tp, enc_dp, llm_tp, llm_dp, llm_cp + ): + """Hetero-DP+CP>1 dist post-step encoder weights match equal-DP CP=1 ref. + + Both sides use ``gradient_reduce_div_factor=1`` and the num+den + global-mean CE so the DDP reduction is a pure SUM and the + aggregate grad on every encoder shard equals the DP=1 gradient. + Encoder TP and per-rank batch are matched between dist and ref so + encoder shards line up 1:1 for direct comparison. + + On the dist side the LLM additionally runs CP>1, so the encoder- + bound gradient must survive three transforms unchanged: + * loss reduction over ``dp*cp`` (not just dp), + * bridge backward intra-CP all_reduce(SUM), + * (fan-out only) per-CP-level fan-out gather groups. + + Any of those mis-scoped scales the encoder grad by ``cp_size`` + (or drops sequence content), and one Adam step makes shards + diverge from the ref beyond bf16 rounding. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + _set_deterministic_env() + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + encoder_name = "images" + # seq_length must be divisible by 2*llm_cp — PartitionAdapter's + # causal load-balancing splits each sequence into 2*cp chunks. + hidden_size, seq_length, vocab_size = 256, 64, 1000 + micro_batch_size = 2 + num_microbatches = 1 + + global_batch_size = micro_batch_size * max(enc_dp, llm_dp) + + # Dist: heterogeneous TP/DP, llm_cp>1. Ref: equal-DP uniform with the + # SAME encoder TP/DP as dist so the bridge is identity and encoder + # shards align 1:1 for direct comparison. Ref keeps cp=1; CP>1 lives + # only on the dist side because that is the path under audit. + dist_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) + dist_llm_grid = create_hypercomm_grid( + offset=0, tp=llm_tp, cp=llm_cp, pp=1, dp=llm_dp + ) + ref_enc_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) + ref_llm_grid = create_hypercomm_grid(offset=0, tp=enc_tp, cp=1, pp=1, dp=enc_dp) + create_all_embedding_groups( + [dist_enc_grid, dist_llm_grid, ref_enc_grid, ref_llm_grid] + ) + + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, + bucket_size=10000, + use_distributed_optimizer=True, + gradient_reduce_div_factor=1, + ) + + torch.manual_seed(12345) + dist_mimo, _, _, dist_language_pg, dist_vision_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + hidden_size=hidden_size, + num_layers=2, + vocab_size=vocab_size, + seq_len=seq_length, + ddp_config=ddp_config, + ) + dist_mimo.model_type = ModelType.encoder_or_decoder + + torch.manual_seed(12345) + ref_mimo, _, _, ref_language_pg, ref_vision_pg = get_mimo_model( + encoder_name=encoder_name, + encoder_grid=ref_enc_grid, + llm_grid=ref_llm_grid, + hidden_size=hidden_size, + num_layers=2, + vocab_size=vocab_size, + seq_len=seq_length, + ddp_config=ddp_config, + ) + ref_mimo.model_type = ModelType.encoder_or_decoder + + # Encoder TP layouts match between dist and ref → shard-to-shard + # copy. LLM TP differs (and dist additionally has CP, but CP does + # not reshape weights), so the helper all-gathers ref's shards + # across ref's LLM TP group and re-slices for dist's LLM TP group. + _copy_ref_params_to_dist( + ref_mimo.modality_submodules[encoder_name].module, + dist_mimo.modality_submodules[encoder_name].module, + ref_enc_grid.get_pg("tp"), + dist_enc_grid.get_pg("tp"), + ) + _copy_ref_params_to_dist( + ref_mimo.language_model.module, + dist_mimo.language_model.module, + ref_llm_grid.get_pg("tp"), + dist_llm_grid.get_pg("tp"), + ) + + _wire_training_hooks(dist_mimo, dist_language_pg, dist_vision_pg) + _wire_training_hooks(ref_mimo, ref_language_pg, ref_vision_pg) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + dist_optimizer = get_mimo_optimizer(dist_mimo, opt_config) + ref_optimizer = get_mimo_optimizer(ref_mimo, opt_config) + + torch.manual_seed(99999) + global_batches = _generate_and_broadcast_global_batches( + global_mbs=global_batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + vocab_size=vocab_size, + encoder_name=encoder_name, + num_batches=num_microbatches, + ) + # Dist: pre-slice along the larger DP side; forward_step further + # slices the encoder/LLM side as needed. CP does not affect the + # batch dim so the helper is reused unchanged. + dist_batches = [ + _slice_global_batch_for_dist(b, dist_enc_grid, dist_llm_grid) + for b in global_batches + ] + # Ref is equal-DP (enc_dp == llm_dp) so the dist helper would + # return the full batch; slice explicitly so each rank sees the + # same per-rank encoder batch as dist's encoder. + ref_batches = [ + _slice_global_batch_by_dp(b, ref_enc_grid.get_pg("dp")) + for b in global_batches + ] + ref_per_rank_batch_size = global_batch_size // enc_dp + + dist_optimizer.zero_grad() + _run_forward_backward( + mimo_model=dist_mimo, + batches=dist_batches, + enc_grid=dist_enc_grid, + llm_grid=dist_llm_grid, + encoder_name=encoder_name, + language_pg=dist_language_pg, + micro_batch_size=micro_batch_size, + seq_length=seq_length, + num_microbatches=num_microbatches, + ) + dist_success, dist_grad_norm, _ = dist_optimizer.step() + assert dist_success, "Dist optimizer step failed" + assert dist_grad_norm is not None and dist_grad_norm > 0, ( + f"Dist grad_norm={dist_grad_norm} — encoder grads may have been " + "silently zeroed by wrong CP scaling" + ) + + ref_optimizer.zero_grad() + _run_forward_backward( + mimo_model=ref_mimo, + batches=ref_batches, + enc_grid=ref_enc_grid, + llm_grid=ref_llm_grid, + encoder_name=encoder_name, + language_pg=ref_language_pg, + micro_batch_size=ref_per_rank_batch_size, + seq_length=seq_length, + num_microbatches=num_microbatches, + ) + ref_success, ref_grad_norm, _ = ref_optimizer.step() + assert ref_success, "Ref optimizer step failed" + assert ref_grad_norm is not None and ref_grad_norm > 0, ( + f"Ref grad_norm={ref_grad_norm}" + ) + + # Loose-ish tolerance because dist and ref differ on the LLM side + # (different llm_tp, additionally cp>1 on dist) — bf16 accumulation + # noise from the LLM forward propagates into each model's encoder + # gradient. Mirrors the tolerances in the CP=1 correctness test. + _assert_encoder_weights_match( + ref_mimo.modality_submodules[encoder_name].module, + dist_mimo.modality_submodules[encoder_name].module, + rtol=1e-3, + atol=1e-3, + ) diff --git a/tests/unit_tests/models/test_mimo_colocated_e2e.py b/tests/unit_tests/models/test_mimo_colocated_e2e.py index 2282ca7ca23..923f2ae3b32 100644 --- a/tests/unit_tests/models/test_mimo_colocated_e2e.py +++ b/tests/unit_tests/models/test_mimo_colocated_e2e.py @@ -61,20 +61,25 @@ logger = logging.getLogger(__name__) -def loss_func(loss_mask, llm_dp_pg, output_tensor): - """Global-mean CE across the LLM DP group via all-reduced ``(num, den)``. +def loss_func(loss_mask, llm_dp_cp_pg, output_tensor): + """Global-mean CE across the LLM DP*CP group via all-reduced ``(num, den)``. ``output_tensor`` is the per-token CE from ``GPTModel.compute_language_model_loss`` with shape ``[b, s]`` (the "actual Megatron cross entropy"). Masking ignored tokens, all-reducing - numerator and valid-token count across the LLM DP group, and dividing - is the exact distributed equivalent of full-batch + numerator and valid-token count across the LLM DP*CP group, and + dividing is the exact distributed equivalent of full-batch ``F.cross_entropy(..., reduction='mean')``. - The all-reduce is scoped to the **LLM DP group only** — never the - full dp*tp group. All TP ranks within a DP replica already hold the - identical per-token loss (since the LLM's output is TP-replicated), - so summing over TP peers would double-count. + The all-reduce is scoped to ``dp * cp`` — never the full dp*cp*tp + group. All TP ranks within a (dp, cp) slot already hold the identical + per-token loss (the LLM's output is TP-replicated), so summing over + TP peers would double-count. CP ranks, however, hold *disjoint* + sequence chunks of the same batch, so summing over CP is mandatory + for the per-token grad factor to equal ``1 / global_den`` — without + it, CP>1 under-counts tokens and every encoder gradient is scaled + by ``cp_size``. When ``cp == 1`` this PG is equivalent to the DP + group, so existing CP=1 tests are unaffected. """ if output_tensor is None: return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} @@ -82,8 +87,8 @@ def loss_func(loss_mask, llm_dp_pg, output_tensor): masked = output_tensor.float() * loss_mask.float() local_num = masked.sum() local_den = loss_mask.float().sum() - dist.all_reduce(local_num, group=llm_dp_pg) - dist.all_reduce(local_den, group=llm_dp_pg) + dist.all_reduce(local_num, group=llm_dp_cp_pg) + dist.all_reduce(local_den, group=llm_dp_cp_pg) # clamp_min(1.0) guards against the pathological "all tokens masked" # batch. In normal training local_den > 0 on every rank, but a CE loss # that divides by zero crashes the entire step; better to return 0. @@ -94,11 +99,13 @@ def loss_func(loss_mask, llm_dp_pg, output_tensor): def forward_step(data_iterator, model, encoder_grid, llm_grid, encoder_name): """Forward step with data slicing for heterogeneous DP.""" batch = next(data_iterator) if data_iterator is not None else {'input_ids': None} - llm_dp_pg = llm_grid.get_pg("dp") + # Reduce the (num, den) loss statistics over dp*cp so CP>1 does not + # under-count tokens. Equivalent to the plain dp group when cp=1. + llm_dp_cp_pg = llm_grid.get_pg(["dp", "cp"]) if batch.get('input_ids') is None: output_tensor, loss_mask = model(**batch) - return output_tensor, partial(loss_func, loss_mask, llm_dp_pg) + return output_tensor, partial(loss_func, loss_mask, llm_dp_cp_pg) encoder_dp = encoder_grid.get_pg("dp").size() llm_dp = llm_grid.get_pg("dp").size() @@ -135,7 +142,7 @@ def forward_step(data_iterator, model, encoder_grid, llm_grid, encoder_name): batch[key] = batch[key][start : start + slice_size].contiguous() output_tensor, loss_mask = model(**batch) - return output_tensor, partial(loss_func, loss_mask, llm_dp_pg) + return output_tensor, partial(loss_func, loss_mask, llm_dp_cp_pg) def run_colocated_test( @@ -143,6 +150,7 @@ def run_colocated_test( encoder_dp, llm_tp, llm_dp, + llm_cp=1, hidden_size=256, num_layers=2, vocab_size=1000, @@ -158,9 +166,10 @@ def run_colocated_test( encoder_name = "images" - # Both grids at offset=0 (colocated on same ranks) + # Both grids at offset=0 (colocated on same ranks). Encoder stays CP=1; + # LLM CP follows the parameter so tests can exercise the CP>1 backward. encoder_grid = create_hypercomm_grid(offset=0, tp=encoder_tp, cp=1, pp=1, dp=encoder_dp) - llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=1, pp=1, dp=llm_dp) + llm_grid = create_hypercomm_grid(offset=0, tp=llm_tp, cp=llm_cp, pp=1, dp=llm_dp) # dist.new_group is a collective — create all embedding PGs up front. create_all_embedding_groups([encoder_grid, llm_grid]) @@ -410,3 +419,53 @@ def test_colocated_fan_out_grad_accumulation_8gpu(self): micro_batch_size=2, num_microbatches=4, ) + + def test_colocated_fan_in_cp2_8gpu(self): + """Encoder TP2/DP4, LLM TP2/DP2/CP2 — fan-in with dest CP=2. + + Exercises the real PartitionAdapter.shard path: the LLM's + context_parallel_size=2 config drives sequence sharding via + index_select, whose backward is zero-pad. The bridge communicator's + backward must therefore intra-CP all_reduce to return a full-sequence + gradient to the encoder — covered by this end-to-end path. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_test( + encoder_tp=2, + encoder_dp=4, + llm_tp=2, + llm_dp=2, + llm_cp=2, + hidden_size=256, + num_layers=2, + vocab_size=1000, + # seq_length must be divisible by 2*cp=4 (PartitionAdapter + # causal-load-balancing splits sequence into 2*cp chunks). + seq_length=64, + micro_batch_size=2, + num_microbatches=2, + ) + + def test_colocated_fan_out_cp2_8gpu(self): + """Encoder TP4/DP2, LLM TP1/DP4/CP2 — fan-out with dest CP=2. + + Complements the fan-in CP test: exercises the per-CP-level fan-out + gather groups and the backward CP reduction for the narrow-adjoint + path. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + run_colocated_test( + encoder_tp=4, + encoder_dp=2, + llm_tp=1, + llm_dp=4, + llm_cp=2, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=4, + num_microbatches=2, + )