From bbedcebad444cd035118da8bdcfcca37011222f7 Mon Sep 17 00:00:00 2001 From: Sanket Purandare Date: Thu, 30 Apr 2026 19:38:51 -0700 Subject: [PATCH] Remove pipeline parallelism (graph_pp) from AutoParallel Graph-based pipeline parallelism is moving to its own package. This removes all PP modules (api_pp, graph_pp_runner, graph_partition, split_fsdp_collectives, split_di_dw_graph, graph_multiplex), the two PP examples (example_ds3_pp, example_pp_graph_passes), the numerics comparison script (run_ds3_numerics_check), and their test (test_api_pp). Files edited: - __init__.py: removed AutoParallelPP export - placement_options.py: removed log_pp_model_weights/log_pp_grads from NumericsLogger (only called from the deleted PP example) - dsv3.py: removed DeepSeekV3Stage{0,I,N} classes and simplified _init_weights_* type annotations from Union[Model, Stage] to Model - test_cuda.yml: removed commented-out PP CI entries activation_checkpointing.py, module_construction.py, and graph_utils.py are shared with the non-PP codepath and left untouched. Validated: pytest tests/ passes (327 tests, 1 xfail). stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/439, branch: sanketpurandare/stack/1 --- .github/workflows/test_cuda.yml | 4 - autoparallel/__init__.py | 2 - autoparallel/_testing/models/dsv3.py | 79 +- autoparallel/api_pp.py | 210 ---- autoparallel/graph_passes/graph_multiplex.py | 236 ---- autoparallel/graph_passes/graph_partition.py | 101 -- autoparallel/graph_passes/graph_pp_runner.py | 1067 ----------------- .../graph_passes/split_di_dw_graph.py | 138 --- .../graph_passes/split_fsdp_collectives.py | 174 --- autoparallel/shardings/placement_options.py | 57 - examples/example_ds3_pp.py | 742 ------------ examples/example_pp_graph_passes.py | 424 ------- examples/run_ds3_numerics_check.py | 97 -- tests/test_api_pp.py | 304 ----- 14 files changed, 4 insertions(+), 3631 deletions(-) delete mode 100644 autoparallel/api_pp.py delete mode 100644 autoparallel/graph_passes/graph_multiplex.py delete mode 100644 autoparallel/graph_passes/graph_partition.py delete mode 100644 autoparallel/graph_passes/graph_pp_runner.py delete mode 100644 autoparallel/graph_passes/split_di_dw_graph.py delete mode 100644 autoparallel/graph_passes/split_fsdp_collectives.py delete mode 100644 examples/example_ds3_pp.py delete mode 100644 examples/example_pp_graph_passes.py delete mode 100644 examples/run_ds3_numerics_check.py delete mode 100644 tests/test_api_pp.py diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index ac522718..b031c793 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -56,8 +56,6 @@ jobs: run_timed python examples/example_autoparallel.py run_timed python examples/example_llama3.py run_timed python examples/example_local_map.py - # TODO(#436): Re-enable once OpStrategy.__str__ handles None specs in PyTorch. - # run_timed python examples/example_pp_graph_passes.py echo "========== Timings ==========" cat /tmp/timings.txt @@ -83,5 +81,3 @@ jobs: python examples/example_dcp.py # TODO(#436): Re-enable once OpStrategy.__str__ handles None specs in PyTorch. # torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py - # Skipped: graph PP is being moved out of AutoParallel shortly. - # torchrun --standalone --nproc_per_node=4 examples/example_ds3_pp.py --use-loss-fn --fake-evaluate diff --git a/autoparallel/__init__.py b/autoparallel/__init__.py index 70ca3d22..384b66ae 100644 --- a/autoparallel/__init__.py +++ b/autoparallel/__init__.py @@ -4,14 +4,12 @@ # LICENSE file in the root directory of this source tree. from autoparallel.api import AutoParallel, auto_parallel -from autoparallel.api_pp import AutoParallelPP from autoparallel.collectives import with_sharding_constraint from autoparallel.compile import autoparallel_backend __all__ = [ "auto_parallel", "AutoParallel", - "AutoParallelPP", "autoparallel_backend", "with_sharding_constraint", ] diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 6054ad41..863becf1 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -5,7 +5,7 @@ import math from dataclasses import dataclass, field -from typing import Callable, ClassVar, Literal, Optional, Tuple, Union +from typing import Callable, ClassVar, Literal, Optional, Tuple import torch import torch.fx.traceback as fx_traceback @@ -1621,78 +1621,7 @@ def dsv3_loss_fn(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: ) -######################## -# Pipeline stuff start # -######################## - - -class DeepSeekV3StageI(nn.Module): - def __init__(self, layers, model_args): - super().__init__() - self.layers = layers - self.register_buffer( - "freqs_cis", precompute_freqs_cis(model_args), persistent=False - ) - self.model_args = model_args - - def forward(self, h): - # intermediate stages only have layers - for layer in self.layers.values(): - h = layer(h, self.freqs_cis) - return h - - def init_weights( - self, buffer_device: torch.device | None = None, seed: int | None = None - ) -> None: - _init_weights_layers(self, buffer_device, seed) - - -class DeepSeekV3Stage0(DeepSeekV3StageI): - def __init__(self, embed, layers, model_args): - super().__init__(layers, model_args) - self.tok_embeddings = embed - - def forward(self, tokens): - # torch.Size([1024, 1024]) - h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens - # torch.Size([1024, 1024, 2048]) - return super().forward(h) - - def init_weights( - self, buffer_device: torch.device | None = None, seed: int | None = None - ) -> None: - _init_weights_tok_embeddings(self, seed) - super().init_weights(buffer_device, seed) - - -class DeepSeekV3StageN(DeepSeekV3StageI): - def __init__(self, layers, norm, output, model_args): - super().__init__(layers, model_args) - self.norm = norm - self.output = output - self.model_args = model_args - - def forward(self, h): - h = super().forward(h) - h = self.norm(h) if self.norm is not None else h - output = self.output(h) if self.output is not None else h - return output - - def init_weights( - self, buffer_device: torch.device | None = None, seed: int | None = None - ) -> None: - super().init_weights(buffer_device, seed) - _init_weights_norm_and_output(self) - - -###################### -# Pipeline stuff end # -###################### - - -def _init_weights_tok_embeddings( - self: Union[DeepSeekV3Model, DeepSeekV3Stage0], seed: int | None = None -): +def _init_weights_tok_embeddings(self: DeepSeekV3Model, seed: int | None = None): if seed is not None: torch.manual_seed(seed) if self.tok_embeddings is not None: @@ -1700,7 +1629,7 @@ def _init_weights_tok_embeddings( def _init_weights_layers( - self: Union[DeepSeekV3Model, DeepSeekV3StageI], + self: DeepSeekV3Model, buffer_device: torch.device | None, seed: int | None = None, ): @@ -1716,7 +1645,7 @@ def _init_weights_layers( layer.init_weights(buffer_device) # type: ignore[arg-type] -def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]): +def _init_weights_norm_and_output(self: DeepSeekV3Model): if self.norm is not None: self.norm.reset_parameters() if self.output is not None: diff --git a/autoparallel/api_pp.py b/autoparallel/api_pp.py deleted file mode 100644 index 1d075955..00000000 --- a/autoparallel/api_pp.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import Any, Optional - -import torch -from torch._logging import trace_structured - -from autoparallel.graph_passes.graph_partition import partition_joint_with_descriptors - -from .api import AutoParallel -from .module_construction import make_parallel_module - -logger = logging.getLogger(__name__) - - -def make_pp_module( - sharded_param_dict: dict[str, torch.nn.Parameter], - sharded_buffer_dict: dict[str, torch.Tensor], - ref_model: torch.nn.Module, -): - """Create an AutoParallelPPModule that inherits from the user's model class.""" - return make_parallel_module(ref_model, sharded_param_dict, sharded_buffer_dict) - - -class AutoParallelPP(AutoParallel): - def apply_placement_pp( - self, sharding_placement=None, graph_passes: list[str] = [] - ) -> dict[str, Any]: - assert all( - g_pass in ["split_fsdp_collectives", "split_dI_dW"] - for g_pass in graph_passes - ), "Only split_fsdp_collectives and split_dI_dW_graph are supported" - sharded_param_dict, sharded_buffer_dict = self._apply_placement_common( - sharding_placement - ) - num_params = len(sharded_param_dict) - num_buffers = len(sharded_buffer_dict) - ( - fw_module, - bw_module, - num_params_buffers, - num_user_outputs, - num_mutate_inputs, - num_fw_outs_saved_for_bw, - num_symints_saved_for_bw, - _indices_of_inps_to_detach, - adjusted_flat_args, - ) = partition_joint_with_descriptors(self.joint_with_descriptors) - assert num_params_buffers == ( - num_params + num_buffers - ), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}" - num_input_grads = ( - len(bw_module.graph.find_nodes(op="output")[0].args[0]) - num_params_buffers - ) - logger.info( - f"num_params_buffers: {num_params_buffers}\n" - f"num_user_outputs: {num_user_outputs}\n" - f"num_mutate_inputs: {num_mutate_inputs}\n" - f"num_input_grads: {num_input_grads}\n" - f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n" - f"num_symints_saved_for_bw: {num_symints_saved_for_bw}" - ) - - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_fwd_graph", - "encoding": "string", - }, - payload_fn=lambda: fw_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_bwd_graph", - "encoding": "string", - }, - payload_fn=lambda: bw_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - unshard_module: Optional[torch.fx.GraphModule] = None - reduce_grad_module: Optional[torch.fx.GraphModule] = None - if "split_fsdp_collectives" in graph_passes: - assert ( - not self.reshard_after_forward - ), "reshard_after_forward should be False to disable FSDP all_gather in the backward pass" - from autoparallel.graph_passes.split_fsdp_collectives import ( - split_fsdp_prefetch, - split_fsdp_reduce_scatters_epilogue, - ) - - unshard_module, fw_module = split_fsdp_prefetch(fw_module, num_params) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_unshard_graph", - "encoding": "string", - }, - payload_fn=lambda: unshard_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_fwd_no_fsdp_graph", - "encoding": "string", - }, - payload_fn=lambda: fw_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - bw_module, reduce_grad_module = split_fsdp_reduce_scatters_epilogue( - bw_module, num_params - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_bwd_no_fsdp_graph", - "encoding": "string", - }, - payload_fn=lambda: bw_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_reduce_grad_graph", - "encoding": "string", - }, - payload_fn=lambda: reduce_grad_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - - bw_dI_module: Optional[torch.fx.GraphModule] = None - bw_dW_module: Optional[torch.fx.GraphModule] = None - if "split_dI_dW" in graph_passes: - from autoparallel.graph_passes.split_di_dw_graph import split_di_dw_graph - - bw_dI_module, bw_dW_module, num_input_grads = split_di_dw_graph( - bw_module, - num_weight_gradients=num_params_buffers, - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_bw_dI_graph", - "encoding": "string", - }, - payload_fn=lambda: bw_dI_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_pp_bw_dW_graph", - "encoding": "string", - }, - payload_fn=lambda: bw_dW_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - if all( - x is None - for x in bw_dI_module.graph.find_nodes(op="output")[0].args[0][ - :num_input_grads - ] - ): - raise RuntimeError( - "attempted to run split dI/dW pass on a graph that has no input gradients" - ) - - graph_meta: dict[str, int] = { - "num_mutate_inputs": num_mutate_inputs, - "num_user_outputs": num_user_outputs, - "num_symints_saved_for_bw": num_symints_saved_for_bw, - "num_params": num_params, - "num_buffers": num_buffers, - "num_input_grads": num_input_grads, - } - - graph_modules: dict[str, Optional[torch.fx.GraphModule]] = { - "fw": fw_module, - "full_bw": bw_module, - "bw_dI": bw_dI_module, - "bw_dW": bw_dW_module, - "unshard": unshard_module, - "reduce_grad": reduce_grad_module, - } - self.parallel_model = make_pp_module( - sharded_param_dict, - sharded_buffer_dict, - self.model, - ) - return { - "graph_callables": graph_modules, - "graph_meta": graph_meta, - "sharded_param_dict": sharded_param_dict, - "sharded_buffer_dict": sharded_buffer_dict, - } diff --git a/autoparallel/graph_passes/graph_multiplex.py b/autoparallel/graph_passes/graph_multiplex.py deleted file mode 100644 index 7a676b9a..00000000 --- a/autoparallel/graph_passes/graph_multiplex.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import copy -from itertools import dropwhile - -import torch -import torch.fx as fx -from torch._inductor.fx_passes.bucketing import is_wait_tensor -from torch._logging import trace_structured - - -def _add_compute_annotations(gm: fx.GraphModule, tag: str) -> bool: - """Add compute_region annotations to nodes without custom metadata.""" - has_comm_region = False - for n in gm.graph.nodes: - if n.op == "placeholder": - continue - if n.meta.get("custom", None) is None: - n.meta["custom"] = {"compute_region": tag} - else: - if "comm_region" in n.meta["custom"]: - has_comm_region = True - val = n.meta["custom"]["comm_region"] - n.meta["custom"]["comm_region"] = tag + " " + val - elif "compute_region" in n.meta["custom"]: - val = n.meta["custom"]["compute_region"] - n.meta["custom"]["compute_region"] = tag + " " + val - else: - n.meta["custom"]["compute_region"] = tag - return has_comm_region - - -def _move_wait_tensors_to_compute_region(gm: fx.GraphModule, tag: str): - """Move the last wait_tensor node from each contiguous comm_region to the compute_region of its first user.""" - # First pass: identify the last wait_tensor in each contiguous comm region - last_waits: list[fx.Node] = [] - last_wait: fx.Node | None = None - in_comm_region = False - - for n in gm.graph.nodes: - if n.op == "placeholder": - continue - if "comm_region" in n.meta["custom"]: - in_comm_region = True - if is_wait_tensor(n): - last_wait = n - else: - # Transitioning out of a comm region — flush - if in_comm_region and last_wait is not None: - last_waits.append(last_wait) - last_wait = None - in_comm_region = False - - # Handle graph ending inside a comm region - if in_comm_region and last_wait is not None: - last_waits.append(last_wait) - - # Second pass: re-tag and move only the collected last-wait nodes - for n in last_waits: - assert len(n.users) >= 1, "wait tensor must have at least one user" - user: fx.Node = next(iter(n.users)) - if "compute_region" in user.meta["custom"]: - val = n.meta["custom"].pop("comm_region") - if tag not in val: - val = tag + " " + val - n.meta["custom"].update({"compute_region": val + " " + "wait"}) - if n.next is not user: - user.prepend(n) - - -def multiplex_fw_bw_graph( - fw_gm: fx.GraphModule, bw_gm: fx.GraphModule, overlap_with_annotations: bool = True -) -> fx.GraphModule: - """ - Multiplexes forward and backward graphs into a single unified graph module. - - This function combines a forward graph and a backward graph into one multiplexed - graph by merging their nodes and outputs. The resulting graph has: - - All placeholders from both forward and backward graphs (backward followed by forward) - - All computation nodes from both graphs (backward followed by forward) - - Combined outputs (backward outputs followed by forward outputs) - - Args: - fw_gm: The forward graph module containing the forward computation - bw_gm: The backward graph module containing the backward computation - - Returns: - A multiplexed fx.GraphModule containing both forward and backward computations - with backward outputs appearing before forward outputs - - Note: - The function preserves node metadata during the merging process. - """ - if overlap_with_annotations: - fw_has_comm = _add_compute_annotations(fw_gm, "forward") - bw_has_comm = _add_compute_annotations(bw_gm, "backward") - assert fw_has_comm and bw_has_comm, "No comm region found in either graph" - _move_wait_tensors_to_compute_region(fw_gm, "forward") - _move_wait_tensors_to_compute_region(bw_gm, "backward") - - # Mapping to track correspondence between forward graph nodes and new nodes - old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} - - # Start with a deep copy of the backward graph as the base - multiplexed_gm = copy.deepcopy(bw_gm) - - # Copy tensor constant attributes from fw_gm to multiplexed_gm with "fw_" prefix - # to avoid collision with bw's tensor constants - fw_tensor_constant_remap: dict[str, str] = {} - for attr_name in dir(fw_gm): - if attr_name.startswith("_tensor_constant"): - fw_attr = getattr(fw_gm, attr_name) - new_attr_name = ( - f"fw{attr_name}" # e.g., _tensor_constant0 -> fw_tensor_constant0 - ) - setattr(multiplexed_gm, new_attr_name, fw_attr) - fw_tensor_constant_remap[attr_name] = new_attr_name - - # Collect all placeholder nodes from all the graphs - bw_placeholders = bw_gm.graph.find_nodes(op="placeholder") - fw_placeholders = fw_gm.graph.find_nodes(op="placeholder") - insert_point = multiplexed_gm.graph.find_nodes(op="placeholder")[-1] - - # Insert forward placeholders after the backward placeholders of the multiplexed graph - for n in fw_placeholders: - with multiplexed_gm.graph.inserting_after(insert_point): - new_placeholder = multiplexed_gm.graph.placeholder(n.name) - new_placeholder.meta = copy.copy(n.meta) - new_placeholder.target = new_placeholder.name - old_node_to_new_node[n] = new_placeholder - insert_point = new_placeholder - - multiplexed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") - assert len(multiplexed_gm_placeholders) == len(fw_placeholders) + len( - bw_placeholders - ) - fw_nodes_iter = iter(fw_gm.graph.nodes) - fw_nodes_iter = dropwhile(lambda n: n.op == "placeholder", fw_nodes_iter) - # Initialize the forward node to be the first non-placeholder node - fn = next(fw_nodes_iter) - if overlap_with_annotations: - # Interleave forward and backward nodes to create overlap pattern: - # bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat] - # This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute - bw_in_comm = False - for bn in multiplexed_gm.graph.nodes: - if bn.op == "placeholder" or bn.op == "output": - continue - # Track when we enter a backward comm region - if "comm_region" in bn.meta["custom"] and not bw_in_comm: - bw_in_comm = True - # When we transition from bw_comm to bw_compute, insert forward nodes - elif "compute_region" in bn.meta["custom"] and bw_in_comm: - bw_in_comm = False - fw_in_comm = False - insert_point = bn - # Insert forward nodes before this bw_compute node - # Note: We cannot reorder nodes within a graph, only their relative order between graphs - while fn.op != "output": - if "comm_region" in fn.meta["custom"] and not fw_in_comm: - fw_in_comm = True - elif "compute_region" in fn.meta["custom"] and fw_in_comm: - # Stop when we reach the next fw_compute after fw_comm - # This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition - # If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute - fw_in_comm = False - break - with multiplexed_gm.graph.inserting_before(insert_point): - # Copy node and remap its arguments using the node mapping - new_node = multiplexed_gm.graph.node_copy( - fn, lambda x: old_node_to_new_node[x] - ) - new_node.meta = copy.copy(fn.meta) - # Remap get_attr targets for tensor constants to avoid collision - if ( - new_node.op == "get_attr" - and new_node.target in fw_tensor_constant_remap - ): - new_node.target = fw_tensor_constant_remap[ - str(new_node.target) - ] - old_node_to_new_node[fn] = new_node - fn = next(fw_nodes_iter) - # Insert any remaining forward nodes at the end - # If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes - insert_point = multiplexed_gm.graph.find_nodes(op="output")[-1] - while fn.op != "output": - with multiplexed_gm.graph.inserting_before(insert_point): - # Copy node and remap its arguments using the node mapping - new_node = multiplexed_gm.graph.node_copy( - fn, lambda x: old_node_to_new_node[x] - ) - new_node.meta = copy.copy(fn.meta) - # Remap get_attr targets for tensor constants to avoid collision - if ( - new_node.op == "get_attr" - and new_node.target in fw_tensor_constant_remap - ): - new_node.target = fw_tensor_constant_remap[str(new_node.target)] - old_node_to_new_node[fn] = new_node - fn = next(fw_nodes_iter) - - # Collect output arguments from forward graph, remapping to new nodes - fw_outputs = fw_gm.graph.find_nodes(op="output") - multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") - assert len(multiplexed_graph_outputs) == 1 and len(fw_outputs) == 1 - fw_graph_op_node = fw_outputs[0] - fw_op_node_args = [ - old_node_to_new_node[n] if n is not None else None - for n in fw_graph_op_node.args[0] - ] - - # Collect output arguments from multiplexed graph (will contain only bwd_outs) - multiplexed_graph_op_node = multiplexed_graph_outputs[0] - bw_op_node_args = list(multiplexed_graph_op_node.args[0]) - - # Update output node args to prepend backward outputs before forward outputs - multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),) - - multiplexed_gm.graph.eliminate_dead_code() - multiplexed_gm.graph.lint() - multiplexed_gm.recompile() - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "autoparallel_multiplexed_graph", - "encoding": "string", - }, - payload_fn=lambda: multiplexed_gm.print_readable( - print_output=False, include_stride=True, include_device=True - ), - ) - return multiplexed_gm diff --git a/autoparallel/graph_passes/graph_partition.py b/autoparallel/graph_passes/graph_partition.py deleted file mode 100644 index 87f894f7..00000000 --- a/autoparallel/graph_passes/graph_partition.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Callable - -import torch -from torch._functorch._aot_autograd.graph_compile import ( - _aot_stage2a_partition, - _apply_tensorify_python_scalars, -) -from torch._functorch.aot_autograd import ( - AOTConfig, - AOTGraphCapture, - AOTState, - JointWithDescriptors, - OutputType, - ViewAndMutationMeta, - boxed_nop_preserve_node_meta, - default_partition, -) - - -def partition_joint_with_descriptors( - jd: JointWithDescriptors, - *, - partition_fn: Callable = default_partition, - fw_compiler: Callable = boxed_nop_preserve_node_meta, - bw_compiler: Callable = boxed_nop_preserve_node_meta, -) -> tuple[ - torch.fx.GraphModule, - torch.fx.GraphModule, - int, - int, - int, - int, - int, - list[int], - list[Any], -]: - aot_state: AOTState = jd._aot_state - aot_graph_capture: AOTGraphCapture = jd._aot_graph_capture - # Update the AOTState with the provided compilers - aot_state.aot_config.partition_fn = partition_fn - aot_state.aot_config.fw_compiler = fw_compiler - aot_state.aot_config.bw_compiler = bw_compiler - aot_state.aot_config.inference_compiler = fw_compiler - - fx_g: torch.fx.GraphModule = aot_graph_capture.graph_module - maybe_subclass_meta: Any = aot_graph_capture.maybe_subclass_meta - fw_metadata: ViewAndMutationMeta = aot_state.fw_metadata - aot_config: AOTConfig = aot_state.aot_config - - # AOTAutogradStage2a: Partition the graph into forward and backward graphs and - # return the some metadata about the partitioning. - - _apply_tensorify_python_scalars(fx_g) - - ( - fw_module, - bw_module, - num_fw_outs_saved_for_bw, - num_symints_saved_for_bw, - _indices_of_inps_to_detach, - adjusted_flat_args, - ) = _aot_stage2a_partition( - fx_g, - aot_graph_capture.updated_flat_args, - maybe_subclass_meta, - fw_metadata, - aot_config, - ) - - num_user_outputs = ( - len( - [ - x - for x in fw_metadata.output_info - if x.output_type - in (OutputType.non_alias, OutputType.alias_of_intermediate) - ] - ) - + fw_metadata.num_intermediate_bases - ) - - num_mutate_inputs = len( - [x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata] - ) - num_params_buffers = aot_config.num_params_buffers - return ( - fw_module, - bw_module, - num_params_buffers, - num_user_outputs, - num_mutate_inputs, - num_fw_outs_saved_for_bw, - num_symints_saved_for_bw, - _indices_of_inps_to_detach, - adjusted_flat_args, - ) diff --git a/autoparallel/graph_passes/graph_pp_runner.py b/autoparallel/graph_passes/graph_pp_runner.py deleted file mode 100644 index 12ad9cd9..00000000 --- a/autoparallel/graph_passes/graph_pp_runner.py +++ /dev/null @@ -1,1067 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from dataclasses import dataclass -from typing import Any, Callable, Optional, Protocol, Union, cast - -import torch -import torch.fx as fx -from torch.distributed.pipelining.schedules import ( - FULL_BACKWARD, - _Action, - _PipelineContext, - _PipelineScheduleRuntime, - _wait_batch_p2p, -) -from torch.distributed.pipelining.stage import ( - PipelineStage, - _normalize_model_output_as_tuple, -) -from torch.distributed.tensor import DTensor - -from autoparallel.shardings.placement_options import DebugInterpreter, NumericsLogger - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -def _execute_graph( - gm: fx.GraphModule, args: list[Any], *, inductor: bool = False -) -> Any: - """Execute a graph module, optionally compiling with Inductor on first call.""" - if inductor: - if not hasattr(gm, "_compiled"): - from torch._inductor.compile_fx import compile_fx_inner - - gm._compiled = compile_fx_inner(gm, args) # type: ignore[assignment, attr-defined] - return gm._compiled(args) # type: ignore[operator, attr-defined] - return fx.Interpreter(gm).boxed_run(args) - - -@dataclass -class GraphCallables: - fw: fx.GraphModule - full_bw: fx.GraphModule - bw_dI: Optional[fx.GraphModule] = None - bw_dW: Optional[fx.GraphModule] = None - unshard: Optional[fx.GraphModule] = None - reduce_grad: Optional[fx.GraphModule] = None - - -@dataclass -class GraphMeta: - num_mutate_inputs: int - num_user_outputs: int - num_symints_saved_for_bw: int - num_params: int - num_buffers: int - num_input_grads: int - - -class MultiplexFwBwGraphPass(Protocol): - """Protocol defining the contract for forward-backward graph multiplexing passes. - - Implementations must accept two GraphModules (forward and backward) and return a fused - GraphModule that multiplexes their execution. - - Contract Requirements: - 1. Input placeholders ordering: The returned GraphModule's placeholders must be ordered - as ``bw_placeholders + fw_placeholders`` (backward placeholders concatenated with - forward placeholders, each maintaining their original order from the input graphs). - - 2. Output node args ordering: The returned GraphModule's output node args must contain - outputs ordered as ``bw_outputs + fw_outputs`` (backward outputs concatenated with - forward outputs, each maintaining their original order from the input graphs). - - Example:: - - def my_multiplex_pass( - fw_graph: fx.GraphModule, - bw_graph: fx.GraphModule - ) -> fx.GraphModule: - # Implementation that satisfies the contract - ... - return multiplexed_graph - """ - - def __call__( - self, - fw_graph: fx.GraphModule, - bw_graph: fx.GraphModule, - ) -> fx.GraphModule: - """Multiplex forward and backward graphs into a single fused graph. - - Args: - fw_graph (fx.GraphModule): Forward graph module. - bw_graph (fx.GraphModule): Backward graph module. - - Returns: - fx.GraphModule: Fused graph module satisfying the contract requirements. - """ - ... - - -def get_multiplexed_graph_callables( - stage_graphs: dict[int, GraphCallables], - multiplex_fw_bw_graph_pass: MultiplexFwBwGraphPass, -) -> dict[tuple[int, int], fx.GraphModule]: - """Generate multiplexed graph modules that fuse forward and backward passes from different stages. - - Creates fused modules for all stage pairs where fw_stage_idx != bw_stage_idx. This enables - pipeline schedules (e.g., DualPipe) to overlap communication with computation. - - Args: - stage_graphs (dict[int, GraphCallables]): Mapping from stage index to GraphCallables - containing forward/backward modules. - multiplex_fw_bw_graph_pass (MultiplexFwBwGraphPass): A callable that takes two - GraphModules (forward and backward) and returns a fused GraphModule that multiplexes - their execution. Must satisfy the contract defined in - :class:`MultiplexFwBwGraphPass`. - - Returns: - dict[tuple[int, int], fx.GraphModule]: Mapping from (fw_stage_idx, bw_stage_idx) to fused - GraphModule that executes forward from fw_stage_idx and backward from bw_stage_idx. - """ - multiplexed_graph_callables: dict[tuple[int, int], torch.fx.GraphModule] = {} - for bw_stage_idx, bw_stage_graph_callables in stage_graphs.items(): - for fw_stage_idx, fw_stage_graph_callables in stage_graphs.items(): - if bw_stage_idx != fw_stage_idx: - fw_bw_module = multiplex_fw_bw_graph_pass( - fw_stage_graph_callables.fw, - bw_stage_graph_callables.full_bw, - ) - multiplexed_graph_callables[(fw_stage_idx, bw_stage_idx)] = fw_bw_module - return multiplexed_graph_callables - - -class GraphPipelineStage(PipelineStage): - def __init__( - self, - submodule: torch.nn.Module, - graph_callables: GraphCallables, - graph_meta: GraphMeta, - stage_index: int, - num_stages: int, - device: torch.device, - input_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, - output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, - group: Optional[torch.distributed.ProcessGroup] = None, - dw_builder: Optional[Callable[[], Callable[..., None]]] = None, - numerics_logger: Optional[NumericsLogger] = None, - should_log_fw_outs: bool = False, - ): - super().__init__( - submodule=submodule, - stage_index=stage_index, - num_stages=num_stages, - device=device, - input_args=input_args, - output_args=output_args, - group=group, - dw_builder=dw_builder, - ) - self.numerics_logger = numerics_logger - self.should_log_fw_outs = should_log_fw_outs - self.graph_callables = graph_callables - self.graph_meta = graph_meta - self.state: dict[str, list[Any]] = { - "sharded_params": [], - "unsharded_params": [], - "buffers": [], - "sharded_grads": [], - "unsharded_grads": [], - } - self.inductor: bool = False - self.bwd_activation_cache: dict[int, tuple[Any]] = {} - - def scale_grads(self, grad_scale_factor: int) -> None: - """Scale stage's gradients by `grad_scale_factor`, which should be specified in coordination with the - loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor` - should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should - be set to 1. - - Should only be called once per pipeline schedule step, after all backwards passes have completed. - """ - - # PP scales only for its own contribution (microbatches), but relies on DP to scale further - # for DP degree. - if grad_scale_factor != 1: - for grad in self.state["unsharded_grads"]: - if grad is not None: - grad.div_(grad_scale_factor) - - def _accumulate_stage_unsharded_grads( - self, - param_buffer_grads: list[Union[torch.Tensor, None]], - ) -> None: - unsharded_grads = self.state["unsharded_grads"] - grads_to_accumulate = param_buffer_grads[: self.graph_meta.num_params] - assert len(unsharded_grads) == len(grads_to_accumulate) - assert not all( - grad is None for grad in grads_to_accumulate - ), "All grads are None" - for i in range(len(unsharded_grads)): - if grads_to_accumulate[i] is not None: - if unsharded_grads[i] is None: - unsharded_grads[i] = grads_to_accumulate[i] - else: - unsharded_grads[i] += grads_to_accumulate[i] - - -def _run_fw_module( - fw_module: fx.GraphModule, - graph_meta: GraphMeta, - fw_args: list[Any], - numerics_logs: Optional[list[str]] = None, - inductor: bool = False, -) -> tuple[Any, tuple[tuple[Any], tuple[Any]]]: - if numerics_logs is not None: - debug_interpreter = DebugInterpreter(fw_module) - fw_outputs = debug_interpreter.boxed_run(fw_args) - numerics_logs += debug_interpreter.get_logs() - else: - fw_outputs = _execute_graph(fw_module, fw_args, inductor=inductor) - - num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs - saved_intermediates = fw_outputs[num_inner_fwd_outputs:] - num_tensors_for_backward = ( - len(saved_intermediates) - graph_meta.num_symints_saved_for_bw - ) - tensors_for_backward = saved_intermediates[:num_tensors_for_backward] - non_tensors_for_backward = saved_intermediates[num_tensors_for_backward:] - save_for_backward = (tensors_for_backward, non_tensors_for_backward) - user_outputs = fw_outputs[graph_meta.num_mutate_inputs : num_inner_fwd_outputs] - if len(user_outputs) == 1: - user_outputs = user_outputs[0] - return user_outputs, save_for_backward - - -def _run_full_bw_module( - bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args, inductor: bool = False -) -> tuple[list[Any], list[Any]]: - bw_outputs = _execute_graph(bw_module, bw_args, inductor=inductor) - num_params_buffers = graph_meta.num_params + graph_meta.num_buffers - param_buffer_grads = bw_outputs[:num_params_buffers] - input_grads = bw_outputs[num_params_buffers:] - return input_grads, param_buffer_grads - - -def _run_dI_bw_module( - bw_dI_module: fx.GraphModule, - graph_meta: GraphMeta, - bw_dI_args, - inductor: bool = False, -) -> tuple[list[Any], list[Any]]: - inp_grads_and_activations = _execute_graph( - bw_dI_module, bw_dI_args, inductor=inductor - ) - inp_grads, activations = inp_grads_and_activations[ - : graph_meta.num_input_grads - ], list(inp_grads_and_activations[graph_meta.num_input_grads :]) - return inp_grads, activations - - -def _run_dW_bw_module( - bw_dW_module: fx.GraphModule, - graph_meta: GraphMeta, - bw_dW_args, - inductor: bool = False, -) -> list[Any]: - param_buffer_grads = _execute_graph(bw_dW_module, bw_dW_args, inductor=inductor) - return param_buffer_grads - - -def _run_unshard_module( - unshard_module: fx.GraphModule, - graph_meta: GraphMeta, - unshard_args, - inductor: bool = False, -) -> list[Any]: - unsharded_params = _execute_graph(unshard_module, unshard_args, inductor=inductor) - return unsharded_params - - -def _run_reduce_grad_module( - reduce_grad_module: fx.GraphModule, - graph_meta: GraphMeta, - reduce_grad_args, - inductor: bool = False, -) -> list[Any]: - sharded_grads = _execute_graph( - reduce_grad_module, reduce_grad_args, inductor=inductor - ) - return sharded_grads - - -def _run_multiplexed_fw_bw_module( - multiplexed_fw_bw_module: fx.GraphModule, - fw_graph_meta: GraphMeta, - bw_graph_meta: GraphMeta, - bw_fw_args, - inductor: bool = False, -) -> tuple[list[Any], list[Any], Any, tuple[tuple[Any], tuple[Any]]]: - multiplexed_outs = _execute_graph( - multiplexed_fw_bw_module, bw_fw_args, inductor=inductor - ) - - num_params_buffers = bw_graph_meta.num_params + bw_graph_meta.num_buffers - num_bw_outs = bw_graph_meta.num_input_grads + num_params_buffers - bw_outputs = multiplexed_outs[:num_bw_outs] - param_buffer_grads = bw_outputs[:num_params_buffers] - input_grads = bw_outputs[num_params_buffers:] - - fw_outputs = multiplexed_outs[num_bw_outs:] - num_inner_fwd_outputs = ( - fw_graph_meta.num_mutate_inputs + fw_graph_meta.num_user_outputs - ) - saved_intermediates = fw_outputs[num_inner_fwd_outputs:] - num_tensors_for_backward = ( - len(saved_intermediates) - fw_graph_meta.num_symints_saved_for_bw - ) - tensors_for_backward = saved_intermediates[:num_tensors_for_backward] - non_tensors_for_backward = saved_intermediates[num_tensors_for_backward:] - save_for_backward = (tensors_for_backward, non_tensors_for_backward) - user_outputs = fw_outputs[fw_graph_meta.num_mutate_inputs : num_inner_fwd_outputs] - if len(user_outputs) == 1: - user_outputs = user_outputs[0] - - return input_grads, param_buffer_grads, user_outputs, save_for_backward - - -def _get_stage_from_action( - action: _Action, - ctx: _PipelineContext, -) -> tuple[_PipelineScheduleRuntime, dict[int, GraphPipelineStage], GraphPipelineStage]: - """Helper to extract schedule, stage mapping, and specific stage from action and context. - - Args: - action: The action containing the stage index. - ctx: The pipeline context containing the schedule. - - Returns: - A tuple containing: - - schedule: The pipeline schedule runtime object. - - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. - - stage: The specific GraphPipelineStage for the action's stage index. - """ - schedule = ctx.schedule_ref - assert isinstance(schedule, _PipelineScheduleRuntime) - stage_index_to_stage: dict[int, GraphPipelineStage] = { - stage.stage_index: cast(GraphPipelineStage, stage) for stage in schedule._stages - } - stage = stage_index_to_stage[action.stage_index] - return schedule, stage_index_to_stage, stage - - -def _prepare_fwd_common( - action: _Action, - ctx: _PipelineContext, -) -> tuple[ - _PipelineScheduleRuntime, - dict[int, GraphPipelineStage], - GraphPipelineStage, - int, - bool, - bool, -]: - """Common setup for forward stage: retrieve stage info and handle recv ops. - - This function performs the shared initialization logic for forward operations, - including waiting for activation receives from the previous pipeline stage. - - Args: - action: The forward action to execute, containing the stage index and microbatch index. - ctx: The pipeline context containing the schedule and pipeline state. - - Returns: - A tuple containing: - - schedule: The pipeline schedule runtime object managing the execution. - - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. - - stage: The GraphPipelineStage for which forward is being computed. - - mb_index: The microbatch index being processed. - - is_next_stage_on_this_rank: True if stage_index + 1 exists on this rank (V-schedule). - - is_prev_stage_on_this_rank: True if stage_index - 1 exists on this rank (V-schedule). - """ - schedule, stage_index_to_stage, stage = _get_stage_from_action(action, ctx) - stage_index = stage.stage_index - - mb_index = action.microbatch_index - assert mb_index is not None - - is_next_stage_on_this_rank = stage_index + 1 in stage_index_to_stage - is_prev_stage_on_this_rank = stage_index - 1 in stage_index_to_stage - - if ( - not stage.is_first - # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) - and not is_prev_stage_on_this_rank - ): - fwd_recv_ops = schedule.fwd_recv_ops - assert ( - stage_index, - mb_index, - ) in fwd_recv_ops, f"Computing {action=} before receiving input" - _wait_batch_p2p(fwd_recv_ops.pop((stage_index, mb_index))) - - return ( - schedule, - stage_index_to_stage, - stage, - mb_index, - is_next_stage_on_this_rank, - is_prev_stage_on_this_rank, - ) - - -def _prepare_fwd_args( - stage: GraphPipelineStage, - mb_index: int, - ctx: _PipelineContext, -) -> list[Any]: - """Prepare forward args from user inputs or received activations. - - Args: - stage: The GraphPipelineStage for which to prepare forward arguments. - mb_index: The microbatch index being processed. - ctx: The pipeline context containing arg_mbs, kwarg_mbs, and target_mbs. - - Returns: - List of forward arguments including unsharded_params, buffers, and composite_args. - """ - arg_mbs = ctx.arg_mbs - kwarg_mbs = ctx.kwarg_mbs - - args = arg_mbs[mb_index] # type: ignore[index] - kwargs = kwarg_mbs[mb_index] # type: ignore[index] - assert not kwargs # TODO: if kwargs can always be ignored, maybe remove? - - if stage.is_first: - # First stage doesn't need to receive anything - composite_args = args - else: - # Receive activations for this chunk - # Activations only come in args form - composite_args = stage._retrieve_recv_activations(mb_index) - if stage.is_last and ctx.target_mbs is not None: - assert isinstance( - composite_args, tuple - ), f"Expected composite args to be a tuple but got {type(composite_args)}" - composite_args = composite_args + (ctx.target_mbs[mb_index],) # type: ignore[index] - - # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args? - fw_args = [ - *stage.state["unsharded_params"], - *stage.state["buffers"], - *composite_args, - ] - del composite_args - return fw_args - - -def _post_fwd_common( - action: _Action, - stage: GraphPipelineStage, - mb_index: int, - output: Any, - saved_intermediates: tuple[tuple[Any], tuple[Any]], - schedule: _PipelineScheduleRuntime, - stage_index_to_stage: dict[int, GraphPipelineStage], - ctx: _PipelineContext, - is_next_stage_on_this_rank: bool, -) -> None: - """Common post-processing after forward pass: cache outputs and propagate. - - This function handles the shared finalization logic for forward operations, - including normalizing outputs, caching for backward, validating outputs, - and propagating activations to the next pipeline stage. - - Args: - stage: The stage that just completed forward computation. - mb_index: The microbatch index that was processed. - output: The output from the forward pass. - saved_intermediates: The intermediates saved for backward pass. - schedule: The pipeline schedule runtime object. - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. - ctx: The pipeline context. - is_next_stage_on_this_rank: True if the next stage exists on this rank. - """ - # See [Note: pipeline model output type] - output_tuple = _normalize_model_output_as_tuple(output) - - # Prepare for final output merge or reduction - # Output chunks is only used for the last stage since we only merge the output of the last stage - if stage.is_last: - stage.output_chunks.append(output) - if ctx.target_mbs is not None: - ctx.schedule_ref._internal_losses.append(output) - - if stage.should_log_fw_outs: - assert stage.numerics_logger is not None - stage.numerics_logger.log_diff( - output, - rank=torch.distributed.get_rank(), - prefix=f"mb{action.microbatch_index} fwd out", - ) - - stage.fwd_cache[mb_index] = (output_tuple, saved_intermediates) # type: ignore[assignment] - - if hasattr(stage, "_validate_fwd_outputs"): - stage._validate_fwd_outputs(output_tuple) - - schedule._maybe_compute_loss(stage, output, ctx.target_mbs, mb_index) - - # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank - # see [Note: V-schedule special case] - if is_next_stage_on_this_rank: - stage_index_to_stage[stage.stage_index + 1].set_local_fwd_input( - output, mb_index - ) - - -def stage_forward( - action: _Action, - ctx: _PipelineContext, - numerics_logs: Optional[list[str]] = None, -) -> None: - ( - schedule, - stage_index_to_stage, - stage, - mb_index, - is_next_stage_on_this_rank, - is_prev_stage_on_this_rank, - ) = _prepare_fwd_common(action, ctx) - - fw_args = _prepare_fwd_args(stage, mb_index, ctx) - - logger.debug( - "GraphPPRunner running action %s", - action, - ) - output, saved_intermediates = _run_fw_module( - stage.graph_callables.fw, - stage.graph_meta, - fw_args, - numerics_logs=numerics_logs, - inductor=stage.inductor, - ) - - _post_fwd_common( - action, - stage, - mb_index, - output, - saved_intermediates, - schedule, - stage_index_to_stage, - ctx, - is_next_stage_on_this_rank, - ) - - -def _prepare_backward_common( - action: _Action, - ctx: _PipelineContext, -) -> tuple[ - _PipelineScheduleRuntime, - dict[int, GraphPipelineStage], - GraphPipelineStage, - int, - bool, - bool, -]: - """Common setup for backward stages: retrieve stage info and handle recv ops. - - This function performs the shared initialization logic for all backward operations, - including waiting for gradient receives from the next pipeline stage and incrementing - the backward counter. - - Args: - action: The backward action to execute, containing the stage index and microbatch index. - ctx: The pipeline context containing the schedule and pipeline state. - - Returns: - A tuple containing: - - schedule: The pipeline schedule runtime object managing the execution. - - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. - - bw_stage: The GraphPipelineStage for which backward is being computed. - - bw_mb_index: The microbatch index being processed. - - is_next_stage_on_this_rank: True if stage_index + 1 exists on this rank (V-schedule). - - is_prev_stage_on_this_rank: True if stage_index - 1 exists on this rank (V-schedule). - """ - schedule, stage_index_to_stage, bw_stage = _get_stage_from_action(action, ctx) - - bw_mb_index = action.microbatch_index - assert bw_mb_index is not None - is_next_stage_on_this_rank = bw_stage.stage_index + 1 in stage_index_to_stage - is_prev_stage_on_this_rank = bw_stage.stage_index - 1 in stage_index_to_stage - - if not bw_stage.is_last and not is_next_stage_on_this_rank: - bwd_recv_ops = schedule.bwd_recv_ops - assert ( - bw_stage.stage_index, - bw_mb_index, - ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" - _wait_batch_p2p(bwd_recv_ops.pop((bw_stage.stage_index, bw_mb_index))) - - schedule.backward_counter[bw_stage.stage_index] += 1 - - return ( - schedule, - stage_index_to_stage, - bw_stage, - bw_mb_index, - is_next_stage_on_this_rank, - is_prev_stage_on_this_rank, - ) - - -def _prepare_backward_args( - bw_stage: GraphPipelineStage, - bw_mb_index: int, -) -> list[Any]: - """Prepare backward kwargs from cached forward outputs.""" - ( - stage_output, - saved_intermediates, - ) = bw_stage.fwd_cache.pop(bw_mb_index) - - if bw_stage.is_last: - assert len(stage_output) == 1 - loss = stage_output[0] - tangents = (torch.ones_like(loss),) - else: - tangents = bw_stage._retrieve_recv_grads(bw_mb_index) - - tensors_for_backward, non_tensors_for_backward = saved_intermediates - - bw_args = [ - *non_tensors_for_backward, - *tensors_for_backward, - *tangents, - ] - del tensors_for_backward, non_tensors_for_backward, tangents, saved_intermediates - return bw_args - - -def _post_backward_common( - bw_stage: GraphPipelineStage, - bw_mb_index: int, - input_grads: list[Any], - stage_index_to_stage: dict[int, GraphPipelineStage], - is_prev_stage_on_this_rank: bool, -) -> None: - """Common post-processing after backward pass: cache input grads and propagate. - - This function handles the shared finalization logic for backward operations, - including caching input gradients and propagating gradients to the previous - pipeline stage. - - Note: Gradient accumulation and scaling are NOT included here as they occur - at different points for full_backward vs split dI/dW: - - full_backward: accumulation and scaling happen immediately after backward - - split dI/dW: accumulation and scaling happen in backward_weight (dW), not backward_input (dI) - - Args: - bw_stage: The stage that just completed backward computation. - bw_mb_index: The microbatch index that was processed. - input_grads: The computed input gradients to cache. - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. - is_prev_stage_on_this_rank: True if the previous stage exists on this rank. - """ - bw_stage.bwd_cache[bw_mb_index] = ( - tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads - ) - - if is_prev_stage_on_this_rank: - stage_index_to_stage[bw_stage.stage_index - 1].set_local_bwd_input( - bw_stage.get_local_bwd_output(bw_mb_index), - bw_mb_index, - ) - - -def stage_full_backward( - action: _Action, - ctx: _PipelineContext, -) -> None: - ( - schedule, - stage_index_to_stage, - bw_stage, - bw_mb_index, - is_next_stage_on_this_rank, - is_prev_stage_on_this_rank, - ) = _prepare_backward_common(action, ctx) - - last_backward = ( - schedule.backward_counter[bw_stage.stage_index] == schedule._n_microbatches - ) - grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 - - if not bw_stage.has_backward: - logger.debug("Returning early for backward stage") - return - - bw_args = _prepare_backward_args(bw_stage, bw_mb_index) - - logger.debug( - "GraphPPRunner running action %s", - action, - ) - input_grads, param_buffer_grads = _run_full_bw_module( - bw_stage.graph_callables.full_bw, - bw_stage.graph_meta, - bw_args, - inductor=bw_stage.inductor, - ) - bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) - - _post_backward_common( - bw_stage, - bw_mb_index, - input_grads, - stage_index_to_stage, - is_prev_stage_on_this_rank, - ) - - if last_backward: - bw_stage.scale_grads(grad_scale_factor) - - -def stage_backward_input( - action: _Action, - ctx: _PipelineContext, -) -> None: - schedule, stage_index_to_stage, bw_stage = _get_stage_from_action(action, ctx) - - if bw_stage.is_first and bw_stage.graph_callables.bw_dI is None: - # First stage does not have bw_dI graph since usually the inputs of the first stage do not require gradients - # Hence, we do not do a split_dI_dW pass, and call full backward instead during dI action - logger.debug( - "GraphPPRunner skipping action %s", - action, - ) - new_action = _Action( - action.stage_index, - FULL_BACKWARD, - action.microbatch_index, - action.sub_actions, - ) - stage_full_backward(new_action, ctx) - return - - ( - schedule, - stage_index_to_stage, - bw_stage, - bw_mb_index, - is_next_stage_on_this_rank, - is_prev_stage_on_this_rank, - ) = _prepare_backward_common(action, ctx) - - if not bw_stage.has_backward: - logger.debug("Returning early for backward stage") - return - - bw_args = _prepare_backward_args(bw_stage, bw_mb_index) - - logger.debug( - "GraphPPRunner running action %s", - action, - ) - assert bw_stage.graph_callables.bw_dI is not None - input_grads, activations_for_backward = _run_dI_bw_module( - bw_stage.graph_callables.bw_dI, - bw_stage.graph_meta, - bw_args, - inductor=bw_stage.inductor, - ) - - bw_stage.bwd_activation_cache[bw_mb_index] = ( - tuple(activations_for_backward) - if not isinstance(activations_for_backward, tuple) - else activations_for_backward - ) - - _post_backward_common( - bw_stage, - bw_mb_index, - input_grads, - stage_index_to_stage, - is_prev_stage_on_this_rank, - ) - - -def stage_backward_weight( - action: _Action, - ctx: _PipelineContext, -) -> None: - schedule, stage_index_to_stage, bw_stage = _get_stage_from_action(action, ctx) - bw_mb_index = action.microbatch_index - assert bw_mb_index is not None - if bw_stage.is_first and bw_stage.graph_callables.bw_dW is None: - # First stage does not have bw_dW graph since usually the inputs of the first stage do not require gradients - # Hence, we do not do a split_dI_dW pass, and call full backward instead during dI action - # which also performs dW implicitly, hence we skip this step. - logger.debug( - "GraphPPRunner skipping action %s", - action, - ) - return - - last_backward = ( - schedule.backward_counter[bw_stage.stage_index] == schedule._n_microbatches - ) - grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 - - if not bw_stage.has_backward: - logger.debug("Returning early for backward stage") - return - - activations_for_backward = bw_stage.bwd_activation_cache.pop(bw_mb_index) - logger.debug( - "GraphPPRunner running action %s", - action, - ) - bw_args = list(activations_for_backward) - del activations_for_backward - assert bw_stage.graph_callables.bw_dW is not None - param_buffer_grads = _run_dW_bw_module( - bw_stage.graph_callables.bw_dW, - bw_stage.graph_meta, - bw_args, - inductor=bw_stage.inductor, - ) - bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) - - if last_backward: - bw_stage.scale_grads(grad_scale_factor) - - -def overlap_fw_bw( - multiplexed_graph_callables: dict[tuple[int, int], fx.GraphModule], - action: _Action, - ctx: _PipelineContext, -) -> None: - assert action.sub_actions is not None, "Expected sub actions for overlap callback" - fw_action = action.sub_actions[0] - bw_action = action.sub_actions[1] - - ( - schedule, - stage_index_to_stage, - fw_stage, - fw_mb_index, - fw_is_next_stage_on_this_rank, - fw_is_prev_stage_on_this_rank, - ) = _prepare_fwd_common(fw_action, ctx) - - ( - _, - _, - bw_stage, - bw_mb_index, - bw_is_next_stage_on_this_rank, - bw_is_prev_stage_on_this_rank, - ) = _prepare_backward_common(bw_action, ctx) - - last_backward = ( - schedule.backward_counter[bw_stage.stage_index] == schedule._n_microbatches - ) - grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 - - if not bw_stage.has_backward: - logger.debug("Returning early for backward stage") - return - - fw_args = _prepare_fwd_args(fw_stage, fw_mb_index, ctx) - bw_args = _prepare_backward_args(bw_stage, bw_mb_index) - bw_fw_args = bw_args + fw_args - del bw_args, fw_args - multiplexed_fw_bw_module = multiplexed_graph_callables.get( - (fw_action.stage_index, bw_action.stage_index) - ) - assert ( - multiplexed_fw_bw_module is not None - ), "Expected multiplexed graph callables for overlap callback" - logger.debug( - "GraphPPRunner running action %s", - action, - ) - ( - input_grads, - param_buffer_grads, - output, - saved_intermediates, - ) = _run_multiplexed_fw_bw_module( - multiplexed_fw_bw_module, - fw_stage.graph_meta, - bw_stage.graph_meta, - bw_fw_args, - inductor=fw_stage.inductor, - ) - - bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) - - _post_fwd_common( - action, - fw_stage, - fw_mb_index, - output, - saved_intermediates, - schedule, - stage_index_to_stage, - ctx, - fw_is_next_stage_on_this_rank, - ) - _post_backward_common( - bw_stage, - bw_mb_index, - input_grads, - stage_index_to_stage, - bw_is_prev_stage_on_this_rank, - ) - - if last_backward: - bw_stage.scale_grads(grad_scale_factor) - - -def stage_unshard( - action: _Action, - ctx: _PipelineContext, -) -> None: - schedule, stage_index_to_stage, stage = _get_stage_from_action(action, ctx) - logger.debug( - "GraphPPRunner running action %s", - action, - ) - if stage.graph_callables.unshard is None: - stage.state["unsharded_params"] = stage.state["sharded_params"] - else: - sharded_params = list(stage.state["sharded_params"]) - unsharded_params = _run_unshard_module( - stage.graph_callables.unshard, - stage.graph_meta, - sharded_params, - inductor=stage.inductor, - ) - stage.state["unsharded_params"] = unsharded_params - - -def stage_reshard( - action: _Action, - ctx: _PipelineContext, -): - schedule, stage_index_to_stage, stage = _get_stage_from_action(action, ctx) - logger.debug( - "GraphPPRunner running action %s", - action, - ) - stage.state["unsharded_params"] = [] - - -def stage_reduce_grad( - action: _Action, - ctx: _PipelineContext, -) -> None: - schedule, stage_index_to_stage, stage = _get_stage_from_action(action, ctx) - logger.debug( - "GraphPPRunner running action %s", - action, - ) - if stage.graph_callables.reduce_grad is None: - stage.state["sharded_grads"] = stage.state["unsharded_grads"] - else: - sharded_grads = _run_reduce_grad_module( - stage.graph_callables.reduce_grad, - stage.graph_meta, - stage.state["unsharded_grads"], - inductor=stage.inductor, - ) - stage.state["sharded_grads"] = sharded_grads - - -class GraphPPRunner: - def __init__( - self, - schedule: _PipelineScheduleRuntime, - inductor: bool = False, - ): - self.schedule = schedule - if not schedule._backward_requires_autograd: - assert all( - isinstance(stage, GraphPipelineStage) - and ( - stage.graph_callables.full_bw is not None - or ( - stage.graph_callables.bw_dI is not None - and stage.graph_callables.bw_dW is not None - ) - ) - for stage in schedule._stages - ) - self.schedule._has_backward = True - for stage in schedule._stages: - assert isinstance(stage, GraphPipelineStage) - stage.inductor = inductor - - def _populate_stage_states(self, stage: GraphPipelineStage) -> None: - sharded_params = [ - v.to_local() if isinstance(v, DTensor) else v - for k, v in dict( - stage.submod.named_parameters(remove_duplicate=False) - ).items() - ] - buffers = [ - v.to_local() if isinstance(v, DTensor) else v - for k, v in dict(stage.submod.named_buffers(remove_duplicate=False)).items() - ] - stage.state["sharded_params"] = sharded_params - stage.state["buffers"] = buffers - stage.state["unsharded_grads"] = [None] * len(sharded_params) - - def _accumulate_stage_sharded_grads(self, stage: GraphPipelineStage) -> None: - grads = stage.state["sharded_grads"] - params = list(stage.submod.parameters()) - for param, grad in zip(params, grads): - if param.requires_grad and grad is not None: - assert isinstance(grad, torch.Tensor) - if isinstance(param, DTensor): - param_spec = param._spec - _grad = DTensor.from_local( - grad, - device_mesh=param_spec.device_mesh, - placements=param_spec.placements, - shape=param_spec.shape, - stride=param_spec.stride, - ) - else: - _grad = grad # type: ignore[assignment] - if param.grad is None: - param.grad = _grad - else: - param.grad += _grad - - def step(self, *args, **kwargs) -> None: - has_targets_and_loss = ( - "losses" in kwargs and "targets" in kwargs if kwargs else False - ) - for stage in self.schedule._stages: - assert isinstance(stage, GraphPipelineStage) - self._populate_stage_states(stage) - - self.schedule.step(*args, **kwargs) - - for stage in self.schedule._stages: - assert isinstance(stage, GraphPipelineStage) - self._accumulate_stage_sharded_grads(stage) - stage.state.clear() - - if has_targets_and_loss: - losses = kwargs["losses"] - assert len(self.schedule._internal_losses) == self.schedule._n_microbatches - losses.extend(self.schedule._internal_losses) - self.schedule._internal_losses.clear() diff --git a/autoparallel/graph_passes/split_di_dw_graph.py b/autoparallel/graph_passes/split_di_dw_graph.py deleted file mode 100644 index 05a19e34..00000000 --- a/autoparallel/graph_passes/split_di_dw_graph.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import operator - -import torch -import torch.fx as fx -from torch._functorch.partitioners import ( - _extract_fwd_bwd_modules, - _extract_fwd_bwd_outputs, - _extract_graph_with_inputs_outputs, - is_sym_node, -) -from torch.utils._ordered_set import OrderedSet - -from autoparallel.apply_sharding import rename_placeholder_node - -# we are running the default partitioner on the bw graph, which requires AC tags being removed. -# At this stage we have already finished running AC anyway, since we have a bw graph - - -def remove_recompute_tags(bw_gm): - for n in bw_gm.graph.nodes: - if "recompute" in n.meta: - del n.meta["recompute"] - - -# We are using the default partitioner to split our backward into dI and dW subgraphs. -# We want to generate the dI subgraph *first*, because: -# - in pipelining we generally want to schedule dI compute before dW -# - the dI compute will potentially compute more activations that we need to plumb into dW compute -# Today, the default partitioner requires that your split on the first K outputs of your combined graph. -# So here, we reorder the outputs of the backward so grad_inputs are first. - - -def reorder_output_grads(bw_gm, num_weight_gradients): - outputs = bw_gm.graph.find_nodes(op="output") - assert len(outputs) == 1 - output = outputs[0] - assert isinstance(output.args[0], tuple) - grad_weights, grad_inputs = ( - output.args[0][:num_weight_gradients], - output.args[0][num_weight_gradients:], - ) - new_out_tuple = grad_inputs + grad_weights - with bw_gm.graph.inserting_after(output): - # TODO: also set the new node's meta properly - new_out = bw_gm.graph.output(new_out_tuple) - output.replace_all_uses_with(new_out) - bw_gm.graph.erase_node(output) - return len(grad_inputs) - - -# TODO: in theory we can infer num_weight_gradients from the graph metadata directly -def split_di_dw_graph( - bw_gm_old: fx.GraphModule, *, num_weight_gradients: int -) -> tuple[fx.GraphModule, fx.GraphModule, int]: - # we could consider doing this is a non-mutating way - bw_gm = copy.deepcopy(bw_gm_old) - placeholders = bw_gm.graph.find_nodes(op="placeholder") - for p in placeholders: - if p.name.startswith("tangent"): - name_suffix = p.name[8:] - rename_placeholder_node(bw_gm, p, f"not_tngnt{name_suffix}") - - remove_recompute_tags(bw_gm) - num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients) - bw_gm.recompile() - - args = list(bw_gm.graph.find_nodes(op="placeholder")) - - # bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients) - # return bw_inputs, bw_weights, num_input_gradients - - ( - grad_inps, - grad_weights, - grad_inp_descs, - grad_weight_descs, - ) = _extract_fwd_bwd_outputs(bw_gm, num_fwd_outputs=num_input_gradients) - bw_inputs_gm = _extract_graph_with_inputs_outputs( - bw_gm.graph, - args, - grad_inps, - grad_inp_descs, - "forward", - ignore_must_be_in_fw_bw=True, - ) - bw_inputs_gm_node_names = OrderedSet( - node.name for node in bw_inputs_gm.nodes if node.op != "output" - ) - saved_values = [] - saved_sym_nodes = [] - - # TODO: this classification loop is a simplified version of default_partition's - # node classification. It does not handle: get_attr nodes, _assert_scalar/profiler - # ops, MUST_SAVE tags, impure/effectful ops, force_save_collectives, - # force_save_bw_mutation_src, must_recompute skipping, or post-split DCE. - # Ideally we would call default_partition directly instead of reimplementing. - for node in bw_gm.graph.nodes: - if node.name not in bw_inputs_gm_node_names: - # Not handling mutations for now, - # we can try to re-use more of and/or consolidate with default partitioner - continue - if is_sym_node(node): - saved_sym_nodes.append(node) - elif ( - "tensor_meta" not in node.meta - and node.op == "call_function" - and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) - ): - users = node.users - assert all(user.target == operator.getitem for user in users) - saved_values.extend(users) - else: - backward_usages = [ - n for n in node.users if n.name not in bw_inputs_gm_node_names - ] - if "tensor_meta" in node.meta and all( - is_sym_node(n) for n in backward_usages - ): - saved_sym_nodes.extend(backward_usages) - else: - saved_values.append(node) - saved_values = list(dict.fromkeys(saved_values).keys()) - saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - bw_inputs, bw_weights = _extract_fwd_bwd_modules( - bw_gm, - saved_values, - saved_sym_nodes=saved_sym_nodes, - num_fwd_outputs=num_input_gradients, - ignore_must_be_in_fw_bw=True, - omit_aot_autograd_runtime=True, - ) - return bw_inputs, bw_weights, num_input_gradients diff --git a/autoparallel/graph_passes/split_fsdp_collectives.py b/autoparallel/graph_passes/split_fsdp_collectives.py deleted file mode 100644 index d2baa422..00000000 --- a/autoparallel/graph_passes/split_fsdp_collectives.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import dataclasses -from contextlib import contextmanager -from copy import deepcopy -from functools import partial -from typing import Any - -import torch -import torch.fx.node -import torch.utils._pytree as pytree -from torch._functorch._aot_autograd.descriptors import AOTOutput -from torch._functorch.partitioners import _extract_graph_with_inputs_outputs - -from autoparallel.graph_passes.activation_checkpointing import ( - find_last_all_gather_in_chain, - find_last_non_view_node_in_chain, - find_last_user_in_wait_chain, - is_reduce_scatter_tensor, - is_wait_tensor, -) - - -@contextmanager -def exclude_from_fx_side_effectful(exclude_vals: set[Any]): - original_val = torch.fx.node._side_effectful_functions.copy() - try: - torch.fx.node._side_effectful_functions -= exclude_vals - yield - finally: - torch.fx.node._side_effectful_functions.clear() - torch.fx.node._side_effectful_functions.update(original_val) - - -exclude_wait_from_fx_side_effectful = partial( - exclude_from_fx_side_effectful, - { - torch.ops._c10d_functional.wait_tensor, - torch.ops._c10d_functional.wait_tensor.default, - }, -) - - -@dataclasses.dataclass(frozen=True) -class PrefetchOutput(AOTOutput): - pass - - -@dataclasses.dataclass(frozen=True) -class EpilogueInput(AOTOutput): - pass - - -def split_fsdp_prefetch( - gm: torch.fx.GraphModule, - num_params: int, -) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: - g = deepcopy(gm.graph) - all_g_ins = g.find_nodes(op="placeholder") - param_g_ins = all_g_ins[:num_params] - rem_g_ins = all_g_ins[num_params:] - - prefetch_g_outs_map = [] - - for param_g_in in param_g_ins: - # 1. Find last all_gather from each placeholder - last_ag_node = find_last_all_gather_in_chain(param_g_in) - if last_ag_node is None: - prefetch_g_outs_map.append(param_g_in) - else: - # 2. Find last wait_tensor from last all_gather - last_ag_wait_node = next(iter(last_ag_node.users)) - assert is_wait_tensor(last_ag_wait_node) - - # 3. Continue the linear chain from the last wait_tensor - last_wait_chain_user = find_last_user_in_wait_chain(last_ag_wait_node) - - # 4. Get the last non-view node in the wait chain - last_non_view_wait_chain_user = find_last_non_view_node_in_chain( - last_wait_chain_user - ) - - prefetch_g_outs_map.append(last_non_view_wait_chain_user) - - prefetch_g_outs = prefetch_g_outs_map - prefetch_g_outs_descs: list[AOTOutput] = [ - PrefetchOutput() for _ in range(len(prefetch_g_outs)) - ] - g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) - g_outs_descs = pytree.arg_tree_leaves( - next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) - ) - with exclude_wait_from_fx_side_effectful(): - prefetch_g = _extract_graph_with_inputs_outputs( - g, - param_g_ins, - prefetch_g_outs, - prefetch_g_outs_descs, - ignore_must_be_in_fw_bw=True, - ) - - main_g = _extract_graph_with_inputs_outputs( - g, - prefetch_g_outs + rem_g_ins, - g_outs, - g_outs_descs, - ignore_must_be_in_fw_bw=True, - ) - prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) - main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) - return prefetch_gm, main_gm - - -def split_fsdp_reduce_scatters_epilogue( - gm: torch.fx.GraphModule, - num_grads: int, -) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: - g = deepcopy(gm.graph) - g_ins = g.find_nodes(op="placeholder") - g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) - grad_outs = g_outs[:num_grads] - rem_g_outs = g_outs[num_grads:] - out_descs = pytree.arg_tree_leaves( - next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs)) - ) - grad_outs_descs = out_descs[:num_grads] - rem_g_outs_descs = out_descs[num_grads:] - - grad_outs_map = [] - for grad_out in grad_outs: - n = grad_out - earliest_rs = None - while n is not None: - if len(n.all_input_nodes) != 1: - break - n_in = n.all_input_nodes[0] - if len(n_in.users) > 1: - break - prev_n = n - n = n_in - # Maybe we also need to track all_reduce? - if is_reduce_scatter_tensor(prev_n): - # In AP for mesh dim > 1 - # The reduction of gradients happen in multiple steps - earliest_rs = n - if earliest_rs is not None: - grad_outs_map.append(earliest_rs) - else: - grad_outs_map.append(grad_out) - - epi_g_ins = grad_outs_map - epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] - - with exclude_wait_from_fx_side_effectful(): - main_g = _extract_graph_with_inputs_outputs( - g, - g_ins, - epi_g_ins + rem_g_outs, - epi_g_ins_descs + rem_g_outs_descs, - ignore_must_be_in_fw_bw=True, - ) - epi_g = _extract_graph_with_inputs_outputs( - g, - epi_g_ins, - grad_outs, - grad_outs_descs, - ignore_must_be_in_fw_bw=True, - ) - epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g) - main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) - return main_gm, epi_gm diff --git a/autoparallel/shardings/placement_options.py b/autoparallel/shardings/placement_options.py index 7fc906c2..cb4ca2d5 100644 --- a/autoparallel/shardings/placement_options.py +++ b/autoparallel/shardings/placement_options.py @@ -789,63 +789,6 @@ def log_diff(self, t, rank=0, prefix="?"): with open(path, "a") as f: f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n") - def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_log): - path = self.dir / "pp_weights.log" - - torch.distributed.barrier() - # First print the params of every stage - for i in range(num_world_stages): - if should_log and i in stage_mods: - param_logs = [] - real_params = dict(stage_mods[i].named_parameters()) - for name, _ in orig_mod.named_parameters(): - if name not in real_params: - continue - param = real_params[name] - param_logs.append(f"{name=} hash={hash_tensor(param)}") - with open(path, "a") as f: - f.write("\n".join(param_logs) + "\n") - torch.distributed.barrier() - - # Then print the buffers of every stage - for i in range(num_world_stages): - if should_log and i in stage_mods: - buffer_logs = [] - real_buffers = dict(stage_mods[i].named_buffers()) - for name, _ in orig_mod.named_buffers(): - if name not in real_buffers: - continue - buffer = real_buffers[name] - buffer_logs.append(f"{name=} hash={hash_tensor(buffer)}") - with open(path, "a") as f: - f.write("\n".join(buffer_logs) + "\n") - torch.distributed.barrier() - - if self.rank == 0: - logger.info(f"Weight hashes written to {path}") - - def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log): - path = self.dir / "diff.log" - - for i in range(num_world_stages): - if should_log and i in stage_mods: - grad_logs = [] - real_params = dict(stage_mods[i].named_parameters()) - for name, _ in orig_mod.named_parameters(): - if name not in real_params: - continue - grad = real_params[name].grad - if grad is None: - grad_logs.append(f"[grad {name}] None") - else: - grad = grad.to_local() - grad_logs.append( - f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}" - ) - with open(path, "a") as f: - f.write("\n".join(grad_logs) + "\n") - torch.distributed.barrier() - def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger): from torch._inductor.fx_passes.post_grad import view_to_reshape diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py deleted file mode 100644 index cc391836..00000000 --- a/examples/example_ds3_pp.py +++ /dev/null @@ -1,742 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import os -from contextlib import nullcontext -from functools import partial -from typing import Callable, Optional - -import torch -import torch.distributed._tools.fake_collectives -import torch.nn as nn -from torch._logging import trace_structured -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed.pipelining.schedules import ( - BACKWARD_INPUT, - BACKWARD_WEIGHT, - FORWARD, - FULL_BACKWARD, - OVERLAP_F_B, - REDUCE_GRAD, - RESHARD, - UNSHARD, - PipelineScheduleMulti, - _PipelineSchedule, - _PipelineScheduleRuntime, - get_schedule_class, -) -from torch.distributed.pipelining.stage import PipelineStage -from torch.distributed.tensor.placement_types import Replicate, Shard -from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.testing._internal.distributed.fake_pg import FakeStore - -import autoparallel._testing.models.dsv3 as dsv3_module -from autoparallel._testing.models.dsv3 import ( - DeepSeekV3Model, - DeepSeekV3ModelArgs, - DeepSeekV3Stage0, - DeepSeekV3StageI, - DeepSeekV3StageN, - MoEArgs, - dsv3_loss_fn, -) -from autoparallel.api import move_to_fake -from autoparallel.api_pp import AutoParallelPP, make_pp_module -from autoparallel.graph_passes.graph_pp_runner import ( - GraphCallables, - GraphMeta, - GraphPipelineStage, - GraphPPRunner, - get_multiplexed_graph_callables, - overlap_fw_bw, - stage_backward_input, - stage_backward_weight, - stage_forward, - stage_full_backward, - stage_reduce_grad, - stage_reshard, - stage_unshard, -) -from autoparallel.shardings.placement_options import NumericsLogger - -# Configure logging to show DEBUG messages -logging.basicConfig( - level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -def assign_logical_stages_to_pp_rank( - schedule_name: str, pp_degree: int, stages_per_rank: int -) -> dict[int, list[int]]: - style = "v" if schedule_name in ("ZBVZeroBubble", "DualPipeV") else "loop" - if style == "loop": - pp_rank_to_stage_indices = { - pp_rank: [pp_rank + s * pp_degree for s in range(stages_per_rank)] - for pp_rank in range(pp_degree) - } - elif style == "v": - total_pp_stages = pp_degree * stages_per_rank - pp_rank_to_stage_indices = { - pp_rank: [pp_rank, total_pp_stages - 1 - pp_rank] - for pp_rank in range(pp_degree) - } - return pp_rank_to_stage_indices - - -def build_pipeline_schedule( - stages: list[PipelineStage], - loss_fn: Callable, - pipeline_parallel_schedule: str, - microbatch_size: int, - local_batch_size: int, - pipeline_parallel_degree: int, - backward_requires_autograd: bool = False, - scale_grads: bool = True, -) -> _PipelineSchedule: - """Builds a pipeline schedule for the given configuration and stages.""" - schedule_class = get_schedule_class(pipeline_parallel_schedule) - - looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) - assert looped_schedule, "Only looped schedules are supported" - # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training - if local_batch_size % microbatch_size != 0: - raise ValueError( - f"Batch size {local_batch_size} must be divisible by {microbatch_size=}. " - ) - n_microbatches = local_batch_size // microbatch_size - # We expect that the number of local stages (`len(stages)`) is the same across all pp ranks - num_total_stages = pipeline_parallel_degree * len(stages) - if n_microbatches < num_total_stages: - logger.warning( - f"Number of microbatches ({n_microbatches}) is less than the total number " - f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." - ) - - schedule = schedule_class( - stages if looped_schedule else stages[0], - n_microbatches=n_microbatches, - loss_fn=loss_fn, - backward_requires_autograd=backward_requires_autograd, - scale_grads=scale_grads, - ) - logger.info( - f"Using pipeline schedule {pipeline_parallel_schedule} " - f"with {n_microbatches} microbatches and {num_total_stages} stages." - ) - return schedule - - -def run_test( - fake_evaluate: bool, - use_loss_fn: bool, - schedule_name: str, - rng_seed: Optional[int], - logs_dir: str, - use_cache: bool, - use_inductor: bool = False, -): - if not fake_evaluate: - pp_degree = 2 - dp_mod_ep_degree = 2 - ep_degree = 2 - else: - pp_degree = 4 - dp_mod_ep_degree = 4 - ep_degree = 64 - - dp_degree = dp_mod_ep_degree * ep_degree - world_size = pp_degree * dp_mod_ep_degree * ep_degree - - # Initialize process group based on evaluation mode - if fake_evaluate: - assert ( - "WORLD_SIZE" in os.environ - ), "run with torchrun --standalone --nproc-per-node 4" - assert ( - int(os.getenv("WORLD_SIZE")) == pp_degree - ), "world_size must be 4, for fake evaluation" - rank = int(os.getenv("RANK")) - device = torch.device(f"cuda:{rank}") - torch.cuda.set_device(device) - fake_store = FakeStore() - torch.distributed.init_process_group( - "fake", - store=fake_store, - rank=rank * dp_degree, # global rank is pp_rank * spmd_size - world_size=world_size, - ) - pp_rank = rank - else: - assert ( - "WORLD_SIZE" in os.environ - ), "run with torchrun --standalone --nproc-per-node 8" - assert ( - int(os.getenv("WORLD_SIZE")) == world_size - ), "Need at least 8 GPUs for real evaluation" - local_rank = int(os.getenv("LOCAL_RANK")) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - torch.distributed.init_process_group(backend="nccl") - - # Initialize device mesh (common for both modes) - world_mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - (pp_degree, dp_mod_ep_degree, ep_degree), - mesh_dim_names=( - "pp", - "dp_mod_ep", - "ep", - ), - ) - - # Set pp_rank based on evaluation mode - if not fake_evaluate: - pp_rank = world_mesh["pp"].get_local_rank() - - stages_per_rank = 2 - total_pp_stages = pp_degree * stages_per_rank - - # This is the spmd mesh to be used for tracing - mesh = world_mesh[("dp_mod_ep", "ep")] - - # Batch size that will be supplied to the schedule and will be broken down into microbatches - local_batch_size = 32 - # global_batch_size = local_batch_size * dp_degree - n_microbatches = 16 - # Batch size with which the spmd graphs will actually be executed - microbatch_size = local_batch_size // n_microbatches - assert ( - microbatch_size >= 1 - ), f"invalid config {local_batch_size=}, {n_microbatches=}" - # Batch size to be used for spmd tracing - spmd_batch_size = microbatch_size * dp_degree - - seq_len = 1024 - - if fake_evaluate: - config = DeepSeekV3ModelArgs( - vocab_size=102400, - max_seq_len=seq_len, - dim=2048, - inter_dim=10944, - moe_inter_dim=1408, - n_layers=8, # 27, - n_dense_layers=0, # 1, - n_heads=16, - moe_args=MoEArgs( - num_experts=64, - num_shared_experts=2, - top_k=6, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, - use_flex_attn=False, - attn_mask_type="causal", - ) - else: - config = DeepSeekV3ModelArgs( - vocab_size=2048, - max_seq_len=seq_len, - dim=256, - inter_dim=1024, - moe_inter_dim=256, - n_layers=4, - n_dense_layers=0, # 1, - n_heads=16, - moe_args=MoEArgs( - num_experts=4, - num_shared_experts=2, - top_k=2, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, - ) - - with torch.device("meta"): - model = DeepSeekV3Model(config).bfloat16() - embed, layers, norm, output = list(model.children()) - items = list(layers.items()) - assert len(items) == config.n_layers - n_layers_per_rank = len(items) // total_pp_stages - layers = [ - nn.ModuleDict(items[i : i + n_layers_per_rank]) - for i in range(0, len(items), n_layers_per_rank) - ] - assert len(layers) == total_pp_stages - for lst in layers: - assert len(lst) * len(layers) == config.n_layers - - def make_input_fn( - batch_size: int, - inp_type: str, - device: torch.device, - ): - """ - Factory to create input/output generator functions for pipeline stages. - - Args: - batch_size: Batch size (spmd_batch_size, local_batch_size, or microbatch_size) - inp_type: One of "tokens", "embeddings", or "logits" - device: Device to create tensors on (cuda device or "meta") - """ - - def input_fn() -> torch.Tensor: - if inp_type == "tokens": - return torch.randint( - 0, - config.vocab_size, - (batch_size, seq_len), - device=device, - ) - elif inp_type == "embeddings": - return torch.randn( - (batch_size, seq_len, config.dim), - device=device, - dtype=torch.bfloat16, - requires_grad=True, - ) - elif inp_type == "logits": - return torch.randn( - (batch_size, seq_len, config.vocab_size), - device=device, - dtype=torch.bfloat16, - requires_grad=True, - ) - elif inp_type == "loss": - return torch.scalar_tensor( - 1.0, - dtype=torch.float32, - device=device, - requires_grad=True, - ) - else: - raise ValueError(f"Unknown input type: {inp_type}") - - return input_fn - - # Target generators (if needed for loss computation) - tracing_target_fn = make_input_fn(spmd_batch_size, "tokens", device) - runtime_target_fn = make_input_fn(local_batch_size, "tokens", device) - - # Tracing input functions - tracing_input_fn_fist_stage = make_input_fn(spmd_batch_size, "tokens", device) - tracing_input_fn_intermediate_stage = make_input_fn( - spmd_batch_size, "embeddings", device - ) - - def last_stage_inp_with_loss_fn(): - return ( - tracing_input_fn_intermediate_stage(), - tracing_target_fn(), - ) - - tracing_input_fn_last_stage = ( - last_stage_inp_with_loss_fn - if use_loss_fn - else tracing_input_fn_intermediate_stage - ) - - # Runtime input function - runtime_input_fn_first_stage = make_input_fn(local_batch_size, "tokens", device) - - # Shape inference functions - meta_device = torch.device("meta") - shape_inference_input_fn_first_stage = make_input_fn( - microbatch_size, "tokens", meta_device - ) - shape_inference_fn_intermediate_stage = make_input_fn( - microbatch_size, "embeddings", meta_device - ) - shape_inference_output_fn_last_stage = ( - make_input_fn(0, "loss", meta_device) - if use_loss_fn - else make_input_fn(microbatch_size, "logits", meta_device) - ) - - # Step 1. Construct the logical pipeline stages - with torch.device("meta"): - virtual_pp_stages = [DeepSeekV3Stage0(embed, layers[0], config)] - for i in range(1, total_pp_stages - 1): - virtual_pp_stages.append(DeepSeekV3StageI(layers[i], config)) - last_stage = DeepSeekV3StageN(layers[total_pp_stages - 1], norm, output, config) - if use_loss_fn: - - class ModelWithLoss(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, h, labels): - output = self.model(h) - return dsv3_loss_fn(output, labels) - - def init_weights(self, *args, **kwargs): - return self.model.init_weights(*args, **kwargs) - - last_stage = ModelWithLoss(last_stage) - virtual_pp_stages.append(last_stage) - # Step 2. Assign each logical stage(s) to pp ranks for the given schedule - pp_rank_to_stage_indices = assign_logical_stages_to_pp_rank( - schedule_name, pp_degree, stages_per_rank - ) - print(pp_rank_to_stage_indices) - assert len(pp_rank_to_stage_indices) == pp_degree - for stages in pp_rank_to_stage_indices.values(): - assert len(stages) * pp_degree == len(virtual_pp_stages) - stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] - should_log_weights = should_log_fw_outs = False - if rng_seed: - # Compute the ranks to log from - # 1. for fw_outs, log from coord [pp_rank_containing_last_stage, 0, 0] - last_stage_idx = total_pp_stages - 1 - pp_rank_containing_last_stage = None - for pp_rank_, stage_indices in pp_rank_to_stage_indices.items(): - if last_stage_idx in stage_indices: - assert pp_rank_containing_last_stage is None - pp_rank_containing_last_stage = pp_rank_ - - log_fw_out_rank_coordinate = [] - for mesh_dim_name in world_mesh.mesh_dim_names: - if mesh_dim_name == "pp": - log_fw_out_rank_coordinate.append(pp_rank_containing_last_stage) - else: - log_fw_out_rank_coordinate.append(0) - should_log_fw_outs = world_mesh.get_coordinate() == log_fw_out_rank_coordinate - - # 2. for weights, log from coords [:, 0, 0] - pp_world_size = world_mesh.shape[world_mesh._get_mesh_dim_by_name("pp")] - log_weights_rank_coordinates = [(i, 0, 0) for i in range(pp_world_size)] - should_log_weights = ( - tuple(world_mesh.get_coordinate()) in log_weights_rank_coordinates - ) - - stage_mods: dict[int, torch.nn.Module] = {} - stage_graphs: dict[int, GraphCallables] = {} - stage_graph_metas: dict[int, GraphMeta] = {} - # Step 3. Apply AutoParallel to each logical stage assigned to this pp rank - root_cache = "tmp" - os.makedirs(root_cache, exist_ok=True) - - for stage_idx in stage_indices_current_pp_rank: - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": f"begin_tracing_stage_{stage_idx}", - "encoding": "string", - }, - payload_fn=lambda: "placeholder text", - ) - stage_mod = virtual_pp_stages[stage_idx] - eval_mode = "fake" if fake_evaluate else "real" - stage_file = os.path.join(root_cache, f"stage_{eval_mode}_{stage_idx}.pth") - if os.path.exists(stage_file) and use_cache: - cache = torch.load(stage_file, weights_only=False) - graph_callables = cache["graph_callables"] - graph_meta = cache["graph_meta"] - cache["sharded_param_dict"] = { - k: nn.Parameter(v.detach()) - for k, v in cache["sharded_param_dict"].items() - } - fake_mode = FakeTensorMode() - stage_mod = move_to_fake(stage_mod, fake_mode, device) - pp_mod = make_pp_module( - cache["sharded_param_dict"], - cache["sharded_buffer_dict"], - stage_mod, - ) - else: - if stage_idx == 0: - input_fn = tracing_input_fn_fist_stage - elif stage_idx == total_pp_stages - 1: - - input_fn = tracing_input_fn_last_stage - - else: - input_fn = tracing_input_fn_intermediate_stage - with AutoParallelPP( - stage_mod, - input_fn, - mesh, - dynamic=True, - reshard_after_forward=False, - ) as autop: - autop.add_parameter_memory_constraint(low=None, high=None) - - # x_sharding = (Shard(0), Replicate()) - x_sharding = (Shard(0), Shard(0)) - if use_loss_fn and stage_idx == total_pp_stages - 1: - autop.add_input_constraints([x_sharding, x_sharding]) - autop.add_output_constraints([(Replicate(), Replicate())]) - else: - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) - - sharding_placement = autop.optimize_placement(verbose=False) - graph_passes = ["split_fsdp_collectives"] - if stage_idx > 0: - # First stage does not produce gradients wrt to input, - # hence we do not do apply the split_dI_dW pass - graph_passes.extend(["split_dI_dW"]) - cache = autop.apply_placement_pp( - sharding_placement=sharding_placement, graph_passes=graph_passes - ) - graph_callables = cache["graph_callables"] - graph_meta = cache["graph_meta"] - pp_mod = autop.parallel_model - if use_cache: - torch.save(cache, stage_file) - - pp_mod.to_empty(device=device) - # run weight init on our sharded DTensor params - pp_mod.init_weights(buffer_device=device, seed=rng_seed) - - # Store each stage's information in stage_mods, stage_graphs, and stage_graph_metas - stage_mods[stage_idx] = pp_mod - stage_graphs[stage_idx] = GraphCallables( - fw=graph_callables["fw"], - full_bw=graph_callables["full_bw"], - bw_dI=graph_callables["bw_dI"], - bw_dW=graph_callables["bw_dW"], - unshard=graph_callables["unshard"], - reduce_grad=graph_callables["reduce_grad"], - ) - stage_graph_metas[stage_idx] = GraphMeta( - num_mutate_inputs=graph_meta["num_mutate_inputs"], - num_user_outputs=graph_meta["num_user_outputs"], - num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], - num_params=graph_meta["num_params"], - num_buffers=graph_meta["num_buffers"], - num_input_grads=graph_meta["num_input_grads"], - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": f"end_tracing_stage_{stage_idx}", - "encoding": "string", - }, - payload_fn=lambda: "placeholder text", - ) - - # Two stages per pp rank - assert ( - len(stage_indices_current_pp_rank) - == len(stage_mods) - == len(stage_graphs) - == len(stage_graph_metas) - ) - - world_size = torch.distributed.get_world_size() - num_world_stages = world_size * len(stage_mods) - - numerics_logger = None - if rng_seed is not None: - numerics_logger = NumericsLogger(logs_dir) - numerics_logger.log_pp_model_weights( - model, stage_mods, num_world_stages, should_log=should_log_weights - ) - torch.manual_seed(rng_seed) - - stages = [] - # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata - for pp_stage_idx, pp_stage_mod in stage_mods.items(): - stage = GraphPipelineStage( - pp_stage_mod, - stage_graphs[pp_stage_idx], - stage_graph_metas[pp_stage_idx], - stage_index=pp_stage_idx, - num_stages=len(virtual_pp_stages), - device=device, - input_args=( - shape_inference_input_fn_first_stage() - if pp_stage_idx == 0 - else shape_inference_fn_intermediate_stage() - ), - output_args=( - shape_inference_output_fn_last_stage() - if pp_stage_idx == (total_pp_stages - 1) - else shape_inference_fn_intermediate_stage() - ), - group=world_mesh.get_group("pp"), - numerics_logger=numerics_logger, - should_log_fw_outs=should_log_fw_outs, - ) - stages.append(stage) - - # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank - schedule = build_pipeline_schedule( - stages=stages, - loss_fn=None, - pipeline_parallel_schedule=schedule_name, - microbatch_size=microbatch_size, - local_batch_size=local_batch_size, - pipeline_parallel_degree=pp_degree, - backward_requires_autograd=False, - scale_grads=rng_seed is None, # In determinism mode, don't scale grads - ) - assert isinstance(schedule, _PipelineScheduleRuntime) - - # Step 6. Override the pipeline runner's action implementations - schedule.register_custom_function(FORWARD, stage_forward) - schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) - schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) - schedule.register_custom_function(RESHARD, stage_reshard) - schedule.register_custom_function(UNSHARD, stage_unshard) - schedule.register_custom_function(BACKWARD_INPUT, stage_backward_input) - schedule.register_custom_function(BACKWARD_WEIGHT, stage_backward_weight) - if schedule_name == "DualPipeV": - from autoparallel.graph_passes.graph_multiplex import multiplex_fw_bw_graph - - multiplexed_graph_callables = get_multiplexed_graph_callables( - stage_graphs, - partial(multiplex_fw_bw_graph, overlap_with_annotations=True), - ) - schedule.register_custom_function( - OVERLAP_F_B, partial(overlap_fw_bw, multiplexed_graph_callables) - ) - - # Step 7. Register the schedule with the graph runner - graph_pp_runner = GraphPPRunner(schedule, inductor=use_inductor) - - # Step 8. Run the whole pipeline once using the graph runner - has_last_stage = (total_pp_stages - 1) in stage_mods - execution_fake_mode = ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() - ) - - with execution_fake_mode: - with torch.no_grad(): - target, losses = ( - (runtime_target_fn(), []) - if has_last_stage and use_loss_fn - else (None, None) - ) - if pp_rank == 0: - x = runtime_input_fn_first_stage() - if numerics_logger is not None: - numerics_logger.log_diff( - x.to(torch.float32), prefix="full batch input" - ) - graph_pp_runner.step( - x, target=target, losses=losses, return_outputs=False - ) - else: - graph_pp_runner.step(target=target, losses=losses, return_outputs=False) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "pipeline_step_losses", - "encoding": "string", - }, - payload_fn=lambda: f"losses: {losses}", - ) - if numerics_logger is not None: - numerics_logger.log_pp_grads( - model, stage_mods, num_world_stages, should_log=should_log_weights - ) - - print("All good!") - - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.cuda.synchronize() - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="Run DeepSeek V3 pipeline parallel example" - ) - parser.add_argument( - "--fake-evaluate", - action="store_true", - default=False, - help="Use fake evaluation mode with FakeTensorMode (default: False)", - ) - parser.add_argument( - "--use-loss-fn", - action="store_true", - default=False, - help="Trace loss_fn as part of model forward graph for the last stage (default: False)", - ) - parser.add_argument( - "--rng-seed", - type=int, - default=None, - help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", - ) - parser.add_argument( - "--logs-dir", - type=str, - default="out/", - help="Directory to store logs (default: ./out/).", - ) - parser.add_argument( - "--schedule-name", - type=str, - default="DualPipeV", - choices=["Interleaved1F1B", "ZBVZeroBubble", "DualPipeV"], - help="Schedule to use for PP", - ) - parser.add_argument( - "--use-cache", - action="store_true", - default=False, - help="Use cached graph files if available (default: False)", - ) - parser.add_argument( - "--inductor", - action="store_true", - default=False, - help="Compile subgraphs with Inductor (also forces balanced MoE routing)", - ) - args = parser.parse_args() - - if args.use_cache and not args.fake_evaluate: - parser.error("--use-cache can only be used with --fake-evaluate") - - if args.rng_seed is not None: - torch.use_deterministic_algorithms(True) - torch.manual_seed(args.rng_seed) - - if args.inductor: - # The DSv3 MoE implementation uses .tolist() and data-dependent grouped_mm - # offsets, which Inductor cannot compile. Force balanced routing to make - # all token counts static. - dsv3_module.FORCE_BALANCED_ROUTING = True - - run_test( - fake_evaluate=args.fake_evaluate, - use_loss_fn=args.use_loss_fn, - schedule_name=args.schedule_name, - rng_seed=args.rng_seed, - logs_dir=args.logs_dir, - use_cache=args.use_cache, - use_inductor=args.inductor, - ) diff --git a/examples/example_pp_graph_passes.py b/examples/example_pp_graph_passes.py deleted file mode 100644 index 9c71a062..00000000 --- a/examples/example_pp_graph_passes.py +++ /dev/null @@ -1,424 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from contextlib import nullcontext -from typing import Callable, Union - -import torch -from torch._subclasses.fake_tensor import ( - FakeTensor, - FakeTensorMode, - unset_fake_temporarily, -) -from torch.distributed.tensor import DeviceMesh, DTensor -from torch.distributed.tensor.placement_types import Replicate, Shard -from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.testing._internal.distributed.fake_pg import FakeStore - -from autoparallel import AutoParallelPP -from autoparallel._testing.models.dsv3 import ( - DeepSeekV3Model, - DeepSeekV3ModelArgs, - MoEArgs, - dsv3_loss_fn, -) -from autoparallel.graph_passes.graph_multiplex import multiplex_fw_bw_graph -from autoparallel.graph_passes.graph_pp_runner import ( - GraphCallables, - GraphMeta, - _run_dI_bw_module, - _run_dW_bw_module, - _run_full_bw_module, - _run_fw_module, - _run_multiplexed_fw_bw_module, - _run_reduce_grad_module, - _run_unshard_module, -) - - -def compare_tuples(tuple1: Union[list, tuple], tuple2: Union[list, tuple]) -> bool: - """Compare two tuples element-by-element with specialized comparison logic. - - For each element pair: - - If both are FakeTensor: compare shape, stride, and dtype - - If both are Tensor: use torch.allclose for numerical comparison - - Otherwise: skip the comparison (always matches) - - Args: - tuple1: First tuple to compare. - tuple2: Second tuple to compare. - - Returns: - True if all comparable elements match according to the rules above, False otherwise. - """ - if len(tuple1) != len(tuple2): - return False - with unset_fake_temporarily(): - for elem1, elem2 in zip(tuple1, tuple2): - # Check if both are FakeTensor - if isinstance(elem1, FakeTensor): - if not isinstance(elem2, FakeTensor): - return False - if elem1.dtype != elem2.dtype: - return False - # Try to compare strides or shape, but skip if it would trigger data-dependent guards - try: - if elem1.shape != elem2.shape: - return False - if elem1.stride() != elem2.stride(): - return False - except ( - torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode - ): - # Skip stride or shape comparison for symbolic shapes - pass - # Check if both are regular Tensor (but not FakeTensor) - elif isinstance(elem1, torch.Tensor) and not isinstance(elem1, FakeTensor): - if not ( - isinstance(elem2, torch.Tensor) - and not isinstance(elem2, FakeTensor) - ): - return False - # Use torch.allclose for numerical comparison - if not torch.allclose(elem1, elem2): - return False - # Otherwise, skip the comparison (neither or mismatched types) - - return True - - -def _extract_graph_modules_and_meta( - res: dict, -) -> tuple[GraphCallables, GraphMeta]: - graph_callables = res["graph_callables"] - graph_modules = GraphCallables( - fw=graph_callables["fw"], - full_bw=graph_callables["full_bw"], - bw_dI=graph_callables["bw_dI"], - bw_dW=graph_callables["bw_dW"], - unshard=graph_callables["unshard"], - reduce_grad=graph_callables["reduce_grad"], - ) - graph_meta = res["graph_meta"] - graph_meta = GraphMeta( - num_mutate_inputs=graph_meta["num_mutate_inputs"], - num_user_outputs=graph_meta["num_user_outputs"], - num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], - num_params=graph_meta["num_params"], - num_buffers=graph_meta["num_buffers"], - num_input_grads=graph_meta["num_input_grads"], - ) - return graph_modules, graph_meta - - -def _get_fw_inputs( - pp_mod: torch.nn.Module, eval_input_fn: Callable -) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: - x: list[torch.Tensor] = list(eval_input_fn()) - sharded_params = [ - v.to_local() if isinstance(v, DTensor) else v - for k, v in dict(pp_mod.named_parameters(remove_duplicate=False)).items() - ] - buffers = [ - v.to_local() if isinstance(v, DTensor) else v - for k, v in dict(pp_mod.named_buffers(remove_duplicate=False)).items() - ] - return (sharded_params, buffers, x) - - -# Symbolically evaluate in case you want to test running a graph bigger than your gpu - - -def _run_graph_test( - pp_mod: torch.nn.Module, - graph_modules: GraphCallables, - graph_meta: GraphMeta, - sharded_params: list[torch.Tensor], - buffers: list[torch.Tensor], - x: list[torch.Tensor], - fake_evaluate: bool, - use_fsdp_collectives: bool, - use_split_dI_dW: bool, - use_multiplexed_graph: bool, -) -> None: - """Execute forward and backward passes with specified graph options.""" - if use_multiplexed_graph: - multiplexed_fw_bw_module = multiplex_fw_bw_graph( - graph_modules.fw, graph_modules.full_bw, overlap_with_annotations=True - ) - with ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() - ): - with torch.no_grad(): - # Forward pass setup - if use_fsdp_collectives: - unshard_args = list(sharded_params) - assert graph_modules.unshard is not None - unsharded_params = _run_unshard_module( - graph_modules.unshard, graph_meta, unshard_args - ) - fw_args = [*unsharded_params, *buffers, *x] - else: - fw_args = [*sharded_params, *buffers, *x] - if use_multiplexed_graph: - m_fw_args = list(fw_args) - # Forward pass - loss_or_output, saved_intermediates = _run_fw_module( - graph_modules.fw, graph_meta, fw_args - ) - tangents = [torch.ones_like(loss_or_output)] - tensors_for_backward, non_tensors_for_backward = saved_intermediates - - # Backward pass setup - bw_args = [ - *non_tensors_for_backward, - *tensors_for_backward, - *tangents, - ] - if use_multiplexed_graph: - m_bw_args = list(bw_args) - joint_args = m_bw_args + m_fw_args - del m_bw_args, m_fw_args - ( - m_input_grads, - m_param_buffer_grads, - m_loss_or_output, - m_saved_intermediates, - ) = _run_multiplexed_fw_bw_module( - multiplexed_fw_bw_module, graph_meta, graph_meta, joint_args - ) - ( - m_tensors_for_backward, - m_non_tensors_for_backward, - ) = m_saved_intermediates - assert compare_tuples((m_loss_or_output,), (loss_or_output,)) - assert compare_tuples(m_tensors_for_backward, tensors_for_backward) - assert compare_tuples( - m_non_tensors_for_backward, non_tensors_for_backward - ) - del ( - m_non_tensors_for_backward, - m_tensors_for_backward, - m_loss_or_output, - ) - del ( - tensors_for_backward, - non_tensors_for_backward, - tangents, - saved_intermediates, - ) - - # Backward pass - if use_split_dI_dW: - assert graph_modules.bw_dI is not None - input_grads, activations_for_backward = _run_dI_bw_module( - graph_modules.bw_dI, graph_meta, bw_args - ) - dw_args = list(activations_for_backward) - del activations_for_backward - assert graph_modules.bw_dW is not None - param_buffer_grads = _run_dW_bw_module( - graph_modules.bw_dW, graph_meta, dw_args - ) - else: - input_grads, param_buffer_grads = _run_full_bw_module( - graph_modules.full_bw, graph_meta, bw_args - ) - if use_multiplexed_graph: - assert compare_tuples(m_param_buffer_grads, param_buffer_grads) - assert compare_tuples(m_input_grads, input_grads) - del m_param_buffer_grads, m_input_grads - assert len(param_buffer_grads) == (len(sharded_params) + len(buffers)) - unsharded_grads = list(param_buffer_grads[: len(sharded_params)]) - del param_buffer_grads, input_grads - # Gradient reduction (if using FSDP collectives) - if use_fsdp_collectives: - assert graph_modules.reduce_grad is not None - sharded_grads = _run_reduce_grad_module( - graph_modules.reduce_grad, graph_meta, unsharded_grads - ) - else: - sharded_grads = unsharded_grads - assert len(sharded_grads) == len(sharded_params) - - -def run_all_graph_pass_tests( - model: torch.nn.Module, - mesh: DeviceMesh, - tracing_input_fn: Callable, - eval_input_fn: Callable, - fake_evaluate: bool = True, - use_loss_fn: bool = True, -): - test_configs: list[tuple[str, list[str], bool, bool]] = [ - ("graph_partition", [], False, False), - ("split_fsdp_collectives", ["split_fsdp_collectives"], True, False), - ("split_dI_dW", ["split_dI_dW"], False, True), - ("combined", ["split_fsdp_collectives", "split_dI_dW"], True, True), - ] - - with AutoParallelPP( - model, - tracing_input_fn, - mesh, - dynamic=True, - reshard_after_forward=False, - ) as autop: - autop.add_parameter_memory_constraint(low=None, high=None) - - x_sharding = (Shard(0), Shard(0)) - if use_loss_fn: - autop.add_input_constraints([x_sharding, x_sharding]) - autop.add_output_constraints([(Replicate(), Replicate())]) - else: - autop.add_input_constraints([x_sharding]) - autop.add_output_constraints([x_sharding]) - - sharding_placement = autop.optimize_placement() - - # Get pp_mod and inputs once (identical across all graph_passes configs) - res = autop.apply_placement_pp( - sharding_placement=sharding_placement, - graph_passes=[], - ) - pp_mod = autop.parallel_model - with unset_fake_temporarily(): - pp_mod.to_empty(device="cuda") - pp_mod.init_weights(buffer_device="cuda") - sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) - - for name, graph_passes, use_fsdp, use_split in test_configs: - if graph_passes: - res = autop.apply_placement_pp( - sharding_placement=sharding_placement, - graph_passes=graph_passes, - ) - graph_modules, graph_meta = _extract_graph_modules_and_meta(res) - _run_graph_test( - pp_mod, - graph_modules, - graph_meta, - sharded_params, - buffers, - x, - fake_evaluate, - use_fsdp_collectives=use_fsdp, - use_split_dI_dW=use_split, - use_multiplexed_graph=True, - ) - print(f"{name}: All good!") - - -if __name__ == "__main__": - # must symbolically evaluate to run on 32 dp ranks - # world_size = 2048 - fake_evaluate = True - use_loss_fn = True - - world_size = 256 - - fake_store = FakeStore() - torch.distributed.init_process_group( - "fake", store=fake_store, rank=0, world_size=world_size - ) - # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) - mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - (world_size // 64, 64), - mesh_dim_names=( - "dp", - "ep", - ), - ) - - device = torch.device("cuda") - - bs = 4 * mesh.shape[0] * mesh.shape[1] - seq_len = 1024 - - config = DeepSeekV3ModelArgs( - vocab_size=102400, - max_seq_len=seq_len, - dim=2048, - inter_dim=10944, - moe_inter_dim=1408, - n_layers=1, # 27, - n_dense_layers=0, # 1, - n_heads=16, - moe_args=MoEArgs( - num_experts=64, - num_shared_experts=2, - top_k=6, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, - use_flex_attn=False, - attn_mask_type="causal", - ) - - # parallelize the model - with torch.device("meta"): - model = DeepSeekV3Model(config).bfloat16() - model.tok_embeddings = None # type: ignore[assignment] - - if use_loss_fn: - - class ModelWithLoss(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, h, labels): - output = self.model(h) - return dsv3_loss_fn(output, labels) - - def init_weights(self, *args, **kwargs): - return self.model.init_weights(*args, **kwargs) - - model = ModelWithLoss(model) - - def make_input_fn(sharded: bool = False, with_target: bool = False): - """Create input generator. `sharded` uses mesh-adjusted batch size.""" - - def input_fn() -> tuple[torch.Tensor, ...]: - batch_size = bs // (mesh.shape[0] * mesh.shape[1]) if sharded else bs - - inputs = ( - torch.randn( - (batch_size, seq_len, config.dim), - device=device, - dtype=torch.bfloat16, - requires_grad=True, - ), - ) - if with_target: - inputs += ( - torch.randint( - 0, config.vocab_size, (batch_size, seq_len), device=device - ), - ) - return inputs - - return input_fn - - input_fn = make_input_fn(sharded=False, with_target=use_loss_fn) - eval_fn = make_input_fn(sharded=True, with_target=use_loss_fn) - - run_all_graph_pass_tests(model, mesh, input_fn, eval_fn, fake_evaluate, use_loss_fn) - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() diff --git a/examples/run_ds3_numerics_check.py b/examples/run_ds3_numerics_check.py deleted file mode 100644 index 65075de5..00000000 --- a/examples/run_ds3_numerics_check.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -""" -Script to run DS3 numerics check by comparing outputs from local_map and pipeline parallel. -""" -import shutil -import subprocess -import tempfile -import warnings -from pathlib import Path - - -def run_command(cmd, cwd): - """Run a shell command in the specified directory.""" - print(f"Running: {cmd}") - print(f"In directory: {cwd}") - result = subprocess.run(cmd, shell=True, cwd=cwd, capture_output=True, text=True) - print(result.stdout) - if result.stderr: - print("STDERR:", result.stderr) - if result.returncode != 0: - warnings.warn(f"Command failed with return code {result.returncode}") - return result - - -def main(args): - schedule_name = args.schedule_name - - # Create a temporary directory - temp_dir = tempfile.mkdtemp(prefix="ds3_numerics_check_") - print(f"Created temporary directory: {temp_dir}") - - try: - examples_dir = Path(__file__).parent - - print("\n" + "=" * 80) - print("Running non-PP example with 4 GPUs...") - print("=" * 80) - cmd1 = f"torchrun --standalone --nproc-per-node 4 {examples_dir}/example_ds3_local_map.py --rng-seed 42" - run_command(cmd1, temp_dir) - - print("\n" + "=" * 80) - print("Running PP example with 8 GPUs...") - print("=" * 80) - cmd2 = f"torchrun --standalone --nproc-per-node 8 {examples_dir}/example_ds3_pp.py --rng-seed 42 --schedule-name={schedule_name}" - run_command(cmd2, temp_dir) - - out_dir = Path(temp_dir) / "out" - if not out_dir.exists(): - raise RuntimeError(f"Output directory {out_dir} does not exist") - - print("\n" + "=" * 80) - print("Comparing weights.log files...") - print("=" * 80) - run_command("diff out/0/weights.log out/1/pp_weights.log", temp_dir) - - print("\n" + "=" * 80) - print("Comparing diff.log files...") - print("=" * 80) - run_command("diff out/0/diff.log out/1/diff.log", temp_dir) - - print("\n" + "=" * 80) - print("Numerics check completed successfully!") - print(f"Output directory: {temp_dir}/out") - print("=" * 80) - - except Exception as e: - print(f"\nError occurred: {e}") - print(f"Temporary directory preserved at: {temp_dir}") - raise - - print(f"\nTemporary directory location: {temp_dir}") - response = input("Do you want to delete the temporary directory? (y/n): ") - if response.lower() == "y": - shutil.rmtree(temp_dir) - print("Temporary directory deleted.") - else: - print(f"Temporary directory preserved at: {temp_dir}") - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="Run DeepSeek V3 pipeline parallel example" - ) - parser.add_argument( - "--schedule-name", - type=str, - default="ZBVZeroBubble", - help="Schedule to use for PP", - ) - args = parser.parse_args() - main(args) diff --git a/tests/test_api_pp.py b/tests/test_api_pp.py deleted file mode 100644 index ff8650c4..00000000 --- a/tests/test_api_pp.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch import nn -from torch.distributed.tensor import DTensor - -from autoparallel.api_pp import make_pp_module - - -def _make_sharded_dicts(model, device_mesh): - """Create DTensor param/buffer dicts from a model (replicated placement).""" - param_dict = {} - for name, param in model.named_parameters(): - local = torch.empty_like(param, device="cuda") - dt = DTensor.from_local(local, device_mesh=device_mesh) - param_dict[name] = nn.Parameter(dt, requires_grad=param.requires_grad) - buffer_dict = {} - for name, buf in model.named_buffers(): - local = torch.empty_like(buf, device="cuda") - dt = DTensor.from_local(local, device_mesh=device_mesh) - buffer_dict[name] = dt - return param_dict, buffer_dict - - -def test_pp_init_weights_basic(device_mesh_1d): - """Basic init_weights with in-place fills on the PP module.""" - dim = 128 - - class Model(nn.Module): - def __init__(self, dim): - super().__init__() - self.linear = nn.Linear(dim, dim) - self.register_buffer("buf", torch.empty(dim)) - - def forward(self, x): - return self.linear(x) + self.buf - - def init_weights(self): - with torch.no_grad(): - self.linear.weight.fill_(3.0) - self.linear.bias.fill_(7.0) - self.buf.fill_(5.0) - - with torch.device("meta"): - model = Model(dim) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - pp_mod.init_weights() - - assert torch.equal( - pp_mod.get_parameter("linear.weight").full_tensor(), - torch.full((dim, dim), 3.0, device="cuda"), - ) - assert torch.equal( - pp_mod.get_parameter("linear.bias").full_tensor(), - torch.full((dim,), 7.0, device="cuda"), - ) - assert torch.equal( - pp_mod.get_buffer("buf").full_tensor(), - torch.full((dim,), 5.0, device="cuda"), - ) - - -def test_pp_init_weights_setattr(device_mesh_1d): - """init_weights that assigns new Parameters and buffers via setattr.""" - dim = 128 - - class Model(nn.Module): - def __init__(self, dim): - super().__init__() - self.linear = nn.Linear(dim, dim) - self.register_buffer("buf", torch.empty(dim)) - - def forward(self, x): - return self.linear(x) + self.buf - - def init_weights(self): - self.linear.weight = nn.Parameter(torch.ones(dim, dim) * 9.0) - self.buf = torch.arange(dim, dtype=torch.float32) - - with torch.device("meta"): - model = Model(dim) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - pp_mod.init_weights() - - assert torch.equal( - pp_mod.get_parameter("linear.weight").full_tensor(), - torch.full((dim, dim), 9.0, device="cuda"), - ) - assert torch.equal( - pp_mod.get_buffer("buf").full_tensor(), - torch.arange(dim, dtype=torch.float32, device="cuda"), - ) - - -def test_pp_init_weights_submodule(device_mesh_1d): - """init_weights that delegates to submodule init_weights.""" - dim = 128 - - class MLP(nn.Module): - def __init__(self, dim): - super().__init__() - self.fc1 = nn.Linear(dim, dim) - self.fc2 = nn.Linear(dim, dim) - - def forward(self, x): - return self.fc2(self.fc1(x)) - - def init_weights(self): - with torch.no_grad(): - self.fc1.weight.fill_(1.0) - self.fc1.bias.fill_(0.0) - self.fc2.weight.fill_(2.0) - self.fc2.bias.fill_(0.5) - - class Model(nn.Module): - def __init__(self, dim): - super().__init__() - self.mlp = MLP(dim) - - def forward(self, x): - return self.mlp(x) - - def init_weights(self): - self.mlp.init_weights() - - with torch.device("meta"): - model = Model(dim) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - pp_mod.init_weights() - - assert torch.equal( - pp_mod.get_parameter("mlp.fc1.weight").full_tensor(), - torch.ones(dim, dim, device="cuda"), - ) - assert torch.equal( - pp_mod.get_parameter("mlp.fc2.weight").full_tensor(), - torch.full((dim, dim), 2.0, device="cuda"), - ) - assert torch.equal( - pp_mod.get_parameter("mlp.fc2.bias").full_tensor(), - torch.full((dim,), 0.5, device="cuda"), - ) - - -def test_pp_init_weights_load_state_dict(device_mesh_1d): - """init_weights that uses load_state_dict.""" - dim = 128 - - class Model(nn.Module): - def __init__(self, dim): - super().__init__() - self.linear = nn.Linear(dim, dim) - - def forward(self, x): - return self.linear(x) - - def init_weights(self): - state = { - "linear.weight": torch.ones(dim, dim) * 4.0, - "linear.bias": torch.full((dim,), 2.0), - } - self.load_state_dict(state) - - with torch.device("meta"): - model = Model(dim) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - pp_mod.init_weights() - - assert torch.equal( - pp_mod.get_parameter("linear.weight").full_tensor(), - torch.full((dim, dim), 4.0, device="cuda"), - ) - assert torch.equal( - pp_mod.get_parameter("linear.bias").full_tensor(), - torch.full((dim,), 2.0, device="cuda"), - ) - - -def test_pp_init_weights_user_helper_method(device_mesh_1d): - """init_weights that calls a user-defined helper method on self.""" - dim = 128 - - class Model(nn.Module): - def __init__(self, dim): - super().__init__() - self.linear = nn.Linear(dim, dim) - - def forward(self, x): - return self.linear(x) - - def _init_linear(self, linear): - with torch.no_grad(): - linear.weight.fill_(6.0) - linear.bias.fill_(1.0) - - def init_weights(self): - self._init_linear(self.linear) - - with torch.device("meta"): - model = Model(dim) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - - assert isinstance(pp_mod, Model) - pp_mod.init_weights() - - assert torch.equal( - pp_mod.get_parameter("linear.weight").full_tensor(), - torch.full((dim, dim), 6.0, device="cuda"), - ) - assert torch.equal( - pp_mod.get_parameter("linear.bias").full_tensor(), - torch.full((dim,), 1.0, device="cuda"), - ) - - -def test_pp_init_weights_named_parameters(device_mesh_1d): - """init_weights that iterates self.named_parameters().""" - dim = 128 - - class Model(nn.Module): - def __init__(self, dim): - super().__init__() - self.linear1 = nn.Linear(dim, dim) - self.linear2 = nn.Linear(dim, dim) - - def forward(self, x): - return self.linear2(self.linear1(x)) - - def init_weights(self): - for name, param in self.named_parameters(): - with torch.no_grad(): - if "weight" in name: - param.fill_(1.0) - else: - param.fill_(0.0) - - with torch.device("meta"): - model = Model(dim) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - pp_mod.init_weights() - - assert torch.equal( - pp_mod.get_parameter("linear1.weight").full_tensor(), - torch.ones(dim, dim, device="cuda"), - ) - assert torch.equal( - pp_mod.get_parameter("linear2.bias").full_tensor(), - torch.zeros(dim, device="cuda"), - ) - - -def test_pp_init_weights_optional_submodule(device_mesh_1d): - """init_weights that checks for an optional submodule (self.rope is not None). - - Mirrors the torchtitan Decoder pattern where rope may or may not be present. - When rope is None, the parallel model must still have the attribute so the - None check doesn't raise AttributeError. - """ - dim = 128 - - class Model(nn.Module): - def __init__(self, dim, use_rope=False): - super().__init__() - self.linear = nn.Linear(dim, dim) - self.rope = nn.Linear(dim, dim) if use_rope else None - - def forward(self, x): - return self.linear(x) - - def init_weights(self): - with torch.no_grad(): - self.linear.weight.fill_(1.0) - self.linear.bias.fill_(0.0) - if self.rope is not None: - with torch.no_grad(): - self.rope.weight.fill_(2.0) - - with torch.device("meta"): - model = Model(dim, use_rope=False) - - param_dict, buffer_dict = _make_sharded_dicts(model, device_mesh_1d) - pp_mod = make_pp_module(param_dict, buffer_dict, model) - pp_mod.init_weights() - - assert pp_mod.rope is None - assert torch.equal( - pp_mod.get_parameter("linear.weight").full_tensor(), - torch.ones(dim, dim, device="cuda"), - )