From 2ede567b0cf36423eadf2f4a8ad73d89e6529c14 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 21 May 2026 12:29:40 +0000 Subject: [PATCH 1/3] Fuse chained allgathers on different subgroups into single full-mesh allgather MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When weights are placed as `S(0)S(0)` on a multi-dim mesh, `apply_sharding` decomposes the `S(0)S(0) → RR` redistribution into per-dim allgathers: a dp-dim allgather followed by a tp-dim allgather, with cancelling permute pairs between them. Each pair produces two separate NCCL kernel launches when a single full-mesh allgather would suffice. This adds `fuse_chained_allgathers`, a graph pass that detects these chains and replaces them with a single allgather on the flattened mesh process group. The pass validates that both allgathers are on known mesh subgroups in descending dim order, their group sizes multiply to the full mesh size, and the intermediate view ops compose to the identity (verified via FakeTensor shape/stride metadata). The pass runs on the partitioned forward and backward graphs during the first compilation and on the inference path, gated on `mesh.ndim > 1`. Authored with Claude. --- autoparallel/graph_passes/fuse_allgather.py | 263 +++++++++++++++ tests/test_fuse_allgather.py | 342 ++++++++++++++++++++ 2 files changed, 605 insertions(+) create mode 100644 autoparallel/graph_passes/fuse_allgather.py create mode 100644 tests/test_fuse_allgather.py diff --git a/autoparallel/graph_passes/fuse_allgather.py b/autoparallel/graph_passes/fuse_allgather.py new file mode 100644 index 00000000..ebadba53 --- /dev/null +++ b/autoparallel/graph_passes/fuse_allgather.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +logger: logging.Logger = logging.getLogger(__name__) + + +def _is_all_gather(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def _is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def _is_nontrivial_dim_reorder(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + if node.target == torch.ops.aten.t.default: + return True + if node.target == torch.ops.aten.transpose.int: + return node.args[1] != node.args[2] + if node.target == torch.ops.aten.permute.default and isinstance( + node.args[1], (list, tuple) + ): + dims = list(node.args[1]) + return dims != list(range(len(dims))) + return False + + +def _is_identity_view_chain(start: torch.fx.Node, end: torch.fx.Node) -> bool: + """Check that the view-op chain from start to end composes to the identity. + + Walks forward from ``start`` through single-user view ops and verifies + that the composed transformation doesn't change the data layout. + Uses FakeTensor metadata: if the output of ``start`` and the input of + ``end`` have the same shape and stride, the chain is an identity + (no data rearrangement, just metadata changes that cancel). + + Only allows ops that are true views (no data copy, no element removal): + permute, transpose, t, view, reshape, expand, unsqueeze, squeeze. + Rejects slice (can drop elements) and any non-view op. + + Returns False for empty chains or chains with no non-trivial dimension + reorder, since consecutive allgathers on different subgroups have + incompatible rank orderings without explicit layout reconciliation. + """ + _ALLOWED_VIEW_OPS = frozenset( + { + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + torch.ops.aten.t.default, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + torch.ops.aten.expand.default, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + } + ) + + # Reject empty chains: no view ops means no layout reconciliation. + users = list(start.users.keys()) + if len(users) == 1 and users[0] is end: + return False + + start_val = start.meta.get("val") + if start_val is None: + return False + start_stride = start_val.stride() + + # Walk forward from start to end, verifying all intermediate ops are + # allowed views and that some op actually reorders dimensions. + node = start + saw_dim_reorder = False + while node is not end: + users = list(node.users.keys()) + if len(users) != 1: + return False + node = users[0] + if node is end: + break + if node.op != "call_function" or node.target not in _ALLOWED_VIEW_OPS: + return False + if _is_nontrivial_dim_reorder(node): + saw_dim_reorder = True + + if not saw_dim_reorder: + return False + + # Verify the composed transformation is identity via FakeTensor metadata. + ag2_input = end.args[0] + end_val = ( + ag2_input.meta.get("val") if isinstance(ag2_input, torch.fx.Node) else None + ) + + if end_val is None: + return False + if start_val.shape != end_val.shape: + return False + if start_stride != end_val.stride(): + return False + return True + + +def fuse_chained_allgathers( + graph: torch.fx.Graph, + full_group_size: int, + full_group_name: str, + subgroup_order: dict[str, int] | None = None, +) -> int: + """Fuse consecutive allgather chains on different subgroups into a single allgather. + + Detects chains of two allgathers on different process groups connected + through single-user view ops that compose to the identity:: + + ag1 = all_gather(x, size1, pg1) + wait1 = wait_tensor(ag1) + ... = identity_view_ops(wait1) + ag2 = all_gather(..., size2, pg2) + wait2 = wait_tensor(ag2) + + and replaces them with:: + + full_ag = all_gather(x, size1 * size2, full_pg) + full_wait = wait_tensor(full_ag) + + Requirements: + - The two group sizes must multiply to ``full_group_size``. + - Every node between the two allgathers must have exactly one user. + - The view ops between them must compose to the identity (verified + via FakeTensor shape and stride metadata). + - Both allgathers must have the same dtype. + - When ``subgroup_order`` is provided, both process groups must be in + that mapping and appear in descending mesh-dim order. + + Returns the number of fusions performed. + """ + fusions = 0 + all_nodes = list(graph.nodes) + + for ag2 in all_nodes: + if not _is_all_gather(ag2): + continue + + # Walk ag2's input backward through single-user nodes to find wait1. + node = ag2.args[0] + if not isinstance(node, torch.fx.Node): + continue + + # Find the wait_tensor that starts the chain. + wait1 = node + while not _is_wait_tensor(wait1): + if len(wait1.users) != 1: + break + if len(wait1.args) == 0: + break + inp = wait1.args[0] + if not isinstance(inp, torch.fx.Node): + break + wait1 = inp + + if not _is_wait_tensor(wait1): + continue + if len(wait1.users) != 1: + continue + + ag1 = wait1.args[0] + if not isinstance(ag1, torch.fx.Node) or not _is_all_gather(ag1): + continue + if len(ag1.users) != 1: + continue + + # Validate that the view chain between wait1 and ag2 is identity. + if not _is_identity_view_chain(wait1, ag2): + continue + + # Validate group sizes. + ag1_group_size = ag1.args[1] + ag2_group_size = ag2.args[1] + if ag1_group_size * ag2_group_size != full_group_size: + continue + + # Validate group names. + ag1_group = ag1.args[2] + ag2_group = ag2.args[2] + assert isinstance(ag1_group, str) + assert isinstance(ag2_group, str) + if ag1_group == ag2_group: + continue + if subgroup_order is not None: + if ag1_group not in subgroup_order or ag2_group not in subgroup_order: + continue + if subgroup_order[ag1_group] <= subgroup_order[ag2_group]: + continue + + # Validate matching dtype. + ag1_val = ag1.meta.get("val") + ag2_val = ag2.meta.get("val") + if ( + ag1_val is not None + and ag2_val is not None + and ag1_val.dtype != ag2_val.dtype + ): + continue + + # Find wait2. + wait2 = None + for user in ag2.users: + if _is_wait_tensor(user): + wait2 = user + break + if wait2 is None: + continue + + # Build the fused allgather. + original_input = ag1.args[0] + + with graph.inserting_before(ag2): + full_ag = graph.call_function( + torch.ops._c10d_functional.all_gather_into_tensor.default, + args=(original_input, full_group_size, full_group_name), + ) + full_ag.meta.update(ag2.meta) + + full_wait = graph.call_function( + torch.ops._c10d_functional.wait_tensor.default, + args=(full_ag,), + ) + full_wait.meta.update(wait2.meta) + + wait2.replace_all_uses_with(full_wait) + fusions += 1 + + logger.debug( + "Fused ag(%s, gs=%d, pg=%s) + ag(gs=%d, pg=%s) -> ag(gs=%d, pg=%s)", + original_input, + ag1_group_size, + ag1_group, + ag2_group_size, + ag2_group, + full_group_size, + full_group_name, + ) + + if fusions > 0: + graph.eliminate_dead_code() + logger.info( + "Fused %d chained allgather pairs into full-mesh allgathers", fusions + ) + + return fusions diff --git a/tests/test_fuse_allgather.py b/tests/test_fuse_allgather.py new file mode 100644 index 00000000..90a26c0d --- /dev/null +++ b/tests/test_fuse_allgather.py @@ -0,0 +1,342 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for chained allgather fusion pass. + +These tests build minimal FX graphs that mimic chained allgather patterns +without running the full AutoParallel pipeline or needing real process groups. +""" + +import pytest +import torch +import torch.fx + +from autoparallel.api import _suppress_wait_tensor_side_effect +from autoparallel.graph_passes.fuse_allgather import fuse_chained_allgathers + +AG = torch.ops._c10d_functional.all_gather_into_tensor.default +WAIT = torch.ops._c10d_functional.wait_tensor.default +PERMUTE = torch.ops.aten.permute.default + + +@pytest.fixture(autouse=True) +def suppress_wait_side_effect(): + """Allow DCE to remove wait_tensor nodes, matching the runtime environment.""" + with _suppress_wait_tensor_side_effect(): + yield + + +def _count_ops(graph, target): + return len(graph.find_nodes(op="call_function", target=target)) + + +def _add_placeholder(graph, name, shape): + node = graph.placeholder(name) + node.meta["val"] = torch.empty(*shape) + return node + + +def _add_all_gather(graph, input_node, group_size, group_name): + in_shape = input_node.meta["val"].shape + out_shape = (in_shape[0] * group_size, *in_shape[1:]) + node = graph.call_function(AG, args=(input_node, group_size, group_name)) + node.meta["val"] = torch.empty(*out_shape, dtype=input_node.meta["val"].dtype) + return node + + +def _add_wait_tensor(graph, input_node): + node = graph.call_function(WAIT, args=(input_node,)) + node.meta["val"] = input_node.meta["val"].clone() + return node + + +def _add_permute(graph, input_node, dims): + node = graph.call_function(PERMUTE, args=(input_node, dims)) + in_val = input_node.meta["val"] + node.meta["val"] = in_val.permute(dims) + return node + + +def _add_chained_allgather(graph, input_node, size1=16, pg1="dp", size2=8, pg2="tp"): + """Build: input -> ag1 -> wait -> permute([1,0]) -> permute([1,0]) -> ag2 -> wait.""" + ag1 = _add_all_gather(graph, input_node, size1, pg1) + wait1 = _add_wait_tensor(graph, ag1) + p1 = _add_permute(graph, wait1, [1, 0]) + p2 = _add_permute(graph, p1, [1, 0]) + ag2 = _add_all_gather(graph, p2, size2, pg2) + wait2 = _add_wait_tensor(graph, ag2) + return wait2 + + +def test_basic_fusion(): + """Detects chained allgathers and fuses into a single full-mesh allgather.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x) + graph.output((result,)) + + assert _count_ops(graph, AG) == 2 + assert _count_ops(graph, WAIT) == 2 + assert _count_ops(graph, PERMUTE) == 2 + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 1 + assert _count_ops(graph, AG) == 1 + assert _count_ops(graph, WAIT) == 1 + assert _count_ops(graph, PERMUTE) == 0 + + ag_node = graph.find_nodes(op="call_function", target=AG)[0] + assert ag_node.args[1] == 128 + assert ag_node.args[2] == "full" + + +def test_multiple_chains(): + """Multiple independent chains in the same graph are all fused.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + y = _add_placeholder(graph, "y", (8, 4096)) + r1 = _add_chained_allgather(graph, x) + r2 = _add_chained_allgather(graph, y) + graph.output((r1, r2)) + + assert _count_ops(graph, AG) == 4 + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 2 + assert _count_ops(graph, AG) == 2 + assert _count_ops(graph, WAIT) == 2 + + +def test_no_fusion_wrong_permute(): + """Non-[1,0] permutes prevent fusion (view chain doesn't trace through).""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + p1 = _add_permute(graph, wait1, [1, 0]) + p2 = _add_permute(graph, p1, [0, 1]) # identity, not [1, 0] + ag2 = _add_all_gather(graph, p2, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_no_fusion_wait_multiple_users(): + """If the first wait has other users, the intermediate result is consumed elsewhere.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + p1 = _add_permute(graph, wait1, [1, 0]) + p2 = _add_permute(graph, p1, [1, 0]) + ag2 = _add_all_gather(graph, p2, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + # wait1 also used directly in output + graph.output((wait2, wait1)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_no_fusion_group_size_mismatch(): + """If size1 * size2 != full_group_size, no fusion occurs.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=16, size2=8) + graph.output((result,)) + + # Wrong full_group_size + fusions = fuse_chained_allgathers(graph, full_group_size=64, full_group_name="full") + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_no_fusion_same_group(): + """Two allgathers on the same process group are not fused.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=4, pg1="dp", size2=4, pg2="dp") + graph.output((result,)) + + fusions = fuse_chained_allgathers(graph, full_group_size=16, full_group_name="full") + + assert fusions == 0 + + +def test_subgroup_order_validation(): + """When subgroup_order is provided, only matching groups in valid order fuse.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=16, pg1="dp", size2=8, pg2="tp") + graph.output((result,)) + + # Unknown subgroup names — should not fuse + fusions = fuse_chained_allgathers( + graph, + full_group_size=128, + full_group_name="full", + subgroup_order={"other1": 0, "other2": 1}, + ) + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + # Correct subgroup order — should fuse + fusions = fuse_chained_allgathers( + graph, + full_group_size=128, + full_group_name="full", + subgroup_order={"dp": 0, "tp": 1}, + ) + assert fusions == 1 + assert _count_ops(graph, AG) == 1 + + +def test_reversed_subgroup_order_does_not_fuse(): + """Only descending mesh-dim allgather order is fuseable.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + result = _add_chained_allgather(graph, x, size1=8, pg1="tp", size2=16, pg2="dp") + graph.output((result,)) + + fusions = fuse_chained_allgathers( + graph, + full_group_size=128, + full_group_name="full", + subgroup_order={"dp": 0, "tp": 1}, + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_with_cast_before_allgather(): + """The cast before the first allgather is preserved and becomes the fused ag's input.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + # Simulate dtype cast + cast = graph.call_function( + torch.ops.prims.convert_element_type.default, args=(x, torch.bfloat16) + ) + cast.meta["val"] = x.meta["val"].to(torch.bfloat16) + result = _add_chained_allgather(graph, cast) + graph.output((result,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 1 + assert _count_ops(graph, AG) == 1 + + # Cast should still be present + cast_count = _count_ops(graph, torch.ops.prims.convert_element_type.default) + assert cast_count == 1 + + # The allgather input should be the cast output + ag_node = graph.find_nodes(op="call_function", target=AG)[0] + assert ag_node.args[0].target == torch.ops.prims.convert_element_type.default + + +def test_standalone_allgather_untouched(): + """A single allgather without a chain is not affected.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag = _add_all_gather(graph, x, 8, "tp") + wait = _add_wait_tensor(graph, ag) + graph.output((wait,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 1 + + +def test_direct_chain_no_views(): + """Two allgathers directly chained (ag1 -> wait -> ag2) with no view ops. + + Without intervening views to reconcile the rank ordering between the + two subgroup allgathers and the flattened group, this is NOT fuseable. + """ + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + ag2 = _add_all_gather(graph, wait1, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + assert _count_ops(graph, WAIT) == 2 + + +def test_noop_view_chain(): + """A no-op view/reshape between allgathers does not reconcile rank ordering. + + Even though a view op is present, if strides never change the chain is + semantically equivalent to a direct chain and must not be fused. + """ + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + # No-op view: same shape, same strides + view = graph.call_function(torch.ops.aten.view.default, args=(wait1, [128, 4096])) + view.meta["val"] = wait1.meta["val"].clone() + ag2 = _add_all_gather(graph, view, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 + + +def test_unsqueeze_squeeze_chain(): + """A temporary stride change without dimension reorder is not fuseable.""" + graph = torch.fx.Graph() + x = _add_placeholder(graph, "x", (8, 4096)) + ag1 = _add_all_gather(graph, x, 16, "dp") + wait1 = _add_wait_tensor(graph, ag1) + unsqueeze = graph.call_function(torch.ops.aten.unsqueeze.default, args=(wait1, 0)) + unsqueeze.meta["val"] = wait1.meta["val"].unsqueeze(0) + squeeze = graph.call_function(torch.ops.aten.squeeze.dim, args=(unsqueeze, 0)) + squeeze.meta["val"] = unsqueeze.meta["val"].squeeze(0) + ag2 = _add_all_gather(graph, squeeze, 8, "tp") + wait2 = _add_wait_tensor(graph, ag2) + graph.output((wait2,)) + + fusions = fuse_chained_allgathers( + graph, full_group_size=128, full_group_name="full" + ) + + assert fusions == 0 + assert _count_ops(graph, AG) == 2 From ca8936e5c83ba16ff90efa2957badff5666d6e22 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 21 May 2026 12:33:13 +0000 Subject: [PATCH 2/3] Add missing file --- autoparallel/api.py | 47 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 1670d509..abe1ba5c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -26,6 +26,7 @@ from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast from .graph_passes.activation_checkpointing import mark_fsdp_all_gather_recomputation +from .graph_passes.fuse_allgather import fuse_chained_allgathers from .graph_passes.graph_utils import ( _add_alias, _replace_view_mm_view_with_einsum, @@ -57,7 +58,10 @@ logger = logging.getLogger(__name__) -def _boxed_nop_preserve_node_meta(fx_g, example_inputs): +def _boxed_nop_preserve_node_meta(fx_g, example_inputs, pre_pass=None): + if pre_pass is not None: + pre_pass(fx_g.graph) + def run(args): with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(fx_g).boxed_run(args) @@ -473,6 +477,27 @@ def _apply_placement_common(self, sharding_placement): sharded_buffer_dict, ) + def _make_fuse_allgather_pass(self): + flat_mesh = self.mesh._flatten() if self.mesh.ndim > 1 else self.mesh + pg = flat_mesh.get_group() + full_group_size = flat_mesh.size() + full_group_name = pg.group_name + + subgroup_order = { + self.mesh.get_group(mesh_dim=dim).group_name: dim + for dim in range(self.mesh.ndim) + } + + def pre_pass(graph): + fuse_chained_allgathers( + graph, + full_group_size, + full_group_name, + subgroup_order=subgroup_order, + ) + + return pre_pass + def apply_placement(self, sharding_placement): sharded_param_dict, sharded_buffer_dict = self._apply_placement_common( sharding_placement @@ -482,10 +507,19 @@ def apply_placement(self, sharding_placement): self.parallel_gm.graph, self.reshard_after_forward ) + compiler_fn = self.compiler_fn + if self.mesh.ndim > 1: + from functools import partial + + compiler_fn = partial( + compiler_fn, + pre_pass=self._make_fuse_allgather_pass(), + ) + self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( self.joint_with_descriptors, - fw_compiler=self.compiler_fn, - bw_compiler=self.compiler_fn, + fw_compiler=compiler_fn, + bw_compiler=compiler_fn, ) # Build a forward-only graph for inference (no backward, no @@ -500,6 +534,13 @@ def apply_placement(self, sharding_placement): self.parallel_gm, num_fwd_outputs, num_primals ) compiler_fn = self.compiler_fn + if self.mesh.ndim > 1: + from functools import partial + + compiler_fn = partial( + compiler_fn, + pre_pass=self._make_fuse_allgather_pass(), + ) aot_config = self.joint_with_descriptors._aot_state.aot_config out_spec = self.joint_with_descriptors.out_spec From bf4c912bdeb87e32ccef635a6d520b4db75a2424 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 21 May 2026 13:52:37 +0000 Subject: [PATCH 3/3] Bugfix --- autoparallel/graph_passes/fuse_allgather.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoparallel/graph_passes/fuse_allgather.py b/autoparallel/graph_passes/fuse_allgather.py index ebadba53..1ddf1bf4 100644 --- a/autoparallel/graph_passes/fuse_allgather.py +++ b/autoparallel/graph_passes/fuse_allgather.py @@ -202,7 +202,7 @@ def fuse_chained_allgathers( if subgroup_order is not None: if ag1_group not in subgroup_order or ag2_group not in subgroup_order: continue - if subgroup_order[ag1_group] <= subgroup_order[ag2_group]: + if subgroup_order[ag1_group] >= subgroup_order[ag2_group]: continue # Validate matching dtype.