Skip to content

[graph_trainer] Support hinted symbolic input dims in tracing#3362

Open
sanketpurandare wants to merge 1 commit into
mainfrom
sanketpurandare/stack/16
Open

[graph_trainer] Support hinted symbolic input dims in tracing#3362
sanketpurandare wants to merge 1 commit into
mainfrom
sanketpurandare/stack/16

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented May 15, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[graph_trainer] Support hinted symbolic input dims in tracing

Add the graph_trainer tracing prerequisites needed by later EP chunking work, without introducing the chunking pass itself. minimal_fx_tracer now accepts both mark_unbacked and mark_dynamic metadata on plain tensor inputs, builds the corresponding symbolic context during fakeification, and keeps wrapper-subclass inputs rejected so DTensor-style layouts do not silently lose their metadata.

Make RoPE shape checks symbolic-shape friendly by replacing Python shape asserts with compact torch._check-based validation. This keeps the runtime checks but avoids specializing symbolic batch or sequence dimensions during tracing.

Make full-Inductor FX-to-FX canonicalization tolerate fresh intermediate unbacked scalar symbols, and keep MoE split-size CPU copies synchronous under compile/non-strict tracing so traced split sizes cannot race stale CPU reads.

Test Plan:

  • pytest -q torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k 'mark_dynamic_batch_and_seq_dims_with_rope or dtensor_mark_unbacked_rejected'

@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/15 branch from 07f3629 to 1ee37b9 Compare May 15, 2026 03:03
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from 60ac042 to 3178797 Compare May 15, 2026 03:03
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 15, 2026
@sanketpurandare sanketpurandare marked this pull request as draft May 15, 2026 09:36
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/15 to main May 15, 2026 09:36
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch 2 times, most recently from d7fa96e to 0fb48ed Compare May 15, 2026 09:36
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/18 May 15, 2026 09:36
@sanketpurandare sanketpurandare marked this pull request as ready for review May 15, 2026 09:37


def _wrapper_subclass_has_mark_unbacked(tensor: torch.Tensor) -> bool:
def _tensor_has_mark_dynamic(tensor: torch.Tensor) -> bool:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also consider _dynamo_dynamic_range ? Previously, minimal tracer seems to have considered this too as "mark_dynamic"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support for both mark_dynamic and mark_unbacked metadata paths in the minimal tracer. The current helper imports the Dynamo dynamic/unbacked annotations into the StatelessSymbolicContext, including the unbacked bounds used by hinted symbolic dimensions.

dynamic_sizes[dim] = DimDynamic.UNBACKED
constraint_sizes[dim] = RelaxedUnspecConstraint(warn_only=False)
elif dim in marked_dynamic_indices:
dynamic_sizes[dim] = DimDynamic.DYNAMIC

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update constraint_sizes if dynamic has max/min hints?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation keeps strict unbacked dimensions constrained with RelaxedUnspecConstraint and passes the Dynamo-provided unbacked bounds into the symbolic context. The min/max hint path is preserved through unbacked_bounds rather than specializing the traced graph to concrete sizes.

traced = minimal_fx_tracer(forward)(
x, xq, xk, freqs_cis, rope_cache, positions
)
self.assertTrue(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test with a shape different than what was traced?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tracing coverage for marked dynamic inputs and runtime assertion materialization. The traced graph keeps the input dimension symbolic and emits ShapeEnv runtime checks, so the test now verifies behavior beyond the exact concrete size used during capture.

]


def _check_shape_equal(actual, expected, context: str) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we restrict this change to graph_trainer only in some way? It seems necessary only for the minimal fx tracer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restricted the behavior to the graph-trainer/minimal-tracer use case by keeping the symbolic shape support in the tracer path. The RoPE change is limited to replacing Python shape equality assertions with torch._check, which is the PyTorch-native way to let symbolic sizes flow without weakening eager/runtime validation.

@sanketpurandare sanketpurandare marked this pull request as draft May 22, 2026 19:02
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/18 to main May 22, 2026 19:02
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from 0fb48ed to b243510 Compare May 22, 2026 19:03
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/18 May 22, 2026 19:03
@sanketpurandare sanketpurandare marked this pull request as ready for review May 22, 2026 19:03
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:28
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/18 to main May 26, 2026 05:28
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from b243510 to bd80c98 Compare May 26, 2026 05:29
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/18 May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:29
@sanketpurandare sanketpurandare marked this pull request as draft May 26, 2026 05:30
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/18 to main May 26, 2026 05:30
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/18 May 26, 2026 05:31
@sanketpurandare sanketpurandare marked this pull request as ready for review May 26, 2026 05:31
@sanketpurandare sanketpurandare marked this pull request as draft May 27, 2026 03:51
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/18 to main May 27, 2026 03:51
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from bd80c98 to 7a084ea Compare May 27, 2026 03:51
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/18 May 27, 2026 03:51
@sanketpurandare sanketpurandare marked this pull request as ready for review May 27, 2026 03:52
sanketpurandare added a commit that referenced this pull request May 27, 2026
Add the graph_trainer tracing prerequisites needed by later EP chunking work, without introducing the chunking pass itself. minimal_fx_tracer now accepts both mark_unbacked and mark_dynamic metadata on plain tensor inputs, builds the corresponding symbolic context during fakeification, and keeps wrapper-subclass inputs rejected so DTensor-style layouts do not silently lose their metadata.

Make RoPE shape checks symbolic-shape friendly by replacing Python shape asserts with compact torch._check-based validation. This keeps the runtime checks but avoids specializing symbolic batch or sequence dimensions during tracing.

Make full-Inductor FX-to-FX canonicalization tolerate fresh intermediate unbacked scalar symbols, and keep MoE split-size CPU copies synchronous under compile/non-strict tracing so traced split sizes cannot race stale CPU reads.

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k 'mark_dynamic_batch_and_seq_dims_with_rope or dtensor_mark_unbacked_rejected'

stack-info: PR: #3362, branch: sanketpurandare/stack/16
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from 7a084ea to dd9fd88 Compare May 27, 2026 04:41
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/18 to main May 27, 2026 04:41
@sanketpurandare sanketpurandare marked this pull request as draft May 27, 2026 15:28
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from dd9fd88 to 1a4e569 Compare May 27, 2026 15:28
@sanketpurandare sanketpurandare marked this pull request as ready for review May 27, 2026 15:29
@sanketpurandare sanketpurandare marked this pull request as draft June 4, 2026 21:36
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from 1a4e569 to adf04c0 Compare June 4, 2026 21:36
@sanketpurandare sanketpurandare marked this pull request as ready for review June 4, 2026 21:36
@SherlockNoMad

SherlockNoMad commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

There are some coupling between tracer and EP_overlap passes...

This needs to be called out in readme.md / claude.md ... together with clear instruction on how to use ep overlap, and document it's composability and limitations.

Add the graph_trainer tracing prerequisites needed by later EP chunking work, without introducing the chunking pass itself. minimal_fx_tracer now accepts both mark_unbacked and mark_dynamic metadata on plain tensor inputs, builds the corresponding symbolic context during fakeification, and keeps wrapper-subclass inputs rejected so DTensor-style layouts do not silently lose their metadata.

Make RoPE shape checks symbolic-shape friendly by replacing Python shape asserts with compact torch._check-based validation. This keeps the runtime checks but avoids specializing symbolic batch or sequence dimensions during tracing.

Make full-Inductor FX-to-FX canonicalization tolerate fresh intermediate unbacked scalar symbols, and keep MoE split-size CPU copies synchronous under compile/non-strict tracing so traced split sizes cannot race stale CPU reads.

Test Plan:
- pytest -q torchtitan/experiments/graph_trainer/tests/test_trace_module.py -k 'mark_dynamic_batch_and_seq_dims_with_rope or dtensor_mark_unbacked_rejected'

stack-info: PR: #3362, branch: sanketpurandare/stack/16
@sanketpurandare sanketpurandare marked this pull request as draft June 10, 2026 23:49
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/16 branch from adf04c0 to 4bec890 Compare June 10, 2026 23:50
@sanketpurandare sanketpurandare marked this pull request as ready for review June 10, 2026 23:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants