diff --git a/autoparallel/api.py b/autoparallel/api.py index a938d351..bd857381 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -43,6 +43,7 @@ cleanup_graph, update_joint_with_descriptors, ) +from .graph_passes.make_collectives_contiguous import make_collectives_contiguous from .init_weights import hook_params_setters from .optimize_sharding import ShardingOptimizer from .shardings.placement_options import ( @@ -657,6 +658,7 @@ def _apply_placement_common(self, sharding_placement): # clean it up by removing the added aliases from previous pass # as well as redundant views cleanup_graph(parallel_gm, aggressive=True) + make_collectives_contiguous(parallel_gm) t_cleanup = time.perf_counter() trace_structured( diff --git a/autoparallel/graph_passes/make_collectives_contiguous.py b/autoparallel/graph_passes/make_collectives_contiguous.py new file mode 100644 index 00000000..d74e3c53 --- /dev/null +++ b/autoparallel/graph_passes/make_collectives_contiguous.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +_COLLECTIVES_REQUIRING_CONTIGUOUS = { + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, +} + + +def make_collectives_contiguous(gm: torch.fx.GraphModule) -> None: + """Insert clone(memory_format=contiguous) before collectives that require it. + + NCCL collectives like all_gather_into_tensor and reduce_scatter_tensor + require contiguous input tensors. When AP inserts these collectives, the + input may be non-contiguous (e.g. after a transpose or view). This pass + walks the graph and inserts a contiguous clone on any such input. + """ + graph = gm.graph + for node in list(graph.nodes): + if ( + node.op != "call_function" + or node.target not in _COLLECTIVES_REQUIRING_CONTIGUOUS + ): + continue + tensor_arg = node.args[0] + if not isinstance(tensor_arg, torch.fx.Node): + continue + # Skip if the input is already a contiguous clone + if ( + tensor_arg.op == "call_function" + and tensor_arg.target == torch.ops.aten.clone.default + and len(tensor_arg.kwargs) > 0 + and tensor_arg.kwargs.get("memory_format") == torch.contiguous_format + ): + continue + with graph.inserting_before(node): + clone_node = graph.call_function( + torch.ops.aten.clone.default, + args=(tensor_arg,), + kwargs={"memory_format": torch.contiguous_format}, + ) + clone_node.meta.update(tensor_arg.meta) + node.replace_input_with(tensor_arg, clone_node) + gm.recompile() diff --git a/tests/test_make_collectives_contiguous.py b/tests/test_make_collectives_contiguous.py new file mode 100644 index 00000000..6afe90f6 --- /dev/null +++ b/tests/test_make_collectives_contiguous.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.fx + +from autoparallel.graph_passes.make_collectives_contiguous import ( + make_collectives_contiguous, +) + + +def _count_ops(gm, target): + return len(gm.graph.find_nodes(op="call_function", target=target)) + + +def _build_graph_with_collective(collective_target): + """Build a simple FX graph: placeholder -> collective -> output.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(8) + collective = graph.call_function(collective_target, args=(x, 2, "0")) + collective.meta["val"] = torch.randn(16) + output = graph.output(collective) + output.meta["val"] = collective.meta["val"] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + return gm + + +def test_all_gather_gets_contiguous_clone(): + target = torch.ops._c10d_functional.all_gather_into_tensor.default + gm = _build_graph_with_collective(target) + + assert _count_ops(gm, torch.ops.aten.clone.default) == 0 + make_collectives_contiguous(gm) + assert _count_ops(gm, torch.ops.aten.clone.default) == 1 + + # The clone should be the input to the collective + for node in gm.graph.nodes: + if node.target == target: + clone_node = node.args[0] + assert clone_node.target == torch.ops.aten.clone.default + assert clone_node.kwargs["memory_format"] == torch.contiguous_format + + +def test_reduce_scatter_gets_contiguous_clone(): + target = torch.ops._c10d_functional.reduce_scatter_tensor.default + gm = _build_graph_with_collective(target) + + make_collectives_contiguous(gm) + assert _count_ops(gm, torch.ops.aten.clone.default) == 1 + + +def test_already_contiguous_clone_is_not_duplicated(): + """If the input is already a contiguous clone, don't insert another.""" + target = torch.ops._c10d_functional.all_gather_into_tensor.default + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(8) + clone = graph.call_function( + torch.ops.aten.clone.default, + args=(x,), + kwargs={"memory_format": torch.contiguous_format}, + ) + clone.meta["val"] = x.meta["val"] + collective = graph.call_function(target, args=(clone, 2, "0")) + collective.meta["val"] = torch.randn(16) + output = graph.output(collective) + output.meta["val"] = collective.meta["val"] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + make_collectives_contiguous(gm) + # Should still be exactly 1 clone, not 2 + assert _count_ops(gm, torch.ops.aten.clone.default) == 1 + + +def test_non_collective_ops_untouched(): + """Ops that aren't collectives should not get a clone inserted.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(4, 4) + add = graph.call_function(torch.ops.aten.add.Tensor, args=(x, x)) + add.meta["val"] = torch.randn(4, 4) + output = graph.output(add) + output.meta["val"] = add.meta["val"] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + make_collectives_contiguous(gm) + assert _count_ops(gm, torch.ops.aten.clone.default) == 0 + + +def test_multiple_collectives(): + """Each collective gets its own contiguous clone.""" + ag_target = torch.ops._c10d_functional.all_gather_into_tensor.default + rs_target = torch.ops._c10d_functional.reduce_scatter_tensor.default + + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(8) + ag = graph.call_function(ag_target, args=(x, 2, "0")) + ag.meta["val"] = torch.randn(16) + rs = graph.call_function(rs_target, args=(ag, "sum", 2, "0")) + rs.meta["val"] = torch.randn(8) + output = graph.output(rs) + output.meta["val"] = rs.meta["val"] + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + make_collectives_contiguous(gm) + assert _count_ops(gm, torch.ops.aten.clone.default) == 2 + + +def test_shared_input_gets_separate_clones(): + """When two collectives share the same input, each gets its own clone.""" + target = torch.ops._c10d_functional.all_gather_into_tensor.default + + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(8) + ag1 = graph.call_function(target, args=(x, 2, "0")) + ag1.meta["val"] = torch.randn(16) + ag2 = graph.call_function(target, args=(x, 4, "1")) + ag2.meta["val"] = torch.randn(32) + output = graph.output((ag1, ag2)) + output.meta["val"] = (ag1.meta["val"], ag2.meta["val"]) + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + + make_collectives_contiguous(gm) + assert _count_ops(gm, torch.ops.aten.clone.default) == 2