[graph_trainer] Add graph EP chunking pass#3325
Conversation
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize every escaping value back to the original full-value form. The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available. Wire the passes behind compile.chunk_batch_modules and compile.chunk_seq_modules, mark requested input dimensions as unbacked with explicit hints before tracing, copy chunk metadata onto placeholders, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization. Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, and DeepSeek V3 debug-model bitwise equivalence against eager chunking for transformer-block batch chunking, MoE batch chunking, and MoE sequence chunking. Test Plan: - pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q - SKIP=pyrefly-check pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/make_fx_tracer.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/models/deepseek_v3/model.py stack-info: PR: #3325, branch: sanketpurandare/stack/12
da44cd0 to
82df199
Compare
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize every escaping value back to the original full-value form. For supported regions, the graph rewrites are intended to be bitwise equivalent to eager module-level chunking: eager execution splits the selected module input, calls the original module forward on each chunk, and cats the outputs back together. The graph pass implements the same semantics while preserving metadata needed for later scheduling experiments. The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available. Wire the passes behind compile.chunk_batch_modules and compile.chunk_seq_modules, mark requested input dimensions as unbacked with explicit hints before tracing, copy chunk metadata onto placeholders, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization. Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, and DeepSeek V3 debug-model bitwise equivalence against eager chunking for transformer-block batch chunking, MoE batch chunking, and MoE sequence chunking. Test Plan: - Chunk pass unit tests: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_batch_forward_region_semantics torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_pass_guardrails_and_pipeline_order torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_batch_backward_cats_activation_grad_and_sums_param_grad torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_chunk_batch_on_traced_toy_model -q - DeepSeek V3 bitwise eager-chunking equivalence: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses::test_dsv3_debug_chunk_pass_matches_eager_chunking_bitwise -q - Full chunk pass suite: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q - Lint except unrelated Pyrefly dataloader environment issue: SKIP=pyrefly-check pre-commit run --files torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/make_fx_tracer.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/tests/test_passes.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/models/deepseek_v3/model.py stack-info: PR: #3325, branch: sanketpurandare/stack/12
82df199 to
8d6f30e
Compare
|
would be easier for me to review if you can show a tlp, so I can diff the graph before and after the pass. |
| bsz = x.size(0) | ||
| seqlen = x.size(1) |
There was a problem hiding this comment.
hmm... I am surprised this rewrite is needed?
There was a problem hiding this comment.
Rechecked with the current PyTorch fixes and reverted this back to bsz, seqlen, _ = x.size(). The focused chunk tests, including the DeepSeekV3 bitwise chunking subtests, still pass.
| for node, arg in zip( | ||
| (n for n in traced.graph.nodes if n.op == "placeholder"), | ||
| unwrapped_args, | ||
| strict=True, | ||
| ): | ||
| chunk_dims = getattr(arg, "_torchtitan_chunk_dims", None) | ||
| if chunk_dims is not None: | ||
| node.meta["torchtitan_chunk_dims"] = dict(chunk_dims) |
There was a problem hiding this comment.
model specific or pass specific optimization shouldn't live in tracer.
we need to keep tracer general and portable.
There was a problem hiding this comment.
can this be a preprocessing pass?
There was a problem hiding this comment.
Moved this out of the generic tracer. The chunk-specific placeholder metadata import now lives in import_chunk_dim_metadata_pass, which is inserted immediately before the selected chunk pass in compile_time_passes.
| if config.compile.chunk_batch_modules: | ||
| passes.append( | ||
| functools.partial( | ||
| chunk_batch_pass, | ||
| module_patterns=config.compile.chunk_batch_modules, | ||
| num_static_inputs=traced_result.num_static_inputs, | ||
| ) | ||
| ) | ||
| if config.compile.chunk_seq_modules: |
There was a problem hiding this comment.
are these two mutually exclusive?
if so, use elif
There was a problem hiding this comment.
Simplified this by replacing the two independently enabled lists with one chunk_modules list plus one chunk_mode selector. Pass construction now dispatches to either batch or seq chunking from that single mode, and it validates that selected chunk roots are compatible with downstream transformer-block scheduling boundaries.
| return fqn == root_fqn or fqn.startswith(root_fqn + ".") | ||
|
|
||
|
|
||
| def _node_index(node: torch.fx.Node) -> int | None: |
There was a problem hiding this comment.
this is inefficient if we need to call it multiple times.
better build the node to index once, and read from cache.
There was a problem hiding this comment.
Fixed. I removed the repeated _node_index / _is_static_placeholder scan from common_utils.py; chunk_passes.py now computes the static placeholder set once and reuses it while deriving static nodes and chunkable live-ins.
| logger.info( | ||
| "Applied chunk_%s to %d regions (%d materialized live-outs): %s", | ||
| mode, | ||
| len(regions), | ||
| transformed, | ||
| module_patterns, | ||
| ) |
There was a problem hiding this comment.
I would add 10x more logger.debug(...) to make this pass more maintainable.
There was a problem hiding this comment.
Added targeted logger.debug(...) coverage around placeholder metadata import, matched regions, per-region plan sizes, and live-out materialization counts. I also cleaned up the pass structure and comments so the logs map to the main pass phases without dumping whole graphs.
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize escaping values back to the original full-value form when needed. For supported regions, the graph rewrites are intended to be bitwise equivalent to eager module-level chunking: eager execution splits the selected module input, calls the original module forward on each chunk, and cats the outputs back together. The graph pass implements the same semantics while preserving metadata needed for later scheduling experiments. The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. For non-invertible saved bookkeeping values such as MoE per-expert token counts, materialized forward live-outs also retain exact per-chunk provenance so chunked backward compute can consume the same values as eager chunking. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available. Wire the passes behind compile.chunk_modules and compile.chunk_mode, mark the requested input dimension as unbacked with an explicit hint before tracing, import chunk dimension annotations in a chunk-specific preprocessing pass, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization. Validate selected chunk roots against downstream transformer-block scheduling boundaries, and keep user-local Codex workflow state under .codex/local/. Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, downstream boundary validation, and DeepSeek V3 debug-model bitwise equivalence against eager chunking. The DeepSeek coverage checks every debug-model transformer block for batch chunking and every debug-model MoE block for batch and sequence chunking in a single graph pass. Test Plan: - Chunk pass focused suite: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k chunk -q - Full chunk pass class: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q - Lint and type checking: pre-commit run --files .gitignore torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/experiments/graph_trainer/tests/test_passes.py stack-info: PR: #3325, branch: sanketpurandare/stack/12
8d6f30e to
753c321
Compare
|
Added a Manifold-backed tlparse artifact here: The useful entries to compare are |
|
Added two DeepSeekV3 debug-model tlparse artifacts as well: Transformer block batch chunking ( MoE block batch chunking ( I checked the local trace logs before posting these. Both contain |
Add optional chunk_batch and chunk_seq passes for aot_fx_trace graph_trainer execution. The passes select module regions by FQN pattern, split explicitly annotated activation live-ins into two chunks, duplicate only the activation-dependent descendant closure, and materialize escaping values back to the original full-value form when needed. For supported regions, the graph rewrites are intended to be bitwise equivalent to eager module-level chunking: eager execution splits the selected module input, calls the original module forward on each chunk, and cats the outputs back together. The graph pass implements the same semantics while preserving metadata needed for later scheduling experiments. The implementation keeps static model state and static-derived parameter prep shared, so parameter-only all-gather/cast paths are not duplicated. Tensor live-outs that still carry the chunk dimension are concatenated, reduced outputs such as parameter gradients are added, and symbolic shape live-outs derived from chunked dims are handled consistently. For non-invertible saved bookkeeping values such as MoE per-expert token counts, materialized forward live-outs also retain exact per-chunk provenance so chunked backward compute can consume the same values as eager chunking. Sequence chunking rejects attention-containing regions until a full-K/V rewrite is available. Wire the passes behind compile.chunk_modules and compile.chunk_mode, mark the requested input dimension as unbacked with an explicit hint before tracing, import chunk dimension annotations in a chunk-specific preprocessing pass, include the configuration in precompile fingerprints, and run the passes after CPU offload insertion and before SAC rematerialization. Validate selected chunk roots against downstream transformer-block scheduling boundaries, and keep user-local Codex workflow state under .codex/local/. Add focused FX tests for forward and backward materialization, guardrails, pass ordering, traced ToyModel execution, downstream boundary validation, and DeepSeek V3 debug-model bitwise equivalence against eager chunking. The DeepSeek coverage checks every debug-model transformer block for batch chunking and every debug-model MoE block for batch and sequence chunking in a single graph pass. Test Plan: - Chunk pass focused suite: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py -k chunk -q - Full chunk pass class: pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestChunkPasses -q - Lint and type checking: pre-commit run --files .gitignore torchtitan/experiments/graph_trainer/chunk_passes.py torchtitan/experiments/graph_trainer/common_utils.py torchtitan/experiments/graph_trainer/configs.py torchtitan/experiments/graph_trainer/passes.py torchtitan/experiments/graph_trainer/precompile.py torchtitan/experiments/graph_trainer/trainer.py torchtitan/experiments/graph_trainer/tests/test_passes.py stack-info: PR: #3325, branch: sanketpurandare/stack/12
1969fa6 to
a7f75fe
Compare
a7f75fe to
629d28c
Compare
c7d1c78 to
c719e42
Compare
Add the graph-level chunking transform used by graph_trainer's EP overlap pipeline. The public opt-in remains --compile.passes ep_overlap, but this commit intentionally only introduces the correctness-preserving chunked dataflow. It does not implement EP communication-overlap scheduling. The chunking pass lives in ep_chunk_pass.py with shared metadata and symbolic-shape helpers in ep_pass_utils.py. It uses the unbacked SymInt introduced by the trace-input dynamic-shape preparation as the source of truth for chunk influence. The pass plans selected forward/backward module regions, splits live-ins whose fake shapes carry the selected symbol, rewrites one chunk body in place, copies the peer chunk body, and materializes escaping full values. Tensor live-outs carrying the selected dimension are concatenated. Backward parameter-gradient live-outs can be accumulated before grad communication as a performance optimization; --compile.ep_overlap_disable_early_grad_accumulation disables that optimization for strict eager-chunk bitwise comparisons. The pass runs after memory policy, CPU offload, and SAC/rematerialization, and before FSDP bucketing. It does not special-case offload or remat nodes; those interactions are represented as ordinary dataflow through live-ins and live-outs. Forward buffer mutations are preserved as chunk-local mutation chains so MoE accounting such as tokens-per-expert remains semantically visible. Chunk-derived symbolic shapes are kept through scheduling and then concretized by concretize_ep_chunk_symbolic_shapes_pass before backend codegen. The eager wrapper remains the numerical reference: selected modules are split into two chunks, the original module forward runs once per chunk, and outputs are combined back. Graph chunking is expected to be bitwise equivalent to eager chunking when early grad accumulation is disabled, including model state such as MoE expert-bias buffers. This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 - HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837 - FlexAttention chunked unbacked input extents: pytorch/pytorch#183838 - FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839 - Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840 - Inductor TP slice/cat collective fusion: pytorch/pytorch#184833 Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py - pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap - python torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py TestDSv3BitwiseDeterministic.test_ep_overlap_graph_matches_eager_chunking stack-info: PR: #3325, branch: sanketpurandare/stack/12
629d28c to
a0461e9
Compare
| return gm | ||
|
|
||
|
|
||
| def remove_dead_code_pass( |
| # TODO: Disabled due to upstream PyTorch nightly DTensor regression. | ||
| # Sharding propagation fails for aten.mm.default with mixed dtypes | ||
| # (bf16 activations, f32 weights) on the TP mesh. Re-enable once fixed. | ||
| @unittest.skip("upstream DTensor mixed-dtype sharding propagation regression") | ||
| def test_moe_dsv3_ep_overlap_aot_fx_trace_vs_eager_chunked(self): | ||
| self.assertTrue(_run_deepseek_v3_ep_overlap_loss_compare()) | ||
|
|
||
| # TODO: Disabled due to upstream PyTorch nightly DTensor regression. | ||
| # Sharding propagation fails for aten.mm.default with mixed dtypes | ||
| # (bf16 activations, f32 weights) on the TP mesh. Re-enable once fixed. | ||
| @unittest.skip("upstream DTensor mixed-dtype sharding propagation regression") | ||
| def test_moe_dsv3_ep_overlap_moe_seq_aot_fx_trace_vs_eager_chunked(self): | ||
| self.assertTrue(_run_deepseek_v3_ep_overlap_moe_seq_loss_compare()) | ||
|
|
||
| # TODO: Disabled due to upstream PyTorch nightly DTensor regression. | ||
| # Sharding propagation fails for aten.mm.default with mixed dtypes | ||
| # (bf16 activations, f32 weights) on the TP mesh. Re-enable once fixed. | ||
| @unittest.skip("upstream DTensor mixed-dtype sharding propagation regression") | ||
| def test_moe_dsv3_ep_overlap_moe_batch_aot_fx_trace_vs_eager_chunked(self): | ||
| self.assertTrue(_run_deepseek_v3_ep_overlap_moe_batch_loss_compare()) |
There was a problem hiding this comment.
are these addressed already?
| return fqn == root_fqn or fqn.startswith(root_fqn + ".") | ||
|
|
||
|
|
||
| def _tensor_meta(node: torch.fx.Node) -> torch.Tensor | None: |
There was a problem hiding this comment.
val can be list/tuple/sym_int...
| return val if isinstance(val, torch.Tensor) else None | ||
|
|
||
|
|
||
| def _free_symbols(value: object) -> frozenset[object]: |
There was a problem hiding this comment.
docstring, whats' free symbols? u1, u2... ?
| return frozenset(free_symbols(value)) | ||
|
|
||
|
|
||
| def _dynamic_dim_symbols(val: torch.Tensor, dim: int) -> frozenset[object]: |
| remove_identity_slice_pass, | ||
| normalize_view_ops_as_reshape, | ||
| ] | ||
| if "ep_overlap" in config.compile.passes: |
There was a problem hiding this comment.
hmm... I originally intend to deprecate config.compile.passes..
but this is a new way to use it.
I would rename config.compile.passes to config.compile.extra_passes
nit, fine as follow up
| compile_ep_overlap_chunk_dim: str = "batch", | ||
| compile_ep_overlap_module_fqn: str = "layers.*", | ||
| compile_ep_overlap_disable_early_grad_accumulation: bool = False, |
There was a problem hiding this comment.
this is probably a sign for ep_overlap to have its own Config.
compile.ep_overlap.*
|
Need some case in integration_tests.py |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """Graph chunking for EP overlap. |
There was a problem hiding this comment.
This pass is very complicated... and is only maintainable by you / claude.
I am fine with landing this, coz it's a standalone graph pass, and it's complexity is well encapsulated, and if thing doesn't work, ppl can and will turn it off.
but my point is, I have seen your profiler trace, and I trust this would work.
so I am approving this as a blackbox owned by you.
Add the graph-level chunking transform used by graph_trainer's EP overlap pipeline. The public opt-in remains --compile.passes ep_overlap, but this commit intentionally only introduces the correctness-preserving chunked dataflow. It does not implement EP communication-overlap scheduling. The chunking pass lives in ep_chunk_pass.py with shared metadata and symbolic-shape helpers in ep_pass_utils.py. It uses the unbacked SymInt introduced by the trace-input dynamic-shape preparation as the source of truth for chunk influence. The pass plans selected forward/backward module regions, splits live-ins whose fake shapes carry the selected symbol, rewrites one chunk body in place, copies the peer chunk body, and materializes escaping full values. Tensor live-outs carrying the selected dimension are concatenated. Backward parameter-gradient live-outs can be accumulated before grad communication as a performance optimization; --compile.ep_overlap_disable_early_grad_accumulation disables that optimization for strict eager-chunk bitwise comparisons. The pass runs after memory policy, CPU offload, and SAC/rematerialization, and before FSDP bucketing. It does not special-case offload or remat nodes; those interactions are represented as ordinary dataflow through live-ins and live-outs. Forward buffer mutations are preserved as chunk-local mutation chains so MoE accounting such as tokens-per-expert remains semantically visible. Chunk-derived symbolic shapes are kept through scheduling and then concretized by concretize_ep_chunk_symbolic_shapes_pass before backend codegen. concretize_ep_chunk_symbolic_shapes_pass also walks FX GraphModules referenced by node arguments, not only the root graph. Higher-order ops such as FlexAttention carry subgraphs this way, and their placeholders can retain symbolic tensor metadata or scalar example values after chunking. Leaving those nested symbols stale violates the pass invariant that no chunk symbol reaches backend codegen. The pass applies the same argument, metadata, false-guard validation, dead scalar plumbing cleanup, lint, and recompile steps per referenced GraphModule. The eager wrapper remains the numerical reference: selected modules are split into two chunks, the original module forward runs once per chunk, and outputs are combined back. Graph chunking is expected to be bitwise equivalent to eager chunking when early grad accumulation is disabled, including model state such as MoE expert-bias buffers. This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths: - FakeTensor folded matmul: pytorch/pytorch#183397 - ProxyTensor SDPA tracing: pytorch/pytorch#183398 - Inductor bucketing trace isolation from ambient unbacked symbols: pytorch/pytorch#183495 - Inductor collective bucketing with hinted unbacked SymInts: pytorch/pytorch#183544 - DTensor sharding padding for hinted even unbacked shards: pytorch/pytorch#183545 - HOP fake traces with discarded unbacked symbols: pytorch/pytorch#183837 - FlexAttention chunked unbacked input extents: pytorch/pytorch#183838 - FakeTensor trace metadata for hinted symbolic storage: pytorch/pytorch#183839 - Inductor symbolic stride ordering with unbacked hints: pytorch/pytorch#183840 - Inductor TP slice/cat collective fusion: pytorch/pytorch#184833 Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py - pytest -q torchtitan/experiments/graph_trainer/tests/test_passes.py -k 'ep_overlap or fake_pg or fsdp_dense or graph_ep or concretize_ep_chunk_symbolic_shapes or moe_efsdp_bucket or nested_scalar' - pytest -q torchtitan/experiments/graph_trainer/tests/test_numerics.py -k ep_overlap - pytest -q torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py::TestDSv3BitwiseDeterministic::test_ep_chunk_matches_eager_chunking_bitwise - Forced graph_trainer DeepSeek V3 distributed loss comparisons for transformer_batch, moe_seq, and moe_batch; all cases returned True and matched all 20 losses exactly. - graph_trainer_h100 integration smokes: SDPA regional graph transformer_batch, SDPA regional graph moe_batch, SDPA regional graph moe_seq, SDPA regional eager transformer_batch, FlexAttention regional graph moe_seq. - pre-commit run --files torchtitan/experiments/graph_trainer/ep_pass_utils.py torchtitan/experiments/graph_trainer/tests/test_passes.py, with pyrefly-check skipped because it currently mutates and fails on unrelated torchtitan/models/flux/inference/sampling.py suppressions. stack-info: PR: #3325, branch: sanketpurandare/stack/12
Stacked PRs:
[graph_trainer] Add graph EP chunking pass
Add the graph-level chunking transform used by graph_trainer's EP overlap pipeline. The public opt-in remains --compile.passes ep_overlap, but this commit intentionally only introduces the correctness-preserving chunked dataflow. It does not implement EP communication-overlap scheduling.
The chunking pass lives in ep_chunk_pass.py with shared metadata and symbolic-shape helpers in ep_pass_utils.py. It uses the unbacked SymInt introduced by the trace-input dynamic-shape preparation as the source of truth for chunk influence. The pass plans selected forward/backward module regions, splits live-ins whose fake shapes carry the selected symbol, rewrites one chunk body in place, copies the peer chunk body, and materializes escaping full values. Tensor live-outs carrying the selected dimension are concatenated. Backward parameter-gradient live-outs can be accumulated before grad communication as a performance optimization; --compile.ep_overlap_disable_early_grad_accumulation disables that optimization for strict eager-chunk bitwise comparisons.
The pass runs after memory policy, CPU offload, and SAC/rematerialization, and before FSDP bucketing. It does not special-case offload or remat nodes; those interactions are represented as ordinary dataflow through live-ins and live-outs. Forward buffer mutations are preserved as chunk-local mutation chains so MoE accounting such as tokens-per-expert remains semantically visible. Chunk-derived symbolic shapes are kept through scheduling and then concretized by concretize_ep_chunk_symbolic_shapes_pass before backend codegen.
concretize_ep_chunk_symbolic_shapes_pass also walks FX GraphModules referenced by node arguments, not only the root graph. Higher-order ops such as FlexAttention carry subgraphs this way, and their placeholders can retain symbolic tensor metadata or scalar example values after chunking. Leaving those nested symbols stale violates the pass invariant that no chunk symbol reaches backend codegen. The pass applies the same argument, metadata, false-guard validation, dead scalar plumbing cleanup, lint, and recompile steps per referenced GraphModule.
The eager wrapper remains the numerical reference: selected modules are split into two chunks, the original module forward runs once per chunk, and outputs are combined back. Graph chunking is expected to be bitwise equivalent to eager chunking when early grad accumulation is disabled, including model state such as MoE expert-bias buffers.
This pass stack relies on PyTorch support for hinted unbacked symbolic dimensions in the tracing and distributed compiler paths:
Test Plan: