diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index fc10336ea4..f215c6e713 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -_supported_models = frozenset(["deepseek_v3", "llama3", "llama3_ft", "llama4", "qwen3"]) +_supported_models = frozenset(["deepseek_v3", "llama3", "llama3_ft", "llama4", "qwen3", "qwen3_5_moe"]) diff --git a/torchtitan/models/qwen3_5_moe/__init__.py b/torchtitan/models/qwen3_5_moe/__init__.py new file mode 100644 index 0000000000..4ca267cf42 --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/__init__.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.moe import MoEArgs +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize import parallelize_qwen35_moe +from .model.args import Qwen35MoEModelArgs +from .model.model import Qwen35MoEModel +from .model.state_dict_adapter import Qwen35MoEStateDictAdapter + +__all__ = [ + "parallelize_qwen35_moe", + "Qwen35MoEModelArgs", + "Qwen35MoEModel", + "qwen3_5_moe_args", +] + +# Adding different variants of the model +# Translated from PR #2545's nested Config definitions to flat dataclass args. + +qwen3_5_moe_args = { + "debugmodel": Qwen35MoEModelArgs( + dim=256, + n_layers=8, + vocab_size=2048, + norm_eps=1e-6, + max_seq_len=1024, + # Full attention + n_heads=4, + n_kv_heads=2, + head_dim=64, + rotary_dim=16, + qk_norm=True, + # GatedDeltaNet + gdn_n_key_heads=2, + gdn_n_value_heads=4, + gdn_key_head_dim=64, + gdn_value_head_dim=64, + gdn_conv_kernel_size=4, + gdn_norm_eps=1e-6, + gdn_fla_backend="fla_chunked", + # RoPE + rope_theta=10000.0, + # Hybrid + full_attention_interval=4, + # MoE + moe_inter_dim=256, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=0, + top_k=2, + score_func="softmax", + route_norm=True, + score_before_experts=False, + use_grouped_mm=False, + ), + # Shared expert + shared_ffn_hidden_dim=256, + ), + "35b-a3b": Qwen35MoEModelArgs( + dim=2048, + n_layers=40, + vocab_size=248320, + norm_eps=1e-6, + max_seq_len=262144, + # Full attention + n_heads=16, + n_kv_heads=2, + head_dim=256, + rotary_dim=64, + qk_norm=True, + # GatedDeltaNet + gdn_n_key_heads=16, + gdn_n_value_heads=32, + gdn_key_head_dim=128, + gdn_value_head_dim=128, + gdn_conv_kernel_size=4, + gdn_norm_eps=1e-6, + gdn_fla_backend="fla_chunked", + # RoPE + rope_theta=10_000_000.0, + # Hybrid + full_attention_interval=4, + # MoE + moe_inter_dim=512, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), + # Shared expert + shared_ffn_hidden_dim=512, + ), + "122b-a10b": Qwen35MoEModelArgs( + dim=3072, + n_layers=48, + vocab_size=248320, + norm_eps=1e-6, + max_seq_len=262144, + # Full attention + n_heads=32, + n_kv_heads=2, + head_dim=256, + rotary_dim=64, + qk_norm=True, + # GatedDeltaNet + gdn_n_key_heads=16, + gdn_n_value_heads=64, + gdn_key_head_dim=128, + gdn_value_head_dim=128, + gdn_conv_kernel_size=4, + gdn_norm_eps=1e-6, + gdn_fla_backend="fla_chunked", + # RoPE + rope_theta=10_000_000.0, + # Hybrid + full_attention_interval=4, + # MoE + moe_inter_dim=1024, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), + # Shared expert + shared_ffn_hidden_dim=1024, + ), + "397b-a17b": Qwen35MoEModelArgs( + dim=4096, + n_layers=60, + vocab_size=248320, + norm_eps=1e-6, + max_seq_len=262144, + # Full attention + n_heads=32, + n_kv_heads=2, + head_dim=256, + rotary_dim=64, + qk_norm=True, + # GatedDeltaNet + gdn_n_key_heads=16, + gdn_n_value_heads=64, + gdn_key_head_dim=128, + gdn_value_head_dim=128, + gdn_conv_kernel_size=4, + gdn_norm_eps=1e-6, + gdn_fla_backend="fla_chunked", + # RoPE + rope_theta=10_000_000.0, + # Hybrid + full_attention_interval=4, + # MoE + moe_inter_dim=1024, + moe_args=MoEArgs( + num_experts=512, + num_shared_experts=0, + top_k=10, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), + # Shared expert + shared_ffn_hidden_dim=1024, + ), + "397B_A19B": Qwen35MoEModelArgs( + dim=4096, + n_layers=60, + vocab_size=248320, + norm_eps=1e-6, + max_seq_len=1_000_000, + # Full attention + n_heads=32, + n_kv_heads=2, + head_dim=256, + rotary_dim=64, + qk_norm=True, + # GatedDeltaNet + gdn_n_key_heads=16, + gdn_n_value_heads=64, + gdn_key_head_dim=128, + gdn_value_head_dim=128, + gdn_conv_kernel_size=4, + gdn_norm_eps=1e-6, + gdn_fla_backend="fla_chunked", + # RoPE + rope_theta=10_000_000.0, + # Hybrid + full_attention_interval=4, + # MoE + moe_inter_dim=1024, + moe_args=MoEArgs( + num_experts=512, + num_shared_experts=0, + top_k=10, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), + # Shared expert + shared_ffn_hidden_dim=1024, + ), +} + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Qwen35MoEModel, + model_args=qwen3_5_moe_args, + parallelize_fn=parallelize_qwen35_moe, + pipelining_fn=None, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Qwen35MoEStateDictAdapter, + ) diff --git a/torchtitan/models/qwen3_5_moe/infra/__init__.py b/torchtitan/models/qwen3_5_moe/infra/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/infra/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/models/qwen3_5_moe/infra/parallelize.py b/torchtitan/models/qwen3_5_moe/infra/parallelize.py new file mode 100644 index 0000000000..61668bde4e --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/infra/parallelize.py @@ -0,0 +1,495 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Parallelization for Qwen3.5 MoE hybrid decoder. + +Handles the hybrid architecture with both full attention (Attention) and +linear attention (GatedDeltaNet) layers. Key design: maintain Shard(1) +residual stream throughout (SequenceParallel). + +- Full attention layers: standard TP on wq/wk/wv/wo with SP on norms +- GatedDeltaNet layers: allgather input, Replicate DTensors internally, + with DTensor-safe wrappers for conv1d (depthwise) and FLA kernel +- MoE: reuses apply_moe_ep_tp from llama4 +- Shared expert: TP on w1/w3/w2 with Shard(1) residual +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + PrepareModuleInputOutput, + RowwiseParallel, + SequenceParallel, +) + +import torchtitan.models.qwen3_5_moe.model.model as _qwen3_model +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.llama4.infra.parallelize import ( + apply_compile, + apply_fsdp, + apply_moe_ep_tp, +) +from torchtitan.tools.logging import logger + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +# --------------------------------------------------------------------------- +# DTensor-safe wrappers +# --------------------------------------------------------------------------- + + +class _DTensorSafeInnerAttention(nn.Module): + """Wrapper that strips DTensor from Q/K/V before inner_attention and + wraps the output back as DTensor.""" + + def __init__(self, inner: nn.Module): + super().__init__() + self.inner = inner + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + is_dtensor = isinstance(q, DTensor) + if is_dtensor: + mesh, placements = q.device_mesh, q.placements + q, k, v = q.to_local(), k.to_local(), v.to_local() + out = self.inner(q, k, v, *args, **kwargs) + if is_dtensor: + out = DTensor.from_local(out, mesh, placements, run_check=False) + return out + + +class _DTensorSafeConv1d(nn.Module): + """Conv1d wrapper that bypasses DTensor dispatch for depthwise conv. + + DTensor's _tp_conv handler doesn't support depthwise conv (groups > 1). + This wrapper stores weight as a Replicate DTensor (for mesh consistency + needed by gradient norm clipping) but runs F.conv1d on local tensors. + """ + + def __init__(self, original: nn.Conv1d, tp_mesh: DeviceMesh): + super().__init__() + self.weight = nn.Parameter( + DTensor.from_local( + original.weight.data, tp_mesh, [Replicate()], run_check=False + ), + requires_grad=original.weight.requires_grad, + ) + self.bias: nn.Parameter | None = None + if original.bias is not None: + self.bias = nn.Parameter( + DTensor.from_local( + original.bias.data, tp_mesh, [Replicate()], run_check=False + ), + requires_grad=original.bias.requires_grad, + ) + self.stride = original.stride + self.padding = original.padding + self.dilation = original.dilation + self.groups = original.groups + + def forward(self, x: torch.Tensor) -> torch.Tensor: + is_dtensor = isinstance(x, DTensor) + x_local = x.to_local() if is_dtensor else x + w_local = ( + self.weight.to_local() if isinstance(self.weight, DTensor) else self.weight + ) + b_local = None + if self.bias is not None: + b_local = ( + self.bias.to_local() if isinstance(self.bias, DTensor) else self.bias + ) + out = F.conv1d( + x_local, + w_local, + b_local, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + if is_dtensor: + out = DTensor.from_local(out, x.device_mesh, x.placements, run_check=False) + return out + + +_dispatch_patched = False +_softplus_registered = False + + +def _register_dtensor_softplus() -> None: + """Register aten.softplus.default (and backward) as DTensor pointwise ops.""" + global _softplus_registered + if _softplus_registered: + return + _softplus_registered = True + + from torch.distributed.tensor._op_schema import RuntimeSchemaInfo + from torch.distributed.tensor._ops._pointwise_ops import pointwise_strategy + from torch.distributed.tensor._ops.registration import register_op_strategy + + register_op_strategy( + torch.ops.aten.softplus.default, + schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]), + )(pointwise_strategy) + + register_op_strategy( + torch.ops.aten.softplus_backward.default, + schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]), + )(pointwise_strategy) + + +def _install_dtensor_safe_dispatch() -> None: + """Monkey-patch _gated_delta_rule_dispatch to handle DTensor inputs.""" + global _dispatch_patched + if _dispatch_patched: + return + _dispatch_patched = True + + original_dispatch = _qwen3_model._gated_delta_rule_dispatch + + def _dtensor_safe_dispatch( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + backend: str, + ) -> torch.Tensor: + if isinstance(q, DTensor): + mesh, placements = q.device_mesh, q.placements + out = original_dispatch( + q.to_local(), + k.to_local(), + v.to_local(), + g.to_local(), + beta.to_local(), + backend, + ) + return DTensor.from_local(out, mesh, placements, run_check=False) + return original_dispatch(q, k, v, g, beta, backend) + + _qwen3_model._gated_delta_rule_dispatch = _dtensor_safe_dispatch + + +# --------------------------------------------------------------------------- +# Main parallelization function +# --------------------------------------------------------------------------- + + +def parallelize_qwen35_moe( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + world_mesh = parallel_dims.world_mesh + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + attn_backend = model.model_args.attn_backend + if job_config.parallelism.context_parallel_degree > 1 and attn_backend not in ( + "sdpa", + ): + raise NotImplementedError( + f"Context Parallel only supports SDPA attention for Qwen3.5 MoE on v0.2.0. " + f"Got attn_backend='{attn_backend}'." + ) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if parallel_dims.tp_enabled: + if ( + job_config.parallelism.enable_async_tensor_parallel + and not model_compile_enabled + ): + raise RuntimeError("Async TP requires torch.compile") + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + ) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + + if job_config.activation_checkpoint.mode != "none": + use_flex_attn = attn_backend == "flex" + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, + ) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if model_compile_enabled: + apply_compile(model, job_config.compile) + + if parallel_dims.fsdp_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.ep_enabled + else None + ), + gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=model_compile_enabled, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +# --------------------------------------------------------------------------- +# Non-MoE tensor parallelism +# --------------------------------------------------------------------------- + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism to non-MoE components. + + Handles the hybrid architecture: + - Full attention layers: standard TP on Q/K/V/O projections + - GatedDeltaNet layers: NoParallel on all submodules (Replicate DTensors) + with DTensor-safe wrappers for conv1d and FLA kernel dispatch + - Shared expert: TP on w1/w3/w2 + """ + # Patch FLA kernel dispatch to handle Replicate DTensor inputs (idempotent). + _install_dtensor_safe_dispatch() + # Register softplus as a DTensor pointwise op. + _register_dtensor_softplus() + + # Parallel styles for float8 vs standard + if enable_float8_tensorwise_tp: + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Global: embedding, final norm, output head + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Per-layer plans + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(), + } + + if transformer_block.layer_type == "full_attention": + # Full attention: standard TP on Q/K/V/O projections + layer_plan.update( + { + "attn": prepare_module_input( + input_layouts=(Shard(1), Replicate(), None, None), + desired_input_layouts=(Replicate(), Replicate(), None, None), + ), + "attn.wq": colwise_parallel(use_local_output=False), + "attn.wk": colwise_parallel(use_local_output=False), + "attn.wv": colwise_parallel(use_local_output=False), + "attn.q_norm": SequenceParallel(sequence_dim=2), + "attn.k_norm": SequenceParallel(sequence_dim=2), + "attn.wo": rowwise_parallel(output_layouts=Shard(1)), + } + ) + else: + # GatedDeltaNet: conv1d needs full sequence, FLA kernel needs + # plain tensors. Keep intermediates as Replicate DTensors. + + # Replace depthwise conv1d with DTensor-safe wrapper + transformer_block.attn.conv1d = _DTensorSafeConv1d( + transformer_block.attn.conv1d, tp_mesh + ) + + layer_plan.update( + { + "attn": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + output_layouts=(Replicate(),), + desired_output_layouts=(Shard(1),), + ), + "attn.in_proj_qkv": NoParallel(use_local_output=False), + "attn.in_proj_z": NoParallel(use_local_output=False), + "attn.in_proj_a": NoParallel(use_local_output=False), + "attn.in_proj_b": NoParallel(use_local_output=False), + "attn.out_proj": NoParallel(use_local_output=False), + "attn.norm": NoParallel(use_local_output=False), + } + ) + + # Shared expert gate + shared expert FFN + layer_plan.update( + { + "shared_gate": NoParallel( + input_layout=Shard(1), + output_layout=Shard(1), + use_local_output=True, + ), + "shared_ffn": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "shared_ffn.w1": colwise_parallel(), + "shared_ffn.w2": rowwise_parallel(output_layouts=Shard(1)), + "shared_ffn.w3": colwise_parallel(), + } + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + # Distribute standalone GatedDeltaNet parameters (A_log, dt_bias) + # as Replicate DTensors on the TP mesh. + if transformer_block.layer_type != "full_attention": + attn = transformer_block.attn + attn.A_log = nn.Parameter( + DTensor.from_local( + attn.A_log.data, tp_mesh, [Replicate()], run_check=False + ), + requires_grad=attn.A_log.requires_grad, + ) + attn.dt_bias = nn.Parameter( + DTensor.from_local( + attn.dt_bias.data, tp_mesh, [Replicate()], run_check=False + ), + requires_grad=attn.dt_bias.requires_grad, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) diff --git a/torchtitan/models/qwen3_5_moe/model/__init__.py b/torchtitan/models/qwen3_5_moe/model/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/model/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/models/qwen3_5_moe/model/args.py b/torchtitan/models/qwen3_5_moe/model/args.py new file mode 100644 index 0000000000..cfa613fe91 --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/model/args.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from torch import nn + +from torchtitan.config import JobConfig +from torchtitan.models.moe import MoEArgs +from torchtitan.models.utils import get_moe_model_nparams_and_flops +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger + + +@dataclass +class Qwen35MoEModelArgs(BaseModelArgs): + """Flat model args for Qwen3.5 MoE hybrid decoder. + + Flattened from the upstream PR's nested Config classes (Model.Config, + TransformerBlock.Config, Attention.Config, GatedDeltaNet.Config). + """ + + # --- Standard transformer --- + dim: int = 2048 + n_layers: int = 40 + vocab_size: int = 248320 + norm_eps: float = 1e-6 + max_seq_len: int = 262144 + + # --- Full attention --- + n_heads: int = 16 + n_kv_heads: int = 2 + head_dim: int = 256 + rotary_dim: int = 64 # partial RoPE: only first rotary_dim dims get RoPE + qk_norm: bool = True + + # --- GatedDeltaNet (linear attention) --- + gdn_n_key_heads: int = 16 + gdn_n_value_heads: int = 32 + gdn_key_head_dim: int = 128 + gdn_value_head_dim: int = 128 + gdn_conv_kernel_size: int = 4 + gdn_norm_eps: float = 1e-6 + gdn_fla_backend: str = "fla_chunked" + + # --- RoPE --- + rope_theta: float = 10_000_000.0 + + # --- Hybrid layer arrangement --- + full_attention_interval: int = 4 # every Nth layer is full attention + + # --- MoE --- + moe_inter_dim: int = 512 # per-expert FFN hidden dim + moe_args: MoEArgs = field(default_factory=MoEArgs) + + # --- Shared expert --- + shared_ffn_hidden_dim: int = 512 + + # --- Attention backend --- + attn_backend: str = "sdpa" # sdpa, flex + attn_mask_type: str = "causal" + + # --- Misc --- + depth_init: bool = True + eos_id: int = 151645 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + self.moe_args._debug_force_load_balance = ( + job_config.training.debug_moe_force_load_balance + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_moe_model_nparams_and_flops(self, model, seq_len) diff --git a/torchtitan/models/qwen3_5_moe/model/model.py b/torchtitan/models/qwen3_5_moe/model/model.py new file mode 100644 index 0000000000..97e578d044 --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/model/model.py @@ -0,0 +1,715 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Qwen3.5 MoE hybrid decoder model. + +Back-ported from upstream PR #2545 to v0.2.0 APIs. Core model math +(GatedDeltaNet recurrence, hybrid attention, MoE routing) is preserved; +wiring adapted to v0.2.0's TrainSpec / BaseModelArgs / ModelProtocol. +""" + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.models.moe import FeedForward, MoE +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import Qwen35MoEModelArgs + +try: + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule as _fla_chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule as _fla_fused_recurrent_gated_delta_rule, + ) + + _HAS_FLA = True +except ImportError: + _HAS_FLA = False + + +# --------------------------------------------------------------------------- +# Utility modules +# --------------------------------------------------------------------------- + + +class OffsetRMSNorm(nn.Module): + """RMSNorm with offset: ``(1 + weight) * norm(x)``, weight init to zeros.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return ((1.0 + self.weight.float()) * x).to(input_dtype) + + def reset_parameters(self): + nn.init.zeros_(self.weight) + + +class RMSNormGated(nn.Module): + """Gated RMSNorm: ``silu(gate) * weight * norm(x)``, weight init to ones. + + Used inside GatedDeltaNet. Takes ``(hidden_states, gate)`` separately. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.float() + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + hidden_states = (self.weight * hidden_states).to(input_dtype) + hidden_states = hidden_states * F.silu(gate.float()) + return hidden_states.to(input_dtype) + + def reset_parameters(self): + nn.init.ones_(self.weight) + + +# --------------------------------------------------------------------------- +# RoPE — adapted from qwen3/model/model.py with partial RoPE support +# --------------------------------------------------------------------------- + + +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + idx_theta = torch.outer(t, freqs).float() + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast( + rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: + ndim = x.ndim + assert ndim > 1 + bz, seqlen, _, head_dim = x.shape + if positions is None: + rope_cache = rope_cache[0:seqlen] + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + rope_cache = rope_cache[positions.squeeze(0)] + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + else: + assert positions.shape == (bz, seqlen) + rope_cache_expanded = rope_cache[None, :, None, :].expand(bz, -1, -1, -1) + rope_cache = torch.gather( + rope_cache_expanded, + dim=1, + index=positions.view(bz, seqlen, 1, 1).expand(bz, seqlen, 1, head_dim * 2), + ) + assert rope_cache.shape == (bz, seqlen, 1, head_dim * 2) + return rope_cache + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + rope_cache: torch.Tensor, + positions: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = xq.shape[-1] + rope_cache = reshape_for_broadcast(rope_cache, xq, positions) + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def apply_partial_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + rope_cache: torch.Tensor, + rotary_dim: int, + positions: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply RoPE only to the first ``rotary_dim`` elements of Q and K.""" + if rotary_dim >= xq.shape[-1]: + return apply_rotary_emb(xq, xk, rope_cache, positions) + xq_rot, xq_pass = xq[..., :rotary_dim], xq[..., rotary_dim:] + xk_rot, xk_pass = xk[..., :rotary_dim], xk[..., rotary_dim:] + xq_rot, xk_rot = apply_rotary_emb(xq_rot, xk_rot, rope_cache, positions) + xq = torch.cat([xq_rot, xq_pass], dim=-1) + xk = torch.cat([xk_rot, xk_pass], dim=-1) + return xq, xk + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +# --------------------------------------------------------------------------- +# Gated Delta Rule — pure-torch fallback +# --------------------------------------------------------------------------- + + +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + """L2 normalization matching the FLA library implementation.""" + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def _torch_chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, +) -> torch.Tensor: + """Pure-torch reference implementation of the gated delta rule. + + Uses ``(B, L, H, D)`` layout matching the FLA kernel convention. + + Args: + q: (B, L, H, D_k) + k: (B, L, H, D_k) + v: (B, L, H, D_v) + g: (B, L, H) -- log-space decay (negative values) + beta: (B, L, H) -- update weight + + Returns: + output: (B, L, H, D_v) + """ + B, L, H, D_k = q.shape + D_v = v.shape[-1] + dtype = q.dtype + + q = _l2norm(q.float(), dim=-1) + k = _l2norm(k.float(), dim=-1) + + scale = D_k**-0.5 + q = q * scale + + v = v.float() + g, beta = g.float(), beta.float() + + output = torch.zeros(B, L, H, D_v, dtype=torch.float32, device=q.device) + state = torch.zeros(B, H, D_k, D_v, dtype=torch.float32, device=q.device) + + for t in range(L): + q_t = q[:, t, :, :] # (B, H, D_k) + k_t = k[:, t, :, :] # (B, H, D_k) + v_t = v[:, t, :, :] # (B, H, D_v) + g_t = g[:, t, :].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + b_t = beta[:, t, :].unsqueeze(-1) # (B, H, 1) + + # Decay state + state = state * g_t + # Delta-rule error correction + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, D_v) + delta = (v_t - kv_mem) * b_t # (B, H, D_v) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, D_k, D_v) + # Query against state + output[:, t, :, :] = (state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.to(dtype) + + +def _gated_delta_rule_dispatch( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + backend: str, +) -> torch.Tensor: + """Dispatch gated delta rule to the selected backend. + + Args: + q: (B, L, H, D_k) + k: (B, L, H, D_k) + v: (B, L, H, D_v) + g: (B, L, H) -- log-space decay + beta: (B, L, H) -- update weight + backend: One of ``"fla_chunked"``, ``"fla_fused_recurrent"``, + ``"torch_naive"``. + + Returns: + output: (B, L, H, D_v) + """ + _VALID_BACKENDS = {"fla_chunked", "fla_fused_recurrent", "torch_naive"} + if backend not in _VALID_BACKENDS: + raise ValueError( + f"Unknown fla_backend '{backend}'. Valid options: " + "'fla_chunked', 'fla_fused_recurrent', 'torch_naive'." + ) + + if backend == "torch_naive": + return _torch_chunk_gated_delta_rule(q, k, v, g, beta) + + if not _HAS_FLA: + raise RuntimeError( + f"Backend '{backend}' requires the `fla` package, but it is not installed." + ) + + if backend == "fla_chunked": + result = _fla_chunk_gated_delta_rule( + q, k, v, g, beta, use_qk_l2norm_in_kernel=True + ) + elif backend == "fla_fused_recurrent": + result = _fla_fused_recurrent_gated_delta_rule( + q, k, v, g, beta=beta, use_qk_l2norm_in_kernel=True + ) + + if isinstance(result, tuple): + return result[0] + return result + + +# --------------------------------------------------------------------------- +# GatedDeltaNet — linear attention module +# --------------------------------------------------------------------------- + + +class GatedDeltaNet(nn.Module): + """Gated DeltaNet linear attention. + + Completely different from standard attention: no RoPE, no attention masks, + different head structure. Uses recurrent state + gated delta rule. + """ + + def __init__(self, model_args: Qwen35MoEModelArgs): + super().__init__() + dim = model_args.dim + self.n_key_heads = model_args.gdn_n_key_heads + self.n_value_heads = model_args.gdn_n_value_heads + self.key_head_dim = model_args.gdn_key_head_dim + self.value_head_dim = model_args.gdn_value_head_dim + self.conv_kernel_size = model_args.gdn_conv_kernel_size + self.fla_backend = model_args.gdn_fla_backend + + key_dim = self.n_key_heads * self.key_head_dim + value_dim = self.n_value_heads * self.value_head_dim + conv_dim = key_dim * 2 + value_dim + + self.in_proj_qkv = nn.Linear(dim, conv_dim, bias=False) + self.in_proj_z = nn.Linear(dim, value_dim, bias=False) + self.in_proj_a = nn.Linear(dim, self.n_value_heads, bias=False) + self.in_proj_b = nn.Linear(dim, self.n_value_heads, bias=False) + + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=conv_dim, # depthwise + padding=0, # causal padding applied manually in forward + ) + + self.A_log = nn.Parameter(torch.zeros(self.n_value_heads)) + self.dt_bias = nn.Parameter(torch.ones(self.n_value_heads)) + + self.norm = RMSNormGated(self.value_head_dim, eps=model_args.gdn_norm_eps) + self.out_proj = nn.Linear(value_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, L, D = x.shape + + # Projections + qkv = self.in_proj_qkv(x) # (B, L, conv_dim) + z = self.in_proj_z(x) # (B, L, value_dim) + a = self.in_proj_a(x) # (B, L, n_value_heads) + b = self.in_proj_b(x) # (B, L, n_value_heads) + + # Causal Conv1d + SiLU + qkv = F.pad(qkv.transpose(1, 2), (self.conv_kernel_size - 1, 0)) + qkv = F.silu(self.conv1d(qkv).transpose(1, 2)) # (B, L, conv_dim) + + # Split into q, k, v + key_dim = self.n_key_heads * self.key_head_dim + value_dim = self.n_value_heads * self.value_head_dim + q, k, v = qkv.split([key_dim, key_dim, value_dim], dim=-1) + + # Reshape to heads -- stay in (B, L, H, D) layout for FLA kernel + q = q.view(B, L, self.n_key_heads, self.key_head_dim) + k = k.view(B, L, self.n_key_heads, self.key_head_dim) + v = v.view(B, L, self.n_value_heads, self.value_head_dim) + + # Repeat q, k if n_value_heads > n_key_heads (grouped heads) + if self.n_value_heads > self.n_key_heads: + repeat = self.n_value_heads // self.n_key_heads + q = q.repeat_interleave(repeat, dim=2) + k = k.repeat_interleave(repeat, dim=2) + + # Compute log-decay (g) and update weight (beta) + g = -torch.exp(self.A_log.float()) * F.softplus( + a.float() + self.dt_bias + ) # (B, L, H_v) + beta = torch.sigmoid(b) # (B, L, H_v) + + # Gated delta rule + output = _gated_delta_rule_dispatch( + q, k, v, g, beta, self.fla_backend + ) # (B, L, H_v, D_v) + + # Apply gated norm + z = z.view(B, L, self.n_value_heads, self.value_head_dim) + output = self.norm(output, z) + + # Project output + output = output.reshape(B, L, -1) + return self.out_proj(output) + + def init_weights(self, init_std: float): + for linear in ( + self.in_proj_qkv, + self.in_proj_z, + self.in_proj_a, + self.in_proj_b, + ): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.out_proj.weight, mean=0.0, std=init_std) + with torch.no_grad(): + self.A_log.copy_( + torch.log(torch.empty_like(self.A_log).uniform_(1e-6, 16.0)) + ) + nn.init.ones_(self.dt_bias) + self.norm.reset_parameters() + + +# --------------------------------------------------------------------------- +# Attention — full attention with output gating + partial RoPE +# --------------------------------------------------------------------------- + + +class Attention(nn.Module): + """Full attention with output gating and partial RoPE for Qwen3.5 MoE. + + Key differences from standard GQAttention: + - ``wq`` is 2x wider: produces both query and sigmoid gate + - Partial RoPE: only first ``rotary_dim`` elements get RoPE + - Output gating: ``attn_output * sigmoid(gate)`` before ``wo`` + - QK norm uses ``OffsetRMSNorm`` + """ + + def __init__(self, model_args: Qwen35MoEModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.head_dim + self.rotary_dim = model_args.rotary_dim + self.scaling = self.head_dim**-0.5 + self.use_flex_attn = model_args.attn_backend == "flex" + + # QK norm uses OffsetRMSNorm (not nn.RMSNorm) + if model_args.qk_norm: + self.q_norm = OffsetRMSNorm(self.head_dim, eps=model_args.norm_eps) + self.k_norm = OffsetRMSNorm(self.head_dim, eps=model_args.norm_eps) + else: + self.q_norm = None + self.k_norm = None + + # wq is 2x wider: produces query + gate + self.wq = nn.Linear( + model_args.dim, self.n_heads * self.head_dim * 2, bias=False + ) + self.wk = nn.Linear( + model_args.dim, self.n_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.dim, self.n_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + self.n_heads * self.head_dim, model_args.dim, bias=False + ) + + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.wq.weight, mean=0.0, std=0.02) + for linear in (self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + if self.q_norm is not None: + self.q_norm.reset_parameters() + if self.k_norm is not None: + self.k_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, + ): + bs, seqlen, _ = x.shape + + # Project Q (2x wider for query + gate), K, V + xq_gate = self.wq(x).view(bs, seqlen, -1, self.head_dim * 2) + xq, gate = xq_gate.chunk(2, dim=-1) # each (bs, seqlen, n_heads, head_dim) + xk = self.wk(x).view(bs, seqlen, -1, self.head_dim) + xv = self.wv(x).view(bs, seqlen, -1, self.head_dim) + + # QK norm (before RoPE) + if self.q_norm is not None: + xq = self.q_norm(xq) + if self.k_norm is not None: + xk = self.k_norm(xk) + + # Partial RoPE + xq, xk = apply_partial_rotary_emb( + xq, xk, rope_cache, self.rotary_dim, positions + ) + + # Repeat k/v heads for GQA + keys = repeat_kv(xk, self.n_rep) + values = repeat_kv(xv, self.n_rep) + + xq = xq.transpose(1, 2) # (bs, n_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) + xv = values.transpose(1, 2) + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention( + xq, xk, xv, block_mask=attention_masks, scale=self.scaling + ) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv, scale=self.scaling) + + output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_heads, head_dim) + + # Output gating: attn_output * sigmoid(gate) before wo + output = output * torch.sigmoid(gate) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +# --------------------------------------------------------------------------- +# TransformerBlock — hybrid decoder layer +# --------------------------------------------------------------------------- + + +class TransformerBlock(nn.Module): + """Transformer block for Qwen3.5 MoE hybrid decoder. + + Each layer uses either full attention (``Attention``) or linear attention + (``GatedDeltaNet``), determined by ``full_attention_interval``. Both types + share the same MoE + gated shared expert FFN structure. + """ + + def __init__(self, layer_id: int, model_args: Qwen35MoEModelArgs): + super().__init__() + self.layer_id = layer_id + + # Determine layer type + is_full_attn = (layer_id + 1) % model_args.full_attention_interval == 0 + self.layer_type = "full_attention" if is_full_attn else "linear_attention" + + # Attention: full or DeltaNet + if self.layer_type == "full_attention": + self.attn = Attention(model_args) + else: + self.attn = GatedDeltaNet(model_args) + + # MoE (routed experts only) + self.moe_enabled = True # always True for Qwen3.5 MoE + self.moe = MoE( + model_args.moe_args, + dim=model_args.dim, + hidden_dim=model_args.moe_inter_dim, + ) + + # Shared expert: FeedForward + sigmoid gate + self.shared_ffn = FeedForward( + dim=model_args.dim, + hidden_dim=model_args.shared_ffn_hidden_dim, + ) + self.shared_gate = nn.Linear(model_args.dim, 1, bias=False) + + # Norms (OffsetRMSNorm, not standard nn.RMSNorm) + self.attention_norm = OffsetRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = OffsetRMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, + ) -> torch.Tensor: + # Attention block + h = self.attention_norm(x) + if self.layer_type == "full_attention": + h = self.attn(h, rope_cache, attention_masks, positions) + else: + h = self.attn(h) # DeltaNet ignores rope/masks + x = x + h + + # FFN block: MoE + gated shared expert + h = self.ffn_norm(x) + moe_out = self.moe(h) + shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h) + x = x + moe_out + shared_out + return x + + def init_weights(self, buffer_device: torch.device): + self.attn.init_weights(self.weight_init_std) + self.moe.init_weights(self.weight_init_std, buffer_device) + self.shared_ffn.init_weights(self.weight_init_std) + nn.init.trunc_normal_(self.shared_gate.weight, mean=0.0, std=0.02) + self.attention_norm.reset_parameters() + self.ffn_norm.reset_parameters() + + +# --------------------------------------------------------------------------- +# Model — top-level hybrid decoder +# --------------------------------------------------------------------------- + + +class Qwen35MoEModel(nn.Module, ModelProtocol): + """Qwen3.5 MoE hybrid decoder model. + + Alternates between GatedDeltaNet (linear attention) and full attention + layers, controlled by ``full_attention_interval``. Every Nth layer uses + full attention; the rest use GatedDeltaNet. + """ + + def __init__(self, model_args: Qwen35MoEModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = OffsetRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.rotary_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + buffer_device = buffer_device or self.rope_cache.device + with torch.device(buffer_device): + self.rope_cache = self._precompute_rope_cache() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append( + get_document_mask_mod(input_batch, tokenizer.eos_id) + ) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, + ): + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.rope_cache, attention_masks, positions) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output diff --git a/torchtitan/models/qwen3_5_moe/model/state_dict_adapter.py b/torchtitan/models/qwen3_5_moe/model/state_dict_adapter.py new file mode 100644 index 0000000000..10af3fe823 --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/model/state_dict_adapter.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""State dict adapter for Qwen3.5 MoE: HF <-> torchtitan conversion. + +Key difference from Qwen3 MoE: HF stores expert weights as already-fused 3D +tensors (``gate_up_proj [E, 2*I, D]``), not per-expert 2D tensors. Conversion +is a simple ``chunk``/``cat`` along dim=1 -- no per-expert stacking needed. +""" + +import re +from typing import Any + +import torch + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import Qwen35MoEModelArgs + +# Prefix on all text-model keys in the HF checkpoint +_HF_PREFIX = "model.language_model." + + +class Qwen35MoEStateDictAdapter(StateDictAdapter): + def __init__(self, model_args: Qwen35MoEModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.model_args = model_args + + # Non-expert key mapping: HF (with prefix) -> torchtitan + # Expert keys (gate_up_proj / down_proj) are handled separately + # because they need tensor manipulation (split / concat). + self.from_hf_map = { + # --- top-level --- + f"{_HF_PREFIX}embed_tokens.weight": "tok_embeddings.weight", + f"{_HF_PREFIX}norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + # --- per-layer norms --- + f"{_HF_PREFIX}layers.{{}}.input_layernorm.weight": "layers.{}.attention_norm.weight", + f"{_HF_PREFIX}layers.{{}}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # --- MoE router + shared expert --- + f"{_HF_PREFIX}layers.{{}}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", + f"{_HF_PREFIX}layers.{{}}.mlp.shared_expert.gate_proj.weight": "layers.{}.shared_ffn.w1.weight", + f"{_HF_PREFIX}layers.{{}}.mlp.shared_expert.up_proj.weight": "layers.{}.shared_ffn.w3.weight", + f"{_HF_PREFIX}layers.{{}}.mlp.shared_expert.down_proj.weight": "layers.{}.shared_ffn.w2.weight", + f"{_HF_PREFIX}layers.{{}}.mlp.shared_expert_gate.weight": "layers.{}.shared_gate.weight", + # --- GatedDeltaNet (linear_attn) --- + f"{_HF_PREFIX}layers.{{}}.linear_attn.in_proj_qkv.weight": "layers.{}.attn.in_proj_qkv.weight", + f"{_HF_PREFIX}layers.{{}}.linear_attn.in_proj_z.weight": "layers.{}.attn.in_proj_z.weight", + f"{_HF_PREFIX}layers.{{}}.linear_attn.in_proj_a.weight": "layers.{}.attn.in_proj_a.weight", + f"{_HF_PREFIX}layers.{{}}.linear_attn.in_proj_b.weight": "layers.{}.attn.in_proj_b.weight", + f"{_HF_PREFIX}layers.{{}}.linear_attn.conv1d.weight": "layers.{}.attn.conv1d.weight", + f"{_HF_PREFIX}layers.{{}}.linear_attn.A_log": "layers.{}.attn.A_log", + f"{_HF_PREFIX}layers.{{}}.linear_attn.dt_bias": "layers.{}.attn.dt_bias", + f"{_HF_PREFIX}layers.{{}}.linear_attn.norm.weight": "layers.{}.attn.norm.weight", + f"{_HF_PREFIX}layers.{{}}.linear_attn.out_proj.weight": "layers.{}.attn.out_proj.weight", + # --- Full attention (self_attn) --- + f"{_HF_PREFIX}layers.{{}}.self_attn.q_proj.weight": "layers.{}.attn.wq.weight", + f"{_HF_PREFIX}layers.{{}}.self_attn.k_proj.weight": "layers.{}.attn.wk.weight", + f"{_HF_PREFIX}layers.{{}}.self_attn.v_proj.weight": "layers.{}.attn.wv.weight", + f"{_HF_PREFIX}layers.{{}}.self_attn.o_proj.weight": "layers.{}.attn.wo.weight", + f"{_HF_PREFIX}layers.{{}}.self_attn.q_norm.weight": "layers.{}.attn.q_norm.weight", + f"{_HF_PREFIX}layers.{{}}.self_attn.k_norm.weight": "layers.{}.attn.k_norm.weight", + } + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert torchtitan state dict to HF format. + + Fuses ``moe.experts.w1`` + ``moe.experts.w3`` into ``gate_up_proj``, + and maps ``moe.experts.w2`` to ``down_proj``. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict: dict[str, Any] = {} + + for key, value in state_dict.items(): + # Skip training-only buffers + if key.endswith(".expert_bias") or key.endswith(".tokens_per_expert"): + continue + if key == "rope_cache": + continue + + # --- Expert w1 + w3 -> gate_up_proj --- + if ".moe.experts.w1" in key: + layer_num = re.search(r"layers\.(\d+)\.", key).group(1) + w3_key = key.replace(".moe.experts.w1", ".moe.experts.w3") + w3 = state_dict[w3_key] + gate_up = torch.cat([value, w3], dim=1) + hf_key = f"{_HF_PREFIX}layers.{layer_num}.mlp.experts.gate_up_proj" + hf_state_dict[hf_key] = gate_up + continue + + # w3 already handled together with w1 + if ".moe.experts.w3" in key: + continue + + # --- Expert w2 -> down_proj --- + if ".moe.experts.w2" in key: + layer_num = re.search(r"layers\.(\d+)\.", key).group(1) + hf_key = f"{_HF_PREFIX}layers.{layer_num}.mlp.experts.down_proj" + hf_state_dict[hf_key] = value + continue + + # --- Standard key mapping --- + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + if abstract_key not in to_hf_map: + continue + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key].format(layer_num) + hf_state_dict[new_key] = value + else: + if key not in to_hf_map: + continue + hf_state_dict[to_hf_map[key]] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert HF state dict to torchtitan format. + + Splits fused ``gate_up_proj [E, 2*I, D]`` into ``w1`` and ``w3``, + and maps ``down_proj`` to ``w2``. + """ + state_dict: dict[str, Any] = {} + + for key, value in hf_state_dict.items(): + # Skip visual encoder, MTP, and rotary embedding keys + if key.startswith("model.visual.") or key.startswith("mtp."): + continue + if "rotary_emb.inv_freq" in key: + continue + + # --- Expert gate_up_proj -> w1 + w3 --- + if ".mlp.experts.gate_up_proj" in key: + layer_num = re.search(r"layers\.(\d+)\.", key).group(1) + # shape: [num_experts, 2 * intermediate, dim] + w1, w3 = value.chunk(2, dim=1) + state_dict[f"layers.{layer_num}.moe.experts.w1"] = w1 + state_dict[f"layers.{layer_num}.moe.experts.w3"] = w3 + continue + + # --- Expert down_proj -> w2 --- + if ".mlp.experts.down_proj" in key: + layer_num = re.search(r"layers\.(\d+)\.", key).group(1) + state_dict[f"layers.{layer_num}.moe.experts.w2"] = value + continue + + # --- Standard key mapping --- + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + if abstract_key not in self.from_hf_map: + continue + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + if new_key is None: + continue + state_dict[new_key.format(layer_num)] = value + else: + if key not in self.from_hf_map: + continue + new_key = self.from_hf_map[key] + if new_key is None: + continue + state_dict[new_key] = value + + # Populate expert_bias as zeros for each MoE layer. + for key in list(state_dict.keys()): + if key.endswith(".moe.router.gate.weight"): + layer_prefix = key.rsplit(".moe.router.gate.weight", 1)[0] + num_experts = state_dict[key].shape[0] + bias_key = f"{layer_prefix}.moe.expert_bias" + if bias_key not in state_dict: + state_dict[bias_key] = torch.zeros( + num_experts, dtype=torch.float32 + ) + + return state_dict diff --git a/torchtitan/models/qwen3_5_moe/train_configs/debug_model.toml b/torchtitan/models/qwen3_5_moe/train_configs/debug_model.toml new file mode 100644 index 0000000000..eae530e08e --- /dev/null +++ b/torchtitan/models/qwen3_5_moe/train_configs/debug_model.toml @@ -0,0 +1,63 @@ +[job] +dump_folder = "./outputs" +description = "Qwen 3.5 MoE debug model training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "qwen3_5_moe" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 3e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 4 +seq_len = 1024 +max_norm = 1.0 +steps = 10 +dataset = "c4_test" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "selective" +selective_ac_option = "op" + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"]