Skip to content

[graph_trainer] Add graph EP chunking pass#3325

Open
sanketpurandare wants to merge 1 commit into
sanketpurandare/stack/17from
sanketpurandare/stack/12
Open

[graph_trainer] Add graph EP chunking pass#3325
sanketpurandare wants to merge 1 commit into
sanketpurandare/stack/17from
sanketpurandare/stack/12

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented May 12, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[graph_trainer] Add graph EP chunking pass

Add the graph-level chunking transform used by graph_trainer's EP overlap pipeline. The public opt-in remains --compile.passes ep_overlap, but this commit intentionally only introduces the correctness-preserving chunked dataflow. It does not implement EP communication-overlap scheduling.

The chunking pass lives in ep_chunk_pass.py with shared metadata and symbolic-shape helpers in ep_pass_utils.py. It uses the unbacked SymInt introduced by the trace-input dynamic-shape preparation as the source of truth for chunk influence. The pass plans selected forward/backward module regions, splits live-ins whose fake shapes carry the selected symbol, rewrites one chunk body in place, copies the peer chunk body, and materializes escaping full values. Tensor live-outs carrying the selected dimension are concatenated. Backward parameter-gradient live-outs can be accumulated before grad communication as a performance optimization; --compile.ep_overlap_disable_early_grad_accumulation disables that optimization for strict eager-chunk bitwise comparisons.

The pass runs after memory policy, CPU offload, and SAC/rematerialization, and before FSDP bucketing. It does not special-case offload or remat nodes; those interactions are represented as ordinary dataflow through live-ins and live-outs. Forward buffer mutations are preserved as chunk-local mutation chains so MoE accounting such as tokens-per-expert remains semantically visible. Chunk-derived symbolic shapes are kept through scheduling and then concretized by concretize_ep_chunk_symbolic_shapes_pass before backend codegen.

concretize_ep_chunk_symbolic_shapes_pass also walks FX GraphModules referenced by node arguments, not only the root graph. Higher-order ops such as FlexAttention carry subgraphs this way, and their placeholders can retain symbolic tensor metadata or scalar example values after chunking. Leaving those nested symbols stale violates the pass invariant that no chunk symbol reaches backend codegen. The pass applies the same argument, metadata, false-guard validation, dead scalar plumbing cleanup, lint, and recompile steps per referenced GraphModule.

The eager wrapper remains the numerical reference: selected modules are split into two chunks, the original module forward runs once per chunk, and outputs are combined back. Graph chunking is expected to be bitwise equivalent to eager chunking when early grad accumulation is disabled, including model state such as MoE expert-bias buffers.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:

Test Plan:

  • pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
  • pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or fake_pg or fsdp_dense or graph_ep or concretize_ep_chunk_symbolic_shapes or moe_efsdp_bucket or nested_scalar'
  • pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
  • pytest -q torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py::TestDSv3BitwiseDeterministic::test_ep_chunk_matches_eager_chunking_bitwise
  • Forced graph_trainer DeepSeek V3 distributed loss comparisons for transformer_batch, moe_seq, and moe_batch; all cases returned True and matched all 20 losses exactly.
  • graph_trainer_h100 integration smokes: SDPA regional graph transformer_batch, SDPA regional graph moe_batch, SDPA regional graph moe_seq, SDPA regional eager transformer_batch, FlexAttention regional graph moe_seq.
  • pre-commit run --files torchtitan/experiments/graph_trainer/ep_pass_utils.py torchtitan/experiments/graph_trainer/tests/test_passes.py, with pyrefly-check skipped because it currently mutates and fails on unrelated torchtitan/models/flux/inference/sampling.py suppressions.

sanketpurandare added a commit that referenced this pull request May 12, 2026
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize every escaping value back to the original full-value form.

The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available.

Wire the passes behind compile.chunk_batch_modules and compile.chunk_seq_modules, mark requested input dimensions as unbacked with explicit hints before tracing, copy chunk metadata onto placeholders, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization.

Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, and DeepSeek V3 debug-model bitwise equivalence against eager chunking for transformer-block batch chunking, MoE batch chunking, and MoE sequence chunking.

Test Plan:
- pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q
- SKIP=pyrefly-check pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/make_fx_tracer.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/models/deepseek_v3/model.py

stack-info: PR: #3325, branch: sanketpurandare/stack/12
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from da44cd0 to 82df199 Compare May 12, 2026 02:17
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 12, 2026
@sanketpurandare sanketpurandare marked this pull request as draft May 12, 2026 02:22
sanketpurandare added a commit that referenced this pull request May 12, 2026
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize every escaping value back to the original full-value form.

For supported regions, the graph rewrites are intended to be bitwise equivalent to eager module-level chunking: eager execution splits the selected module input, calls the original module forward on each chunk, and cats the outputs back together. The graph pass implements the same semantics while preserving metadata needed for later scheduling experiments.

The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available.

Wire the passes behind compile.chunk_batch_modules and compile.chunk_seq_modules, mark requested input dimensions as unbacked with explicit hints before tracing, copy chunk metadata onto placeholders, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization.

Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, and DeepSeek V3 debug-model bitwise equivalence against eager chunking for transformer-block batch chunking, MoE batch chunking, and MoE sequence chunking.

Test Plan:
- Chunk pass unit tests: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_batch_forward_region_semantics torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_pass_guardrails_and_pipeline_order torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_batch_backward_cats_activation_grad_and_sums_param_grad torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_batch_on_traced_toy_model -q
- DeepSeek V3 bitwise eager-chunking equivalence: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_dsv3_debug_chunk_pass_matches_eager_chunking_bitwise -q
- Full chunk pass suite: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q
- Lint except unrelated Pyrefly dataloader environment issue: SKIP=pyrefly-check pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/make_fx_tracer.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/models/deepseek_v3/model.py

stack-info: PR: #3325, branch: sanketpurandare/stack/12
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from 82df199 to 8d6f30e Compare May 12, 2026 02:22
@sanketpurandare sanketpurandare marked this pull request as ready for review May 12, 2026 02:22
@SherlockNoMad

Copy link
Copy Markdown
Contributor

would be easier for me to review if you can show a tlp, so I can diff the graph before and after the pass.

Comment thread torchtitan/models/deepseek_v3/model.py Outdated
Comment on lines +104 to +105
bsz = x.size(0)
seqlen = x.size(1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... I am surprised this rewrite is needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rechecked with the current PyTorch fixes and reverted this back to bsz, seqlen, _ = x.size(). The focused chunk tests, including the DeepSeekV3 bitwise chunking subtests, still pass.

Comment on lines +490 to +497
for node, arg in zip(
(n for n in traced.graph.nodes if n.op == "placeholder"),
unwrapped_args,
strict=True,
):
chunk_dims = getattr(arg, "_torchtitan_chunk_dims", None)
if chunk_dims is not None:
node.meta["torchtitan_chunk_dims"] = dict(chunk_dims)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model specific or pass specific optimization shouldn't live in tracer.
we need to keep tracer general and portable.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be a preprocessing pass?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this out of the generic tracer. The chunk-specific placeholder metadata import now lives in import_chunk_dim_metadata_pass, which is inserted immediately before the selected chunk pass in compile_time_passes.

Comment on lines +176 to +184
if config.compile.chunk_batch_modules:
passes.append(
functools.partial(
chunk_batch_pass,
module_patterns=config.compile.chunk_batch_modules,
num_static_inputs=traced_result.num_static_inputs,
)
)
if config.compile.chunk_seq_modules:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these two mutually exclusive?
if so, use elif

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified this by replacing the two independently enabled lists with one chunk_modules list plus one chunk_mode selector. Pass construction now dispatches to either batch or seq chunking from that single mode, and it validates that selected chunk roots are compatible with downstream transformer-block scheduling boundaries.

return fqn == root_fqn or fqn.startswith(root_fqn + ".")


def _node_index(node: torch.fx.Node) -> int | None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is inefficient if we need to call it multiple times.
better build the node to index once, and read from cache.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I removed the repeated _node_index / _is_static_placeholder scan from common_utils.py; chunk_passes.py now computes the static placeholder set once and reuses it while deriving static nodes and chunkable live-ins.

Comment on lines +943 to +949
logger.info(
"Applied chunk_%s to %d regions (%d materialized live-outs): %s",
mode,
len(regions),
transformed,
module_patterns,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add 10x more logger.debug(...) to make this pass more maintainable.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added targeted logger.debug(...) coverage around placeholder metadata import, matched regions, per-region plan sizes, and live-out materialization counts. I also cleaned up the pass structure and comments so the logs map to the main pass phases without dumping whole graphs.

@sanketpurandare sanketpurandare marked this pull request as draft May 12, 2026 06:23
sanketpurandare added a commit that referenced this pull request May 12, 2026
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize escaping values back to the original full-value form when needed.

For supported regions, the graph rewrites are intended to be bitwise equivalent to eager module-level chunking: eager execution splits the selected module input, calls the original module forward on each chunk, and cats the outputs back together. The graph pass implements the same semantics while preserving metadata needed for later scheduling experiments.

The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. For non-invertible saved bookkeeping values such as MoE per-expert token counts, materialized forward live-outs also retain exact per-chunk provenance so chunked backward compute can consume the same values as eager chunking. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available.

Wire the passes behind compile.chunk_modules and compile.chunk_mode, mark the requested input dimension as unbacked with an explicit hint before tracing, import chunk dimension annotations in a chunk-specific preprocessing pass, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization. Validate selected chunk roots against downstream transformer-block scheduling boundaries, and keep user-local Codex workflow state under .codex/local/.

Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, downstream boundary validation, and DeepSeek V3 debug-model bitwise equivalence against eager chunking. The DeepSeek coverage checks every debug-model transformer block for batch chunking and every debug-model MoE block for batch and sequence chunking in a single graph pass.

Test Plan:

- Chunk pass focused suite: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k chunk -q

- Full chunk pass class: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q

- Lint and type checking: pre-commit run --files .gitignore torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/experiments/graph_trainer/tests/test_passes.py

stack-info: PR: #3325, branch: sanketpurandare/stack/12
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from 8d6f30e to 753c321 Compare May 12, 2026 06:23
@sanketpurandare sanketpurandare marked this pull request as ready for review May 12, 2026 06:24
@sanketpurandare

Copy link
Copy Markdown
Contributor Author

Added a Manifold-backed tlparse artifact here:

https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/sanketpurandare/a00c7cc7-93ab-41cb-a09d-c7de1e0f0d97/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

The useful entries to compare are before_chunk_batch_pass and after_chunk_batch_pass; this is a minimal traced graph that shows the pass inserting split, duplicated chunk compute, and cat materialization.

@sanketpurandare

Copy link
Copy Markdown
Contributor Author

Added two DeepSeekV3 debug-model tlparse artifacts as well:

Transformer block batch chunking (layers.*):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/sanketpurandare/2ccf18d7-c6e7-417f-bc45-55bb2517d7a5/dsv3_transformer_chunk_batch/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

MoE block batch chunking (layers.*.moe):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/sanketpurandare/1fc9547b-f670-42ac-bb9b-50fcb930df78/dsv3_moe_chunk_batch/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

I checked the local trace logs before posting these. Both contain before_chunk_batch_pass / after_chunk_batch_pass. The transformer trace chunks layers.0 through layers.5; the MoE trace chunks layers.1.moe through layers.5.moe (layers.0 has feed_forward, not moe). The after graphs show the expected dynamic batch split with floor(u0/2), duplicated _chunk0 / _chunk1 compute, _chunk_cat full-shape materialization, and _chunk_tuple provenance.

@sanketpurandare sanketpurandare marked this pull request as draft May 12, 2026 06:57
sanketpurandare added a commit that referenced this pull request May 12, 2026
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize escaping values back to the original full-value form when needed.

For supported regions, the graph rewrites are intended to be bitwise equivalent to eager module-level chunking: eager execution splits the selected module input, calls the original module forward on each chunk, and cats the outputs back together. The graph pass implements the same semantics while preserving metadata needed for later scheduling experiments.

The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. For non-invertible saved bookkeeping values such as MoE per-expert token counts, materialized forward live-outs also retain exact per-chunk provenance so chunked backward compute can consume the same values as eager chunking. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available.

Wire the passes behind compile.chunk_modules and compile.chunk_mode, mark the requested input dimension as unbacked with an explicit hint before tracing, import chunk dimension annotations in a chunk-specific preprocessing pass, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization. Validate selected chunk roots against downstream transformer-block scheduling boundaries, and keep user-local Codex workflow state under .codex/local/.

Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, downstream boundary validation, and DeepSeek V3 debug-model bitwise equivalence against eager chunking. The DeepSeek coverage checks every debug-model transformer block for batch chunking and every debug-model MoE block for batch and sequence chunking in a single graph pass.

Test Plan:

- Chunk pass focused suite: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k chunk -q

- Full chunk pass class: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q

- Lint and type checking: pre-commit run --files .gitignore torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/experiments/graph_trainer/tests/test_passes.py

stack-info: PR: #3325, branch: sanketpurandare/stack/12
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/17 to main May 26, 2026 05:28
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from 1969fa6 to a7f75fe Compare May 26, 2026 05:29
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/17 May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:30
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/17 to main May 26, 2026 05:31
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/17 May 26, 2026 05:31
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:31
@sanketpurandare sanketpurandare marked this pull request as draft May 27, 2026 03:51
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/17 to main May 27, 2026 03:51
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from a7f75fe to 629d28c Compare May 27, 2026 03:51
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/17 May 27, 2026 03:52
@sanketpurandare sanketpurandare marked this pull request as ready for review May 27, 2026 03:52
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/17 branch from c7d1c78 to c719e42 Compare May 27, 2026 04:41
sanketpurandare added a commit that referenced this pull request May 27, 2026
Add the graph-level chunking transform used by graph_trainer's EP overlap pipeline. The public opt-in remains --compile.passes ep_overlap, but this commit intentionally only introduces the correctness-preserving chunked dataflow. It does not implement EP communication-overlap scheduling.

The chunking pass lives in ep_chunk_pass.py with shared metadata and symbolic-shape helpers in ep_pass_utils.py. It uses the unbacked SymInt introduced by the trace-input dynamic-shape preparation as the source of truth for chunk influence. The pass plans selected forward/backward module regions, splits live-ins whose fake shapes carry the selected symbol, rewrites one chunk body in place, copies the peer chunk body, and materializes escaping full values. Tensor live-outs carrying the selected dimension are concatenated. Backward parameter-gradient live-outs can be accumulated before grad communication as a performance optimization; --compile.ep_overlap_disable_early_grad_accumulation disables that optimization for strict eager-chunk bitwise comparisons.

The pass runs after memory policy, CPU offload, and SAC/rematerialization, and before FSDP bucketing. It does not special-case offload or remat nodes; those interactions are represented as ordinary dataflow through live-ins and live-outs. Forward buffer mutations are preserved as chunk-local mutation chains so MoE accounting such as tokens-per-expert remains semantically visible. Chunk-derived symbolic shapes are kept through scheduling and then concretized by concretize_ep_chunk_symbolic_shapes_pass before backend codegen.

The eager wrapper remains the numerical reference: selected modules are split into two chunks, the original module forward runs once per chunk, and outputs are combined back. Graph chunking is expected to be bitwise equivalent to eager chunking when early grad accumulation is disabled, including model state such as MoE expert-bias buffers.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545
- HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837
- FlexAttention chunked unbacked input extents: pytorch/pytorch#183838
- FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839
- Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840
- Inductor TP slice/cat collective fusion: pytorch/pytorch#184833

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
- pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
- python torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py TestDSv3BitwiseDeterministic.test_ep_overlap_graph_matches_eager_chunking

stack-info: PR: #3325, branch: sanketpurandare/stack/12
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/12 branch from 629d28c to a0461e9 Compare May 27, 2026 04:41
@sanketpurandare sanketpurandare marked this pull request as draft May 27, 2026 15:28
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/17 to main May 27, 2026 15:28
return gm


def remove_dead_code_pass(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already in main.

Comment on lines +417 to +436
# TODO: Disabled due to upstream PyTorch nightly DTensor regression.
# Sharding propagation fails for aten.mm.default with mixed dtypes
# (bf16 activations, f32 weights) on the TP mesh. Re-enable once fixed.
@unittest.skip("upstream DTensor mixed-dtype sharding propagation regression")
def test_moe_dsv3_ep_overlap_aot_fx_trace_vs_eager_chunked(self):
self.assertTrue(_run_deepseek_v3_ep_overlap_loss_compare())

# TODO: Disabled due to upstream PyTorch nightly DTensor regression.
# Sharding propagation fails for aten.mm.default with mixed dtypes
# (bf16 activations, f32 weights) on the TP mesh. Re-enable once fixed.
@unittest.skip("upstream DTensor mixed-dtype sharding propagation regression")
def test_moe_dsv3_ep_overlap_moe_seq_aot_fx_trace_vs_eager_chunked(self):
self.assertTrue(_run_deepseek_v3_ep_overlap_moe_seq_loss_compare())

# TODO: Disabled due to upstream PyTorch nightly DTensor regression.
# Sharding propagation fails for aten.mm.default with mixed dtypes
# (bf16 activations, f32 weights) on the TP mesh. Re-enable once fixed.
@unittest.skip("upstream DTensor mixed-dtype sharding propagation regression")
def test_moe_dsv3_ep_overlap_moe_batch_aot_fx_trace_vs_eager_chunked(self):
self.assertTrue(_run_deepseek_v3_ep_overlap_moe_batch_loss_compare())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these addressed already?

return fqn == root_fqn or fqn.startswith(root_fqn + ".")


def _tensor_meta(node: torch.fx.Node) -> torch.Tensor | None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val can be list/tuple/sym_int...

return val if isinstance(val, torch.Tensor) else None


def _free_symbols(value: object) -> frozenset[object]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring, whats' free symbols? u1, u2... ?

return frozenset(free_symbols(value))


def _dynamic_dim_symbols(val: torch.Tensor, dim: int) -> frozenset[object]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc string

remove_identity_slice_pass,
normalize_view_ops_as_reshape,
]
if "ep_overlap" in config.compile.passes:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... I originally intend to deprecate config.compile.passes..
but this is a new way to use it.

I would rename config.compile.passes to config.compile.extra_passes

nit, fine as follow up

Comment on lines +150 to +152
compile_ep_overlap_chunk_dim: str = "batch",
compile_ep_overlap_module_fqn: str = "layers.*",
compile_ep_overlap_disable_early_grad_accumulation: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably a sign for ep_overlap to have its own Config.

compile.ep_overlap.*

@SherlockNoMad

Copy link
Copy Markdown
Contributor

Need some case in integration_tests.py

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Graph chunking for EP overlap.

@SherlockNoMad SherlockNoMad Jun 9, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is very complicated... and is only maintainable by you / claude.

I am fine with landing this, coz it's a standalone graph pass, and it's complexity is well encapsulated, and if thing doesn't work, ppl can and will turn it off.

but my point is, I have seen your profiler trace, and I trust this would work.

so I am approving this as a blackbox owned by you.

Add the graph-level chunking transform used by graph_trainer's EP overlap pipeline. The public opt-in remains --compile.passes ep_overlap, but this commit intentionally only introduces the correctness-preserving chunked dataflow. It does not implement EP communication-overlap scheduling.

The chunking pass lives in ep_chunk_pass.py with shared metadata and symbolic-shape helpers in ep_pass_utils.py. It uses the unbacked SymInt introduced by the trace-input dynamic-shape preparation as the source of truth for chunk influence. The pass plans selected forward/backward module regions, splits live-ins whose fake shapes carry the selected symbol, rewrites one chunk body in place, copies the peer chunk body, and materializes escaping full values. Tensor live-outs carrying the selected dimension are concatenated. Backward parameter-gradient live-outs can be accumulated before grad communication as a performance optimization; --compile.ep_overlap_disable_early_grad_accumulation disables that optimization for strict eager-chunk bitwise comparisons.

The pass runs after memory policy, CPU offload, and SAC/rematerialization, and before FSDP bucketing. It does not special-case offload or remat nodes; those interactions are represented as ordinary dataflow through live-ins and live-outs. Forward buffer mutations are preserved as chunk-local mutation chains so MoE accounting such as tokens-per-expert remains semantically visible. Chunk-derived symbolic shapes are kept through scheduling and then concretized by concretize_ep_chunk_symbolic_shapes_pass before backend codegen.

concretize_ep_chunk_symbolic_shapes_pass also walks FX GraphModules referenced by node arguments, not only the root graph. Higher-order ops such as FlexAttention carry subgraphs this way, and their placeholders can retain symbolic tensor metadata or scalar example values after chunking. Leaving those nested symbols stale violates the pass invariant that no chunk symbol reaches backend codegen. The pass applies the same argument, metadata, false-guard validation, dead scalar plumbing cleanup, lint, and recompile steps per referenced GraphModule.

The eager wrapper remains the numerical reference: selected modules are split into two chunks, the original module forward runs once per chunk, and outputs are combined back. Graph chunking is expected to be bitwise equivalent to eager chunking when early grad accumulation is disabled, including model state such as MoE expert-bias buffers.

This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
- FakeTensor folded matmul: pytorch/pytorch#183397
- ProxyTensor SDPA tracing: pytorch/pytorch#183398
- Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495
- Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544
- DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545
- HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837
- FlexAttention chunked unbacked input extents: pytorch/pytorch#183838
- FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839
- Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840
- Inductor TP slice/cat collective fusion: pytorch/pytorch#184833

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py
- pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or fake_pg or fsdp_dense or graph_ep or concretize_ep_chunk_symbolic_shapes or moe_efsdp_bucket or nested_scalar'
- pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap
- pytest -q torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py::TestDSv3BitwiseDeterministic::test_ep_chunk_matches_eager_chunking_bitwise
- Forced graph_trainer DeepSeek V3 distributed loss comparisons for transformer_batch, moe_seq, and moe_batch; all cases returned True and matched all 20 losses exactly.
- graph_trainer_h100 integration smokes: SDPA regional graph transformer_batch, SDPA regional graph moe_batch, SDPA regional graph moe_seq, SDPA regional eager transformer_batch, FlexAttention regional graph moe_seq.
- pre-commit run --files torchtitan/experiments/graph_trainer/ep_pass_utils.py torchtitan/experiments/graph_trainer/tests/test_passes.py, with pyrefly-check skipped because it currently mutates and fails on unrelated torchtitan/models/flux/inference/sampling.py suppressions.

stack-info: PR: #3325, branch: sanketpurandare/stack/12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants