From b24cf466325f0bb92b646f189a4b4efbcbd6bb2c Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Sun, 3 May 2026 14:26:17 -0700 Subject: [PATCH] Move FSDP recompute tagging to the placement compile path This moves mark_fsdp_all_gather_recomputation out of _apply_placement_common and into apply_placement, after the sharded graph has been cleaned up, traced, converted from view to reshape, functionalized for fresh index_put_ mutations, written back to joint descriptors, and prepared for AOT compilation. The common placement helper now only builds and normalizes the parallel graph, while the training compile path applies the FSDP all-gather recomputation tags immediately before invoking aot_compile_joint_with_descriptors. Keeping the tag insertion at the apply_placement boundary makes the graph mutation order explicit: graph rewrites that affect structure happen first, descriptor state is refreshed, wait_tensor DCE behavior is installed, and then recompute metadata is added to the graph that the joint compiler consumes. This avoids mixing placement graph construction with compile-time recompute metadata and keeps the common helper usable for future placement flows that should not eagerly stamp FSDP recompute tags. The compile backend behavior is otherwise unchanged, but the Inductor overlap-scheduling patch set is now centralized in _INDUCTOR_OVERLAP_PATCHES and selected directly when overlap_scheduling is enabled. That keeps autoparallel_backend focused on installing optional functorch AC and Inductor overlap config patches around compile_fx without rebuilding the same overlap dictionary on each backend construction. Authored with Claude. stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/443, branch: sanketpurandare/stack/5 --- autoparallel/api.py | 7 ++++--- autoparallel/compile.py | 19 ++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index eca79884..8f2eea55 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -449,9 +449,6 @@ def _apply_placement_common(self, sharding_placement): view_to_reshape(parallel_gm) functionalize_fresh_index_put_mutations(parallel_gm) - mark_fsdp_all_gather_recomputation( - parallel_gm.graph, self.reshard_after_forward - ) t_ac = time.perf_counter() # now rename input/param/tangent/output/grad_param/grad_input nodes following # our convention @@ -483,6 +480,10 @@ def apply_placement(self, sharding_placement): sharding_placement ) + mark_fsdp_all_gather_recomputation( + self.parallel_gm.graph, self.reshard_after_forward + ) + self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( self.joint_with_descriptors, fw_compiler=self.compiler_fn, diff --git a/autoparallel/compile.py b/autoparallel/compile.py index 256185f1..3c73a66a 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -12,6 +12,13 @@ from .graph_passes.activation_checkpointing import ac_joint_pass +_INDUCTOR_OVERLAP_PATCHES = { + "aten_distributed_optimizations.enable_overlap_scheduling": True, + "aten_distributed_optimizations.collective_bucketing": True, + "aten_distributed_optimizations.insert_overlap_deps": True, + "aten_distributed_optimizations.max_compute_pre_fetch": 10, +} + def _make_ac_joint_pass( ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", @@ -44,23 +51,13 @@ def autoparallel_backend( overlap_scheduling: Enable comm/compute overlap scheduling. """ functorch_patches = {} - inductor_patches = {} + inductor_patches = _INDUCTOR_OVERLAP_PATCHES if overlap_scheduling else {} if enable_ac: functorch_patches["joint_custom_pass"] = _make_ac_joint_pass( ac_stage_size_in_GiB ) - if overlap_scheduling: - inductor_patches.update( - { - "aten_distributed_optimizations.enable_overlap_scheduling": True, - "aten_distributed_optimizations.collective_bucketing": True, - "aten_distributed_optimizations.insert_overlap_deps": True, - "aten_distributed_optimizations.max_compute_pre_fetch": 10, - } - ) - def backend(gm, example_inputs): with ( torch._functorch.config.patch(functorch_patches),