From 15da60f0ec084e67bba7a5beebdfeea790433e84 Mon Sep 17 00:00:00 2001 From: Kaijian Wang Date: Tue, 26 May 2026 11:33:57 -0700 Subject: [PATCH] Add Qwen3 AutoParallel model and examples --- autoparallel/_testing/models/dsv3.py | 2 +- autoparallel/_testing/models/qwen3.py | 974 +++++++++++++++++++++ examples/example_qwen3.py | 252 ++++++ examples/example_sanity_check_qwen3.py | 335 +++++++ examples/example_sanity_check_qwen3_moe.py | 466 ++++++++++ examples/example_torchtitan_qwen3_dense.py | 370 ++++++++ tests/test_dsv3_torchtitan_config.py | 37 + tests/test_qwen3.py | 323 +++++++ 8 files changed, 2758 insertions(+), 1 deletion(-) create mode 100644 autoparallel/_testing/models/qwen3.py create mode 100644 examples/example_qwen3.py create mode 100644 examples/example_sanity_check_qwen3.py create mode 100644 examples/example_sanity_check_qwen3_moe.py create mode 100644 examples/example_torchtitan_qwen3_dense.py create mode 100644 tests/test_dsv3_torchtitan_config.py create mode 100644 tests/test_qwen3.py diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 5a897b71..05f78a92 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -1581,7 +1581,7 @@ def __init__( route_norm=moe_cfg.router.route_norm, route_scale=moe_cfg.router.route_scale, score_before_experts=moe_cfg.experts.token_dispatcher.score_before_experts, - use_grouped_mm=moe_cfg.experts.use_grouped_mm, + use_grouped_mm=getattr(moe_cfg.experts, "use_grouped_mm", True), load_balance_coeff=moe_cfg.load_balance_coeff, mesh=mesh, compute_dtype=compute_dtype, diff --git a/autoparallel/_testing/models/qwen3.py b/autoparallel/_testing/models/qwen3.py new file mode 100644 index 00000000..f2a194c2 --- /dev/null +++ b/autoparallel/_testing/models/qwen3.py @@ -0,0 +1,974 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from typing import Callable, ClassVar, Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard +from torch.fx import traceback as fx_traceback +from torch.nn.attention import SDPBackend, sdpa_kernel + +from autoparallel._testing.models.dsv3 import ( + _permute, + _run_experts_for_loop, + _run_experts_grouped_mm, + _token_combine, +) +from autoparallel.collectives import all_to_all, axis_size, local_map + + +def has_cuda_capability(major: int, minor: int) -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + +class ScaledDotProductAttention(torch.nn.Module): + backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError("Qwen3 with SDPA currently only supports causal mask.") + + ScaledDotProductAttention._init_backend() + + @classmethod + def _init_backend(cls) -> None: + if cls.backends: + return + + cls.backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + if has_cuda_capability(10, 0): + cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, + ) -> torch.Tensor: + assert self.backends, "SDPA backends should not be empty." + with sdpa_kernel(self.backends, set_priority=True): + return F.scaled_dot_product_attention( + q, + k, + v, + is_causal=True, + scale=scale, + ) + + +def build_attention(attn_mask_type: str): + if attn_mask_type != "causal": + raise ValueError("Qwen3 with SDPA currently only supports causal mask.") + return ScaledDotProductAttention(attn_mask_type) + + +@dataclass +class Qwen3ModelArgs: + dim: int = 4096 + n_layers: int = 36 + n_heads: int = 32 + n_kv_heads: Optional[int] = 8 + head_dim: int = 128 + hidden_dim: int = 12288 + vocab_size: int = 151936 + norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + max_seq_len: int = 4096 + depth_init: bool = True + attn_mask_type: str = "causal" + eos_id: int = 0 + enable_weight_tying: bool = False + moe_enabled: bool = False + moe_hidden_dim: int = 768 + num_experts: int = 64 + top_k: int = 8 + route_norm: bool = True + route_scale: float = 1.0 + score_before_experts: bool = False + use_grouped_mm: bool = True + load_balance_coeff: Optional[float] = 1e-3 + moe_axis_name: str = "ep" + + def __post_init__(self) -> None: + n_kv_heads = self.n_heads if self.n_kv_heads is None else self.n_kv_heads + if self.n_heads % n_kv_heads != 0: + raise ValueError( + f"n_heads ({self.n_heads}) must be divisible by " + f"n_kv_heads ({n_kv_heads})." + ) + if self.moe_enabled and self.top_k > self.num_experts: + raise ValueError( + f"top_k ({self.top_k}) must be <= num_experts ({self.num_experts})." + ) + + def update_from_config(self, job_config, tokenizer) -> None: + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len + self.eos_id = tokenizer.eos_id + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.head_dim, + seq_len, + ) + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + return nparams, num_flops_per_token + + +def qwen3_args_from_torchtitan_config(config) -> Qwen3ModelArgs: + """Build AutoParallel Qwen3 args from TorchTitan's Qwen3Model.Config.""" + if not config.layers: + raise ValueError("Qwen3 config must contain at least one layer.") + + first_layer = config.layers[0] + attention = first_layer.attention + moe = first_layer.moe + + if getattr(attention, "fuse_qkv", False): + raise ValueError("AutoParallel Qwen3 does not support fused QKV yet.") + + moe_enabled = moe is not None + if moe_enabled: + hidden_dim = 0 + moe_hidden_dim = moe.experts.hidden_dim + num_experts = moe.num_experts + top_k = moe.router.top_k + route_norm = moe.router.route_norm + route_scale = moe.router.route_scale + score_before_experts = moe.experts.token_dispatcher.score_before_experts + load_balance_coeff = moe.load_balance_coeff + else: + hidden_dim = first_layer.feed_forward.w1.out_features + moe_hidden_dim = 0 + num_experts = 0 + top_k = 1 + route_norm = True + route_scale = 1.0 + score_before_experts = False + load_balance_coeff = None + + return Qwen3ModelArgs( + dim=config.dim, + n_layers=len(config.layers), + n_heads=attention.n_heads, + n_kv_heads=attention.n_kv_heads, + head_dim=attention.head_dim, + hidden_dim=hidden_dim, + vocab_size=config.vocab_size, + norm_eps=config.norm.eps, + rope_theta=config.rope.theta, + max_seq_len=config.rope.max_seq_len, + attn_mask_type=attention.mask_type, + enable_weight_tying=config.enable_weight_tying, + moe_enabled=moe_enabled, + moe_hidden_dim=moe_hidden_dim, + num_experts=num_experts, + top_k=top_k, + route_norm=route_norm, + route_scale=route_scale, + score_before_experts=score_before_experts, + load_balance_coeff=load_balance_coeff, + ) + + +def qwen3_debug_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=256, + n_layers=8, + n_heads=16, + n_kv_heads=8, + head_dim=128, + hidden_dim=3072, + vocab_size=2048, + max_seq_len=4096, + enable_weight_tying=True, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_0_6b_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=1024, + n_layers=28, + n_heads=16, + n_kv_heads=8, + head_dim=128, + hidden_dim=3072, + vocab_size=151936, + enable_weight_tying=True, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_1_7b_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=2048, + n_layers=28, + n_heads=16, + n_kv_heads=8, + head_dim=128, + hidden_dim=6144, + vocab_size=151936, + enable_weight_tying=True, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_4b_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=2560, + n_layers=36, + n_heads=32, + n_kv_heads=8, + head_dim=128, + hidden_dim=9728, + vocab_size=151936, + enable_weight_tying=True, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_8b_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs() + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_moe_debug_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=256, + n_layers=8, + n_heads=16, + n_kv_heads=8, + head_dim=128, + hidden_dim=3072, + vocab_size=2048, + max_seq_len=4096, + moe_enabled=True, + moe_hidden_dim=768, + num_experts=64, + top_k=8, + route_norm=True, + score_before_experts=False, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_30b_a3b_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=2048, + n_layers=48, + n_heads=32, + n_kv_heads=4, + head_dim=128, + hidden_dim=6144, + vocab_size=151936, + max_seq_len=262144, + moe_enabled=True, + moe_hidden_dim=768, + num_experts=128, + top_k=8, + route_norm=True, + score_before_experts=False, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def qwen3_235b_a22b_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=4096, + n_layers=94, + n_heads=64, + n_kv_heads=4, + head_dim=128, + hidden_dim=12288, + vocab_size=151936, + max_seq_len=4096, + rope_theta=5000000.0, + moe_enabled=True, + moe_hidden_dim=1536, + num_experts=128, + top_k=8, + route_norm=True, + score_before_experts=False, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def precompute_freqs_cos_sin( + dim: int, + max_seq_len: int, + theta: float = 1000000.0, +) -> torch.Tensor: + freq = theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + inv_freq = 1.0 / freq + t = torch.arange(max_seq_len, dtype=inv_freq.dtype, device=inv_freq.device) + freqs = torch.outer(t, inv_freq).float() + freqs = torch.cat([freqs, freqs], dim=-1) + cos = freqs.cos() + sin = freqs.sin() + return torch.cat([cos, sin], dim=-1) + + +def reshape_for_broadcast_cos_sin( + rope_cache: torch.Tensor, + x: torch.Tensor, +) -> torch.Tensor: + bsz, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + assert rope_cache.shape == (seqlen, head_dim * 2) + return rope_cache.view(1, seqlen, 1, head_dim * 2).expand(bsz, -1, -1, -1) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_emb_cos_sin( + xq: torch.Tensor, + xk: torch.Tensor, + rope_cache: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = xq.shape[-1] + rope_cache = reshape_for_broadcast_cos_sin(rope_cache, xq) + cos = rope_cache[..., :head_dim].to(device=xq.device) + sin = rope_cache[..., head_dim:].to(device=xq.device) + xq_f = xq.float() + xk_f = xk.float() + xq_out = (xq_f * cos) + (_rotate_half(xq_f) * sin) + xk_out = (xk_f * cos) + (_rotate_half(xk_f) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def _to_activation_device( + tensor: torch.Tensor, activation: torch.Tensor +) -> torch.Tensor: + if tensor.device != activation.device and tensor.device.type == "meta": + return tensor.to(activation.device) + return tensor + + +def _rms_norm(x: torch.Tensor, norm: nn.RMSNorm) -> torch.Tensor: + weight = _to_activation_device(norm.weight, x) if norm.weight is not None else None + if weight is not None and weight.dtype != x.dtype: + weight = weight.to(dtype=x.dtype) + return F.rms_norm(x, norm.normalized_shape, weight, norm.eps).to(dtype=x.dtype) + + +def _linear(x: torch.Tensor, linear: nn.Linear) -> torch.Tensor: + weight = _to_activation_device(linear.weight, x) + bias = _to_activation_device(linear.bias, x) if linear.bias is not None else None + if weight.dtype != x.dtype: + weight = weight.to(dtype=x.dtype) + if bias is not None and bias.dtype != x.dtype: + bias = bias.to(dtype=x.dtype) + return F.linear(x, weight, bias) + + +class Attention(nn.Module): + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.head_dim + self.scale = self.head_dim**-0.5 + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.q_norm = nn.RMSNorm(self.head_dim, eps=model_args.norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=model_args.norm_eps) + self.sdpa = build_attention(model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + self.q_norm.reset_parameters() + self.k_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + freqs_cos_sin: torch.Tensor, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = _linear(x, self.wq), _linear(x, self.wk), _linear(x, self.wv) + + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq = _rms_norm(xq, self.q_norm) + xk = _rms_norm(xk, self.k_norm) + freqs_cos_sin = _to_activation_device(freqs_cos_sin, xq) + xq, xk = apply_rotary_emb_cos_sin(xq, xk, freqs_cos_sin) + + keys = repeat_kv(xk, self.n_rep) + values = repeat_kv(xv, self.n_rep) + + xq = xq.transpose(1, 2) + xk = keys.transpose(1, 2) + xv = values.transpose(1, 2) + + output = self.sdpa(xq, xk, xv, scale=self.scale) + + output = output.transpose(1, 2).contiguous() + output = output.view(bs, seqlen, -1) + return _linear(output, self.wo) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return _linear(F.silu(_linear(x, self.w1)) * _linear(x, self.w3), self.w2) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.use_grouped_mm = use_grouped_mm + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + if self.use_grouped_mm: + return _run_experts_grouped_mm( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + return _run_experts_for_loop( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + +def _qwen3_token_dispatch(routed_input, num_tokens_per_expert, axis_name): + ep_size = axis_size(axis_name) + num_tokens_per_expert_group = all_to_all( + num_tokens_per_expert, + None, + None, + axis_name, + ) + + with torch.no_grad(): + input_splits = ( + num_tokens_per_expert.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + output_splits = ( + num_tokens_per_expert_group.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + input_splits = input_splits.tolist() + output_splits = output_splits.tolist() + + with fx_traceback.annotate({"comm_region": "token_dispatch"}): + routed_input = all_to_all( + routed_input, + output_splits, + input_splits, + axis_name, + ) + + num_local_experts = num_tokens_per_expert_group.shape[0] // ep_size + return ( + *_permute( + routed_input, + num_tokens_per_expert_group, + ep_size, + num_local_experts, + ), + input_splits, + output_splits, + ) + + +def qwen3_moe_local_mapped_region( + x: torch.Tensor, + selected_experts_indices: torch.Tensor, + top_scores: torch.Tensor, + experts_w1: torch.Tensor, + experts_w3: torch.Tensor, + experts_w2: torch.Tensor, + out: torch.Tensor, + top_k: int, + num_experts: int, + score_before_experts: bool, + axis_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + dim = x.shape[-1] + ep_size = axis_size(axis_name) + if num_experts % ep_size != 0: + raise ValueError( + f"num_experts ({num_experts}) must be divisible by " + f"axis_size({axis_name!r}) ({ep_size})." + ) + + num_tokens_per_expert = torch.histc( + selected_experts_indices.flatten(), + bins=num_experts, + min=0, + max=num_experts, + ).view(-1) + + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // top_k + + routed_input = x[token_indices_experts_sorted] + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + shape = routed_input.shape + ( + input_shape, + routed_input, + permuted_indices, + num_tokens_per_expert_group, + input_splits, + output_splits, + ) = _qwen3_token_dispatch(routed_input, num_tokens_per_expert, axis_name) + + routed_output = _run_experts_grouped_mm( + experts_w1, + experts_w2, + experts_w3, + routed_input, + num_tokens_per_expert_group, + ) + routed_output = _token_combine( + routed_output, + input_shape, + permuted_indices, + input_splits, + output_splits, + axis_name, + ) + + torch._check(routed_output.shape[0] == shape[0]) + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(routed_output.dtype) + + out = out.scatter_add( + dim=0, + index=token_indices_experts_sorted.reshape(-1, 1).expand(-1, dim), + src=routed_output, + ) + return out, num_tokens_per_expert + + +class MoE(nn.Module): + def __init__( + self, + model_args: Qwen3ModelArgs, + mesh: DeviceMesh | None = None, + axis_name: str | None = None, + ): + super().__init__() + self.mesh = mesh + self.axis_name = axis_name or model_args.moe_axis_name + self.num_experts = model_args.num_experts + self.top_k = model_args.top_k + self.route_norm = model_args.route_norm + self.route_scale = model_args.route_scale + self.score_before_experts = model_args.score_before_experts + self.load_balance_coeff = model_args.load_balance_coeff + + self.router = nn.Linear(model_args.dim, model_args.num_experts, bias=False) + self.experts = GroupedExperts( + dim=model_args.dim, + hidden_dim=model_args.moe_hidden_dim, + num_experts=model_args.num_experts, + use_grouped_mm=model_args.use_grouped_mm, + ) + self.register_buffer( + "expert_bias", + torch.zeros(model_args.num_experts, dtype=torch.float32), + persistent=self.load_balance_coeff is not None, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(model_args.num_experts, dtype=torch.float32), + persistent=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bs, slen, dim = x.shape + x = x.view(-1, dim) + # Annotate as plain tensors: parameters() yields Parameter, but + # _to_activation_device returns Tensor, and we reassign in place. + experts_w1: torch.Tensor + experts_w2: torch.Tensor + experts_w3: torch.Tensor + experts_w1, experts_w2, experts_w3 = self.experts.parameters() + experts_w1 = _to_activation_device(experts_w1, x) + experts_w2 = _to_activation_device(experts_w2, x) + experts_w3 = _to_activation_device(experts_w3, x) + + scores = F.linear( + x.to(torch.float32), + _to_activation_device(self.router.weight, x).to(torch.float32), + None, + ) + scores = F.softmax(scores, dim=-1) + expert_bias = _to_activation_device(self.expert_bias, scores) # type: ignore[arg-type] + scores_for_choice = ( + scores + expert_bias if self.load_balance_coeff is not None else scores + ) + _, selected_experts_indices = torch.topk( + scores_for_choice, + k=self.top_k, + dim=-1, + sorted=False, + ) + + top_scores = scores.gather(dim=-1, index=selected_experts_indices) + if self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + + # Qwen3 MoE has no shared expert path, but keeping the initial output + # differentiably tied to x matches the DSv3 local_map autograd shape. + out = x * 0 + out, num_tokens_per_expert = local_map( + qwen3_moe_local_mapped_region, + out_placements=( + (Shard(0), Shard(0)), + (Partial(reduce_op="sum"), Partial(reduce_op="sum")), + ), + in_placements=( + (Shard(0), Shard(0)), + (Shard(0), Shard(0)), + (Shard(0), Shard(0)), + (Replicate(), Shard(0)), + (Replicate(), Shard(0)), + (Replicate(), Shard(0)), + (Shard(0), Shard(0)), + None, + None, + None, + None, + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=self.mesh, + )( + x, + selected_experts_indices, + top_scores, + experts_w1, + experts_w3, + experts_w2, + out, + self.top_k, + self.num_experts, + self.score_before_experts, + self.axis_name, + ) + # This counter is only used for runtime load-balance diagnostics. During + # AutoParallel graph capture the module buffers are fake/meta tensors + # while the traced local_map output can be CUDA-fake, and recording this + # mutation is not needed for the solved training graph. + if not torch.compiler.is_compiling(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # type: ignore[operator] + return out.reshape(bs, slen, dim) + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + nn.init.trunc_normal_(self.router.weight, mean=0.0, std=init_std) + self.experts.init_weights(init_std) + with torch.device(buffer_device): + self.tokens_per_expert.zero_() # type: ignore[operator] + self.expert_bias.zero_() # type: ignore[operator] + + +class TransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + model_args: Qwen3ModelArgs, + mesh: DeviceMesh | None = None, + moe_axis_name: str | None = None, + ): + super().__init__() + self.attention = Attention(model_args) + self.moe_enabled = model_args.moe_enabled + if self.moe_enabled: + self.moe = MoE(model_args, mesh=mesh, axis_name=moe_axis_name) + else: + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=model_args.hidden_dim, + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / math.sqrt(2 * (layer_id + 1)) + else: + self.weight_init_std = 0.02 / math.sqrt(2 * model_args.n_layers) + + def forward( + self, + x: torch.Tensor, + freqs_cos_sin: torch.Tensor, + ): + h = x + self.attention(_rms_norm(x, self.attention_norm), freqs_cos_sin) + if self.moe_enabled: + out = h + self.moe(_rms_norm(h, self.ffn_norm)) + else: + out = h + self.feed_forward(_rms_norm(h, self.ffn_norm)) + return out + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std, buffer_device) + else: + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module): + def __init__( + self, + model_args: Qwen3ModelArgs, + mesh: DeviceMesh | None = None, + moe_axis_name: str | None = None, + ): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + self.enable_weight_tying = model_args.enable_weight_tying + self.mesh = mesh + self.moe_axis_name = moe_axis_name or model_args.moe_axis_name + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "freqs_cos_sin", + self._precompute_freqs_cos_sin(), + persistent=True, + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock( + layer_id, + model_args, + mesh=mesh, + moe_axis_name=self.moe_axis_name, + ) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.lm_head = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + if self.enable_weight_tying: + self.tok_embeddings.weight = self.lm_head.weight + + def init_weights( + self, + buffer_device: Optional[torch.device] = None, + seed: int | None = None, + ): + if seed is not None: + torch.manual_seed(seed) + + if self.enable_weight_tying: + self.tok_embeddings.weight = self.lm_head.weight + + buffer_device = buffer_device or self.freqs_cos_sin.device # type: ignore[has-type,assignment] + with torch.device(buffer_device): # type: ignore[arg-type] + self.freqs_cos_sin = self._precompute_freqs_cos_sin() + + if not self.enable_weight_tying and self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device) # type: ignore[operator] + if self.norm is not None: + self.norm.reset_parameters() + + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.lm_head is not None: + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + if self.enable_weight_tying: + self.tok_embeddings.weight = self.lm_head.weight + + def _precompute_freqs_cos_sin(self) -> torch.Tensor: + return precompute_freqs_cos_sin( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def _token_embedding(self, tokens: torch.Tensor) -> torch.Tensor: + weight = self.tok_embeddings.weight + if weight.device != tokens.device and weight.device.type == "meta": + weight = weight.to(tokens.device) + return F.embedding(tokens, weight) + + def forward(self, tokens: torch.Tensor, input_batch: Optional[torch.Tensor] = None): + h = self._token_embedding(tokens) if self.tok_embeddings is not None else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cos_sin) + + h = _rms_norm(h, self.norm) if self.norm is not None else h + output = _linear(h, self.lm_head) if self.lm_head is not None else h + return output + + +_MODULE_FQN = "module_fqn" + + +def _annotate_once(fn: Callable, meta: dict): + if getattr(fn, "_graph_trainer_annotated", False): + return fn + wrapped = fx_traceback.annotate_fn(meta)(fn) + setattr(wrapped, "_graph_trainer_annotated", True) + return wrapped + + +def _annotate_module_fqns(model: nn.Module) -> None: + for fqn, submodule in model.named_modules(): + if fqn: + submodule.forward = _annotate_once( + submodule.forward, + {_MODULE_FQN: fqn}, + ) + + +def annotate_qwen3_for_graph_trainer(model: Transformer) -> None: + """Attach graph_trainer-compatible FX annotations to AP's Qwen3 model.""" + global qwen3_moe_local_mapped_region + + qwen3_moe_local_mapped_region = _annotate_once( + qwen3_moe_local_mapped_region, + {"EP": "compute"}, + ) + MoE.forward = _annotate_once( # type: ignore[method-assign] + MoE.forward, + {"EP": "compute"}, + ) + _annotate_module_fqns(model) diff --git a/examples/example_qwen3.py b/examples/example_qwen3.py new file mode 100644 index 00000000..d82c0c14 --- /dev/null +++ b/examples/example_qwen3.py @@ -0,0 +1,252 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import time + +import torch +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel._testing.models.qwen3 import ( + Qwen3ModelArgs, + Transformer, + qwen3_8b_args, + qwen3_30b_a3b_args, + qwen3_235b_a22b_args, + qwen3_debug_args, + qwen3_moe_debug_args, +) +from autoparallel.api import AutoParallel +from autoparallel.compile import autoparallel_backend + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Trace, optimize, and smoke-test dense Qwen3 with AutoParallel." + ) + parser.add_argument( + "--flavor", + choices=( + "tiny", + "moe-tiny", + "debug", + "8b", + "moe-debug", + "30b-a3b", + "235b-a22b", + ), + default="tiny", + help="Qwen3 model size to instantiate. Defaults to tiny for faster runs.", + ) + parser.add_argument( + "--seq-len", + type=int, + default=None, + help="Sequence length. Defaults to 8 for tiny, 512 for debug, and 4096 for 8b.", + ) + parser.add_argument( + "--world-size", + type=int, + default=64, + help="Fake process-group world size.", + ) + parser.add_argument( + "--tp-degree", + type=int, + default=8, + help="Second mesh degree. Used as TP for dense flavors and EP for MoE flavors.", + ) + parser.add_argument( + "--local-batch-size", + type=int, + default=2, + help="Per-DP-rank batch size used for the runtime smoke pass.", + ) + parser.add_argument( + "--save-optimizer", + type=str, + default=None, + help="Optional path for the serialized sharding optimizer state.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Compile the placed module with the AutoParallel backend before running.", + ) + parser.add_argument( + "--skip-run", + action="store_true", + help="Only run tracing, optimization, and placement application.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print the full AutoParallel optimizer log.", + ) + return parser.parse_args() + + +def make_model_args(flavor: str, seq_len: int): + if flavor == "tiny": + return Qwen3ModelArgs( + dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=16, + hidden_dim=128, + vocab_size=128, + max_seq_len=seq_len, + ) + if flavor == "moe-tiny": + return Qwen3ModelArgs( + dim=64, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=16, + hidden_dim=128, + vocab_size=128, + max_seq_len=seq_len, + moe_enabled=True, + moe_hidden_dim=32, + num_experts=8, + top_k=2, + route_norm=True, + score_before_experts=False, + ) + if flavor == "debug": + return qwen3_debug_args(max_seq_len=seq_len) + if flavor == "8b": + return qwen3_8b_args(max_seq_len=seq_len) + if flavor == "moe-debug": + return qwen3_moe_debug_args(max_seq_len=seq_len) + if flavor == "30b-a3b": + return qwen3_30b_a3b_args(max_seq_len=seq_len) + if flavor == "235b-a22b": + return qwen3_235b_a22b_args(max_seq_len=seq_len) + raise ValueError(f"Unknown Qwen3 flavor: {flavor}") + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.DEBUG) + + seq_len = args.seq_len + if seq_len is None: + seq_len = { + "tiny": 8, + "moe-tiny": 8, + "debug": 512, + "8b": 4096, + "moe-debug": 512, + "30b-a3b": 4096, + "235b-a22b": 4096, + }[args.flavor] + if args.world_size % args.tp_degree != 0: + raise ValueError( + f"world-size ({args.world_size}) must be divisible by " + f"tp-degree ({args.tp_degree})." + ) + + if not torch.distributed.is_initialized(): + fake_store = FakeStore() + torch.distributed.init_process_group( + "fake", + store=fake_store, + rank=0, + world_size=args.world_size, + ) + + model_args = make_model_args(args.flavor, seq_len) + mesh_dim_names = ("dp", "ep") if model_args.moe_enabled else ("dp", "tp") + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (args.world_size // args.tp_degree, args.tp_degree), + mesh_dim_names=mesh_dim_names, + ) + device = torch.device("cuda") + + global_batch_size = args.local_batch_size * mesh.shape[0] + if model_args.moe_enabled: + global_batch_size *= mesh.shape[1] + + with torch.device("meta"): + model = Transformer( + model_args, + mesh=mesh if model_args.moe_enabled else None, + moe_axis_name=mesh.mesh_dim_names[1], + ) + + def input_fn(): + return torch.randint( + 0, + model_args.vocab_size, + (global_batch_size, seq_len), + device=device, + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + + t0 = time.time() + with AutoParallel( + model, + input_fn, + mesh, + mp_policy, + dynamic=model_args.moe_enabled, + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = ( + (Shard(0), Shard(0)) if model_args.moe_enabled else (Shard(0), Replicate()) + ) + out_sharding = (Shard(0), Shard(2)) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + + sharding_placement = autop.optimize_placement(verbose=args.verbose) + print(f"Tracing + optimization took {time.time() - t0:.1f}s") + + if args.save_optimizer is not None: + autop.sharding_optimizer.save(args.save_optimizer) + autop.sharding_optimizer.save_placements( + f"{args.save_optimizer}.placements.json" + ) + + parallel_mod = autop.apply_placement(sharding_placement) + + if args.skip_run: + print("Placement applied successfully.") + return + + parallel_mod.to_empty(device=device) + parallel_mod.init_weights(buffer_device=device) # type: ignore[operator] + + if args.compile: + parallel_mod = torch.compile(parallel_mod, backend=autoparallel_backend()) + + tokens = torch.randint( + 0, + model_args.vocab_size, + (args.local_batch_size, seq_len), + device=device, + ) + out = parallel_mod(tokens) + if torch.any(torch.isnan(out)): + raise RuntimeError("Found NaNs in Qwen3 forward output.") + out.backward(torch.randn_like(out)) + print("All good!") + + +if __name__ == "__main__": + main() diff --git a/examples/example_sanity_check_qwen3.py b/examples/example_sanity_check_qwen3.py new file mode 100644 index 00000000..b7af6c0d --- /dev/null +++ b/examples/example_sanity_check_qwen3.py @@ -0,0 +1,335 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import time + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn_func +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.qwen3 import Transformer, qwen3_8b_args +from autoparallel.api import AutoParallel +from autoparallel.compile import autoparallel_backend + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run a real Qwen3 8B AutoParallel training sanity check." + ) + parser.add_argument( + "--global-batch-size", + type=int, + default=16, + help="Global batch size across data-parallel ranks.", + ) + parser.add_argument( + "--microbatch-size", + type=int, + default=1, + help="Per-DP-rank microbatch size for gradient accumulation.", + ) + parser.add_argument( + "--seq-len", + type=int, + default=4096, + help="Sequence length. Defaults to Qwen3 8B's max sequence length.", + ) + parser.add_argument( + "--dp-degree", + type=int, + default=2, + help="Data-parallel mesh degree.", + ) + parser.add_argument( + "--tp-degree", + type=int, + default=2, + help="Tensor-parallel mesh degree.", + ) + parser.add_argument( + "--train-steps", + type=int, + default=20, + help="Number of optimizer steps.", + ) + parser.add_argument( + "--lr", + type=float, + default=3e-4, + help="AdamW learning rate.", + ) + parser.add_argument( + "--max-grad-norm", + type=float, + default=1.0, + help="Gradient clipping max norm.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed for model initialization and synthetic data generation.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Compile the placed module with the AutoParallel backend before training.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print the full AutoParallel optimizer log.", + ) + return parser.parse_args() + + +def init_distributed(args): + if "WORLD_SIZE" not in os.environ or "LOCAL_RANK" not in os.environ: + raise RuntimeError( + "Run this example with torchrun, e.g. " + "torchrun --standalone --nproc-per-node 4 " + "examples/example_sanity_check_qwen3.py" + ) + + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + expected_world_size = args.dp_degree * args.tp_degree + if world_size != expected_world_size: + raise ValueError( + f"WORLD_SIZE ({world_size}) must equal dp-degree * tp-degree " + f"({args.dp_degree} * {args.tp_degree} = {expected_world_size})." + ) + if args.global_batch_size % args.dp_degree != 0: + raise ValueError( + f"global-batch-size ({args.global_batch_size}) must be divisible by " + f"dp-degree ({args.dp_degree})." + ) + local_batch_size = args.global_batch_size // args.dp_degree + if local_batch_size % args.microbatch_size != 0: + raise ValueError( + f"local batch size ({local_batch_size}) must be divisible by " + f"microbatch-size ({args.microbatch_size})." + ) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + dist.init_process_group("nccl", device_id=device) + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (args.dp_degree, args.tp_degree), + mesh_dim_names=("dp", "tp"), + ) + return device, mesh + + +def make_local_tokens(args, mesh, device, vocab_size: int) -> torch.Tensor: + coordinate = mesh.get_coordinate() + if coordinate is None: + raise RuntimeError("DeviceMesh coordinate is unavailable on this rank.") + dp_rank, _tp_rank = coordinate + local_batch_size = args.global_batch_size // args.dp_degree + + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + tokens = torch.randint( + 0, + vocab_size, + (args.global_batch_size, args.seq_len + 1), + generator=generator, + dtype=torch.long, + ) + + start = dp_rank * local_batch_size + stop = start + local_batch_size + return tokens[start:stop].to(device, non_blocking=True) + + +def vocab_parallel_cross_entropy( + logits: torch.Tensor, + labels: torch.Tensor, + *, + vocab_size: int, + tp_group, + tp_rank: int, + tp_degree: int, + global_token_count: int, +) -> torch.Tensor: + if logits.shape[:2] != labels.shape: + raise ValueError( + f"logits shape {tuple(logits.shape)} is incompatible with " + f"labels shape {tuple(labels.shape)}." + ) + + local_vocab_size = logits.shape[-1] + vocab_start = tp_rank * local_vocab_size + vocab_stop = vocab_start + local_vocab_size + if tp_rank == tp_degree - 1: + vocab_stop = vocab_size + + logits = logits.float() + local_max = logits.amax(dim=-1) + with torch.no_grad(): + global_max = local_max.detach().clone() + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group) + + shifted_logits = logits - global_max.unsqueeze(-1) + local_exp_sum = shifted_logits.exp().sum(dim=-1) + global_exp_sum = dist_nn_func.all_reduce( + local_exp_sum, + op=dist.ReduceOp.SUM, + group=tp_group, + ) + + target_mask = (labels >= vocab_start) & (labels < vocab_stop) + local_target = torch.zeros_like(labels, dtype=torch.long) + local_target[target_mask] = labels[target_mask] - vocab_start + local_target_logits = logits.gather(-1, local_target.unsqueeze(-1)).squeeze(-1) + local_target_logits = local_target_logits * target_mask.to(logits.dtype) + target_logits = dist_nn_func.all_reduce( + local_target_logits, + op=dist.ReduceOp.SUM, + group=tp_group, + ) + + loss_sum = (global_exp_sum.log() + global_max - target_logits).sum() + return loss_sum / (global_token_count * tp_degree) + + +def print_rank0(message: str) -> None: + if dist.get_rank() == 0: + print(message, flush=True) + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.DEBUG) + + device, mesh = init_distributed(args) + tp_group = mesh.get_group("tp") + tp_rank = mesh.get_local_rank("tp") + local_batch_size = args.global_batch_size // args.dp_degree + gradient_accumulation_steps = local_batch_size // args.microbatch_size + + torch.manual_seed(args.seed) + model_args = qwen3_8b_args(max_seq_len=args.seq_len) + trace_global_batch_size = args.microbatch_size * args.dp_degree + + with torch.device("meta"): + model = Transformer(model_args) + + def input_fn(): + return torch.randint( + 0, + model_args.vocab_size, + (trace_global_batch_size, args.seq_len), + device=device, + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + + print_rank0( + "Qwen3 8B sanity check: " + f"mesh=(dp={args.dp_degree}, tp={args.tp_degree}), " + f"global_batch={args.global_batch_size}, " + f"local_batch={local_batch_size}, " + f"microbatch={args.microbatch_size}, " + f"grad_accum={gradient_accumulation_steps}, " + f"trace_global_batch={trace_global_batch_size}, " + f"seq_len={args.seq_len}" + ) + + t0 = time.time() + with AutoParallel( + model, + input_fn, + mesh, + mp_policy, + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=args.verbose) + parallel_mod = autop.apply_placement(sharding_placement) + + print_rank0(f"Tracing + optimization took {time.time() - t0:.1f}s") + + parallel_mod.to_empty(device=device) + parallel_mod.init_weights(buffer_device=device, seed=args.seed) # type: ignore[operator] + + if args.compile: + parallel_mod = torch.compile(parallel_mod, backend=autoparallel_backend()) + + batch = make_local_tokens(args, mesh, device, model_args.vocab_size) + inputs = batch[:, :-1].contiguous() + labels = batch[:, 1:].contiguous() + input_microbatches = inputs.split(args.microbatch_size, dim=0) + label_microbatches = labels.split(args.microbatch_size, dim=0) + global_token_count = args.global_batch_size * args.seq_len + optimizer = torch.optim.AdamW(parallel_mod.parameters(), lr=args.lr) + + try: + losses: list[float] = [] + for step in range(args.train_steps): + optimizer.zero_grad(set_to_none=True) + step_loss = torch.zeros((), device=device) + for micro_inputs, micro_labels in zip( + input_microbatches, label_microbatches + ): + logits = parallel_mod(micro_inputs) + if torch.any(torch.isnan(logits)): + raise RuntimeError("Found NaNs in Qwen3 forward output.") + + loss = vocab_parallel_cross_entropy( + logits, + micro_labels, + vocab_size=model_args.vocab_size, + tp_group=tp_group, + tp_rank=tp_rank, + tp_degree=args.tp_degree, + global_token_count=global_token_count, + ) + if torch.any(torch.isnan(loss)): + raise RuntimeError("Found NaNs in Qwen3 training loss.") + + loss.backward() + step_loss = step_loss + loss.detach() + + torch.nn.utils.clip_grad_norm_( + parallel_mod.parameters(), args.max_grad_norm + ) + optimizer.step() + + with torch.no_grad(): + logged_loss = step_loss.clone() + dist.all_reduce(logged_loss, op=dist.ReduceOp.SUM) + loss_value = float(logged_loss.item()) + losses.append(loss_value) + print_rank0(f"step={step:03d} loss={loss_value:.6f}") + + if losses[-1] >= losses[0]: + raise RuntimeError( + f"Qwen3 training loss did not improve: initial={losses[0]:.6f}, " + f"final={losses[-1]:.6f}" + ) + + print_rank0(f"Loss improved: initial={losses[0]:.6f}, final={losses[-1]:.6f}") + dist.barrier(device_ids=[device.index]) + torch.cuda.synchronize(device) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/example_sanity_check_qwen3_moe.py b/examples/example_sanity_check_qwen3_moe.py new file mode 100644 index 00000000..7f48d69c --- /dev/null +++ b/examples/example_sanity_check_qwen3_moe.py @@ -0,0 +1,466 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import time + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn_func +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Shard + +from autoparallel._testing.models.qwen3 import ( + Qwen3ModelArgs, + Transformer, + qwen3_30b_a3b_args, + qwen3_235b_a22b_args, + qwen3_moe_debug_args, +) +from autoparallel.api import AutoParallel +from autoparallel.compile import autoparallel_backend + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run a real Qwen3 MoE AutoParallel training sanity check." + ) + parser.add_argument( + "--flavor", + choices=("moe-tiny", "moe-debug", "30b-a3b", "235b-a22b"), + default="30b-a3b", + help="Qwen3 MoE model size. Defaults to the real Qwen3-30B-A3B model.", + ) + parser.add_argument( + "--global-batch-size", + type=int, + default=4, + help="Global batch size across data-parallel ranks.", + ) + parser.add_argument( + "--microbatch-size", + type=int, + default=1, + help="Per-rank input microbatch size before EP all-gather inside the model.", + ) + parser.add_argument( + "--seq-len", + type=int, + default=8192, + help="Sequence length. Defaults to 8192 for the 4xH100 sanity run.", + ) + parser.add_argument( + "--dp-degree", + type=int, + default=2, + help="Data-parallel mesh degree.", + ) + parser.add_argument( + "--ep-degree", + type=int, + default=2, + help="Expert-parallel mesh degree.", + ) + parser.add_argument( + "--train-steps", + type=int, + default=30, + help="Number of optimizer steps.", + ) + parser.add_argument( + "--lr", + type=float, + default=3e-4, + help="Optimizer learning rate.", + ) + parser.add_argument( + "--optimizer", + choices=("adamw", "sgd", "none"), + default="adamw", + help="Optimizer to use after backward. Use sgd/none for large-model memory smoke runs.", + ) + parser.add_argument( + "--max-grad-norm", + type=float, + default=1.0, + help="Gradient clipping max norm.", + ) + parser.add_argument( + "--loss-chunk-size", + type=int, + default=512, + help=( + "Sequence chunk size for vocab-parallel cross entropy. " + "Keeps the 8192-token real-model run from materializing full-size " + "float logits and exp buffers at once." + ), + ) + parser.add_argument( + "--skip-loss-improvement-check", + action="store_true", + help="Only require finite forward/backward/optimizer steps.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed for model initialization and synthetic data generation.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Compile the placed module with the AutoParallel backend before training.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print the full AutoParallel optimizer log.", + ) + return parser.parse_args() + + +def make_model_args(flavor: str, seq_len: int | None) -> Qwen3ModelArgs: + if flavor == "moe-tiny": + max_seq_len = 512 if seq_len is None else seq_len + return Qwen3ModelArgs( + dim=64, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=16, + hidden_dim=128, + vocab_size=128, + max_seq_len=max_seq_len, + moe_enabled=True, + moe_hidden_dim=32, + num_experts=8, + top_k=2, + route_norm=True, + score_before_experts=False, + moe_axis_name="ep", + ) + overrides = {"moe_axis_name": "ep"} + if seq_len is not None: + overrides["max_seq_len"] = seq_len + if flavor == "moe-debug": + return qwen3_moe_debug_args(**overrides) + if flavor == "30b-a3b": + return qwen3_30b_a3b_args(**overrides) + if flavor == "235b-a22b": + return qwen3_235b_a22b_args(**overrides) + raise ValueError(f"Unknown Qwen3 MoE flavor: {flavor}") + + +def init_distributed(args): + if "WORLD_SIZE" not in os.environ or "LOCAL_RANK" not in os.environ: + raise RuntimeError( + "Run this example with torchrun, e.g. " + "torchrun --standalone --nproc-per-node 4 " + "examples/example_sanity_check_qwen3_moe.py" + ) + + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + expected_world_size = args.dp_degree * args.ep_degree + if world_size != expected_world_size: + raise ValueError( + f"WORLD_SIZE ({world_size}) must equal dp-degree * ep-degree " + f"({args.dp_degree} * {args.ep_degree} = {expected_world_size})." + ) + if args.global_batch_size % args.dp_degree != 0: + raise ValueError( + f"global-batch-size ({args.global_batch_size}) must be divisible by " + f"dp-degree ({args.dp_degree})." + ) + + local_dp_batch_size = args.global_batch_size // args.dp_degree + local_dp_microbatch = args.microbatch_size * args.ep_degree + if local_dp_batch_size % local_dp_microbatch != 0: + raise ValueError( + f"local DP batch size ({local_dp_batch_size}) must be divisible by " + f"microbatch-size * ep-degree " + f"({args.microbatch_size} * {args.ep_degree} = {local_dp_microbatch})." + ) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + dist.init_process_group("nccl", device_id=device) + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (args.dp_degree, args.ep_degree), + mesh_dim_names=("dp", "ep"), + ) + return device, mesh + + +def make_local_tokens(args, mesh, device, vocab_size: int) -> torch.Tensor: + coordinate = mesh.get_coordinate() + if coordinate is None: + raise RuntimeError("DeviceMesh coordinate is unavailable on this rank.") + dp_rank, _ep_rank = coordinate + local_dp_batch_size = args.global_batch_size // args.dp_degree + + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + tokens = torch.randint( + 0, + vocab_size, + (args.global_batch_size, args.seq_len + 1), + generator=generator, + dtype=torch.long, + ) + + start = dp_rank * local_dp_batch_size + stop = start + local_dp_batch_size + return tokens[start:stop].to(device, non_blocking=True) + + +def vocab_parallel_cross_entropy( + logits: torch.Tensor, + labels: torch.Tensor, + *, + vocab_size: int, + vocab_group, + vocab_rank: int, + vocab_degree: int, + global_token_count: int, +) -> torch.Tensor: + if logits.shape[:2] != labels.shape: + raise ValueError( + f"logits shape {tuple(logits.shape)} is incompatible with " + f"labels shape {tuple(labels.shape)}." + ) + + local_vocab_size = logits.shape[-1] + vocab_start = vocab_rank * local_vocab_size + vocab_stop = vocab_start + local_vocab_size + if vocab_rank == vocab_degree - 1: + vocab_stop = vocab_size + + logits = logits.float() + local_max = logits.amax(dim=-1) + with torch.no_grad(): + global_max = local_max.detach().clone() + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=vocab_group) + + shifted_logits = logits - global_max.unsqueeze(-1) + local_exp_sum = shifted_logits.exp().sum(dim=-1) + global_exp_sum = dist_nn_func.all_reduce( + local_exp_sum, + op=dist.ReduceOp.SUM, + group=vocab_group, + ) + + target_mask = (labels >= vocab_start) & (labels < vocab_stop) + local_target = torch.zeros_like(labels, dtype=torch.long) + local_target[target_mask] = labels[target_mask] - vocab_start + local_target_logits = logits.gather(-1, local_target.unsqueeze(-1)).squeeze(-1) + local_target_logits = local_target_logits * target_mask.to(logits.dtype) + target_logits = dist_nn_func.all_reduce( + local_target_logits, + op=dist.ReduceOp.SUM, + group=vocab_group, + ) + + loss_sum = (global_exp_sum.log() + global_max - target_logits).sum() + return loss_sum / (global_token_count * vocab_degree) + + +def chunk_ranges(size: int, chunk_size: int): + if chunk_size <= 0: + yield 0, size + return + for start in range(0, size, chunk_size): + yield start, min(start + chunk_size, size) + + +def print_rank0(message: str) -> None: + if dist.get_rank() == 0: + print(message, flush=True) + + +def print_cuda_memory(stage: str, device: torch.device) -> None: + allocated = torch.cuda.memory_allocated(device) / 1024**3 + reserved = torch.cuda.memory_reserved(device) / 1024**3 + max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + print_rank0( + f"{stage}: cuda allocated={allocated:.2f}GiB " + f"reserved={reserved:.2f}GiB max_reserved={max_reserved:.2f}GiB" + ) + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.DEBUG) + + device, mesh = init_distributed(args) + ep_group = mesh.get_group("ep") + ep_rank = mesh.get_local_rank("ep") + local_dp_batch_size = args.global_batch_size // args.dp_degree + local_dp_microbatch = args.microbatch_size * args.ep_degree + gradient_accumulation_steps = local_dp_batch_size // local_dp_microbatch + + torch.manual_seed(args.seed) + model_args = make_model_args(args.flavor, args.seq_len) + if args.seq_len is None: + args.seq_len = model_args.max_seq_len + if model_args.num_experts % args.ep_degree != 0: + raise ValueError( + f"num_experts ({model_args.num_experts}) must be divisible by " + f"ep-degree ({args.ep_degree})." + ) + trace_global_batch_size = args.microbatch_size * args.dp_degree * args.ep_degree + + with torch.device("meta"): + model = Transformer(model_args, mesh=mesh, moe_axis_name="ep") + + def input_fn(): + return torch.randint( + 0, + model_args.vocab_size, + (trace_global_batch_size, args.seq_len), + device=device, + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + + print_rank0( + f"Qwen3 {args.flavor} sanity check: " + f"mesh=(dp={args.dp_degree}, ep={args.ep_degree}), " + f"global_batch={args.global_batch_size}, " + f"local_dp_batch={local_dp_batch_size}, " + f"per_rank_microbatch={args.microbatch_size}, " + f"local_dp_microbatch={local_dp_microbatch}, " + f"grad_accum={gradient_accumulation_steps}, " + f"trace_global_batch={trace_global_batch_size}, " + f"seq_len={args.seq_len}, " + f"loss_chunk_size={args.loss_chunk_size}, " + f"optimizer={args.optimizer}" + ) + + t0 = time.time() + with AutoParallel( + model, + input_fn, + mesh, + mp_policy, + dynamic=True, + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Shard(0))]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=args.verbose) + parallel_mod = autop.apply_placement(sharding_placement) + + print_rank0(f"Tracing + optimization took {time.time() - t0:.1f}s") + print_cuda_memory("after AutoParallel", device) + + parallel_mod.to_empty(device=device) + print_cuda_memory("after to_empty", device) + parallel_mod.init_weights(buffer_device=device, seed=args.seed) # type: ignore[operator] + print_cuda_memory("after init_weights", device) + + if args.compile: + parallel_mod = torch.compile(parallel_mod, backend=autoparallel_backend()) + + batch = make_local_tokens(args, mesh, device, model_args.vocab_size) + inputs = batch[:, :-1].contiguous() + labels = batch[:, 1:].contiguous() + + ep_coordinate = mesh.get_coordinate()[1] + input_microbatches = [] + label_microbatches = [] + for start in range(0, local_dp_batch_size, local_dp_microbatch): + stop = start + local_dp_microbatch + input_block = inputs[start:stop] + input_start = ep_coordinate * args.microbatch_size + input_stop = input_start + args.microbatch_size + input_microbatches.append(input_block[input_start:input_stop].contiguous()) + label_microbatches.append(labels[start:stop].contiguous()) + + global_token_count = args.global_batch_size * args.seq_len + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(parallel_mod.parameters(), lr=args.lr) + elif args.optimizer == "sgd": + optimizer = torch.optim.SGD(parallel_mod.parameters(), lr=args.lr) + else: + optimizer = None + + try: + losses: list[float] = [] + for step in range(args.train_steps): + if optimizer is not None: + optimizer.zero_grad(set_to_none=True) + else: + parallel_mod.zero_grad(set_to_none=True) + step_loss = torch.zeros((), device=device) + for micro_inputs, micro_labels in zip( + input_microbatches, label_microbatches + ): + logits = parallel_mod(micro_inputs) + + seq_ranges = list(chunk_ranges(logits.shape[1], args.loss_chunk_size)) + for chunk_idx, (seq_start, seq_stop) in enumerate(seq_ranges): + logits_chunk = logits[:, seq_start:seq_stop] + labels_chunk = micro_labels[:, seq_start:seq_stop] + loss = vocab_parallel_cross_entropy( + logits_chunk, + labels_chunk, + vocab_size=model_args.vocab_size, + vocab_group=ep_group, + vocab_rank=ep_rank, + vocab_degree=args.ep_degree, + global_token_count=global_token_count, + ) + if torch.any(torch.isnan(loss)): + raise RuntimeError("Found NaNs in Qwen3 MoE training loss.") + + retain_graph = chunk_idx != len(seq_ranges) - 1 + loss.backward(retain_graph=retain_graph) + step_loss = step_loss + loss.detach() + + torch.nn.utils.clip_grad_norm_( + parallel_mod.parameters(), args.max_grad_norm + ) + if optimizer is not None: + optimizer.step() + + with torch.no_grad(): + logged_loss = step_loss.clone() + dist.all_reduce(logged_loss, op=dist.ReduceOp.SUM) + loss_value = float(logged_loss.item()) + losses.append(loss_value) + print_rank0(f"step={step:03d} loss={loss_value:.6f}") + print_cuda_memory(f"after step {step:03d}", device) + + if ( + not args.skip_loss_improvement_check + and len(losses) > 1 + and losses[-1] >= losses[0] + ): + raise RuntimeError( + f"Qwen3 MoE training loss did not improve: " + f"initial={losses[0]:.6f}, final={losses[-1]:.6f}" + ) + + if len(losses) > 1: + print_rank0( + f"Loss improved: initial={losses[0]:.6f}, final={losses[-1]:.6f}" + ) + dist.barrier(device_ids=[device.index]) + torch.cuda.synchronize(device) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/example_torchtitan_qwen3_dense.py b/examples/example_torchtitan_qwen3_dense.py new file mode 100644 index 00000000..a4685d1b --- /dev/null +++ b/examples/example_torchtitan_qwen3_dense.py @@ -0,0 +1,370 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import dataclasses +import logging +import os +import sys +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn_func +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel.api import AutoParallel +from autoparallel.compile import autoparallel_backend + + +def _add_sibling_torchtitan_to_path() -> None: + repo_root = Path(__file__).resolve().parents[1] + torchtitan_root = repo_root.parent / "torchtitan" + if torchtitan_root.exists(): + sys.path.insert(0, str(torchtitan_root)) + + +_add_sibling_torchtitan_to_path() + +from torchtitan.models.qwen3 import Qwen3Model, qwen3_configs # noqa: E402 + + +def parse_args(): + parser = argparse.ArgumentParser( + description=( + "Run torchtitan's dense Qwen3 model through AutoParallel's " + "searched placement on real GPUs." + ) + ) + parser.add_argument( + "--flavor", + choices=("debugmodel", "debugmodel_fused_qkv", "0.6B", "1.7B", "4B", "8B"), + default="8B", + help="Dense torchtitan Qwen3 flavor.", + ) + parser.add_argument( + "--global-batch-size", + type=int, + default=4, + help="Global batch size across data-parallel ranks.", + ) + parser.add_argument( + "--microbatch-size", + type=int, + default=1, + help="Per-DP-rank microbatch size for gradient accumulation.", + ) + parser.add_argument( + "--seq-len", + type=int, + default=2048, + help="Sequence length for the real sanity run.", + ) + parser.add_argument( + "--dp-degree", + type=int, + default=2, + help="Data-parallel mesh degree.", + ) + parser.add_argument( + "--tp-degree", + type=int, + default=2, + help="Tensor-parallel mesh degree.", + ) + parser.add_argument( + "--train-steps", + type=int, + default=2, + help="Number of optimizer steps.", + ) + parser.add_argument( + "--lr", + type=float, + default=3e-4, + help="AdamW learning rate.", + ) + parser.add_argument( + "--max-grad-norm", + type=float, + default=1.0, + help="Gradient clipping max norm.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed for model initialization and synthetic data generation.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Compile the placed module with the AutoParallel backend before training.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print the full AutoParallel optimizer log.", + ) + return parser.parse_args() + + +def make_model_config(flavor: str, seq_len: int) -> Qwen3Model.Config: + config = qwen3_configs[flavor](attn_backend="sdpa") + config.rope = dataclasses.replace(config.rope, max_seq_len=seq_len) + return config + + +def init_distributed(args): + if "WORLD_SIZE" not in os.environ or "LOCAL_RANK" not in os.environ: + raise RuntimeError( + "Run this example with torchrun, e.g. " + "torchrun --standalone --nproc-per-node 4 " + "examples/example_torchtitan_qwen3_dense.py" + ) + + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + expected_world_size = args.dp_degree * args.tp_degree + if world_size != expected_world_size: + raise ValueError( + f"WORLD_SIZE ({world_size}) must equal dp-degree * tp-degree " + f"({args.dp_degree} * {args.tp_degree} = {expected_world_size})." + ) + if args.global_batch_size % args.dp_degree != 0: + raise ValueError( + f"global-batch-size ({args.global_batch_size}) must be divisible by " + f"dp-degree ({args.dp_degree})." + ) + local_batch_size = args.global_batch_size // args.dp_degree + if local_batch_size % args.microbatch_size != 0: + raise ValueError( + f"local batch size ({local_batch_size}) must be divisible by " + f"microbatch-size ({args.microbatch_size})." + ) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + dist.init_process_group("nccl", device_id=device) + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (args.dp_degree, args.tp_degree), + mesh_dim_names=("dp", "tp"), + ) + return device, mesh + + +def make_local_tokens(args, mesh, device, vocab_size: int) -> torch.Tensor: + coordinate = mesh.get_coordinate() + if coordinate is None: + raise RuntimeError("DeviceMesh coordinate is unavailable on this rank.") + dp_rank, _tp_rank = coordinate + local_batch_size = args.global_batch_size // args.dp_degree + + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + tokens = torch.randint( + 0, + vocab_size, + (args.global_batch_size, args.seq_len + 1), + generator=generator, + dtype=torch.long, + ) + + start = dp_rank * local_batch_size + stop = start + local_batch_size + return tokens[start:stop].to(device, non_blocking=True) + + +def vocab_parallel_cross_entropy( + logits: torch.Tensor, + labels: torch.Tensor, + *, + vocab_size: int, + tp_group, + tp_rank: int, + tp_degree: int, + global_token_count: int, +) -> torch.Tensor: + if logits.shape[:2] != labels.shape: + raise ValueError( + f"logits shape {tuple(logits.shape)} is incompatible with " + f"labels shape {tuple(labels.shape)}." + ) + + local_vocab_size = logits.shape[-1] + vocab_start = tp_rank * local_vocab_size + vocab_stop = vocab_start + local_vocab_size + if tp_rank == tp_degree - 1: + vocab_stop = vocab_size + + logits = logits.float() + local_max = logits.amax(dim=-1) + with torch.no_grad(): + global_max = local_max.detach().clone() + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group) + + shifted_logits = logits - global_max.unsqueeze(-1) + local_exp_sum = shifted_logits.exp().sum(dim=-1) + global_exp_sum = dist_nn_func.all_reduce( + local_exp_sum, + op=dist.ReduceOp.SUM, + group=tp_group, + ) + + target_mask = (labels >= vocab_start) & (labels < vocab_stop) + local_target = torch.zeros_like(labels, dtype=torch.long) + local_target[target_mask] = labels[target_mask] - vocab_start + local_target_logits = logits.gather(-1, local_target.unsqueeze(-1)).squeeze(-1) + local_target_logits = local_target_logits * target_mask.to(logits.dtype) + target_logits = dist_nn_func.all_reduce( + local_target_logits, + op=dist.ReduceOp.SUM, + group=tp_group, + ) + + loss_sum = (global_exp_sum.log() + global_max - target_logits).sum() + return loss_sum / (global_token_count * tp_degree) + + +def print_rank0(message: str) -> None: + if dist.get_rank() == 0: + print(message, flush=True) + + +def main(): + args = parse_args() + logging.basicConfig(level=logging.DEBUG) + + device, mesh = init_distributed(args) + tp_group = mesh.get_group("tp") + tp_rank = mesh.get_local_rank("tp") + local_batch_size = args.global_batch_size // args.dp_degree + gradient_accumulation_steps = local_batch_size // args.microbatch_size + + torch.manual_seed(args.seed) + model_config = make_model_config(args.flavor, args.seq_len) + vocab_size = model_config.vocab_size + + with torch.device("meta"): + model = model_config.build() + + def input_fn(): + return torch.randint( + 0, + vocab_size, + (args.global_batch_size, args.seq_len), + device=device, + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + + print_rank0( + f"torchtitan Qwen3 {args.flavor} via AutoParallel: " + f"mesh=(dp={args.dp_degree}, tp={args.tp_degree}), " + f"global_batch={args.global_batch_size}, " + f"local_batch={local_batch_size}, " + f"microbatch={args.microbatch_size}, " + f"grad_accum={gradient_accumulation_steps}, " + f"seq_len={args.seq_len}" + ) + + t0 = time.time() + with AutoParallel( + model, + input_fn, + mesh, + mp_policy, + repeated_subgraphs=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=args.verbose) + parallel_mod = autop.apply_placement(sharding_placement) + + print_rank0(f"Tracing + optimization took {time.time() - t0:.1f}s") + + parallel_mod.to_empty(device=device) + torch.manual_seed(args.seed) + parallel_mod.init_weights(buffer_device=device) # type: ignore[operator] + + if args.compile: + parallel_mod = torch.compile(parallel_mod, backend=autoparallel_backend()) + + batch = make_local_tokens(args, mesh, device, vocab_size) + inputs = batch[:, :-1].contiguous() + labels = batch[:, 1:].contiguous() + input_microbatches = torch.split(inputs, args.microbatch_size, dim=0) + label_microbatches = torch.split(labels, args.microbatch_size, dim=0) + + global_token_count = args.global_batch_size * args.seq_len + optimizer = torch.optim.AdamW(parallel_mod.parameters(), lr=args.lr) + + try: + losses: list[float] = [] + for step in range(args.train_steps): + optimizer.zero_grad(set_to_none=True) + step_loss = torch.zeros((), device=device) + for micro_inputs, micro_labels in zip( + input_microbatches, label_microbatches + ): + logits = parallel_mod(micro_inputs) + if torch.any(torch.isnan(logits)): + raise RuntimeError("Found NaNs in forward output.") + + loss = vocab_parallel_cross_entropy( + logits, + micro_labels, + vocab_size=vocab_size, + tp_group=tp_group, + tp_rank=tp_rank, + tp_degree=args.tp_degree, + global_token_count=global_token_count, + ) + if torch.any(torch.isnan(loss)): + raise RuntimeError("Found NaNs in training loss.") + + loss.backward() + step_loss = step_loss + loss.detach() + + torch.nn.utils.clip_grad_norm_( + parallel_mod.parameters(), args.max_grad_norm + ) + optimizer.step() + + with torch.no_grad(): + logged_loss = step_loss.clone() + dist.all_reduce(logged_loss, op=dist.ReduceOp.SUM) + loss_value = float(logged_loss.item()) + losses.append(loss_value) + print_rank0(f"step={step:03d} loss={loss_value:.6f}") + + if len(losses) > 1 and losses[-1] >= losses[0]: + raise RuntimeError( + f"Training loss did not improve: " + f"initial={losses[0]:.6f}, final={losses[-1]:.6f}" + ) + + if len(losses) > 1: + print_rank0( + f"Loss improved: initial={losses[0]:.6f}, final={losses[-1]:.6f}" + ) + else: + print_rank0(f"Completed one step: loss={losses[0]:.6f}") + dist.barrier(device_ids=[device.index]) + torch.cuda.synchronize(device) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/test_dsv3_torchtitan_config.py b/tests/test_dsv3_torchtitan_config.py new file mode 100644 index 00000000..923016f8 --- /dev/null +++ b/tests/test_dsv3_torchtitan_config.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from pathlib import Path + +import pytest +import torch + +from autoparallel._testing.models.dsv3 import DeepSeekV3Model + + +def test_dsv3_accepts_torchtitan_grouped_experts_config(): + torchtitan_root = Path(__file__).resolve().parents[2] / "torchtitan" + if not torchtitan_root.exists(): + pytest.skip("torchtitan sibling checkout not found") + sys.path.insert(0, str(torchtitan_root)) + + try: + from torchtitan.models.deepseek_v3 import ( + deepseekv3_configs, # type: ignore[import-not-found] + ) + except Exception as exc: + pytest.skip(f"torchtitan DeepSeek-V3 config unavailable: {exc}") + + with torch.device("meta"): + model = DeepSeekV3Model( + deepseekv3_configs["debugmodel"]( + attn_backend="sdpa", + moe_comm_backend="standard", + ) + ) + + moe_layer = next(layer for layer in model.layers.values() if layer.moe_enabled) + assert moe_layer.moe.experts.use_grouped_mm diff --git a/tests/test_qwen3.py b/tests/test_qwen3.py new file mode 100644 index 00000000..3d20bace --- /dev/null +++ b/tests/test_qwen3.py @@ -0,0 +1,323 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from pathlib import Path + +import pytest +import torch +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel._testing.models.qwen3 import ( + Qwen3ModelArgs, + Transformer, + apply_rotary_emb_cos_sin, + qwen3_args_from_torchtitan_config, + qwen3_debug_args, + qwen3_moe_debug_args, +) +from autoparallel.api import AutoParallel, auto_parallel + + +def _tiny_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=16, + hidden_dim=128, + vocab_size=128, + max_seq_len=16, + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def _tiny_moe_args(**overrides) -> Qwen3ModelArgs: + args = Qwen3ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + vocab_size=64, + max_seq_len=4, + moe_enabled=True, + moe_hidden_dim=16, + num_experts=64, + top_k=8, + route_norm=True, + score_before_experts=False, + moe_axis_name="tp", + ) + for key, value in overrides.items(): + setattr(args, key, value) + args.__post_init__() + return args + + +def test_qwen3_forward_shape(): + args = _tiny_args() + model = Transformer(args) + model.init_weights(seed=0) + + tokens = torch.randint(0, args.vocab_size, (2, args.max_seq_len)) + logits = model(tokens) + + assert logits.shape == (2, args.max_seq_len, args.vocab_size) + + +def test_qwen3_qk_norm_changes_logits(): + args = _tiny_args(n_layers=1) + model = Transformer(args) + model.init_weights(seed=0) + + tokens = torch.randint(0, args.vocab_size, (2, args.max_seq_len)) + logits = model(tokens) + + with torch.no_grad(): + model.layers["0"].attention.q_norm.weight.zero_() + logits_without_q = model(tokens) + + assert not torch.allclose(logits, logits_without_q) + + +def test_qwen3_weight_tying_survives_init_weights(): + args = _tiny_args(enable_weight_tying=True) + model = Transformer(args) + + assert model.tok_embeddings.weight is model.lm_head.weight + model.init_weights(seed=0) + assert model.tok_embeddings.weight is model.lm_head.weight + + +def test_qwen3_debug_args_matches_torchtitan_dense_shape(): + args = qwen3_debug_args(max_seq_len=32) + + assert args.dim == 256 + assert args.n_layers == 8 + assert args.n_heads == 16 + assert args.n_kv_heads == 8 + assert args.head_dim == 128 + assert args.hidden_dim == 3072 + assert args.vocab_size == 2048 + assert args.rope_theta == 1000000.0 + assert args.enable_weight_tying + + +def test_qwen3_moe_debug_args_matches_torchtitan_shape(): + args = qwen3_moe_debug_args(max_seq_len=32) + + assert args.dim == 256 + assert args.n_layers == 8 + assert args.n_heads == 16 + assert args.n_kv_heads == 8 + assert args.head_dim == 128 + assert args.moe_enabled + assert args.moe_hidden_dim == 768 + assert args.num_experts == 64 + assert args.top_k == 8 + assert args.route_norm + assert not args.score_before_experts + + +@pytest.mark.parametrize( + ("flavor", "expected"), + [ + ( + "8B", + { + "dim": 4096, + "n_layers": 36, + "n_heads": 32, + "n_kv_heads": 8, + "head_dim": 128, + "hidden_dim": 12288, + "vocab_size": 151936, + "moe_enabled": False, + "num_experts": 0, + "top_k": 1, + "max_seq_len": 4096, + }, + ), + ( + "30B-A3B", + { + "dim": 2048, + "n_layers": 48, + "n_heads": 32, + "n_kv_heads": 4, + "head_dim": 128, + "hidden_dim": 0, + "vocab_size": 151936, + "moe_enabled": True, + "moe_hidden_dim": 768, + "num_experts": 128, + "top_k": 8, + "route_norm": True, + "score_before_experts": False, + "max_seq_len": 262144, + }, + ), + ], +) +def test_qwen3_args_from_torchtitan_config(flavor, expected): + torchtitan_root = Path(__file__).resolve().parents[2] / "torchtitan" + if not torchtitan_root.exists(): + pytest.skip("torchtitan sibling checkout not found") + sys.path.insert(0, str(torchtitan_root)) + + try: + from torchtitan.models.qwen3 import ( + qwen3_configs, # type: ignore[import-not-found] + ) + except Exception as exc: + pytest.skip(f"torchtitan Qwen3 config unavailable: {exc}") + + args = qwen3_args_from_torchtitan_config(qwen3_configs[flavor](attn_backend="sdpa")) + + for attr, value in expected.items(): + assert getattr(args, attr) == value + assert args.rope_theta == 1000000.0 + assert args.norm_eps == 1e-6 + + +def test_qwen3_cos_sin_rope_matches_torchtitan_helper(): + torchtitan_root = Path(__file__).resolve().parents[2] / "torchtitan" + if not torchtitan_root.exists(): + pytest.skip("torchtitan sibling checkout not found") + sys.path.insert(0, str(torchtitan_root)) + + try: + from torchtitan.models.common.rope import RoPE + from torchtitan.models.common.rope import ( + apply_rotary_emb_cos_sin as tt_apply_rotary_emb_cos_sin, # type: ignore[import-not-found] + ) + except Exception as exc: + pytest.skip(f"torchtitan Qwen3 RoPE helper unavailable: {exc}") + + args = _tiny_args() + rope = RoPE( + RoPE.Config( + dim=args.head_dim, + max_seq_len=args.max_seq_len, + theta=args.rope_theta, + backend="cos_sin", + ) + ) + xq = torch.randn(2, args.max_seq_len, args.n_heads, args.head_dim) + xk = torch.randn(2, args.max_seq_len, args.n_kv_heads, args.head_dim) + + actual = apply_rotary_emb_cos_sin(xq, xk, rope.cache) + expected = tt_apply_rotary_emb_cos_sin(xq, xk, rope.cache) + + torch.testing.assert_close(actual[0], expected[0]) + torch.testing.assert_close(actual[1], expected[1]) + + +def test_qwen3_autoparallel_pipeline_smoke(device_mesh_2d): + args = _tiny_args(n_layers=2, max_seq_len=8) + batch_size = 2 * device_mesh_2d.shape[0] + + with torch.device("meta"): + model = Transformer(args) + + def input_fn(): + return torch.randint( + 0, + args.vocab_size, + (batch_size, args.max_seq_len), + device="cuda", + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + + with AutoParallel( + model, + input_fn, + device_mesh_2d, + mp_policy, + repeated_subgraphs=True, + ) as autop: + autop.add_input_constraints([(Shard(0), Replicate())]) + autop.add_output_constraints([(Shard(0), Shard(2))]) + sharding_placement = autop.optimize_placement(verbose=False) + parallel_mod = autop.apply_placement(sharding_placement) + + assert isinstance(parallel_mod, Transformer) + + +def test_qwen3_moe_auto_parallel_smoke(device_mesh_2d): + args = _tiny_moe_args() + local_batch_size = 1 + + with torch.device("meta"): + model = Transformer(args, mesh=device_mesh_2d, moe_axis_name="tp") + + expected_param_shapes = { + name: tuple(param.shape) for name, param in model.named_parameters() + } + expected_nparams = sum(param.numel() for param in model.parameters()) + + tokens = DTensor.from_local( + torch.randint( + 0, + args.vocab_size, + (local_batch_size, args.max_seq_len), + device="cuda", + ), + device_mesh_2d, + [Shard(0), Shard(0)], + ) + + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + parallel_mod = auto_parallel( + model, + device_mesh_2d, + sample_inputs=(tokens,), + out_shardings=(Shard(0), Shard(2)), + mp_policy=mp_policy, + dynamic=True, + ) + + assert isinstance(parallel_mod, Transformer) + assert sum(param.numel() for param in parallel_mod.parameters()) == expected_nparams + assert { + name: tuple(param.shape) for name, param in parallel_mod.named_parameters() + } == expected_param_shapes + assert parallel_mod.layers["0"].moe.experts.w1.shape == ( + args.num_experts, + args.moe_hidden_dim, + args.dim, + ) + + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights(buffer_device=torch.device("cuda"), seed=0) + + local_tokens = torch.randint( + 0, + args.vocab_size, + (local_batch_size, args.max_seq_len), + device="cuda", + ) + out = parallel_mod(local_tokens) + assert out.shape == ( + local_batch_size * device_mesh_2d.shape[1], + args.max_seq_len, + args.vocab_size // device_mesh_2d.shape[1], + ) + out.backward(torch.randn_like(out))