Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
cleanup_graph,
update_joint_with_descriptors,
)
from .graph_passes.make_collectives_contiguous import make_collectives_contiguous
from .init_weights import hook_params_setters
from .optimize_sharding import ShardingOptimizer
from .shardings.placement_options import (
Expand Down Expand Up @@ -657,6 +658,7 @@ def _apply_placement_common(self, sharding_placement):
# clean it up by removing the added aliases from previous pass
# as well as redundant views
cleanup_graph(parallel_gm, aggressive=True)
make_collectives_contiguous(parallel_gm)
t_cleanup = time.perf_counter()

trace_structured(
Expand Down
48 changes: 48 additions & 0 deletions autoparallel/graph_passes/make_collectives_contiguous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch

_COLLECTIVES_REQUIRING_CONTIGUOUS = {
torch.ops._c10d_functional.all_gather_into_tensor.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
}


def make_collectives_contiguous(gm: torch.fx.GraphModule) -> None:
"""Insert clone(memory_format=contiguous) before collectives that require it.

NCCL collectives like all_gather_into_tensor and reduce_scatter_tensor
require contiguous input tensors. When AP inserts these collectives, the
input may be non-contiguous (e.g. after a transpose or view). This pass
walks the graph and inserts a contiguous clone on any such input.
"""
graph = gm.graph
for node in list(graph.nodes):
if (
node.op != "call_function"
or node.target not in _COLLECTIVES_REQUIRING_CONTIGUOUS
):
continue
tensor_arg = node.args[0]
if not isinstance(tensor_arg, torch.fx.Node):
continue
# Skip if the input is already a contiguous clone
if (
tensor_arg.op == "call_function"
and tensor_arg.target == torch.ops.aten.clone.default
and len(tensor_arg.kwargs) > 0
and tensor_arg.kwargs.get("memory_format") == torch.contiguous_format
):
continue
with graph.inserting_before(node):
clone_node = graph.call_function(
torch.ops.aten.clone.default,
args=(tensor_arg,),
kwargs={"memory_format": torch.contiguous_format},
)
clone_node.meta.update(tensor_arg.meta)
node.replace_input_with(tensor_arg, clone_node)
gm.recompile()
129 changes: 129 additions & 0 deletions tests/test_make_collectives_contiguous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.fx

from autoparallel.graph_passes.make_collectives_contiguous import (
make_collectives_contiguous,
)


def _count_ops(gm, target):
return len(gm.graph.find_nodes(op="call_function", target=target))


def _build_graph_with_collective(collective_target):
"""Build a simple FX graph: placeholder -> collective -> output."""
graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(8)
collective = graph.call_function(collective_target, args=(x, 2, "0"))
collective.meta["val"] = torch.randn(16)
output = graph.output(collective)
output.meta["val"] = collective.meta["val"]
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
return gm


def test_all_gather_gets_contiguous_clone():
target = torch.ops._c10d_functional.all_gather_into_tensor.default
gm = _build_graph_with_collective(target)

assert _count_ops(gm, torch.ops.aten.clone.default) == 0
make_collectives_contiguous(gm)
assert _count_ops(gm, torch.ops.aten.clone.default) == 1

# The clone should be the input to the collective
for node in gm.graph.nodes:
if node.target == target:
clone_node = node.args[0]
assert clone_node.target == torch.ops.aten.clone.default
assert clone_node.kwargs["memory_format"] == torch.contiguous_format


def test_reduce_scatter_gets_contiguous_clone():
target = torch.ops._c10d_functional.reduce_scatter_tensor.default
gm = _build_graph_with_collective(target)

make_collectives_contiguous(gm)
assert _count_ops(gm, torch.ops.aten.clone.default) == 1


def test_already_contiguous_clone_is_not_duplicated():
"""If the input is already a contiguous clone, don't insert another."""
target = torch.ops._c10d_functional.all_gather_into_tensor.default
graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(8)
clone = graph.call_function(
torch.ops.aten.clone.default,
args=(x,),
kwargs={"memory_format": torch.contiguous_format},
)
clone.meta["val"] = x.meta["val"]
collective = graph.call_function(target, args=(clone, 2, "0"))
collective.meta["val"] = torch.randn(16)
output = graph.output(collective)
output.meta["val"] = collective.meta["val"]
gm = torch.fx.GraphModule(torch.nn.Module(), graph)

make_collectives_contiguous(gm)
# Should still be exactly 1 clone, not 2
assert _count_ops(gm, torch.ops.aten.clone.default) == 1


def test_non_collective_ops_untouched():
"""Ops that aren't collectives should not get a clone inserted."""
graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(4, 4)
add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x))
add.meta["val"] = torch.randn(4, 4)
output = graph.output(add)
output.meta["val"] = add.meta["val"]
gm = torch.fx.GraphModule(torch.nn.Module(), graph)

make_collectives_contiguous(gm)
assert _count_ops(gm, torch.ops.aten.clone.default) == 0


def test_multiple_collectives():
"""Each collective gets its own contiguous clone."""
ag_target = torch.ops._c10d_functional.all_gather_into_tensor.default
rs_target = torch.ops._c10d_functional.reduce_scatter_tensor.default

graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(8)
ag = graph.call_function(ag_target, args=(x, 2, "0"))
ag.meta["val"] = torch.randn(16)
rs = graph.call_function(rs_target, args=(ag, "sum", 2, "0"))
rs.meta["val"] = torch.randn(8)
output = graph.output(rs)
output.meta["val"] = rs.meta["val"]
gm = torch.fx.GraphModule(torch.nn.Module(), graph)

make_collectives_contiguous(gm)
assert _count_ops(gm, torch.ops.aten.clone.default) == 2


def test_shared_input_gets_separate_clones():
"""When two collectives share the same input, each gets its own clone."""
target = torch.ops._c10d_functional.all_gather_into_tensor.default

graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.randn(8)
ag1 = graph.call_function(target, args=(x, 2, "0"))
ag1.meta["val"] = torch.randn(16)
ag2 = graph.call_function(target, args=(x, 4, "1"))
ag2.meta["val"] = torch.randn(32)
output = graph.output((ag1, ag2))
output.meta["val"] = (ag1.meta["val"], ag2.meta["val"])
gm = torch.fx.GraphModule(torch.nn.Module(), graph)

make_collectives_contiguous(gm)
assert _count_ops(gm, torch.ops.aten.clone.default) == 2
Loading