diff --git a/autoparallel/api.py b/autoparallel/api.py index 1670d509..abe1ba5c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -26,6 +26,7 @@ from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast from .graph_passes.activation_checkpointing import mark_fsdp_all_gather_recomputation +from .graph_passes.fuse_allgather import fuse_chained_allgathers from .graph_passes.graph_utils import ( _add_alias, _replace_view_mm_view_with_einsum, @@ -57,7 +58,10 @@ logger = logging.getLogger(__name__) -def _boxed_nop_preserve_node_meta(fx_g, example_inputs): +def _boxed_nop_preserve_node_meta(fx_g, example_inputs, pre_pass=None): + if pre_pass is not None: + pre_pass(fx_g.graph) + def run(args): with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(fx_g).boxed_run(args) @@ -473,6 +477,27 @@ def _apply_placement_common(self, sharding_placement): sharded_buffer_dict, ) + def _make_fuse_allgather_pass(self): + flat_mesh = self.mesh._flatten() if self.mesh.ndim > 1 else self.mesh + pg = flat_mesh.get_group() + full_group_size = flat_mesh.size() + full_group_name = pg.group_name + + subgroup_order = { + self.mesh.get_group(mesh_dim=dim).group_name: dim + for dim in range(self.mesh.ndim) + } + + def pre_pass(graph): + fuse_chained_allgathers( + graph, + full_group_size, + full_group_name, + subgroup_order=subgroup_order, + ) + + return pre_pass + def apply_placement(self, sharding_placement): sharded_param_dict, sharded_buffer_dict = self._apply_placement_common( sharding_placement @@ -482,10 +507,19 @@ def apply_placement(self, sharding_placement): self.parallel_gm.graph, self.reshard_after_forward ) + compiler_fn = self.compiler_fn + if self.mesh.ndim > 1: + from functools import partial + + compiler_fn = partial( + compiler_fn, + pre_pass=self._make_fuse_allgather_pass(), + ) + self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( self.joint_with_descriptors, - fw_compiler=self.compiler_fn, - bw_compiler=self.compiler_fn, + fw_compiler=compiler_fn, + bw_compiler=compiler_fn, ) # Build a forward-only graph for inference (no backward, no @@ -500,6 +534,13 @@ def apply_placement(self, sharding_placement): self.parallel_gm, num_fwd_outputs, num_primals ) compiler_fn = self.compiler_fn + if self.mesh.ndim > 1: + from functools import partial + + compiler_fn = partial( + compiler_fn, + pre_pass=self._make_fuse_allgather_pass(), + ) aot_config = self.joint_with_descriptors._aot_state.aot_config out_spec = self.joint_with_descriptors.out_spec diff --git a/autoparallel/graph_passes/fuse_allgather.py b/autoparallel/graph_passes/fuse_allgather.py new file mode 100644 index 00000000..1ddf1bf4 --- /dev/null +++ b/autoparallel/graph_passes/fuse_allgather.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +logger: logging.Logger = logging.getLogger(__name__) + + +def _is_all_gather(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def _is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def _is_nontrivial_dim_reorder(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + if node.target == torch.ops.aten.t.default: + return True + if node.target == torch.ops.aten.transpose.int: + return node.args[1] != node.args[2] + if node.target == torch.ops.aten.permute.default and isinstance( + node.args[1], (list, tuple) + ): + dims = list(node.args[1]) + return dims != list(range(len(dims))) + return False + + +def _is_identity_view_chain(start: torch.fx.Node, end: torch.fx.Node) -> bool: + """Check that the view-op chain from start to end composes to the identity. + + Walks forward from ``start`` through single-user view ops and verifies + that the composed transformation doesn't change the data layout. + Uses FakeTensor metadata: if the output of ``start`` and the input of + ``end`` have the same shape and stride, the chain is an identity + (no data rearrangement, just metadata changes that cancel). + + Only allows ops that are true views (no data copy, no element removal): + permute, transpose, t, view, reshape, expand, unsqueeze, squeeze. + Rejects slice (can drop elements) and any non-view op. + + Returns False for empty chains or chains with no non-trivial dimension + reorder, since consecutive allgathers on different subgroups have + incompatible rank orderings without explicit layout reconciliation. + """ + _ALLOWED_VIEW_OPS = frozenset( + { + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + torch.ops.aten.t.default, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + torch.ops.aten.expand.default, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + } + ) + + # Reject empty chains: no view ops means no layout reconciliation. + users = list(start.users.keys()) + if len(users) == 1 and users[0] is end: + return False + + start_val = start.meta.get("val") + if start_val is None: + return False + start_stride = start_val.stride() + + # Walk forward from start to end, verifying all intermediate ops are + # allowed views and that some op actually reorders dimensions. + node = start + saw_dim_reorder = False + while node is not end: + users = list(node.users.keys()) + if len(users) != 1: + return False + node = users[0] + if node is end: + break + if node.op != "call_function" or node.target not in _ALLOWED_VIEW_OPS: + return False + if _is_nontrivial_dim_reorder(node): + saw_dim_reorder = True + + if not saw_dim_reorder: + return False + + # Verify the composed transformation is identity via FakeTensor metadata. + ag2_input = end.args[0] + end_val = ( + ag2_input.meta.get("val") if isinstance(ag2_input, torch.fx.Node) else None + ) + + if end_val is None: + return False + if start_val.shape != end_val.shape: + return False + if start_stride != end_val.stride(): + return False + return True + + +def fuse_chained_allgathers( + graph: torch.fx.Graph, + full_group_size: int, + full_group_name: str, + subgroup_order: dict[str, int] | None = None, +) -> int: + """Fuse consecutive allgather chains on different subgroups into a single allgather. + + Detects chains of two allgathers on different process groups connected + through single-user view ops that compose to the identity:: + + ag1 = all_gather(x, size1, pg1) + wait1 = wait_tensor(ag1) + ... = identity_view_ops(wait1) + ag2 = all_gather(..., size2, pg2) + wait2 = wait_tensor(ag2) + + and replaces them with:: + + full_ag = all_gather(x, size1 * size2, full_pg) + full_wait = wait_tensor(full_ag) + + Requirements: + - The two group sizes must multiply to ``full_group_size``. + - Every node between the two allgathers must have exactly one user. + - The view ops between them must compose to the identity (verified + via FakeTensor shape and stride metadata). + - Both allgathers must have the same dtype. + - When ``subgroup_order`` is provided, both process groups must be in + that mapping and appear in descending mesh-dim order. + + Returns the number of fusions performed. + """ + fusions = 0 + all_nodes = list(graph.nodes) + + for ag2 in all_nodes: + if not _is_all_gather(ag2): + continue + + # Walk ag2's input backward through single-user nodes to find wait1. + node = ag2.args[0] + if not isinstance(node, torch.fx.Node): + continue + + # Find the wait_tensor that starts the chain. + wait1 = node + while not _is_wait_tensor(wait1): + if len(wait1.users) != 1: + break + if len(wait1.args) == 0: + break + inp = wait1.args[0] + if not isinstance(inp, torch.fx.Node): + break + wait1 = inp + + if not _is_wait_tensor(wait1): + continue + if len(wait1.users) != 1: + continue + + ag1 = wait1.args[0] + if not isinstance(ag1, torch.fx.Node) or not _is_all_gather(ag1): + continue + if len(ag1.users) != 1: + continue + + # Validate that the view chain between wait1 and ag2 is identity. + if not _is_identity_view_chain(wait1, ag2): + continue + + # Validate group sizes. + ag1_group_size = ag1.args[1] + ag2_group_size = ag2.args[1] + if ag1_group_size * ag2_group_size != full_group_size: + continue + + # Validate group names. + ag1_group = ag1.args[2] + ag2_group = ag2.args[2] + assert isinstance(ag1_group, str) + assert isinstance(ag2_group, str) + if ag1_group == ag2_group: + continue + if subgroup_order is not None: + if ag1_group not in subgroup_order or ag2_group not in subgroup_order: + continue + if subgroup_order[ag1_group] >= subgroup_order[ag2_group]: + continue + + # Validate matching dtype. + ag1_val = ag1.meta.get("val") + ag2_val = ag2.meta.get("val") + if ( + ag1_val is not None + and ag2_val is not None + and ag1_val.dtype != ag2_val.dtype + ): + continue + + # Find wait2. + wait2 = None + for user in ag2.users: + if _is_wait_tensor(user): + wait2 = user + break + if wait2 is None: + continue + + # Build the fused allgather. + original_input = ag1.args[0] + + with graph.inserting_before(ag2): + full_ag = graph.call_function( + torch.ops._c10d_functional.all_gather_into_tensor.default, + args=(original_input, full_group_size, full_group_name), + ) + full_ag.meta.update(ag2.meta) + + full_wait = graph.call_function( + torch.ops._c10d_functional.wait_tensor.default, + args=(full_ag,), + ) + full_wait.meta.update(wait2.meta) + + wait2.replace_all_uses_with(full_wait) + fusions += 1 + + logger.debug( + "Fused ag(%s, gs=%d, pg=%s) + ag(gs=%d, pg=%s) -> ag(gs=%d, pg=%s)", + original_input, + ag1_group_size, + ag1_group, + ag2_group_size, + ag2_group, + full_group_size, + full_group_name, + ) + + if fusions > 0: + graph.eliminate_dead_code() + logger.info( + "Fused %d chained allgather pairs into full-mesh allgathers", fusions + ) + + return fusions diff --git a/tests/test_fuse_allgather.py b/tests/test_fuse_allgather.py new file mode 100644 index 00000000..90a26c0d --- /dev/null +++ b/tests/test_fuse_allgather.py @@ -0,0 +1,342 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for chained allgather fusion pass. + +These tests build minimal FX graphs that mimic chained allgather patterns +without running the full AutoParallel pipeline or needing real process groups. +""" + +import pytest +import torch +import torch.fx + +from autoparallel.api import _suppress_wait_tensor_side_effect +from autoparallel.graph_passes.fuse_allgather import fuse_chained_allgathers + +AG = torch.ops._c10d_functional.all_gather_into_tensor.default +WAIT = torch.ops._c10d_functional.wait_tensor.default +PERMUTE = torch.ops.aten.permute.default + + +@pytest.fixture(autouse=True) +def suppress_wait_side_effect(): + """Allow DCE to remove wait_tensor nodes, matching the runtime environment.""" + with _suppress_wait_tensor_side_effect(): + yield + + +def _count_ops(graph, target): + return len(graph.find_nodes(op="call_function", target=target)) + + +def _add_placeholder(graph, name, shape): + node = graph.placeholder(name) + node.meta["val"] = torch.empty(*shape) + return node + + +def _add_all_gather(graph, input_node, group_size, group_name): + in_shape = input_node.meta["val"].shape + out_shape = (in_shape[0] * group_size, *in_shape[1:]) + node = graph.call_function(AG, args=(input_node, group_size, group_name)) + node.meta["val"] = torch.empty(*out_shape, dtype=input_node.meta["val"].dtype) + return node + + +def _add_wait_tensor(graph, input_node): + node = graph.call_function(WAIT, args=(input_node,)) + node.meta["val"] = input_node.meta["val"].clone() + return node + + +def _add_permute(graph, input_node, dims): + node = graph.call_function(PERMUTE, args=(input_node, dims)) + in_val = input_node.meta["val"] + node.meta["val"] = in_val.permute(dims) + return node + + +def _add_chained_allgather(graph, input_node, size1=16, pg1="dp", size2=8, pg2="tp"): + """Build: input -> ag1 -> wait -> permute([1,0]) -> permute([1,0]) -> ag2 -> wait.""" + ag1 = _add_all_gather(graph, input_node, size1, pg1) + wait1 = _add_wait_tensor(graph, ag1) + p1 = _add_permute(graph, wait1, [1, 0]) + p2 = _add_permute(graph, p1, [1, 0]) + ag2 = _add_all_gather(graph, p2, size2, pg2) + wait2 = _add_wait_tensor(graph, ag2) + return wait2 + + +def test_basic_fusion(): + """Detects chained allgathers and fuses into a single full-mesh allgather.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x) + graph.output((result,)) + + assert _count_ops(graph, AG) == 2 + assert _count_ops(graph, WAIT) == 2 + assert _count_ops(graph, PERMUTE) == 2 + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 1 + assert _count_ops(graph, AG) == 1 + assert _count_ops(graph, WAIT) == 1 + assert _count_ops(graph, PERMUTE) == 0 + + ag_node = graph.find_nodes(op="call_function", target=AG)[0] + assert ag_node.args[1] == 128 + assert ag_node.args[2] == "full" + + +def test_multiple_chains(): + """Multiple independent chains in the same graph are all fused.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + y = _add_placeholder(graph, "y", (8, 4096)) + r1 = _add_chained_allgather(graph, x) + r2 = _add_chained_allgather(graph, y) + graph.output((r1, r2)) + + assert _count_ops(graph, AG) == 4 + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 2 + assert _count_ops(graph, AG) == 2 + assert _count_ops(graph, WAIT) == 2 + + +def test_no_fusion_wrong_permute(): + """Non-[1,0] permutes prevent fusion (view chain doesn't trace through).""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + p1 = _add_permute(graph, wait1, [1, 0]) + p2 = _add_permute(graph, p1, [0, 1]) # identity, not [1, 0] + ag2 = _add_all_gather(graph, p2, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_no_fusion_wait_multiple_users(): + """If the first wait has other users, the intermediate result is consumed elsewhere.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + p1 = _add_permute(graph, wait1, [1, 0]) + p2 = _add_permute(graph, p1, [1, 0]) + ag2 = _add_all_gather(graph, p2, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + # wait1 also used directly in output + graph.output((wait2, wait1)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_no_fusion_group_size_mismatch(): + """If size1 * size2 != full_group_size, no fusion occurs.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=16, size2=8) + graph.output((result,)) + + # Wrong full_group_size + fusions = fuse_chained_allgathers(graph, full_group_size=64, full_group_name="full") + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_no_fusion_same_group(): + """Two allgathers on the same process group are not fused.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=4, pg1="dp", size2=4, pg2="dp") + graph.output((result,)) + + fusions = fuse_chained_allgathers(graph, full_group_size=16, full_group_name="full") + + assert fusions == 0 + + +def test_subgroup_order_validation(): + """When subgroup_order is provided, only matching groups in valid order fuse.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=16, pg1="dp", size2=8, pg2="tp") + graph.output((result,)) + + # Unknown subgroup names — should not fuse + fusions = fuse_chained_allgathers( + graph, + full_group_size=128, + full_group_name="full", + subgroup_order={"other1": 0, "other2": 1}, + ) + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + # Correct subgroup order — should fuse + fusions = fuse_chained_allgathers( + graph, + full_group_size=128, + full_group_name="full", + subgroup_order={"dp": 0, "tp": 1}, + ) + assert fusions == 1 + assert _count_ops(graph, AG) == 1 + + +def test_reversed_subgroup_order_does_not_fuse(): + """Only descending mesh-dim allgather order is fuseable.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=8, pg1="tp", size2=16, pg2="dp") + graph.output((result,)) + + fusions = fuse_chained_allgathers( + graph, + full_group_size=128, + full_group_name="full", + subgroup_order={"dp": 0, "tp": 1}, + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_with_cast_before_allgather(): + """The cast before the first allgather is preserved and becomes the fused ag's input.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + # Simulate dtype cast + cast = graph.call_function( + torch.ops.prims.convert_element_type.default, args=(x, torch.bfloat16) + ) + cast.meta["val"] = x.meta["val"].to(torch.bfloat16) + result = _add_chained_allgather(graph, cast) + graph.output((result,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 1 + assert _count_ops(graph, AG) == 1 + + # Cast should still be present + cast_count = _count_ops(graph, torch.ops.prims.convert_element_type.default) + assert cast_count == 1 + + # The allgather input should be the cast output + ag_node = graph.find_nodes(op="call_function", target=AG)[0] + assert ag_node.args[0].target == torch.ops.prims.convert_element_type.default + + +def test_standalone_allgather_untouched(): + """A single allgather without a chain is not affected.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag = _add_all_gather(graph, x, 8, "tp") + wait = _add_wait_tensor(graph, ag) + graph.output((wait,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 1 + + +def test_direct_chain_no_views(): + """Two allgathers directly chained (ag1 -> wait -> ag2) with no view ops. + + Without intervening views to reconcile the rank ordering between the + two subgroup allgathers and the flattened group, this is NOT fuseable. + """ + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + ag2 = _add_all_gather(graph, wait1, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + assert _count_ops(graph, WAIT) == 2 + + +def test_noop_view_chain(): + """A no-op view/reshape between allgathers does not reconcile rank ordering. + + Even though a view op is present, if strides never change the chain is + semantically equivalent to a direct chain and must not be fused. + """ + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + # No-op view: same shape, same strides + view = graph.call_function(torch.ops.aten.view.default, args=(wait1, [128, 4096])) + view.meta["val"] = wait1.meta["val"].clone() + ag2 = _add_all_gather(graph, view, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_unsqueeze_squeeze_chain(): + """A temporary stride change without dimension reorder is not fuseable.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + unsqueeze = graph.call_function(torch.ops.aten.unsqueeze.default, args=(wait1, 0)) + unsqueeze.meta["val"] = wait1.meta["val"].unsqueeze(0) + squeeze = graph.call_function(torch.ops.aten.squeeze.dim, args=(unsqueeze, 0)) + squeeze.meta["val"] = unsqueeze.meta["val"].squeeze(0) + ag2 = _add_all_gather(graph, squeeze, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2