Skip to content

[graph_trainer] Use separate EP process groups for overlap#3369

Merged
sanketpurandare merged 1 commit into
mainfrom
sanketpurandare/stack/18
May 27, 2026
Merged

[graph_trainer] Use separate EP process groups for overlap#3369
sanketpurandare merged 1 commit into
mainfrom
sanketpurandare/stack/18

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented May 15, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[graph_trainer] Use separate EP process groups for overlap

When EP all-to-alls use the same process group object as FSDP or TP collectives, they share a NCCL communicator/stream and can serialize communication that the EP overlap schedule is trying to expose.

Add isolate_ep_process_group_pass as a dedicated EP pass instead of mixing this policy into FSDP bucketing. The pass scans traced collectives, identifies EP process-group users by _c10d_functional.all_to_all_single.default nodes, and compares the resolved process-group object identity against non-EP collective process groups. If an EP process group is object-identical to a non-EP collective group, the pass creates one high-priority EP-only process group over the same EP ranks and rewrites all traced EP all-to-all calls that used the shared source group.

The matching contract is intentionally process-group based because graph-level metadata does not carry DeviceMesh objects. Rank-set equality is only a validator: two distinct process groups may have the same ranks but already have separate NCCL communicators, so rank-set equality alone is not a sufficient reason to rewrite EP. The replacement communicator is always constructed from the EP process group itself, preserving the EP group rank order used by all-to-all split semantics.

FSDP all-gathers continue to use the existing extra FSDP group from overlap_fsdp_ag_rs_pass. Generic MoE dispatch/combine traceback metadata and EP_token_exchange scheduling metadata are not required for EP process-group isolation; the scheduling pass still owns token-exchange wait ordering.

Test Plan:

  • pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py::TestOverlapFsdpAgRsPass



def _pg_rank_set(pg_name: str) -> frozenset[int]:
import torch.distributed as dist

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks like we importing locally in 3+ functions, consider moving to the top of the file.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the EP process-group logic into its own graph-trainer pass, so the local imports are no longer spread through the FSDP pass. The remaining distributed imports are contained in the PG resolution/creation helpers that need them.

Comment on lines +224 to +230
def _is_ep_all_to_all(node: fx.Node) -> bool:
custom = node.meta.get("custom", {})
return (
node.op == "call_function"
and "all_to_all_single" in str(node.target)
and isinstance(custom, dict)
and custom.get("EP") in ("dispatch", "combine")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who adds this metadata?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom.get("EP") in ("dispatch", "combine") this is not reliable.

In some come case, I see the entire EP is annotated as "compute"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworked this to identify EP process-group users by the traced _c10d_functional.all_to_all_single.default calls themselves, instead of relying on generic custom["EP"] phase metadata. That keeps PG isolation tied to the collectives issued through the EP mesh rather than to scheduler-only traceback labels.

all-gather to the corresponding extra PG. This separates all-gathers
from reduce-scatters onto different streams, enabling AG/RS overlap in
backward.
backward. If an EP all-to-all resolves to the same process group as an FSDP

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we ever want to enable EP A2A in a separate stream but not a separate stream for AG/RS ? Should this be a separate pass with its own enable toggle?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split this out into isolate_ep_process_group_pass with its own graph-trainer pass hook. It rewrites EP all-to-all collectives when their process group object is shared with non-EP collectives.

pg_mapping: dict[str, str] = {
pg: _get_or_create_extra_fsdp_pg(pg) for pg in source_pg_names
}
ep_source_pg_names = OrderedSet(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can avoid extra O(N) scan of graph nodes by populating ep_source_pg_names inside previous for loop

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standalone pass now performs one graph scan to classify EP and non-EP process groups, then rewrites only the matching EP all-to-alls. There is no extra FSDP-specific second scan for EP source PG discovery.

custom = node.meta.get("custom", {})
return (
node.op == "call_function"
and "all_to_all_single" in str(node.target)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do direct target match.

node.target == torch.ops.c10d.all_to_all_single.default

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the predicate to match the actual _c10d_functional.all_to_all_single.default target directly. The pass no longer requires the EP_token_exchange scheduler annotation for PG isolation.

def _is_ep_all_to_all(node: fx.Node) -> bool:
custom = node.meta.get("custom", {})
return (
node.op == "call_function"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node.op == "call_function" is redundant.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept the node.op guard only as an FX safety check before reading call-function arguments. The semantic selection is the direct c10d all-to-all target; PG isolation no longer depends on the token-exchange annotation.

No-op when the graph has no FSDP all-gathers. Must be applied BEFORE
bucketing passes so bucketed all-gathers inherit the new PG name.
"""
source_pg_names: OrderedSet[str] = OrderedSet()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these fsdp_ag_pg_names? rename if so

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these looks like a mixed of fsdp ag and rs pgs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed and generalized the collection to distinguish EP all-to-all PGs from non-EP collective PGs. The pass now handles FSDP AG/RS, TP collectives, and other non-EP collectives uniformly.

)


def overlap_fsdp_ag_rs_pass(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update pass name to reflect EP change.

overlap_collectives_pass?
reassign_pg_for_overlap_collectives_pass?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed the pass to isolate_ep_process_group_pass, which reflects the actual behavior: EP all-to-alls are moved to an EP-only process group when they otherwise share a PG with non-EP collectives.

Comment on lines +138 to +139
_EXTRA_FSDP_PG_REGISTRY: dict[str, str] = {}
_EXTRA_EP_PG_REGISTRY: dict[str, str] = {}

@SherlockNoMad SherlockNoMad May 18, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like they can be merged?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept the EP PG isolation separate from FSDP bucketing because it is not FSDP-specific. It can apply when EP shares a PG with FSDP, TP, or another non-EP collective source.

@SherlockNoMad SherlockNoMad left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! lgtm overall, just minor comments.

@sanketpurandare sanketpurandare marked this pull request as draft May 22, 2026 19:02
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/15 to main May 22, 2026 19:02
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/18 branch from 4079f4a to d435ffb Compare May 22, 2026 19:03
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/15 May 22, 2026 19:03
@sanketpurandare sanketpurandare marked this pull request as ready for review May 22, 2026 19:03
sanketpurandare added a commit that referenced this pull request May 26, 2026
When EP all-to-alls use the same process group object as FSDP or TP collectives, they share a NCCL communicator/stream and can serialize communication that the EP overlap schedule is trying to expose.

Add isolate_ep_process_group_pass as a dedicated EP pass instead of mixing this policy into FSDP bucketing. The pass scans traced collectives, identifies EP process-group users by _c10d_functional.all_to_all_single.default nodes, and compares the resolved process-group object identity against non-EP collective process groups. If an EP process group is object-identical to a non-EP collective group, the pass creates one high-priority EP-only process group over the same EP ranks and rewrites all traced EP all-to-all calls that used the shared source group.

The matching contract is intentionally process-group based because graph-level metadata does not carry DeviceMesh objects. Rank-set equality is only a validator: two distinct process groups may have the same ranks but already have separate NCCL communicators, so rank-set equality alone is not a sufficient reason to rewrite EP. The replacement communicator is always constructed from the EP process group itself, preserving the EP group rank order used by all-to-all split semantics.

FSDP all-gathers continue to use the existing extra FSDP group from overlap_fsdp_ag_rs_pass. Generic MoE dispatch/combine traceback metadata and EP_token_exchange scheduling metadata are not required for EP process-group isolation; the scheduling pass still owns token-exchange wait ordering.

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py::TestOverlapFsdpAgRsPass

stack-info: PR: #3369, branch: sanketpurandare/stack/18
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/18 branch from d435ffb to cd21979 Compare May 26, 2026 05:15
@sanketpurandare sanketpurandare requested a review from wconstab as a code owner May 26, 2026 05:15
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:28
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/15 to main May 26, 2026 05:28
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/18 branch from cd21979 to 5bf9551 Compare May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:30
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:31
When EP all-to-alls use the same process group object as FSDP or TP collectives, they share a NCCL communicator/stream and can serialize communication that the EP overlap schedule is trying to expose.

Add isolate_ep_process_group_pass as a dedicated EP pass instead of mixing this policy into FSDP bucketing. The pass scans traced collectives, identifies EP process-group users by _c10d_functional.all_to_all_single.default nodes, and compares the resolved process-group object identity against non-EP collective process groups. If an EP process group is object-identical to a non-EP collective group, the pass creates one high-priority EP-only process group over the same EP ranks and rewrites all traced EP all-to-all calls that used the shared source group.

The matching contract is intentionally process-group based because graph-level metadata does not carry DeviceMesh objects. Rank-set equality is only a validator: two distinct process groups may have the same ranks but already have separate NCCL communicators, so rank-set equality alone is not a sufficient reason to rewrite EP. The replacement communicator is always constructed from the EP process group itself, preserving the EP group rank order used by all-to-all split semantics.

FSDP all-gathers continue to use the existing extra FSDP group from overlap_fsdp_ag_rs_pass. Generic MoE dispatch/combine traceback metadata and EP_token_exchange scheduling metadata are not required for EP process-group isolation; the scheduling pass still owns token-exchange wait ordering.

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py::TestOverlapFsdpAgRsPass

stack-info: PR: #3369, branch: sanketpurandare/stack/18
@sanketpurandare sanketpurandare marked this pull request as draft May 27, 2026 03:51
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/18 branch from 5bf9551 to 5707dd2 Compare May 27, 2026 03:51
@sanketpurandare sanketpurandare marked this pull request as ready for review May 27, 2026 03:52
@sanketpurandare sanketpurandare merged commit c59f8c9 into main May 27, 2026
16 of 17 checks passed
saforem2 added a commit to saforem2/torchtitan that referenced this pull request May 27, 2026
… routing

Merged 7 upstream commits (19c567f..af33f76). Documents which
ones needed ezpz replays:

- PR pytorch#3398 (Module subclass refactor): 3 import paths replayed in
  b052f29 — pure import-path swap, class API unchanged.
- PR pytorch#3146 (deterministic MoE routing): inherits transitively; this
  is the upstream fix for the _histc_xpu non-determinism blocker
  we hit on 2026-05-21. --debug.deterministic on MoE+XPU should now
  work.
- PR pytorch#3423 (MoE [7/n] 3D tensors): inherits transitively; doesn't
  touch deepseek_v3 callsites.
- PR pytorch#3105 (FSDP symm_mem): skipped — ezpz has its own apply_fsdp
  and symm_mem is an optional optimization XPU CCL likely doesn't
  support.
- PRs pytorch#3331/pytorch#3369/pytorch#3361: graph_trainer-only no-ops.

Captures two action items: smoke-test before next production push,
and re-try --debug.deterministic on MoE+XPU.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants