Skip to content

NMFW-17: Add ColocatedBridgeCommunicator for heterogeneous TP/DP MIMO training#1

Closed
yashaswikarnati wants to merge 3 commits into
mainfrom
ykarnati/nmfw-17-colocated-colocated-bridge-communicator
Closed

NMFW-17: Add ColocatedBridgeCommunicator for heterogeneous TP/DP MIMO training#1
yashaswikarnati wants to merge 3 commits into
mainfrom
ykarnati/nmfw-17-colocated-colocated-bridge-communicator

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Owner

Summary

  • ColocatedBridgeCommunicator: autograd-aware fan-in/fan-out/equal-DP communication between encoder and LLM with different TP/DP on same ranks
  • MimoModel colocated forward path with config, role, and optimizer support
  • 3 test files: communicator unit tests (11), multi-iteration correctness (9 checks x 3 iters x 3 configs), e2e VLM with MimoOptimizer

Test commands (8 GPUs, run individually)

uv run python -m torch.distributed.run --nproc_per_node=8 -m pytest tests/unit_tests/models/test_mimo_colocated_communicator.py -v
uv run python -m torch.distributed.run --nproc_per_node=8 -m pytest "tests/unit_tests/models/test_mimo_colocated_correctness.py::TestColocatedCorrectness::test_correctness[fan_in]" -v
uv run python -m torch.distributed.run --nproc_per_node=8 -m pytest tests/unit_tests/models/test_mimo_colocated_e2e.py -v

Linear: NMFW-17

🤖 Generated with Claude Code

Comment thread megatron/core/models/mimo/model/base.py Outdated
packing_kwargs: Optional[dict] = None,
):
"""Forward pass for colocated mode: encoder and LLM on same ranks, different TP/DP."""
packed_seq_params = None
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not worry about sequence packing for now

Comment thread megatron/core/models/mimo/model/base.py Outdated
)

# 4. Optional partition adapter
if self.partition_adapter is not None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also dont worry about partition adapter yet

Comment thread megatron/core/models/mimo/model/base.py Outdated
packing_kwargs: Optional[dict] = None,
):
"""Forward pass for colocated mode: encoder and LLM on same ranks, different TP/DP."""
packed_seq_params = None
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this function seems a little verbose, we almost copied the whole thing and just added the apply colocated comms?



@dataclass
class ColocatedCommConfig:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need seperate ColocatedCommConfig ? also module to grid map seems to be replicated at both places ? mimo model config and here? suggest simpler and cleaner alternatives

self._extract_parallelism_info()
self._build_rank_mappings()

self.all_gather_pg: Optional[dist.ProcessGroup] = None
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when can this be None ?

)

def _extract_parallelism_info(self):
self.src_tp_size = self.src_grid.shape[self.src_grid.dim_names.index('tp')]
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use pg for this? pg.size() ?

"""Config for colocated modules with different TP/DP on same ranks."""

module_to_grid_map: Dict[str, 'HyperCommGrid'] = field(default_factory=dict)
topology: Dict[str, list] = field(default_factory=dict)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do we need topology for ?

Comment thread megatron/core/models/mimo/optimizer.py Outdated
from megatron.core.optimizer import get_megatron_optimizer

grid_map = mimo_model.mimo_config.module_to_grid_map
if grid_map is None and mimo_model.mimo_config.colocated_comm_config is not None:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a little redundant and two sources of truth ?

@@ -0,0 +1,348 @@
# Colocated MIMO Correctness Testing Design
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont push local planning here, nothing from docs/plans unless explicitly asked

self.dp_scale_factor = self.src_dp_size / self.dest_dp_size

def _build_rank_mappings(self):
self.rank_to_src_pos: Dict[int, Tuple[int, int]] = {}
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this var storing here ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
if mimo_config.module_to_grid_map:
self.colocated_comms = {}
if mimo_config.colocated_comm_config is not None:
self.role = RankRole.colocated(modality_names + [MIMO_LANGUAGE_MODULE_KEY])
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a little flaky, can we also build this from grid map ?

Comment thread megatron/core/models/mimo/model/base.py Outdated
)
return modality_embeddings

def _forward_colocated(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would unified mode still makes sense ? or will that break things now ? can we just have colocated custom process groups and the legacy unified with process groups used from parallel state cleanly supported in colocated umbrella?

@yashaswikarnati yashaswikarnati force-pushed the ykarnati/nmfw-17-colocated-colocated-bridge-communicator branch from c12e3db to a8122d5 Compare March 26, 2026 20:48
yashaswikarnati and others added 2 commits March 27, 2026 20:28
… (NMFW-17)

COLOCATED mode replaces UNIFIED — covers both legacy (no grid map) and
heterogeneous TP/DP on shared ranks. Auto-detects colocated from grid overlap.

Core: ColocatedBridgeCommunicator with fan-in/fan-out/equal-DP autograd.
Model: _forward_all_modules with optional colocated communication.
Tests: communicator unit tests, multi-iteration correctness, e2e VLM.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three-phase execution for colocated encoder PP=1 + LLM PP>1:
- Phase 1: Encoder forward + communicate on full batch (all ranks sync)
- Phase 2: LLM 1F1B pipeline with detached encoder embeddings
- Phase 3: Encoder backward on full batch (all ranks sync)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/nmfw-17-colocated-colocated-bridge-communicator branch from a8122d5 to d834b76 Compare March 28, 2026 03:29
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yashaswikarnati
Copy link
Copy Markdown
Owner Author

Status Update (NMFW-17 + NMFW-19)

Implementation Complete — 19 tests passing

PR structure:

  • Commit 1 (NMFW-17): PP=1 ColocatedBridgeCommunicator — fan-in/fan-out/equal-DP autograd, COLOCATED mode replaces UNIFIED, auto-detect from grid overlap
  • Commit 2 (NMFW-19): PP>1 three-phase schedule — encoder batch → LLM 1F1B pipeline → encoder backward
  • Commit 3: Moved correctness tests to PR NMFW-50: Colocated correctness tests #2 (NMFW-50)

PP=1 tests: communicator (11), e2e VLM + MimoOptimizer (1)
PP>1 tests: fan-in TP2/DP4→TP2/DP2/PP2 (1), equal-DP TP4/DP2→TP2/DP2/PP2 (1), grad accumulation 6mb (1), extreme TP1/DP8→TP4/DP1/PP2 (1)
Correctness tests: moved to PR #2

PP>1 Design Summary

  • Phase 1: One encoder forward + one communicate on full batch (all ranks sync)
  • Phase 2: 1F1B pipeline for LLM with detached encoder embeddings sliced per microbatch
  • Phase 3: Broadcast gradient from PP stage 0 → 1+, one encoder backward (all ranks sync)
  • Detach prevents encoder TP all-reduce (which may cross PP stages) from running inside staggered pipeline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant