[graph_trainer] Use separate EP process groups for overlap#3369
Conversation
115b707 to
4079f4a
Compare
|
|
||
|
|
||
| def _pg_rank_set(pg_name: str) -> frozenset[int]: | ||
| import torch.distributed as dist |
There was a problem hiding this comment.
nit: looks like we importing locally in 3+ functions, consider moving to the top of the file.
There was a problem hiding this comment.
Moved the EP process-group logic into its own graph-trainer pass, so the local imports are no longer spread through the FSDP pass. The remaining distributed imports are contained in the PG resolution/creation helpers that need them.
| def _is_ep_all_to_all(node: fx.Node) -> bool: | ||
| custom = node.meta.get("custom", {}) | ||
| return ( | ||
| node.op == "call_function" | ||
| and "all_to_all_single" in str(node.target) | ||
| and isinstance(custom, dict) | ||
| and custom.get("EP") in ("dispatch", "combine") |
There was a problem hiding this comment.
custom.get("EP") in ("dispatch", "combine") this is not reliable.
In some come case, I see the entire EP is annotated as "compute"
There was a problem hiding this comment.
Reworked this to identify EP process-group users by the traced _c10d_functional.all_to_all_single.default calls themselves, instead of relying on generic custom["EP"] phase metadata. That keeps PG isolation tied to the collectives issued through the EP mesh rather than to scheduler-only traceback labels.
| all-gather to the corresponding extra PG. This separates all-gathers | ||
| from reduce-scatters onto different streams, enabling AG/RS overlap in | ||
| backward. | ||
| backward. If an EP all-to-all resolves to the same process group as an FSDP |
There was a problem hiding this comment.
Would we ever want to enable EP A2A in a separate stream but not a separate stream for AG/RS ? Should this be a separate pass with its own enable toggle?
There was a problem hiding this comment.
Split this out into isolate_ep_process_group_pass with its own graph-trainer pass hook. It rewrites EP all-to-all collectives when their process group object is shared with non-EP collectives.
| pg_mapping: dict[str, str] = { | ||
| pg: _get_or_create_extra_fsdp_pg(pg) for pg in source_pg_names | ||
| } | ||
| ep_source_pg_names = OrderedSet( |
There was a problem hiding this comment.
nit: Can avoid extra O(N) scan of graph nodes by populating ep_source_pg_names inside previous for loop
There was a problem hiding this comment.
The standalone pass now performs one graph scan to classify EP and non-EP process groups, then rewrites only the matching EP all-to-alls. There is no extra FSDP-specific second scan for EP source PG discovery.
| custom = node.meta.get("custom", {}) | ||
| return ( | ||
| node.op == "call_function" | ||
| and "all_to_all_single" in str(node.target) |
There was a problem hiding this comment.
let's do direct target match.
node.target == torch.ops.c10d.all_to_all_single.default
There was a problem hiding this comment.
Updated the predicate to match the actual _c10d_functional.all_to_all_single.default target directly. The pass no longer requires the EP_token_exchange scheduler annotation for PG isolation.
| def _is_ep_all_to_all(node: fx.Node) -> bool: | ||
| custom = node.meta.get("custom", {}) | ||
| return ( | ||
| node.op == "call_function" |
There was a problem hiding this comment.
node.op == "call_function" is redundant.
There was a problem hiding this comment.
Kept the node.op guard only as an FX safety check before reading call-function arguments. The semantic selection is the direct c10d all-to-all target; PG isolation no longer depends on the token-exchange annotation.
| No-op when the graph has no FSDP all-gathers. Must be applied BEFORE | ||
| bucketing passes so bucketed all-gathers inherit the new PG name. | ||
| """ | ||
| source_pg_names: OrderedSet[str] = OrderedSet() |
There was a problem hiding this comment.
are these fsdp_ag_pg_names? rename if so
There was a problem hiding this comment.
these looks like a mixed of fsdp ag and rs pgs?
There was a problem hiding this comment.
Renamed and generalized the collection to distinguish EP all-to-all PGs from non-EP collective PGs. The pass now handles FSDP AG/RS, TP collectives, and other non-EP collectives uniformly.
| ) | ||
|
|
||
|
|
||
| def overlap_fsdp_ag_rs_pass( |
There was a problem hiding this comment.
update pass name to reflect EP change.
overlap_collectives_pass?
reassign_pg_for_overlap_collectives_pass?
There was a problem hiding this comment.
Renamed the pass to isolate_ep_process_group_pass, which reflects the actual behavior: EP all-to-alls are moved to an EP-only process group when they otherwise share a PG with non-EP collectives.
| _EXTRA_FSDP_PG_REGISTRY: dict[str, str] = {} | ||
| _EXTRA_EP_PG_REGISTRY: dict[str, str] = {} |
There was a problem hiding this comment.
feels like they can be merged?
There was a problem hiding this comment.
Kept the EP PG isolation separate from FSDP bucketing because it is not FSDP-specific. It can apply when EP shares a PG with FSDP, TP, or another non-EP collective source.
SherlockNoMad
left a comment
There was a problem hiding this comment.
thanks! lgtm overall, just minor comments.
4079f4a to
d435ffb
Compare
When EP all-to-alls use the same process group object as FSDP or TP collectives, they share a NCCL communicator/stream and can serialize communication that the EP overlap schedule is trying to expose. Add isolate_ep_process_group_pass as a dedicated EP pass instead of mixing this policy into FSDP bucketing. The pass scans traced collectives, identifies EP process-group users by _c10d_functional.all_to_all_single.default nodes, and compares the resolved process-group object identity against non-EP collective process groups. If an EP process group is object-identical to a non-EP collective group, the pass creates one high-priority EP-only process group over the same EP ranks and rewrites all traced EP all-to-all calls that used the shared source group. The matching contract is intentionally process-group based because graph-level metadata does not carry DeviceMesh objects. Rank-set equality is only a validator: two distinct process groups may have the same ranks but already have separate NCCL communicators, so rank-set equality alone is not a sufficient reason to rewrite EP. The replacement communicator is always constructed from the EP process group itself, preserving the EP group rank order used by all-to-all split semantics. FSDP all-gathers continue to use the existing extra FSDP group from overlap_fsdp_ag_rs_pass. Generic MoE dispatch/combine traceback metadata and EP_token_exchange scheduling metadata are not required for EP process-group isolation; the scheduling pass still owns token-exchange wait ordering. Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py::TestOverlapFsdpAgRsPass stack-info: PR: #3369, branch: sanketpurandare/stack/18
d435ffb to
cd21979
Compare
cd21979 to
5bf9551
Compare
When EP all-to-alls use the same process group object as FSDP or TP collectives, they share a NCCL communicator/stream and can serialize communication that the EP overlap schedule is trying to expose. Add isolate_ep_process_group_pass as a dedicated EP pass instead of mixing this policy into FSDP bucketing. The pass scans traced collectives, identifies EP process-group users by _c10d_functional.all_to_all_single.default nodes, and compares the resolved process-group object identity against non-EP collective process groups. If an EP process group is object-identical to a non-EP collective group, the pass creates one high-priority EP-only process group over the same EP ranks and rewrites all traced EP all-to-all calls that used the shared source group. The matching contract is intentionally process-group based because graph-level metadata does not carry DeviceMesh objects. Rank-set equality is only a validator: two distinct process groups may have the same ranks but already have separate NCCL communicators, so rank-set equality alone is not a sufficient reason to rewrite EP. The replacement communicator is always constructed from the EP process group itself, preserving the EP group rank order used by all-to-all split semantics. FSDP all-gathers continue to use the existing extra FSDP group from overlap_fsdp_ag_rs_pass. Generic MoE dispatch/combine traceback metadata and EP_token_exchange scheduling metadata are not required for EP process-group isolation; the scheduling pass still owns token-exchange wait ordering. Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py::TestOverlapFsdpAgRsPass stack-info: PR: #3369, branch: sanketpurandare/stack/18
5bf9551 to
5707dd2
Compare
… routing Merged 7 upstream commits (19c567f..af33f76). Documents which ones needed ezpz replays: - PR pytorch#3398 (Module subclass refactor): 3 import paths replayed in b052f29 — pure import-path swap, class API unchanged. - PR pytorch#3146 (deterministic MoE routing): inherits transitively; this is the upstream fix for the _histc_xpu non-determinism blocker we hit on 2026-05-21. --debug.deterministic on MoE+XPU should now work. - PR pytorch#3423 (MoE [7/n] 3D tensors): inherits transitively; doesn't touch deepseek_v3 callsites. - PR pytorch#3105 (FSDP symm_mem): skipped — ezpz has its own apply_fsdp and symm_mem is an optional optimization XPU CCL likely doesn't support. - PRs pytorch#3331/pytorch#3369/pytorch#3361: graph_trainer-only no-ops. Captures two action items: smoke-test before next production push, and re-try --debug.deterministic on MoE+XPU.
Stacked PRs:
[graph_trainer] Use separate EP process groups for overlap
When EP all-to-alls use the same process group object as FSDP or TP collectives, they share a NCCL communicator/stream and can serialize communication that the EP overlap schedule is trying to expose.
Add isolate_ep_process_group_pass as a dedicated EP pass instead of mixing this policy into FSDP bucketing. The pass scans traced collectives, identifies EP process-group users by _c10d_functional.all_to_all_single.default nodes, and compares the resolved process-group object identity against non-EP collective process groups. If an EP process group is object-identical to a non-EP collective group, the pass creates one high-priority EP-only process group over the same EP ranks and rewrites all traced EP all-to-all calls that used the shared source group.
The matching contract is intentionally process-group based because graph-level metadata does not carry DeviceMesh objects. Rank-set equality is only a validator: two distinct process groups may have the same ranks but already have separate NCCL communicators, so rank-set equality alone is not a sufficient reason to rewrite EP. The replacement communicator is always constructed from the EP process group itself, preserving the EP group rank order used by all-to-all split semantics.
FSDP all-gathers continue to use the existing extra FSDP group from overlap_fsdp_ag_rs_pass. Generic MoE dispatch/combine traceback metadata and EP_token_exchange scheduling metadata are not required for EP process-group isolation; the scheduling pass still owns token-exchange wait ordering.
Test Plan: