[graph_trainer] Support hinted symbolic input dims in tracing#3362
[graph_trainer] Support hinted symbolic input dims in tracing#3362sanketpurandare wants to merge 1 commit into
Conversation
07f3629 to
1ee37b9
Compare
60ac042 to
3178797
Compare
d7fa96e to
0fb48ed
Compare
|
|
||
|
|
||
| def _wrapper_subclass_has_mark_unbacked(tensor: torch.Tensor) -> bool: | ||
| def _tensor_has_mark_dynamic(tensor: torch.Tensor) -> bool: |
There was a problem hiding this comment.
Should this also consider _dynamo_dynamic_range ? Previously, minimal tracer seems to have considered this too as "mark_dynamic"
There was a problem hiding this comment.
Added support for both mark_dynamic and mark_unbacked metadata paths in the minimal tracer. The current helper imports the Dynamo dynamic/unbacked annotations into the StatelessSymbolicContext, including the unbacked bounds used by hinted symbolic dimensions.
| dynamic_sizes[dim] = DimDynamic.UNBACKED | ||
| constraint_sizes[dim] = RelaxedUnspecConstraint(warn_only=False) | ||
| elif dim in marked_dynamic_indices: | ||
| dynamic_sizes[dim] = DimDynamic.DYNAMIC |
There was a problem hiding this comment.
Should we update constraint_sizes if dynamic has max/min hints?
There was a problem hiding this comment.
The implementation keeps strict unbacked dimensions constrained with RelaxedUnspecConstraint and passes the Dynamo-provided unbacked bounds into the symbolic context. The min/max hint path is preserved through unbacked_bounds rather than specializing the traced graph to concrete sizes.
| traced = minimal_fx_tracer(forward)( | ||
| x, xq, xk, freqs_cis, rope_cache, positions | ||
| ) | ||
| self.assertTrue( |
There was a problem hiding this comment.
Test with a shape different than what was traced?
There was a problem hiding this comment.
Added tracing coverage for marked dynamic inputs and runtime assertion materialization. The traced graph keeps the input dimension symbolic and emits ShapeEnv runtime checks, so the test now verifies behavior beyond the exact concrete size used during capture.
| ] | ||
|
|
||
|
|
||
| def _check_shape_equal(actual, expected, context: str) -> None: |
There was a problem hiding this comment.
Can we restrict this change to graph_trainer only in some way? It seems necessary only for the minimal fx tracer?
There was a problem hiding this comment.
Restricted the behavior to the graph-trainer/minimal-tracer use case by keeping the symbolic shape support in the tracer path. The RoPE change is limited to replacing Python shape equality assertions with torch._check, which is the PyTorch-native way to let symbolic sizes flow without weakening eager/runtime validation.
0fb48ed to
b243510
Compare
b243510 to
bd80c98
Compare
bd80c98 to
7a084ea
Compare
Add the graph_trainer tracing prerequisites needed by later EP chunking work, without introducing the chunking pass itself. minimal_fx_tracer now accepts both mark_unbacked and mark_dynamic metadata on plain tensor inputs, builds the corresponding symbolic context during fakeification, and keeps wrapper-subclass inputs rejected so DTensor-style layouts do not silently lose their metadata. Make RoPE shape checks symbolic-shape friendly by replacing Python shape asserts with compact torch._check-based validation. This keeps the runtime checks but avoids specializing symbolic batch or sequence dimensions during tracing. Make full-Inductor FX-to-FX canonicalization tolerate fresh intermediate unbacked scalar symbols, and keep MoE split-size CPU copies synchronous under compile/non-strict tracing so traced split sizes cannot race stale CPU reads. Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k 'mark_dynamic_batch_and_seq_dims_with_rope or dtensor_mark_unbacked_rejected' stack-info: PR: #3362, branch: sanketpurandare/stack/16
7a084ea to
dd9fd88
Compare
dd9fd88 to
1a4e569
Compare
1a4e569 to
adf04c0
Compare
|
There are some coupling between tracer and EP_overlap passes... This needs to be called out in readme.md / claude.md ... together with clear instruction on how to use ep overlap, and document it's composability and limitations. |
Add the graph_trainer tracing prerequisites needed by later EP chunking work, without introducing the chunking pass itself. minimal_fx_tracer now accepts both mark_unbacked and mark_dynamic metadata on plain tensor inputs, builds the corresponding symbolic context during fakeification, and keeps wrapper-subclass inputs rejected so DTensor-style layouts do not silently lose their metadata. Make RoPE shape checks symbolic-shape friendly by replacing Python shape asserts with compact torch._check-based validation. This keeps the runtime checks but avoids specializing symbolic batch or sequence dimensions during tracing. Make full-Inductor FX-to-FX canonicalization tolerate fresh intermediate unbacked scalar symbols, and keep MoE split-size CPU copies synchronous under compile/non-strict tracing so traced split sizes cannot race stale CPU reads. Test Plan: - pytest -q torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k 'mark_dynamic_batch_and_seq_dims_with_rope or dtensor_mark_unbacked_rejected' stack-info: PR: #3362, branch: sanketpurandare/stack/16
adf04c0 to
4bec890
Compare
Stacked PRs:
[graph_trainer] Support hinted symbolic input dims in tracing
Add the graph_trainer tracing prerequisites needed by later EP chunking work, without introducing the chunking pass itself. minimal_fx_tracer now accepts both mark_unbacked and mark_dynamic metadata on plain tensor inputs, builds the corresponding symbolic context during fakeification, and keeps wrapper-subclass inputs rejected so DTensor-style layouts do not silently lose their metadata.
Make RoPE shape checks symbolic-shape friendly by replacing Python shape asserts with compact torch._check-based validation. This keeps the runtime checks but avoids specializing symbolic batch or sequence dimensions during tracing.
Make full-Inductor FX-to-FX canonicalization tolerate fresh intermediate unbacked scalar symbols, and keep MoE split-size CPU copies synchronous under compile/non-strict tracing so traced split sizes cannot race stale CPU reads.
Test Plan: