NMFW-17: Add ColocatedBridgeCommunicator for heterogeneous TP/DP MIMO training#1
Conversation
| packing_kwargs: Optional[dict] = None, | ||
| ): | ||
| """Forward pass for colocated mode: encoder and LLM on same ranks, different TP/DP.""" | ||
| packed_seq_params = None |
There was a problem hiding this comment.
lets not worry about sequence packing for now
| ) | ||
|
|
||
| # 4. Optional partition adapter | ||
| if self.partition_adapter is not None: |
There was a problem hiding this comment.
also dont worry about partition adapter yet
| packing_kwargs: Optional[dict] = None, | ||
| ): | ||
| """Forward pass for colocated mode: encoder and LLM on same ranks, different TP/DP.""" | ||
| packed_seq_params = None |
There was a problem hiding this comment.
also this function seems a little verbose, we almost copied the whole thing and just added the apply colocated comms?
|
|
||
|
|
||
| @dataclass | ||
| class ColocatedCommConfig: |
There was a problem hiding this comment.
do we really need seperate ColocatedCommConfig ? also module to grid map seems to be replicated at both places ? mimo model config and here? suggest simpler and cleaner alternatives
| self._extract_parallelism_info() | ||
| self._build_rank_mappings() | ||
|
|
||
| self.all_gather_pg: Optional[dist.ProcessGroup] = None |
There was a problem hiding this comment.
when can this be None ?
| ) | ||
|
|
||
| def _extract_parallelism_info(self): | ||
| self.src_tp_size = self.src_grid.shape[self.src_grid.dim_names.index('tp')] |
There was a problem hiding this comment.
can we use pg for this? pg.size() ?
| """Config for colocated modules with different TP/DP on same ranks.""" | ||
|
|
||
| module_to_grid_map: Dict[str, 'HyperCommGrid'] = field(default_factory=dict) | ||
| topology: Dict[str, list] = field(default_factory=dict) |
There was a problem hiding this comment.
what do we need topology for ?
| from megatron.core.optimizer import get_megatron_optimizer | ||
|
|
||
| grid_map = mimo_model.mimo_config.module_to_grid_map | ||
| if grid_map is None and mimo_model.mimo_config.colocated_comm_config is not None: |
There was a problem hiding this comment.
this seems a little redundant and two sources of truth ?
| @@ -0,0 +1,348 @@ | |||
| # Colocated MIMO Correctness Testing Design | |||
There was a problem hiding this comment.
dont push local planning here, nothing from docs/plans unless explicitly asked
| self.dp_scale_factor = self.src_dp_size / self.dest_dp_size | ||
|
|
||
| def _build_rank_mappings(self): | ||
| self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {} |
There was a problem hiding this comment.
what is this var storing here ?
| if mimo_config.module_to_grid_map: | ||
| self.colocated_comms = {} | ||
| if mimo_config.colocated_comm_config is not None: | ||
| self.role = RankRole.colocated(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) |
There was a problem hiding this comment.
this seems a little flaky, can we also build this from grid map ?
| ) | ||
| return modality_embeddings | ||
|
|
||
| def _forward_colocated( |
There was a problem hiding this comment.
would unified mode still makes sense ? or will that break things now ? can we just have colocated custom process groups and the legacy unified with process groups used from parallel state cleanly supported in colocated umbrella?
c12e3db to
a8122d5
Compare
… (NMFW-17) COLOCATED mode replaces UNIFIED — covers both legacy (no grid map) and heterogeneous TP/DP on shared ranks. Auto-detects colocated from grid overlap. Core: ColocatedBridgeCommunicator with fan-in/fan-out/equal-DP autograd. Model: _forward_all_modules with optional colocated communication. Tests: communicator unit tests, multi-iteration correctness, e2e VLM. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three-phase execution for colocated encoder PP=1 + LLM PP>1: - Phase 1: Encoder forward + communicate on full batch (all ranks sync) - Phase 2: LLM 1F1B pipeline with detached encoder embeddings - Phase 3: Encoder backward on full batch (all ranks sync) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
a8122d5 to
d834b76
Compare
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Status Update (NMFW-17 + NMFW-19)Implementation Complete — 19 tests passingPR structure:
PP=1 tests: communicator (11), e2e VLM + MimoOptimizer (1) PP>1 Design Summary
|
Summary
Test commands (8 GPUs, run individually)
Linear: NMFW-17
🤖 Generated with Claude Code