diff --git a/autoparallel/api.py b/autoparallel/api.py index 5ecac40e..bdfb5491 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -63,6 +63,10 @@ def _boxed_nop_preserve_node_meta(fx_g, example_inputs): + from autoparallel.graph_passes.bucket_collectives import bucket_collectives + + bucket_collectives(fx_g) + if torch._inductor.config.aten_distributed_optimizations.enable_overlap_scheduling: from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing_from_inductor_configs, @@ -275,6 +279,18 @@ def __enter__(self): ) torch._inductor.config.comprehensive_padding = False + if self.compiler_fn is compile_fx_inner: + from autoparallel.graph_passes.bucket_collectives import ( + bucket_collectives, + ) + + self.old_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + torch._inductor.config.post_grad_custom_post_pass = ( + lambda gm: bucket_collectives(gm.owning_module) + ) + rescale_grad_comm_cost_for_mp = 1.0 if self.mp_policy is not None: param_size = self.mp_policy.param_dtype.itemsize @@ -312,6 +328,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch._inductor.config.comprehensive_padding = ( self.old_inductor_comprehensive_padding ) + if hasattr(self, "old_post_grad_custom_post_pass"): + torch._inductor.config.post_grad_custom_post_pass = ( + self.old_post_grad_custom_post_pass + ) self.active = None return self.stack.__exit__(exc_type, exc_val, exc_tb) @@ -462,6 +482,13 @@ def _apply_placement_common(self, sharding_placement): self.parallel_gm = parallel_gm update_joint_with_descriptors(self.joint_with_descriptors, parallel_gm) fix_scatter_on_aliased_inputs(parallel_gm.graph) + # Tag FSDP/DDP collectives on the joint graph so the bucketing + # pass can identify them after partitioning into fw/bw subgraphs. + from autoparallel.graph_passes.bucket_collectives import ( + tag_collectives_for_bucketing, + ) + + tag_collectives_for_bucketing(parallel_gm.graph) # Allow DCE to remove unused wait_tensor nodes in the backward graph. # Pushed onto self.stack so it's restored in AutoParallel.__exit__. self.stack.enter_context(_suppress_wait_tensor_side_effect()) diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 1766770b..67a237ce 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -310,12 +310,27 @@ def _lower_to_parallel_graph(gm, sharding_placement, local_args): def _copy_descriptors_and_rename_placeholders(source_gm, target_gm): """Copy node descriptors from source graph and rename placeholders to match.""" + from torch.utils._pytree import tree_flatten + for n1, n2 in zip( (n for n in source_gm.graph.nodes if n.op in ("placeholder", "output")), (n for n in target_gm.graph.nodes if n.op in ("placeholder", "output")), ): - n2.meta["desc"] = n1.meta["desc"] - if n2.op == "placeholder": + if n2.op == "output": + # get_all_input_and_grad_nodes iterates with zip(n.args[0], desc), + # so both must be flat. make_fx may nest the output args (e.g., + # ((fw_outs...), (grads...))) while desc from the source is flat. + # Flatten the target output args to match, and copy the flat desc. + flat_desc, _ = tree_flatten(n1.meta["desc"]) + flat_args, _ = tree_flatten(n2.args[0]) + assert len(flat_desc) == len(flat_args), ( + f"Output desc has {len(flat_desc)} leaves but output args has " + f"{len(flat_args)} leaves" + ) + n2.args = (tuple(flat_args),) + n2.meta["desc"] = flat_desc + else: + n2.meta["desc"] = n1.meta["desc"] n2.target = n1.target # node renaming is needed for partitioner as it searches for tangent # nodes. See https://fburl.com/kc4jtc3t for one case where it's used diff --git a/autoparallel/graph_passes/bucket_collectives.py b/autoparallel/graph_passes/bucket_collectives.py new file mode 100644 index 00000000..d1b9821b --- /dev/null +++ b/autoparallel/graph_passes/bucket_collectives.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict + +import torch +from torch._inductor.fx_passes.bucketing import ( + collect_node_descendants, + is_all_gather_into_tensor, + is_all_reduce_tensor, + is_reduce_scatter_tensor, + merge_all_gather_bucket, + merge_all_reduce_bucket, + merge_reduce_scatter_bucket, +) +from torch._inductor.fx_passes.post_grad import stable_topological_sort +from torch.utils._ordered_set import OrderedSet + +from .graph_utils import build_param_derived_set, build_terminal_derived_set + +logger = logging.getLogger(__name__) + +# Meta key used to tag collectives eligible for bucketing. +# Set to "param" for param-derived (forward all-gathers) or +# "terminal" for terminal-derived (backward reduce-scatters / all-reduces). +AP_BUCKET_KEY = "ap_bucket_group" + + +def tag_collectives_for_bucketing(graph: torch.fx.Graph) -> None: + """Tag FSDP/DDP collectives on the joint graph for later bucketing. + + Must run on the joint graph where placeholder metadata is available. + The tags survive partitioning into fw/bw subgraphs via node_copy's + shallow copy of node.meta. + """ + param_derived = build_param_derived_set(graph) + terminal_derived = build_terminal_derived_set(graph) + + n_ag = 0 + n_rs = 0 + n_ar = 0 + for node in graph.nodes: + if is_all_gather_into_tensor(node) and node in param_derived: + node.meta[AP_BUCKET_KEY] = "param" + n_ag += 1 + elif is_reduce_scatter_tensor(node) and node in terminal_derived: + node.meta[AP_BUCKET_KEY] = "terminal" + n_rs += 1 + elif is_all_reduce_tensor(node) and node in terminal_derived: + node.meta[AP_BUCKET_KEY] = "terminal" + n_ar += 1 + + logger.info( + "Tagged collectives for bucketing: AG %d, RS %d, AR %d", + n_ag, + n_rs, + n_ar, + ) + + +def _group_key(node: torch.fx.Node) -> tuple: + """Extract group key for any collective type. + + The key ensures only collectives on the same process group / reduce op / + dtype can be bucketed together — the same constraints enforced by + PyTorch's merge functions. + """ + if is_all_gather_into_tensor(node): + _, group_size, group_name = node.args + return (group_name,) + elif is_reduce_scatter_tensor(node): + _, reduce_op, group_size, group_name = node.args + dtype = node.meta["val"].dtype + return (group_name, reduce_op, dtype) + elif is_all_reduce_tensor(node): + _, reduce_op, group_name = node.args + dtype = node.meta["val"].dtype + return (group_name, reduce_op, dtype) + else: + raise ValueError(f"Unsupported collective type: {node.target}") + + +def _greedy_bucket( + graph: torch.fx.Graph, + coll_nodes: list[torch.fx.Node], + bucket_cap_bytes: int, + node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]], +) -> list[list[torch.fx.Node]]: + """Group collectives into buckets up to bucket_cap_bytes. + + Unlike PyTorch's greedy_bucket_collective_by_mb, this does not require + collectives to be adjacent in the graph — it only requires that no + collective is a descendant of another in the same bucket (to avoid + creating cycles when merged). + """ + if not coll_nodes: + return [] + + groups: dict[tuple, list[torch.fx.Node]] = defaultdict(list) + for node in coll_nodes: + groups[_group_key(node)].append(node) + + buckets: list[list[torch.fx.Node]] = [] + for nodes in groups.values(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_descendants: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_bytes = 0 + + for node in nodes: + if node in cur_bucket_descendants: + continue + + val = node.meta["val"] + out_bytes = val.numel() * val.element_size() + in_val = node.all_input_nodes[0].meta["val"] + in_bytes = in_val.numel() * in_val.element_size() + size_bytes = max(out_bytes, in_bytes) + + if cur_bucket_bytes + size_bytes > bucket_cap_bytes and cur_bucket: + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_bytes = 0 + cur_bucket_descendants = OrderedSet() + + cur_bucket_bytes += size_bytes + cur_bucket.append(node) + cur_bucket_descendants |= node_descendants[node] + + if len(cur_bucket) > 1: + buckets.append(cur_bucket) + + return buckets + + +def bucket_collectives( + gm: torch.fx.GraphModule, + bucket_cap_mb: float = 25.0, +) -> None: + """Bucket FSDP/DDP collectives before the overlap scheduler runs. + + Reads the tags set by tag_collectives_for_bucketing() to identify + which collectives to merge. Merges per-parameter collectives into + larger bucketed collectives so the overlap scheduler sees fewer, + larger ops. + """ + graph = gm.graph + + fsdp_all_gathers: list[torch.fx.Node] = [] + fsdp_reduce_scatters: list[torch.fx.Node] = [] + ddp_all_reduces: list[torch.fx.Node] = [] + + for node in graph.nodes: + bucket_group = node.meta.get(AP_BUCKET_KEY) + if bucket_group is None: + continue + if is_all_gather_into_tensor(node): + fsdp_all_gathers.append(node) + elif is_reduce_scatter_tensor(node): + fsdp_reduce_scatters.append(node) + elif is_all_reduce_tensor(node): + ddp_all_reduces.append(node) + + total = len(fsdp_all_gathers) + len(fsdp_reduce_scatters) + len(ddp_all_reduces) + if total < 2: + return + + node_descendants = collect_node_descendants(graph) + bucket_cap_bytes = int(bucket_cap_mb * 1024 * 1024) + + ag_buckets = _greedy_bucket( + graph, fsdp_all_gathers, bucket_cap_bytes, node_descendants + ) + rs_buckets = _greedy_bucket( + graph, fsdp_reduce_scatters, bucket_cap_bytes, node_descendants + ) + ar_buckets = _greedy_bucket( + graph, ddp_all_reduces, bucket_cap_bytes, node_descendants + ) + + n_merged = sum(len(b) for b in ag_buckets + rs_buckets + ar_buckets) + n_buckets = len(ag_buckets) + len(rs_buckets) + len(ar_buckets) + if n_buckets == 0: + return + + logger.info( + "Bucketing %d collectives into %d buckets " + "(AG: %d->%d, RS: %d->%d, AR: %d->%d)", + n_merged, + n_buckets, + len(fsdp_all_gathers), + len(fsdp_all_gathers) - sum(len(b) for b in ag_buckets) + len(ag_buckets), + len(fsdp_reduce_scatters), + len(fsdp_reduce_scatters) - sum(len(b) for b in rs_buckets) + len(rs_buckets), + len(ddp_all_reduces), + len(ddp_all_reduces) - sum(len(b) for b in ar_buckets) + len(ar_buckets), + ) + + for bucket in ag_buckets: + merge_all_gather_bucket(graph, bucket) + for bucket in rs_buckets: + merge_reduce_scatter_bucket(graph, bucket) + for bucket in ar_buckets: + merge_all_reduce_bucket(graph, bucket) + + # Bucketing can place new nodes (concat, split, etc.) at positions that + # break use-before-def ordering. Re-sort to fix this — same as PyTorch's + # post_grad_passes does after its own FX bucketing. + stable_topological_sort(graph) + gm.recompile()