diff --git a/autoparallel/api.py b/autoparallel/api.py index eca79884..8f2eea55 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -449,9 +449,6 @@ def _apply_placement_common(self, sharding_placement): view_to_reshape(parallel_gm) functionalize_fresh_index_put_mutations(parallel_gm) - mark_fsdp_all_gather_recomputation( - parallel_gm.graph, self.reshard_after_forward - ) t_ac = time.perf_counter() # now rename input/param/tangent/output/grad_param/grad_input nodes following # our convention @@ -483,6 +480,10 @@ def apply_placement(self, sharding_placement): sharding_placement ) + mark_fsdp_all_gather_recomputation( + self.parallel_gm.graph, self.reshard_after_forward + ) + self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( self.joint_with_descriptors, fw_compiler=self.compiler_fn, diff --git a/autoparallel/compile.py b/autoparallel/compile.py index 256185f1..3c73a66a 100644 --- a/autoparallel/compile.py +++ b/autoparallel/compile.py @@ -12,6 +12,13 @@ from .graph_passes.activation_checkpointing import ac_joint_pass +_INDUCTOR_OVERLAP_PATCHES = { + "aten_distributed_optimizations.enable_overlap_scheduling": True, + "aten_distributed_optimizations.collective_bucketing": True, + "aten_distributed_optimizations.insert_overlap_deps": True, + "aten_distributed_optimizations.max_compute_pre_fetch": 10, +} + def _make_ac_joint_pass( ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", @@ -44,23 +51,13 @@ def autoparallel_backend( overlap_scheduling: Enable comm/compute overlap scheduling. """ functorch_patches = {} - inductor_patches = {} + inductor_patches = _INDUCTOR_OVERLAP_PATCHES if overlap_scheduling else {} if enable_ac: functorch_patches["joint_custom_pass"] = _make_ac_joint_pass( ac_stage_size_in_GiB ) - if overlap_scheduling: - inductor_patches.update( - { - "aten_distributed_optimizations.enable_overlap_scheduling": True, - "aten_distributed_optimizations.collective_bucketing": True, - "aten_distributed_optimizations.insert_overlap_deps": True, - "aten_distributed_optimizations.max_compute_pre_fetch": 10, - } - ) - def backend(gm, example_inputs): with ( torch._functorch.config.patch(functorch_patches),