Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .apply_sharding import apply_sharding_to_model
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
from .graph_passes.activation_checkpointing import mark_fsdp_all_gather_recomputation
from .graph_passes.fuse_allgather import fuse_chained_allgathers
from .graph_passes.graph_utils import (
_add_alias,
_replace_view_mm_view_with_einsum,
Expand Down Expand Up @@ -57,7 +58,10 @@
logger = logging.getLogger(__name__)


def _boxed_nop_preserve_node_meta(fx_g, example_inputs):
def _boxed_nop_preserve_node_meta(fx_g, example_inputs, pre_pass=None):
if pre_pass is not None:
pre_pass(fx_g.graph)

def run(args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(fx_g).boxed_run(args)
Expand Down Expand Up @@ -473,6 +477,27 @@ def _apply_placement_common(self, sharding_placement):
sharded_buffer_dict,
)

def _make_fuse_allgather_pass(self):
flat_mesh = self.mesh._flatten() if self.mesh.ndim > 1 else self.mesh
pg = flat_mesh.get_group()
full_group_size = flat_mesh.size()
full_group_name = pg.group_name

subgroup_order = {
self.mesh.get_group(mesh_dim=dim).group_name: dim
for dim in range(self.mesh.ndim)
}

def pre_pass(graph):
fuse_chained_allgathers(
graph,
full_group_size,
full_group_name,
subgroup_order=subgroup_order,
)

return pre_pass

def apply_placement(self, sharding_placement):
sharded_param_dict, sharded_buffer_dict = self._apply_placement_common(
sharding_placement
Expand All @@ -482,10 +507,19 @@ def apply_placement(self, sharding_placement):
self.parallel_gm.graph, self.reshard_after_forward
)

compiler_fn = self.compiler_fn
if self.mesh.ndim > 1:
from functools import partial

compiler_fn = partial(
compiler_fn,
pre_pass=self._make_fuse_allgather_pass(),
)

self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
self.joint_with_descriptors,
fw_compiler=self.compiler_fn,
bw_compiler=self.compiler_fn,
fw_compiler=compiler_fn,
bw_compiler=compiler_fn,
)

# Build a forward-only graph for inference (no backward, no
Expand All @@ -500,6 +534,13 @@ def apply_placement(self, sharding_placement):
self.parallel_gm, num_fwd_outputs, num_primals
)
compiler_fn = self.compiler_fn
if self.mesh.ndim > 1:
from functools import partial

compiler_fn = partial(
compiler_fn,
pre_pass=self._make_fuse_allgather_pass(),
)
aot_config = self.joint_with_descriptors._aot_state.aot_config
out_spec = self.joint_with_descriptors.out_spec

Expand Down
263 changes: 263 additions & 0 deletions autoparallel/graph_passes/fuse_allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch

logger: logging.Logger = logging.getLogger(__name__)


def _is_all_gather(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
)


def _is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)


def _is_nontrivial_dim_reorder(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
if node.target == torch.ops.aten.t.default:
return True
if node.target == torch.ops.aten.transpose.int:
return node.args[1] != node.args[2]
if node.target == torch.ops.aten.permute.default and isinstance(
node.args[1], (list, tuple)
):
dims = list(node.args[1])
return dims != list(range(len(dims)))
return False


def _is_identity_view_chain(start: torch.fx.Node, end: torch.fx.Node) -> bool:
"""Check that the view-op chain from start to end composes to the identity.

Walks forward from ``start`` through single-user view ops and verifies
that the composed transformation doesn't change the data layout.
Uses FakeTensor metadata: if the output of ``start`` and the input of
``end`` have the same shape and stride, the chain is an identity
(no data rearrangement, just metadata changes that cancel).

Only allows ops that are true views (no data copy, no element removal):
permute, transpose, t, view, reshape, expand, unsqueeze, squeeze.
Rejects slice (can drop elements) and any non-view op.

Returns False for empty chains or chains with no non-trivial dimension
reorder, since consecutive allgathers on different subgroups have
incompatible rank orderings without explicit layout reconciliation.
"""
_ALLOWED_VIEW_OPS = frozenset(
{
torch.ops.aten.permute.default,
torch.ops.aten.transpose.int,
torch.ops.aten.t.default,
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
torch.ops.aten.expand.default,
torch.ops.aten.unsqueeze.default,
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze.dim,
}
)

# Reject empty chains: no view ops means no layout reconciliation.
users = list(start.users.keys())
if len(users) == 1 and users[0] is end:
return False

start_val = start.meta.get("val")
if start_val is None:
return False
start_stride = start_val.stride()

# Walk forward from start to end, verifying all intermediate ops are
# allowed views and that some op actually reorders dimensions.
node = start
saw_dim_reorder = False
while node is not end:
users = list(node.users.keys())
if len(users) != 1:
return False
node = users[0]
if node is end:
break
if node.op != "call_function" or node.target not in _ALLOWED_VIEW_OPS:
return False
if _is_nontrivial_dim_reorder(node):
saw_dim_reorder = True

if not saw_dim_reorder:
return False

# Verify the composed transformation is identity via FakeTensor metadata.
ag2_input = end.args[0]
end_val = (
ag2_input.meta.get("val") if isinstance(ag2_input, torch.fx.Node) else None
)

if end_val is None:
return False
if start_val.shape != end_val.shape:
return False
if start_stride != end_val.stride():
return False
return True


def fuse_chained_allgathers(
graph: torch.fx.Graph,
full_group_size: int,
full_group_name: str,
subgroup_order: dict[str, int] | None = None,
) -> int:
"""Fuse consecutive allgather chains on different subgroups into a single allgather.

Detects chains of two allgathers on different process groups connected
through single-user view ops that compose to the identity::

ag1 = all_gather(x, size1, pg1)
wait1 = wait_tensor(ag1)
... = identity_view_ops(wait1)
ag2 = all_gather(..., size2, pg2)
wait2 = wait_tensor(ag2)

and replaces them with::

full_ag = all_gather(x, size1 * size2, full_pg)
full_wait = wait_tensor(full_ag)

Requirements:
- The two group sizes must multiply to ``full_group_size``.
- Every node between the two allgathers must have exactly one user.
- The view ops between them must compose to the identity (verified
via FakeTensor shape and stride metadata).
- Both allgathers must have the same dtype.
- When ``subgroup_order`` is provided, both process groups must be in
that mapping and appear in descending mesh-dim order.

Returns the number of fusions performed.
"""
fusions = 0
all_nodes = list(graph.nodes)

for ag2 in all_nodes:
if not _is_all_gather(ag2):
continue

# Walk ag2's input backward through single-user nodes to find wait1.
node = ag2.args[0]
if not isinstance(node, torch.fx.Node):
continue

# Find the wait_tensor that starts the chain.
wait1 = node
while not _is_wait_tensor(wait1):
if len(wait1.users) != 1:
break
if len(wait1.args) == 0:
break
inp = wait1.args[0]
if not isinstance(inp, torch.fx.Node):
break
wait1 = inp

if not _is_wait_tensor(wait1):
continue
if len(wait1.users) != 1:
continue

ag1 = wait1.args[0]
if not isinstance(ag1, torch.fx.Node) or not _is_all_gather(ag1):
continue
if len(ag1.users) != 1:
continue

# Validate that the view chain between wait1 and ag2 is identity.
if not _is_identity_view_chain(wait1, ag2):
continue

# Validate group sizes.
ag1_group_size = ag1.args[1]
ag2_group_size = ag2.args[1]
if ag1_group_size * ag2_group_size != full_group_size:
continue

# Validate group names.
ag1_group = ag1.args[2]
ag2_group = ag2.args[2]
assert isinstance(ag1_group, str)
assert isinstance(ag2_group, str)
if ag1_group == ag2_group:
continue
if subgroup_order is not None:
if ag1_group not in subgroup_order or ag2_group not in subgroup_order:
continue
if subgroup_order[ag1_group] >= subgroup_order[ag2_group]:
continue

# Validate matching dtype.
ag1_val = ag1.meta.get("val")
ag2_val = ag2.meta.get("val")
if (
ag1_val is not None
and ag2_val is not None
and ag1_val.dtype != ag2_val.dtype
):
continue

# Find wait2.
wait2 = None
for user in ag2.users:
if _is_wait_tensor(user):
wait2 = user
break
if wait2 is None:
continue

# Build the fused allgather.
original_input = ag1.args[0]

with graph.inserting_before(ag2):
full_ag = graph.call_function(
torch.ops._c10d_functional.all_gather_into_tensor.default,
args=(original_input, full_group_size, full_group_name),
)
full_ag.meta.update(ag2.meta)

full_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(full_ag,),
)
full_wait.meta.update(wait2.meta)

wait2.replace_all_uses_with(full_wait)
fusions += 1

logger.debug(
"Fused ag(%s, gs=%d, pg=%s) + ag(gs=%d, pg=%s) -> ag(gs=%d, pg=%s)",
original_input,
ag1_group_size,
ag1_group,
ag2_group_size,
ag2_group,
full_group_size,
full_group_name,
)

if fusions > 0:
graph.eliminate_dead_code()
logger.info(
"Fused %d chained allgather pairs into full-mesh allgathers", fusions
)

return fusions
Loading
Loading