From 6b6d7b4f665c9b9431102792ee021878e6bb3b62 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 3 Apr 2026 14:57:57 +0000 Subject: [PATCH] Pre-overlap collective bucketing pass for FSDP/DDP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyTorch's overlap scheduler adds sequential timeline dependencies between all consecutive events on each process group, which prevents its downstream bucketer from merging collectives that were originally independent. This PR adds a pre-pass that merges per-parameter FSDP/DDP collectives before the overlap scheduler runs, so it sees fewer, larger collectives. The pass targets three collective types: forward all-gathers (param-derived), backward reduce-scatters (terminal-derived), and backward all-reduces (terminal-derived, for DDP). The implementation is split into two phases: - Tagging runs on the joint graph (where placeholder metadata is available) and marks eligible collectives via node.meta. Tags survive the fw/bw partition via node_copy's shallow copy. - Bucketing runs on the split fw/bw graphs inside the compiler, reads the tags, and merges collectives using PyTorch's existing merge functions. This PR also fixes a pre-existing bug in _copy_descriptors_and_rename_placeholders where make_fx could nest the output tuple while desc stayed flat, causing get_all_input_and_grad_nodes's zip to silently mismatch output nodes with their descriptors. Results on LLaMA: AG 290→98 (fwd), 225→97 (bwd recomputed), RS 290→194 (bwd). Test plan - example_autoparallel.py passes (compile=False) - example_llama3.py passes - pytest tests/ passes Authored with Claude. --- autoparallel/api.py | 27 +++ autoparallel/apply_sharding.py | 19 +- .../graph_passes/bucket_collectives.py | 213 ++++++++++++++++++ 3 files changed, 257 insertions(+), 2 deletions(-) create mode 100644 autoparallel/graph_passes/bucket_collectives.py diff --git a/autoparallel/api.py b/autoparallel/api.py index 5ecac40e..bdfb5491 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -63,6 +63,10 @@ def _boxed_nop_preserve_node_meta(fx_g, example_inputs): + from autoparallel.graph_passes.bucket_collectives import bucket_collectives + + bucket_collectives(fx_g) + if torch._inductor.config.aten_distributed_optimizations.enable_overlap_scheduling: from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing_from_inductor_configs, @@ -275,6 +279,18 @@ def __enter__(self): ) torch._inductor.config.comprehensive_padding = False + if self.compiler_fn is compile_fx_inner: + from autoparallel.graph_passes.bucket_collectives import ( + bucket_collectives, + ) + + self.old_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + torch._inductor.config.post_grad_custom_post_pass = ( + lambda gm: bucket_collectives(gm.owning_module) + ) + rescale_grad_comm_cost_for_mp = 1.0 if self.mp_policy is not None: param_size = self.mp_policy.param_dtype.itemsize @@ -312,6 +328,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch._inductor.config.comprehensive_padding = ( self.old_inductor_comprehensive_padding ) + if hasattr(self, "old_post_grad_custom_post_pass"): + torch._inductor.config.post_grad_custom_post_pass = ( + self.old_post_grad_custom_post_pass + ) self.active = None return self.stack.__exit__(exc_type, exc_val, exc_tb) @@ -462,6 +482,13 @@ def _apply_placement_common(self, sharding_placement): self.parallel_gm = parallel_gm update_joint_with_descriptors(self.joint_with_descriptors, parallel_gm) fix_scatter_on_aliased_inputs(parallel_gm.graph) + # Tag FSDP/DDP collectives on the joint graph so the bucketing + # pass can identify them after partitioning into fw/bw subgraphs. + from autoparallel.graph_passes.bucket_collectives import ( + tag_collectives_for_bucketing, + ) + + tag_collectives_for_bucketing(parallel_gm.graph) # Allow DCE to remove unused wait_tensor nodes in the backward graph. # Pushed onto self.stack so it's restored in AutoParallel.__exit__. self.stack.enter_context(_suppress_wait_tensor_side_effect()) diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 1766770b..67a237ce 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -310,12 +310,27 @@ def _lower_to_parallel_graph(gm, sharding_placement, local_args): def _copy_descriptors_and_rename_placeholders(source_gm, target_gm): """Copy node descriptors from source graph and rename placeholders to match.""" + from torch.utils._pytree import tree_flatten + for n1, n2 in zip( (n for n in source_gm.graph.nodes if n.op in ("placeholder", "output")), (n for n in target_gm.graph.nodes if n.op in ("placeholder", "output")), ): - n2.meta["desc"] = n1.meta["desc"] - if n2.op == "placeholder": + if n2.op == "output": + # get_all_input_and_grad_nodes iterates with zip(n.args[0], desc), + # so both must be flat. make_fx may nest the output args (e.g., + # ((fw_outs...), (grads...))) while desc from the source is flat. + # Flatten the target output args to match, and copy the flat desc. + flat_desc, _ = tree_flatten(n1.meta["desc"]) + flat_args, _ = tree_flatten(n2.args[0]) + assert len(flat_desc) == len(flat_args), ( + f"Output desc has {len(flat_desc)} leaves but output args has " + f"{len(flat_args)} leaves" + ) + n2.args = (tuple(flat_args),) + n2.meta["desc"] = flat_desc + else: + n2.meta["desc"] = n1.meta["desc"] n2.target = n1.target # node renaming is needed for partitioner as it searches for tangent # nodes. See https://fburl.com/kc4jtc3t for one case where it's used diff --git a/autoparallel/graph_passes/bucket_collectives.py b/autoparallel/graph_passes/bucket_collectives.py new file mode 100644 index 00000000..d1b9821b --- /dev/null +++ b/autoparallel/graph_passes/bucket_collectives.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict + +import torch +from torch._inductor.fx_passes.bucketing import ( + collect_node_descendants, + is_all_gather_into_tensor, + is_all_reduce_tensor, + is_reduce_scatter_tensor, + merge_all_gather_bucket, + merge_all_reduce_bucket, + merge_reduce_scatter_bucket, +) +from torch._inductor.fx_passes.post_grad import stable_topological_sort +from torch.utils._ordered_set import OrderedSet + +from .graph_utils import build_param_derived_set, build_terminal_derived_set + +logger = logging.getLogger(__name__) + +# Meta key used to tag collectives eligible for bucketing. +# Set to "param" for param-derived (forward all-gathers) or +# "terminal" for terminal-derived (backward reduce-scatters / all-reduces). +AP_BUCKET_KEY = "ap_bucket_group" + + +def tag_collectives_for_bucketing(graph: torch.fx.Graph) -> None: + """Tag FSDP/DDP collectives on the joint graph for later bucketing. + + Must run on the joint graph where placeholder metadata is available. + The tags survive partitioning into fw/bw subgraphs via node_copy's + shallow copy of node.meta. + """ + param_derived = build_param_derived_set(graph) + terminal_derived = build_terminal_derived_set(graph) + + n_ag = 0 + n_rs = 0 + n_ar = 0 + for node in graph.nodes: + if is_all_gather_into_tensor(node) and node in param_derived: + node.meta[AP_BUCKET_KEY] = "param" + n_ag += 1 + elif is_reduce_scatter_tensor(node) and node in terminal_derived: + node.meta[AP_BUCKET_KEY] = "terminal" + n_rs += 1 + elif is_all_reduce_tensor(node) and node in terminal_derived: + node.meta[AP_BUCKET_KEY] = "terminal" + n_ar += 1 + + logger.info( + "Tagged collectives for bucketing: AG %d, RS %d, AR %d", + n_ag, + n_rs, + n_ar, + ) + + +def _group_key(node: torch.fx.Node) -> tuple: + """Extract group key for any collective type. + + The key ensures only collectives on the same process group / reduce op / + dtype can be bucketed together — the same constraints enforced by + PyTorch's merge functions. + """ + if is_all_gather_into_tensor(node): + _, group_size, group_name = node.args + return (group_name,) + elif is_reduce_scatter_tensor(node): + _, reduce_op, group_size, group_name = node.args + dtype = node.meta["val"].dtype + return (group_name, reduce_op, dtype) + elif is_all_reduce_tensor(node): + _, reduce_op, group_name = node.args + dtype = node.meta["val"].dtype + return (group_name, reduce_op, dtype) + else: + raise ValueError(f"Unsupported collective type: {node.target}") + + +def _greedy_bucket( + graph: torch.fx.Graph, + coll_nodes: list[torch.fx.Node], + bucket_cap_bytes: int, + node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]], +) -> list[list[torch.fx.Node]]: + """Group collectives into buckets up to bucket_cap_bytes. + + Unlike PyTorch's greedy_bucket_collective_by_mb, this does not require + collectives to be adjacent in the graph — it only requires that no + collective is a descendant of another in the same bucket (to avoid + creating cycles when merged). + """ + if not coll_nodes: + return [] + + groups: dict[tuple, list[torch.fx.Node]] = defaultdict(list) + for node in coll_nodes: + groups[_group_key(node)].append(node) + + buckets: list[list[torch.fx.Node]] = [] + for nodes in groups.values(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_descendants: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_bytes = 0 + + for node in nodes: + if node in cur_bucket_descendants: + continue + + val = node.meta["val"] + out_bytes = val.numel() * val.element_size() + in_val = node.all_input_nodes[0].meta["val"] + in_bytes = in_val.numel() * in_val.element_size() + size_bytes = max(out_bytes, in_bytes) + + if cur_bucket_bytes + size_bytes > bucket_cap_bytes and cur_bucket: + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_bytes = 0 + cur_bucket_descendants = OrderedSet() + + cur_bucket_bytes += size_bytes + cur_bucket.append(node) + cur_bucket_descendants |= node_descendants[node] + + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + + return buckets + + +def bucket_collectives( + gm: torch.fx.GraphModule, + bucket_cap_mb: float = 25.0, +) -> None: + """Bucket FSDP/DDP collectives before the overlap scheduler runs. + + Reads the tags set by tag_collectives_for_bucketing() to identify + which collectives to merge. Merges per-parameter collectives into + larger bucketed collectives so the overlap scheduler sees fewer, + larger ops. + """ + graph = gm.graph + + fsdp_all_gathers: list[torch.fx.Node] = [] + fsdp_reduce_scatters: list[torch.fx.Node] = [] + ddp_all_reduces: list[torch.fx.Node] = [] + + for node in graph.nodes: + bucket_group = node.meta.get(AP_BUCKET_KEY) + if bucket_group is None: + continue + if is_all_gather_into_tensor(node): + fsdp_all_gathers.append(node) + elif is_reduce_scatter_tensor(node): + fsdp_reduce_scatters.append(node) + elif is_all_reduce_tensor(node): + ddp_all_reduces.append(node) + + total = len(fsdp_all_gathers) + len(fsdp_reduce_scatters) + len(ddp_all_reduces) + if total < 2: + return + + node_descendants = collect_node_descendants(graph) + bucket_cap_bytes = int(bucket_cap_mb * 1024 * 1024) + + ag_buckets = _greedy_bucket( + graph, fsdp_all_gathers, bucket_cap_bytes, node_descendants + ) + rs_buckets = _greedy_bucket( + graph, fsdp_reduce_scatters, bucket_cap_bytes, node_descendants + ) + ar_buckets = _greedy_bucket( + graph, ddp_all_reduces, bucket_cap_bytes, node_descendants + ) + + n_merged = sum(len(b) for b in ag_buckets + rs_buckets + ar_buckets) + n_buckets = len(ag_buckets) + len(rs_buckets) + len(ar_buckets) + if n_buckets == 0: + return + + logger.info( + "Bucketing %d collectives into %d buckets " + "(AG: %d->%d, RS: %d->%d, AR: %d->%d)", + n_merged, + n_buckets, + len(fsdp_all_gathers), + len(fsdp_all_gathers) - sum(len(b) for b in ag_buckets) + len(ag_buckets), + len(fsdp_reduce_scatters), + len(fsdp_reduce_scatters) - sum(len(b) for b in rs_buckets) + len(rs_buckets), + len(ddp_all_reduces), + len(ddp_all_reduces) - sum(len(b) for b in ar_buckets) + len(ar_buckets), + ) + + for bucket in ag_buckets: + merge_all_gather_bucket(graph, bucket) + for bucket in rs_buckets: + merge_reduce_scatter_bucket(graph, bucket) + for bucket in ar_buckets: + merge_all_reduce_bucket(graph, bucket) + + # Bucketing can place new nodes (concat, split, etc.) at positions that + # break use-before-def ordering. Re-sort to fix this — same as PyTorch's + # post_grad_passes does after its own FX bucketing. + stable_topological_sort(graph) + gm.recompile()