Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ jobs:
run_timed python examples/example_autoparallel.py
run_timed python examples/example_llama3.py
run_timed python examples/example_local_map.py
# TODO(#436): Re-enable once OpStrategy.__str__ handles None specs in PyTorch.
# run_timed python examples/example_pp_graph_passes.py
echo "========== Timings =========="
cat /tmp/timings.txt

Expand All @@ -83,5 +81,3 @@ jobs:
python examples/example_dcp.py
# TODO(#436): Re-enable once OpStrategy.__str__ handles None specs in PyTorch.
# torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py
# Skipped: graph PP is being moved out of AutoParallel shortly.
# torchrun --standalone --nproc_per_node=4 examples/example_ds3_pp.py --use-loss-fn --fake-evaluate
2 changes: 0 additions & 2 deletions autoparallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
# LICENSE file in the root directory of this source tree.

from autoparallel.api import AutoParallel, auto_parallel
from autoparallel.api_pp import AutoParallelPP
from autoparallel.collectives import with_sharding_constraint
from autoparallel.compile import autoparallel_backend

__all__ = [
"auto_parallel",
"AutoParallel",
"AutoParallelPP",
"autoparallel_backend",
"with_sharding_constraint",
]
79 changes: 4 additions & 75 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Literal, Optional, Tuple, Union
from typing import Callable, ClassVar, Literal, Optional, Tuple

import torch
import torch.fx.traceback as fx_traceback
Expand Down Expand Up @@ -1621,86 +1621,15 @@ def dsv3_loss_fn(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
)


########################
# Pipeline stuff start #
########################


class DeepSeekV3StageI(nn.Module):
def __init__(self, layers, model_args):
super().__init__()
self.layers = layers
self.register_buffer(
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
)
self.model_args = model_args

def forward(self, h):
# intermediate stages only have layers
for layer in self.layers.values():
h = layer(h, self.freqs_cis)
return h

def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
_init_weights_layers(self, buffer_device, seed)


class DeepSeekV3Stage0(DeepSeekV3StageI):
def __init__(self, embed, layers, model_args):
super().__init__(layers, model_args)
self.tok_embeddings = embed

def forward(self, tokens):
# torch.Size([1024, 1024])
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
# torch.Size([1024, 1024, 2048])
return super().forward(h)

def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
_init_weights_tok_embeddings(self, seed)
super().init_weights(buffer_device, seed)


class DeepSeekV3StageN(DeepSeekV3StageI):
def __init__(self, layers, norm, output, model_args):
super().__init__(layers, model_args)
self.norm = norm
self.output = output
self.model_args = model_args

def forward(self, h):
h = super().forward(h)
h = self.norm(h) if self.norm is not None else h
output = self.output(h) if self.output is not None else h
return output

def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
super().init_weights(buffer_device, seed)
_init_weights_norm_and_output(self)


######################
# Pipeline stuff end #
######################


def _init_weights_tok_embeddings(
self: Union[DeepSeekV3Model, DeepSeekV3Stage0], seed: int | None = None
):
def _init_weights_tok_embeddings(self: DeepSeekV3Model, seed: int | None = None):
if seed is not None:
torch.manual_seed(seed)
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)


def _init_weights_layers(
self: Union[DeepSeekV3Model, DeepSeekV3StageI],
self: DeepSeekV3Model,
buffer_device: torch.device | None,
seed: int | None = None,
):
Expand All @@ -1716,7 +1645,7 @@ def _init_weights_layers(
layer.init_weights(buffer_device) # type: ignore[arg-type]


def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]):
def _init_weights_norm_and_output(self: DeepSeekV3Model):
if self.norm is not None:
self.norm.reset_parameters()
if self.output is not None:
Expand Down
210 changes: 0 additions & 210 deletions autoparallel/api_pp.py

This file was deleted.

Loading
Loading