Skip to content
150 changes: 137 additions & 13 deletions megatron/core/models/mimo/comm/colocated_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,17 @@ def __init__(
elif self.dest_dp_size > self.src_dp_size:
self.direction = BridgeDirection.FAN_OUT
self.scale = self.dest_dp_size // self.src_dp_size
self.gather_group_ranks = self._build_gather_groups(
# Fan-out gather groups must be split by dest CP level: every
# world rank must land in exactly one subgroup, so a single
# pooled group per (src_dp, dest_tp) would orphan cp>0 ranks.
# With dest_cp_size == 1 this collapses to the original
# (src_dp, dest_tp) product.
self.gather_group_ranks = self._build_fan_out_gather_groups(
iter_size=self.src_dp_size,
sibling_tp_size=self.dest_tp_size,
scale=self.scale,
rank_to_pos=self.rank_to_dest_pos,
cp_size=self.dest_cp_size,
rank_to_coords=self.rank_to_dest_coords,
)
self.gather_pg, _ = dist.new_subgroups_by_enumeration(
self.gather_group_ranks, backend='nccl'
Expand Down Expand Up @@ -144,14 +150,30 @@ def _validate_grids(self):
f"{name} PP must be 1 for ColocatedBridgeCommunicator, got {pp_size}"
)

# CP>1 corrupts dp_idx when we iterate get_rank_enum(['tp']) groups.
for name, grid in [("src", self.src_grid), ("dest", self.dest_grid)]:
if 'cp' in grid.dim_names:
cp_size = grid.shape[grid.dim_names.index('cp')]
if cp_size != 1:
raise ValueError(
f"{name} CP must be 1 for ColocatedBridgeCommunicator, got {cp_size}"
)
# Source (encoder) must have CP=1. Dest (LLM) may have CP>1 — the LLM
# shards sequence across CP ranks via PartitionAdapter after receiving
# the bridge's full-sequence output. The bridge's backward path
# reduces partial-sequence gradients across dest CP siblings before
# returning the full-sequence gradient to the encoder.
if 'cp' in self.src_grid.dim_names:
src_cp_size = self.src_grid.shape[self.src_grid.dim_names.index('cp')]
if src_cp_size != 1:
raise ValueError(
f"Source CP must be 1 for ColocatedBridgeCommunicator, got {src_cp_size}"
)

# _build_rank_mappings assumes that _gen_rank_enum(['tp']) yields cp
# varying fastest for fixed dp — true only when cp appears BEFORE dp
# in dim_names (reversed, cp becomes an inner loop around dp). If the
# caller reverses them, dp_idx advances at the wrong cp level and
# rank_to_dest_coords is silently wrong. Guard explicitly.
if 'cp' in self.dest_grid.dim_names:
dim_names = self.dest_grid.dim_names
if dim_names.index('cp') > dim_names.index('dp'):
raise ValueError(
f"dest_grid dim_names must have 'cp' before 'dp' "
f"(e.g. ['tp','cp','pp','dp']); got {dim_names}"
)

src_dp = self.src_grid.shape[self.src_grid.dim_names.index('dp')]
dest_dp = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')]
Expand All @@ -165,20 +187,68 @@ def _extract_parallelism_info(self):
self.src_dp_size = self.src_grid.shape[self.src_grid.dim_names.index('dp')]
self.dest_tp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('tp')]
self.dest_dp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('dp')]
if 'cp' in self.dest_grid.dim_names:
self.dest_cp_size = self.dest_grid.shape[self.dest_grid.dim_names.index('cp')]
else:
self.dest_cp_size = 1
# Reuse the existing CP group from dest_grid (caller creates it via
# grid.create_pg(['cp'])). None when dest CP=1. Used in backward to
# reduce sequence-sharded gradients across CP siblings before returning
# to the encoder — see the backward docstring for the math.
self.dest_cp_pg: Optional[dist.ProcessGroup] = (
self.dest_grid.get_pg('cp') if self.dest_cp_size > 1 else None
)

@staticmethod
def _get_rank_dim_coord(rank: int, grid: HyperCommGrid, dim_name: str) -> int:
"""Extract a rank's coordinate for a specific grid dimension."""
dim_idx = grid.dim_names.index(dim_name)
temp = rank - grid.rank_offset
for i in range(dim_idx):
temp //= grid.shape[i]
return temp % grid.shape[dim_idx]

def _build_rank_mappings(self):
self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {}
# rank_to_dest_pos: canonical (cp_idx=0, pp_idx=0) rank per (dp, tp)
# slot — preserves the one-entry-per-slot contract downstream group
# construction depends on.
self.rank_to_dest_pos: Dict[int, Tuple[int, int]] = {}
# rank_to_dest_coords: every dest rank's full (dp_idx, tp_idx, cp_idx)
# for PP stage 0. Used to build per-CP-level fan-out groups and to
# drive the intra-CP gradient reduction in backward.
self.rank_to_dest_coords: Dict[int, Tuple[int, int, int]] = {}

src_tp_groups = self.src_grid.get_rank_enum(['tp'])
for dp_idx, tp_group in enumerate(src_tp_groups):
for tp_idx, rank in enumerate(tp_group):
self.rank_to_src_pos[rank] = (dp_idx, tp_idx)

# Dest iteration: get_rank_enum(['tp']) returns dp*pp*cp tp-groups.
# We advance dp_idx only after the final cp level of each dp; every
# (cp, pp=0) rank still records its coords so fan-out backward can
# locate CP siblings.
dest_has_pp = 'pp' in self.dest_grid.dim_names
dest_has_cp = 'cp' in self.dest_grid.dim_names
dest_tp_groups = self.dest_grid.get_rank_enum(['tp'])
for dp_idx, tp_group in enumerate(dest_tp_groups):
dp_idx = 0
for tp_group in dest_tp_groups:
if dest_has_pp:
pp_coord = self._get_rank_dim_coord(tp_group[0], self.dest_grid, 'pp')
if pp_coord != 0:
continue
cp_coord = (
self._get_rank_dim_coord(tp_group[0], self.dest_grid, 'cp') if dest_has_cp else 0
)
# get_rank_enum yields cp varying fastest for fixed dp (with
# pp=0 filtered). All cp levels of one dp share dp_idx; we
# advance dp_idx only after the final cp level of that dp.
for tp_idx, rank in enumerate(tp_group):
self.rank_to_dest_pos[rank] = (dp_idx, tp_idx)
self.rank_to_dest_coords[rank] = (dp_idx, tp_idx, cp_coord)
if cp_coord == 0:
self.rank_to_dest_pos[rank] = (dp_idx, tp_idx)
if cp_coord == self.dest_cp_size - 1:
dp_idx += 1

@staticmethod
def _build_gather_groups(
Expand Down Expand Up @@ -208,6 +278,39 @@ def _build_gather_groups(
groups.append(group_ranks)
return groups

@staticmethod
def _build_fan_out_gather_groups(
iter_size: int,
sibling_tp_size: int,
scale: int,
cp_size: int,
rank_to_coords: Dict[int, Tuple[int, int, int]],
) -> List[List[int]]:
"""Build fan-out gather groups split per (src_dp, dest_tp, dest_cp).

Splitting by ``cp_idx`` is required because
``new_subgroups_by_enumeration`` demands every world rank land in
exactly one subgroup — a single pooled group per (src_dp, dest_tp)
would leave cp>0 ranks orphaned. After the CP reduction in backward,
each cp-level's all-gather produces the same full-batch gradient.
When ``cp_size == 1`` this degenerates to the original
(src_dp, dest_tp) product.
"""
coords_to_rank: Dict[Tuple[int, int, int], int] = {
coords: rank for rank, coords in rank_to_coords.items()
}
groups: List[List[int]] = []
for iter_idx in range(iter_size):
sibling_dp_indices = range(iter_idx * scale, (iter_idx + 1) * scale)
for sibling_tp_idx in range(sibling_tp_size):
for cp_idx in range(cp_size):
group_ranks = [
coords_to_rank[(sibling_dp_idx, sibling_tp_idx, cp_idx)]
for sibling_dp_idx in sibling_dp_indices
]
groups.append(group_ranks)
return groups

def is_fan_in(self) -> bool:
"""True if src DP > dest DP (forward all-gathers)."""
return self.direction is BridgeDirection.FAN_IN
Expand All @@ -229,7 +332,11 @@ def get_slice_info(self, batch_size: int) -> SliceInfo:
return SliceInfo(start=0, size=batch_size)
self._check_divisible(batch_size)
if self.direction is BridgeDirection.FAN_OUT:
dp_idx = self.rank_to_dest_pos[self.current_rank][0]
# rank_to_dest_pos only tracks cp=0 canonical slots; CP>0 dest
# ranks must slice the same batch slot as their cp=0 sibling so
# the intra-CP all_reduce in backward sees matching shapes.
# rank_to_dest_coords has an entry per (dp, tp, cp).
dp_idx = self.rank_to_dest_coords[self.current_rank][0]
else: # FAN_IN
dp_idx = self.rank_to_src_pos[self.current_rank][0]
slot = dp_idx % self.scale
Expand Down Expand Up @@ -296,10 +403,27 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
activation. Instead we all-gather across the fan-out sibling
group, reconstructing the full src-batch gradient (symmetric
with the fan-in forward's all-gather).

When the dest grid has CP>1, the LLM's PartitionAdapter.shard slices
sequence via ``index_select`` whose autograd adjoint is a scatter /
zero-pad. So the grad flowing into this backward is already zero-
padded along the sequence dimension — each CP rank holds only its
own sequence chunks, zeros elsewhere. We therefore run an intra-CP
``all_reduce(SUM)`` on the incoming gradient *before* the direction-
specific op. After the reduction every CP sibling holds the same
full-sequence gradient, and the downstream narrow (fan-in) or
all-gather (fan-out) proceeds exactly as in the CP=1 case.
"""
comm = ctx.comm
batch_dim = ctx.batch_dim

# CP sequence-reduction step. Runs for both fan-in and fan-out when
# dest CP>1. See the docstring for the math. Clone first so the in-
# place all_reduce does not mutate the tensor autograd passed in.
if comm.dest_cp_pg is not None:
grad_output = grad_output.contiguous().clone()
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=comm.dest_cp_pg)

if comm.direction is BridgeDirection.FAN_OUT:
return _all_gather_along_batch_dim(grad_output, comm.gather_pg, batch_dim), None

Expand Down
2 changes: 1 addition & 1 deletion megatron/core/models/mimo/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _apply_context_parallel(
batch["attention_mask"] = attention_mask

if packed_seq_params is None or getattr(packed_seq_params, 'qkv_format', 'sbhd') == 'sbhd':
batch = get_batch_on_this_cp_rank(batch)
batch = get_batch_on_this_cp_rank(batch, cp_group=self.cfg.cp_group)
else:
assert _HAVE_TEX and is_te_min_version("1.10.0"), (
"Please update Transformer Engine to >= 1.10 "
Expand Down
6 changes: 5 additions & 1 deletion tests/unit_tests/models/test_mimo_1f1b_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ def get_mimo_model(
module_to_grid_map=module_to_grid_map,
)

mimo_model = MimoModel(mimo_config)
# Forward language_pg.cp/tp so PartitionAdapter binds to the test's CP
# group directly. Test grids skip parallel_state.initialize_model_parallel,
# so leaving these None routes through the uninitialised global and trips
# 'context parallel group is not initialized' under CP>1.
mimo_model = MimoModel(mimo_config, cp_group=language_pg.cp, tp_group=language_pg.tp)
mimo_model.to(torch.device("cuda")).to(torch.bfloat16)

# Wrap with DDP (caller may override e.g. for heterogeneous-DP scaling).
Expand Down
Loading