Skip to content

[graph_trainer] Add EP overlap scheduling pass#3328

Open
sanketpurandare wants to merge 1 commit into
sanketpurandare/stack/12from
sanketpurandare/stack/14
Open

[graph_trainer] Add EP overlap scheduling pass#3328
sanketpurandare wants to merge 1 commit into
sanketpurandare/stack/12from
sanketpurandare/stack/14

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented May 12, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[graph_trainer] Add EP overlap scheduling pass

Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values.

The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers.

The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values.

Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:

Test Plan:

  • pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
  • pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8
  • PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 12, 2026
sanketpurandare added a commit that referenced this pull request May 12, 2026
Expose EP overlap as the user-facing graph_trainer pass for chunked MoE communication scheduling. The pass is opt-in through --compile.passes ep_overlap, defaults to batch splitting transformer blocks, and keeps the lower-level chunk transform as an internal implementation detail used by the overlap scheduler.

The pass first applies chunking to the configured module regions, then classifies copied forward/backward nodes by chunk id and EP dispatch/combine annotations. It splits each EP phase at the last all-to-all launch so wait-suffix nodes can be delayed, then adds stable topological ordering constraints that launch peer chunk communication before waiting on the current chunk. Forward regions schedule dispatch before combine with chunk order 0 then 1; backward autograd regions schedule combine-gradient communication before dispatch-gradient communication with chunk order 1 then 0. Non-MoE transformer blocks keep the chunk transform's topological order, and configurations with no EP all-to-all regions fail fast.

Update MoE traceback annotation setup so the actual AllToAllTokenDispatcher dispatch/combine bodies are annotated, not only the local dispatcher fallback. The annotation helper is idempotent and shared by DeepSeek and Qwen graph trainer parallelization.

Replace the public chunk_modules/chunk_mode compile knobs with ep_overlap_modules/ep_overlap_mode. The trace input preparation hook now only runs when ep_overlap is present in compile.passes, and it marks every pytree tensor that shares the selected dynamic batch/sequence extent with the main input so labels and positional tensors carry the same dynamic annotation when they participate in tracing.

Harden the chunk planning path for whole-block and MoE-only overlap regions. Symbolic shape scalar helpers may be shared across adjacent planned regions, while real tensor compute remains disjoint. Region live-ins are filtered after the copied closure reaches a fixed point, scalar live-out materialization handles guarded even split dimensions explicitly, and FX renaming is used for copied/materialized nodes to avoid duplicate generated names in large multi-region transforms.

Test Plan:

- pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or chunk' -q

- pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/qwen3/parallelize.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py

- pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/tests/test_passes.py

- git diff --check

stack-info: PR: #3328, branch: sanketpurandare/stack/14
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/14 branch from 6945c23 to 83081fc Compare May 12, 2026 08:41
@sanketpurandare

Copy link
Copy Markdown
Contributor Author

TL parse artifacts from the EP overlap graph-pass validation runs:

Verification notes:

  • Transformer-block default mode applied chunk_batch to 12 layer regions and ep_overlap to 10 EP chunked regions. That matches the debug model shape: dense/non-MoE layers can be chunked but do not have EP all-to-all scheduling phases.
  • MoE-only mode applied chunk_batch to 10 MoE regions and ep_overlap to 10 EP chunked regions.
  • The MoE-only graph diff shows the expected EP duplication from chunking: all_to_all_single.default: 25 -> 50, histc.default: 10 -> 20, and wait_tensor.default: 291 -> 366.
  • Generated rank0 code shows the intended launch/wait ordering for a MoE layer: chunk 0 dispatch launch, chunk 1 dispatch launch, then dispatch waits/compute, followed by chunk 0 combine launch, chunk 1 combine launch, then combine waits. The delayed wait-suffix nodes are tagged with EP_wait=True after the relevant all-to-all boundaries.

These artifacts validate the graph transformation and scheduling structure. They are not performance numbers; the performance/profile runs need to be collected separately on exclusively available GPUs.

@sanketpurandare sanketpurandare marked this pull request as draft May 12, 2026 09:13
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/12 to main May 12, 2026 09:13
sanketpurandare added a commit that referenced this pull request May 12, 2026
Expose EP overlap as the user-facing graph_trainer pass for chunked MoE communication scheduling. The pass is opt-in through --compile.passes ep_overlap, defaults to batch splitting transformer blocks, and keeps the lower-level chunk transform as an internal implementation detail used by the overlap scheduler.

The pass first applies chunking to the configured module regions, then classifies copied forward/backward nodes by chunk id and EP dispatch/combine annotations. It splits each EP phase at the last all-to-all launch so wait-suffix nodes can be delayed, then adds stable topological ordering constraints that launch peer chunk communication before waiting on the current chunk. Forward regions schedule dispatch before combine with chunk order 0 then 1; backward autograd regions schedule combine-gradient communication before dispatch-gradient communication with chunk order 1 then 0. Non-MoE transformer blocks keep the chunk transform's topological order, and configurations with no EP all-to-all regions fail fast.

Update MoE traceback annotation setup so the actual AllToAllTokenDispatcher dispatch/combine bodies are annotated, not only the local dispatcher fallback. The annotation helper is idempotent and shared by DeepSeek and Qwen graph trainer parallelization.

Replace the public chunk_modules/chunk_mode compile knobs with ep_overlap_modules/ep_overlap_mode. The trace input preparation hook now only runs when ep_overlap is present in compile.passes, and it marks every pytree tensor that shares the selected dynamic batch/sequence extent with the main input so labels and positional tensors carry the same dynamic annotation when they participate in tracing.

Harden the chunk planning path for whole-block and MoE-only overlap regions. Symbolic shape scalar helpers may be shared across adjacent planned regions, while real tensor compute remains disjoint. Region live-ins are filtered after the copied closure reaches a fixed point, scalar live-out materialization handles guarded even split dimensions explicitly, and FX renaming is used for copied/materialized nodes to avoid duplicate generated names in large multi-region transforms.

Test Plan:

- pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or chunk' -q

- pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/qwen3/parallelize.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py

- pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/tests/test_passes.py

- git diff --check

stack-info: PR: #3328, branch: sanketpurandare/stack/14
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/14 branch from 83081fc to 6783bda Compare May 12, 2026 09:13
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/12 May 12, 2026 09:13
@sanketpurandare sanketpurandare marked this pull request as ready for review May 12, 2026 09:13

@xmfan xmfan left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please include both a tlparse and a perfetto trace of before vs after

chunk0_val: object, chunk1_val: object, original_val: object
) -> bool:
try:
half_original = _expr(original_val) // 2

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configurable chunk size?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chunking dataflow is generic over module granularity and the selected logical dimension, but this PR intentionally keeps the chunk count fixed at two because the EP overlap schedule is pairwise. I added a TODO in the pass contract calling out what N-way support would require: N-way provenance, materialization, scalar handling, eager references, and launch/wait scheduling.

Comment on lines +1881 to +1882
dispatch_nodes = [node for node in chunk_nodes if _ep_region(node) == "dispatch"]
combine_nodes = [node for node in chunk_nodes if _ep_region(node) == "combine"]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 loop

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored the EP phase helper so it collects dispatch/combine nodes and last all-to-all launches in one pass over the chunk body, then does one boundary-aware pass to mark wait suffixes and partition phases. This removes the repeated comprehensions while keeping the validation checks local.

Comment on lines +1883 to +1884
if not dispatch_nodes and not combine_nodes:
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably split up the dense vs moe paths, they can share a lot of logic, but it's a bit hard to reason about both at the same time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user-facing control for this is ep_overlap_modules: layers.* selects transformer-block roots, while layers.*.moe selects only MoE roots. I also made the EP-overlap chunk phase skip matched roots with no EP dispatch/combine metadata, so transformer-block mode no longer chunks dense/non-EP layers unnecessarily. Regions with MoE dispatch/combine metadata go through EP scheduling.

Comment on lines +1917 to +1919
[node for node in chunk_nodes if order[node] <= first_idx],
[node for node in chunk_nodes if first_idx < order[node] <= second_idx],
[node for node in chunk_nodes if order[node] > second_idx],

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please piggy back off the same loops everywhere

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Folded the repeated phase scans into the same collection/partition flow used by the EP scheduler. The helper still needs two conceptual passes because the wait boundaries are known only after seeing the last all-to-all in dispatch/combine, but it no longer rebuilds the same filtered lists multiple times.

@sanketpurandare sanketpurandare marked this pull request as draft May 13, 2026 14:16
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/12 to main May 13, 2026 14:16
sanketpurandare added a commit that referenced this pull request May 13, 2026
Add the scheduling layer for graph_trainer's public ep_overlap pass. The previous commit creates the correctness-preserving chunked graph; this commit consumes that chunked graph after FSDP bucketing and reorders EP dispatch/combine phases so paired chunk all-to-all launches can run before their delayed wait suffixes.

The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and MoE dispatcher regions annotated with custom EP metadata for dispatch and combine. The scheduler validates canonical chunk order, extracts dispatch/combine phases, annotates delayed wait suffixes with EP_wait=True, and applies a priority topological sort. It does not duplicate nodes, delete nodes, split tensors, concatenate tensors, or otherwise change tensor values.

Wire the public ep_overlap pipeline as two stages: chunking runs after CPU offload and SAC/rematerialization, while scheduling runs after downstream FSDP bucketing/reordering. This keeps chunk planning aligned with the final recompute graph and lets bucketing place shared full waits before the overlap scheduler establishes the final EP launch/wait order.

The distributed DSV3 bitwise numerics tests introduced in the previous commit now cover this composed public ep_overlap path as well. Since this scheduling pass is value-preserving, those tests validate that adding the scheduler on top of graph chunking does not change the eager-chunk-equivalent loss/gradient behavior. Existing tlparse validation artifacts show the intended rank-code order: dispatch chunk0 launch, dispatch chunk1 launch, delayed dispatch waits, combine chunk0 launch, combine chunk1 launch, and delayed combine waits annotated with EP_wait=True.

This pass stack relies on pending PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545

Test Plan:
- pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or chunk' -q
  - 8 passed, 73 deselected, 2 subtests passed
- pytest torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap -q
  - 3 passed, 20 deselected

stack-info: PR: #3328, branch: sanketpurandare/stack/14
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/14 branch from 6783bda to 632a794 Compare May 13, 2026 14:16
@sanketpurandare sanketpurandare changed the title [graph_trainer] Add EP overlap pass [graph_trainer] Add EP overlap scheduling pass May 13, 2026
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/12 May 13, 2026 14:16
@sanketpurandare sanketpurandare marked this pull request as ready for review May 13, 2026 14:16
@sanketpurandare sanketpurandare marked this pull request as draft May 13, 2026 23:25
@sanketpurandare sanketpurandare marked this pull request as ready for review May 15, 2026 09:37
@sanketpurandare sanketpurandare marked this pull request as draft May 22, 2026 19:03
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/12 to main May 22, 2026 19:03
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/14 branch from 6c26f2c to 7157513 Compare May 22, 2026 19:03
sanketpurandare added a commit that referenced this pull request May 22, 2026
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values.

The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers.

The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values.

Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545
- HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837
- FlexAttention chunked unbacked input extents: pytorch/pytorch#183838
- FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839
- Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840
- Inductor TP slice/cat collective fusion: pytorch/pytorch#184833

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
- pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8

stack-info: PR: #3328, branch: sanketpurandare/stack/14
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/12 May 22, 2026 19:03
@sanketpurandare sanketpurandare marked this pull request as ready for review May 22, 2026 19:03
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:29
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/12 to main May 26, 2026 05:29
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/14 branch from 7157513 to 42d4319 Compare May 26, 2026 05:29
sanketpurandare added a commit that referenced this pull request May 26, 2026
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values.

The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers.

The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values.

Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545
- HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837
- FlexAttention chunked unbacked input extents: pytorch/pytorch#183838
- FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839
- Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840
- Inductor TP slice/cat collective fusion: pytorch/pytorch#184833

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
- pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8

stack-info: PR: #3328, branch: sanketpurandare/stack/14
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/12 May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:31
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/12 to main May 26, 2026 05:31
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/12 May 26, 2026 05:31
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:31
@sanketpurandare sanketpurandare marked this pull request as draft May 27, 2026 03:51
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/12 to main May 27, 2026 03:51
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/14 branch from 42d4319 to 414f3aa Compare May 27, 2026 03:51
sanketpurandare added a commit that referenced this pull request May 27, 2026
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values.

The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers.

The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values.

Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545
- HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837
- FlexAttention chunked unbacked input extents: pytorch/pytorch#183838
- FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839
- Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840
- Inductor TP slice/cat collective fusion: pytorch/pytorch#184833

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
- pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8

stack-info: PR: #3328, branch: sanketpurandare/stack/14
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/12 May 27, 2026 03:52
@sanketpurandare sanketpurandare marked this pull request as ready for review May 27, 2026 03:52
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from 629d28c to a0461e9 Compare May 27, 2026 04:41
Add the scheduling layer for graph_trainer's public ep_overlap pass. Graph and eager chunking producers create correctness-preserving chunked graphs; this commit consumes their shared chunk metadata after FSDP bucketing and reorders EP token-exchange regions to expose communication/compute overlap without changing tensor values.

The scheduling logic lives in ep_overlap_pass.py and is intentionally separate from chunking. Its input contract is an already chunked graph with exactly two chunks per selected region, chunk body nodes tagged with chunk_id and chunked_region metadata, and true EP token-exchange markers on _c10d_functional.all_to_all_single.default nodes via custom[EP_token_exchange]. Generic MoE dispatch/combine traceback annotations are sanitized and are not scheduling markers.

The scheduler validates that both chunks have the same token-exchange signature, annotates each token-exchange wait with EP_token_exchange_wait, and emits a wait-gated schedule. For each token-exchange pair it emits the launch dependency closure in chunk order, then emits ready non-wait filler work that can run while the previous token exchange is outstanding, and only then allows the required wait/tail work. Forward uses chunk order 0 then 1; backward uses chunk order 1 then 0. All nodes are emitted exactly once through a global topological emitter, and the pass does not duplicate, delete, split, concatenate, or otherwise compute tensor values.

Wire the public ep_overlap pipeline as graph/eager chunking, FSDP process-group isolation and bucketing, EP scheduling, chunk-symbol concretization, optional async TP fusion, and DTensor metadata cleanup. Add focused scheduler tests and H100 integration coverage for graph and eager chunking across SDPA and FlexAttention debug-model variants. The distributed bitwise numerics tests introduced with graph chunking cover this composed public ep_overlap path because scheduling is value-preserving.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545
- HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837
- FlexAttention chunked unbacked input extents: pytorch/pytorch#183838
- FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839
- Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840
- Inductor TP slice/cat collective fusion: pytorch/pytorch#184833

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
- pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_moe_seq --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_sdpa_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8
- PYTHONPATH=$PWD python torchtitan/experiments/graph_trainer/tests/integration_tests.py <out_dir> --test_suite graph_trainer_h100 --test_name aot_fx_trace_deepseek_v3_flexattn_full_inductor_ep_overlap_eager_transformer_batch --ngpu 8

stack-info: PR: #3328, branch: sanketpurandare/stack/14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants