[graph_trainer] Add EP overlap scheduling pass#3328
Conversation
Expose EP overlap as the user-facing graph_trainer pass for chunked MoE communication scheduling. The pass is opt-in through --compile.passes ep_overlap, defaults to batch splitting transformer blocks, and keeps the lower-level chunk transform as an internal implementation detail used by the overlap scheduler. The pass first applies chunking to the configured module regions, then classifies copied forward/backward nodes by chunk id and EP dispatch/combine annotations. It splits each EP phase at the last all-to-all launch so wait-suffix nodes can be delayed, then adds stable topological ordering constraints that launch peer chunk communication before waiting on the current chunk. Forward regions schedule dispatch before combine with chunk order 0 then 1; backward autograd regions schedule combine-gradient communication before dispatch-gradient communication with chunk order 1 then 0. Non-MoE transformer blocks keep the chunk transform's topological order, and configurations with no EP all-to-all regions fail fast. Update MoE traceback annotation setup so the actual AllToAllTokenDispatcher dispatch/combine bodies are annotated, not only the local dispatcher fallback. The annotation helper is idempotent and shared by DeepSeek and Qwen graph trainer parallelization. Replace the public chunk_modules/chunk_mode compile knobs with ep_overlap_modules/ep_overlap_mode. The trace input preparation hook now only runs when ep_overlap is present in compile.passes, and it marks every pytree tensor that shares the selected dynamic batch/sequence extent with the main input so labels and positional tensors carry the same dynamic annotation when they participate in tracing. Harden the chunk planning path for whole-block and MoE-only overlap regions. Symbolic shape scalar helpers may be shared across adjacent planned regions, while real tensor compute remains disjoint. Region live-ins are filtered after the copied closure reaches a fixed point, scalar live-out materialization handles guarded even split dimensions explicitly, and FX renaming is used for copied/materialized nodes to avoid duplicate generated names in large multi-region transforms. Test Plan: - pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or chunk' -q - pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/qwen3/parallelize.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py - pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/tests/test_passes.py - git diff --check stack-info: PR: #3328, branch: sanketpurandare/stack/14
6945c23 to
83081fc
Compare
|
TL parse artifacts from the EP overlap graph-pass validation runs:
Verification notes:
These artifacts validate the graph transformation and scheduling structure. They are not performance numbers; the performance/profile runs need to be collected separately on exclusively available GPUs. |
Expose EP overlap as the user-facing graph_trainer pass for chunked MoE communication scheduling. The pass is opt-in through --compile.passes ep_overlap, defaults to batch splitting transformer blocks, and keeps the lower-level chunk transform as an internal implementation detail used by the overlap scheduler. The pass first applies chunking to the configured module regions, then classifies copied forward/backward nodes by chunk id and EP dispatch/combine annotations. It splits each EP phase at the last all-to-all launch so wait-suffix nodes can be delayed, then adds stable topological ordering constraints that launch peer chunk communication before waiting on the current chunk. Forward regions schedule dispatch before combine with chunk order 0 then 1; backward autograd regions schedule combine-gradient communication before dispatch-gradient communication with chunk order 1 then 0. Non-MoE transformer blocks keep the chunk transform's topological order, and configurations with no EP all-to-all regions fail fast. Update MoE traceback annotation setup so the actual AllToAllTokenDispatcher dispatch/combine bodies are annotated, not only the local dispatcher fallback. The annotation helper is idempotent and shared by DeepSeek and Qwen graph trainer parallelization. Replace the public chunk_modules/chunk_mode compile knobs with ep_overlap_modules/ep_overlap_mode. The trace input preparation hook now only runs when ep_overlap is present in compile.passes, and it marks every pytree tensor that shares the selected dynamic batch/sequence extent with the main input so labels and positional tensors carry the same dynamic annotation when they participate in tracing. Harden the chunk planning path for whole-block and MoE-only overlap regions. Symbolic shape scalar helpers may be shared across adjacent planned regions, while real tensor compute remains disjoint. Region live-ins are filtered after the copied closure reaches a fixed point, scalar live-out materialization handles guarded even split dimensions explicitly, and FX renaming is used for copied/materialized nodes to avoid duplicate generated names in large multi-region transforms. Test Plan: - pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or chunk' -q - pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/qwen3/parallelize.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py - pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/tests/test_passes.py - git diff --check stack-info: PR: #3328, branch: sanketpurandare/stack/14
83081fc to
6783bda
Compare
| chunk0_val: object, chunk1_val: object, original_val: object | ||
| ) -> bool: | ||
| try: | ||
| half_original = _expr(original_val) // 2 |
There was a problem hiding this comment.
The chunking dataflow is generic over module granularity and the selected logical dimension, but this PR intentionally keeps the chunk count fixed at two because the EP overlap schedule is pairwise. I added a TODO in the pass contract calling out what N-way support would require: N-way provenance, materialization, scalar handling, eager references, and launch/wait scheduling.
| dispatch_nodes = [node for node in chunk_nodes if _ep_region(node) == "dispatch"] | ||
| combine_nodes = [node for node in chunk_nodes if _ep_region(node) == "combine"] |
There was a problem hiding this comment.
Refactored the EP phase helper so it collects dispatch/combine nodes and last all-to-all launches in one pass over the chunk body, then does one boundary-aware pass to mark wait suffixes and partition phases. This removes the repeated comprehensions while keeping the validation checks local.
| if not dispatch_nodes and not combine_nodes: | ||
| return None |
There was a problem hiding this comment.
we should probably split up the dense vs moe paths, they can share a lot of logic, but it's a bit hard to reason about both at the same time
There was a problem hiding this comment.
The user-facing control for this is ep_overlap_modules: layers.* selects transformer-block roots, while layers.*.moe selects only MoE roots. I also made the EP-overlap chunk phase skip matched roots with no EP dispatch/combine metadata, so transformer-block mode no longer chunks dense/non-EP layers unnecessarily. Regions with MoE dispatch/combine metadata go through EP scheduling.
| [node for node in chunk_nodes if order[node] <= first_idx], | ||
| [node for node in chunk_nodes if first_idx < order[node] <= second_idx], | ||
| [node for node in chunk_nodes if order[node] > second_idx], |
There was a problem hiding this comment.
please piggy back off the same loops everywhere
There was a problem hiding this comment.
Folded the repeated phase scans into the same collection/partition flow used by the EP scheduler. The helper still needs two conceptual passes because the wait boundaries are known only after seeing the last all-to-all in dispatch/combine, but it no longer rebuilds the same filtered lists multiple times.
Add the scheduling layer for graph_trainer's public ep_overlap pass. The previous commit creates the correctness-preserving chunked graph; this commit consumes that chunked graph after FSDP bucketing and reorders EP dispatch/combine phases so paired chunk all-to-all launches can run before their delayed wait suffixes. The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and MoE dispatcher regions annotated with custom EP metadata for dispatch and combine. The scheduler validates canonical chunk order, extracts dispatch/combine phases, annotates delayed wait suffixes with EP_wait=True, and applies a priority topological sort. It does not duplicate nodes, delete nodes, split tensors, concatenate tensors, or otherwise change tensor values. Wire the public ep_overlap pipeline as two stages: chunking runs after CPU offload and SAC/rematerialization, while scheduling runs after downstream FSDP bucketing/reordering. This keeps chunk planning aligned with the final recompute graph and lets bucketing place shared full waits before the overlap scheduler establishes the final EP launch/wait order. The distributed DSV3 bitwise numerics tests introduced in the previous commit now cover this composed public ep_overlap path as well. Since this scheduling pass is value-preserving, those tests validate that adding the scheduler on top of graph chunking does not change the eager-chunk-equivalent loss/gradient behavior. Existing tlparse validation artifacts show the intended rank-code order: dispatch chunk0 launch, dispatch chunk1 launch, delayed dispatch waits, combine chunk0 launch, combine chunk1 launch, and delayed combine waits annotated with EP_wait=True. This pass stack relies on pending PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 Test Plan: - pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or chunk' -q - 8 passed, 73 deselected, 2 subtests passed - pytest torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap -q - 3 passed, 20 deselected stack-info: PR: #3328, branch: sanketpurandare/stack/14
6783bda to
632a794
Compare
6c26f2c to
7157513
Compare
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values. The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers. The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values. Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving. This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 - HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837 - FlexAttention chunked unbacked input extents: pytorch/pytorch#183838 - FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839 - Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840 - Inductor TP slice/cat collective fusion: pytorch/pytorch#184833 Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py - pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 stack-info: PR: #3328, branch: sanketpurandare/stack/14
7157513 to
42d4319
Compare
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values. The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers. The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values. Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving. This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 - HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837 - FlexAttention chunked unbacked input extents: pytorch/pytorch#183838 - FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839 - Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840 - Inductor TP slice/cat collective fusion: pytorch/pytorch#184833 Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py - pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 stack-info: PR: #3328, branch: sanketpurandare/stack/14
42d4319 to
414f3aa
Compare
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values. The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers. The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values. Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving. This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 - HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837 - FlexAttention chunked unbacked input extents: pytorch/pytorch#183838 - FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839 - Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840 - Inductor TP slice/cat collective fusion: pytorch/pytorch#184833 Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py - pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 stack-info: PR: #3328, branch: sanketpurandare/stack/14
629d28c to
a0461e9
Compare
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values. The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers. The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values. Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving. This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 - HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837 - FlexAttention chunked unbacked input extents: pytorch/pytorch#183838 - FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839 - Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840 - Inductor TP slice/cat collective fusion: pytorch/pytorch#184833 Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py - pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 - PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8 stack-info: PR: #3328, branch: sanketpurandare/stack/14
Stacked PRs:
[graph_trainer] Add EP overlap scheduling pass
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values.
The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers.
The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values.
Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving.
This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
Test Plan: