From ab174be77cafde602d0aa64374c73f5abab2efda Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Fri, 15 May 2026 01:45:28 -0700 Subject: [PATCH 1/7] evovle qwen3_vl to qwen3_5 --- ...qwen3_vl.py => numerical_tests_qwen3_5.py} | 64 +- .../numerical_tests_qwen3_5_shard.py | 226 ++++ tests/integration_tests/models.py | 27 +- torchtitan/models/__init__.py | 2 +- torchtitan/models/qwen3_5/README.md | 88 ++ torchtitan/models/qwen3_5/__init__.py | 1058 +++++++++++++++++ torchtitan/models/qwen3_5/config_registry.py | 320 +++++ torchtitan/models/qwen3_5/model.py | 1045 ++++++++++++++++ torchtitan/models/qwen3_5/parallelize.py | 200 ++++ .../{qwen3_vl => qwen3_5}/requirements.txt | 0 .../models/{qwen3_vl => qwen3_5}/rope.py | 0 torchtitan/models/qwen3_5/sharding.py | 371 ++++++ .../models/qwen3_5/state_dict_adapter.py | 343 ++++++ .../{qwen3_vl => qwen3_5}/vision_encoder.py | 336 +++--- torchtitan/models/qwen3_vl/README.md | 66 - torchtitan/models/qwen3_vl/__init__.py | 649 ---------- torchtitan/models/qwen3_vl/config_registry.py | 188 --- torchtitan/models/qwen3_vl/model.py | 573 --------- torchtitan/models/qwen3_vl/parallelize.py | 248 ---- torchtitan/models/qwen3_vl/sharding.py | 33 - .../models/qwen3_vl/state_dict_adapter.py | 341 ------ torchtitan/models/utils.py | 12 +- 22 files changed, 3854 insertions(+), 2336 deletions(-) rename scripts/checkpoint_conversion/{numerical_tests_qwen3_vl.py => numerical_tests_qwen3_5.py} (88%) create mode 100644 scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py create mode 100644 torchtitan/models/qwen3_5/README.md create mode 100644 torchtitan/models/qwen3_5/__init__.py create mode 100644 torchtitan/models/qwen3_5/config_registry.py create mode 100644 torchtitan/models/qwen3_5/model.py create mode 100644 torchtitan/models/qwen3_5/parallelize.py rename torchtitan/models/{qwen3_vl => qwen3_5}/requirements.txt (100%) rename torchtitan/models/{qwen3_vl => qwen3_5}/rope.py (100%) create mode 100644 torchtitan/models/qwen3_5/sharding.py create mode 100644 torchtitan/models/qwen3_5/state_dict_adapter.py rename torchtitan/models/{qwen3_vl => qwen3_5}/vision_encoder.py (63%) delete mode 100644 torchtitan/models/qwen3_vl/README.md delete mode 100644 torchtitan/models/qwen3_vl/__init__.py delete mode 100644 torchtitan/models/qwen3_vl/config_registry.py delete mode 100644 torchtitan/models/qwen3_vl/model.py delete mode 100644 torchtitan/models/qwen3_vl/parallelize.py delete mode 100644 torchtitan/models/qwen3_vl/sharding.py delete mode 100644 torchtitan/models/qwen3_vl/state_dict_adapter.py diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_vl.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py similarity index 88% rename from scripts/checkpoint_conversion/numerical_tests_qwen3_vl.py rename to scripts/checkpoint_conversion/numerical_tests_qwen3_5.py index 944363a318..7e5f345266 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_vl.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py @@ -6,21 +6,21 @@ # LICENSE file in the root directory of this source tree. """ -End-to-end numerical correctness test for Qwen3-VL checkpoint conversion. +End-to-end numerical correctness test for Qwen3.5 checkpoint conversion. Compares HuggingFace and TorchTitan next-token logits on multimodal inputs (random image + text prompt). Each pipeline uses its own image preprocessing so the test validates the full path: pixels → vision encoder → decoder → logits. Usage: - python -m scripts.checkpoint_conversion.numerical_tests_qwen3_vl \ - --hf_model_path /path/to/Qwen3-VL-2B-Instruct \ - --tt_checkpoint_path /path/to/qwen3_vl_2b_dcp - - python -m scripts.checkpoint_conversion.numerical_tests_qwen3_vl \ - --hf_model_path /path/to/Qwen3-VL-30B-A3B-Instruct \ - --tt_checkpoint_path /path/to/qwen3_vl_30b_a3b_dcp \ - --model_flavor 30B-A3B + python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \ + --hf_model_path ../hf_models/Qwen/Qwen3.5-4B \ + --tt_checkpoint_path outputs/Qwen/qwen3_5_4b_dcp + + python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \ + --hf_model_path ../hf_models/Qwen/Qwen3.5-35B-A3B \ + --tt_checkpoint_path outputs/Qwen/qwen3_5_35b_a3b_dcp \ + --model_flavor 35B-A3B """ import argparse @@ -34,7 +34,7 @@ torch._dynamo.config.disable = True from torchtitan.components.checkpoint import ModelWrapper -from torchtitan.models.qwen3_vl import model_registry +from torchtitan.models.qwen3_5 import model_registry from transformers import AutoProcessor @@ -88,12 +88,13 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): vision_to_patches, ) - processor = AutoProcessor.from_pretrained(hf_model_path, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(hf_model_path) model_config = model_registry(model_flavor).model - encoder_config = model_config.vision_encoder # pyrefly: ignore[missing-attribute] - patch_size = encoder_config.patch_embed.patch_size - temporal_patch_size = encoder_config.patch_embed.temporal_patch_size + # pyrefly: ignore [missing-attribute] + encoder_config = model_config.vision_encoder + patch_size = encoder_config.patch_size + temporal_patch_size = encoder_config.temporal_patch_size merge_size = encoder_config.spatial_merge_size hf_inputs, tt_inputs, pixel_comparisons = [], [], [] @@ -117,7 +118,8 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): ], } ] - hf_in = processor.apply_chat_template( # pyrefly: ignore[missing-attribute] + # pyrefly: ignore [missing-attribute] + hf_in = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, @@ -213,7 +215,7 @@ def run_hf(model_path, hf_inputs, device): model = AutoModelForImageTextToText.from_pretrained( model_path, device_map=device, - torch_dtype=torch.float16, + dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, ) @@ -242,13 +244,15 @@ def run_hf(model_path, hf_inputs, device): @torch.no_grad() def run_tt(model_flavor, checkpoint_path, tt_inputs, device): """Run TT model, return last-token logits per sample.""" + from torchtitan.models.common.attention import ScaledDotProductAttention + print(f"Loading TorchTitan model on {device} ...") model_config = model_registry(model_flavor).model with torch.device("meta"): model = model_config.build() model.to_empty(device="cpu") - model.init_weights(buffer_device=torch.device("cpu")) + model.init_states(buffer_device=torch.device("cpu")) model.half() state_dict = ModelWrapper(model)._get_state_dict() @@ -258,11 +262,9 @@ def run_tt(model_flavor, checkpoint_path, tt_inputs, device): # Replace FlexAttention with SDPA for single-process inference # (unfused FlexAttention without torch.compile has poor fp16 numerics). - from torchtitan.models.common.attention import ScaledDotProductAttention - for layer in model.layers.values(): - layer.attention.attn_backend = "sdpa" - layer.attention.inner_attention = ScaledDotProductAttention.Config().build() + if layer.layer_type == "full_attention": + layer.attn.inner_attention = ScaledDotProductAttention.Config().build() class _BidirectionalSDPA(torch.nn.Module): def forward(self, q, k, v, **kwargs): @@ -278,13 +280,7 @@ def forward(self, q, k, v, **kwargs): model.eval() - special_tokens = { - "image_id": 151655, - "video_id": 151656, - "vision_start_id": 151652, - "vision_end_id": 151653, - "pad_id": 151643, - } + special_tokens = {"image_id": 248056, "video_id": 248057} outputs = [] for i, (tokens, pixel_values, grid_thw) in enumerate(tt_inputs): @@ -319,12 +315,13 @@ def compare(hf_outputs, tt_outputs): kl = kl_divergence(hf, tt).item() top1, top5 = top_k_match(hf, tt) cos = F.cosine_similarity(hf.flatten(), tt.flatten(), dim=0).item() + max_diff = (hf - tt).abs().max().item() total_kl += kl total_top1 += top1 total_top5 += top5 print( f" Sample {i + 1:2d}: KL={kl:.4e} cos={cos:.6f} " - f"top1={top1:.0%} top5={top5:.0%}" + f"max_diff={max_diff:.4e} top1={top1:.0%} top5={top5:.0%}" ) n = len(hf_outputs) @@ -351,16 +348,11 @@ def compare(hf_outputs, tt_outputs): def main(): parser = argparse.ArgumentParser( - description="End-to-end numerical correctness test for Qwen3-VL.", + description="End-to-end numerical correctness test for Qwen3.5.", ) parser.add_argument("--hf_model_path", type=str, required=True) parser.add_argument("--tt_checkpoint_path", type=str, required=True) - parser.add_argument( - "--model_flavor", - type=str, - default="2B", - choices=["2B", "8B", "30B-A3B", "235B-A22B"], - ) + parser.add_argument("--model_flavor", type=str, default="4B") parser.add_argument("--num_samples", type=int, default=10) args = parser.parse_args() diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py new file mode 100644 index 0000000000..d63ea09a54 --- /dev/null +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Numerical comparison across parallelism configs for Qwen3.5. + +Feeds identical fake tokens across 4 configs and verifies logits match. +Requires 8 GPUs. + +Usage: + python scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py +""" + +import argparse +import json +import os +import subprocess +import sys +import tempfile + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor + +from torchtitan.config import ( + ActivationCheckpointConfig, + CompileConfig, + ParallelismConfig, + TrainingConfig, +) +from torchtitan.distributed import ParallelDims +from torchtitan.models.qwen3_5 import qwen3_5_configs +from torchtitan.models.qwen3_5.parallelize import parallelize_qwen3_5 + +CONFIGS = [ + {"ngpu": 1, "tp": 1, "ep": 1, "cp": 1, "label": "no_parallel"}, + {"ngpu": 4, "tp": 1, "ep": 1, "cp": 1, "label": "fsdp"}, + {"ngpu": 8, "tp": 1, "ep": 4, "cp": 1, "label": "fsdp+ep"}, + {"ngpu": 8, "tp": 2, "ep": 2, "cp": 1, "label": "fsdp+ep+tp"}, + {"ngpu": 8, "tp": 2, "ep": 2, "cp": 2, "label": "fsdp+ep+tp+cp"}, +] + + +def run_worker(args): + """Worker entry point — called via torchrun.""" + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + dp_shard = world_size // (args.tp * args.cp) + + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + config = qwen3_5_configs["debugmodel_moe"]( + attn_backend="flex", + moe_comm_backend="standard", + ) + + parallel_dims = ParallelDims( + dp_shard=dp_shard, + dp_replicate=1, + cp=args.cp, + tp=args.tp, + pp=1, + ep=args.ep, + world_size=world_size, + ) + parallel_dims.build_mesh() + + parallelism = ParallelismConfig( + tensor_parallel_degree=args.tp, + data_parallel_shard_degree=dp_shard, + context_parallel_degree=args.cp, + expert_parallel_degree=args.ep, + ) + training = TrainingConfig( + local_batch_size=1, + seq_len=128, + steps=1, + mixed_precision_param="bfloat16", + mixed_precision_reduce="float32", + ) + + config.update_from_config( + trainer_config=type( + "C", + (), + { + "training": training, + "parallelism": parallelism, + "debug": type("D", (), {"moe_force_load_balance": False})(), + }, + )(), + ) + + model = config.build() + model.to_empty(device="cuda") + model.init_weights(buffer_device=torch.device("cuda")) + + model = parallelize_qwen3_5( + model, + parallel_dims=parallel_dims, + training=training, + parallelism=parallelism, + compile_config=CompileConfig(), + ac_config=ActivationCheckpointConfig(), + dump_folder="/tmp", + ) + + torch.manual_seed(seed) + seq_len = 128 + tokens = torch.randint(0, 248320, (1, seq_len), device="cuda") + dist.broadcast(tokens, src=0) + + extra_kwargs: dict = {} + if args.cp > 1: + # Shard tokens and create positions for CP + from torchtitan.distributed.context_parallel import ( + prepare_context_parallel_input, + ) + + positions = torch.arange(seq_len, device="cuda").unsqueeze(0) + labels = tokens.clone() + extra_kwargs = {"positions": positions} + tokens, labels, extra_kwargs = prepare_context_parallel_input( + tokens, + labels, + extra_kwargs, + parallel_dims.get_mesh("cp"), + torch.device("cuda"), + ) + + with torch.no_grad(): + output = model( + tokens, + special_tokens={"image_id": 151859, "video_id": 151860}, + **extra_kwargs, + ) + + if isinstance(output, DTensor): + output = output.full_tensor() + + # With CP, each rank has partial output — gather first token's logits from rank 0's CP portion + logits = output[0, 0, :10].float().tolist() + + if rank == 0: + with open(args.output, "w") as f: + json.dump(logits, f) + + dist.destroy_process_group() + + +def main(): + """Orchestrator — launches torchrun for each config and compares.""" + results = {} + + with tempfile.TemporaryDirectory() as tmpdir: + for cfg in CONFIGS: + outfile = os.path.join(tmpdir, f"{cfg['label']}.json") + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc-per-node={cfg['ngpu']}", + __file__, + "--worker", + f"--tp={cfg['tp']}", + f"--ep={cfg['ep']}", + f"--cp={cfg['cp']}", + f"--output={outfile}", + ] + print( + f"Running {cfg['label']} (ngpu={cfg['ngpu']}, " + f"tp={cfg['tp']}, ep={cfg['ep']}, cp={cfg['cp']})..." + ) + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode != 0: + print(f" FAILED:\n{result.stderr[-500:]}") + return 1 + + with open(outfile) as f: + results[cfg["label"]] = json.load(f) + print(f" logits: {[f'{v:.6f}' for v in results[cfg['label']]]}") + + print("\n--- Comparison ---") + baseline_label = CONFIGS[0]["label"] + baseline = results[baseline_label] + all_pass = True + + for cfg in CONFIGS[1:]: + label = cfg["label"] + logits = results[label] + max_diff = max(abs(a - b) for a, b in zip(baseline, logits)) + status = "PASS" if max_diff < 0.02 else "FAIL" + if status == "FAIL": + all_pass = False + print(f" {baseline_label} vs {label}: max_diff={max_diff:.6e} {status}") + + print(f"\n{'ALL PASS' if all_pass else 'SOME FAILED'}") + return 0 if all_pass else 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--worker", action="store_true") + parser.add_argument("--tp", type=int) + parser.add_argument("--ep", type=int) + parser.add_argument("--cp", type=int, default=1) + parser.add_argument("--output", type=str) + args = parser.parse_args() + + if args.worker: + run_worker(args) + else: + sys.exit(main()) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 30e03fb0ec..2c0078195f 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -119,18 +119,37 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "qwen3_fsdp+tp+cp", ngpu=8, ), - # Integration Test Cases for Qwen3-VL + # Integration Test Cases for Llama 4 + # TODO: re-enable compile after fixing + # https://github.com/pytorch/torchtitan/issues/2771 + OverrideDefinitions( + [ + [ + "--module llama4 --config llama4_debugmodel_ep", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule Interleaved1F1B", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + # "--compile.enable", + ], + ], + "Llama 4 PP+FSDP+TP+EP+compile", + "llama4_pp+fsdp+tp+ep+compile", + ngpu=8, + ), + # Integration Test Cases for Qwen3.5 OverrideDefinitions( [ [ - "--module qwen3_vl --config qwen3_vl_debugmodel_moe", + "--module qwen3_5 --config qwen35_debugmodel_moe", "--parallelism.data_parallel_shard_degree 4", "--parallelism.tensor_parallel_degree 2", "--parallelism.expert_parallel_degree 4", ], ], - "Qwen3-VL MoE FSDP+TP+EP", - "qwen3_vl_moe_fsdp+tp+ep", + "Qwen3.5 MoE FSDP+TP+EP", + "qwen3_5_moe_fsdp+tp+ep", ngpu=8, ), # Integration Test Cases for gpt-oss diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index 44f32f521f..da1b5c6224 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -5,5 +5,5 @@ # LICENSE file in the root directory of this source tree. _supported_models = frozenset( - ["deepseek_v3", "flux", "gpt_oss", "llama3", "qwen3", "qwen3_vl"] + ["deepseek_v3", "flux", "gpt_oss", "llama3", "qwen3", "qwen3_5"] ) diff --git a/torchtitan/models/qwen3_5/README.md b/torchtitan/models/qwen3_5/README.md new file mode 100644 index 0000000000..9ae6f58fc3 --- /dev/null +++ b/torchtitan/models/qwen3_5/README.md @@ -0,0 +1,88 @@ +# Qwen3.5: Multimodal Model with Hybrid Attention + +## Overview + +Qwen3.5 combines: +- **Hybrid Decoder**: 75% GatedDeltaNet (linear attention) + 25% full attention with output gating and partial RoPE. +- **Vision Encoder**: A Vision Transformer (ViT) with 2D RoPE and bilinear-interpolated learned position embeddings. +- **Patch Merger**: Reduces vision sequence length by merging spatial patches (e.g., 2x2 patches -> 1 token). +- **MRoPE**: Interleaves RoPE from temporal, height, and width position IDs in decoder layers. +- **MoE variant**: Routed experts + shared expert with sigmoid gate. + +## Vision Scatter + +- `tok_embeddings` produces text token embeddings of shape `B×S`. +- `vision_encoder` produces visual token embeddings of shape `N×L`. +- Valid visual tokens (excluding padding) are scattered into the placeholder positions in the text sequence, as illustrated below (credit: [@lkhphuc](https://github.com/lkhphuc)). + +VLM Architecture + +Note: the diagram shows each patch mapping to one vision token. In practice, the Patch Merger groups `merge_size²` patches into one token (e.g., `merge_size=2` → 4 patches per token), reducing the vision sequence length by `merge_size²`. + +## Prerequisites + +Install the additional dependencies: + +```bash +pip install av torchvision +``` + +For GatedDeltaNet GPU efficiency (optional, pure-torch fallback available): + +```bash +pip install flash-linear-attention +``` + +## Model Variants + +### Dense + +| Variant | LLM dim | Layers | Heads | KV Heads | Head Dim | ViT dim | ViT layers | +|---------|---------|--------|-------|----------|----------|---------|------------| +| debugmodel | 256 | 8 | 4 | 2 | 64 | 256 | 4 | +| 0.8B | 1024 | 24 | 8 | 2 | 256 | 768 | 12 | +| 2B | 2048 | 24 | 8 | 2 | 256 | 1024 | 24 | +| 4B | 2560 | 32 | 16 | 4 | 256 | 1024 | 24 | +| 9B | 4096 | 32 | 16 | 4 | 256 | 1152 | 27 | +| 27B | 5120 | 64 | 24 | 4 | 256 | 1152 | 27 | + +### MoE + +| Variant | LLM dim | Layers | Experts | Top-k | Shared Expert | +|---------|---------|--------|---------|-------|---------------| +| debugmodel_moe | 256 | 4 | 8 | 2 | Yes | +| 35B-A3B | 2048 | 40 | 256 | 8 | Yes | +| 122B-A10B | 3072 | 48 | 256 | 8 | Yes | +| 397B-A17B | 4096 | 60 | 512 | 10 | Yes | + +## Datasets + +| Dataset | Type | Format | +|---------|------|--------| +| `cc12m` | Image-text pairs | WebDataset (streaming) | +| `cc12m-test` | Image-text pairs | Local WebDataset (for testing) | + +## Supported Parallelisms + +| Feature | Notes | +|---------|-------| +| FSDP / HSDP | Single `apply_fsdp` call covers both decoder and vision encoder | +| Tensor Parallelism (TP) | With Sequence Parallel; head-sharded TP on GatedDeltaNet projections | +| Expert Parallelism (EP) | For MoE variants | +| Context Parallel (CP) | Text-only; full attention uses ring attention, GatedDeltaNet allgathers full sequence | +| Pipeline Parallel (PP) | Vision encoder assigned to first stage; 1F1B and Interleaved1F1B schedules | +| Sample Packing | Configurable via `packing_buffer_size` in dataloader config | + +## Numerical Parity + +End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**. + +Parallelism correctness: bitwise identical logits across no-parallel, FSDP, FSDP+EP, FSDP+EP+TP, and FSDP+EP+TP+CP configs. + +Test scripts: +- `scripts/checkpoint_conversion/numerical_tests_qwen3_5.py` — HF vs TT comparison +- `scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py` — parallelism correctness + +## TODO + +- Add video dataset training configs diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py new file mode 100644 index 0000000000..42715084f1 --- /dev/null +++ b/torchtitan/models/qwen3_5/__init__.py @@ -0,0 +1,1058 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable +from functools import partial + +import torch.nn as nn + +from torchtitan.components.optimizer import register_moe_load_balancing_hook +from torchtitan.components.quantization import QuantizationConverter + +from torchtitan.models.common import Embedding, Linear, RoPE # noqa: F401 +from torchtitan.models.common.config_utils import ( + get_attention_config, + make_experts_config, + make_ffn_config, + make_moe_config, + make_router_config, +) +from torchtitan.models.common.param_init import depth_scaled_std # noqa: F401 + +from torchtitan.protocols.model_spec import ModelSpec + +from .model import ( + GatedDeltaNet, + OffsetRMSNorm, + Qwen35Attention, + Qwen35Model, + Qwen35TransformerBlock, +) +from .parallelize import parallelize_qwen3_5, pipeline_qwen3_5 +from .state_dict_adapter import Qwen35StateDictAdapter +from .vision_encoder import Qwen35VisionEncoder + +__all__ = [ + "parallelize_qwen3_5", + "Qwen35Model", + "qwen3_5_configs", + "QWEN3_5_SPECIAL_TOKENS", +] + +QWEN3_5_SPECIAL_TOKENS = { + "image_token": "<|image_pad|>", + "video_token": "<|video_pad|>", + "vision_start_token": "<|vision_start|>", + "vision_end_token": "<|vision_end|>", + "pad_token": "<|endoftext|>", +} + + +_LINEAR_INIT = { + "weight": partial(nn.init.trunc_normal_, std=0.02), + "bias": nn.init.zeros_, +} +_OFFSET_NORM_INIT = {"weight": nn.init.zeros_} +_EMBEDDING_INIT = {"weight": partial(nn.init.normal_, std=1.0)} +_POS_EMBED_INIT = {"pos_embed": partial(nn.init.trunc_normal_, mean=0.0, std=0.02)} + +_EPS = 1e-6 + + +def _output_linear_init(dim: int) -> dict[str, Callable]: + s = dim**-0.5 + return { + "weight": partial(nn.init.trunc_normal_, std=s, a=-3 * s, b=3 * s), + "bias": nn.init.zeros_, + } + + +def _depth_init(layer_id: int) -> dict[str, Callable]: + return { + "weight": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), + "bias": nn.init.zeros_, + } + + +def _depth_experts_init(layer_id: int) -> dict[str, Callable]: + return { + "w1": partial(nn.init.trunc_normal_, std=0.02), + "w2": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), + "w3": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), + } + + +def _a_log_init(param: nn.Parameter) -> None: + param.data.uniform_(1e-6, 16.0).log_() + + +def _linear(in_features: int, out_features: int) -> Linear.Config: + return Linear.Config( + in_features=in_features, + out_features=out_features, + bias=True, + param_init=_LINEAR_INIT, + ) + + +def _offset_norm(dim: int) -> OffsetRMSNorm.Config: + return OffsetRMSNorm.Config(dim=dim, eps=_EPS, param_init=_OFFSET_NORM_INIT) + + +def _vision_encoder_config( + *, + dim: int, + ffn_dim: int, + num_layers: int, + num_heads: int, + patch_size: int, + temporal_patch_size: int, + spatial_merge_size: int, + out_hidden_size: int, + num_position_embeddings: int, + in_channels: int = 3, +) -> Qwen35VisionEncoder.Config: + """Build a fully-specified Qwen35VisionEncoder.Config.""" + patch_dim = in_channels * temporal_patch_size * patch_size * patch_size + merged_hidden_size = dim * (spatial_merge_size**2) + return Qwen35VisionEncoder.Config( + dim=dim, + ffn_dim=ffn_dim, + num_layers=num_layers, + num_heads=num_heads, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + spatial_merge_size=spatial_merge_size, + out_hidden_size=out_hidden_size, + num_position_embeddings=num_position_embeddings, + patch_embed_proj=_linear(patch_dim, dim), + attn_wq=_linear(dim, dim), + attn_wk=_linear(dim, dim), + attn_wv=_linear(dim, dim), + attn_proj=_linear(dim, dim), + mlp_fc1=_linear(dim, ffn_dim), + mlp_fc2=_linear(ffn_dim, dim), + merger_fc1=_linear(merged_hidden_size, merged_hidden_size), + merger_fc2=_linear(merged_hidden_size, out_hidden_size), + param_init=_POS_EMBED_INIT, + ) + + +def _qwen35_attention_config( + *, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + rotary_dim: int, + attn_backend: str, + layer_id: int, +) -> Qwen35Attention.Config: + """Build a fully-specified Qwen35Attention.Config.""" + inner_attention, mask_type = get_qwen35_attention_config(attn_backend) + return Qwen35Attention.Config( + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + rotary_dim=rotary_dim, + wq=Linear.Config( + in_features=dim, + out_features=n_heads * head_dim * 2, + param_init=_LINEAR_INIT, + ), + wk=Linear.Config( + in_features=dim, + out_features=n_kv_heads * head_dim, + param_init=_LINEAR_INIT, + ), + wv=Linear.Config( + in_features=dim, + out_features=n_kv_heads * head_dim, + param_init=_LINEAR_INIT, + ), + wo=Linear.Config( + in_features=n_heads * head_dim, + out_features=dim, + param_init=_depth_init(layer_id), + ), + q_norm=_offset_norm(head_dim), + k_norm=_offset_norm(head_dim), + inner_attention=inner_attention, + mask_type=mask_type, + ) + + +def _qwen35_deltanet_config( + *, + dim: int, + n_key_heads: int, + n_value_heads: int, + key_head_dim: int, + value_head_dim: int, + layer_id: int, + fla_backend: str = "fla_chunked", +) -> GatedDeltaNet.Config: + """Build a fully-specified GatedDeltaNet.Config.""" + return GatedDeltaNet.Config( + dim=dim, + n_key_heads=n_key_heads, + n_value_heads=n_value_heads, + key_head_dim=key_head_dim, + value_head_dim=value_head_dim, + fla_backend=fla_backend, + in_proj_init=_LINEAR_INIT, + out_proj_init=_depth_init(layer_id), + norm_init={"weight": nn.init.ones_}, + param_init={ + "A_log": _a_log_init, + "dt_bias": nn.init.ones_, + }, + ) + + +def _build_qwen35_layers( + *, + n_layers: int, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + rotary_dim: int, + hidden_dim: int, + n_key_heads: int, + n_value_heads: int, + key_head_dim: int, + value_head_dim: int, + full_attention_interval: int = 4, + attn_backend: str, + fla_backend: str = "fla_chunked", +) -> list[Qwen35TransformerBlock.Config]: + """Build per-layer configs for dense Qwen3.5 models.""" + # Shared attention config — set on ALL layers so the trainer can read + # attn_config.inner_attention and mask_type from any layer. + shared_attn_config = _qwen35_attention_config( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + rotary_dim=rotary_dim, + attn_backend=attn_backend, + layer_id=0, + ) + layers = [] + for layer_id in range(n_layers): + is_full = (layer_id + 1) % full_attention_interval == 0 + layer_type = "full_attn" if is_full else "linear_attn" + + attention = ( + _qwen35_attention_config( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + rotary_dim=rotary_dim, + attn_backend=attn_backend, + layer_id=layer_id, + ) + if is_full + else shared_attn_config + ) + deltanet = ( + _qwen35_deltanet_config( + dim=dim, + n_key_heads=n_key_heads, + n_value_heads=n_value_heads, + key_head_dim=key_head_dim, + value_head_dim=value_head_dim, + layer_id=layer_id, + fla_backend=fla_backend, + ) + if not is_full + else None + ) + + layers.append( + Qwen35TransformerBlock.Config( + layer_type=layer_type, + attention=attention, + deltanet=deltanet, + feed_forward=make_ffn_config( + dim=dim, + hidden_dim=hidden_dim, + w1_param_init=_LINEAR_INIT, + w2w3_param_init=_depth_init(layer_id), + ), + attention_norm=_offset_norm(dim), + ffn_norm=_offset_norm(dim), + ) + ) + return layers + + +def _build_qwen35_moe_layers( + *, + n_layers: int, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + rotary_dim: int, + moe_hidden_dim: int, + num_experts: int, + top_k: int, + shared_expert_hidden_dim: int, + n_key_heads: int, + n_value_heads: int, + key_head_dim: int, + value_head_dim: int, + full_attention_interval: int = 4, + attn_backend: str, + fla_backend: str = "fla_chunked", + moe_comm_backend: str = "standard", + non_blocking_capacity_factor: float | None = None, +) -> list[Qwen35TransformerBlock.Config]: + """Build per-layer configs for MoE Qwen3.5 models with shared expert.""" + shared_attn_config = _qwen35_attention_config( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + rotary_dim=rotary_dim, + attn_backend=attn_backend, + layer_id=0, + ) + layers = [] + for layer_id in range(n_layers): + is_full = (layer_id + 1) % full_attention_interval == 0 + layer_type = "full_attn" if is_full else "linear_attn" + + attention = ( + _qwen35_attention_config( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + rotary_dim=rotary_dim, + attn_backend=attn_backend, + layer_id=layer_id, + ) + if is_full + else shared_attn_config + ) + deltanet = ( + _qwen35_deltanet_config( + dim=dim, + n_key_heads=n_key_heads, + n_value_heads=n_value_heads, + key_head_dim=key_head_dim, + value_head_dim=value_head_dim, + layer_id=layer_id, + fla_backend=fla_backend, + ) + if not is_full + else None + ) + + layers.append( + Qwen35TransformerBlock.Config( + layer_type=layer_type, + attention=attention, + deltanet=deltanet, + moe=make_moe_config( + num_experts=num_experts, + router=make_router_config( + dim=dim, + num_experts=num_experts, + gate_param_init=_depth_init(layer_id), + top_k=top_k, + score_func="softmax", + route_norm=True, + ), + experts=make_experts_config( + dim=dim, + hidden_dim=moe_hidden_dim, + num_experts=num_experts, + top_k=top_k, + param_init=_depth_experts_init(layer_id), + score_before_experts=False, + comm_backend=moe_comm_backend, + non_blocking_capacity_factor=non_blocking_capacity_factor, + ), + ), + shared_ffn=make_ffn_config( + dim=dim, + hidden_dim=shared_expert_hidden_dim, + w1_param_init=_LINEAR_INIT, + w2w3_param_init=_depth_init(layer_id), + ), + shared_gate=Linear.Config( + in_features=dim, + out_features=1, + param_init=_LINEAR_INIT, + ), + attention_norm=_offset_norm(dim), + ffn_norm=_offset_norm(dim), + ) + ) + return layers + + +def _debugmodel(attn_backend: str) -> Qwen35Model.Config: + """Debug config for Qwen3.5 with vision encoder.""" + dim = 256 + head_dim = 64 + rotary_dim = 16 + n_layers = 8 + vocab_size = 248320 + # mrope_section sum must equal rotary_dim / 2 (8 for rotary_dim=16). + # Real models use [11, 11, 10] with rotary_dim=64. + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=4096, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=4, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + hidden_dim=512, + n_key_heads=2, + n_value_heads=4, + key_head_dim=64, + value_head_dim=64, + fla_backend="fla_chunked", + ), + vision_encoder=_vision_encoder_config( + dim=256, + ffn_dim=512, + num_layers=4, + num_heads=4, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=256, + num_position_embeddings=1024, + ), + mrope_section=[3, 3, 2], + ) + + +def _debugmodel_moe( + attn_backend: str, + moe_comm_backend: str = "standard", +) -> Qwen35Model.Config: + """Debug MoE config for Qwen3.5 with shared expert.""" + dim = 256 + head_dim = 64 + rotary_dim = 16 + n_layers = 4 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=4096, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_moe_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=4, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + moe_hidden_dim=256, + num_experts=8, + top_k=2, + shared_expert_hidden_dim=256, + n_key_heads=2, + n_value_heads=4, + key_head_dim=64, + value_head_dim=64, + moe_comm_backend=moe_comm_backend, + fla_backend="fla_chunked", + ), + vision_encoder=_vision_encoder_config( + dim=256, + ffn_dim=512, + num_layers=2, + num_heads=4, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=256, + num_position_embeddings=1024, + ), + mrope_section=[3, 3, 2], + ) + + +def _0_8b(attn_backend: str) -> Qwen35Model.Config: + """Qwen3.5-0.8B dense config with vision encoder. + + NOTE: HF config has tie_word_embeddings=true. Torchtitan doesn't support + tied embeddings yet, so we use a separate lm_head. Checkpoint conversion + must handle this. + """ + dim = 1024 + head_dim = 256 + rotary_dim = 64 # partial_rotary_factor=0.25 → head_dim * 0.25 + n_layers = 24 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=8, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + hidden_dim=3584, + n_key_heads=16, + n_value_heads=16, + key_head_dim=128, + value_head_dim=128, + ), + vision_encoder=_vision_encoder_config( + dim=768, + ffn_dim=3072, + num_layers=12, + num_heads=12, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=1024, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _2b(attn_backend: str) -> Qwen35Model.Config: + """Qwen3.5-2B dense config with vision encoder. + + NOTE: HF config has tie_word_embeddings=true. Torchtitan doesn't support + tied embeddings yet, so we use a separate lm_head. Checkpoint conversion + must handle this. + """ + dim = 2048 + head_dim = 256 + rotary_dim = 64 # partial_rotary_factor=0.25 + n_layers = 24 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=8, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + hidden_dim=6144, + n_key_heads=16, + n_value_heads=16, + key_head_dim=128, + value_head_dim=128, + ), + vision_encoder=_vision_encoder_config( + dim=1024, + ffn_dim=4096, + num_layers=24, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=2048, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _4b(attn_backend: str) -> Qwen35Model.Config: + """Qwen3.5-4B dense config with vision encoder. + + NOTE: HF config has tie_word_embeddings=true. Torchtitan doesn't support + tied embeddings yet, so we use a separate lm_head. + """ + dim = 2560 + head_dim = 256 + rotary_dim = 64 + n_layers = 32 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=16, + n_kv_heads=4, + head_dim=head_dim, + rotary_dim=rotary_dim, + hidden_dim=9216, + n_key_heads=16, + n_value_heads=32, + key_head_dim=128, + value_head_dim=128, + ), + vision_encoder=_vision_encoder_config( + dim=1024, + ffn_dim=4096, + num_layers=24, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=2560, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _9b(attn_backend: str) -> Qwen35Model.Config: + """Qwen3.5-9B dense config with vision encoder.""" + dim = 4096 + head_dim = 256 + rotary_dim = 64 + n_layers = 32 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=16, + n_kv_heads=4, + head_dim=head_dim, + rotary_dim=rotary_dim, + hidden_dim=12288, + n_key_heads=16, + n_value_heads=32, + key_head_dim=128, + value_head_dim=128, + ), + vision_encoder=_vision_encoder_config( + dim=1152, + ffn_dim=4304, + num_layers=27, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=4096, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _27b(attn_backend: str) -> Qwen35Model.Config: + """Qwen3.5-27B dense config with vision encoder.""" + dim = 5120 + head_dim = 256 + rotary_dim = 64 + n_layers = 64 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=24, + n_kv_heads=4, + head_dim=head_dim, + rotary_dim=rotary_dim, + hidden_dim=17408, + n_key_heads=16, + n_value_heads=48, + key_head_dim=128, + value_head_dim=128, + ), + vision_encoder=_vision_encoder_config( + dim=1152, + ffn_dim=4304, + num_layers=27, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=5120, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _35b_a3b( + attn_backend: str, + moe_comm_backend: str = "standard", +) -> Qwen35Model.Config: + """Qwen3.5-35B-A3B MoE config with vision encoder.""" + dim = 2048 + head_dim = 256 + rotary_dim = 64 + n_layers = 40 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_moe_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=16, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + moe_hidden_dim=512, + num_experts=256, + top_k=8, + shared_expert_hidden_dim=512, + n_key_heads=16, + n_value_heads=32, + key_head_dim=128, + value_head_dim=128, + moe_comm_backend=moe_comm_backend, + ), + vision_encoder=_vision_encoder_config( + dim=1152, + ffn_dim=4304, + num_layers=27, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=2048, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _122b_a10b( + attn_backend: str, + moe_comm_backend: str = "standard", +) -> Qwen35Model.Config: + """Qwen3.5-122B-A10B MoE config with vision encoder.""" + dim = 3072 + head_dim = 256 + rotary_dim = 64 + n_layers = 48 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_moe_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=32, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + moe_hidden_dim=1024, + num_experts=256, + top_k=8, + shared_expert_hidden_dim=1024, + n_key_heads=16, + n_value_heads=64, + key_head_dim=128, + value_head_dim=128, + moe_comm_backend=moe_comm_backend, + ), + vision_encoder=_vision_encoder_config( + dim=1152, + ffn_dim=4304, + num_layers=27, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=3072, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +def _397b_a17b( + attn_backend: str, + moe_comm_backend: str = "standard", +) -> Qwen35Model.Config: + """Qwen3.5-397B-A17B MoE config with vision encoder.""" + dim = 4096 + head_dim = 256 + rotary_dim = 64 + n_layers = 60 + vocab_size = 248320 + return Qwen35Model.Config( + vocab_size=vocab_size, + dim=dim, + # pyrefly: ignore [bad-argument-type] + norm=_offset_norm(dim), + tok_embeddings=Embedding.Config( + num_embeddings=vocab_size, + embedding_dim=dim, + param_init=_EMBEDDING_INIT, + ), + lm_head=Linear.Config( + in_features=dim, + out_features=vocab_size, + param_init=_output_linear_init(dim), + ), + rope=RoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + backend="cos_sin", + ), + layers=_build_qwen35_moe_layers( + attn_backend=attn_backend, + n_layers=n_layers, + dim=dim, + n_heads=32, + n_kv_heads=2, + head_dim=head_dim, + rotary_dim=rotary_dim, + moe_hidden_dim=1024, + num_experts=512, + top_k=10, + shared_expert_hidden_dim=1024, + n_key_heads=16, + n_value_heads=64, + key_head_dim=128, + value_head_dim=128, + moe_comm_backend=moe_comm_backend, + ), + vision_encoder=_vision_encoder_config( + dim=1152, + ffn_dim=4304, + num_layers=27, + num_heads=16, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + out_hidden_size=4096, + num_position_embeddings=2304, + ), + mrope_section=[11, 11, 10], + ) + + +qwen3_5_configs = { + "debugmodel": _debugmodel_qwen35, + "debugmodel_moe": _debugmodel_moe_qwen35, + "0.8B": _0_8b_qwen35, + "2B": _2b_qwen35, + "4B": _4b_qwen35, + "9B": _9b_qwen35, + "27B": _27b_qwen35, + "35B-A3B": _35b_a3b_qwen35, + "122B-A10B": _122b_a10b_qwen35, + "397B-A17B": _397b_a17b_qwen35, +} + + +def model_registry( + flavor: str, + attn_backend: str = "sdpa", + moe_comm_backend: str | None = None, + quantization: list[QuantizationConverter.Config] | None = None, +) -> ModelSpec: + kwargs = dict(attn_backend=attn_backend) + if moe_comm_backend is not None: + kwargs["moe_comm_backend"] = moe_comm_backend + config = qwen3_5_configs[flavor](**kwargs) + if quantization is not None: + for q in quantization: + q.build().convert(config) + + # Detect MoE: check if any layer has moe config + has_moe = any(getattr(layer, "moe", None) is not None for layer in config.layers) + + return ModelSpec( + name="qwen3_5", + flavor=flavor, + model=config, + parallelize_fn=parallelize_qwen3_5, + pipelining_fn=pipeline_qwen3_5, + post_optimizer_build_fn=(register_moe_load_balancing_hook if has_moe else None), + state_dict_adapter=Qwen35StateDictAdapter, + ) diff --git a/torchtitan/models/qwen3_5/config_registry.py b/torchtitan/models/qwen3_5/config_registry.py new file mode 100644 index 0000000000..1af8f0c659 --- /dev/null +++ b/torchtitan/models/qwen3_5/config_registry.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.loss import ChunkedCELoss +from torchtitan.components.lr_scheduler import LRSchedulersContainer +from torchtitan.components.metrics import MetricsProcessor +from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.components.tokenizer import MultiModalTokenizer + +from torchtitan.config import ( + ActivationCheckpointConfig, + ParallelismConfig, + TrainingConfig, +) +from torchtitan.hf_datasets.multimodal.mm_datasets import MMDataLoader +from torchtitan.trainer import Trainer + +from . import model_registry, QWEN3_5_SPECIAL_TOKENS + + +def _dataloader(dataset: str, **kwargs) -> MMDataLoader.Config: + return MMDataLoader.Config( + dataset=dataset, + max_images_per_batch=128, + patch_size=16, + temporal_patch_size=2, + spatial_merge_size=2, + min_pixels=65536, + max_pixels=16777216, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + **kwargs, + ) + + +def qwen35_debugmodel() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./tests/assets/tokenizer", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + metrics=MetricsProcessor.Config(log_freq=1), + model_spec=model_registry("debugmodel"), + dataloader=_dataloader("cc12m-test"), + optimizer=OptimizersContainer.Config(lr=5e-3), + lr_scheduler=LRSchedulersContainer.Config( + warmup_steps=2, + decay_ratio=0.8, + decay_type="linear", + min_lr_factor=0.0, + ), + training=TrainingConfig( + local_batch_size=1, + seq_len=512, + steps=10, + ), + checkpoint=CheckpointManager.Config( + interval=10, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + ), + ) + + +def qwen35_debugmodel_moe() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./tests/assets/tokenizer", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + metrics=MetricsProcessor.Config(log_freq=1), + model_spec=model_registry("debugmodel_moe", moe_comm_backend="standard"), + dataloader=_dataloader("cc12m-test"), + optimizer=OptimizersContainer.Config(lr=5e-3), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=2), + training=TrainingConfig( + local_batch_size=1, + seq_len=512, + steps=10, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=4, + expert_parallel_degree=4, + tensor_parallel_degree=2, + ), + checkpoint=CheckpointManager.Config( + interval=10, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + ), + ) + + +def qwen35_0_8b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-0.8B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("0.8B"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-3), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + ), + ) + + +def qwen35_2b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-2B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("2B"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-3), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + ), + ) + + +def qwen35_4b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-4B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("4B"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-4), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + ), + checkpoint=CheckpointManager.Config( + interval=500, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="full", + ), + ) + + +def qwen35_9b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-9B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("9B"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-4), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + tensor_parallel_degree=2, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="full", + ), + ) + + +def qwen35_27b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-27B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("27B"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-4), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + tensor_parallel_degree=4, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="full", + ), + ) + + +def qwen35_35b_a3b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-35B-A3B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("35B-A3B", moe_comm_backend="standard"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-4), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + tensor_parallel_degree=2, + expert_parallel_degree=8, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="full", + ), + ) + + +def qwen35_122b_a10b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-122B-A10B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("122B-A10B", moe_comm_backend="standard"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-4), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + tensor_parallel_degree=4, + expert_parallel_degree=8, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="full", + ), + ) + + +def qwen35_397b_a17b() -> Trainer.Config: + return Trainer.Config( + loss=ChunkedCELoss.Config(), + hf_assets_path="./assets/hf/Qwen3.5-397B-A17B", + tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), + model_spec=model_registry("397B-A17B", moe_comm_backend="standard"), + dataloader=_dataloader("cc12m"), + optimizer=OptimizersContainer.Config(lr=5e-4), + lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + steps=1000, + ), + parallelism=ParallelismConfig( + data_parallel_shard_degree=-1, + tensor_parallel_degree=8, + expert_parallel_degree=16, + ), + checkpoint=CheckpointManager.Config( + interval=500, + last_save_model_only=False, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="full", + ), + ) diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py new file mode 100644 index 0000000000..ca59e8bd9b --- /dev/null +++ b/torchtitan/models/qwen3_5/model.py @@ -0,0 +1,1045 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import dataclasses +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from torch import nn + +from torchtitan.models.common.attention import AttentionMasksType, BaseAttention +from torchtitan.models.common.decoder import Decoder +from torchtitan.models.common.linear import Linear +from torchtitan.models.common.rope import apply_rotary_emb_cos_sin +from torchtitan.models.utils import get_moe_model_nparams_and_flops +from torchtitan.protocols.module import Module +from torchtitan.tools.logging import logger + +from .sharding import set_qwen35_sharding_config +from .vision_encoder import Qwen35VisionEncoder + +_Conv1d = Module.from_nn_module(nn.Conv1d) + + +try: + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule as _fla_chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule as _fla_fused_recurrent_gated_delta_rule, + ) + + _HAS_FLA = True +except ImportError: + _HAS_FLA = False + + +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + """L2 norm using rsqrt(sum(x²) + eps), not x/max(norm, eps) like F.normalize, to match FLA kernel.""" + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def _torch_naive_gated_delta( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, +) -> torch.Tensor: + """Standalone math reference for the gated delta rule recurrence. + + Sequential O(seqlen) loop — use FLA kernels for GPU efficiency. + + Args: + q, k: (bs, seqlen, n_heads, key_head_dim) + v: (bs, seqlen, n_heads, value_head_dim) + g: (bs, seqlen, n_heads) — log-space decay, always negative + beta: (bs, seqlen, n_heads) — update gate ∈ (0, 1) + + Returns: + output: (bs, seqlen, n_heads, value_head_dim) + """ + B, L, H, D_k = q.shape + D_v = v.shape[-1] + dtype = q.dtype + + # Upcast to float32 — recurrence accumulates over seqlen steps + q = _l2norm(q.float(), dim=-1) * (D_k**-0.5) + k = _l2norm(k.float(), dim=-1) + v, g, beta = v.float(), g.float(), beta.float() + + output = torch.zeros(B, L, H, D_v, dtype=torch.float32, device=q.device) + state = torch.zeros(B, H, D_k, D_v, dtype=torch.float32, device=q.device) + + for t in range(L): + q_t = q[:, t] + k_t = k[:, t] + v_t = v[:, t] + g_t = g[:, t].exp().unsqueeze(-1).unsqueeze(-1) + b_t = beta[:, t].unsqueeze(-1) + + state = state * g_t + kv_mem = torch.einsum("bhkv,bhk->bhv", state, k_t) + delta = (v_t - kv_mem) * b_t + state = state + torch.einsum("bhk,bhv->bhkv", k_t, delta) + output[:, t] = torch.einsum("bhkv,bhk->bhv", state, q_t) + + return output.to(dtype) + + +class OffsetRMSNorm(Module): + """RMSNorm with offset: ``(1 + weight) * norm(x)``. + + Weight is zero-initialized so the norm starts as identity-scaled. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + dim: int + eps: float = 1e-6 + + def __init__(self, config: Config): + super().__init__() + self.eps = config.eps + self.weight = nn.Parameter(torch.empty(config.dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Upcast to float32 for numerical stability in pow/rsqrt + input_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return ((1.0 + self.weight.float()) * x).to(input_dtype) + + +class RMSNormGated(Module): + """Gated RMSNorm: ``silu(gate) * weight * norm(x)``. + + Takes ``(x, gate)`` separately. Weight is ones-initialized. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + dim: int + eps: float = 1e-6 + + def __init__(self, config: Config): + super().__init__() + self.eps = config.eps + self.weight = nn.Parameter(torch.empty(config.dim)) + + def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + # Upcast to float32 for numerical stability in pow/rsqrt + input_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = (self.weight.float() * x).to(input_dtype) + x = x * F.silu(gate.float()) + return x.to(input_dtype) + + +class GatedDeltaKernel(Module): + """Stateless dispatch to FLA kernel or pure-torch fallback. + + Provides a module boundary for the sharding code to wrap forward with + DTensor→local conversion — same pattern as FlexAttention. Handles Q/K + head expansion for grouped linear attention internally so that + repeat_interleave runs on local tensors under TP. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + # "fla_chunked": parallel within chunks, fast for training (default) + # "fla_fused_recurrent": token-by-token, lower memory for long sequences + # "torch_naive": pure-Python reference, for numerical testing only + backend: str = "fla_chunked" + + def __init__(self, config: Config): + super().__init__() + self.backend = config.backend + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # Expand Q/K heads to match V when n_value_heads > n_key_heads + if q.shape[2] != v.shape[2]: + assert v.shape[2] % q.shape[2] == 0 + repeat = v.shape[2] // q.shape[2] + q = q.repeat_interleave(repeat, dim=2) + k = k.repeat_interleave(repeat, dim=2) + + if self.backend == "torch_naive": + return _torch_naive_gated_delta(q, k, v, g, beta) + + if not _HAS_FLA: + raise RuntimeError( + f"Backend '{self.backend}' requires the `fla` package. " + "Install: pip install flash-linear-attention" + ) + + if self.backend == "fla_chunked": + result = _fla_chunk_gated_delta_rule( + q, # pyrefly: ignore [bad-argument-type] + k, # pyrefly: ignore [bad-argument-count] + v, + g, + beta, + use_qk_l2norm_in_kernel=True, # pyrefly: ignore [unexpected-keyword] + ) + elif self.backend == "fla_fused_recurrent": + result = _fla_fused_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta=beta, + use_qk_l2norm_in_kernel=True, + ) + else: + raise ValueError( + f"Unknown fla_backend '{self.backend}'. " + "Valid: 'fla_chunked', 'fla_fused_recurrent', 'torch_naive'." + ) + + # FLA kernels return (output, final_state); we only need output + # pyrefly: ignore [unsupported-operation] + return result[0] + + +class GatedDeltaNet(Module): + """Gated DeltaNet linear attention. + + Uses recurrent state + gated delta rule instead of softmax attention. + No RoPE, no attention masks, different head structure from standard + attention. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + dim: int + n_key_heads: int + n_value_heads: int + key_head_dim: int + value_head_dim: int + conv_kernel_size: int = 4 + norm_eps: float = 1e-6 + fla_backend: str = "fla_chunked" + in_proj_init: dict + out_proj_init: dict + norm_init: dict + + def __init__(self, config: Config): + super().__init__() + self.key_head_dim = config.key_head_dim + self.value_head_dim = config.value_head_dim + self.conv_kernel_size = config.conv_kernel_size + + key_dim = config.n_key_heads * config.key_head_dim + value_dim = config.n_value_heads * config.value_head_dim + + self.in_proj_q = Linear.Config( + in_features=config.dim, + out_features=key_dim, + bias=False, + param_init=config.in_proj_init, + ).build() + self.in_proj_k = Linear.Config( + in_features=config.dim, + out_features=key_dim, + bias=False, + param_init=config.in_proj_init, + ).build() + self.in_proj_v = Linear.Config( + in_features=config.dim, + out_features=value_dim, + bias=False, + param_init=config.in_proj_init, + ).build() + self.in_proj_z = Linear.Config( + in_features=config.dim, + out_features=value_dim, + bias=False, + param_init=config.in_proj_init, + ).build() + self.in_proj_a = Linear.Config( + in_features=config.dim, + out_features=config.n_value_heads, + bias=False, + param_init=config.in_proj_init, + ).build() + self.in_proj_b = Linear.Config( + in_features=config.dim, + out_features=config.n_value_heads, + bias=False, + param_init=config.in_proj_init, + ).build() + + self.conv_q = _Conv1d( + in_channels=key_dim, + out_channels=key_dim, + bias=False, + kernel_size=config.conv_kernel_size, + groups=key_dim, + padding=0, + ) + self.conv_k = _Conv1d( + in_channels=key_dim, + out_channels=key_dim, + bias=False, + kernel_size=config.conv_kernel_size, + groups=key_dim, + padding=0, + ) + self.conv_v = _Conv1d( + in_channels=value_dim, + out_channels=value_dim, + bias=False, + kernel_size=config.conv_kernel_size, + groups=value_dim, + padding=0, + ) + + self.A_log = nn.Parameter(torch.empty(config.n_value_heads)) + self.dt_bias = nn.Parameter(torch.empty(config.n_value_heads)) + + self.kernel = GatedDeltaKernel.Config(backend=config.fla_backend).build() + + self.norm = RMSNormGated.Config( + dim=config.value_head_dim, + eps=config.norm_eps, + param_init=config.norm_init, + ).build() + self.out_proj = Linear.Config( + in_features=value_dim, + out_features=config.dim, + bias=False, + param_init=config.out_proj_init, + ).build() + + def _causal_conv(self, x: torch.Tensor, conv: nn.Module) -> torch.Tensor: + # pyrefly: ignore [bad-argument-type] + x = F.pad(x.transpose(1, 2), (self.conv_kernel_size - 1, 0)) + return F.silu(conv(x)).transpose(1, 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bs, seqlen, _ = x.shape + + # Split projections (not fused QKV) so each is ColwiseParallel for TP. + xq = self._causal_conv(self.in_proj_q(x), self.conv_q) + xk = self._causal_conv(self.in_proj_k(x), self.conv_k) + xv = self._causal_conv(self.in_proj_v(x), self.conv_v) + xz = self.in_proj_z(x) + xa = self.in_proj_a(x) + xb = self.in_proj_b(x) + + xq = xq.view(bs, seqlen, -1, self.key_head_dim) + xk = xk.view(bs, seqlen, -1, self.key_head_dim) + xv = xv.view(bs, seqlen, -1, self.value_head_dim) + + g = -torch.exp(self.A_log.float()) * F.softplus( + xa.float() + self.dt_bias + ) # decay rate, always negative + beta = torch.sigmoid(xb) # update gate ∈ (0, 1) + + output = self.kernel(xq, xk, xv, g, beta) + + xz = xz.view(bs, seqlen, -1, self.value_head_dim) + output = self.norm(output, xz) + + output = output.reshape(bs, seqlen, -1) + return self.out_proj(output) + + +class Qwen35Attention(BaseAttention): + """Full attention with output gating and partial RoPE for Qwen3.5. + + Differences from GQAttention: + - wq is 2x wider: produces both query and sigmoid gate + - Partial RoPE: only first ``rotary_dim`` elements get RoPE + - Output gating: ``attn_output * sigmoid(gate)`` before ``wo`` + - QK norm uses OffsetRMSNorm + """ + + @dataclass(kw_only=True, slots=True) + class Config(BaseAttention.Config): + n_heads: int + n_kv_heads: int + head_dim: int + rotary_dim: int + wq: Linear.Config + wk: Linear.Config + wv: Linear.Config + wo: Linear.Config + q_norm: OffsetRMSNorm.Config + k_norm: OffsetRMSNorm.Config + inner_attention: Module.Config + mask_type: str = "causal" + + def __init__(self, config: Config): + super().__init__() + self.n_heads = config.n_heads + self.n_kv_heads = config.n_kv_heads + self.head_dim = config.head_dim + self.rotary_dim = config.rotary_dim + self.enable_gqa = self.n_heads > self.n_kv_heads + + self.wq = config.wq.build() + self.wk = config.wk.build() + self.wv = config.wv.build() + self.wo = config.wo.build() + + self.q_norm = config.q_norm.build() + self.k_norm = config.k_norm.build() + + self.scaling = self.head_dim**-0.5 + + self.inner_attention = config.inner_attention.build() + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, + ) -> torch.Tensor: + bs, seqlen, _ = x.shape + + # wq is 2x wider: produces query + gate + xq_gate = self.wq(x).view(bs, seqlen, -1, self.head_dim * 2) + xq, gate = xq_gate.chunk(2, dim=-1) + xk = self.wk(x).view(bs, seqlen, -1, self.head_dim) + xv = self.wv(x).view(bs, seqlen, -1, self.head_dim) + + # QK norm (before RoPE) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # Partial RoPE: only first rotary_dim elements get positional encoding + assert self.rotary_dim <= self.head_dim + xq_rot, xq_pass = xq[..., : self.rotary_dim], xq[..., self.rotary_dim :] + xk_rot, xk_pass = xk[..., : self.rotary_dim], xk[..., self.rotary_dim :] + xq_rot, xk_rot = apply_rotary_emb_cos_sin(xq_rot, xk_rot, rope_cache, positions) + xq = torch.cat([xq_rot, xq_pass], dim=-1) + xk = torch.cat([xk_rot, xk_pass], dim=-1) + + output = self.inner_attention( + xq, + xk, + xv, + attention_masks=attention_masks, + scale=self.scaling, + enable_gqa=self.enable_gqa, + ).contiguous() + + # Output gating + output = output * torch.sigmoid(gate) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class Qwen35TransformerBlock(Module): + """Hybrid transformer block for Qwen3.5. + + Each layer uses either full attention (Qwen35Attention) or linear + attention (GatedDeltaNet), determined by ``layer_type`` in config. + Both types share the same FFN/MoE structure. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + layer_type: str # "full_attn" or "linear_attn" + attention: Qwen35Attention.Config | None = None + deltanet: GatedDeltaNet.Config | None = None + feed_forward: Module.Config | None = None + moe: Module.Config | None = None + shared_ffn: Module.Config | None = None + shared_gate: Linear.Config | None = None + attention_norm: OffsetRMSNorm.Config + ffn_norm: OffsetRMSNorm.Config + + def __init__(self, config: Config): + super().__init__() + self.layer_type = config.layer_type + + if config.layer_type == "full_attn": + assert config.attention is not None + self.attn = config.attention.build() + else: + assert config.deltanet is not None + self.attn = config.deltanet.build() + + self.moe_enabled = config.moe is not None + self.shared_expert_enabled = config.shared_ffn is not None + if self.moe_enabled: + # pyrefly: ignore [missing-attribute] + self.moe = config.moe.build() + else: + assert config.feed_forward is not None + self.feed_forward = config.feed_forward.build() + + if self.shared_expert_enabled: + # pyrefly: ignore [missing-attribute] + self.shared_ffn = config.shared_ffn.build() + assert config.shared_gate is not None + self.shared_gate = config.shared_gate.build() + + self.attention_norm = config.attention_norm.build() + self.ffn_norm = config.ffn_norm.build() + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, + ) -> torch.Tensor: + h = self.attention_norm(x) + if self.layer_type == "full_attn": + h = self.attn(h, freqs_cis, attention_masks, positions) + else: + h = self.attn(h) + x = x + h + + h = self.ffn_norm(x) + if self.moe_enabled: + moe_out = self.moe(h) + if self.shared_expert_enabled: + shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h) + x = x + moe_out + shared_out + else: + x = x + moe_out + else: + x = x + self.feed_forward(h) + return x + + +class Qwen35Model(Decoder): + """Qwen3.5: Multimodal model with hybrid attention. + + Combines a hybrid decoder (GatedDeltaNet linear attention + full + attention with output gating and partial RoPE) with a Vision + Transformer encoder for multimodal understanding. + + Key architectural features: + - Hybrid attention: 75% GatedDeltaNet (linear) + 25% full attention + - Output gating on full attention: ``attn_out * sigmoid(gate)`` + - Partial RoPE: only first ``rotary_dim`` elements get positional encoding + - OffsetRMSNorm: ``(1 + weight) * norm(x)`` with zero-init weight + - MRoPE: 3D position IDs (temporal, height, width) for vision tokens + - MoE variant: routed experts + shared expert with sigmoid gate + + Forward pass flow:: + + forward(tokens, pixel_values, grid_thw, ...) + │ + ├─ _prepare_multimodal_embeds + │ ├─ tok_embeddings(tokens) → text embeddings + │ ├─ _get_vision_embeds(pixel_values) → vision embeddings + │ │ └─ vision_encoder(pixel_values) → merge patches + │ ├─ _compute_vision_positions → locate vision regions + │ └─ _scatter_vision_embeds → scatter into text sequence + │ + ├─ _compute_mrope_freqs → 3D position IDs → interleaved cos/sin + │ + └─ transformer layers (hybrid) + └─ for each layer: + ├─ full attention (every Nth): QK-norm → partial RoPE → SDPA → gate + └─ GatedDeltaNet (others): Conv1d → gated delta rule → gated norm + """ + + @dataclass(kw_only=True, slots=True) + class Config(Decoder.Config): + vision_encoder: Qwen35VisionEncoder.Config + + # MRoPE section sizes for interleaved multi-dimensional RoPE + # [temporal, height, width] - controls how position dimensions are interleaved + mrope_section: list[int] = field(default_factory=lambda: [24, 20, 20]) + + def update_from_config( + self, + *, + trainer_config, + **kwargs, + ) -> None: + training = trainer_config.training + parallelism = trainer_config.parallelism + debug = trainer_config.debug + seq_len = training.seq_len + if seq_len > self.rope.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum " + f"{self.rope.max_seq_len}." + ) + self.rope = dataclasses.replace(self.rope, max_seq_len=seq_len) + + for layer_cfg in self.layers: + moe_cfg = getattr(layer_cfg, "moe", None) + if moe_cfg is not None: + moe_cfg.router._debug_force_load_balance = ( + debug.moe_force_load_balance + ) + + tp = parallelism.tensor_parallel_degree + if tp > 1: + attn_cfg = next( + (l.attention for l in self.layers if l.attention is not None), + None, + ) + if attn_cfg is not None and ( + attn_cfg.n_heads % tp != 0 or attn_cfg.n_kv_heads % tp != 0 + ): + raise ValueError( + f"tensor_parallel_degree ({tp}) must divide " + f"n_heads ({attn_cfg.n_heads}) and " + f"n_kv_heads ({attn_cfg.n_kv_heads})." + ) + dn_cfg = next( + (l.deltanet for l in self.layers if l.deltanet is not None), + None, + ) + if dn_cfg is not None and ( + dn_cfg.n_key_heads % tp != 0 or dn_cfg.n_value_heads % tp != 0 + ): + raise ValueError( + f"tensor_parallel_degree ({tp}) must divide " + f"n_key_heads ({dn_cfg.n_key_heads}) and " + f"n_value_heads ({dn_cfg.n_value_heads})." + ) + + set_qwen35_sharding_config( + self, + loss_parallel=not parallelism.disable_loss_parallel, + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, int]: + attn_cfg = next( + (l.attention for l in self.layers if l.attention is not None), + None, + ) + # pyrefly: ignore [missing-attribute] + n_heads = attn_cfg.n_heads + # pyrefly: ignore [missing-attribute] + head_dim = attn_cfg.head_dim + num_full_attn = sum(1 for l in self.layers if l.layer_type == "full_attn") + return get_moe_model_nparams_and_flops( + self, + model, + n_heads, + 2 * head_dim, + seq_len, + num_full_attn=num_full_attn, + ) + + def __init__(self, config: Config): + super().__init__(config) + + self.vision_encoder = config.vision_encoder.build() + + self.mrope_section = config.mrope_section + self.spatial_merge_size = config.vision_encoder.spatial_merge_size + + def _compute_mrope_freqs( + self, + tokens: torch.Tensor, + *, + grid_thw: torch.Tensor | None, + grid_thw_videos: torch.Tensor | None, + special_tokens: dict[str, int], + positions: torch.Tensor | None = None, + ) -> torch.Tensor: + """Build 3D position IDs and compute interleaved MRoPE cos/sin frequencies. + + Constructs (temporal, height, width) position IDs for each token, then + looks up cos/sin from the 1D RoPE table and overwrites H/W-assigned dims + with their own position lookups. + + Args: + tokens: (batch, seq_len) token IDs + grid_thw: (num_images, 3) grid dimensions for images + grid_thw_videos: (num_videos, 3) grid dimensions for videos + special_tokens: Special token definitions + positions: (batch, seq_len) per-token position IDs for packed + sequences. When provided, document boundaries are detected + where positions reset (positions[t] < positions[t-1]), and + pos_id_offset resets to 0 at each boundary + + Returns: + (batch, seq_len, 1, head_dim * 2) pre-computed MRoPE cos/sin + """ + # --- Build 3D position IDs --- + + # Expand each video [T, H, W] into T rows of [1, H, W] so that + # each frame is treated like an image in the MRoPE code below + # Temporal position comes from frame ordering in the sequence + if grid_thw_videos is not None: + grid_thw_videos = torch.repeat_interleave( + grid_thw_videos, grid_thw_videos[:, 0], dim=0 + ) + grid_thw_videos[:, 0] = 1 + + spatial_merge_size = self.spatial_merge_size + image_token_id = special_tokens["image_id"] + video_token_id = special_tokens["video_id"] + + batch_size, seq_len = tokens.shape + position_ids = torch.zeros( + 3, + batch_size, + seq_len, + dtype=tokens.dtype, + device=tokens.device, + ) + + # Precompute document boundaries and vision token positions across batch + if positions is not None: + resets = positions[:, 1:] < positions[:, :-1] # (batch, seq_len-1) + # Find the first token of each consecutive vision region (image or video) + # E.g. for [text, img, img, img, text, vid, vid] → positions [1, 5] + vision_mask = (tokens == image_token_id) | (tokens == video_token_id) + prev_vision = torch.cat( + [torch.zeros_like(vision_mask[:, :1]), vision_mask[:, :-1]], dim=1 + ) + batch_vision_starts = vision_mask & ~prev_vision # (batch, seq_len) + # Cache vision grid indices by shape to avoid redundant construction + grid_cache: dict[tuple[int, int, int], torch.Tensor] = {} + + image_index, video_index = 0, 0 + # Build MRoPE 3D position IDs per sample + # With sample packing, each sample may contain multiple documents + for sample_i in range(batch_size): + llm_pos_ids_list: list[torch.Tensor] = [] + + if positions is not None: + # Detect document boundaries within one packed sample + # pyrefly: ignore [unbound-name] + reset_indices = torch.where(resets[sample_i])[0] + 1 + doc_starts = [0] + reset_indices.tolist() + doc_ranges = [ + ( + doc_starts[d], + doc_starts[d + 1] if d + 1 < len(doc_starts) else seq_len, + ) + for d in range(len(doc_starts)) + ] + else: + doc_ranges = [(0, seq_len)] + + sample_tokens = tokens[sample_i] + sample_vision_starts = torch.where(batch_vision_starts[sample_i])[ + 0 + ].tolist() + vision_start_index = 0 + + for doc_start, doc_end in doc_ranges: + doc_pos_ids_list: list[torch.Tensor] = [] + + # Advance pointer to collect vision region starts in this document + doc_vision_starts: list[int] = [] + while ( + vision_start_index < len(sample_vision_starts) + and sample_vision_starts[vision_start_index] < doc_end + ): + doc_vision_starts.append(sample_vision_starts[vision_start_index]) + vision_start_index += 1 + + # Process [text tokens][vision tokens] pairs within this document + pair_cursor = doc_start + for vision_start in doc_vision_starts: + if sample_tokens[vision_start] == image_token_id: + # pyrefly: ignore [unsupported-operation] + t, h, w = grid_thw[image_index] + image_index += 1 + else: + # pyrefly: ignore [unsupported-operation] + t, h, w = grid_thw_videos[video_index] + video_index += 1 + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = vision_start - pair_cursor + + # pos_id_offset may differ from pair_cursor due to compact + # spatial position IDs for vision regions + pos_id_offset = ( + doc_pos_ids_list[-1].max() + 1 + if len(doc_pos_ids_list) > 0 + else 0 + ) + # [text tokens] — sequential positions, identical on all 3 axes + doc_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset + ) + # [vision tokens] — 3D grid positions (T, H, W) + grid_key = (llm_grid_t, llm_grid_h, llm_grid_w) + if grid_key not in grid_cache: + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + # pyrefly: ignore [no-matching-overload] + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + # pyrefly: ignore [no-matching-overload] + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + # pyrefly: ignore [no-matching-overload] + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + # pyrefly: ignore [unsupported-operation] + grid_cache[grid_key] = torch.stack([t_index, h_index, w_index]) + doc_pos_ids_list.append( + # pyrefly: ignore [bad-index] + grid_cache[grid_key] + + text_len + + pos_id_offset + ) + pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w + + # Trailing [text tokens] after the last [text tokens][vision tokens] pair + if pair_cursor < doc_end: + pos_id_offset = ( + doc_pos_ids_list[-1].max() + 1 + if len(doc_pos_ids_list) > 0 + else 0 + ) + text_len = doc_end - pair_cursor + doc_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset + ) + + llm_pos_ids_list.extend(doc_pos_ids_list) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[:, sample_i, :] = llm_positions.to(position_ids.device) + + # --- Compute interleaved MRoPE cos/sin from position IDs --- + freqs_cis = self.freqs_cis + head_dim = freqs_cis.shape[-1] // 2 + cos_cache = freqs_cis[:, :head_dim] + sin_cache = freqs_cis[:, head_dim:] + + # Initialize with temporal positions, then overwrite H/W slices + t_pos = position_ids[0].long() + mrope_cos = cos_cache[t_pos] + mrope_sin = sin_cache[t_pos] + + # Overwrite H and W slices with their own position lookups + # Both halves of head_dim must be updated (head_dim = cat([freqs, freqs])) + half = head_dim // 2 + for dim, offset in enumerate((1, 2), start=1): # H, W + length = self.mrope_section[dim] * 3 + low = torch.arange(offset, length, 3, device=freqs_cis.device) + col_indices = torch.cat([low, low + half]) + dim_pos = position_ids[dim].long() + mrope_cos[..., col_indices] = cos_cache[:, col_indices][dim_pos] + mrope_sin[..., col_indices] = sin_cache[:, col_indices][dim_pos] + + return torch.cat([mrope_cos, mrope_sin], dim=-1).unsqueeze(2) + + def _compute_vision_positions( + self, + tokens: torch.Tensor, + num_tokens_per_item: torch.Tensor, + vision_token_id: int, + ) -> list[tuple[int, int, int, int]]: + """Compute (item_idx, sample_idx, vision_start, n_tokens) for each vision item. + + Finds where each contiguous run of vision placeholder tokens starts + in the text sequence. + + Args: + tokens: Token IDs (batch, seq_len) + num_tokens_per_item: (num_items,) actual tokens per vision item + vision_token_id: Placeholder token ID + + Returns: + List of (item_idx, sample_idx, vision_start, n_tokens) tuples + """ + vision_mask = tokens == vision_token_id + flat_mask = vision_mask.view(-1) + prev_mask = torch.cat( + [torch.zeros(1, dtype=torch.bool, device=flat_mask.device), flat_mask[:-1]] + ) + region_starts = torch.where(flat_mask & ~prev_mask)[0] + seq_len = tokens.shape[1] + + positions = [] + for i in range(num_tokens_per_item.shape[0]): + start = int(region_starts[i].item()) + n_tokens = int(num_tokens_per_item[i].item()) + positions.append((i, start // seq_len, start % seq_len, n_tokens)) + return positions + + def _get_vision_embeds( + self, + pixel_values: torch.Tensor, + *, + grid_thw: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run vision encoder and return padded embeddings with token counts. + + Args: + pixel_values: Padded patches (num_items, max_num_patch, patch_dim) + grid_thw: Grid dimensions (num_items, 3) for [t, h, w] + + Returns: + merged_embeds: (num_items, max_tokens, dim) padded vision embeddings + num_tokens_per_item: (num_items,) actual token count per item + """ + pixel_values = pixel_values.to(self.vision_encoder.patch_embed.weight.dtype) + merged_embeds = self.vision_encoder(pixel_values, grid_thw=grid_thw) + + merge_unit = self.vision_encoder.spatial_merge_unit + num_tokens_per_item = grid_thw.prod(-1) // merge_unit + + return merged_embeds, num_tokens_per_item + + def _scatter_vision_embeds( + self, + inputs_embeds: torch.Tensor, + *, + merged_embeds: torch.Tensor, + vision_positions: list[tuple[int, int, int, int]], + ) -> torch.Tensor: + """Scatter vision embeddings into text embeddings at placeholder positions. + + Copies directly from the padded vision encoder output into the text + sequence. + + Args: + inputs_embeds: Text embeddings (batch, seq_len, dim) + merged_embeds: Padded vision embeddings (num_items, max_tokens, dim) + vision_positions: List of (item_idx, sample_idx, vision_start, n_tokens) + + Returns: + Updated embeddings + """ + for item_idx, sample_idx, vision_start, n_tokens in vision_positions: + inputs_embeds[ + sample_idx, vision_start : vision_start + n_tokens, : + ] = merged_embeds[item_idx, :n_tokens, :] + return inputs_embeds + + def _prepare_multimodal_embeds( + self, + tokens: torch.Tensor, + *, + pixel_values: torch.Tensor | None, + pixel_values_videos: torch.Tensor | None, + grid_thw: torch.Tensor | None, + grid_thw_videos: torch.Tensor | None, + special_tokens: dict[str, int], + ) -> torch.Tensor: + """Embed tokens, run vision encoder, scatter vision into text. + + Args: + tokens: Input token IDs (batch_size, seq_len) + pixel_values: Image patches or None + pixel_values_videos: Video patches or None + grid_thw: Grid dimensions for images or None + grid_thw_videos: Grid dimensions for videos or None + special_tokens: Special token definitions + + Returns: + (batch, seq_len, dim) embeddings with vision tokens scattered in + """ + image_token_id = special_tokens["image_id"] + video_token_id = special_tokens["video_id"] + + inputs_embeds = ( + self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + ) + + if pixel_values is not None and grid_thw is not None: + merged_embeds, num_tokens = self._get_vision_embeds( + pixel_values, grid_thw=grid_thw + ) + image_positions = self._compute_vision_positions( + tokens, num_tokens, image_token_id + ) + if image_positions: + inputs_embeds = self._scatter_vision_embeds( + inputs_embeds, + merged_embeds=merged_embeds, + vision_positions=image_positions, + ) + + if pixel_values_videos is not None and grid_thw_videos is not None: + merged_embeds, num_tokens = self._get_vision_embeds( + pixel_values_videos, grid_thw=grid_thw_videos + ) + video_positions = self._compute_vision_positions( + tokens, num_tokens, video_token_id + ) + if video_positions: + inputs_embeds = self._scatter_vision_embeds( + inputs_embeds, + merged_embeds=merged_embeds, + vision_positions=video_positions, + ) + + return inputs_embeds + + def forward( # pyrefly: ignore [bad-override] + self, + tokens: torch.Tensor, + *, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + grid_thw: torch.Tensor | None = None, + grid_thw_videos: torch.Tensor | None = None, + attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, + special_tokens: dict[str, int] | None = None, + ): + if self.tok_embeddings is not None: + x = self._prepare_multimodal_embeds( + tokens, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + grid_thw=grid_thw, + grid_thw_videos=grid_thw_videos, + special_tokens=special_tokens, # pyrefly: ignore [bad-argument-type] + ) + else: + x = tokens + + if grid_thw is not None or grid_thw_videos is not None: + freqs_cis = self._compute_mrope_freqs( + tokens, + grid_thw=grid_thw, + grid_thw_videos=grid_thw_videos, + special_tokens=special_tokens, # pyrefly: ignore [bad-argument-type] + positions=positions, + ) + else: + freqs_cis = self.freqs_cis + for layer in self.layers.values(): + x = layer(x, freqs_cis, attention_masks, positions) + + x = self.norm(x) if self.norm is not None else x + if self._skip_lm_head: + return x + return self.lm_head(x) if self.lm_head is not None else x diff --git a/torchtitan/models/qwen3_5/parallelize.py b/torchtitan/models/qwen3_5/parallelize.py new file mode 100644 index 0000000000..b46d61e5a4 --- /dev/null +++ b/torchtitan/models/qwen3_5/parallelize.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Parallelization utilities for Qwen3.5. + +This module applies PT-D parallelisms and various training techniques +(activation checkpointing, compile, FSDP) to the Qwen3.5 model. +""" + +import torch.nn as nn + +from torchtitan.config import ( + ActivationCheckpointConfig, + CompileConfig, + ParallelismConfig, + TORCH_DTYPE_MAP, + TrainingConfig, +) + +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.compile import apply_compile +from torchtitan.distributed.context_parallel import apply_cp_to_forward +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.llama4.parallelize import apply_fsdp, apply_moe_ep_tp +from torchtitan.models.qwen3_5.sharding import ( + set_deltanet_sub_module_sharding, + set_vision_encoder_sub_module_sharding, +) +from torchtitan.tools.logging import logger + + +def parallelize_qwen3_5( + model: nn.Module, + *, + parallel_dims: ParallelDims, + training: TrainingConfig, + parallelism: ParallelismConfig, + compile_config: CompileConfig, + ac_config: ActivationCheckpointConfig, + dump_folder: str, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the Qwen3.5 model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + model_compile_enabled = ( + compile_config.enable and "model" in compile_config.components + ) + + # Context Parallel: wrap inner attention forward BEFORE TP so CP logic + # runs inside the local_map boundary on local tensors. + # Applies to full attention layers only — GatedDeltaNet is recurrent + # and allgathers the full sequence via cp=Replicate() in sharding. + if parallel_dims.cp_enabled: + cp_mesh = parallel_dims.get_mesh("cp") + full_attn_inner_modules = [ + block.attn.inner_attention # pyrefly: ignore [missing-attribute] + for block in model.layers.values() # pyrefly: ignore [not-callable] + if block.layer_type == "full_attn" # pyrefly: ignore [missing-attribute] + ] + if full_attn_inner_modules: + apply_cp_to_forward(full_attn_inner_modules, cp_mesh) + + if parallel_dims.tp_enabled: + if parallelism.enable_async_tensor_parallel and not model_compile_enabled: + raise RuntimeError("Async TP requires torch.compile") + + tp_mesh = parallel_dims.get_mesh("tp") + + # For sub-modules built inline, set _sharding_config on built modules. + # pyrefly: ignore [not-callable] + for block in model.layers.values(): + # pyrefly: ignore [missing-attribute] + if block.layer_type != "full_attn": + # pyrefly: ignore [missing-attribute] + set_deltanet_sub_module_sharding(block.attn) + if model.vision_encoder is not None: + set_vision_encoder_sub_module_sharding(model.vision_encoder) + # pyrefly: ignore [not-callable] + model.parallelize(tp_mesh) + maybe_enable_async_tp(parallelism, compile_config, tp_mesh) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + enable_sp=parallel_dims.tp_enabled, + ) + + if ac_config.mode != "none": + apply_ac( + model, + ac_config, + model_compile_enabled=model_compile_enabled, + base_folder=dump_folder, + ) + if model.vision_encoder is not None: + apply_ac( + # pyrefly: ignore [bad-argument-type] + model.vision_encoder, + ac_config, + model_compile_enabled=model_compile_enabled, + base_folder=dump_folder, + ) + + if model_compile_enabled: + apply_compile(model, compile_config) + if model.vision_encoder is not None: + # pyrefly: ignore [bad-argument-type] + apply_compile(model.vision_encoder, compile_config) + + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) + + edp_mesh = None + if parallel_dims.ep_enabled: + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=training.enable_cpu_offload, + reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + edp_mesh=edp_mesh, + ) + + logger.info("Applied fully_shard to the Qwen3.5 model") + + if training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the Qwen3.5 model") + + return model + + +def pipeline_qwen3_5( + model: nn.Module, + *, + parallel_dims: ParallelDims, + parallelism: ParallelismConfig, + model_config, + **kwargs, +): + """PP wrapper that assigns vision_encoder to the first pipeline stage. + + Delegates to ``pipeline_llm`` after injecting ``vision_encoder`` into + the first stage's FQN list (the auto-generated LLM split doesn't know + about vision encoder modules). + """ + import dataclasses + + from torchtitan.distributed.pipeline_parallel import ( + _generate_llm_fqn_per_model_part, + _get_pipeline_metadata, + pipeline_llm, + ) + + if parallelism.module_fqns_per_model_part is None: + ( + num_virtual_stages, + num_layers, + input_weight, + output_weight, + ) = _get_pipeline_metadata(parallel_dims, parallelism, model_config) + fqn_per_part = _generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + # Vision encoder lives on the first stage alongside tok_embeddings + if hasattr(model, "vision_encoder") and model.vision_encoder is not None: + fqn_per_part[0].insert(0, "vision_encoder") + parallelism = dataclasses.replace( + parallelism, module_fqns_per_model_part=fqn_per_part + ) + + return pipeline_llm( + model, + parallel_dims=parallel_dims, + parallelism=parallelism, + model_config=model_config, + **kwargs, + ) diff --git a/torchtitan/models/qwen3_vl/requirements.txt b/torchtitan/models/qwen3_5/requirements.txt similarity index 100% rename from torchtitan/models/qwen3_vl/requirements.txt rename to torchtitan/models/qwen3_5/requirements.txt diff --git a/torchtitan/models/qwen3_vl/rope.py b/torchtitan/models/qwen3_5/rope.py similarity index 100% rename from torchtitan/models/qwen3_vl/rope.py rename to torchtitan/models/qwen3_5/rope.py diff --git a/torchtitan/models/qwen3_5/sharding.py b/torchtitan/models/qwen3_5/sharding.py new file mode 100644 index 0000000000..10d138bb23 --- /dev/null +++ b/torchtitan/models/qwen3_5/sharding.py @@ -0,0 +1,371 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Sharding configs for Qwen3.5 hybrid attention model. + +Sets ``ShardingConfig`` on all sub-configs so that ``model.parallelize()`` +applies TP via the Module protocol. Same pattern as ``qwen3/sharding.py``. + +Full attention layers: TP on wq/wk/wv/wo with local_map for inner attention. +GatedDeltaNet layers: head-sharded TP on projections (ColwiseParallel) and +out_proj (RowwiseParallel). Conv1d and FLA kernel forwards are wrapped for +DTensor→local conversion. +""" + +import types +from typing import TYPE_CHECKING + +import torch.nn.functional as F + +from torch import nn +from torch.distributed.tensor import DTensor, Replicate, Shard + +from torchtitan.models.common.decoder_sharding import ( + colwise_config, + dense_activation_placement, + dense_param_placement, + norm_config, + rowwise_config, + set_decoder_sharding_config, + set_dense_ffn_sharding, + set_gqa_inner_attention_local_map, +) +from torchtitan.protocols.sharding import LocalMapConfig, ShardingConfig + +if TYPE_CHECKING: + from torchtitan.models.qwen3_5.model import ( + Qwen35Attention, + Qwen35Model, + Qwen35TransformerBlock, + ) + +TP = "tp" + +_REPLICATE_PARAM = dense_param_placement(tp=Replicate()) +_REPLICATE_STATE = ShardingConfig( + state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM} +) +_REPLICATE_ACT = dense_activation_placement(tp=Replicate()) + +# For norms/modules that receive and emit Replicate activations +_REPLICATE_NORM = ShardingConfig( + state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM}, + in_src_shardings={"input": _REPLICATE_ACT}, + in_dst_shardings={"input": _REPLICATE_ACT}, + out_dst_shardings=_REPLICATE_ACT, +) + + +def set_qwen35_sharding_config( + config: "Qwen35Model.Config", + *, + loss_parallel: bool, +) -> None: + """Fill ``sharding_config`` on all Qwen3.5 sub-configs. + + Uses SP for decoder layers, norm, and lm_head. tok_embeddings output + stays Replicate so vision scatter and MRoPE can access the full sequence. + The model forward redistributes to Shard(1) before entering the layers. + """ + # SP on norm, lm_head, and layers + set_decoder_sharding_config(config, loss_parallel=loss_parallel, enable_sp=True) + # Override: don't distribute freqs_cis — MRoPE indexes it with plain tensors + config.sharding_config = ShardingConfig() + # Override tok_embeddings: output Replicate (not Shard(1)) for vision scatter + config.tok_embeddings.sharding_config = ShardingConfig( + state_shardings={"weight": dense_param_placement(tp=Shard(0))}, + in_src_shardings={"input": _REPLICATE_ACT}, + in_dst_shardings={"input": _REPLICATE_ACT}, + out_dst_shardings=_REPLICATE_ACT, + ) + _set_vision_encoder_sharding(config.vision_encoder) + for layer_cfg in config.layers: + _set_qwen35_layer_sharding(layer_cfg) + + +def _set_qwen35_layer_sharding( + layer_cfg: "Qwen35TransformerBlock.Config", +) -> None: + norm = norm_config(enable_sp=True) + layer_cfg.attention_norm.sharding_config = norm + layer_cfg.ffn_norm.sharding_config = norm + + if layer_cfg.layer_type == "full_attn": + assert layer_cfg.attention is not None + _set_full_attention_sharding(layer_cfg.attention) + else: + assert layer_cfg.deltanet is not None + _set_deltanet_sharding(layer_cfg.deltanet) + + if layer_cfg.feed_forward is not None: + set_dense_ffn_sharding( + layer_cfg.feed_forward, + attn_x_placement=Shard(1), + enable_sp=True, + ) + + if layer_cfg.shared_ffn is not None: + set_dense_ffn_sharding( + layer_cfg.shared_ffn, + attn_x_placement=Shard(1), + enable_sp=True, + ) + if layer_cfg.shared_gate is not None: + layer_cfg.shared_gate.sharding_config = _REPLICATE_STATE + + +def _set_vision_encoder_sharding(ve_cfg) -> None: + """Sharding for the vision encoder. + + All activations flow as Replicate — no SP in the vision encoder. + Linear layers are ColwiseParallel/RowwiseParallel for memory savings. + Norms and patch_embed are Replicate. pos_embed is distributed as + Replicate via state_shardings on the encoder config. + """ + ve_cfg.sharding_config = ShardingConfig( + state_shardings={"pos_embed": _REPLICATE_PARAM}, + ) + + ve_cfg.patch_embed_proj.sharding_config = _REPLICATE_STATE + + # Separate Q/K/V: colwise sharding + ve_cfg.attn_wq.sharding_config = colwise_config() + ve_cfg.attn_wk.sharding_config = colwise_config() + ve_cfg.attn_wv.sharding_config = colwise_config() + ve_cfg.attn_proj.sharding_config = rowwise_config(output_sp=False) + ve_cfg.mlp_fc1.sharding_config = colwise_config() + ve_cfg.mlp_fc2.sharding_config = rowwise_config(output_sp=False) + + ve_cfg.merger_fc1.sharding_config = colwise_config() + ve_cfg.merger_fc2.sharding_config = rowwise_config(output_sp=False) + + +def _set_full_attention_sharding( + attention_cfg: "Qwen35Attention.Config", +) -> None: + """TP sharding for Qwen35Attention (output gating + partial RoPE).""" + attention_cfg.sharding_config = ShardingConfig( + in_src_shardings={ + "x": dense_activation_placement(tp=Shard(1)), + "rope_cache": dense_param_placement(tp=Replicate()), + }, + in_dst_shardings={ + "x": dense_activation_placement(tp=Replicate()), + "rope_cache": dense_param_placement(tp=Replicate()), + }, + ) + attention_cfg.wq.sharding_config = colwise_config() + attention_cfg.wk.sharding_config = colwise_config() + attention_cfg.wv.sharding_config = colwise_config() + attention_cfg.wo.sharding_config = rowwise_config(output_sp=True) + + qk_norm_sharding = ShardingConfig( + state_shardings={"weight": _REPLICATE_PARAM}, + in_src_shardings={"input": dense_activation_placement(tp=Shard(2))}, + in_dst_shardings={"input": dense_activation_placement(tp=Shard(2))}, + out_dst_shardings=dense_activation_placement(tp=Shard(2)), + ) + attention_cfg.q_norm.sharding_config = qk_norm_sharding + attention_cfg.k_norm.sharding_config = qk_norm_sharding + + set_gqa_inner_attention_local_map(attention_cfg.inner_attention) + + +def _set_deltanet_sharding(deltanet_cfg) -> None: + """Sharding for GatedDeltaNet: head-sharded TP on projections. + + Input is allgathered on both TP and CP (Shard(1)→Replicate) because + the recurrence needs the full sequence. Projections are ColwiseParallel + (head-sharded output). Conv1d and FLA kernels are wrapped for + DTensor→local conversion. out_proj is RowwiseParallel (reduce-scatter + back to Shard(1)). + + A_log and dt_bias are per-head parameters, Shard(0) on TP. + Sub-module sharding is set on built modules by + ``set_deltanet_sub_module_sharding`` before ``model.parallelize()``. + """ + deltanet_cfg.sharding_config = ShardingConfig( + state_shardings={ + "A_log": dense_param_placement(tp=Shard(0)), + "dt_bias": dense_param_placement(tp=Shard(0)), + }, + in_src_shardings={"x": dense_activation_placement(tp=Shard(1))}, + # cp=Replicate: GatedDeltaNet is recurrent — needs full sequence + in_dst_shardings={ + "x": dense_activation_placement(tp=Replicate(), cp=Replicate()) + }, + out_dst_shardings=dense_activation_placement(tp=Shard(1)), + ) + + +def set_vision_encoder_sub_module_sharding(vision_encoder) -> None: + """Set _sharding_config on vision encoder sub-modules built inline. + + Norms (LayerNorm) in VisionTransformerBlock and PatchMerger are created + via Module.from_nn_module(nn.LayerNorm) — not from config fields. + Must be called after model build but before model.parallelize(). + """ + for layer in vision_encoder.layers.values(): + for name in ("norm1", "norm2"): + child = getattr(layer, name, None) + if child is not None: + child._sharding_config = _REPLICATE_NORM + # VisionAttention: declare rope_cache as Replicate so plain + # rope_cache is wrapped as DTensor to match DTensor q/k. + layer.attn._sharding_config = ShardingConfig( + in_src_shardings={"rope_cache": _REPLICATE_ACT}, + in_dst_shardings={"rope_cache": _REPLICATE_ACT}, + ) + # FlexAttention: local_map to convert DTensor q/k/v to local. + # Same as set_gqa_inner_attention_local_map but on built module. + if hasattr(layer.attn, "flex_attention"): + qkv_plc = {TP: Shard(2)} + layer.attn.flex_attention._sharding_config = ShardingConfig( + local_map=LocalMapConfig( + # pyrefly: ignore [bad-argument-type] + in_placements=(qkv_plc, qkv_plc, qkv_plc), + # pyrefly: ignore [bad-argument-type] + out_placements=(qkv_plc,), + # pyrefly: ignore [bad-argument-type] + in_grad_placements=(qkv_plc, qkv_plc, qkv_plc), + ), + ) + # Merger norm + if hasattr(vision_encoder.merger, "norm"): + vision_encoder.merger.norm._sharding_config = _REPLICATE_NORM + # Merger GELU: set None to skip protocol wrapping. Per-layer mlp.act_fn + # doesn't need this because its parent VisionMLP has no _sharding_config, + # but the merger's children get processed due to merger.norm having one. + if hasattr(vision_encoder.merger, "act_fn"): + vision_encoder.merger.act_fn._sharding_config = None + # VisionRotaryEmbedding: don't set _sharding_config — wrapping forward + # would break RoPE compute. inv_freq stays as a plain buffer; the + # resulting rope_cache is wrapped as DTensor by VisionAttention's + # in_src_shardings. + + # pos_embed interpolation: F.interpolate's decomposition doesn't + # support DTensor. Wrap to convert pos_embed to local before use. + _wrap_pos_embed_for_interpolation(vision_encoder) + + # patch_embed (Linear): plain pixel_values in → DTensor(Replicate) out + vision_encoder.patch_embed._sharding_config = ShardingConfig( + state_shardings={ + "weight": _REPLICATE_PARAM, + "bias": _REPLICATE_PARAM, + }, + in_src_shardings={"input": _REPLICATE_ACT}, + in_dst_shardings={"input": _REPLICATE_ACT}, + out_dst_shardings=_REPLICATE_ACT, + ) + + +def set_deltanet_sub_module_sharding(deltanet_module) -> None: + """Set head-sharded TP on GatedDeltaNet sub-modules. + + Projections are ColwiseParallel (head-sharded output), out_proj is + RowwiseParallel (reduce-scatter to SP). Conv1d weights are Shard(0) + on the channel dim (matching head sharding). The conv and kernel + forwards are wrapped for DTensor→local conversion (depthwise conv + and FLA kernels don't support DTensor dispatch). + + Must be called after model build but before model.parallelize(). + """ + for name in ( + "in_proj_q", + "in_proj_k", + "in_proj_v", + "in_proj_z", + "in_proj_a", + "in_proj_b", + ): + getattr(deltanet_module, name)._sharding_config = colwise_config() + + _conv_shard = ShardingConfig( + state_shardings={"weight": dense_param_placement(tp=Shard(0))}, + ) + for name in ("conv_q", "conv_k", "conv_v"): + conv = getattr(deltanet_module, name) + conv._sharding_config = _conv_shard + _wrap_conv1d(conv) + + # GatedDeltaKernel: local_map converts DTensor q/k/v/g/beta to local + # for FLA kernels, same pattern as FlexAttention's local_map. + _kernel_plc = {TP: Shard(2)} + deltanet_module.kernel._sharding_config = ShardingConfig( + local_map=LocalMapConfig( + # pyrefly: ignore [bad-argument-type] + in_placements=(_kernel_plc,) * 5, + # pyrefly: ignore [bad-argument-type] + out_placements=(_kernel_plc,), + # pyrefly: ignore [bad-argument-type] + in_grad_placements=(_kernel_plc,) * 5, + ), + ) + + deltanet_module.norm._sharding_config = _REPLICATE_STATE + deltanet_module.out_proj._sharding_config = rowwise_config(output_sp=True) + + +def _wrap_conv1d(conv1d_module) -> None: + """Wrap depthwise Conv1d forward for DTensor→local conversion. + + DTensor dispatch for Conv1d doesn't handle sharded groups: nn.Conv1d + stores groups as a plain int, but when the weight is TP-sharded on + the channel dim, the local weight has fewer channels than groups. + This wrapper converts inputs/weights to local and uses the local + channel count as groups. + + TODO: Remove once DTensor Conv1d dispatch handles sharded groups. + """ + original_forward = conv1d_module.forward.__func__ + + def safe_forward(self, x): + if isinstance(x, DTensor): + mesh, plc = x.device_mesh, x.placements + w = self.weight + if isinstance(w, DTensor): + w = w.to_local() + # self.groups is the global count; use local weight's channel dim + local_groups = w.shape[0] + out = F.conv1d( + x.to_local(), + w, + None, + self.stride, + self.padding, + self.dilation, + local_groups, + ) + return DTensor.from_local(out, mesh, plc, run_check=False) + return original_forward(self, x) + + conv1d_module.forward = types.MethodType(safe_forward, conv1d_module) + + +def _wrap_pos_embed_for_interpolation(vision_encoder) -> None: + """Wrap compute_position_embeddings to convert pos_embed to local. + + F.interpolate's decomposition uses _unsafe_index which doesn't support + DTensor. Since pos_embed is Replicate, to_local is a no-op for data. + + TODO: Remove once F.interpolate on FSDP2-managed DTensors is fixed upstream. + """ + original_fn = vision_encoder.compute_position_embeddings.__func__ + + def safe_compute(self, grid_thw, max_num_patch): + pos = self.pos_embed + if isinstance(pos, DTensor): + mesh, plc = pos.device_mesh, pos.placements + self.pos_embed = nn.Parameter(pos.to_local(), requires_grad=False) + learned_pos, rope_cache = original_fn(self, grid_thw, max_num_patch) + self.pos_embed = pos + learned_pos = DTensor.from_local(learned_pos, mesh, plc, run_check=False) + return learned_pos, rope_cache + return original_fn(self, grid_thw, max_num_patch) + + vision_encoder.compute_position_embeddings = types.MethodType( + safe_compute, vision_encoder + ) diff --git a/torchtitan/models/qwen3_5/state_dict_adapter.py b/torchtitan/models/qwen3_5/state_dict_adapter.py new file mode 100644 index 0000000000..f8788d5c88 --- /dev/null +++ b/torchtitan/models/qwen3_5/state_dict_adapter.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +State dict adapter for Qwen3.5. + +Converts between HuggingFace Qwen3.5 checkpoint format and torchtitan format. + +MoE expert weights require two transformations: +- **Transpose**: HF and TT use transposed layouts for grouped 3D expert weights. + E.g. HF down_proj [E, hidden, dim] <-> TT w2 [E, dim, hidden]. +- **Fuse/split gate_up_proj**: HF fuses gate_proj and up_proj into a single + gate_up_proj [E, dim, 2*hidden_dim]. TT stores them separately as + w1 [E, hidden_dim, dim] and w3 [E, hidden_dim, dim]. + +Other notable conversions: +- Conv3d patch embedding (HF) <-> Linear (TT) via weight reshape +- Vision block naming: HF `blocks` <-> TT `layers` +- Vision QKV: HF fused qkv <-> TT separate wq/wk/wv +- GatedDeltaNet QKV: HF fused in_proj_qkv <-> TT separate in_proj_q/k/v +- GatedDeltaNet Conv1d: HF fused conv1d <-> TT separate conv_q/k/v +""" + +import re +from typing import Any + +import torch + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .model import Qwen35Model + + +class Qwen35StateDictAdapter(StateDictAdapter): + def __init__(self, model_config: Qwen35Model.Config, hf_assets_path: str | None): + super().__init__(model_config, hf_assets_path) + self.model_config = model_config + + self.from_hf_map = { + # ===== Language Model ===== + "model.language_model.embed_tokens.weight": "tok_embeddings.weight", + # Full attention layers (self_attn.*) + "model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.wq.weight", + "model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.wk.weight", + "model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.wv.weight", + "model.language_model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.wo.weight", + "model.language_model.layers.{}.self_attn.q_norm.weight": "layers.{}.attn.q_norm.weight", + "model.language_model.layers.{}.self_attn.k_norm.weight": "layers.{}.attn.k_norm.weight", + "model.language_model.layers.{}.self_attn.rotary_emb.inv_freq": None, + # GatedDeltaNet layers (linear_attn.*) + # QKV and Conv1d: HF fused → TT separate (handled in to_hf/from_hf) + "model.language_model.layers.{}.linear_attn.in_proj_qkv.weight": None, + "model.language_model.layers.{}.linear_attn.conv1d.weight": None, + "model.language_model.layers.{}.linear_attn.in_proj_z.weight": "layers.{}.attn.in_proj_z.weight", + "model.language_model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attn.in_proj_a.weight", + "model.language_model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attn.in_proj_b.weight", + "model.language_model.layers.{}.linear_attn.A_log": "layers.{}.attn.A_log", + "model.language_model.layers.{}.linear_attn.dt_bias": "layers.{}.attn.dt_bias", + "model.language_model.layers.{}.linear_attn.norm.weight": "layers.{}.attn.norm.weight", + "model.language_model.layers.{}.linear_attn.out_proj.weight": "layers.{}.attn.out_proj.weight", + # Non-MoE MLP + "model.language_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.language_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.language_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + # Layer norms + "model.language_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.language_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE (grouped 3D format, handled specially in to_hf/from_hf) + "model.language_model.layers.{}.mlp.experts.down_proj": "layers.{}.moe.experts.w2", + "model.language_model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", + # MoE shared expert + "model.language_model.layers.{}.mlp.shared_expert.gate_proj.weight": "layers.{}.shared_ffn.w1.weight", + "model.language_model.layers.{}.mlp.shared_expert.up_proj.weight": "layers.{}.shared_ffn.w3.weight", + "model.language_model.layers.{}.mlp.shared_expert.down_proj.weight": "layers.{}.shared_ffn.w2.weight", + "model.language_model.layers.{}.mlp.shared_expert_gate.weight": "layers.{}.shared_gate.weight", + # Final norm and output + "model.language_model.norm.weight": "norm.weight", + "lm_head.weight": "lm_head.weight", + # ===== Vision Encoder ===== + # Patch embedding (Conv3d in HF, Linear in TT — weight reshape needed) + "model.visual.patch_embed.proj.weight": "vision_encoder.patch_embed.weight", + "model.visual.patch_embed.proj.bias": "vision_encoder.patch_embed.bias", + # Position embeddings + "model.visual.pos_embed.weight": "vision_encoder.pos_embed", + # Vision transformer blocks (HF: blocks, TT: layers) + "model.visual.blocks.{}.norm1.weight": "vision_encoder.layers.{}.norm1.weight", + "model.visual.blocks.{}.norm1.bias": "vision_encoder.layers.{}.norm1.bias", + "model.visual.blocks.{}.norm2.weight": "vision_encoder.layers.{}.norm2.weight", + "model.visual.blocks.{}.norm2.bias": "vision_encoder.layers.{}.norm2.bias", + # Vision QKV: HF fused → TT separate (handled in to_hf/from_hf) + "model.visual.blocks.{}.attn.qkv.weight": None, + "model.visual.blocks.{}.attn.qkv.bias": None, + "model.visual.blocks.{}.attn.proj.weight": "vision_encoder.layers.{}.attn.proj.weight", + "model.visual.blocks.{}.attn.proj.bias": "vision_encoder.layers.{}.attn.proj.bias", + "model.visual.blocks.{}.mlp.linear_fc1.weight": "vision_encoder.layers.{}.mlp.linear_fc1.weight", + "model.visual.blocks.{}.mlp.linear_fc1.bias": "vision_encoder.layers.{}.mlp.linear_fc1.bias", + "model.visual.blocks.{}.mlp.linear_fc2.weight": "vision_encoder.layers.{}.mlp.linear_fc2.weight", + "model.visual.blocks.{}.mlp.linear_fc2.bias": "vision_encoder.layers.{}.mlp.linear_fc2.bias", + # Merger + "model.visual.merger.norm.weight": "vision_encoder.merger.norm.weight", + "model.visual.merger.norm.bias": "vision_encoder.merger.norm.bias", + "model.visual.merger.linear_fc1.weight": "vision_encoder.merger.linear_fc1.weight", + "model.visual.merger.linear_fc1.bias": "vision_encoder.merger.linear_fc1.bias", + "model.visual.merger.linear_fc2.weight": "vision_encoder.merger.linear_fc2.weight", + "model.visual.merger.linear_fc2.bias": "vision_encoder.merger.linear_fc2.bias", + } + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert torchtitan state dict to HuggingFace Qwen3.5 format.""" + to_hf_map = {v: k for k, v in self.from_hf_map.items() if v is not None} + hf_state_dict = {} + + moe_w1_by_layer: dict[str, Any] = {} + moe_w3_by_layer: dict[str, Any] = {} + vision_qkv_by_layer: dict[str, dict[str, Any]] = {} + deltanet_qkv_by_layer: dict[str, dict[str, Any]] = {} + + for tt_key, value in state_dict.items(): + if "moe.experts" in tt_key: + tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1) + # pyrefly: ignore [missing-attribute] + layer_num = re.search(r"\d+", tt_key).group(0) + + if tt_abstract_key == "layers.{}.moe.experts.w1": + moe_w1_by_layer[layer_num] = value + continue + elif tt_abstract_key == "layers.{}.moe.experts.w3": + moe_w3_by_layer[layer_num] = value + continue + elif tt_abstract_key == "layers.{}.moe.experts.w2": + hf_key = ( + f"model.language_model.layers.{layer_num}.mlp.experts.down_proj" + ) + hf_state_dict[hf_key] = value.transpose(-2, -1) + continue + + if tt_abstract_key not in to_hf_map: + continue + hf_state_dict[to_hf_map[tt_abstract_key].format(layer_num)] = value + + elif re.search(r"\.\d+\.", tt_key): + tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1) + # pyrefly: ignore [missing-attribute] + layer_num = re.search(r"\d+", tt_key).group(0) + + # Collect deltanet q/k/v projections and conv weights for fusing + if tt_abstract_key in ( + "layers.{}.attn.in_proj_q.weight", + "layers.{}.attn.in_proj_k.weight", + "layers.{}.attn.in_proj_v.weight", + "layers.{}.attn.conv_q.weight", + "layers.{}.attn.conv_k.weight", + "layers.{}.attn.conv_v.weight", + ): + if layer_num not in deltanet_qkv_by_layer: + deltanet_qkv_by_layer[layer_num] = {} + short_key = tt_abstract_key.split("attn.")[-1].replace("{}", "") + deltanet_qkv_by_layer[layer_num][short_key] = value + continue + + # Collect vision wq/wk/wv for fusing into qkv + if tt_abstract_key in ( + "vision_encoder.layers.{}.attn.wq.weight", + "vision_encoder.layers.{}.attn.wq.bias", + "vision_encoder.layers.{}.attn.wk.weight", + "vision_encoder.layers.{}.attn.wk.bias", + "vision_encoder.layers.{}.attn.wv.weight", + "vision_encoder.layers.{}.attn.wv.bias", + ): + if layer_num not in vision_qkv_by_layer: + vision_qkv_by_layer[layer_num] = {} + short_key = tt_abstract_key.split("attn.")[-1].replace("{}", "") + vision_qkv_by_layer[layer_num][short_key] = value + continue + + if tt_abstract_key not in to_hf_map: + continue + hf_state_dict[to_hf_map[tt_abstract_key].format(layer_num)] = value + + else: + if tt_key not in to_hf_map: + continue + if tt_key == "lm_head.weight" and getattr( + self.model_config, "enable_weight_tying", False + ): + continue + hf_value = value + # Linear weight (out, C*T*H*W) → Conv3d weight (out, C, T, H, W) + if tt_key == "vision_encoder.patch_embed.weight": + encoder = self.model_config.vision_encoder + hf_value = value.reshape( + value.shape[0], + encoder.in_channels, + encoder.temporal_patch_size, + encoder.patch_size, + encoder.patch_size, + ) + hf_state_dict[to_hf_map[tt_key]] = hf_value + + # Fuse MoE w1 (gate) + w3 (up) → gate_up_proj + for layer_num in moe_w1_by_layer: + w1 = moe_w1_by_layer[layer_num].transpose(-2, -1) + w3 = moe_w3_by_layer[layer_num].transpose(-2, -1) + hf_state_dict[ + f"model.language_model.layers.{layer_num}.mlp.experts.gate_up_proj" + ] = torch.cat([w1, w3], dim=-1) + + # Fuse vision wq/wk/wv → qkv + for layer_num, parts in vision_qkv_by_layer.items(): + for suffix in ("weight", "bias"): + q = parts.get(f"wq.{suffix}") + k = parts.get(f"wk.{suffix}") + v = parts.get(f"wv.{suffix}") + if q is not None and k is not None and v is not None: + hf_state_dict[ + f"model.visual.blocks.{layer_num}.attn.qkv.{suffix}" + ] = torch.cat([q, k, v], dim=0) + + # Fuse deltanet in_proj_q/k/v → in_proj_qkv, conv_q/k/v → conv1d + for layer_num, parts in deltanet_qkv_by_layer.items(): + q = parts.get("in_proj_q.weight") + k = parts.get("in_proj_k.weight") + v = parts.get("in_proj_v.weight") + if q is not None and k is not None and v is not None: + hf_state_dict[ + f"model.language_model.layers.{layer_num}.linear_attn.in_proj_qkv.weight" + ] = torch.cat([q, k, v], dim=0) + cq = parts.get("conv_q.weight") + ck = parts.get("conv_k.weight") + cv = parts.get("conv_v.weight") + if cq is not None and ck is not None and cv is not None: + hf_state_dict[ + f"model.language_model.layers.{layer_num}.linear_attn.conv1d.weight" + ] = torch.cat([cq, ck, cv], dim=0) + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert HuggingFace Qwen3.5 state dict to torchtitan format.""" + tt_state_dict = {} + + # HF ties lm_head with embed_tokens — copy if missing + if "lm_head.weight" not in hf_state_dict: + embed_key = "model.language_model.embed_tokens.weight" + if embed_key not in hf_state_dict: + raise ValueError( + f"HF checkpoint missing both 'lm_head.weight' and '{embed_key}'" + ) + hf_state_dict["lm_head.weight"] = hf_state_dict[embed_key] + + for hf_key, value in hf_state_dict.items(): + if re.search(r"\.\d+\.", hf_key): + hf_abstract_key = re.sub(r"(\d+)", "{}", hf_key, count=1) + # pyrefly: ignore [missing-attribute] + idx = re.search(r"\d+", hf_key).group(0) + + # MoE gate_up_proj → split into w1 + w3 and transpose + if ( + hf_abstract_key + == "model.language_model.layers.{}.mlp.experts.gate_up_proj" + ): + w1_hf, w3_hf = value.chunk(2, dim=-1) + tt_state_dict[f"layers.{idx}.moe.experts.w1"] = w1_hf.transpose( + -2, -1 + ) + tt_state_dict[f"layers.{idx}.moe.experts.w3"] = w3_hf.transpose( + -2, -1 + ) + continue + + # MoE down_proj → transpose + if ( + hf_abstract_key + == "model.language_model.layers.{}.mlp.experts.down_proj" + ): + tt_state_dict[f"layers.{idx}.moe.experts.w2"] = value.transpose( + -2, -1 + ) + continue + + # GatedDeltaNet fused in_proj_qkv → split into q/k/v + if ( + hf_abstract_key + == "model.language_model.layers.{}.linear_attn.in_proj_qkv.weight" + ): + dn = self.model_config.layers[int(idx)].deltanet + kd = dn.n_key_heads * dn.key_head_dim + vd = dn.n_value_heads * dn.value_head_dim + q, k, v = value.split([kd, kd, vd], dim=0) + tt_state_dict[f"layers.{idx}.attn.in_proj_q.weight"] = q + tt_state_dict[f"layers.{idx}.attn.in_proj_k.weight"] = k + tt_state_dict[f"layers.{idx}.attn.in_proj_v.weight"] = v + continue + + # GatedDeltaNet fused conv1d → split into conv_q/k/v + if ( + hf_abstract_key + == "model.language_model.layers.{}.linear_attn.conv1d.weight" + ): + dn = self.model_config.layers[int(idx)].deltanet + kd = dn.n_key_heads * dn.key_head_dim + vd = dn.n_value_heads * dn.value_head_dim + cq, ck, cv = value.split([kd, kd, vd], dim=0) + tt_state_dict[f"layers.{idx}.attn.conv_q.weight"] = cq + tt_state_dict[f"layers.{idx}.attn.conv_k.weight"] = ck + tt_state_dict[f"layers.{idx}.attn.conv_v.weight"] = cv + continue + + # Vision fused QKV → split into wq/wk/wv + if hf_abstract_key in ( + "model.visual.blocks.{}.attn.qkv.weight", + "model.visual.blocks.{}.attn.qkv.bias", + ): + suffix = "weight" if "weight" in hf_abstract_key else "bias" + q, k, v = value.chunk(3, dim=0) + tt_state_dict[f"vision_encoder.layers.{idx}.attn.wq.{suffix}"] = q + tt_state_dict[f"vision_encoder.layers.{idx}.attn.wk.{suffix}"] = k + tt_state_dict[f"vision_encoder.layers.{idx}.attn.wv.{suffix}"] = v + continue + + if hf_abstract_key not in self.from_hf_map: + continue + tt_key = self.from_hf_map[hf_abstract_key] + if tt_key is None: + continue + tt_state_dict[tt_key.format(idx)] = value + + else: + if hf_key not in self.from_hf_map: + continue + tt_key = self.from_hf_map[hf_key] + if tt_key is None: + continue + tt_value = value + # Conv3d weight (out, C, T, H, W) → Linear weight (out, C*T*H*W) + if hf_key == "model.visual.patch_embed.proj.weight": + tt_value = value.reshape(value.shape[0], -1) + tt_state_dict[tt_key] = tt_value + + return tt_state_dict diff --git a/torchtitan/models/qwen3_vl/vision_encoder.py b/torchtitan/models/qwen3_5/vision_encoder.py similarity index 63% rename from torchtitan/models/qwen3_vl/vision_encoder.py rename to torchtitan/models/qwen3_5/vision_encoder.py index be59176387..e7625b5cdc 100644 --- a/torchtitan/models/qwen3_vl/vision_encoder.py +++ b/torchtitan/models/qwen3_5/vision_encoder.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import dataclass import torch import torch.nn as nn @@ -15,19 +16,18 @@ from torchtitan.models.common.attention import FlexAttention from torchtitan.models.common.nn_modules import GELU, LayerNorm from torchtitan.models.common.rope import CosSinRoPE -from torchtitan.protocols.module import Module, ModuleDict, ModuleList +from torchtitan.protocols.module import Module, ModuleDict _compiled_create_block_mask = torch.compile(create_block_mask) -def get_vision_block_mask_mod(num_patch: torch.Tensor, max_num_patch: int): +def get_vision_block_mask_mod(num_patch: torch.Tensor) -> Callable: """Create a mask modifier for block-diagonal attention. Each image only attends to its own patches. Args: num_patch: (num_vision,) actual number of patches per visual item - max_num_patch: Maximum number of patches (padded length) """ def mask_mod(b, h, q_idx, kv_idx): @@ -61,7 +61,7 @@ def _compute_learned_pos_embeds( dim: Hidden dimension Returns: - learned_pos: (num_vision, max_num_patch, dim) interpolated position embeddings + pos_embeds: (num_vision, max_num_patch, dim) interpolated position embeddings """ num_vision = grid_thw.shape[0] dtype = learned_pos_embed.dtype @@ -90,6 +90,7 @@ def _compute_learned_pos_embeds( for (h, w), indices in hw_to_indices.items(): pos_hw = F.interpolate( pos_grid, + # pyrefly: ignore [bad-argument-type] size=(h, w), mode="bilinear", align_corners=True, @@ -208,7 +209,7 @@ def _compute_2d_rope_cache( else: rope_embeds[i, :seq_len] = rope_2d - # Compute cos/sin in model dtype (HF uses .float() here) + # Compute cos/sin in float32 for numerical precision rope_embeds = torch.cat((rope_embeds, rope_embeds), dim=-1) # (N, L, head_dim) rope_cache = torch.cat([rope_embeds.cos(), rope_embeds.sin()], dim=-1).unsqueeze( 2 @@ -217,45 +218,6 @@ def _compute_2d_rope_cache( return rope_cache -class PatchEmbed(Module): - """Patch Embedding using Linear projection. - - Since patches are already extracted by the collator, we use Linear instead of Conv3d. - This is mathematically equivalent when Conv3d kernel_size equals input size: - - Conv3d: (B, C, T, H, W) with kernel=C*(T,H,W) and dim kernels → (B, dim, 1, 1, 1) - - Linear: (B, C*T*H*W) → (B, dim) - Same weighted sum, but Linear uses efficient batched matrix multiplication. - """ - - @dataclass(kw_only=True, slots=True) - class Config(Module.Config): - patch_size: int - temporal_patch_size: int - in_channels: int - proj: Linear.Config - - def __init__(self, config: Config): - super().__init__() - # Kept on the Module so the state-dict adapter can reshape - # Linear weights into 5D Conv3d weights. - self.patch_size = config.patch_size - self.temporal_patch_size = config.temporal_patch_size - self.in_channels = config.in_channels - - self.proj = config.proj.build() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Project patches to embeddings. - - Args: - hidden_states: (batch, max_num_patch, patch_dim) - - Returns: - (batch, max_num_patch, embed_dim) - """ - return self.proj(hidden_states) - - class VisionRotaryEmbedding(Module): """2D Rotary Position Embedding for Vision Transformer.""" @@ -286,46 +248,38 @@ def _init_self_buffers(self, *, buffer_device: torch.device | None = None) -> No ) def forward(self, seqlen: int) -> torch.Tensor: - """Compute rotary embeddings for a sequence.""" + """Compute rotary frequency table for positions up to seqlen.""" seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype ) - freqs = torch.outer(seq, self.inv_freq) - return freqs + return torch.outer(seq, self.inv_freq) class PatchMerger(Module): """Merge spatial patches to reduce sequence length. - ``use_postshuffle_norm``: If True, apply LayerNorm after spatial reshape - (norm dim = hidden_size * spatial_merge_size^2). If False, apply - before reshape (norm dim = hidden_size). DeepStack mergers use - postshuffle norm; the main merger uses pre-shuffle norm. The caller - must set ``norm.normalized_shape`` to match the chosen mode. + Applies LayerNorm before spatial reshape, then projects through a + two-layer MLP (fc1 → GELU → fc2). """ - @dataclass(kw_only=True, slots=True) - class Config(Module.Config): - hidden_size: int - spatial_merge_size: int - fc1: Linear.Config - fc2: Linear.Config - norm: LayerNorm.Config - act_fn: GELU.Config = field( - default_factory=lambda: GELU.Config(approximate="tanh") - ) - use_postshuffle_norm: bool = False - - def __init__(self, config: Config): + def __init__( + self, + hidden_size: int, + out_hidden_size: int, + spatial_merge_size: int, + layer_norm_eps: float, + *, + fc1: Linear.Config, + fc2: Linear.Config, + ): super().__init__() - self.spatial_merge_size = config.spatial_merge_size - self.merged_hidden_size = config.hidden_size * (config.spatial_merge_size**2) - self.use_postshuffle_norm = config.use_postshuffle_norm + self.spatial_merge_size = spatial_merge_size + self.merged_hidden_size = hidden_size * (spatial_merge_size**2) - self.norm = config.norm.build() - self.linear_fc1 = config.fc1.build() - self.act_fn = config.act_fn.build() - self.linear_fc2 = config.fc2.build() + self.norm = LayerNorm(hidden_size, eps=layer_norm_eps) + self.linear_fc1 = fc1.build() + self.act_fn = GELU(approximate="tanh") + self.linear_fc2 = fc2.build() def forward(self, x: torch.Tensor) -> torch.Tensor: """Merge spatial patches and project to output dimension. @@ -337,20 +291,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: (batch, seq_len // spatial_merge_size^2, out_hidden_size) """ batch_size, seq_len, _ = x.shape - if self.use_postshuffle_norm: - x = x.view( - batch_size, - seq_len // (self.spatial_merge_size**2), - self.merged_hidden_size, - ) - x = self.norm(x) - else: - x = self.norm(x) - x = x.view( - batch_size, - seq_len // (self.spatial_merge_size**2), - self.merged_hidden_size, - ) + x = self.norm(x) + x = x.view( + batch_size, + seq_len // (self.spatial_merge_size**2), + self.merged_hidden_size, + ) x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) return x @@ -358,56 +304,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VisionAttention(Module): """Multi-head attention with FlexAttention for efficient batched processing.""" - @dataclass(kw_only=True, slots=True) - class Config(Module.Config): - dim: int - n_heads: int - qkv: Linear.Config - proj: Linear.Config - flex_attention: FlexAttention.Config = field( - default_factory=lambda: FlexAttention.Config() - ) - - def __init__(self, config: Config): + def __init__( + self, + dim: int, + num_heads: int, + *, + wq: Linear.Config, + wk: Linear.Config, + wv: Linear.Config, + proj: Linear.Config, + ): super().__init__() - self.dim = config.dim - self.num_heads = config.n_heads + self.dim = dim + self.num_heads = num_heads self.head_dim = self.dim // self.num_heads - self.qkv = config.qkv.build() - self.proj = config.proj.build() - self.flex_attention = config.flex_attention.build() + self.wq = wq.build() + self.wk = wk.build() + self.wv = wv.build() + self.proj = proj.build() + self.flex_attention = FlexAttention.Config().build() def forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, *, rope_cache: torch.Tensor, attention_mask: BlockMask, ) -> torch.Tensor: - """Apply multi-head attention with 2D RoPE. + bs, seqlen, _ = x.shape - Args: - hidden_states: (num_vision, max_num_patch, dim) - rope_cache: (num_vision, max_num_patch, 1, head_dim*2) precomputed cos/sin - attention_mask: BlockMask for attention + xq = self.wq(x).view(bs, seqlen, -1, self.head_dim) + xk = self.wk(x).view(bs, seqlen, -1, self.head_dim) + xv = self.wv(x).view(bs, seqlen, -1, self.head_dim) - Returns: - (num_vision, max_num_patch, dim) - """ - num_vision, max_num_patch, _ = hidden_states.shape + xq, xk = CosSinRoPE.apply_rotary_emb(xq, xk, rope_cache) - qkv = self.qkv(hidden_states).reshape( - num_vision, max_num_patch, 3, -1, self.head_dim - ) - # Each: (num_vision, max_num_patch, n_heads, head_dim) - q, k, v = qkv.permute(2, 0, 1, 3, 4).unbind(0) - # Vision RoPE cache is already position-specific from grid_thw, so - # apply the prepared cache directly without a RoPE module instance. - q, k = CosSinRoPE.apply_rotary_emb(q, k, rope_cache) - attn_output = self.flex_attention(q, k, v, attention_masks=attention_mask) - attn_output = attn_output.reshape(num_vision, max_num_patch, -1) - return self.proj(attn_output) + output = self.flex_attention(xq, xk, xv, attention_masks=attention_mask) + output = output.reshape(bs, seqlen, -1) + return self.proj(output) class VisionMLP(Module): @@ -427,74 +362,80 @@ def __init__(self, config: Config): self.linear_fc2 = config.fc2.build() self.act_fn = config.act_fn.build() - def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.act_fn(self.linear_fc1(x))) class VisionTransformerBlock(Module): """Single transformer block for vision encoder.""" - @dataclass(kw_only=True, slots=True) - class Config(Module.Config): - attn: VisionAttention.Config - mlp: VisionMLP.Config - norm1: LayerNorm.Config - norm2: LayerNorm.Config - - def __init__(self, config: Config): + def __init__( + self, + dim: int, + num_heads: int, + layer_norm_eps: float, + *, + attn_wq: Linear.Config, + attn_wk: Linear.Config, + attn_wv: Linear.Config, + attn_proj: Linear.Config, + mlp_fc1: Linear.Config, + mlp_fc2: Linear.Config, + ): super().__init__() - self.norm1 = config.norm1.build() - self.norm2 = config.norm2.build() - self.attn = config.attn.build() - self.mlp = config.mlp.build() + self.norm1 = LayerNorm(dim, eps=layer_norm_eps) + self.norm2 = LayerNorm(dim, eps=layer_norm_eps) + self.attn = VisionAttention( + dim, num_heads, wq=attn_wq, wk=attn_wk, wv=attn_wv, proj=attn_proj + ) + self.mlp = VisionMLP(fc1=mlp_fc1, fc2=mlp_fc2) def forward( self, - hidden_states: torch.Tensor, + x: torch.Tensor, *, rope_cache: torch.Tensor, attention_mask: BlockMask, ) -> torch.Tensor: - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), - rope_cache=rope_cache, - attention_mask=attention_mask, + x = x + self.attn( + self.norm1(x), rope_cache=rope_cache, attention_mask=attention_mask ) - hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states + x = x + self.mlp(self.norm2(x)) + return x -class Qwen3VLVisionEncoder(Module): - """Qwen3-VL Vision Encoder with FlexAttention. +class Qwen35VisionEncoder(Module): + """Qwen3.5 Vision Encoder with FlexAttention. Uses padded batches (N, L, D) format for efficient processing. """ @dataclass(kw_only=True, slots=True) class Config(Module.Config): - """Configuration for Qwen3-VL Vision Encoder (ViT). - - ``num_position_embeddings`` and ``dim`` stay on this Config because - ``pos_embed`` is an ``nn.Parameter`` owned directly by the encoder - (not factored into a sub-Module). ``n_heads`` is kept to derive - ``head_dim`` inside ``compute_position_embeddings``. - ``spatial_merge_size`` and ``deepstack_visual_indices`` are read by - ``Qwen3VLModel``. - """ + """Configuration for Qwen3.5 Vision Encoder (ViT).""" + + dim: int = 1280 + ffn_dim: int = 5120 + num_layers: int = 32 + num_heads: int = 16 dim: int n_heads: int spatial_merge_size: int num_position_embeddings: int - # DeepStack: layer indices for extracting intermediate visual features - deepstack_visual_indices: list[int] - - patch_embed: PatchEmbed.Config - rotary_pos_emb: VisionRotaryEmbedding.Config - layers: list[VisionTransformerBlock.Config] - merger: PatchMerger.Config - deepstack_mergers: list[PatchMerger.Config] + # Per-layer Linear configs for vision encoder sub-modules + # Linear instead of Conv3d — equivalent when kernel_size equals patch size, + # but more efficient via batched matmul on pre-flattened patches. + patch_embed_proj: Linear.Config + attn_wq: Linear.Config + attn_wk: Linear.Config + attn_wv: Linear.Config + attn_proj: Linear.Config + mlp_fc1: Linear.Config + mlp_fc2: Linear.Config + merger_fc1: Linear.Config + merger_fc2: Linear.Config def __init__(self, config: Config): super().__init__() @@ -502,7 +443,7 @@ def __init__(self, config: Config): self.spatial_merge_size = config.spatial_merge_size self.spatial_merge_unit = config.spatial_merge_size**2 - self.patch_embed = config.patch_embed.build() + self.patch_embed = config.patch_embed_proj.build() # nn.Parameter (not nn.Embedding) because we interpolate the weight directly self.num_position_embeddings = config.num_position_embeddings @@ -511,20 +452,37 @@ def __init__(self, config: Config): ) self.num_grid_per_side = int(config.num_position_embeddings**0.5) - self.rotary_pos_emb = config.rotary_pos_emb.build() + head_dim = config.dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding( + head_dim // 2, theta=config.rope_theta + ) # Cached RoPE freq table — recomputed only when max_hw grows self._cached_freq_table: torch.Tensor | None = None self.layers = ModuleDict( - {str(idx): layer.build() for idx, layer in enumerate(config.layers)} + { + str(idx): VisionTransformerBlock( + config.dim, + config.num_heads, + config.layer_norm_eps, + attn_wq=config.attn_wq, + attn_wk=config.attn_wk, + attn_wv=config.attn_wv, + attn_proj=config.attn_proj, + mlp_fc1=config.mlp_fc1, + mlp_fc2=config.mlp_fc2, + ) + for idx in range(config.num_layers) + } ) - self.merger = config.merger.build() - - # DeepStack mergers for intermediate layers - self.deepstack_visual_indices = config.deepstack_visual_indices - self.deepstack_merger_list = ModuleList( - [cfg.build() for cfg in config.deepstack_mergers] + self.merger = PatchMerger( + hidden_size=config.dim, + out_hidden_size=config.out_hidden_size, + spatial_merge_size=config.spatial_merge_size, + layer_norm_eps=config.layer_norm_eps, + fc1=config.merger_fc1, + fc2=config.merger_fc2, ) def compute_position_embeddings( @@ -537,7 +495,7 @@ def compute_position_embeddings( - ``_compute_2d_rope_cache``: 2D RoPE cache Args: - grid_thw: (num_vision, 3) with pixel patch counts [t, h, w] per visual item + grid_thw: (num_vision, 3) with patch counts [t, h, w] per visual item max_num_patch: Maximum number of patches (for padding) Returns: @@ -545,7 +503,7 @@ def compute_position_embeddings( rope_cache: (num_vision, max_num_patch, 1, head_dim*2) RoPE cache for VisionAttention """ - head_dim = self.config.dim // self.config.n_heads + head_dim = self.config.dim // self.config.num_heads # Get RoPE freq table, reusing cache when possible max_hw = int(grid_thw[:, 1:].max().item()) @@ -576,7 +534,7 @@ def forward( pixel_values: torch.Tensor, *, grid_thw: torch.Tensor, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor: """Forward pass of the vision encoder. Processes both images and videos — each visual item is a batch of @@ -588,40 +546,28 @@ def forward( Returns: merged_hidden_states: (num_vision, max_merged_num_patch, out_hidden_size) - deepstack_features: List of features from intermediate layers """ num_vision, max_num_patch, _ = pixel_values.shape num_patch = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).to(torch.long) - hidden_states = self.patch_embed(pixel_values) + x = self.patch_embed(pixel_values) # (num_vision, max_num_patch, dim) learned_pos, rope_cache = self.compute_position_embeddings( grid_thw, max_num_patch ) - hidden_states = hidden_states + learned_pos + x = x + learned_pos - mask_mod = get_vision_block_mask_mod(num_patch, max_num_patch) + mask_mod = get_vision_block_mask_mod(num_patch) attention_mask = _compiled_create_block_mask( mask_mod, num_vision, None, max_num_patch, max_num_patch, - device=hidden_states.device, + device=x.device, ) - deepstack_features = [] - - for layer_idx, layer in self.layers.items(): - hidden_states = layer( - hidden_states, - rope_cache=rope_cache, - attention_mask=attention_mask, - ) - if int(layer_idx) in self.deepstack_visual_indices: - idx = self.deepstack_visual_indices.index(int(layer_idx)) - deepstack_feature = self.deepstack_merger_list[idx](hidden_states) - deepstack_features.append(deepstack_feature) - merged_hidden_states = self.merger(hidden_states) + for layer in self.layers.values(): + x = layer(x, rope_cache=rope_cache, attention_mask=attention_mask) - return merged_hidden_states, deepstack_features + return self.merger(x) diff --git a/torchtitan/models/qwen3_vl/README.md b/torchtitan/models/qwen3_vl/README.md deleted file mode 100644 index 6f5b54afa7..0000000000 --- a/torchtitan/models/qwen3_vl/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Qwen3-VL: Vision-Language Model - -## Overview - -Qwen3-VL combines: -- **Qwen3 LLM**: The base language model with QK-norm and RoPE. -- **Vision Encoder**: A Vision Transformer (ViT) that supports native resolution images (no fixed square crops) with 2D RoPE and bilinear-interpolated position embeddings. -- **Patch Merger**: Reduces vision sequence length by merging spatial patches (e.g., 2×2 patches → 1 token). -- **DeepStack**: Adds intermediate ViT features to vision positions in early decoder layers. -- **MRoPE**: Interleaves RoPE from temporal, height, and width position IDs in decoder layers. - -## Vision Scatter - -- `tok_embeddings` produces text token embeddings of shape `B×S`. -- `vision_encoder` produces visual token embeddings of shape `N×L`. -- Valid visual tokens (excluding padding) are scattered into the placeholder positions in the text sequence, as illustrated below (credit: [@lkhphuc](https://github.com/lkhphuc)). - -VLM Architecture - -Note: the diagram shows each patch mapping to one vision token. In practice, the Patch Merger groups `merge_size²` patches into one token (e.g., `merge_size=2` → 4 patches per token), reducing the vision sequence length by `merge_size²`. - -## Prerequisites - -Install the additional dependencies required by Qwen3-VL: - -```bash -pip install av torchvision -``` - -## Model Variants - -| Variant | LLM dim | Layers | ViT dim | ViT layers | Patch size | MoE | -|---------|---------|--------|---------|------------|------------|-----| -| debugmodel | 256 | 4 | 256 | 4 | 16 | No | -| debugmodel_moe | 256 | 1 | 256 | 4 | 16 | Yes (8 experts) | -| 2B | 2048 | 28 | 1024 | 24 | 16 | No | -| 8B | 4096 | 36 | 1152 | 27 | 16 | No | -| 30B-A3B | 2048 | 48 | 1152 | 27 | 16 | Yes (128 experts) | -| 235B-A22B | 4096 | 94 | 1152 | 27 | 16 | Yes (128 experts) | - -## Datasets - -| Dataset | Type | Format | -|---------|------|--------| -| `cc12m` | Image-text pairs | WebDataset (streaming) | -| `cc12m-test` | Image-text pairs | Local WebDataset (for testing) | -| `obelics` | Interleaved image-text | HuggingFace (streaming) | - -## Supported Features - -| Feature | Notes | -|---------|-------| -| FSDP / HSDP | Both vision encoder and decoder are individually sharded | -| Tensor Parallelism (TP) | Applied to both vision encoder and decoder (without SequenceParallel due to vision scatter and DeepStack) | -| Expert Parallelism (EP) | For MoE variants (e.g., 30B-A3B) | -| Sample Packing | Configurable via `packing_buffer_size` in dataloader config | - -## Numerical Parity - -End-to-end KL divergence against HuggingFace Transformers (2B, 10 random samples): **~5e-8 to ~5e-5** per sample, with **100% top-1 and top-5 match**. Test scripts are in `scripts/checkpoint_conversion/numerical_tests_qwen3_vl_*.py`. - -## TODO - -- Add Pipeline Parallelism support -- Add Context Parallel support -- Add video dataset training configs diff --git a/torchtitan/models/qwen3_vl/__init__.py b/torchtitan/models/qwen3_vl/__init__.py deleted file mode 100644 index 75b3ef29ee..0000000000 --- a/torchtitan/models/qwen3_vl/__init__.py +++ /dev/null @@ -1,649 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from collections.abc import Callable -from functools import partial - -import torch.nn as nn - -from torchtitan.models.common import Embedding, Linear, RoPE, TransformerBlock -from torchtitan.models.common.config_utils import ( - get_attention_config, - make_experts_config, - make_ffn_config, - make_gqa_config, - make_moe_config, - make_router_config, -) -from torchtitan.models.common.nn_modules import LayerNorm, RMSNorm -from torchtitan.models.common.param_init import depth_scaled_std, skip_param_init -from torchtitan.models.qwen3.model import Qwen3TransformerBlock -from torchtitan.models.utils import validate_converter_order -from torchtitan.protocols.model import ModelConfigConverter -from torchtitan.protocols.model_spec import ModelSpec -from .model import Qwen3VLModel -from .parallelize import parallelize_qwen3_vl -from .rope import MRoPE -from .state_dict_adapter import Qwen3VLStateDictAdapter -from .vision_encoder import ( - PatchEmbed, - PatchMerger, - Qwen3VLVisionEncoder, - VisionAttention, - VisionMLP, - VisionRotaryEmbedding, - VisionTransformerBlock, -) - -__all__ = [ - "parallelize_qwen3_vl", - "Qwen3VLModel", - "qwen3_vl_configs", - "QWEN3_VL_SPECIAL_TOKENS", -] - -QWEN3_VL_SPECIAL_TOKENS = { - "image_token": "<|image_pad|>", - "video_token": "<|video_pad|>", - "vision_start_token": "<|vision_start|>", - "vision_end_token": "<|vision_end|>", - "pad_token": "<|endoftext|>", -} - - -_LINEAR_INIT = { - "weight": partial(nn.init.trunc_normal_, std=0.02), - "bias": nn.init.zeros_, -} -_NORM_INIT = {"weight": nn.init.ones_} -_EMBEDDING_INIT = {"weight": partial(nn.init.normal_, std=1.0)} -_EMBEDDING_SKIP_INIT = {"weight": skip_param_init} -_POS_EMBED_INIT = {"pos_embed": partial(nn.init.trunc_normal_, mean=0.0, std=0.02)} - -_EPS = 1e-6 - - -def _output_linear_init(dim: int) -> dict[str, Callable]: - s = dim**-0.5 - return { - "weight": partial(nn.init.trunc_normal_, std=s, a=-3 * s, b=3 * s), - "bias": nn.init.zeros_, - } - - -def _depth_init(layer_id: int) -> dict[str, Callable]: - return { - "weight": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), - "bias": nn.init.zeros_, - } - - -def _depth_experts_init(layer_id: int) -> dict[str, Callable]: - return { - "w1_EFD": partial(nn.init.trunc_normal_, std=0.02), - "w2_EDF": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), - "w3_EFD": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), - } - - -def _vl_linear(in_features: int, out_features: int) -> Linear.Config: - return Linear.Config( - in_features=in_features, - out_features=out_features, - bias=True, - param_init=_LINEAR_INIT, - ) - - -def _qwen3_vl_norm(dim: int) -> RMSNorm.Config: - return RMSNorm.Config(normalized_shape=dim, eps=_EPS, param_init=_NORM_INIT) - - -def _vl_layer_norm(normalized_shape: int) -> LayerNorm.Config: - return LayerNorm.Config(normalized_shape=normalized_shape, eps=_EPS) - - -def _build_qwen3_vl_vision_block( - *, dim: int, ffn_dim: int, n_heads: int -) -> VisionTransformerBlock.Config: - return VisionTransformerBlock.Config( - attn=VisionAttention.Config( - dim=dim, - n_heads=n_heads, - qkv=_vl_linear(dim, dim * 3), - proj=_vl_linear(dim, dim), - ), - mlp=VisionMLP.Config( - fc1=_vl_linear(dim, ffn_dim), - fc2=_vl_linear(ffn_dim, dim), - ), - norm1=_vl_layer_norm(dim), - norm2=_vl_layer_norm(dim), - ) - - -def _build_qwen3_vl_merger( - *, - dim: int, - out_hidden_size: int, - spatial_merge_size: int, - use_postshuffle_norm: bool, -) -> PatchMerger.Config: - merged_hidden_size = dim * (spatial_merge_size**2) - norm_dim = merged_hidden_size if use_postshuffle_norm else dim - return PatchMerger.Config( - hidden_size=dim, - spatial_merge_size=spatial_merge_size, - fc1=_vl_linear(merged_hidden_size, merged_hidden_size), - fc2=_vl_linear(merged_hidden_size, out_hidden_size), - norm=_vl_layer_norm(norm_dim), - use_postshuffle_norm=use_postshuffle_norm, - ) - - -def _vl_vision_encoder_config( - *, - dim: int, - ffn_dim: int, - n_layers: int, - n_heads: int, - patch_size: int, - temporal_patch_size: int, - spatial_merge_size: int, - out_hidden_size: int, - num_position_embeddings: int, - deepstack_visual_indices: list[int], - in_channels: int = 3, -) -> Qwen3VLVisionEncoder.Config: - """Build a fully-specified Qwen3VLVisionEncoder.Config.""" - patch_dim = in_channels * temporal_patch_size * patch_size * patch_size - head_dim = dim // n_heads - return Qwen3VLVisionEncoder.Config( - dim=dim, - n_heads=n_heads, - spatial_merge_size=spatial_merge_size, - num_position_embeddings=num_position_embeddings, - deepstack_visual_indices=deepstack_visual_indices, - patch_embed=PatchEmbed.Config( - patch_size=patch_size, - temporal_patch_size=temporal_patch_size, - in_channels=in_channels, - proj=_vl_linear(patch_dim, dim), - ), - rotary_pos_emb=VisionRotaryEmbedding.Config(dim=head_dim // 2), - layers=[ - _build_qwen3_vl_vision_block(dim=dim, ffn_dim=ffn_dim, n_heads=n_heads) - for _ in range(n_layers) - ], - merger=_build_qwen3_vl_merger( - dim=dim, - out_hidden_size=out_hidden_size, - spatial_merge_size=spatial_merge_size, - use_postshuffle_norm=False, - ), - deepstack_mergers=[ - _build_qwen3_vl_merger( - dim=dim, - out_hidden_size=out_hidden_size, - spatial_merge_size=spatial_merge_size, - use_postshuffle_norm=True, - ) - for _ in range(len(deepstack_visual_indices)) - ], - param_init=_POS_EMBED_INIT, - ) - - -def _build_qwen3_vl_layers( - *, - n_layers: int, - dim: int, - n_heads: int, - n_kv_heads: int, - head_dim: int, - hidden_dim: int, - attn_backend: str, - rope: RoPE.Config, -) -> list[TransformerBlock.Config]: - """Build per-layer configs for dense Qwen3-VL models with depth-scaled inits.""" - inner_attention = get_attention_config(attn_backend) - layers = [] - for layer_id in range(n_layers): - layers.append( - Qwen3TransformerBlock.Config( - attention_norm=_qwen3_vl_norm(dim), - ffn_norm=_qwen3_vl_norm(dim), - attention=make_gqa_config( - dim=dim, - n_heads=n_heads, - n_kv_heads=n_kv_heads, - head_dim=head_dim, - wqkv_param_init=_LINEAR_INIT, - wo_param_init=_depth_init(layer_id), - inner_attention=inner_attention, - rope=rope, - qk_norm=_qwen3_vl_norm(head_dim), - ), - feed_forward=make_ffn_config( - dim=dim, - hidden_dim=hidden_dim, - w1_param_init=_LINEAR_INIT, - w2w3_param_init=_depth_init(layer_id), - ), - ) - ) - return layers - - -def _build_qwen3_vl_moe_layers( - *, - n_layers: int, - dim: int, - n_heads: int, - n_kv_heads: int, - head_dim: int, - moe_hidden_dim: int, - num_experts: int, - top_k: int, - attn_backend: str, - moe_comm_backend: str, - non_blocking_capacity_factor: float | None = None, - rope: RoPE.Config, -) -> list[TransformerBlock.Config]: - """Build per-layer configs for MoE Qwen3-VL models with depth-scaled inits.""" - inner_attention = get_attention_config(attn_backend) - layers = [] - for layer_id in range(n_layers): - layers.append( - Qwen3TransformerBlock.Config( - attention_norm=_qwen3_vl_norm(dim), - ffn_norm=_qwen3_vl_norm(dim), - attention=make_gqa_config( - dim=dim, - n_heads=n_heads, - n_kv_heads=n_kv_heads, - head_dim=head_dim, - wqkv_param_init=_LINEAR_INIT, - wo_param_init=_depth_init(layer_id), - inner_attention=inner_attention, - rope=rope, - qk_norm=_qwen3_vl_norm(head_dim), - ), - moe=make_moe_config( - num_experts=num_experts, - router=make_router_config( - dim=dim, - num_experts=num_experts, - gate_param_init=_depth_init(layer_id), - top_k=top_k, - score_func="softmax", - route_norm=True, - ), - experts=make_experts_config( - dim=dim, - hidden_dim=moe_hidden_dim, - num_experts=num_experts, - top_k=top_k, - param_init=_depth_experts_init(layer_id), - score_before_experts=False, - comm_backend=moe_comm_backend, - non_blocking_capacity_factor=non_blocking_capacity_factor, - ), - ), - ) - ) - return layers - - -def _debugmodel(attn_backend: str) -> Qwen3VLModel.Config: - dim = 256 - head_dim = 64 - n_layers = 4 - vocab_size = 151936 - return Qwen3VLModel.Config( - vocab_size=vocab_size, - dim=dim, - norm=_qwen3_vl_norm(dim), - tok_embeddings=Embedding.Config( - num_embeddings=vocab_size, - embedding_dim=dim, - param_init=_EMBEDDING_INIT, - ), - lm_head=Linear.Config( - in_features=dim, - out_features=vocab_size, - param_init=_output_linear_init(dim), - ), - layers=_build_qwen3_vl_layers( - attn_backend=attn_backend, - n_layers=n_layers, - dim=dim, - n_heads=4, - n_kv_heads=2, - head_dim=head_dim, - hidden_dim=512, - rope=MRoPE.Config( - dim=head_dim, - max_seq_len=4096, - theta=1000000.0, - mrope_section=[8, 8, 8], - ), - ), - vision_encoder=_vl_vision_encoder_config( - dim=256, - ffn_dim=512, - n_layers=4, - n_heads=4, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - out_hidden_size=256, - num_position_embeddings=1024, - deepstack_visual_indices=[1, 2, 3], - ), - ) - - -def _debugmodel_moe( - attn_backend: str, - moe_comm_backend: str = "standard", -) -> Qwen3VLModel.Config: - dim = 256 - head_dim = 64 - n_layers = 1 - vocab_size = 151936 - return Qwen3VLModel.Config( - vocab_size=vocab_size, - dim=dim, - norm=_qwen3_vl_norm(dim), - tok_embeddings=Embedding.Config( - num_embeddings=vocab_size, - embedding_dim=dim, - param_init=_EMBEDDING_INIT, - ), - lm_head=Linear.Config( - in_features=dim, - out_features=vocab_size, - param_init=_output_linear_init(dim), - ), - layers=_build_qwen3_vl_moe_layers( - attn_backend=attn_backend, - n_layers=n_layers, - dim=dim, - n_heads=4, - n_kv_heads=2, - head_dim=head_dim, - moe_hidden_dim=768, - num_experts=8, - top_k=4, - moe_comm_backend=moe_comm_backend, - rope=MRoPE.Config( - dim=head_dim, - max_seq_len=4096, - theta=1000000.0, - mrope_section=[8, 8, 8], - ), - ), - vision_encoder=_vl_vision_encoder_config( - dim=256, - ffn_dim=512, - n_layers=4, - n_heads=4, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - out_hidden_size=256, - num_position_embeddings=1024, - deepstack_visual_indices=[1, 2, 3], - ), - ) - - -def _2b(attn_backend: str) -> Qwen3VLModel.Config: - dim = 2048 - head_dim = 128 - n_layers = 28 - vocab_size = 151936 - return Qwen3VLModel.Config( - vocab_size=vocab_size, - dim=dim, - norm=_qwen3_vl_norm(dim), - enable_weight_tying=True, - tok_embeddings=Embedding.Config( - num_embeddings=vocab_size, - embedding_dim=dim, - param_init=_EMBEDDING_SKIP_INIT, - ), - lm_head=Linear.Config( - in_features=dim, - out_features=vocab_size, - param_init=_output_linear_init(dim), - ), - layers=_build_qwen3_vl_layers( - attn_backend=attn_backend, - n_layers=n_layers, - dim=dim, - n_heads=16, - n_kv_heads=8, - head_dim=head_dim, - hidden_dim=6144, - rope=MRoPE.Config( - dim=head_dim, - max_seq_len=32768, - theta=5000000.0, - mrope_section=[24, 20, 20], - ), - ), - vision_encoder=_vl_vision_encoder_config( - dim=1024, - ffn_dim=4096, - n_layers=24, - n_heads=16, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - out_hidden_size=2048, - num_position_embeddings=2304, - deepstack_visual_indices=[5, 11, 17], - ), - ) - - -def _8b(attn_backend: str) -> Qwen3VLModel.Config: - dim = 4096 - head_dim = 128 - n_layers = 36 - vocab_size = 151936 - return Qwen3VLModel.Config( - vocab_size=vocab_size, - dim=dim, - norm=_qwen3_vl_norm(dim), - tok_embeddings=Embedding.Config( - num_embeddings=vocab_size, - embedding_dim=dim, - param_init=_EMBEDDING_INIT, - ), - lm_head=Linear.Config( - in_features=dim, - out_features=vocab_size, - param_init=_output_linear_init(dim), - ), - layers=_build_qwen3_vl_layers( - attn_backend=attn_backend, - n_layers=n_layers, - dim=dim, - n_heads=32, - n_kv_heads=8, - head_dim=head_dim, - hidden_dim=12288, - rope=MRoPE.Config( - dim=head_dim, - max_seq_len=32768, - theta=5000000.0, - mrope_section=[24, 20, 20], - ), - ), - vision_encoder=_vl_vision_encoder_config( - dim=1152, - ffn_dim=4304, - n_layers=27, - n_heads=16, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - out_hidden_size=4096, - num_position_embeddings=2304, - deepstack_visual_indices=[8, 16, 24], - ), - ) - - -# Qwen3-VL MoE models - - -def _30b_a3b( - attn_backend: str, - moe_comm_backend: str = "standard", -) -> Qwen3VLModel.Config: - dim = 2048 - head_dim = 128 - n_layers = 48 - vocab_size = 151936 - return Qwen3VLModel.Config( - vocab_size=vocab_size, - dim=dim, - norm=_qwen3_vl_norm(dim), - tok_embeddings=Embedding.Config( - num_embeddings=vocab_size, - embedding_dim=dim, - param_init=_EMBEDDING_INIT, - ), - lm_head=Linear.Config( - in_features=dim, - out_features=vocab_size, - param_init=_output_linear_init(dim), - ), - layers=_build_qwen3_vl_moe_layers( - attn_backend=attn_backend, - n_layers=n_layers, - dim=dim, - n_heads=32, - n_kv_heads=4, - head_dim=head_dim, - moe_hidden_dim=768, - num_experts=128, - top_k=8, - moe_comm_backend=moe_comm_backend, - rope=MRoPE.Config( - dim=head_dim, - max_seq_len=32768, - theta=5000000.0, - mrope_section=[24, 20, 20], - ), - ), - vision_encoder=_vl_vision_encoder_config( - dim=1152, - ffn_dim=4304, - n_layers=27, - n_heads=16, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - out_hidden_size=2048, - num_position_embeddings=2304, - deepstack_visual_indices=[8, 16, 24], - ), - ) - - -def _235b_a22b( - attn_backend: str, - moe_comm_backend: str = "standard", -) -> Qwen3VLModel.Config: - dim = 4096 - head_dim = 128 - n_layers = 94 - vocab_size = 151936 - return Qwen3VLModel.Config( - vocab_size=vocab_size, - dim=dim, - norm=_qwen3_vl_norm(dim), - tok_embeddings=Embedding.Config( - num_embeddings=vocab_size, - embedding_dim=dim, - param_init=_EMBEDDING_INIT, - ), - lm_head=Linear.Config( - in_features=dim, - out_features=vocab_size, - param_init=_output_linear_init(dim), - ), - layers=_build_qwen3_vl_moe_layers( - attn_backend=attn_backend, - n_layers=n_layers, - dim=dim, - n_heads=64, - n_kv_heads=4, - head_dim=head_dim, - moe_hidden_dim=1536, - num_experts=128, - top_k=8, - moe_comm_backend=moe_comm_backend, - rope=MRoPE.Config( - dim=head_dim, - max_seq_len=32768, - theta=5000000.0, - mrope_section=[24, 20, 20], - ), - ), - vision_encoder=_vl_vision_encoder_config( - dim=1152, - ffn_dim=4304, - n_layers=27, - n_heads=16, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - out_hidden_size=4096, - num_position_embeddings=2304, - deepstack_visual_indices=[8, 16, 24], - ), - ) - - -qwen3_vl_configs = { - "debugmodel": _debugmodel, - "debugmodel_moe": _debugmodel_moe, - "2B": _2b, - "8B": _8b, - "30B-A3B": _30b_a3b, - "235B-A22B": _235b_a22b, -} - - -def model_registry( - flavor: str, - attn_backend: str = "flex", - moe_comm_backend: str | None = None, - converters: list[ModelConfigConverter.Config] | None = None, -) -> ModelSpec: - kwargs = dict(attn_backend=attn_backend) - if moe_comm_backend is not None: - kwargs["moe_comm_backend"] = moe_comm_backend - config = qwen3_vl_configs[flavor](**kwargs) - if converters is not None: - validate_converter_order(converters) - for c in converters: - c.build().convert(config) - return ModelSpec( - name="qwen3_vl", - flavor=flavor, - model=config, - parallelize_fn=parallelize_qwen3_vl, - pipelining_fn=None, - post_optimizer_build_fn=None, - state_dict_adapter=Qwen3VLStateDictAdapter, - ) diff --git a/torchtitan/models/qwen3_vl/config_registry.py b/torchtitan/models/qwen3_vl/config_registry.py deleted file mode 100644 index a2a7fd59ff..0000000000 --- a/torchtitan/models/qwen3_vl/config_registry.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchtitan.components.checkpoint import CheckpointManager -from torchtitan.components.loss import ChunkedCELoss -from torchtitan.components.lr_scheduler import LRSchedulersContainer -from torchtitan.components.metrics import MetricsProcessor -from torchtitan.components.optimizer import default_adamw -from torchtitan.components.tokenizer import MultiModalTokenizer - -from torchtitan.config import ( - ActivationCheckpointConfig, - ParallelismConfig, - TrainingConfig, -) -from torchtitan.hf_datasets.multimodal.mm_datasets import MMDataLoader -from torchtitan.trainer import Trainer - -from . import model_registry, QWEN3_VL_SPECIAL_TOKENS - - -def _qwen3_vl_dataloader(dataset: str, **kwargs) -> MMDataLoader.Config: - return MMDataLoader.Config( - dataset=dataset, - max_images_per_batch=128, - patch_size=16, - temporal_patch_size=2, - spatial_merge_size=2, - min_pixels=65536, - max_pixels=16777216, - image_mean=(0.5, 0.5, 0.5), - image_std=(0.5, 0.5, 0.5), - **kwargs, - ) - - -def qwen3_vl_debugmodel() -> Trainer.Config: - return Trainer.Config( - loss=ChunkedCELoss.Config(), - hf_assets_path="./tests/assets/tokenizer", - tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS), - metrics=MetricsProcessor.Config(log_freq=1), - model_spec=model_registry("debugmodel"), - dataloader=_qwen3_vl_dataloader("cc12m-test"), - optimizer=default_adamw(lr=8e-4), - lr_scheduler=LRSchedulersContainer.Config( - warmup_steps=2, - decay_ratio=0.8, - decay_type="linear", - min_lr_factor=0.0, - ), - training=TrainingConfig( - local_batch_size=1, - seq_len=512, - steps=10, - ), - checkpoint=CheckpointManager.Config( - interval=10, - last_save_model_only=False, - ), - activation_checkpoint=ActivationCheckpointConfig( - mode="selective", - ), - ) - - -def qwen3_vl_debugmodel_moe() -> Trainer.Config: - return Trainer.Config( - loss=ChunkedCELoss.Config(), - hf_assets_path="./tests/assets/tokenizer", - tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS), - metrics=MetricsProcessor.Config(log_freq=1), - model_spec=model_registry("debugmodel_moe"), - dataloader=_qwen3_vl_dataloader("cc12m-test"), - optimizer=default_adamw(lr=3e-3), - lr_scheduler=LRSchedulersContainer.Config(warmup_steps=2), - training=TrainingConfig( - local_batch_size=1, - seq_len=512, - steps=10, - ), - parallelism=ParallelismConfig( - data_parallel_shard_degree=4, - expert_parallel_degree=4, - tensor_parallel_degree=2, - ), - checkpoint=CheckpointManager.Config( - interval=10, - last_save_model_only=False, - ), - activation_checkpoint=ActivationCheckpointConfig( - mode="selective", - ), - ) - - -def qwen3_vl_2b() -> Trainer.Config: - return Trainer.Config( - loss=ChunkedCELoss.Config(), - hf_assets_path="./assets/hf/Qwen3-VL-2B-Instruct", - tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS), - model_spec=model_registry("2B"), - dataloader=_qwen3_vl_dataloader("cc12m"), - optimizer=default_adamw(lr=8e-4), - lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), - training=TrainingConfig( - local_batch_size=8, - seq_len=4096, - steps=1000, - ), - parallelism=ParallelismConfig( - data_parallel_shard_degree=-1, - tensor_parallel_degree=1, - ), - checkpoint=CheckpointManager.Config( - enable=False, - interval=50, - last_save_model_only=False, - export_dtype="float16", - ), - activation_checkpoint=ActivationCheckpointConfig( - mode="full", - ), - ) - - -def qwen3_vl_8b() -> Trainer.Config: - return Trainer.Config( - loss=ChunkedCELoss.Config(), - hf_assets_path="./assets/hf/Qwen3-VL-8B-Instruct", - tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS), - model_spec=model_registry("8B"), - dataloader=_qwen3_vl_dataloader("cc12m"), - optimizer=default_adamw(lr=8e-4), - lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), - training=TrainingConfig( - local_batch_size=4, - seq_len=4096, - steps=1000, - ), - parallelism=ParallelismConfig( - data_parallel_shard_degree=-1, - tensor_parallel_degree=1, - ), - checkpoint=CheckpointManager.Config( - enable=False, - interval=50, - last_save_model_only=False, - export_dtype="float16", - ), - activation_checkpoint=ActivationCheckpointConfig( - mode="full", - ), - ) - - -def qwen3_vl_30b_a3b() -> Trainer.Config: - return Trainer.Config( - loss=ChunkedCELoss.Config(), - hf_assets_path="./assets/hf/Qwen3-VL-30B-A3B-Instruct", - tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS), - model_spec=model_registry("30B-A3B"), - dataloader=_qwen3_vl_dataloader("cc12m"), - optimizer=default_adamw(lr=3e-4), - lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), - training=TrainingConfig( - local_batch_size=4, - seq_len=4096, - steps=1000, - ), - parallelism=ParallelismConfig( - data_parallel_shard_degree=-1, - tensor_parallel_degree=1, - expert_parallel_degree=8, - ), - checkpoint=CheckpointManager.Config( - enable=False, - interval=500, - last_save_model_only=False, - export_dtype="float16", - ), - activation_checkpoint=ActivationCheckpointConfig( - mode="full", - ), - ) diff --git a/torchtitan/models/qwen3_vl/model.py b/torchtitan/models/qwen3_vl/model.py deleted file mode 100644 index ce5e5e0374..0000000000 --- a/torchtitan/models/qwen3_vl/model.py +++ /dev/null @@ -1,573 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from dataclasses import dataclass - -import torch -from torch import nn - -from torchtitan.models.common.attention import AttentionMasksType, GQAttention -from torchtitan.models.common.decoder import Decoder -from torchtitan.models.qwen3.model import Qwen3Model -from torchtitan.models.utils import get_moe_model_nparams_and_flops - -from .vision_encoder import Qwen3VLVisionEncoder - - -class Qwen3VLModel(Qwen3Model): - """Qwen3-VL: A Vision-Language Model based on Qwen3. - - Combines the Qwen3 language model with a Vision Transformer encoder - for multimodal understanding of images and videos. - - Key features: - - DeepStack: Vision embeddings from intermediate ViT layers are added to - early LLM hidden states for better multimodal understanding - - MRoPE: Multi-dimensional RoPE with interleaved temporal/height/width - position encoding for vision tokens - - Forward pass flow:: - - forward(tokens, pixel_values, grid_thw, ...) - │ - ├─ _prepare_multimodal_embeds - │ ├─ tok_embeddings(tokens) → text embeddings - │ ├─ _get_vision_embeds(pixel_values) → padded vision embeddings + deepstack features - │ │ └─ vision_encoder(pixel_values) → per-layer features, merge patches - │ ├─ _compute_vision_positions → locate vision regions in text sequence - │ └─ _scatter_vision_embeds → copy vision into text at placeholder positions - │ - ├─ _compute_mrope_position_ids → build 3D MRoPE position IDs - │ - └─ transformer layers - └─ for each layer: - ├─ layer(hidden_states, masks, positions) - └─ _deepstack_process: add intermediate ViT features at vision positions (early layers only) - """ - - @dataclass(kw_only=True, slots=True) - class Config(Qwen3Model.Config): - vision_encoder: Qwen3VLVisionEncoder.Config - - def update_from_config( - self, - *, - config, - **kwargs, - ) -> None: - Decoder.Config.update_from_config(self, config=config, **kwargs) - parallelism = config.parallelism - - from torchtitan.models.qwen3_vl.sharding import set_qwen3_vl_sharding_config - - set_qwen3_vl_sharding_config( - self, - loss_parallel=not parallelism.disable_loss_parallel, - enable_ep=parallelism.expert_parallel_degree > 1, - ) - - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, int]: - assert isinstance(self.layers[0].attention, GQAttention.Config) - assert self.layers[0].attention.head_dim is not None - return get_moe_model_nparams_and_flops( - self, - model, - self.layers[0].attention.n_heads, - 2 * self.layers[0].attention.head_dim, - seq_len, - ) - - def __init__(self, config: Config): - super().__init__(config) - - self.vision_encoder = config.vision_encoder.build() - - self.spatial_merge_size = config.vision_encoder.spatial_merge_size - - # Number of early LLM layers that receive DeepStack visual features - self.num_deepstack_layers = len(config.vision_encoder.deepstack_visual_indices) - - def _compute_mrope_position_ids( - self, - tokens: torch.Tensor, - *, - grid_thw: torch.Tensor | None, - grid_thw_videos: torch.Tensor | None, - special_tokens: dict[str, int], - positions: torch.Tensor | None = None, - ) -> torch.Tensor: - """Build 3D position IDs for Qwen3-VL MRoPE. - - Constructs temporal/height/width position IDs for each token. The RoPE - module consumes these IDs and computes the interleaved cos/sin cache. - Text tokens use the same position value across all three axes. Vision - tokens use temporal, height, and width positions derived from their - patch grid. - - Args: - tokens: (batch, seq_len) token IDs - grid_thw: (num_images, 3) grid dimensions for images - grid_thw_videos: (num_videos, 3) grid dimensions for videos - special_tokens: Special token definitions - positions: (batch, seq_len) per-token position IDs for packed - sequences. When provided, document boundaries are detected - where positions reset (positions[t] < positions[t-1]), and - pos_id_offset resets to 0 at each boundary - - Returns: - (3, batch, seq_len) MRoPE position IDs - """ - # --- Build 3D position IDs --- - - # Expand each video [T, H, W] into T rows of [1, H, W] so that - # each frame is treated like an image in the MRoPE code below - # Temporal position comes from frame ordering in the sequence - if grid_thw_videos is not None: - grid_thw_videos = torch.repeat_interleave( - grid_thw_videos, grid_thw_videos[:, 0], dim=0 - ) - grid_thw_videos[:, 0] = 1 - - spatial_merge_size = self.spatial_merge_size - image_token_id = special_tokens["image_id"] - video_token_id = special_tokens["video_id"] - - batch_size, seq_len = tokens.shape - position_ids = torch.zeros( - 3, - batch_size, - seq_len, - dtype=tokens.dtype, - device=tokens.device, - ) - - # Precompute document boundaries and vision token positions across batch - if positions is not None: - resets = positions[:, 1:] < positions[:, :-1] # (batch, seq_len-1) - # Find the first token of each consecutive vision region (image or video) - # E.g. for [text, img, img, img, text, vid, vid] → positions [1, 5] - vision_mask = (tokens == image_token_id) | (tokens == video_token_id) - prev_vision = torch.cat( - [torch.zeros_like(vision_mask[:, :1]), vision_mask[:, :-1]], dim=1 - ) - batch_vision_starts = vision_mask & ~prev_vision # (batch, seq_len) - # Cache vision grid indices by shape to avoid redundant construction - grid_cache: dict[tuple[int, int, int], torch.Tensor] = {} - - image_index, video_index = 0, 0 - # Build MRoPE 3D position IDs per sample - # With sample packing, each sample may contain multiple documents - for sample_i in range(batch_size): - llm_pos_ids_list: list[torch.Tensor] = [] - - if positions is not None: - # Detect document boundaries within one packed sample - # pyrefly: ignore [unbound-name] - reset_indices = torch.where(resets[sample_i])[0] + 1 - doc_starts = [0] + reset_indices.tolist() - doc_ranges = [ - ( - doc_starts[d], - doc_starts[d + 1] if d + 1 < len(doc_starts) else seq_len, - ) - for d in range(len(doc_starts)) - ] - else: - doc_ranges = [(0, seq_len)] - - sample_tokens = tokens[sample_i] - sample_vision_starts = torch.where(batch_vision_starts[sample_i])[ - 0 - ].tolist() - vision_start_index = 0 - - for doc_start, doc_end in doc_ranges: - doc_pos_ids_list: list[torch.Tensor] = [] - - # Advance pointer to collect vision region starts in this document - doc_vision_starts: list[int] = [] - while ( - vision_start_index < len(sample_vision_starts) - and sample_vision_starts[vision_start_index] < doc_end - ): - doc_vision_starts.append(sample_vision_starts[vision_start_index]) - vision_start_index += 1 - - # Process [text tokens][vision tokens] pairs within this document - pair_cursor = doc_start - for vision_start in doc_vision_starts: - if sample_tokens[vision_start] == image_token_id: - # pyrefly: ignore [unsupported-operation] - t, h, w = grid_thw[image_index] - image_index += 1 - else: - # pyrefly: ignore [unsupported-operation] - t, h, w = grid_thw_videos[video_index] - video_index += 1 - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = vision_start - pair_cursor - - # pos_id_offset may differ from pair_cursor due to compact - # spatial position IDs for vision regions - pos_id_offset = ( - doc_pos_ids_list[-1].max() + 1 - if len(doc_pos_ids_list) > 0 - else 0 - ) - # [text tokens] — sequential positions, identical on all 3 axes - doc_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset - ) - # [vision tokens] — 3D grid positions (T, H, W) - grid_key = (llm_grid_t, llm_grid_h, llm_grid_w) - if grid_key not in grid_cache: - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - # pyrefly: ignore [no-matching-overload] - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - # pyrefly: ignore [no-matching-overload] - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - # pyrefly: ignore [no-matching-overload] - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - # pyrefly: ignore [unsupported-operation] - grid_cache[grid_key] = torch.stack([t_index, h_index, w_index]) - doc_pos_ids_list.append( - # pyrefly: ignore [bad-index] - grid_cache[grid_key] - + text_len - + pos_id_offset - ) - pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w - - # Trailing [text tokens] after the last [text tokens][vision tokens] pair - if pair_cursor < doc_end: - pos_id_offset = ( - doc_pos_ids_list[-1].max() + 1 - if len(doc_pos_ids_list) > 0 - else 0 - ) - text_len = doc_end - pair_cursor - doc_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset - ) - - llm_pos_ids_list.extend(doc_pos_ids_list) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[:, sample_i, :] = llm_positions.to(position_ids.device) - - return position_ids - - def _compute_vision_positions( - self, - tokens: torch.Tensor, - num_tokens_per_item: torch.Tensor, - vision_token_id: int, - ) -> list[tuple[int, int, int, int]]: - """Compute (item_idx, sample_idx, vision_start, n_tokens) for each vision item. - - Finds where each contiguous run of vision placeholder tokens starts - in the text sequence. - - Args: - tokens: Token IDs (batch, seq_len) - num_tokens_per_item: (num_items,) actual tokens per vision item - vision_token_id: Placeholder token ID - - Returns: - List of (item_idx, sample_idx, vision_start, n_tokens) tuples - """ - vision_mask = tokens == vision_token_id - flat_mask = vision_mask.view(-1) - prev_mask = torch.cat( - [torch.zeros(1, dtype=torch.bool, device=flat_mask.device), flat_mask[:-1]] - ) - region_starts = torch.where(flat_mask & ~prev_mask)[0] - seq_len = tokens.shape[1] - - positions = [] - for i in range(num_tokens_per_item.shape[0]): - start = int(region_starts[i].item()) - n_tokens = int(num_tokens_per_item[i].item()) - positions.append((i, start // seq_len, start % seq_len, n_tokens)) - return positions - - def _get_vision_embeds( - self, - pixel_values: torch.Tensor, - *, - grid_thw: torch.Tensor, - ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: - """Run vision encoder and return padded embeddings with token counts. - - Works for both images and videos — the ViT processes them identically. - - Args: - pixel_values: Padded patches (num_items, max_num_patch, patch_dim) - grid_thw: Grid dimensions (num_items, 3) for [t, h, w] - - Returns: - merged_embeds: (num_items, max_tokens, dim) padded vision embeddings - deepstack_features: List of (num_items, max_tokens, dim) per layer - num_tokens_per_item: (num_items,) actual token count per item - """ - pixel_values = pixel_values.to( - self.vision_encoder.patch_embed.proj.weight.dtype - ) - merged_embeds, deepstack_features = self.vision_encoder( - pixel_values, grid_thw=grid_thw - ) - - merge_unit = self.vision_encoder.spatial_merge_unit - num_tokens_per_item = grid_thw.prod(-1) // merge_unit - - return merged_embeds, deepstack_features, num_tokens_per_item - - def _scatter_vision_embeds( - self, - inputs_embeds: torch.Tensor, - *, - merged_embeds: torch.Tensor, - vision_positions: list[tuple[int, int, int, int]], - ) -> torch.Tensor: - """Scatter vision embeddings into text embeddings at placeholder positions. - - Copies directly from the padded vision encoder output into the text - sequence. - - Args: - inputs_embeds: Text embeddings (batch, seq_len, dim) - merged_embeds: Padded vision embeddings (num_items, max_tokens, dim) - vision_positions: List of (item_idx, sample_idx, vision_start, n_tokens) - - Returns: - Updated embeddings - """ - for item_idx, sample_idx, vision_start, n_tokens in vision_positions: - inputs_embeds[ - sample_idx, vision_start : vision_start + n_tokens, : - ] = merged_embeds[item_idx, :n_tokens, :] - return inputs_embeds - - def _deepstack_process( - self, - hidden_states: torch.Tensor, - *, - vision_positions: list[tuple[int, int, int, int]], - deepstack_embeds: torch.Tensor, - ) -> torch.Tensor: - """Add vision embeddings to hidden states at vision token positions. - - Args: - hidden_states: LLM hidden states (batch, seq_len, dim) - vision_positions: List of (item_idx, sample_idx, vision_start, n_tokens) - deepstack_embeds: Padded vision embeddings (num_items, max_tokens, dim) - - Returns: - Updated hidden states - """ - for item_idx, sample_idx, vision_start, n_tokens in vision_positions: - hidden_states[ - sample_idx, vision_start : vision_start + n_tokens, : - ] += deepstack_embeds[item_idx, :n_tokens, :] - return hidden_states - - def _prepare_multimodal_embeds( - self, - tokens: torch.Tensor, - *, - pixel_values: torch.Tensor | None, - pixel_values_videos: torch.Tensor | None, - grid_thw: torch.Tensor | None, - grid_thw_videos: torch.Tensor | None, - special_tokens: dict[str, int], - ) -> tuple[ - torch.Tensor, - list[tuple[int, int, int, int]], - list[tuple[int, int, int, int]], - list[torch.Tensor] | None, - list[torch.Tensor] | None, - ]: - """Embed tokens, run vision encoder, scatter vision into text, prepare DeepStack. - - Args: - tokens: Input token IDs (batch_size, seq_len) - pixel_values: Image patches or None - pixel_values_videos: Video patches or None - grid_thw: Grid dimensions for images or None - grid_thw_videos: Grid dimensions for videos or None - special_tokens: Special token definitions - - Returns: - inputs_embeds: (batch, seq_len, dim) with vision tokens scattered in - image_positions: List of (item_idx, sample_idx, vision_start, n_tokens) - video_positions: List of (item_idx, sample_idx, vision_start, n_tokens) - deepstack_image_features: List of (num_items, max_tokens, dim) or None - deepstack_video_features: List of (num_items, max_tokens, dim) or None - """ - image_token_id = special_tokens["image_id"] - video_token_id = special_tokens["video_id"] - - inputs_embeds = ( - self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens - ) - - # Process image inputs - image_positions: list[tuple[int, int, int, int]] = [] - deepstack_image_features = None - if pixel_values is not None and grid_thw is not None: - merged_embeds, deepstack_features, num_tokens = self._get_vision_embeds( - pixel_values, grid_thw=grid_thw - ) - image_positions = self._compute_vision_positions( - tokens, num_tokens, image_token_id - ) - if image_positions: - inputs_embeds = self._scatter_vision_embeds( - inputs_embeds, - merged_embeds=merged_embeds, - vision_positions=image_positions, - ) - deepstack_image_features = deepstack_features - - # Process video inputs - video_positions: list[tuple[int, int, int, int]] = [] - deepstack_video_features = None - if pixel_values_videos is not None and grid_thw_videos is not None: - merged_embeds, deepstack_features, num_tokens = self._get_vision_embeds( - pixel_values_videos, grid_thw=grid_thw_videos - ) - video_positions = self._compute_vision_positions( - tokens, num_tokens, video_token_id - ) - if video_positions: - inputs_embeds = self._scatter_vision_embeds( - inputs_embeds, - merged_embeds=merged_embeds, - vision_positions=video_positions, - ) - deepstack_video_features = deepstack_features - - return ( - inputs_embeds, - image_positions, - video_positions, - deepstack_image_features, - deepstack_video_features, - ) - - def forward( # pyrefly: ignore [bad-override] - self, - tokens: torch.Tensor, - *, - pixel_values: torch.Tensor | None = None, - pixel_values_videos: torch.Tensor | None = None, - grid_thw: torch.Tensor | None = None, - grid_thw_videos: torch.Tensor | None = None, - attention_masks: AttentionMasksType | None = None, - positions: torch.Tensor | None = None, - special_tokens: dict[str, int], - ): - """Forward pass of Qwen3-VL. - - Args: - tokens: Input token IDs (batch_size, seq_len) - pixel_values: Flattened image patches (num_images, max_num_patches, patch_dim) - pixel_values_videos: Flattened video patches (num_videos, max_num_patches, patch_dim) - grid_thw: Grid dimensions for images (num_images, 3) - grid_thw_videos: Grid dimensions for videos (num_videos, 3) - attention_masks: Attention masks for block_causal / flex attention - positions: Per-token position IDs (batch_size, seq_len) for packed sequences. - Each document's positions reset to 0. None means sequential positions. - special_tokens: Special token definitions - - Returns: - Output logits (batch_size, seq_len, vocab_size) - """ - ( - inputs_embeds, - image_positions, - video_positions, - deepstack_image_features, - deepstack_video_features, - ) = self._prepare_multimodal_embeds( - tokens, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - grid_thw=grid_thw, - grid_thw_videos=grid_thw_videos, - special_tokens=special_tokens, - ) - - # Compute MRoPE position IDs when vision inputs are present. - if grid_thw is not None or grid_thw_videos is not None: - positions = self._compute_mrope_position_ids( - tokens, - grid_thw=grid_thw, - grid_thw_videos=grid_thw_videos, - special_tokens=special_tokens, - positions=positions, - ) - - # Apply transformer layers with DeepStack - hidden_states = inputs_embeds - for layer_idx, layer in self.layers.items(): - hidden_states = layer(hidden_states, attention_masks, positions) - - # Apply DeepStack: add visual features to early layer hidden states - layer_idx_int = int(layer_idx) - if layer_idx_int < self.num_deepstack_layers: - if ( - deepstack_image_features is not None - and image_positions - and layer_idx_int < len(deepstack_image_features) - ): - hidden_states = self._deepstack_process( - hidden_states, - vision_positions=image_positions, - deepstack_embeds=deepstack_image_features[layer_idx_int], - ) - if ( - deepstack_video_features is not None - and video_positions - and layer_idx_int < len(deepstack_video_features) - ): - hidden_states = self._deepstack_process( - hidden_states, - vision_positions=video_positions, - deepstack_embeds=deepstack_video_features[layer_idx_int], - ) - - hidden_states = ( - self.norm(hidden_states) if self.norm is not None else hidden_states - ) - if self._skip_lm_head: - return hidden_states - output = ( - self.lm_head(hidden_states) if self.lm_head is not None else hidden_states - ) - return output diff --git a/torchtitan/models/qwen3_vl/parallelize.py b/torchtitan/models/qwen3_vl/parallelize.py deleted file mode 100644 index 006d0ea82d..0000000000 --- a/torchtitan/models/qwen3_vl/parallelize.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Parallelization utilities for Qwen3-VL. - -This module applies PT-D parallelisms and various training techniques -(activation checkpointing, compile, FSDP) to the Qwen3-VL model. -""" - -import torch -import torch.nn as nn - -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed.tensor import distribute_tensor, Replicate -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - parallelize_module, - RowwiseParallel, -) - -from torchtitan.config import ( - ActivationCheckpointConfig, - CompileConfig, - ParallelismConfig, - TORCH_DTYPE_MAP, - TrainingConfig, -) - -from torchtitan.distributed import ParallelDims -from torchtitan.distributed.activation_checkpoint import apply_ac -from torchtitan.distributed.compile import apply_compile -from torchtitan.distributed.fsdp import ( - apply_fsdp_to_decoder, - get_fsdp_reshard_after_forward_policy, -) -from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp, NoParallel -from torchtitan.models.qwen3_vl.model import Qwen3VLModel -from torchtitan.tools.logging import logger - - -def _apply_tp_to_vision_encoder( - vision_encoder: nn.Module, - tp_mesh: DeviceMesh, -): - """Apply tensor parallelism to the vision encoder. - - Hidden states flow as DTensor(Replicate) throughout — all ranks hold the - full hidden_states. Only the linear layers (qkv, proj, fc1, fc2) are - sharded via ColwiseParallel/RowwiseParallel to save memory. Norms operate - on Replicate DTensors directly. - """ - # NoParallel on patch_embed distributes its params as Replicate DTensors - # on tp_mesh for FSDP mesh consistency. Its input hook wraps plain - # pixel_values as DTensor(Replicate), and the output stays as DTensor - # (Replicate) to flow through the rest of the vision encoder. - parallelize_module(vision_encoder, tp_mesh, {"patch_embed": NoParallel()}) - - # pos_embed is an nn.Parameter (not a submodule), so it can't be targeted - # by parallelize_module's plan dict. We distribute it as Replicate DTensor - # on tp_mesh for FSDP mesh consistency. - vision_encoder.pos_embed = nn.Parameter( - # pyrefly: ignore [bad-argument-type] - distribute_tensor(vision_encoder.pos_embed.data, tp_mesh, [Replicate()]), - requires_grad=vision_encoder.pos_embed.requires_grad, - ) - - # TP plan for each vision transformer block. - # hidden_states flows through as DTensor (Replicate) so residual adds work. - # NoParallel on norms sets their params as Replicate DTensors on tp_mesh - # (for consistent (fsdp, tp) mesh after FSDP). - # RowwiseParallel uses use_local_output=False to return DTensor (Replicate) - # so residual connections (hidden_states + attn/mlp output) stay in DTensor - # space. ColwiseParallel uses use_local_output=True (default) to return - # local shards for the internal attention/MLP computation. - layer_plan = { - "norm1": NoParallel(), - "norm2": NoParallel(), - "attn.qkv": ColwiseParallel(), # needs plain tensor for reshape after qkv - "attn.proj": RowwiseParallel(use_local_output=False), - "mlp.linear_fc1": ColwiseParallel(use_local_output=False), - "mlp.linear_fc2": RowwiseParallel(use_local_output=False), - } - - # pyrefly: ignore [not-callable] - for transformer_block in vision_encoder.layers.values(): - # pyrefly: ignore [bad-argument-type] - parallelize_module(transformer_block, tp_mesh, layer_plan) - - # TP plan for patch mergers (main + deepstack). - # Mergers output DTensor(Replicate) — the model passes padded embeddings - # directly to vision scatter and DeepStack. - merger_plan = { - "norm": NoParallel(), - "linear_fc1": ColwiseParallel(use_local_output=False), - "linear_fc2": RowwiseParallel(use_local_output=False), - } - - # pyrefly: ignore [bad-argument-type] - parallelize_module(vision_encoder.merger, tp_mesh, merger_plan) - # pyrefly: ignore [not-iterable] - for merger in vision_encoder.deepstack_merger_list: - # pyrefly: ignore [bad-argument-type] - parallelize_module(merger, tp_mesh, merger_plan) - - logger.info("Applied Tensor Parallelism to the vision encoder") - - -def _apply_fsdp_to_vision_encoder( - vision_encoder: nn.Module, - dp_mesh: DeviceMesh, - param_dtype: torch.dtype, - reduce_dtype: torch.dtype, - reshard_after_forward_policy: str = "default", - pp_enabled: bool = False, -): - """ - Apply FSDP to the vision encoder as a single unit. - - Wraps the entire vision encoder with one fully_shard call so all parameters - are gathered in a single AllGather. The vision encoder's compute is small - relative to the decoder, so per-layer sharding would launch many small - AllGather kernels whose total overhead exceeds a single AllGather followed - by computing all layers in one shot — even without overlap. - - Must be called before apply_fsdp on the decoder so the vision encoder is - already sharded when the final fully_shard(model) is applied. - """ - mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - reshard_after_forward = get_fsdp_reshard_after_forward_policy( - reshard_after_forward_policy, pp_enabled=pp_enabled - ) - - fully_shard( - vision_encoder, - **fsdp_config, - reshard_after_forward=reshard_after_forward, - ) - - -def parallelize_qwen3_vl( - model: Qwen3VLModel, - *, - parallel_dims: ParallelDims, - training: TrainingConfig, - parallelism: ParallelismConfig, - compile_config: CompileConfig, - ac_config: ActivationCheckpointConfig, - dump_folder: str, -): - """ - Apply tensor parallelism, activation checkpointing, torch.compile, and data - parallelism to the Qwen3-VL model. - - NOTE: The passed-in model preferably should be on meta device. Otherwise, - the model must fit on GPU or CPU memory. - """ - if parallelism.spmd_backend == "full_dtensor": - raise NotImplementedError("full_dtensor is not supported yet.") - - model_compile_enabled = ( - compile_config.enable and "model" in compile_config.components - ) - - if parallel_dims.cp_enabled: - raise NotImplementedError("Context Parallel is not yet supported for Qwen3-VL.") - - if parallel_dims.tp_enabled: - # TODO(@fegin): Apply TP to vision encoder (still uses parallelize_module path) - if model.vision_encoder is not None: - _apply_tp_to_vision_encoder( - model.vision_encoder, parallel_dims.get_mesh("tp") - ) - - if parallel_dims.tp_enabled or parallel_dims.ep_enabled: - model.parallelize(parallel_dims) - - if parallel_dims.tp_enabled: - maybe_enable_async_tp(parallelism, compile_config, parallel_dims.get_mesh("tp")) - - # Apply activation checkpointing - if ac_config.mode != "none": - apply_ac( - model, - ac_config, - model_compile_enabled=model_compile_enabled, - base_folder=dump_folder, - ) - if model.vision_encoder is not None: - apply_ac( - model.vision_encoder, - ac_config, - model_compile_enabled=model_compile_enabled, - base_folder=dump_folder, - ) - - # Apply torch.compile after AC wrapping and before FSDP - if model_compile_enabled: - apply_compile(model, compile_config) - if model.vision_encoder is not None: - apply_compile(model.vision_encoder, compile_config) - - # Apply FSDP / HSDP unconditionally (fully_shard handles dp_shard=1) - dp_mesh_names = ( - ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] - ) - dp_mesh = parallel_dims.get_mesh(dp_mesh_names) - - # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - edp_mesh = None - if parallel_dims.ep_enabled: - edp_mesh_names = ( - ["dp_replicate", "efsdp"] - if parallel_dims.dp_replicate_enabled - else ["efsdp"] - ) - edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) - - # FSDP the vision encoder as a single unit (see _apply_fsdp_to_vision_encoder) - if model.vision_encoder is not None: - _apply_fsdp_to_vision_encoder( - model.vision_encoder, - dp_mesh, - param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce], - reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward, - pp_enabled=parallel_dims.pp_enabled, - ) - - # FSDP the decoder with MoE-aware sharding - apply_fsdp_to_decoder( - model, - dp_mesh, - param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce], - pp_enabled=parallel_dims.pp_enabled, - cpu_offload=training.enable_cpu_offload, - reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward, - ep_degree=parallel_dims.ep, - edp_mesh=edp_mesh, - ) - - return model diff --git a/torchtitan/models/qwen3_vl/sharding.py b/torchtitan/models/qwen3_vl/sharding.py deleted file mode 100644 index d93f105de9..0000000000 --- a/torchtitan/models/qwen3_vl/sharding.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import TYPE_CHECKING - -from torchtitan.models.qwen3.sharding import set_qwen3_sharding_config - -if TYPE_CHECKING: - from torchtitan.models.qwen3_vl.model import Qwen3VLModel - - -def set_qwen3_vl_sharding_config( - config: "Qwen3VLModel.Config", - *, - loss_parallel: bool, - enable_ep: bool, -) -> None: - """Fill ``sharding_config`` on all Qwen3-VL sub-configs. - - Delegates to ``set_qwen3_sharding_config`` with ``enable_sp=False`` - because Qwen3-VL keeps hidden states as ``Replicate`` (not - ``Shard(1)``) — no SequenceParallel due to vision scatter and - DeepStack needing full-sequence access. - """ - set_qwen3_sharding_config( - config, - loss_parallel=loss_parallel, - enable_sp=False, - enable_ep=enable_ep, - ) diff --git a/torchtitan/models/qwen3_vl/state_dict_adapter.py b/torchtitan/models/qwen3_vl/state_dict_adapter.py deleted file mode 100644 index be5c097f23..0000000000 --- a/torchtitan/models/qwen3_vl/state_dict_adapter.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -State dict adapter for Qwen3-VL. - -Converts between HuggingFace Qwen3-VL checkpoint format and torchtitan format. - -MoE expert weights require two transformations: -- **Transpose**: HF and TT use transposed layouts for grouped 3D expert weights. - E.g. HF down_proj [E, hidden, dim] <-> TT w2 [E, dim, hidden]. -- **Fuse/split gate_up_proj**: HF fuses gate_proj and up_proj into a single - gate_up_proj [E, dim, 2*hidden_dim]. TT stores them separately as - w1 [E, hidden_dim, dim] and w3 [E, hidden_dim, dim]. - -Other notable conversions: -- Conv3d patch embedding (HF) <-> Linear (TT) via weight reshape -- Vision block naming: HF `blocks` <-> TT `layers` -""" - -import re -from typing import Any - -import torch - -from torchtitan.models.common.attention import FusedQKVLinear -from torchtitan.protocols.state_dict_adapter import StateDictAdapter -from .model import Qwen3VLModel -from .rope import MRoPE - - -class Qwen3VLStateDictAdapter(StateDictAdapter): - def __init__(self, model_config: Qwen3VLModel.Config, hf_assets_path: str | None): - super().__init__(model_config, hf_assets_path) - self.model_config = model_config - self.fuse_qkv = isinstance( - model_config.layers[0].attention.qkv_linear, FusedQKVLinear.Config - ) - - qkv_map: dict[str, str | None] - if self.fuse_qkv: - qkv_map = { - "model.language_model.layers.{}.self_attn.q_proj.weight": None, - "model.language_model.layers.{}.self_attn.k_proj.weight": None, - "model.language_model.layers.{}.self_attn.v_proj.weight": None, - } - else: - qkv_map = { - "model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.qkv_linear.wq.weight", - "model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.qkv_linear.wk.weight", - "model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.qkv_linear.wv.weight", - } - - self.from_hf_map = { - # ===== Language Model ===== - "model.language_model.embed_tokens.weight": "tok_embeddings.weight", - # Attention - **qkv_map, - "model.language_model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.language_model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight", - "model.language_model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight", - "model.language_model.layers.{}.self_attn.rotary_emb.inv_freq": None, - # Non-MoE MLP - "model.language_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", - "model.language_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.language_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - # Layer norms - "model.language_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.language_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - # MoE (grouped 3D format, handled specially in to_hf/from_hf) - # gate_up_proj is fused gate+up, mapped to w1+w3 via custom logic - "model.language_model.layers.{}.mlp.experts.down_proj": "layers.{}.moe.experts.w2", - "model.language_model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", - # Final norm and output - "model.language_model.norm.weight": "norm.weight", - "lm_head.weight": "lm_head.weight", - # ===== Vision Encoder ===== - # Patch embedding (Conv3d in HF, Linear in TT - weight reshape needed) - "model.visual.patch_embed.proj.weight": "vision_encoder.patch_embed.proj.weight", - "model.visual.patch_embed.proj.bias": "vision_encoder.patch_embed.proj.bias", - # Position embeddings - "model.visual.pos_embed.weight": "vision_encoder.pos_embed", - # Vision transformer blocks (HF: blocks, TT: layers) - "model.visual.blocks.{}.norm1.weight": "vision_encoder.layers.{}.norm1.weight", - "model.visual.blocks.{}.norm1.bias": "vision_encoder.layers.{}.norm1.bias", - "model.visual.blocks.{}.norm2.weight": "vision_encoder.layers.{}.norm2.weight", - "model.visual.blocks.{}.norm2.bias": "vision_encoder.layers.{}.norm2.bias", - "model.visual.blocks.{}.attn.qkv.weight": "vision_encoder.layers.{}.attn.qkv.weight", - "model.visual.blocks.{}.attn.qkv.bias": "vision_encoder.layers.{}.attn.qkv.bias", - "model.visual.blocks.{}.attn.proj.weight": "vision_encoder.layers.{}.attn.proj.weight", - "model.visual.blocks.{}.attn.proj.bias": "vision_encoder.layers.{}.attn.proj.bias", - "model.visual.blocks.{}.mlp.linear_fc1.weight": "vision_encoder.layers.{}.mlp.linear_fc1.weight", - "model.visual.blocks.{}.mlp.linear_fc1.bias": "vision_encoder.layers.{}.mlp.linear_fc1.bias", - "model.visual.blocks.{}.mlp.linear_fc2.weight": "vision_encoder.layers.{}.mlp.linear_fc2.weight", - "model.visual.blocks.{}.mlp.linear_fc2.bias": "vision_encoder.layers.{}.mlp.linear_fc2.bias", - # Merger (maps vision dim to LLM dim) - "model.visual.merger.norm.weight": "vision_encoder.merger.norm.weight", - "model.visual.merger.norm.bias": "vision_encoder.merger.norm.bias", - "model.visual.merger.linear_fc1.weight": "vision_encoder.merger.linear_fc1.weight", - "model.visual.merger.linear_fc1.bias": "vision_encoder.merger.linear_fc1.bias", - "model.visual.merger.linear_fc2.weight": "vision_encoder.merger.linear_fc2.weight", - "model.visual.merger.linear_fc2.bias": "vision_encoder.merger.linear_fc2.bias", - # DeepStack mergers - "model.visual.deepstack_merger_list.{}.norm.weight": "vision_encoder.deepstack_merger_list.{}.norm.weight", - "model.visual.deepstack_merger_list.{}.norm.bias": "vision_encoder.deepstack_merger_list.{}.norm.bias", - "model.visual.deepstack_merger_list.{}.linear_fc1.weight": "vision_encoder.deepstack_merger_list.{}.linear_fc1.weight", - "model.visual.deepstack_merger_list.{}.linear_fc1.bias": "vision_encoder.deepstack_merger_list.{}.linear_fc1.bias", - "model.visual.deepstack_merger_list.{}.linear_fc2.weight": "vision_encoder.deepstack_merger_list.{}.linear_fc2.weight", - "model.visual.deepstack_merger_list.{}.linear_fc2.bias": "vision_encoder.deepstack_merger_list.{}.linear_fc2.bias", - } - - def _get_attention_dims(self) -> tuple[int, int, int]: - """Return (n_heads, n_kv_heads, head_dim) from model config.""" - # pyrefly: ignore [missing-attribute] - attn = self.model_config.layers[0].attention - n_heads = attn.n_heads - n_kv_heads = attn.n_kv_heads if attn.n_kv_heads is not None else n_heads - head_dim = ( - attn.head_dim - if attn.head_dim is not None - else self.model_config.dim // n_heads # pyrefly: ignore [missing-attribute] - ) - return n_heads, n_kv_heads, head_dim - - def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: - """Convert torchtitan state dict to HuggingFace Qwen3-VL format.""" - - if self.fuse_qkv: - to_hf_map = {v: k for k, v in self.from_hf_map.items() if v is not None} - n_heads, n_kv_heads, head_dim = self._get_attention_dims() - else: - to_hf_map = {v: k for k, v in self.from_hf_map.items() if v is not None} - hf_state_dict = {} - - # Collect MoE w1/w3 per layer to fuse into gate_up_proj - moe_w1_by_layer: dict[str, Any] = {} - moe_w3_by_layer: dict[str, Any] = {} - - for tt_key, value in state_dict.items(): - if "moe.experts" in tt_key: - tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1) - # pyrefly: ignore [missing-attribute] - layer_num = re.search(r"\d+", tt_key).group(0) - - # Collect w1/w3 for fusing into gate_up_proj - if tt_abstract_key == "layers.{}.moe.experts.w1": - moe_w1_by_layer[layer_num] = value - continue - elif tt_abstract_key == "layers.{}.moe.experts.w3": - moe_w3_by_layer[layer_num] = value - continue - - # Handle down_proj: TT w2 [E, dim, hidden] -> HF [E, hidden, dim] - if tt_abstract_key == "layers.{}.moe.experts.w2": - hf_key = ( - f"model.language_model.layers.{layer_num}.mlp.experts.down_proj" - ) - hf_state_dict[hf_key] = value.transpose(-2, -1) - continue - - if tt_abstract_key not in to_hf_map: - continue - hf_key = to_hf_map[tt_abstract_key].format(layer_num) - hf_state_dict[hf_key] = value - # Indexed key: contains a layer/block/merger index (e.g. ".0.") - elif re.search(r"\.\d+\.", tt_key): - tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1) - # pyrefly: ignore [missing-attribute] - layer_num = re.search(r"\d+", tt_key).group(0) - - # Handle fused QKV: split wqkv into separate q/k/v projections - if ( - self.fuse_qkv - and tt_abstract_key == "layers.{}.attention.qkv_linear.wqkv.weight" - ): - wq, wk, wv = self.fused_to_separate_qkv( - value, - n_heads, # pyrefly: ignore [unbound-name] - n_kv_heads, # pyrefly: ignore [unbound-name] - head_dim, # pyrefly: ignore [unbound-name] - ) - hf_state_dict[ - f"model.language_model.layers.{layer_num}.self_attn.q_proj.weight" - ] = wq - hf_state_dict[ - f"model.language_model.layers.{layer_num}.self_attn.k_proj.weight" - ] = wk - hf_state_dict[ - f"model.language_model.layers.{layer_num}.self_attn.v_proj.weight" - ] = wv - continue - - if tt_abstract_key not in to_hf_map: - continue - hf_key = to_hf_map[tt_abstract_key].format(layer_num) - hf_state_dict[hf_key] = value - - else: - if tt_key not in to_hf_map: - continue - if ( - tt_key == "lm_head.weight" - and self.model_config.enable_weight_tying # pyrefly: ignore [missing-attribute] - ): - continue - hf_key = to_hf_map[tt_key] - hf_value = value - # Linear weight (out, C*T*H*W) -> Conv3d weight (out, C, T, H, W) - # Plain reshape since both use channel-first (c pt ph pw) layout. - if tt_key == "vision_encoder.patch_embed.proj.weight": - # pyrefly: ignore [missing-attribute] - patch_embed = self.model_config.vision_encoder.patch_embed - hf_value = value.reshape( - value.shape[0], - patch_embed.in_channels, - patch_embed.temporal_patch_size, - patch_embed.patch_size, - patch_embed.patch_size, - ) - hf_state_dict[hf_key] = hf_value - - # Fuse w1 (gate) and w3 (up) into gate_up_proj per layer - # TT w1/w3: [E, hidden_dim, dim] -> transpose to [E, dim, hidden_dim] -> cat on last dim - for layer_num in moe_w1_by_layer: - w1 = moe_w1_by_layer[layer_num].transpose(-2, -1) # [E, dim, hidden_dim] - w3 = moe_w3_by_layer[layer_num].transpose(-2, -1) # [E, dim, hidden_dim] - gate_up = torch.cat([w1, w3], dim=-1) # [E, dim, 2*hidden_dim] - hf_key = f"model.language_model.layers.{layer_num}.mlp.experts.gate_up_proj" - hf_state_dict[hf_key] = gate_up - - return hf_state_dict - - def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: - """Convert HuggingFace Qwen3-VL state dict to torchtitan format.""" - self._validate_hf_rope_config(MRoPE.Config) - tt_state_dict = {} - # Collect Q/K/V per layer for fusing (only used when fuse_qkv=True) - pending_qkv: dict[str, dict[str, torch.Tensor]] = {} - - if self.fuse_qkv: - n_heads, n_kv_heads, head_dim = self._get_attention_dims() - - # HF Qwen3-VL ties lm_head.weight with embed_tokens.weight, - # so lm_head.weight may not be stored in the checkpoint. - if "lm_head.weight" not in hf_state_dict: - if "model.language_model.embed_tokens.weight" not in hf_state_dict: - raise ValueError( - "HF checkpoint missing both 'lm_head.weight' and " - "'model.language_model.embed_tokens.weight'" - ) - hf_state_dict["lm_head.weight"] = hf_state_dict[ - "model.language_model.embed_tokens.weight" - ] - - for hf_key, value in hf_state_dict.items(): - # Indexed key: contains a layer/block/merger index (e.g. ".0.") - if re.search(r"\.\d+\.", hf_key): - hf_abstract_key = re.sub(r"(\d+)", "{}", hf_key, count=1) - # pyrefly: ignore [missing-attribute] - idx = re.search(r"\d+", hf_key).group(0) - - # Handle fused QKV: collect q/k/v and fuse when all 3 are ready - if self.fuse_qkv and hf_abstract_key in ( - "model.language_model.layers.{}.self_attn.q_proj.weight", - "model.language_model.layers.{}.self_attn.k_proj.weight", - "model.language_model.layers.{}.self_attn.v_proj.weight", - ): - if idx not in pending_qkv: - pending_qkv[idx] = {} - proj = hf_abstract_key.split(".")[-2] # q_proj, k_proj, v_proj - pending_qkv[idx][proj] = value - if len(pending_qkv[idx]) == 3: - fused = self.separate_to_fused_qkv( - pending_qkv[idx]["q_proj"], - pending_qkv[idx]["k_proj"], - pending_qkv[idx]["v_proj"], - n_heads, # pyrefly: ignore [unbound-name] - n_kv_heads, # pyrefly: ignore [unbound-name] - head_dim, # pyrefly: ignore [unbound-name] - ) - tt_state_dict[ - f"layers.{idx}.attention.qkv_linear.wqkv.weight" - ] = fused - del pending_qkv[idx] - continue - - # Handle fused gate_up_proj: split and transpose - # HF gate_up_proj: [E, dim, 2*hidden_dim] -> split -> transpose each to [E, hidden_dim, dim] - if ( - hf_abstract_key - == "model.language_model.layers.{}.mlp.experts.gate_up_proj" - ): - w1_hf, w3_hf = value.chunk(2, dim=-1) # each [E, dim, hidden_dim] - tt_state_dict[f"layers.{idx}.moe.experts.w1"] = w1_hf.transpose( - -2, -1 - ) - tt_state_dict[f"layers.{idx}.moe.experts.w3"] = w3_hf.transpose( - -2, -1 - ) - continue - - # Handle down_proj transpose: HF [E, hidden, dim] -> TT w2 [E, dim, hidden] - if ( - hf_abstract_key - == "model.language_model.layers.{}.mlp.experts.down_proj" - ): - tt_state_dict[f"layers.{idx}.moe.experts.w2"] = value.transpose( - -2, -1 - ) - continue - - if hf_abstract_key not in self.from_hf_map: - continue - tt_key = self.from_hf_map[hf_abstract_key] - if tt_key is None: - continue - tt_key = tt_key.format(idx) - tt_state_dict[tt_key] = value - - else: - if hf_key not in self.from_hf_map: - continue - tt_key = self.from_hf_map[hf_key] - if tt_key is None: - continue - tt_value = value - # Conv3d weight (out, C, T, H, W) -> Linear weight (out, C*T*H*W) - # Plain reshape since both use channel-first (c pt ph pw) layout. - if hf_key == "model.visual.patch_embed.proj.weight": - tt_value = value.reshape(value.shape[0], -1) - tt_state_dict[tt_key] = tt_value - - if self.fuse_qkv and pending_qkv: - raise ValueError( - f"Incomplete Q/K/V projections for layers: {list(pending_qkv.keys())}" - ) - - return tt_state_dict diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index 96b9ec0b00..abbeaa91b8 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -460,6 +460,8 @@ def get_moe_model_nparams_and_flops( n_heads: int, head_dims: int, seq_len: int, + *, + num_full_attn: int | None = None, ) -> tuple[int, int]: """ Calculate nparams and nflops for MoE models. @@ -470,6 +472,11 @@ def get_moe_model_nparams_and_flops( n_heads: The number of attention heads. head_dims: The sum of qk and v head dimensions. seq_len: The sequence length in training configs. + num_full_attn: For hybrid models that mix full (O(L²)) + attention with linear (O(L)) attention, the number of layers using + softmax attention. Only these layers contribute the quadratic + attention FLOPs term. If None (default), all layers are assumed to + use full attention. Returns: Tuple of (nparams, num_flops_per_token): @@ -521,9 +528,10 @@ def get_moe_model_nparams_and_flops( nparams_for_matmul = nparams_dense + nparams_sparse_active else: nparams_for_matmul = nparams_dense - nparams_embedding + nparams_sparse_active + if num_full_attn is None: + num_full_attn = len(model_config.layers) num_flops_per_token = ( - 6 * nparams_for_matmul - + 6 * len(model_config.layers) * n_heads * head_dims * seq_len + 6 * nparams_for_matmul + 6 * num_full_attn * n_heads * head_dims * seq_len ) return nparams, num_flops_per_token From 03912f2e66554c16dbdd26de744251449d834873 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Tue, 2 Jun 2026 10:32:42 -0700 Subject: [PATCH 2/7] rebase refactor --- .../numerical_tests_qwen3_5.py | 12 +- .../numerical_tests_qwen3_5_shard.py | 44 +-- tests/integration_tests/models.py | 9 +- torchtitan/models/common/config_utils.py | 2 + torchtitan/models/common/decoder.py | 13 +- torchtitan/models/common/moe.py | 10 + torchtitan/models/common/moe_sharding.py | 8 + torchtitan/models/flux/trainer.py | 1 - torchtitan/models/qwen3_5/README.md | 6 +- torchtitan/models/qwen3_5/__init__.py | 210 +++++++----- torchtitan/models/qwen3_5/config_registry.py | 5 +- torchtitan/models/qwen3_5/model.py | 228 ++++++------- torchtitan/models/qwen3_5/parallelize.py | 93 ++++-- torchtitan/models/qwen3_5/sharding.py | 309 ++++++------------ .../models/qwen3_5/state_dict_adapter.py | 20 +- torchtitan/models/qwen3_5/vision_encoder.py | 158 ++++----- torchtitan/trainer.py | 25 +- 17 files changed, 530 insertions(+), 623 deletions(-) diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py index 7e5f345266..bb42c237cc 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py @@ -14,17 +14,18 @@ Usage: python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \ - --hf_model_path ../hf_models/Qwen/Qwen3.5-4B \ + --hf_model_path hf_assets/Qwen/Qwen3.5-4B \ --tt_checkpoint_path outputs/Qwen/qwen3_5_4b_dcp python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \ - --hf_model_path ../hf_models/Qwen/Qwen3.5-35B-A3B \ + --hf_model_path hf_assets/Qwen/Qwen3.5-35B-A3B \ --tt_checkpoint_path outputs/Qwen/qwen3_5_35b_a3b_dcp \ --model_flavor 35B-A3B """ import argparse import os +from typing import Any import torch import torch._dynamo @@ -88,7 +89,9 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): vision_to_patches, ) - processor = AutoProcessor.from_pretrained(hf_model_path) + # Annotate as Any: AutoProcessor.from_pretrained is typed Optional, which + # trips the .apply_chat_template call on environments without transformers stubs. + processor: Any = AutoProcessor.from_pretrained(hf_model_path) model_config = model_registry(model_flavor).model # pyrefly: ignore [missing-attribute] @@ -118,7 +121,6 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): ], } ] - # pyrefly: ignore [missing-attribute] hf_in = processor.apply_chat_template( messages, tokenize=True, @@ -263,7 +265,7 @@ def run_tt(model_flavor, checkpoint_path, tt_inputs, device): # Replace FlexAttention with SDPA for single-process inference # (unfused FlexAttention without torch.compile has poor fp16 numerics). for layer in model.layers.values(): - if layer.layer_type == "full_attention": + if layer.full_attn: layer.attn.inner_attention = ScaledDotProductAttention.Config().build() class _BidirectionalSDPA(torch.nn.Module): diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py index d63ea09a54..f9443f109c 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py @@ -7,8 +7,8 @@ """Numerical comparison across parallelism configs for Qwen3.5. -Feeds identical fake tokens across 4 configs and verifies logits match. -Requires 8 GPUs. +Feeds identical fake tokens across configs (no_parallel, FSDP, FSDP+EP, +FSDP+EP+TP) and verifies logits match. Requires up to 8 GPUs. Usage: python scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py @@ -36,11 +36,10 @@ from torchtitan.models.qwen3_5.parallelize import parallelize_qwen3_5 CONFIGS = [ - {"ngpu": 1, "tp": 1, "ep": 1, "cp": 1, "label": "no_parallel"}, - {"ngpu": 4, "tp": 1, "ep": 1, "cp": 1, "label": "fsdp"}, - {"ngpu": 8, "tp": 1, "ep": 4, "cp": 1, "label": "fsdp+ep"}, - {"ngpu": 8, "tp": 2, "ep": 2, "cp": 1, "label": "fsdp+ep+tp"}, - {"ngpu": 8, "tp": 2, "ep": 2, "cp": 2, "label": "fsdp+ep+tp+cp"}, + {"ngpu": 1, "tp": 1, "ep": 1, "label": "no_parallel"}, + {"ngpu": 4, "tp": 1, "ep": 1, "label": "fsdp"}, + {"ngpu": 8, "tp": 1, "ep": 4, "label": "fsdp+ep"}, + {"ngpu": 8, "tp": 2, "ep": 2, "label": "fsdp+ep+tp"}, ] @@ -51,7 +50,7 @@ def run_worker(args): world_size = dist.get_world_size() torch.cuda.set_device(rank) - dp_shard = world_size // (args.tp * args.cp) + dp_shard = world_size // args.tp seed = 42 torch.manual_seed(seed) @@ -65,7 +64,7 @@ def run_worker(args): parallel_dims = ParallelDims( dp_shard=dp_shard, dp_replicate=1, - cp=args.cp, + cp=1, tp=args.tp, pp=1, ep=args.ep, @@ -76,7 +75,6 @@ def run_worker(args): parallelism = ParallelismConfig( tensor_parallel_degree=args.tp, data_parallel_shard_degree=dp_shard, - context_parallel_degree=args.cp, expert_parallel_degree=args.ep, ) training = TrainingConfig( @@ -88,7 +86,7 @@ def run_worker(args): ) config.update_from_config( - trainer_config=type( + config=type( "C", (), { @@ -118,35 +116,15 @@ def run_worker(args): tokens = torch.randint(0, 248320, (1, seq_len), device="cuda") dist.broadcast(tokens, src=0) - extra_kwargs: dict = {} - if args.cp > 1: - # Shard tokens and create positions for CP - from torchtitan.distributed.context_parallel import ( - prepare_context_parallel_input, - ) - - positions = torch.arange(seq_len, device="cuda").unsqueeze(0) - labels = tokens.clone() - extra_kwargs = {"positions": positions} - tokens, labels, extra_kwargs = prepare_context_parallel_input( - tokens, - labels, - extra_kwargs, - parallel_dims.get_mesh("cp"), - torch.device("cuda"), - ) - with torch.no_grad(): output = model( tokens, special_tokens={"image_id": 151859, "video_id": 151860}, - **extra_kwargs, ) if isinstance(output, DTensor): output = output.full_tensor() - # With CP, each rank has partial output — gather first token's logits from rank 0's CP portion logits = output[0, 0, :10].float().tolist() if rank == 0: @@ -172,12 +150,11 @@ def main(): "--worker", f"--tp={cfg['tp']}", f"--ep={cfg['ep']}", - f"--cp={cfg['cp']}", f"--output={outfile}", ] print( f"Running {cfg['label']} (ngpu={cfg['ngpu']}, " - f"tp={cfg['tp']}, ep={cfg['ep']}, cp={cfg['cp']})..." + f"tp={cfg['tp']}, ep={cfg['ep']})..." ) result = subprocess.run( cmd, @@ -216,7 +193,6 @@ def main(): parser.add_argument("--worker", action="store_true") parser.add_argument("--tp", type=int) parser.add_argument("--ep", type=int) - parser.add_argument("--cp", type=int, default=1) parser.add_argument("--output", type=str) args = parser.parse_args() diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 2c0078195f..36365d4f6f 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -143,14 +143,15 @@ def build_model_tests_list() -> list[OverrideDefinitions]: [ [ "--module qwen3_5 --config qwen35_debugmodel_moe", - "--parallelism.data_parallel_shard_degree 4", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.pipeline_parallel_degree 2", "--parallelism.tensor_parallel_degree 2", "--parallelism.expert_parallel_degree 4", ], ], - "Qwen3.5 MoE FSDP+TP+EP", - "qwen3_5_moe_fsdp+tp+ep", - ngpu=8, + "Qwen3.5 MoE FSDP+TP+EP+PP", + "qwen3_5_moe_fsdp+tp+ep+pp", + ngpu=32, ), # Integration Test Cases for gpt-oss # TODO: re-enable compile after fixing diff --git a/torchtitan/models/common/config_utils.py b/torchtitan/models/common/config_utils.py index a8f320ea77..cf4e40cad9 100644 --- a/torchtitan/models/common/config_utils.py +++ b/torchtitan/models/common/config_utils.py @@ -157,6 +157,7 @@ def make_moe_config( router: TokenChoiceTopKRouter.Config, experts: GroupedExperts.Config, shared_experts: FeedForward.Config | None = None, + shared_expert_gate: Module.Config | None = None, load_balance_coeff: float | None = 1e-3, ) -> MoE.Config: """Build a fully-specified MoE.Config.""" @@ -166,6 +167,7 @@ def make_moe_config( router=router, experts=experts, shared_experts=shared_experts, + shared_expert_gate=shared_expert_gate, ) diff --git a/torchtitan/models/common/decoder.py b/torchtitan/models/common/decoder.py index b0ff896e84..7985c19ffe 100644 --- a/torchtitan/models/common/decoder.py +++ b/torchtitan/models/common/decoder.py @@ -116,7 +116,18 @@ def update_from_config( tp = parallelism.tensor_parallel_degree if tp > 1: - attention = self.layers[0].attention + attention = next( + ( + l.attention + for l in self.layers + if getattr(l, "attention", None) is not None + ), + None, + ) + if attention is None: + raise ValueError( + "No layer with attention config found for TP validation." + ) n_heads = attention.n_heads n_kv_heads = getattr(attention, "n_kv_heads", None) or n_heads if n_heads % tp != 0: diff --git a/torchtitan/models/common/moe.py b/torchtitan/models/common/moe.py index bb72f31995..cdddd4ce9f 100644 --- a/torchtitan/models/common/moe.py +++ b/torchtitan/models/common/moe.py @@ -323,6 +323,7 @@ class Config(Module.Config): router: TokenChoiceTopKRouter.Config load_balance_coeff: float | None = 1e-3 shared_experts: FeedForward.Config | None = None + shared_expert_gate: Module.Config | None = None def __init__(self, config: Config): super().__init__() @@ -333,6 +334,11 @@ def __init__(self, config: Config): self.shared_experts = ( config.shared_experts.build() if config.shared_experts is not None else None ) + self.shared_expert_gate = ( + config.shared_expert_gate.build() + if config.shared_expert_gate is not None + else None + ) # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) # NOTE: tokens_per_expert_E is accumulated in the model forward pass. @@ -447,6 +453,10 @@ def _generate_routing_map( sync_combine() if shared_out_BLD is not None: + if self.shared_expert_gate is not None: + shared_out_BLD = ( + torch.sigmoid(self.shared_expert_gate(x_BLD)) * shared_out_BLD + ) out_BLD = out_BLD + shared_out_BLD return out_BLD diff --git a/torchtitan/models/common/moe_sharding.py b/torchtitan/models/common/moe_sharding.py index ebe0da6452..8953d2436d 100644 --- a/torchtitan/models/common/moe_sharding.py +++ b/torchtitan/models/common/moe_sharding.py @@ -241,6 +241,14 @@ def set_moe_sharding_config( enable_ep=enable_ep, enable_sp=enable_sp ) + if getattr(moe_cfg, "shared_expert_gate", None) is not None: + moe_cfg.shared_expert_gate.sharding_config = ShardingConfig( + state_shardings={ + "weight": dense_param_placement(tp=Replicate()), + "bias": dense_param_placement(tp=Replicate()), + } + ) + # Routed experts: local_map converts DTensor inputs to local for # dispatch/compute/combine, then wraps local output as DTensor(Partial). # Routed experts: the three things that differ between EP and TP-only diff --git a/torchtitan/models/flux/trainer.py b/torchtitan/models/flux/trainer.py index ec2e44a6d9..7dd5c4ebe7 100644 --- a/torchtitan/models/flux/trainer.py +++ b/torchtitan/models/flux/trainer.py @@ -287,7 +287,6 @@ def train_step( if self.gradient_accumulation_steps > 1: raise ValueError("FLUX doesn't support gradient accumulation for now.") - # pyrefly: ignore [no-matching-overload] input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict=input_dict, labels=labels) diff --git a/torchtitan/models/qwen3_5/README.md b/torchtitan/models/qwen3_5/README.md index 9ae6f58fc3..aed8135752 100644 --- a/torchtitan/models/qwen3_5/README.md +++ b/torchtitan/models/qwen3_5/README.md @@ -66,10 +66,9 @@ pip install flash-linear-attention | Feature | Notes | |---------|-------| -| FSDP / HSDP | Single `apply_fsdp` call covers both decoder and vision encoder | +| FSDP / HSDP | Decoder sharded per-layer; vision encoder sharded as a single unit (one AllGather) | | Tensor Parallelism (TP) | With Sequence Parallel; head-sharded TP on GatedDeltaNet projections | | Expert Parallelism (EP) | For MoE variants | -| Context Parallel (CP) | Text-only; full attention uses ring attention, GatedDeltaNet allgathers full sequence | | Pipeline Parallel (PP) | Vision encoder assigned to first stage; 1F1B and Interleaved1F1B schedules | | Sample Packing | Configurable via `packing_buffer_size` in dataloader config | @@ -77,7 +76,7 @@ pip install flash-linear-attention End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**. -Parallelism correctness: bitwise identical logits across no-parallel, FSDP, FSDP+EP, FSDP+EP+TP, and FSDP+EP+TP+CP configs. +Parallelism correctness: bit-identical logits (max diff `0.0`) across no-parallel, FSDP, FSDP+EP, and FSDP+EP+TP configs. Test scripts: - `scripts/checkpoint_conversion/numerical_tests_qwen3_5.py` — HF vs TT comparison @@ -86,3 +85,4 @@ Test scripts: ## TODO - Add video dataset training configs +- Add Context Parallel (CP) support diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py index 42715084f1..dba06463f4 100644 --- a/torchtitan/models/qwen3_5/__init__.py +++ b/torchtitan/models/qwen3_5/__init__.py @@ -6,11 +6,11 @@ from collections.abc import Callable from functools import partial +from typing import Literal import torch.nn as nn from torchtitan.components.optimizer import register_moe_load_balancing_hook -from torchtitan.components.quantization import QuantizationConverter from torchtitan.models.common import Embedding, Linear, RoPE # noqa: F401 from torchtitan.models.common.config_utils import ( @@ -20,20 +20,32 @@ make_moe_config, make_router_config, ) +from torchtitan.models.common.nn_modules import LayerNorm from torchtitan.models.common.param_init import depth_scaled_std # noqa: F401 +from torchtitan.models.utils import validate_converter_order +from torchtitan.protocols.model import ModelConfigConverter from torchtitan.protocols.model_spec import ModelSpec from .model import ( + GatedDeltaKernel, GatedDeltaNet, OffsetRMSNorm, Qwen35Attention, Qwen35Model, Qwen35TransformerBlock, + RMSNormGated, ) from .parallelize import parallelize_qwen3_5, pipeline_qwen3_5 from .state_dict_adapter import Qwen35StateDictAdapter -from .vision_encoder import Qwen35VisionEncoder +from .vision_encoder import ( + PatchMerger, + Qwen35VisionEncoder, + VisionAttention, + VisionMLP, + VisionRotaryEmbedding, + VisionTransformerBlock, +) __all__ = [ "parallelize_qwen3_5", @@ -79,9 +91,9 @@ def _depth_init(layer_id: int) -> dict[str, Callable]: def _depth_experts_init(layer_id: int) -> dict[str, Callable]: return { - "w1": partial(nn.init.trunc_normal_, std=0.02), - "w2": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), - "w3": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), + "w1_EFD": partial(nn.init.trunc_normal_, std=0.02), + "w2_EDF": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), + "w3_EFD": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)), } @@ -102,7 +114,7 @@ def _offset_norm(dim: int) -> OffsetRMSNorm.Config: return OffsetRMSNorm.Config(dim=dim, eps=_EPS, param_init=_OFFSET_NORM_INIT) -def _vision_encoder_config( +def _qwen35_vision_encoder_config( *, dim: int, ffn_dim: int, @@ -113,30 +125,51 @@ def _vision_encoder_config( spatial_merge_size: int, out_hidden_size: int, num_position_embeddings: int, + layer_norm_eps: float = 1e-6, + rope_theta: float = 10000.0, in_channels: int = 3, ) -> Qwen35VisionEncoder.Config: """Build a fully-specified Qwen35VisionEncoder.Config.""" patch_dim = in_channels * temporal_patch_size * patch_size * patch_size merged_hidden_size = dim * (spatial_merge_size**2) + head_dim = dim // num_heads + _norm = LayerNorm.Config(normalized_shape=dim, eps=layer_norm_eps) return Qwen35VisionEncoder.Config( dim=dim, - ffn_dim=ffn_dim, num_layers=num_layers, num_heads=num_heads, patch_size=patch_size, temporal_patch_size=temporal_patch_size, + in_channels=in_channels, spatial_merge_size=spatial_merge_size, - out_hidden_size=out_hidden_size, num_position_embeddings=num_position_embeddings, patch_embed_proj=_linear(patch_dim, dim), - attn_wq=_linear(dim, dim), - attn_wk=_linear(dim, dim), - attn_wv=_linear(dim, dim), - attn_proj=_linear(dim, dim), - mlp_fc1=_linear(dim, ffn_dim), - mlp_fc2=_linear(ffn_dim, dim), - merger_fc1=_linear(merged_hidden_size, merged_hidden_size), - merger_fc2=_linear(merged_hidden_size, out_hidden_size), + block=VisionTransformerBlock.Config( + norm1=_norm, + norm2=_norm, + attn=VisionAttention.Config( + dim=dim, + num_heads=num_heads, + wq=_linear(dim, dim), + wk=_linear(dim, dim), + wv=_linear(dim, dim), + proj=_linear(dim, dim), + ), + mlp=VisionMLP.Config( + fc1=_linear(dim, ffn_dim), + fc2=_linear(ffn_dim, dim), + ), + ), + rotary_pos_emb=VisionRotaryEmbedding.Config( + dim=head_dim // 2, theta=rope_theta + ), + merger=PatchMerger.Config( + spatial_merge_size=spatial_merge_size, + merged_hidden_size=merged_hidden_size, + norm=LayerNorm.Config(normalized_shape=dim, eps=layer_norm_eps), + fc1=_linear(merged_hidden_size, merged_hidden_size), + fc2=_linear(merged_hidden_size, out_hidden_size), + ), param_init=_POS_EMBED_INIT, ) @@ -152,7 +185,7 @@ def _qwen35_attention_config( layer_id: int, ) -> Qwen35Attention.Config: """Build a fully-specified Qwen35Attention.Config.""" - inner_attention, mask_type = get_qwen35_attention_config(attn_backend) + inner_attention, mask_type = get_attention_config(attn_backend) return Qwen35Attention.Config( n_heads=n_heads, n_kv_heads=n_kv_heads, @@ -193,19 +226,35 @@ def _qwen35_deltanet_config( key_head_dim: int, value_head_dim: int, layer_id: int, - fla_backend: str = "fla_chunked", + fla_backend: Literal[ + "fla_chunked", "fla_fused_recurrent", "torch_naive" + ] = "fla_chunked", ) -> GatedDeltaNet.Config: """Build a fully-specified GatedDeltaNet.Config.""" + key_dim = n_key_heads * key_head_dim + value_dim = n_value_heads * value_head_dim + + def _proj(in_f: int, out_f: int, init: dict) -> Linear.Config: + return Linear.Config( + in_features=in_f, out_features=out_f, bias=False, param_init=init + ) + return GatedDeltaNet.Config( - dim=dim, - n_key_heads=n_key_heads, - n_value_heads=n_value_heads, key_head_dim=key_head_dim, value_head_dim=value_head_dim, - fla_backend=fla_backend, - in_proj_init=_LINEAR_INIT, - out_proj_init=_depth_init(layer_id), - norm_init={"weight": nn.init.ones_}, + in_proj_q=_proj(dim, key_dim, _LINEAR_INIT), + in_proj_k=_proj(dim, key_dim, _LINEAR_INIT), + in_proj_v=_proj(dim, value_dim, _LINEAR_INIT), + in_proj_z=_proj(dim, value_dim, _LINEAR_INIT), + in_proj_a=_proj(dim, n_value_heads, _LINEAR_INIT), + in_proj_b=_proj(dim, n_value_heads, _LINEAR_INIT), + kernel=GatedDeltaKernel.Config(backend=fla_backend), + norm=RMSNormGated.Config( + dim=value_head_dim, + eps=1e-6, + param_init={"weight": nn.init.ones_}, + ), + out_proj=_proj(value_dim, dim, _depth_init(layer_id)), param_init={ "A_log": _a_log_init, "dt_bias": nn.init.ones_, @@ -228,24 +277,14 @@ def _build_qwen35_layers( value_head_dim: int, full_attention_interval: int = 4, attn_backend: str, - fla_backend: str = "fla_chunked", + fla_backend: Literal[ + "fla_chunked", "fla_fused_recurrent", "torch_naive" + ] = "fla_chunked", ) -> list[Qwen35TransformerBlock.Config]: """Build per-layer configs for dense Qwen3.5 models.""" - # Shared attention config — set on ALL layers so the trainer can read - # attn_config.inner_attention and mask_type from any layer. - shared_attn_config = _qwen35_attention_config( - dim=dim, - n_heads=n_heads, - n_kv_heads=n_kv_heads, - head_dim=head_dim, - rotary_dim=rotary_dim, - attn_backend=attn_backend, - layer_id=0, - ) layers = [] for layer_id in range(n_layers): is_full = (layer_id + 1) % full_attention_interval == 0 - layer_type = "full_attn" if is_full else "linear_attn" attention = ( _qwen35_attention_config( @@ -258,7 +297,7 @@ def _build_qwen35_layers( layer_id=layer_id, ) if is_full - else shared_attn_config + else None ) deltanet = ( _qwen35_deltanet_config( @@ -276,9 +315,8 @@ def _build_qwen35_layers( layers.append( Qwen35TransformerBlock.Config( - layer_type=layer_type, attention=attention, - deltanet=deltanet, + delta_net=deltanet, feed_forward=make_ffn_config( dim=dim, hidden_dim=hidden_dim, @@ -310,24 +348,16 @@ def _build_qwen35_moe_layers( value_head_dim: int, full_attention_interval: int = 4, attn_backend: str, - fla_backend: str = "fla_chunked", + fla_backend: Literal[ + "fla_chunked", "fla_fused_recurrent", "torch_naive" + ] = "fla_chunked", moe_comm_backend: str = "standard", non_blocking_capacity_factor: float | None = None, ) -> list[Qwen35TransformerBlock.Config]: """Build per-layer configs for MoE Qwen3.5 models with shared expert.""" - shared_attn_config = _qwen35_attention_config( - dim=dim, - n_heads=n_heads, - n_kv_heads=n_kv_heads, - head_dim=head_dim, - rotary_dim=rotary_dim, - attn_backend=attn_backend, - layer_id=0, - ) layers = [] for layer_id in range(n_layers): is_full = (layer_id + 1) % full_attention_interval == 0 - layer_type = "full_attn" if is_full else "linear_attn" attention = ( _qwen35_attention_config( @@ -340,7 +370,7 @@ def _build_qwen35_moe_layers( layer_id=layer_id, ) if is_full - else shared_attn_config + else None ) deltanet = ( _qwen35_deltanet_config( @@ -358,9 +388,8 @@ def _build_qwen35_moe_layers( layers.append( Qwen35TransformerBlock.Config( - layer_type=layer_type, attention=attention, - deltanet=deltanet, + delta_net=deltanet, moe=make_moe_config( num_experts=num_experts, router=make_router_config( @@ -381,17 +410,17 @@ def _build_qwen35_moe_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), - ), - shared_ffn=make_ffn_config( - dim=dim, - hidden_dim=shared_expert_hidden_dim, - w1_param_init=_LINEAR_INIT, - w2w3_param_init=_depth_init(layer_id), - ), - shared_gate=Linear.Config( - in_features=dim, - out_features=1, - param_init=_LINEAR_INIT, + shared_experts=make_ffn_config( + dim=dim, + hidden_dim=shared_expert_hidden_dim, + w1_param_init=_LINEAR_INIT, + w2w3_param_init=_depth_init(layer_id), + ), + shared_expert_gate=Linear.Config( + in_features=dim, + out_features=1, + param_init=_LINEAR_INIT, + ), ), attention_norm=_offset_norm(dim), ffn_norm=_offset_norm(dim), @@ -445,7 +474,7 @@ def _debugmodel(attn_backend: str) -> Qwen35Model.Config: value_head_dim=64, fla_backend="fla_chunked", ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=256, ffn_dim=512, num_layers=4, @@ -510,7 +539,7 @@ def _debugmodel_moe( moe_comm_backend=moe_comm_backend, fla_backend="fla_chunked", ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=256, ffn_dim=512, num_layers=2, @@ -572,7 +601,7 @@ def _0_8b(attn_backend: str) -> Qwen35Model.Config: key_head_dim=128, value_head_dim=128, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=768, ffn_dim=3072, num_layers=12, @@ -634,7 +663,7 @@ def _2b(attn_backend: str) -> Qwen35Model.Config: key_head_dim=128, value_head_dim=128, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1024, ffn_dim=4096, num_layers=24, @@ -695,7 +724,7 @@ def _4b(attn_backend: str) -> Qwen35Model.Config: key_head_dim=128, value_head_dim=128, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1024, ffn_dim=4096, num_layers=24, @@ -752,7 +781,7 @@ def _9b(attn_backend: str) -> Qwen35Model.Config: key_head_dim=128, value_head_dim=128, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1152, ffn_dim=4304, num_layers=27, @@ -809,7 +838,7 @@ def _27b(attn_backend: str) -> Qwen35Model.Config: key_head_dim=128, value_head_dim=128, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1152, ffn_dim=4304, num_layers=27, @@ -873,7 +902,7 @@ def _35b_a3b( value_head_dim=128, moe_comm_backend=moe_comm_backend, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1152, ffn_dim=4304, num_layers=27, @@ -937,7 +966,7 @@ def _122b_a10b( value_head_dim=128, moe_comm_backend=moe_comm_backend, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1152, ffn_dim=4304, num_layers=27, @@ -1001,7 +1030,7 @@ def _397b_a17b( value_head_dim=128, moe_comm_backend=moe_comm_backend, ), - vision_encoder=_vision_encoder_config( + vision_encoder=_qwen35_vision_encoder_config( dim=1152, ffn_dim=4304, num_layers=27, @@ -1017,16 +1046,16 @@ def _397b_a17b( qwen3_5_configs = { - "debugmodel": _debugmodel_qwen35, - "debugmodel_moe": _debugmodel_moe_qwen35, - "0.8B": _0_8b_qwen35, - "2B": _2b_qwen35, - "4B": _4b_qwen35, - "9B": _9b_qwen35, - "27B": _27b_qwen35, - "35B-A3B": _35b_a3b_qwen35, - "122B-A10B": _122b_a10b_qwen35, - "397B-A17B": _397b_a17b_qwen35, + "debugmodel": _debugmodel, + "debugmodel_moe": _debugmodel_moe, + "0.8B": _0_8b, + "2B": _2b, + "4B": _4b, + "9B": _9b, + "27B": _27b, + "35B-A3B": _35b_a3b, + "122B-A10B": _122b_a10b, + "397B-A17B": _397b_a17b, } @@ -1034,15 +1063,16 @@ def model_registry( flavor: str, attn_backend: str = "sdpa", moe_comm_backend: str | None = None, - quantization: list[QuantizationConverter.Config] | None = None, + converters: list[ModelConfigConverter.Config] | None = None, ) -> ModelSpec: kwargs = dict(attn_backend=attn_backend) if moe_comm_backend is not None: kwargs["moe_comm_backend"] = moe_comm_backend config = qwen3_5_configs[flavor](**kwargs) - if quantization is not None: - for q in quantization: - q.build().convert(config) + if converters is not None: + validate_converter_order(converters) + for c in converters: + c.build().convert(config) # Detect MoE: check if any layer has moe config has_moe = any(getattr(layer, "moe", None) is not None for layer in config.layers) diff --git a/torchtitan/models/qwen3_5/config_registry.py b/torchtitan/models/qwen3_5/config_registry.py index 1af8f0c659..3bcb1c890c 100644 --- a/torchtitan/models/qwen3_5/config_registry.py +++ b/torchtitan/models/qwen3_5/config_registry.py @@ -78,12 +78,13 @@ def qwen35_debugmodel_moe() -> Trainer.Config: optimizer=OptimizersContainer.Config(lr=5e-3), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=2), training=TrainingConfig( - local_batch_size=1, + local_batch_size=2, seq_len=512, steps=10, ), parallelism=ParallelismConfig( - data_parallel_shard_degree=4, + data_parallel_shard_degree=2, + pipeline_parallel_degree=2, expert_parallel_degree=4, tensor_parallel_degree=2, ), diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py index ca59e8bd9b..033f58cec7 100644 --- a/torchtitan/models/qwen3_5/model.py +++ b/torchtitan/models/qwen3_5/model.py @@ -5,25 +5,27 @@ # LICENSE file in the root directory of this source tree. -import dataclasses from dataclasses import dataclass, field +from typing import Literal import torch import torch.nn.functional as F from torch import nn +from torch.distributed.tensor import DTensor +from torchtitan.models.common import Linear from torchtitan.models.common.attention import AttentionMasksType, BaseAttention from torchtitan.models.common.decoder import Decoder -from torchtitan.models.common.linear import Linear from torchtitan.models.common.rope import apply_rotary_emb_cos_sin from torchtitan.models.utils import get_moe_model_nparams_and_flops from torchtitan.protocols.module import Module -from torchtitan.tools.logging import logger from .sharding import set_qwen35_sharding_config from .vision_encoder import Qwen35VisionEncoder -_Conv1d = Module.from_nn_module(nn.Conv1d) + +class _Conv1d(nn.Conv1d, Module): + pass try: @@ -156,7 +158,9 @@ class Config(Module.Config): # "fla_chunked": parallel within chunks, fast for training (default) # "fla_fused_recurrent": token-by-token, lower memory for long sequences # "torch_naive": pure-Python reference, for numerical testing only - backend: str = "fla_chunked" + backend: Literal[ + "fla_chunked", "fla_fused_recurrent", "torch_naive" + ] = "fla_chunked" def __init__(self, config: Config): super().__init__() @@ -188,12 +192,12 @@ def forward( if self.backend == "fla_chunked": result = _fla_chunk_gated_delta_rule( - q, # pyrefly: ignore [bad-argument-type] - k, # pyrefly: ignore [bad-argument-count] + q, + k, v, g, beta, - use_qk_l2norm_in_kernel=True, # pyrefly: ignore [unexpected-keyword] + use_qk_l2norm_in_kernel=True, ) elif self.backend == "fla_fused_recurrent": result = _fla_fused_recurrent_gated_delta_rule( @@ -211,7 +215,6 @@ def forward( ) # FLA kernels return (output, final_state); we only need output - # pyrefly: ignore [unsupported-operation] return result[0] @@ -225,17 +228,20 @@ class GatedDeltaNet(Module): @dataclass(kw_only=True, slots=True) class Config(Module.Config): - dim: int - n_key_heads: int - n_value_heads: int key_head_dim: int value_head_dim: int conv_kernel_size: int = 4 - norm_eps: float = 1e-6 - fla_backend: str = "fla_chunked" - in_proj_init: dict - out_proj_init: dict - norm_init: dict + + # Sub-module configs + in_proj_q: Linear.Config + in_proj_k: Linear.Config + in_proj_v: Linear.Config + in_proj_z: Linear.Config + in_proj_a: Linear.Config + in_proj_b: Linear.Config + kernel: GatedDeltaKernel.Config + norm: RMSNormGated.Config + out_proj: Linear.Config def __init__(self, config: Config): super().__init__() @@ -243,45 +249,15 @@ def __init__(self, config: Config): self.value_head_dim = config.value_head_dim self.conv_kernel_size = config.conv_kernel_size - key_dim = config.n_key_heads * config.key_head_dim - value_dim = config.n_value_heads * config.value_head_dim + key_dim = config.in_proj_q.out_features + value_dim = config.in_proj_v.out_features - self.in_proj_q = Linear.Config( - in_features=config.dim, - out_features=key_dim, - bias=False, - param_init=config.in_proj_init, - ).build() - self.in_proj_k = Linear.Config( - in_features=config.dim, - out_features=key_dim, - bias=False, - param_init=config.in_proj_init, - ).build() - self.in_proj_v = Linear.Config( - in_features=config.dim, - out_features=value_dim, - bias=False, - param_init=config.in_proj_init, - ).build() - self.in_proj_z = Linear.Config( - in_features=config.dim, - out_features=value_dim, - bias=False, - param_init=config.in_proj_init, - ).build() - self.in_proj_a = Linear.Config( - in_features=config.dim, - out_features=config.n_value_heads, - bias=False, - param_init=config.in_proj_init, - ).build() - self.in_proj_b = Linear.Config( - in_features=config.dim, - out_features=config.n_value_heads, - bias=False, - param_init=config.in_proj_init, - ).build() + self.in_proj_q = config.in_proj_q.build() + self.in_proj_k = config.in_proj_k.build() + self.in_proj_v = config.in_proj_v.build() + self.in_proj_z = config.in_proj_z.build() + self.in_proj_a = config.in_proj_a.build() + self.in_proj_b = config.in_proj_b.build() self.conv_q = _Conv1d( in_channels=key_dim, @@ -308,27 +284,38 @@ def __init__(self, config: Config): padding=0, ) - self.A_log = nn.Parameter(torch.empty(config.n_value_heads)) - self.dt_bias = nn.Parameter(torch.empty(config.n_value_heads)) - - self.kernel = GatedDeltaKernel.Config(backend=config.fla_backend).build() + n_value_heads = value_dim // config.value_head_dim + self.A_log = nn.Parameter(torch.empty(n_value_heads)) + self.dt_bias = nn.Parameter(torch.empty(n_value_heads)) - self.norm = RMSNormGated.Config( - dim=config.value_head_dim, - eps=config.norm_eps, - param_init=config.norm_init, - ).build() - self.out_proj = Linear.Config( - in_features=value_dim, - out_features=config.dim, - bias=False, - param_init=config.out_proj_init, - ).build() + self.kernel = config.kernel.build() + self.norm = config.norm.build() + self.out_proj = config.out_proj.build() def _causal_conv(self, x: torch.Tensor, conv: nn.Module) -> torch.Tensor: # pyrefly: ignore [bad-argument-type] x = F.pad(x.transpose(1, 2), (self.conv_kernel_size - 1, 0)) - return F.silu(conv(x)).transpose(1, 2) + if isinstance(x, DTensor): + # TODO: Remove once DTensor Conv1d dispatch handles sharded groups. + mesh, plc = x.device_mesh, x.placements + w: torch.Tensor = conv.weight # pyrefly: ignore [bad-assignment] + if isinstance(w, DTensor): + w = w.to_local() + local_groups = w.size(0) + # pyrefly: ignore [no-matching-overload] + out = F.conv1d( + x.to_local(), + w, + None, + conv.stride, + conv.padding, + conv.dilation, + local_groups, + ) + x = DTensor.from_local(out, mesh, plc, run_check=False) + else: + x = conv(x) + return F.silu(x).transpose(1, 2) def forward(self, x: torch.Tensor) -> torch.Tensor: bs, seqlen, _ = x.shape @@ -450,35 +437,30 @@ class Qwen35TransformerBlock(Module): """Hybrid transformer block for Qwen3.5. Each layer uses either full attention (Qwen35Attention) or linear - attention (GatedDeltaNet), determined by ``layer_type`` in config. + attention (GatedDeltaNet), determined by which config is provided. Both types share the same FFN/MoE structure. """ @dataclass(kw_only=True, slots=True) class Config(Module.Config): - layer_type: str # "full_attn" or "linear_attn" attention: Qwen35Attention.Config | None = None - deltanet: GatedDeltaNet.Config | None = None + delta_net: GatedDeltaNet.Config | None = None feed_forward: Module.Config | None = None moe: Module.Config | None = None - shared_ffn: Module.Config | None = None - shared_gate: Linear.Config | None = None attention_norm: OffsetRMSNorm.Config ffn_norm: OffsetRMSNorm.Config def __init__(self, config: Config): super().__init__() - self.layer_type = config.layer_type + self.full_attn = config.attention is not None - if config.layer_type == "full_attn": - assert config.attention is not None - self.attn = config.attention.build() + if self.full_attn: + self.attn = config.attention.build() # pyrefly: ignore [missing-attribute] else: - assert config.deltanet is not None - self.attn = config.deltanet.build() + assert config.delta_net is not None + self.attn = config.delta_net.build() self.moe_enabled = config.moe is not None - self.shared_expert_enabled = config.shared_ffn is not None if self.moe_enabled: # pyrefly: ignore [missing-attribute] self.moe = config.moe.build() @@ -486,12 +468,6 @@ def __init__(self, config: Config): assert config.feed_forward is not None self.feed_forward = config.feed_forward.build() - if self.shared_expert_enabled: - # pyrefly: ignore [missing-attribute] - self.shared_ffn = config.shared_ffn.build() - assert config.shared_gate is not None - self.shared_gate = config.shared_gate.build() - self.attention_norm = config.attention_norm.build() self.ffn_norm = config.ffn_norm.build() @@ -503,7 +479,7 @@ def forward( positions: torch.Tensor | None = None, ) -> torch.Tensor: h = self.attention_norm(x) - if self.layer_type == "full_attn": + if self.full_attn: h = self.attn(h, freqs_cis, attention_masks, positions) else: h = self.attn(h) @@ -511,12 +487,7 @@ def forward( h = self.ffn_norm(x) if self.moe_enabled: - moe_out = self.moe(h) - if self.shared_expert_enabled: - shared_out = torch.sigmoid(self.shared_gate(h)) * self.shared_ffn(h) - x = x + moe_out + shared_out - else: - x = x + moe_out + x = x + self.moe(h) else: x = x + self.feed_forward(h) return x @@ -567,26 +538,11 @@ class Config(Decoder.Config): def update_from_config( self, *, - trainer_config, + config, **kwargs, ) -> None: - training = trainer_config.training - parallelism = trainer_config.parallelism - debug = trainer_config.debug - seq_len = training.seq_len - if seq_len > self.rope.max_seq_len: - logger.warning( - f"Sequence length {seq_len} exceeds original maximum " - f"{self.rope.max_seq_len}." - ) - self.rope = dataclasses.replace(self.rope, max_seq_len=seq_len) - - for layer_cfg in self.layers: - moe_cfg = getattr(layer_cfg, "moe", None) - if moe_cfg is not None: - moe_cfg.router._debug_force_load_balance = ( - debug.moe_force_load_balance - ) + Decoder.Config.update_from_config(self, config=config, **kwargs) + parallelism = config.parallelism tp = parallelism.tensor_parallel_degree if tp > 1: @@ -603,21 +559,25 @@ def update_from_config( f"n_kv_heads ({attn_cfg.n_kv_heads})." ) dn_cfg = next( - (l.deltanet for l in self.layers if l.deltanet is not None), + (l.delta_net for l in self.layers if l.delta_net is not None), None, ) - if dn_cfg is not None and ( - dn_cfg.n_key_heads % tp != 0 or dn_cfg.n_value_heads % tp != 0 - ): - raise ValueError( - f"tensor_parallel_degree ({tp}) must divide " - f"n_key_heads ({dn_cfg.n_key_heads}) and " - f"n_value_heads ({dn_cfg.n_value_heads})." + if dn_cfg is not None: + n_key_heads = dn_cfg.in_proj_q.out_features // dn_cfg.key_head_dim + n_value_heads = ( + dn_cfg.in_proj_v.out_features // dn_cfg.value_head_dim ) + if n_key_heads % tp != 0 or n_value_heads % tp != 0: + raise ValueError( + f"tensor_parallel_degree ({tp}) must divide " + f"n_key_heads ({n_key_heads}) and " + f"n_value_heads ({n_value_heads})." + ) set_qwen35_sharding_config( self, loss_parallel=not parallelism.disable_loss_parallel, + enable_ep=parallelism.expert_parallel_degree > 1, ) def get_nparams_and_flops( @@ -631,7 +591,7 @@ def get_nparams_and_flops( n_heads = attn_cfg.n_heads # pyrefly: ignore [missing-attribute] head_dim = attn_cfg.head_dim - num_full_attn = sum(1 for l in self.layers if l.layer_type == "full_attn") + num_full_attn = sum(1 for l in self.layers if l.attention is not None) return get_moe_model_nparams_and_flops( self, model, @@ -766,9 +726,9 @@ def _compute_mrope_freqs( video_index += 1 llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, + int(t.item()), + int(h.item()) // spatial_merge_size, + int(w.item()) // spatial_merge_size, ) text_len = vision_start - pair_cursor @@ -786,34 +746,28 @@ def _compute_mrope_freqs( # [vision tokens] — 3D grid positions (T, H, W) grid_key = (llm_grid_t, llm_grid_h, llm_grid_w) if grid_key not in grid_cache: + hw = llm_grid_h * llm_grid_w t_index = ( torch.arange(llm_grid_t) .view(-1, 1) - # pyrefly: ignore [no-matching-overload] - .expand(-1, llm_grid_h * llm_grid_w) + .expand(-1, hw) .flatten() ) h_index = ( torch.arange(llm_grid_h) .view(1, -1, 1) - # pyrefly: ignore [no-matching-overload] .expand(llm_grid_t, -1, llm_grid_w) .flatten() ) w_index = ( torch.arange(llm_grid_w) .view(1, 1, -1) - # pyrefly: ignore [no-matching-overload] .expand(llm_grid_t, llm_grid_h, -1) .flatten() ) - # pyrefly: ignore [unsupported-operation] grid_cache[grid_key] = torch.stack([t_index, h_index, w_index]) doc_pos_ids_list.append( - # pyrefly: ignore [bad-index] - grid_cache[grid_key] - + text_len - + pos_id_offset + grid_cache[grid_key] + text_len + pos_id_offset ) pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w @@ -835,7 +789,11 @@ def _compute_mrope_freqs( position_ids[:, sample_i, :] = llm_positions.to(position_ids.device) # --- Compute interleaved MRoPE cos/sin from position IDs --- + # Convert to local — DTensor doesn't support fancy indexing with + # plain-tensor indices (cos_cache[t_pos], sin_cache[:, col][dim_pos]). freqs_cis = self.freqs_cis + if isinstance(freqs_cis, DTensor): + freqs_cis = freqs_cis.to_local() head_dim = freqs_cis.shape[-1] // 2 cos_cache = freqs_cis[:, :head_dim] sin_cache = freqs_cis[:, head_dim:] diff --git a/torchtitan/models/qwen3_5/parallelize.py b/torchtitan/models/qwen3_5/parallelize.py index b46d61e5a4..0416da20b8 100644 --- a/torchtitan/models/qwen3_5/parallelize.py +++ b/torchtitan/models/qwen3_5/parallelize.py @@ -11,7 +11,10 @@ (activation checkpointing, compile, FSDP) to the Qwen3.5 model. """ +import torch import torch.nn as nn +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy from torchtitan.config import ( ActivationCheckpointConfig, @@ -24,16 +27,40 @@ from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.compile import apply_compile -from torchtitan.distributed.context_parallel import apply_cp_to_forward +from torchtitan.distributed.fsdp import get_fsdp_reshard_after_forward_policy + from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.models.llama4.parallelize import apply_fsdp, apply_moe_ep_tp -from torchtitan.models.qwen3_5.sharding import ( - set_deltanet_sub_module_sharding, - set_vision_encoder_sub_module_sharding, -) +from torchtitan.models.llama4.parallelize import apply_fsdp +from torchtitan.models.qwen3_5.sharding import set_deltanet_conv1d_sharding from torchtitan.tools.logging import logger +def _apply_fsdp_to_vision_encoder( + vision_encoder: nn.Module, + dp_mesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + reshard_after_forward_policy: str = "default", + pp_enabled: bool = False, +): + """FSDP the vision encoder as a single unit. + + One AllGather for all vision params is more efficient than per-layer + sharding — the vision encoder is small relative to the decoder. + Must be called before apply_fsdp on the decoder. + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + reshard_after_forward = get_fsdp_reshard_after_forward_policy( + reshard_after_forward_policy, pp_enabled=pp_enabled + ) + fully_shard( + vision_encoder, + mesh=dp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + ) + + def parallelize_qwen3_5( model: nn.Module, *, @@ -51,50 +78,36 @@ def parallelize_qwen3_5( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + if parallelism.full_dtensor: + raise NotImplementedError("full_dtensor is not supported yet.") + model_compile_enabled = ( compile_config.enable and "model" in compile_config.components ) - # Context Parallel: wrap inner attention forward BEFORE TP so CP logic - # runs inside the local_map boundary on local tensors. - # Applies to full attention layers only — GatedDeltaNet is recurrent - # and allgathers the full sequence via cp=Replicate() in sharding. if parallel_dims.cp_enabled: - cp_mesh = parallel_dims.get_mesh("cp") - full_attn_inner_modules = [ - block.attn.inner_attention # pyrefly: ignore [missing-attribute] - for block in model.layers.values() # pyrefly: ignore [not-callable] - if block.layer_type == "full_attn" # pyrefly: ignore [missing-attribute] - ] - if full_attn_inner_modules: - apply_cp_to_forward(full_attn_inner_modules, cp_mesh) + raise NotImplementedError( + "Context Parallel is not yet supported for Qwen3.5. " + "GatedDeltaNet (75% of layers) requires full-sequence allgather, " + "and multimodal CP needs vision scatter before CP sharding." + ) - if parallel_dims.tp_enabled: + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: if parallelism.enable_async_tensor_parallel and not model_compile_enabled: raise RuntimeError("Async TP requires torch.compile") - tp_mesh = parallel_dims.get_mesh("tp") - - # For sub-modules built inline, set _sharding_config on built modules. + # Conv1d modules don't have Config fields, set sharding on built modules. # pyrefly: ignore [not-callable] for block in model.layers.values(): # pyrefly: ignore [missing-attribute] - if block.layer_type != "full_attn": + if not block.full_attn: # pyrefly: ignore [missing-attribute] - set_deltanet_sub_module_sharding(block.attn) - if model.vision_encoder is not None: - set_vision_encoder_sub_module_sharding(model.vision_encoder) + set_deltanet_conv1d_sharding(block.attn) # pyrefly: ignore [not-callable] - model.parallelize(tp_mesh) - maybe_enable_async_tp(parallelism, compile_config, tp_mesh) + model.parallelize(parallel_dims) - if parallel_dims.tp_enabled or parallel_dims.ep_enabled: - apply_moe_ep_tp( - model, - tp_mesh=parallel_dims.get_optional_mesh("tp"), - ep_mesh=parallel_dims.get_optional_mesh("ep"), - enable_sp=parallel_dims.tp_enabled, - ) + if parallel_dims.tp_enabled: + maybe_enable_async_tp(parallelism, compile_config, parallel_dims.get_mesh("tp")) if ac_config.mode != "none": apply_ac( @@ -123,6 +136,16 @@ def parallelize_qwen3_5( ) dp_mesh = parallel_dims.get_mesh(dp_mesh_names) + if model.vision_encoder is not None: + _apply_fsdp_to_vision_encoder( + model.vision_encoder, # pyrefly: ignore [bad-argument-type] + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce], + reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward, + pp_enabled=parallel_dims.pp_enabled, + ) + edp_mesh = None if parallel_dims.ep_enabled: edp_mesh_names = ( diff --git a/torchtitan/models/qwen3_5/sharding.py b/torchtitan/models/qwen3_5/sharding.py index 10d138bb23..7fc81f4e41 100644 --- a/torchtitan/models/qwen3_5/sharding.py +++ b/torchtitan/models/qwen3_5/sharding.py @@ -11,17 +11,13 @@ Full attention layers: TP on wq/wk/wv/wo with local_map for inner attention. GatedDeltaNet layers: head-sharded TP on projections (ColwiseParallel) and -out_proj (RowwiseParallel). Conv1d and FLA kernel forwards are wrapped for -DTensor→local conversion. +out_proj (RowwiseParallel). FLA kernel uses local_map for DTensor→local +conversion. Conv1d sharding is set on built modules. """ -import types from typing import TYPE_CHECKING -import torch.nn.functional as F - -from torch import nn -from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import Placement, Replicate, Shard from torchtitan.models.common.decoder_sharding import ( colwise_config, @@ -33,6 +29,7 @@ set_dense_ffn_sharding, set_gqa_inner_attention_local_map, ) +from torchtitan.models.common.moe_sharding import set_moe_sharding_config from torchtitan.protocols.sharding import LocalMapConfig, ShardingConfig if TYPE_CHECKING: @@ -42,8 +39,6 @@ Qwen35TransformerBlock, ) -TP = "tp" - _REPLICATE_PARAM = dense_param_placement(tp=Replicate()) _REPLICATE_STATE = ShardingConfig( state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM} @@ -59,10 +54,18 @@ ) +_GROUPED_EXPERTS_PARAM_LAYOUT: dict[str, Placement] = { + "w1_EFD": Shard(1), + "w2_EDF": Shard(2), + "w3_EFD": Shard(1), +} + + def set_qwen35_sharding_config( config: "Qwen35Model.Config", *, loss_parallel: bool, + enable_ep: bool, ) -> None: """Fill ``sharding_config`` on all Qwen3.5 sub-configs. @@ -70,10 +73,8 @@ def set_qwen35_sharding_config( stays Replicate so vision scatter and MRoPE can access the full sequence. The model forward redistributes to Shard(1) before entering the layers. """ - # SP on norm, lm_head, and layers + # SP on norm, lm_head, and layers. freqs_cis stays Replicate (set by base). set_decoder_sharding_config(config, loss_parallel=loss_parallel, enable_sp=True) - # Override: don't distribute freqs_cis — MRoPE indexes it with plain tensors - config.sharding_config = ShardingConfig() # Override tok_embeddings: output Replicate (not Shard(1)) for vision scatter config.tok_embeddings.sharding_config = ShardingConfig( state_shardings={"weight": dense_param_placement(tp=Shard(0))}, @@ -83,22 +84,23 @@ def set_qwen35_sharding_config( ) _set_vision_encoder_sharding(config.vision_encoder) for layer_cfg in config.layers: - _set_qwen35_layer_sharding(layer_cfg) + _set_qwen35_layer_sharding(layer_cfg, enable_ep=enable_ep) def _set_qwen35_layer_sharding( layer_cfg: "Qwen35TransformerBlock.Config", + *, + enable_ep: bool, ) -> None: norm = norm_config(enable_sp=True) layer_cfg.attention_norm.sharding_config = norm layer_cfg.ffn_norm.sharding_config = norm - if layer_cfg.layer_type == "full_attn": - assert layer_cfg.attention is not None + if layer_cfg.attention is not None: _set_full_attention_sharding(layer_cfg.attention) else: - assert layer_cfg.deltanet is not None - _set_deltanet_sharding(layer_cfg.deltanet) + assert layer_cfg.delta_net is not None + _set_deltanet_sharding(layer_cfg.delta_net) if layer_cfg.feed_forward is not None: set_dense_ffn_sharding( @@ -107,14 +109,13 @@ def _set_qwen35_layer_sharding( enable_sp=True, ) - if layer_cfg.shared_ffn is not None: - set_dense_ffn_sharding( - layer_cfg.shared_ffn, - attn_x_placement=Shard(1), + if layer_cfg.moe is not None: + set_moe_sharding_config( + layer_cfg.moe, + enable_ep=enable_ep, enable_sp=True, + expert_param_layout=_GROUPED_EXPERTS_PARAM_LAYOUT, ) - if layer_cfg.shared_gate is not None: - layer_cfg.shared_gate.sharding_config = _REPLICATE_STATE def _set_vision_encoder_sharding(ve_cfg) -> None: @@ -122,25 +123,43 @@ def _set_vision_encoder_sharding(ve_cfg) -> None: All activations flow as Replicate — no SP in the vision encoder. Linear layers are ColwiseParallel/RowwiseParallel for memory savings. - Norms and patch_embed are Replicate. pos_embed is distributed as - Replicate via state_shardings on the encoder config. + Norms are Replicate. pos_embed is Replicate via state_shardings. """ ve_cfg.sharding_config = ShardingConfig( state_shardings={"pos_embed": _REPLICATE_PARAM}, ) - ve_cfg.patch_embed_proj.sharding_config = _REPLICATE_STATE + # patch_embed receives plain pixel_values — wrap as DTensor(Replicate) + ve_cfg.patch_embed_proj.sharding_config = ShardingConfig( + state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM}, + in_src_shardings={"input": _REPLICATE_ACT}, + in_dst_shardings={"input": _REPLICATE_ACT}, + out_dst_shardings=_REPLICATE_ACT, + ) + + # Block sub-modules + block = ve_cfg.block + block.norm1.sharding_config = _REPLICATE_NORM + block.norm2.sharding_config = _REPLICATE_NORM - # Separate Q/K/V: colwise sharding - ve_cfg.attn_wq.sharding_config = colwise_config() - ve_cfg.attn_wk.sharding_config = colwise_config() - ve_cfg.attn_wv.sharding_config = colwise_config() - ve_cfg.attn_proj.sharding_config = rowwise_config(output_sp=False) - ve_cfg.mlp_fc1.sharding_config = colwise_config() - ve_cfg.mlp_fc2.sharding_config = rowwise_config(output_sp=False) + block.attn.sharding_config = ShardingConfig( + in_src_shardings={"rope_cache": _REPLICATE_ACT}, + in_dst_shardings={"rope_cache": _REPLICATE_ACT}, + ) + block.attn.wq.sharding_config = colwise_config() + block.attn.wk.sharding_config = colwise_config() + block.attn.wv.sharding_config = colwise_config() + block.attn.proj.sharding_config = rowwise_config(output_sp=False) + set_gqa_inner_attention_local_map(block.attn.inner_attention) + + block.mlp.fc1.sharding_config = colwise_config() + block.mlp.fc2.sharding_config = rowwise_config(output_sp=False) - ve_cfg.merger_fc1.sharding_config = colwise_config() - ve_cfg.merger_fc2.sharding_config = rowwise_config(output_sp=False) + # Merger sub-modules + merger = ve_cfg.merger + merger.norm.sharding_config = _REPLICATE_NORM + merger.fc1.sharding_config = colwise_config() + merger.fc2.sharding_config = rowwise_config(output_sp=False) def _set_full_attention_sharding( @@ -162,11 +181,12 @@ def _set_full_attention_sharding( attention_cfg.wv.sharding_config = colwise_config() attention_cfg.wo.sharding_config = rowwise_config(output_sp=True) + _head_plc = dense_activation_placement(tp=Shard(2)) qk_norm_sharding = ShardingConfig( state_shardings={"weight": _REPLICATE_PARAM}, - in_src_shardings={"input": dense_activation_placement(tp=Shard(2))}, - in_dst_shardings={"input": dense_activation_placement(tp=Shard(2))}, - out_dst_shardings=dense_activation_placement(tp=Shard(2)), + in_src_shardings={"input": _head_plc}, + in_dst_shardings={"input": _head_plc}, + out_dst_shardings=_head_plc, ) attention_cfg.q_norm.sharding_config = qk_norm_sharding attention_cfg.k_norm.sharding_config = qk_norm_sharding @@ -177,102 +197,16 @@ def _set_full_attention_sharding( def _set_deltanet_sharding(deltanet_cfg) -> None: """Sharding for GatedDeltaNet: head-sharded TP on projections. - Input is allgathered on both TP and CP (Shard(1)→Replicate) because - the recurrence needs the full sequence. Projections are ColwiseParallel - (head-sharded output). Conv1d and FLA kernels are wrapped for - DTensor→local conversion. out_proj is RowwiseParallel (reduce-scatter - back to Shard(1)). + Input is allgathered (Shard(1)→Replicate) so that the recurrence + sees the full sequence. Projections are ColwiseParallel (head-sharded + output). The FLA kernel runs on local tensors via local_map. + out_proj is RowwiseParallel (reduce-scatter back to Shard(1)). A_log and dt_bias are per-head parameters, Shard(0) on TP. - Sub-module sharding is set on built modules by - ``set_deltanet_sub_module_sharding`` before ``model.parallelize()``. - """ - deltanet_cfg.sharding_config = ShardingConfig( - state_shardings={ - "A_log": dense_param_placement(tp=Shard(0)), - "dt_bias": dense_param_placement(tp=Shard(0)), - }, - in_src_shardings={"x": dense_activation_placement(tp=Shard(1))}, - # cp=Replicate: GatedDeltaNet is recurrent — needs full sequence - in_dst_shardings={ - "x": dense_activation_placement(tp=Replicate(), cp=Replicate()) - }, - out_dst_shardings=dense_activation_placement(tp=Shard(1)), - ) - - -def set_vision_encoder_sub_module_sharding(vision_encoder) -> None: - """Set _sharding_config on vision encoder sub-modules built inline. - - Norms (LayerNorm) in VisionTransformerBlock and PatchMerger are created - via Module.from_nn_module(nn.LayerNorm) — not from config fields. - Must be called after model build but before model.parallelize(). - """ - for layer in vision_encoder.layers.values(): - for name in ("norm1", "norm2"): - child = getattr(layer, name, None) - if child is not None: - child._sharding_config = _REPLICATE_NORM - # VisionAttention: declare rope_cache as Replicate so plain - # rope_cache is wrapped as DTensor to match DTensor q/k. - layer.attn._sharding_config = ShardingConfig( - in_src_shardings={"rope_cache": _REPLICATE_ACT}, - in_dst_shardings={"rope_cache": _REPLICATE_ACT}, - ) - # FlexAttention: local_map to convert DTensor q/k/v to local. - # Same as set_gqa_inner_attention_local_map but on built module. - if hasattr(layer.attn, "flex_attention"): - qkv_plc = {TP: Shard(2)} - layer.attn.flex_attention._sharding_config = ShardingConfig( - local_map=LocalMapConfig( - # pyrefly: ignore [bad-argument-type] - in_placements=(qkv_plc, qkv_plc, qkv_plc), - # pyrefly: ignore [bad-argument-type] - out_placements=(qkv_plc,), - # pyrefly: ignore [bad-argument-type] - in_grad_placements=(qkv_plc, qkv_plc, qkv_plc), - ), - ) - # Merger norm - if hasattr(vision_encoder.merger, "norm"): - vision_encoder.merger.norm._sharding_config = _REPLICATE_NORM - # Merger GELU: set None to skip protocol wrapping. Per-layer mlp.act_fn - # doesn't need this because its parent VisionMLP has no _sharding_config, - # but the merger's children get processed due to merger.norm having one. - if hasattr(vision_encoder.merger, "act_fn"): - vision_encoder.merger.act_fn._sharding_config = None - # VisionRotaryEmbedding: don't set _sharding_config — wrapping forward - # would break RoPE compute. inv_freq stays as a plain buffer; the - # resulting rope_cache is wrapped as DTensor by VisionAttention's - # in_src_shardings. - - # pos_embed interpolation: F.interpolate's decomposition doesn't - # support DTensor. Wrap to convert pos_embed to local before use. - _wrap_pos_embed_for_interpolation(vision_encoder) - - # patch_embed (Linear): plain pixel_values in → DTensor(Replicate) out - vision_encoder.patch_embed._sharding_config = ShardingConfig( - state_shardings={ - "weight": _REPLICATE_PARAM, - "bias": _REPLICATE_PARAM, - }, - in_src_shardings={"input": _REPLICATE_ACT}, - in_dst_shardings={"input": _REPLICATE_ACT}, - out_dst_shardings=_REPLICATE_ACT, - ) - - -def set_deltanet_sub_module_sharding(deltanet_module) -> None: - """Set head-sharded TP on GatedDeltaNet sub-modules. - - Projections are ColwiseParallel (head-sharded output), out_proj is - RowwiseParallel (reduce-scatter to SP). Conv1d weights are Shard(0) - on the channel dim (matching head sharding). The conv and kernel - forwards are wrapped for DTensor→local conversion (depthwise conv - and FLA kernels don't support DTensor dispatch). - - Must be called after model build but before model.parallelize(). + Conv1d sharding is set on built modules by + ``set_deltanet_sub_module_sharding`` (Conv1d doesn't have Config). """ + # ColwiseParallel on all input projections for name in ( "in_proj_q", "in_proj_k", @@ -281,91 +215,58 @@ def set_deltanet_sub_module_sharding(deltanet_module) -> None: "in_proj_a", "in_proj_b", ): - getattr(deltanet_module, name)._sharding_config = colwise_config() + getattr(deltanet_cfg, name).sharding_config = colwise_config() - _conv_shard = ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Shard(0))}, + # RowwiseParallel on output projection (reduce-scatter to SP) + deltanet_cfg.out_proj.sharding_config = rowwise_config(output_sp=True) + + # RMSNormGated: per-head norm, weight Replicate, activations Shard(2) + _norm_plc = dense_activation_placement(tp=Shard(2)) + deltanet_cfg.norm.sharding_config = ShardingConfig( + state_shardings={"weight": _REPLICATE_PARAM}, + in_src_shardings={"x": _norm_plc, "gate": _norm_plc}, + in_dst_shardings={"x": _norm_plc, "gate": _norm_plc}, + out_dst_shardings=_norm_plc, ) - for name in ("conv_q", "conv_k", "conv_v"): - conv = getattr(deltanet_module, name) - conv._sharding_config = _conv_shard - _wrap_conv1d(conv) # GatedDeltaKernel: local_map converts DTensor q/k/v/g/beta to local - # for FLA kernels, same pattern as FlexAttention's local_map. - _kernel_plc = {TP: Shard(2)} - deltanet_module.kernel._sharding_config = ShardingConfig( + _kernel_plc = dense_activation_placement(tp=Shard(2)) + deltanet_cfg.kernel.sharding_config = ShardingConfig( + in_dst_shardings={ + "q": _kernel_plc, + "k": _kernel_plc, + "v": _kernel_plc, + "g": _kernel_plc, + "beta": _kernel_plc, + }, + out_src_shardings=_kernel_plc, local_map=LocalMapConfig( - # pyrefly: ignore [bad-argument-type] - in_placements=(_kernel_plc,) * 5, - # pyrefly: ignore [bad-argument-type] - out_placements=(_kernel_plc,), - # pyrefly: ignore [bad-argument-type] in_grad_placements=(_kernel_plc,) * 5, ), ) - deltanet_module.norm._sharding_config = _REPLICATE_STATE - deltanet_module.out_proj._sharding_config = rowwise_config(output_sp=True) + deltanet_cfg.sharding_config = ShardingConfig( + state_shardings={ + "A_log": dense_param_placement(tp=Shard(0)), + "dt_bias": dense_param_placement(tp=Shard(0)), + }, + in_src_shardings={"x": dense_activation_placement(tp=Shard(1))}, + in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, + out_dst_shardings=dense_activation_placement(tp=Shard(1)), + ) -def _wrap_conv1d(conv1d_module) -> None: - """Wrap depthwise Conv1d forward for DTensor→local conversion. +def set_deltanet_conv1d_sharding(deltanet_module) -> None: + """Set sharding on GatedDeltaNet sub-modules built inline. - DTensor dispatch for Conv1d doesn't handle sharded groups: nn.Conv1d - stores groups as a plain int, but when the weight is TP-sharded on - the channel dim, the local weight has fewer channels than groups. - This wrapper converts inputs/weights to local and uses the local - channel count as groups. + Conv1d modules don't have Config fields, so their sharding must be + set on the built modules. DTensor→local conversion for Conv1d is + handled in the model's _causal_conv (same pattern as GroupedExperts). - TODO: Remove once DTensor Conv1d dispatch handles sharded groups. - """ - original_forward = conv1d_module.forward.__func__ - - def safe_forward(self, x): - if isinstance(x, DTensor): - mesh, plc = x.device_mesh, x.placements - w = self.weight - if isinstance(w, DTensor): - w = w.to_local() - # self.groups is the global count; use local weight's channel dim - local_groups = w.shape[0] - out = F.conv1d( - x.to_local(), - w, - None, - self.stride, - self.padding, - self.dilation, - local_groups, - ) - return DTensor.from_local(out, mesh, plc, run_check=False) - return original_forward(self, x) - - conv1d_module.forward = types.MethodType(safe_forward, conv1d_module) - - -def _wrap_pos_embed_for_interpolation(vision_encoder) -> None: - """Wrap compute_position_embeddings to convert pos_embed to local. - - F.interpolate's decomposition uses _unsafe_index which doesn't support - DTensor. Since pos_embed is Replicate, to_local is a no-op for data. - - TODO: Remove once F.interpolate on FSDP2-managed DTensors is fixed upstream. + Must be called after model build but before model.parallelize(). """ - original_fn = vision_encoder.compute_position_embeddings.__func__ - - def safe_compute(self, grid_thw, max_num_patch): - pos = self.pos_embed - if isinstance(pos, DTensor): - mesh, plc = pos.device_mesh, pos.placements - self.pos_embed = nn.Parameter(pos.to_local(), requires_grad=False) - learned_pos, rope_cache = original_fn(self, grid_thw, max_num_patch) - self.pos_embed = pos - learned_pos = DTensor.from_local(learned_pos, mesh, plc, run_check=False) - return learned_pos, rope_cache - return original_fn(self, grid_thw, max_num_patch) - - vision_encoder.compute_position_embeddings = types.MethodType( - safe_compute, vision_encoder + _conv_shard = ShardingConfig( + state_shardings={"weight": dense_param_placement(tp=Shard(0))}, ) + for name in ("conv_q", "conv_k", "conv_v"): + getattr(deltanet_module, name)._sharding_config = _conv_shard diff --git a/torchtitan/models/qwen3_5/state_dict_adapter.py b/torchtitan/models/qwen3_5/state_dict_adapter.py index f8788d5c88..f3cb1c5236 100644 --- a/torchtitan/models/qwen3_5/state_dict_adapter.py +++ b/torchtitan/models/qwen3_5/state_dict_adapter.py @@ -72,10 +72,10 @@ def __init__(self, model_config: Qwen35Model.Config, hf_assets_path: str | None) "model.language_model.layers.{}.mlp.experts.down_proj": "layers.{}.moe.experts.w2", "model.language_model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", # MoE shared expert - "model.language_model.layers.{}.mlp.shared_expert.gate_proj.weight": "layers.{}.shared_ffn.w1.weight", - "model.language_model.layers.{}.mlp.shared_expert.up_proj.weight": "layers.{}.shared_ffn.w3.weight", - "model.language_model.layers.{}.mlp.shared_expert.down_proj.weight": "layers.{}.shared_ffn.w2.weight", - "model.language_model.layers.{}.mlp.shared_expert_gate.weight": "layers.{}.shared_gate.weight", + "model.language_model.layers.{}.mlp.shared_expert.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight", + "model.language_model.layers.{}.mlp.shared_expert.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight", + "model.language_model.layers.{}.mlp.shared_expert.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight", + "model.language_model.layers.{}.mlp.shared_expert_gate.weight": "layers.{}.moe.shared_expert_gate.weight", # Final norm and output "model.language_model.norm.weight": "norm.weight", "lm_head.weight": "lm_head.weight", @@ -286,9 +286,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key == "model.language_model.layers.{}.linear_attn.in_proj_qkv.weight" ): - dn = self.model_config.layers[int(idx)].deltanet - kd = dn.n_key_heads * dn.key_head_dim - vd = dn.n_value_heads * dn.value_head_dim + dn = self.model_config.layers[int(idx)].delta_net + kd = dn.in_proj_q.out_features + vd = dn.in_proj_v.out_features q, k, v = value.split([kd, kd, vd], dim=0) tt_state_dict[f"layers.{idx}.attn.in_proj_q.weight"] = q tt_state_dict[f"layers.{idx}.attn.in_proj_k.weight"] = k @@ -300,9 +300,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key == "model.language_model.layers.{}.linear_attn.conv1d.weight" ): - dn = self.model_config.layers[int(idx)].deltanet - kd = dn.n_key_heads * dn.key_head_dim - vd = dn.n_value_heads * dn.value_head_dim + dn = self.model_config.layers[int(idx)].delta_net + kd = dn.in_proj_q.out_features + vd = dn.in_proj_v.out_features cq, ck, cv = value.split([kd, kd, vd], dim=0) tt_state_dict[f"layers.{idx}.attn.conv_q.weight"] = cq tt_state_dict[f"layers.{idx}.attn.conv_k.weight"] = ck diff --git a/torchtitan/models/qwen3_5/vision_encoder.py b/torchtitan/models/qwen3_5/vision_encoder.py index e7625b5cdc..b720452037 100644 --- a/torchtitan/models/qwen3_5/vision_encoder.py +++ b/torchtitan/models/qwen3_5/vision_encoder.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field import torch import torch.nn as nn @@ -262,24 +262,26 @@ class PatchMerger(Module): two-layer MLP (fc1 → GELU → fc2). """ - def __init__( - self, - hidden_size: int, - out_hidden_size: int, - spatial_merge_size: int, - layer_norm_eps: float, - *, - fc1: Linear.Config, - fc2: Linear.Config, - ): + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + spatial_merge_size: int + merged_hidden_size: int + norm: LayerNorm.Config + fc1: Linear.Config + act_fn: GELU.Config = field( + default_factory=lambda: GELU.Config(approximate="tanh") + ) + fc2: Linear.Config + + def __init__(self, config: Config): super().__init__() - self.spatial_merge_size = spatial_merge_size - self.merged_hidden_size = hidden_size * (spatial_merge_size**2) + self.spatial_merge_size = config.spatial_merge_size + self.merged_hidden_size = config.merged_hidden_size - self.norm = LayerNorm(hidden_size, eps=layer_norm_eps) - self.linear_fc1 = fc1.build() - self.act_fn = GELU(approximate="tanh") - self.linear_fc2 = fc2.build() + self.norm = config.norm.build() + self.linear_fc1 = config.fc1.build() + self.act_fn = config.act_fn.build() + self.linear_fc2 = config.fc2.build() def forward(self, x: torch.Tensor) -> torch.Tensor: """Merge spatial patches and project to output dimension. @@ -304,26 +306,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VisionAttention(Module): """Multi-head attention with FlexAttention for efficient batched processing.""" - def __init__( - self, - dim: int, - num_heads: int, - *, - wq: Linear.Config, - wk: Linear.Config, - wv: Linear.Config, - proj: Linear.Config, - ): + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + dim: int + num_heads: int + wq: Linear.Config + wk: Linear.Config + wv: Linear.Config + proj: Linear.Config + inner_attention: Module.Config = field(default_factory=FlexAttention.Config) + + def __init__(self, config: Config): super().__init__() - self.dim = dim - self.num_heads = num_heads + self.dim = config.dim + self.num_heads = config.num_heads self.head_dim = self.dim // self.num_heads - self.wq = wq.build() - self.wk = wk.build() - self.wv = wv.build() - self.proj = proj.build() - self.flex_attention = FlexAttention.Config().build() + self.wq = config.wq.build() + self.wk = config.wk.build() + self.wv = config.wv.build() + self.proj = config.proj.build() + self.flex_attention = config.inner_attention.build() def forward( self, @@ -369,26 +372,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VisionTransformerBlock(Module): """Single transformer block for vision encoder.""" - def __init__( - self, - dim: int, - num_heads: int, - layer_norm_eps: float, - *, - attn_wq: Linear.Config, - attn_wk: Linear.Config, - attn_wv: Linear.Config, - attn_proj: Linear.Config, - mlp_fc1: Linear.Config, - mlp_fc2: Linear.Config, - ): + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + norm1: LayerNorm.Config + norm2: LayerNorm.Config + attn: VisionAttention.Config + mlp: VisionMLP.Config + + def __init__(self, config: Config): super().__init__() - self.norm1 = LayerNorm(dim, eps=layer_norm_eps) - self.norm2 = LayerNorm(dim, eps=layer_norm_eps) - self.attn = VisionAttention( - dim, num_heads, wq=attn_wq, wk=attn_wk, wv=attn_wv, proj=attn_proj - ) - self.mlp = VisionMLP(fc1=mlp_fc1, fc2=mlp_fc2) + self.norm1 = config.norm1.build() + self.norm2 = config.norm2.build() + self.attn = config.attn.build() + self.mlp = config.mlp.build() def forward( self, @@ -415,27 +411,21 @@ class Config(Module.Config): """Configuration for Qwen3.5 Vision Encoder (ViT).""" dim: int = 1280 - ffn_dim: int = 5120 num_layers: int = 32 num_heads: int = 16 - dim: int - n_heads: int - spatial_merge_size: int - num_position_embeddings: int + patch_size: int = 16 + temporal_patch_size: int = 2 + in_channels: int = 3 + spatial_merge_size: int = 2 + + num_position_embeddings: int = 4096 - # Per-layer Linear configs for vision encoder sub-modules - # Linear instead of Conv3d — equivalent when kernel_size equals patch size, - # but more efficient via batched matmul on pre-flattened patches. + # Sub-module configs patch_embed_proj: Linear.Config - attn_wq: Linear.Config - attn_wk: Linear.Config - attn_wv: Linear.Config - attn_proj: Linear.Config - mlp_fc1: Linear.Config - mlp_fc2: Linear.Config - merger_fc1: Linear.Config - merger_fc2: Linear.Config + block: VisionTransformerBlock.Config + rotary_pos_emb: VisionRotaryEmbedding.Config + merger: PatchMerger.Config def __init__(self, config: Config): super().__init__() @@ -452,38 +442,14 @@ def __init__(self, config: Config): ) self.num_grid_per_side = int(config.num_position_embeddings**0.5) - head_dim = config.dim // config.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding( - head_dim // 2, theta=config.rope_theta - ) - # Cached RoPE freq table — recomputed only when max_hw grows + self.rotary_pos_emb = config.rotary_pos_emb.build() self._cached_freq_table: torch.Tensor | None = None self.layers = ModuleDict( - { - str(idx): VisionTransformerBlock( - config.dim, - config.num_heads, - config.layer_norm_eps, - attn_wq=config.attn_wq, - attn_wk=config.attn_wk, - attn_wv=config.attn_wv, - attn_proj=config.attn_proj, - mlp_fc1=config.mlp_fc1, - mlp_fc2=config.mlp_fc2, - ) - for idx in range(config.num_layers) - } + {str(idx): config.block.build() for idx in range(config.num_layers)} ) - self.merger = PatchMerger( - hidden_size=config.dim, - out_hidden_size=config.out_hidden_size, - spatial_merge_size=config.spatial_merge_size, - layer_norm_eps=config.layer_norm_eps, - fc1=config.merger_fc1, - fc2=config.merger_fc2, - ) + self.merger = config.merger.build() def compute_position_embeddings( self, grid_thw: torch.Tensor, max_num_patch: int diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index d4209351b4..8ed2982d17 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -188,7 +188,6 @@ def __init__(self, config: Config): model_spec = config.model_spec device_module, device_type = utils.device_module, utils.device_type - # pyrefly: ignore [read-only] self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) @@ -593,8 +592,28 @@ def post_dataloading_process( # maskless backend (e.g. the SDPA config used by the graph_trainer # tests) still receives positions for RoPE but no masks — it relies on # is_causal instead. - if isinstance(self.model_config, Decoder.Config) and positions is not None: - inner_attention = self.model_config.layers[0].attention.inner_attention + if isinstance(self.model_config, Decoder.Config): + attn_config = next( + ( + l.attention + for l in self.model_config.layers + if getattr(l, "attention", None) is not None + ), + None, + ) + inner_attention = ( + attn_config.inner_attention if attn_config is not None else None + ) + + if attn_config is not None and attn_config.mask_type == "block_causal": + assert ( + positions is not None + ), "block_causal mask requires per-document positions from the dataloader" + else: + positions = torch.arange( + inputs.shape[1], dtype=torch.int32, device=inputs.device + ).repeat(inputs.shape[0], 1) + if isinstance( inner_attention, (FlexAttention.Config, VarlenAttention.Config) ): From 541c1b67768e8cda82d9a3361d7be8f6eb98e6f7 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Thu, 4 Jun 2026 19:14:28 -0700 Subject: [PATCH 3/7] rewrite sharedexperts and use local map --- .ci/docker/requirements-vlm.txt | 1 + torchtitan/models/common/__init__.py | 2 + torchtitan/models/common/config_utils.py | 39 +++++++- torchtitan/models/common/decoder.py | 25 ++--- torchtitan/models/common/moe.py | 36 ++++--- torchtitan/models/common/moe_sharding.py | 40 +++++--- torchtitan/models/common/nn_modules.py | 28 ++++++ torchtitan/models/deepseek_v3/__init__.py | 3 +- torchtitan/models/qwen3_5/__init__.py | 28 ++++-- torchtitan/models/qwen3_5/model.py | 99 +++++++------------ torchtitan/models/qwen3_5/parallelize.py | 8 -- torchtitan/models/qwen3_5/sharding.py | 28 ++---- .../models/qwen3_5/state_dict_adapter.py | 2 +- torchtitan/trainer.py | 13 +-- 14 files changed, 205 insertions(+), 147 deletions(-) diff --git a/.ci/docker/requirements-vlm.txt b/.ci/docker/requirements-vlm.txt index e82b2e33ba..9b1e557669 100644 --- a/.ci/docker/requirements-vlm.txt +++ b/.ci/docker/requirements-vlm.txt @@ -2,3 +2,4 @@ av einops pillow torchvision +flash-linear-attention diff --git a/torchtitan/models/common/__init__.py b/torchtitan/models/common/__init__.py index 52a1634726..2902247763 100644 --- a/torchtitan/models/common/__init__.py +++ b/torchtitan/models/common/__init__.py @@ -25,6 +25,7 @@ from .feed_forward import compute_ffn_hidden_dim, FeedForward from .moe import MoE from .nn_modules import ( + Conv1d, Conv2d, Embedding, GELU, @@ -38,6 +39,7 @@ from .rope import ComplexRoPE, CosSinRoPE, RoPE __all__ = [ + "Conv1d", "Conv2d", "ComplexRoPE", "CosSinRoPE", diff --git a/torchtitan/models/common/config_utils.py b/torchtitan/models/common/config_utils.py index cf4e40cad9..81d7fcb96a 100644 --- a/torchtitan/models/common/config_utils.py +++ b/torchtitan/models/common/config_utils.py @@ -22,7 +22,12 @@ VarlenAttention, ) from torchtitan.models.common.feed_forward import FeedForward -from torchtitan.models.common.moe import GroupedExperts, MoE, TokenChoiceTopKRouter +from torchtitan.models.common.moe import ( + GroupedExperts, + MoE, + SharedExperts, + TokenChoiceTopKRouter, +) from torchtitan.models.common.nn_modules import Linear, RMSNorm from torchtitan.models.common.rope import RoPE from torchtitan.models.common.token_dispatcher import ( @@ -151,13 +156,40 @@ def make_ffn_config( ) +def make_shared_experts_config( + *, + dim: int, + hidden_dim: int, + w1_param_init: dict[str, Callable], + w2w3_param_init: dict[str, Callable], + gate_param_init: dict[str, Callable] | None = None, +) -> SharedExperts.Config: + """Build a SharedExperts.Config (SwiGLU FFN with optional sigmoid gate). + + When ``gate_param_init`` is given, the shared expert applies a per-token + sigmoid gate (``sigmoid(gate(x)) * ffn(x)``), e.g. the Qwen3.5 shared + expert. Otherwise it is a plain SwiGLU FFN. + """ + ffn = make_ffn_config( + dim=dim, + hidden_dim=hidden_dim, + w1_param_init=w1_param_init, + w2w3_param_init=w2w3_param_init, + ) + gate = ( + Linear.Config(in_features=dim, out_features=1, param_init=gate_param_init) + if gate_param_init is not None + else None + ) + return SharedExperts.Config(w1=ffn.w1, w2=ffn.w2, w3=ffn.w3, gate=gate) + + def make_moe_config( *, num_experts: int = 8, router: TokenChoiceTopKRouter.Config, experts: GroupedExperts.Config, - shared_experts: FeedForward.Config | None = None, - shared_expert_gate: Module.Config | None = None, + shared_experts: SharedExperts.Config | None = None, load_balance_coeff: float | None = 1e-3, ) -> MoE.Config: """Build a fully-specified MoE.Config.""" @@ -167,7 +199,6 @@ def make_moe_config( router=router, experts=experts, shared_experts=shared_experts, - shared_expert_gate=shared_expert_gate, ) diff --git a/torchtitan/models/common/decoder.py b/torchtitan/models/common/decoder.py index 7985c19ffe..872817a1a2 100644 --- a/torchtitan/models/common/decoder.py +++ b/torchtitan/models/common/decoder.py @@ -79,7 +79,14 @@ class Config(BaseModel.Config): @property def max_seq_len(self) -> int: - return self.layers[0].attention.rope.max_seq_len + # Llama4/iRoPE can have NoPE layers with ``rope=None``; use the + # first layer that carries RoPE to expose the model context length. + for layer_cfg in self.layers: + attention_cfg = getattr(layer_cfg, "attention", None) + rope_cfg = getattr(attention_cfg, "rope", None) + if rope_cfg is not None: + return rope_cfg.max_seq_len + raise ValueError("Decoder config does not define RoPE max_seq_len.") def update_from_config( self, @@ -116,14 +123,7 @@ def update_from_config( tp = parallelism.tensor_parallel_degree if tp > 1: - attention = next( - ( - l.attention - for l in self.layers - if getattr(l, "attention", None) is not None - ), - None, - ) + attention = self.first_attn_config if attention is None: raise ValueError( "No layer with attention config found for TP validation." @@ -142,7 +142,7 @@ def update_from_config( ) for layer_cfg in self.layers: - if hasattr(layer_cfg, "moe") and layer_cfg.moe is not None: + if layer_cfg.moe is not None: from torchtitan.models.common.token_dispatcher import ( DeepEPTokenDispatcher, HybridEPTokenDispatcher, @@ -285,7 +285,10 @@ def get_attention_masks( self, positions: torch.Tensor, ) -> AttentionMasksType: - attn_config = self.config.layers[0].attention + attn_config = self.config.first_attn_config + assert ( + attn_config is not None + ), "get_attention_masks requires an attention layer" inner_attn = attn_config.inner_attention if isinstance(inner_attn, FlexAttention.Config): return self._create_flex_attention_mask_for_document(positions, attn_config) diff --git a/torchtitan/models/common/moe.py b/torchtitan/models/common/moe.py index cdddd4ce9f..ce0bb7bcf3 100644 --- a/torchtitan/models/common/moe.py +++ b/torchtitan/models/common/moe.py @@ -295,6 +295,30 @@ def forward( ) +class SharedExperts(FeedForward): + """Shared expert: SwiGLU FFN with an optional per-token sigmoid gate. + + When ``gate`` is set, the output is ``sigmoid(gate(x)) * ffn(x)``; + otherwise it is a plain SwiGLU FFN. Inherits ``w1/w2/w3`` from + FeedForward so weight FQNs are unchanged. + """ + + @dataclass(kw_only=True, slots=True) + class Config(FeedForward.Config): + gate: Linear.Config | None = None + + def __init__(self, config: Config): + super().__init__(config) + self.gate = config.gate.build() if config.gate is not None else None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = super().forward(x) + if self.gate is not None: + # TODO: make the gate activation configurable (e.g. softmax, silu) + out = torch.sigmoid(self.gate(x)) * out + return out + + class MoE(Module): """Mixture of Experts layer. @@ -322,8 +346,7 @@ class Config(Module.Config): experts: GroupedExperts.Config router: TokenChoiceTopKRouter.Config load_balance_coeff: float | None = 1e-3 - shared_experts: FeedForward.Config | None = None - shared_expert_gate: Module.Config | None = None + shared_experts: SharedExperts.Config | None = None def __init__(self, config: Config): super().__init__() @@ -334,11 +357,6 @@ def __init__(self, config: Config): self.shared_experts = ( config.shared_experts.build() if config.shared_experts is not None else None ) - self.shared_expert_gate = ( - config.shared_expert_gate.build() - if config.shared_expert_gate is not None - else None - ) # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) # NOTE: tokens_per_expert_E is accumulated in the model forward pass. @@ -453,10 +471,6 @@ def _generate_routing_map( sync_combine() if shared_out_BLD is not None: - if self.shared_expert_gate is not None: - shared_out_BLD = ( - torch.sigmoid(self.shared_expert_gate(x_BLD)) * shared_out_BLD - ) out_BLD = out_BLD + shared_out_BLD return out_BLD diff --git a/torchtitan/models/common/moe_sharding.py b/torchtitan/models/common/moe_sharding.py index 8953d2436d..e1d99bb968 100644 --- a/torchtitan/models/common/moe_sharding.py +++ b/torchtitan/models/common/moe_sharding.py @@ -230,24 +230,38 @@ def set_moe_sharding_config( # Router gate: dense-family TP plan with Partial output grad. moe_cfg.router.gate.sharding_config = _router_gate_config(enable_ep=enable_ep) - # Shared experts: optional. Use Partial-flow variants so the - # Partial->sp_layout reduce only happens once at the MoE boundary. - if getattr(moe_cfg, "shared_experts", None) is not None: - moe_cfg.shared_experts.w1.sharding_config = _shared_expert_colwise_config( + # Shared experts: optional SharedExperts (SwiGLU FFN + optional gate). + # Gather x to Replicate ONCE at the module boundary so w1/w3/gate all share + # it (their per-linear input redistributions become no-ops). w2 (rowwise) + # keeps the output Partial; the Partial->sp_layout reduce happens once at + # the MoE boundary. + shared = moe_cfg.shared_experts + if shared is not None: + sp_layout = Shard(1) if enable_sp else Replicate() + shared_input_layout = Replicate() if not enable_ep else sp_layout + shared.sharding_config = ShardingConfig( + in_src_shardings={"x": dense_activation_placement(tp=shared_input_layout)}, + in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, + ) + + shared.w1.sharding_config = _shared_expert_colwise_config( enable_ep=enable_ep, enable_sp=enable_sp ) - moe_cfg.shared_experts.w2.sharding_config = _shared_expert_rowwise_config() - moe_cfg.shared_experts.w3.sharding_config = _shared_expert_colwise_config( + shared.w2.sharding_config = _shared_expert_rowwise_config() + shared.w3.sharding_config = _shared_expert_colwise_config( enable_ep=enable_ep, enable_sp=enable_sp ) - if getattr(moe_cfg, "shared_expert_gate", None) is not None: - moe_cfg.shared_expert_gate.sharding_config = ShardingConfig( - state_shardings={ - "weight": dense_param_placement(tp=Replicate()), - "bias": dense_param_placement(tp=Replicate()), - } - ) + if shared.gate is not None: + # Gate output is Replicate, so `gate * ffn` is + # `Replicate * Partial = Partial` with no extra collective. + shared.gate.sharding_config = ShardingConfig( + state_shardings={ + "weight": dense_param_placement(tp=Replicate()), + "bias": dense_param_placement(tp=Replicate()), + }, + out_dst_shardings=dense_activation_placement(tp=Replicate()), + ) # Routed experts: local_map converts DTensor inputs to local for # dispatch/compute/combine, then wraps local output as DTensor(Partial). diff --git a/torchtitan/models/common/nn_modules.py b/torchtitan/models/common/nn_modules.py index 10273967d7..db2127e54b 100644 --- a/torchtitan/models/common/nn_modules.py +++ b/torchtitan/models/common/nn_modules.py @@ -23,6 +23,33 @@ from torchtitan.protocols.module import Module +class Conv1d(nn.Conv1d, Module): + """Configurable nn.Conv1d.""" + + @dataclass(kw_only=True, slots=True) + class Config(Module.Config): + in_channels: int + out_channels: int + kernel_size: int + stride: int = 1 + padding: int = 0 + groups: int = 1 + # Matches the upstream ``nn.Conv1d`` default (differs from + # ``Linear.Config.bias``, which defaults to False). + bias: bool = True + + def __init__(self, config: Config): + super().__init__( + config.in_channels, + config.out_channels, + config.kernel_size, + stride=config.stride, + padding=config.padding, + groups=config.groups, + bias=config.bias, + ) + + class Conv2d(nn.Conv2d, Module): """Configurable nn.Conv2d.""" @@ -162,6 +189,7 @@ def __init__(self, config: Config): __all__ = [ + "Conv1d", "Conv2d", "Embedding", "GELU", diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d66071a641..ce540a5db4 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -27,6 +27,7 @@ make_ffn_config, make_moe_config, make_router_config, + make_shared_experts_config, ) from torchtitan.models.common.param_init import depth_scaled_std from torchtitan.models.utils import validate_converter_order @@ -242,7 +243,7 @@ def _build_dsv3_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), - shared_experts=make_ffn_config( + shared_experts=make_shared_experts_config( dim=dim, hidden_dim=moe_hidden_dim * num_shared_experts, w1_param_init=_LINEAR_INIT, diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py index dba06463f4..2203bf18ae 100644 --- a/torchtitan/models/qwen3_5/__init__.py +++ b/torchtitan/models/qwen3_5/__init__.py @@ -12,13 +12,14 @@ from torchtitan.components.optimizer import register_moe_load_balancing_hook -from torchtitan.models.common import Embedding, Linear, RoPE # noqa: F401 +from torchtitan.models.common import Conv1d, Embedding, Linear, RoPE # noqa: F401 from torchtitan.models.common.config_utils import ( get_attention_config, make_experts_config, make_ffn_config, make_moe_config, make_router_config, + make_shared_experts_config, ) from torchtitan.models.common.nn_modules import LayerNorm from torchtitan.models.common.param_init import depth_scaled_std # noqa: F401 @@ -226,6 +227,7 @@ def _qwen35_deltanet_config( key_head_dim: int, value_head_dim: int, layer_id: int, + conv_kernel_size: int = 4, fla_backend: Literal[ "fla_chunked", "fla_fused_recurrent", "torch_naive" ] = "fla_chunked", @@ -239,15 +241,31 @@ def _proj(in_f: int, out_f: int, init: dict) -> Linear.Config: in_features=in_f, out_features=out_f, bias=False, param_init=init ) + def _conv(channels: int) -> Conv1d.Config: + # Depthwise causal conv (groups == channels). Causal left-padding is + # applied in the forward, so padding=0 here. + return Conv1d.Config( + in_channels=channels, + out_channels=channels, + kernel_size=conv_kernel_size, + groups=channels, + padding=0, + bias=False, + ) + return GatedDeltaNet.Config( key_head_dim=key_head_dim, value_head_dim=value_head_dim, + conv_kernel_size=conv_kernel_size, in_proj_q=_proj(dim, key_dim, _LINEAR_INIT), in_proj_k=_proj(dim, key_dim, _LINEAR_INIT), in_proj_v=_proj(dim, value_dim, _LINEAR_INIT), in_proj_z=_proj(dim, value_dim, _LINEAR_INIT), in_proj_a=_proj(dim, n_value_heads, _LINEAR_INIT), in_proj_b=_proj(dim, n_value_heads, _LINEAR_INIT), + conv_q=_conv(key_dim), + conv_k=_conv(key_dim), + conv_v=_conv(value_dim), kernel=GatedDeltaKernel.Config(backend=fla_backend), norm=RMSNormGated.Config( dim=value_head_dim, @@ -410,16 +428,12 @@ def _build_qwen35_moe_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), - shared_experts=make_ffn_config( + shared_experts=make_shared_experts_config( dim=dim, hidden_dim=shared_expert_hidden_dim, w1_param_init=_LINEAR_INIT, w2w3_param_init=_depth_init(layer_id), - ), - shared_expert_gate=Linear.Config( - in_features=dim, - out_features=1, - param_init=_LINEAR_INIT, + gate_param_init=_LINEAR_INIT, ), ), attention_norm=_offset_norm(dim), diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py index 033f58cec7..02e4174d3e 100644 --- a/torchtitan/models/qwen3_5/model.py +++ b/torchtitan/models/qwen3_5/model.py @@ -12,8 +12,9 @@ import torch.nn.functional as F from torch import nn from torch.distributed.tensor import DTensor +from torch.distributed.tensor.experimental import local_map -from torchtitan.models.common import Linear +from torchtitan.models.common import Conv1d, Linear from torchtitan.models.common.attention import AttentionMasksType, BaseAttention from torchtitan.models.common.decoder import Decoder from torchtitan.models.common.rope import apply_rotary_emb_cos_sin @@ -24,10 +25,6 @@ from .vision_encoder import Qwen35VisionEncoder -class _Conv1d(nn.Conv1d, Module): - pass - - try: from fla.ops.gated_delta_rule import ( chunk_gated_delta_rule as _fla_chunk_gated_delta_rule, @@ -239,6 +236,9 @@ class Config(Module.Config): in_proj_z: Linear.Config in_proj_a: Linear.Config in_proj_b: Linear.Config + conv_q: Conv1d.Config + conv_k: Conv1d.Config + conv_v: Conv1d.Config kernel: GatedDeltaKernel.Config norm: RMSNormGated.Config out_proj: Linear.Config @@ -249,7 +249,6 @@ def __init__(self, config: Config): self.value_head_dim = config.value_head_dim self.conv_kernel_size = config.conv_kernel_size - key_dim = config.in_proj_q.out_features value_dim = config.in_proj_v.out_features self.in_proj_q = config.in_proj_q.build() @@ -259,30 +258,9 @@ def __init__(self, config: Config): self.in_proj_a = config.in_proj_a.build() self.in_proj_b = config.in_proj_b.build() - self.conv_q = _Conv1d( - in_channels=key_dim, - out_channels=key_dim, - bias=False, - kernel_size=config.conv_kernel_size, - groups=key_dim, - padding=0, - ) - self.conv_k = _Conv1d( - in_channels=key_dim, - out_channels=key_dim, - bias=False, - kernel_size=config.conv_kernel_size, - groups=key_dim, - padding=0, - ) - self.conv_v = _Conv1d( - in_channels=value_dim, - out_channels=value_dim, - bias=False, - kernel_size=config.conv_kernel_size, - groups=value_dim, - padding=0, - ) + self.conv_q = config.conv_q.build() + self.conv_k = config.conv_k.build() + self.conv_v = config.conv_v.build() n_value_heads = value_dim // config.value_head_dim self.A_log = nn.Parameter(torch.empty(n_value_heads)) @@ -296,23 +274,35 @@ def _causal_conv(self, x: torch.Tensor, conv: nn.Module) -> torch.Tensor: # pyrefly: ignore [bad-argument-type] x = F.pad(x.transpose(1, 2), (self.conv_kernel_size - 1, 0)) if isinstance(x, DTensor): - # TODO: Remove once DTensor Conv1d dispatch handles sharded groups. - mesh, plc = x.device_mesh, x.placements - w: torch.Tensor = conv.weight # pyrefly: ignore [bad-assignment] - if isinstance(w, DTensor): - w = w.to_local() - local_groups = w.size(0) - # pyrefly: ignore [no-matching-overload] - out = F.conv1d( - x.to_local(), - w, - None, - conv.stride, - conv.padding, - conv.dilation, - local_groups, + # TODO: Remove once the DTensor Conv1d dispatch fix for sharded + # groups lands in a released torch. local_map runs the conv on + # local shards (channel-sharded input + Shard(0) weight) and + # restores DTensor-ness, with explicit gradient placements. + x_plc = x.placements + w = conv.weight + w_plc = w.placements # pyrefly: ignore [missing-attribute] + + def _conv(x_local: torch.Tensor, w_local: torch.Tensor) -> torch.Tensor: + # groups == local out-channels (depthwise, channel-sharded) + # pyrefly: ignore [no-matching-overload] + return F.conv1d( + x_local, + w_local, + None, + conv.stride, + conv.padding, + conv.dilation, + w_local.size(0), + ) + + conv_dt = local_map( + _conv, + out_placements=(x_plc,), + in_placements=(x_plc, w_plc), + in_grad_placements=(x_plc, w_plc), + device_mesh=x.device_mesh, ) - x = DTensor.from_local(out, mesh, plc, run_check=False) + x = conv_dt(x, w) # pyrefly: ignore [bad-argument-count] else: x = conv(x) return F.silu(x).transpose(1, 2) @@ -546,18 +536,6 @@ def update_from_config( tp = parallelism.tensor_parallel_degree if tp > 1: - attn_cfg = next( - (l.attention for l in self.layers if l.attention is not None), - None, - ) - if attn_cfg is not None and ( - attn_cfg.n_heads % tp != 0 or attn_cfg.n_kv_heads % tp != 0 - ): - raise ValueError( - f"tensor_parallel_degree ({tp}) must divide " - f"n_heads ({attn_cfg.n_heads}) and " - f"n_kv_heads ({attn_cfg.n_kv_heads})." - ) dn_cfg = next( (l.delta_net for l in self.layers if l.delta_net is not None), None, @@ -583,10 +561,7 @@ def update_from_config( def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, int]: - attn_cfg = next( - (l.attention for l in self.layers if l.attention is not None), - None, - ) + attn_cfg = self.first_attn_config # pyrefly: ignore [missing-attribute] n_heads = attn_cfg.n_heads # pyrefly: ignore [missing-attribute] diff --git a/torchtitan/models/qwen3_5/parallelize.py b/torchtitan/models/qwen3_5/parallelize.py index 0416da20b8..0e625acea4 100644 --- a/torchtitan/models/qwen3_5/parallelize.py +++ b/torchtitan/models/qwen3_5/parallelize.py @@ -31,7 +31,6 @@ from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama4.parallelize import apply_fsdp -from torchtitan.models.qwen3_5.sharding import set_deltanet_conv1d_sharding from torchtitan.tools.logging import logger @@ -96,13 +95,6 @@ def parallelize_qwen3_5( if parallelism.enable_async_tensor_parallel and not model_compile_enabled: raise RuntimeError("Async TP requires torch.compile") - # Conv1d modules don't have Config fields, set sharding on built modules. - # pyrefly: ignore [not-callable] - for block in model.layers.values(): - # pyrefly: ignore [missing-attribute] - if not block.full_attn: - # pyrefly: ignore [missing-attribute] - set_deltanet_conv1d_sharding(block.attn) # pyrefly: ignore [not-callable] model.parallelize(parallel_dims) diff --git a/torchtitan/models/qwen3_5/sharding.py b/torchtitan/models/qwen3_5/sharding.py index 7fc81f4e41..a90fbbc88a 100644 --- a/torchtitan/models/qwen3_5/sharding.py +++ b/torchtitan/models/qwen3_5/sharding.py @@ -203,8 +203,8 @@ def _set_deltanet_sharding(deltanet_cfg) -> None: out_proj is RowwiseParallel (reduce-scatter back to Shard(1)). A_log and dt_bias are per-head parameters, Shard(0) on TP. - Conv1d sharding is set on built modules by - ``set_deltanet_sub_module_sharding`` (Conv1d doesn't have Config). + Conv1d weights are Shard(0) (out-channels); the DTensor->local conversion + for the depthwise conv is handled in the model's ``_causal_conv``. """ # ColwiseParallel on all input projections for name in ( @@ -217,6 +217,14 @@ def _set_deltanet_sharding(deltanet_cfg) -> None: ): getattr(deltanet_cfg, name).sharding_config = colwise_config() + # Depthwise Conv1d weights: Shard(0) on out-channels (head-sharded). + _conv_shard = ShardingConfig( + state_shardings={"weight": dense_param_placement(tp=Shard(0))}, + ) + deltanet_cfg.conv_q.sharding_config = _conv_shard + deltanet_cfg.conv_k.sharding_config = _conv_shard + deltanet_cfg.conv_v.sharding_config = _conv_shard + # RowwiseParallel on output projection (reduce-scatter to SP) deltanet_cfg.out_proj.sharding_config = rowwise_config(output_sp=True) @@ -254,19 +262,3 @@ def _set_deltanet_sharding(deltanet_cfg) -> None: in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, out_dst_shardings=dense_activation_placement(tp=Shard(1)), ) - - -def set_deltanet_conv1d_sharding(deltanet_module) -> None: - """Set sharding on GatedDeltaNet sub-modules built inline. - - Conv1d modules don't have Config fields, so their sharding must be - set on the built modules. DTensor→local conversion for Conv1d is - handled in the model's _causal_conv (same pattern as GroupedExperts). - - Must be called after model build but before model.parallelize(). - """ - _conv_shard = ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Shard(0))}, - ) - for name in ("conv_q", "conv_k", "conv_v"): - getattr(deltanet_module, name)._sharding_config = _conv_shard diff --git a/torchtitan/models/qwen3_5/state_dict_adapter.py b/torchtitan/models/qwen3_5/state_dict_adapter.py index f3cb1c5236..8f3c579dec 100644 --- a/torchtitan/models/qwen3_5/state_dict_adapter.py +++ b/torchtitan/models/qwen3_5/state_dict_adapter.py @@ -75,7 +75,7 @@ def __init__(self, model_config: Qwen35Model.Config, hf_assets_path: str | None) "model.language_model.layers.{}.mlp.shared_expert.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight", "model.language_model.layers.{}.mlp.shared_expert.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight", "model.language_model.layers.{}.mlp.shared_expert.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight", - "model.language_model.layers.{}.mlp.shared_expert_gate.weight": "layers.{}.moe.shared_expert_gate.weight", + "model.language_model.layers.{}.mlp.shared_expert_gate.weight": "layers.{}.moe.shared_experts.gate.weight", # Final norm and output "model.language_model.norm.weight": "norm.weight", "lm_head.weight": "lm_head.weight", diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index 8ed2982d17..442bd721b9 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -593,17 +593,8 @@ def post_dataloading_process( # tests) still receives positions for RoPE but no masks — it relies on # is_causal instead. if isinstance(self.model_config, Decoder.Config): - attn_config = next( - ( - l.attention - for l in self.model_config.layers - if getattr(l, "attention", None) is not None - ), - None, - ) - inner_attention = ( - attn_config.inner_attention if attn_config is not None else None - ) + attn_config = self.model_config.first_attn_config + inner_attention = getattr(attn_config, "inner_attention", None) if attn_config is not None and attn_config.mask_type == "block_causal": assert ( From 8c92617e434755065329f63d716788c71a615914 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Thu, 4 Jun 2026 19:41:34 -0700 Subject: [PATCH 4/7] refactor optimizer config --- pyproject.toml | 2 +- tests/integration_tests/models.py | 2 +- torchtitan/models/flux/trainer.py | 1 + torchtitan/models/qwen3_5/__init__.py | 11 ++++------ torchtitan/models/qwen3_5/config_registry.py | 22 ++++++++++---------- torchtitan/models/qwen3_5/model.py | 17 +++++++-------- torchtitan/models/qwen3_5/vision_encoder.py | 3 +-- torchtitan/trainer.py | 1 + 8 files changed, 28 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 21a2d28397..92bffa30e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,5 +65,5 @@ testpaths = ["tests"] [tool.pyrefly] project-excludes = ["torchtitan/experiments", "**/tests/**"] -replace-imports-with-any = ["torchao.*", "torchft", "torchvision.*", "deep_ep.*", "jinja2.*"] # optional dependencies +replace-imports-with-any = ["torchao.*", "torchft", "torchvision.*", "deep_ep.*", "jinja2.*", "fla.*"] # optional dependencies search-path = ["../pytorch"] # local built pytorch diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 36365d4f6f..7e615d7302 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -151,7 +151,7 @@ def build_model_tests_list() -> list[OverrideDefinitions]: ], "Qwen3.5 MoE FSDP+TP+EP+PP", "qwen3_5_moe_fsdp+tp+ep+pp", - ngpu=32, + ngpu=8, ), # Integration Test Cases for gpt-oss # TODO: re-enable compile after fixing diff --git a/torchtitan/models/flux/trainer.py b/torchtitan/models/flux/trainer.py index 7dd5c4ebe7..ec2e44a6d9 100644 --- a/torchtitan/models/flux/trainer.py +++ b/torchtitan/models/flux/trainer.py @@ -287,6 +287,7 @@ def train_step( if self.gradient_accumulation_steps > 1: raise ValueError("FLUX doesn't support gradient accumulation for now.") + # pyrefly: ignore [no-matching-overload] input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict=input_dict, labels=labels) diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py index 2203bf18ae..777581f72c 100644 --- a/torchtitan/models/qwen3_5/__init__.py +++ b/torchtitan/models/qwen3_5/__init__.py @@ -229,7 +229,7 @@ def _qwen35_deltanet_config( layer_id: int, conv_kernel_size: int = 4, fla_backend: Literal[ - "fla_chunked", "fla_fused_recurrent", "torch_naive" + "fla_chunked", "fla_fused_recurrent", "torch_native" ] = "fla_chunked", ) -> GatedDeltaNet.Config: """Build a fully-specified GatedDeltaNet.Config.""" @@ -296,7 +296,7 @@ def _build_qwen35_layers( full_attention_interval: int = 4, attn_backend: str, fla_backend: Literal[ - "fla_chunked", "fla_fused_recurrent", "torch_naive" + "fla_chunked", "fla_fused_recurrent", "torch_native" ] = "fla_chunked", ) -> list[Qwen35TransformerBlock.Config]: """Build per-layer configs for dense Qwen3.5 models.""" @@ -367,7 +367,7 @@ def _build_qwen35_moe_layers( full_attention_interval: int = 4, attn_backend: str, fla_backend: Literal[ - "fla_chunked", "fla_fused_recurrent", "torch_naive" + "fla_chunked", "fla_fused_recurrent", "torch_native" ] = "fla_chunked", moe_comm_backend: str = "standard", non_blocking_capacity_factor: float | None = None, @@ -1088,15 +1088,12 @@ def model_registry( for c in converters: c.build().convert(config) - # Detect MoE: check if any layer has moe config - has_moe = any(getattr(layer, "moe", None) is not None for layer in config.layers) - return ModelSpec( name="qwen3_5", flavor=flavor, model=config, parallelize_fn=parallelize_qwen3_5, pipelining_fn=pipeline_qwen3_5, - post_optimizer_build_fn=(register_moe_load_balancing_hook if has_moe else None), + post_optimizer_build_fn=register_moe_load_balancing_hook, state_dict_adapter=Qwen35StateDictAdapter, ) diff --git a/torchtitan/models/qwen3_5/config_registry.py b/torchtitan/models/qwen3_5/config_registry.py index 3bcb1c890c..6fa2f35a17 100644 --- a/torchtitan/models/qwen3_5/config_registry.py +++ b/torchtitan/models/qwen3_5/config_registry.py @@ -8,7 +8,7 @@ from torchtitan.components.loss import ChunkedCELoss from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.metrics import MetricsProcessor -from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.components.optimizer import default_adamw from torchtitan.components.tokenizer import MultiModalTokenizer from torchtitan.config import ( @@ -45,7 +45,7 @@ def qwen35_debugmodel() -> Trainer.Config: metrics=MetricsProcessor.Config(log_freq=1), model_spec=model_registry("debugmodel"), dataloader=_dataloader("cc12m-test"), - optimizer=OptimizersContainer.Config(lr=5e-3), + optimizer=default_adamw(lr=5e-3), lr_scheduler=LRSchedulersContainer.Config( warmup_steps=2, decay_ratio=0.8, @@ -75,7 +75,7 @@ def qwen35_debugmodel_moe() -> Trainer.Config: metrics=MetricsProcessor.Config(log_freq=1), model_spec=model_registry("debugmodel_moe", moe_comm_backend="standard"), dataloader=_dataloader("cc12m-test"), - optimizer=OptimizersContainer.Config(lr=5e-3), + optimizer=default_adamw(lr=5e-3), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=2), training=TrainingConfig( local_batch_size=2, @@ -105,7 +105,7 @@ def qwen35_0_8b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("0.8B"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-3), + optimizer=default_adamw(lr=5e-3), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -132,7 +132,7 @@ def qwen35_2b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("2B"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-3), + optimizer=default_adamw(lr=5e-3), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -159,7 +159,7 @@ def qwen35_4b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("4B"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-4), + optimizer=default_adamw(lr=5e-4), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -185,7 +185,7 @@ def qwen35_9b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("9B"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-4), + optimizer=default_adamw(lr=5e-4), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -213,7 +213,7 @@ def qwen35_27b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("27B"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-4), + optimizer=default_adamw(lr=5e-4), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -241,7 +241,7 @@ def qwen35_35b_a3b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("35B-A3B", moe_comm_backend="standard"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-4), + optimizer=default_adamw(lr=5e-4), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -270,7 +270,7 @@ def qwen35_122b_a10b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("122B-A10B", moe_comm_backend="standard"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-4), + optimizer=default_adamw(lr=5e-4), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, @@ -299,7 +299,7 @@ def qwen35_397b_a17b() -> Trainer.Config: tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS), model_spec=model_registry("397B-A17B", moe_comm_backend="standard"), dataloader=_dataloader("cc12m"), - optimizer=OptimizersContainer.Config(lr=5e-4), + optimizer=default_adamw(lr=5e-4), lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20), training=TrainingConfig( local_batch_size=4, diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py index 02e4174d3e..e751c25282 100644 --- a/torchtitan/models/qwen3_5/model.py +++ b/torchtitan/models/qwen3_5/model.py @@ -41,7 +41,7 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) -def _torch_naive_gated_delta( +def _torch_native_gated_delta( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -154,9 +154,9 @@ class GatedDeltaKernel(Module): class Config(Module.Config): # "fla_chunked": parallel within chunks, fast for training (default) # "fla_fused_recurrent": token-by-token, lower memory for long sequences - # "torch_naive": pure-Python reference, for numerical testing only + # "torch_native": pure-Python reference, for numerical testing only backend: Literal[ - "fla_chunked", "fla_fused_recurrent", "torch_naive" + "fla_chunked", "fla_fused_recurrent", "torch_native" ] = "fla_chunked" def __init__(self, config: Config): @@ -178,8 +178,8 @@ def forward( q = q.repeat_interleave(repeat, dim=2) k = k.repeat_interleave(repeat, dim=2) - if self.backend == "torch_naive": - return _torch_naive_gated_delta(q, k, v, g, beta) + if self.backend == "torch_native": + return _torch_native_gated_delta(q, k, v, g, beta) if not _HAS_FLA: raise RuntimeError( @@ -208,7 +208,7 @@ def forward( else: raise ValueError( f"Unknown fla_backend '{self.backend}'. " - "Valid: 'fla_chunked', 'fla_fused_recurrent', 'torch_naive'." + "Valid: 'fla_chunked', 'fla_fused_recurrent', 'torch_native'." ) # FLA kernels return (output, final_state); we only need output @@ -271,8 +271,7 @@ def __init__(self, config: Config): self.out_proj = config.out_proj.build() def _causal_conv(self, x: torch.Tensor, conv: nn.Module) -> torch.Tensor: - # pyrefly: ignore [bad-argument-type] - x = F.pad(x.transpose(1, 2), (self.conv_kernel_size - 1, 0)) + x = F.pad(x.transpose(1, 2), [self.conv_kernel_size - 1, 0]) if isinstance(x, DTensor): # TODO: Remove once the DTensor Conv1d dispatch fix for sharded # groups lands in a released torch. local_map runs the conv on @@ -302,7 +301,7 @@ def _conv(x_local: torch.Tensor, w_local: torch.Tensor) -> torch.Tensor: in_grad_placements=(x_plc, w_plc), device_mesh=x.device_mesh, ) - x = conv_dt(x, w) # pyrefly: ignore [bad-argument-count] + x = conv_dt(x, w) # pyrefly: ignore else: x = conv(x) return F.silu(x).transpose(1, 2) diff --git a/torchtitan/models/qwen3_5/vision_encoder.py b/torchtitan/models/qwen3_5/vision_encoder.py index b720452037..8e2bf38317 100644 --- a/torchtitan/models/qwen3_5/vision_encoder.py +++ b/torchtitan/models/qwen3_5/vision_encoder.py @@ -90,8 +90,7 @@ def _compute_learned_pos_embeds( for (h, w), indices in hw_to_indices.items(): pos_hw = F.interpolate( pos_grid, - # pyrefly: ignore [bad-argument-type] - size=(h, w), + size=[h, w], mode="bilinear", align_corners=True, ) diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index 442bd721b9..9246da2c0d 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -188,6 +188,7 @@ def __init__(self, config: Config): model_spec = config.model_spec device_module, device_type = utils.device_module, utils.device_type + # pyrefly: ignore [read-only] self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) From c08c4d6dff4e45a43391fe1fd70b0b22498976b6 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Tue, 9 Jun 2026 01:21:59 -0700 Subject: [PATCH 5/7] refactors of shared experts, mrope --- .../numerical_tests_qwen3_5.py | 80 +++-- .../numerical_tests_qwen3_5_shard.py | 6 +- torchtitan/components/validate.py | 18 +- .../hf_datasets/multimodal/mm_collator.py | 179 ++++++++++ .../hf_datasets/multimodal/mm_datasets.py | 6 + torchtitan/models/common/config_utils.py | 31 +- torchtitan/models/common/decoder.py | 18 + torchtitan/models/common/moe.py | 26 +- torchtitan/models/common/moe_sharding.py | 22 +- torchtitan/models/deepseek_v3/__init__.py | 3 +- torchtitan/models/qwen3_5/README.md | 8 +- torchtitan/models/qwen3_5/__init__.py | 165 +++++---- torchtitan/models/qwen3_5/config_registry.py | 1 + torchtitan/models/qwen3_5/model.py | 329 ++++-------------- torchtitan/models/qwen3_5/rope.py | 57 ++- torchtitan/models/qwen3_5/sharding.py | 156 +++++---- torchtitan/trainer.py | 3 + 17 files changed, 579 insertions(+), 529 deletions(-) diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py index bb42c237cc..e9d993e25f 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py @@ -14,29 +14,34 @@ Usage: python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \ - --hf_model_path hf_assets/Qwen/Qwen3.5-4B \ - --tt_checkpoint_path outputs/Qwen/qwen3_5_4b_dcp - - python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \ - --hf_model_path hf_assets/Qwen/Qwen3.5-35B-A3B \ - --tt_checkpoint_path outputs/Qwen/qwen3_5_35b_a3b_dcp \ - --model_flavor 35B-A3B + --hf_model_path hf_assets/Qwen/Qwen3.5-2B \ + --tt_checkpoint_path outputs/Qwen/qwen3_5_2b_dcp \ + --model_flavor 2B """ import argparse import os +import types from typing import Any +import einops as E import torch import torch._dynamo import torch.distributed.checkpoint as dcp import torch.nn.functional as F +from PIL import Image torch._dynamo.config.disable = True from torchtitan.components.checkpoint import ModelWrapper -from torchtitan.models.qwen3_5 import model_registry -from transformers import AutoProcessor +from torchtitan.hf_datasets.multimodal.mm_collator import MultiModalCollator +from torchtitan.hf_datasets.multimodal.utils.image import ( + process_image, + vision_to_patches, +) +from torchtitan.models.common.attention import ScaledDotProductAttention +from torchtitan.models.qwen3_5 import model_registry, QWEN3_5_SPECIAL_TOKENS +from transformers import AutoModelForImageTextToText, AutoProcessor # ============================================================ @@ -45,7 +50,7 @@ def kl_divergence(logits_a, logits_b): - """KL(a || b) between two logit tensors.""" + """KL(softmax(b) || softmax(a)) — F.kl_div(log Q, P) computes KL(P || Q).""" return F.kl_div( F.log_softmax(logits_a, dim=-1), F.softmax(logits_b, dim=-1), @@ -78,17 +83,10 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): Returns: hf_inputs: list of dicts (processor output, ready for HF model) - tt_inputs: list of (input_ids, pixel_values, grid_thw) + tt_inputs: list of (input_ids, pixel_values, grid_thw, mrope_positions) pixel_comparisons: list of per-sample pixel diff stats + special_tokens: {"image_id", "video_id"} from the tokenizer """ - import einops as E - from PIL import Image - - from torchtitan.hf_datasets.multimodal.utils.image import ( - process_image, - vision_to_patches, - ) - # Annotate as Any: AutoProcessor.from_pretrained is typed Optional, which # trips the .apply_chat_template call on environments without transformers stubs. processor: Any = AutoProcessor.from_pretrained(hf_model_path) @@ -100,6 +98,17 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): temporal_patch_size = encoder_config.temporal_patch_size merge_size = encoder_config.spatial_merge_size + image_token_id = processor.tokenizer.convert_tokens_to_ids( + QWEN3_5_SPECIAL_TOKENS["image_token"] + ) + video_token_id = processor.tokenizer.convert_tokens_to_ids( + QWEN3_5_SPECIAL_TOKENS["video_token"] + ) + special_tokens = {"image_id": image_token_id, "video_id": video_token_id} + # Reuse the collator's MRoPE builder (only needs spatial_merge_size) so the + # 3D position IDs match the training path exactly. + mrope_builder = types.SimpleNamespace(spatial_merge_size=merge_size) + hf_inputs, tt_inputs, pixel_comparisons = [], [], [] for i in range(num_samples): @@ -141,8 +150,23 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): temporal_patch_size, merge_size, ) + # 3D MRoPE positions (1, S, 3); positions=None → single document. + mrope_positions = MultiModalCollator._build_mrope_positions( + mrope_builder, + hf_in["input_ids"], + grid_thw.unsqueeze(0), + None, + None, + image_token_id=image_token_id, + video_token_id=video_token_id, + ) tt_inputs.append( - (hf_in["input_ids"], patches.unsqueeze(0), grid_thw.unsqueeze(0)) + ( + hf_in["input_ids"], + patches.unsqueeze(0), + grid_thw.unsqueeze(0), + mrope_positions, + ) ) # --- Compare pixel values in image space --- @@ -166,7 +190,7 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): tt_img = E.rearrange(patches, pattern, **kwargs) pixel_comparisons.append(_compare_images(hf_img[:1], tt_img[:1], i)) - return hf_inputs, tt_inputs, pixel_comparisons + return hf_inputs, tt_inputs, pixel_comparisons, special_tokens def _compare_images(hf_img, tt_img, sample_idx): @@ -211,8 +235,6 @@ def print_pixel_comparisons(comparisons): @torch.no_grad() def run_hf(model_path, hf_inputs, device): """Run HF model, return last-token logits per sample.""" - from transformers import AutoModelForImageTextToText - print(f"Loading HuggingFace model on {device} ...") model = AutoModelForImageTextToText.from_pretrained( model_path, @@ -244,10 +266,8 @@ def run_hf(model_path, hf_inputs, device): @torch.no_grad() -def run_tt(model_flavor, checkpoint_path, tt_inputs, device): +def run_tt(model_flavor, checkpoint_path, tt_inputs, special_tokens, device): """Run TT model, return last-token logits per sample.""" - from torchtitan.models.common.attention import ScaledDotProductAttention - print(f"Loading TorchTitan model on {device} ...") model_config = model_registry(model_flavor).model @@ -282,14 +302,13 @@ def forward(self, q, k, v, **kwargs): model.eval() - special_tokens = {"image_id": 248056, "video_id": 248057} - outputs = [] - for i, (tokens, pixel_values, grid_thw) in enumerate(tt_inputs): + for i, (tokens, pixel_values, grid_thw, mrope_positions) in enumerate(tt_inputs): logits = model( tokens.to(device), pixel_values=pixel_values.half().to(device), grid_thw=grid_thw.to(device), + mrope_positions=mrope_positions.to(device), special_tokens=special_tokens, ) outputs.append(logits[:, -1:, :].cpu()) @@ -367,7 +386,7 @@ def main(): print(f"Using {torch.cuda.get_device_name(0)}") print(f"\nBuilding {args.num_samples} test samples ...") - hf_inputs, tt_inputs, pixel_comparisons = build_inputs( + hf_inputs, tt_inputs, pixel_comparisons, special_tokens = build_inputs( args.hf_model_path, args.model_flavor, args.num_samples, @@ -382,6 +401,7 @@ def main(): args.model_flavor, args.tt_checkpoint_path, tt_inputs, + special_tokens, device, ) diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py index f9443f109c..62308692c8 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py @@ -116,10 +116,14 @@ def run_worker(args): tokens = torch.randint(0, 248320, (1, seq_len), device="cuda") dist.broadcast(tokens, src=0) + # Text-only inputs + positions = torch.arange(seq_len, device="cuda").unsqueeze(0) + with torch.no_grad(): output = model( tokens, - special_tokens={"image_id": 151859, "video_id": 151860}, + positions=positions, + special_tokens={"image_id": 248056, "video_id": 248057}, ) if isinstance(output, DTensor): diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 7190bb1a72..ab81e6cd24 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -191,8 +191,20 @@ def post_dataloading_process( # which is where get_attention_masks is defined. A maskless backend (the # SDPA config used by the graph_trainer tests) still receives positions # for RoPE but no masks — it relies on is_causal instead. - if isinstance(model_config, Decoder.Config) and positions is not None: - inner_attention = model_config.layers[0].attention.inner_attention + mrope_positions = extra_inputs.pop("mrope_positions", None) + if isinstance(model_config, Decoder.Config): + attn_config = model_config.layers[0].attention + inner_attention = attn_config.inner_attention + + if attn_config.mask_type == "block_causal": + assert ( + positions is not None + ), "block_causal mask requires per-document positions from the dataloader" + else: + positions = torch.arange( + inputs.shape[1], dtype=torch.int32, device=inputs.device + ).repeat(inputs.shape[0], 1) + if isinstance( inner_attention, (FlexAttention.Config, VarlenAttention.Config) ): @@ -202,6 +214,8 @@ def post_dataloading_process( ) extra_kwargs["positions"] = positions + if mrope_positions is not None: + extra_kwargs["mrope_positions"] = mrope_positions if self.parallel_dims.cp_enabled: inputs, labels, extra_kwargs = prepare_context_parallel_input( diff --git a/torchtitan/hf_datasets/multimodal/mm_collator.py b/torchtitan/hf_datasets/multimodal/mm_collator.py index 97c65f470e..3350d89de8 100644 --- a/torchtitan/hf_datasets/multimodal/mm_collator.py +++ b/torchtitan/hf_datasets/multimodal/mm_collator.py @@ -34,6 +34,7 @@ class MultiModalCollator: temporal_patch_size: int spatial_merge_size: int tokenizer: MultiModalTokenizer + build_mrope_positions: bool def collate_images( self, all_images: list[torch.Tensor] @@ -133,6 +134,171 @@ def collate_text( return input_ids[:, :-1], labels[:, 1:], positions[:, :-1] + def _build_mrope_positions( + self, + tokens: torch.Tensor, + grid_thw: torch.Tensor | None, + grid_thw_videos: torch.Tensor | None, + positions: torch.Tensor | None, + *, + image_token_id: int, + video_token_id: int, + ) -> torch.Tensor: + """Build 3D (temporal, height, width) MRoPE position IDs per token. + + Returns ``(batch, seq_len, 3)`` — batch/seq leading (like the 2D + ``positions``) so pipeline-parallel microbatching chunks the batch dim + and context parallel can shard the seq dim, with the 3 T/H/W coords as + the last (feature) axis. Runs here on CPU data workers, off the GPU + training path. + + Args: + tokens: (batch, seq_len) token IDs. + grid_thw: (num_images, 3) image grid dims, or None. + grid_thw_videos: (num_videos, 3) video grid dims, or None. + positions: (batch, seq_len) per-token positions; document + boundaries are detected where positions reset. + image_token_id: Placeholder token ID marking image positions. + video_token_id: Placeholder token ID marking video positions. + + Returns: + (batch, seq_len, 3) MRoPE position IDs. + """ + # Expand each video [T, H, W] into T rows of [1, H, W] so each frame is + # treated like an image; temporal position comes from frame ordering. + if grid_thw_videos is not None: + grid_thw_videos = torch.repeat_interleave( + grid_thw_videos, grid_thw_videos[:, 0], dim=0 + ) + grid_thw_videos[:, 0] = 1 + + spatial_merge_size = self.spatial_merge_size + + batch_size, seq_len = tokens.shape + mrope_positions = torch.zeros( + batch_size, seq_len, 3, dtype=tokens.dtype, device=tokens.device + ) + + if positions is not None: + resets = positions[:, 1:] < positions[:, :-1] # (batch, seq_len-1) + # First token of each consecutive vision region (image or video). + vision_mask = (tokens == image_token_id) | (tokens == video_token_id) + prev_vision = torch.cat( + [torch.zeros_like(vision_mask[:, :1]), vision_mask[:, :-1]], dim=1 + ) + batch_vision_starts = vision_mask & ~prev_vision # (batch, seq_len) + grid_cache: dict[tuple[int, int, int], torch.Tensor] = {} + + image_index, video_index = 0, 0 + # With sample packing, each sample may contain multiple documents. + for sample_i in range(batch_size): + llm_pos_ids_list: list[torch.Tensor] = [] + + if positions is not None: + # pyrefly: ignore [unbound-name] + reset_indices = torch.where(resets[sample_i])[0] + 1 + doc_starts = [0] + reset_indices.tolist() + doc_ranges = [ + ( + doc_starts[d], + doc_starts[d + 1] if d + 1 < len(doc_starts) else seq_len, + ) + for d in range(len(doc_starts)) + ] + else: + doc_ranges = [(0, seq_len)] + + sample_tokens = tokens[sample_i] + sample_vision_starts = torch.where(batch_vision_starts[sample_i])[ + 0 + ].tolist() + vision_start_index = 0 + + for doc_start, doc_end in doc_ranges: + doc_pos_ids_list: list[torch.Tensor] = [] + + doc_vision_starts: list[int] = [] + while ( + vision_start_index < len(sample_vision_starts) + and sample_vision_starts[vision_start_index] < doc_end + ): + doc_vision_starts.append(sample_vision_starts[vision_start_index]) + vision_start_index += 1 + + pair_cursor = doc_start + for vision_start in doc_vision_starts: + if sample_tokens[vision_start] == image_token_id: + # pyrefly: ignore [unsupported-operation] + t, h, w = grid_thw[image_index] + image_index += 1 + else: + # pyrefly: ignore [unsupported-operation] + t, h, w = grid_thw_videos[video_index] + video_index += 1 + + llm_grid_t, llm_grid_h, llm_grid_w = ( + int(t.item()), + int(h.item()) // spatial_merge_size, + int(w.item()) // spatial_merge_size, + ) + text_len = vision_start - pair_cursor + + pos_id_offset = ( + doc_pos_ids_list[-1].max() + 1 + if len(doc_pos_ids_list) > 0 + else 0 + ) + # [text tokens] — sequential positions, identical on all 3 axes. + doc_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset + ) + # [vision tokens] — 3D grid positions (T, H, W). + grid_key = (llm_grid_t, llm_grid_h, llm_grid_w) + if grid_key not in grid_cache: + hw = llm_grid_h * llm_grid_w + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, hw) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + grid_cache[grid_key] = torch.stack([t_index, h_index, w_index]) + doc_pos_ids_list.append( + grid_cache[grid_key] + text_len + pos_id_offset + ) + pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w + + # Trailing [text tokens] after the last text/vision pair. + if pair_cursor < doc_end: + pos_id_offset = ( + doc_pos_ids_list[-1].max() + 1 + if len(doc_pos_ids_list) > 0 + else 0 + ) + text_len = doc_end - pair_cursor + doc_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset + ) + + llm_pos_ids_list.extend(doc_pos_ids_list) + + # llm_pos_ids_list is (3, segment_len); concat -> (3, seq), then transpose + mrope_positions[sample_i] = torch.cat(llm_pos_ids_list, dim=1).T + + return mrope_positions + def __call__( self, batch: list[dict[str, Any]] ) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]: @@ -186,5 +352,18 @@ def __call__( }, } + if self.build_mrope_positions and ( + grids is not None or video_grids is not None + ): + special_tokens = input_dict["special_tokens"] + input_dict["mrope_positions"] = self._build_mrope_positions( + input_ids, + grids, + video_grids, + positions, + image_token_id=special_tokens["image_id"], + video_token_id=special_tokens["video_id"], + ) + # pyrefly: ignore [bad-return] return input_dict, labels diff --git a/torchtitan/hf_datasets/multimodal/mm_datasets.py b/torchtitan/hf_datasets/multimodal/mm_datasets.py index b9a9dca862..71adda51a2 100644 --- a/torchtitan/hf_datasets/multimodal/mm_datasets.py +++ b/torchtitan/hf_datasets/multimodal/mm_datasets.py @@ -531,6 +531,11 @@ class Config(ParallelAwareDataloader.Config): video_max_frames: int = 768 """Maximum number of frames to sample from a video.""" + # Other loading configs + build_mrope_positions: bool = False + """Build 3D MRoPE position IDs (``mrope_positions``) for models that use + multi-dimensional RoPE""" + def __init__( self, config: Config, @@ -574,6 +579,7 @@ def __init__( temporal_patch_size=config.temporal_patch_size, spatial_merge_size=config.spatial_merge_size, tokenizer=tokenizer, + build_mrope_positions=config.build_mrope_positions, ) dataloader_kwargs = { diff --git a/torchtitan/models/common/config_utils.py b/torchtitan/models/common/config_utils.py index 81d7fcb96a..d88c9d138d 100644 --- a/torchtitan/models/common/config_utils.py +++ b/torchtitan/models/common/config_utils.py @@ -25,7 +25,6 @@ from torchtitan.models.common.moe import ( GroupedExperts, MoE, - SharedExperts, TokenChoiceTopKRouter, ) from torchtitan.models.common.nn_modules import Linear, RMSNorm @@ -156,40 +155,12 @@ def make_ffn_config( ) -def make_shared_experts_config( - *, - dim: int, - hidden_dim: int, - w1_param_init: dict[str, Callable], - w2w3_param_init: dict[str, Callable], - gate_param_init: dict[str, Callable] | None = None, -) -> SharedExperts.Config: - """Build a SharedExperts.Config (SwiGLU FFN with optional sigmoid gate). - - When ``gate_param_init`` is given, the shared expert applies a per-token - sigmoid gate (``sigmoid(gate(x)) * ffn(x)``), e.g. the Qwen3.5 shared - expert. Otherwise it is a plain SwiGLU FFN. - """ - ffn = make_ffn_config( - dim=dim, - hidden_dim=hidden_dim, - w1_param_init=w1_param_init, - w2w3_param_init=w2w3_param_init, - ) - gate = ( - Linear.Config(in_features=dim, out_features=1, param_init=gate_param_init) - if gate_param_init is not None - else None - ) - return SharedExperts.Config(w1=ffn.w1, w2=ffn.w2, w3=ffn.w3, gate=gate) - - def make_moe_config( *, num_experts: int = 8, router: TokenChoiceTopKRouter.Config, experts: GroupedExperts.Config, - shared_experts: SharedExperts.Config | None = None, + shared_experts: FeedForward.Config | None = None, load_balance_coeff: float | None = 1e-3, ) -> MoE.Config: """Build a fully-specified MoE.Config.""" diff --git a/torchtitan/models/common/decoder.py b/torchtitan/models/common/decoder.py index 872817a1a2..c27cdda35e 100644 --- a/torchtitan/models/common/decoder.py +++ b/torchtitan/models/common/decoder.py @@ -88,6 +88,24 @@ def max_seq_len(self) -> int: return rope_cfg.max_seq_len raise ValueError("Decoder config does not define RoPE max_seq_len.") + @property + def first_attn_config(self) -> BaseAttention.Config | None: + """Attention config of the first layer that has one, else None. + + Hybrid models (linear + full attention) don't carry an attention + config on every layer, so callers needing attention metadata (TP + validation, FLOPs, mask type) look up the first full-attention + layer rather than assuming ``layers[0]``. + """ + return next( + ( + layer.attention + for layer in self.layers + if layer.attention is not None + ), + None, + ) + def update_from_config( self, *, diff --git a/torchtitan/models/common/moe.py b/torchtitan/models/common/moe.py index ce0bb7bcf3..bb72f31995 100644 --- a/torchtitan/models/common/moe.py +++ b/torchtitan/models/common/moe.py @@ -295,30 +295,6 @@ def forward( ) -class SharedExperts(FeedForward): - """Shared expert: SwiGLU FFN with an optional per-token sigmoid gate. - - When ``gate`` is set, the output is ``sigmoid(gate(x)) * ffn(x)``; - otherwise it is a plain SwiGLU FFN. Inherits ``w1/w2/w3`` from - FeedForward so weight FQNs are unchanged. - """ - - @dataclass(kw_only=True, slots=True) - class Config(FeedForward.Config): - gate: Linear.Config | None = None - - def __init__(self, config: Config): - super().__init__(config) - self.gate = config.gate.build() if config.gate is not None else None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = super().forward(x) - if self.gate is not None: - # TODO: make the gate activation configurable (e.g. softmax, silu) - out = torch.sigmoid(self.gate(x)) * out - return out - - class MoE(Module): """Mixture of Experts layer. @@ -346,7 +322,7 @@ class Config(Module.Config): experts: GroupedExperts.Config router: TokenChoiceTopKRouter.Config load_balance_coeff: float | None = 1e-3 - shared_experts: SharedExperts.Config | None = None + shared_experts: FeedForward.Config | None = None def __init__(self, config: Config): super().__init__() diff --git a/torchtitan/models/common/moe_sharding.py b/torchtitan/models/common/moe_sharding.py index e1d99bb968..67d219688e 100644 --- a/torchtitan/models/common/moe_sharding.py +++ b/torchtitan/models/common/moe_sharding.py @@ -230,11 +230,12 @@ def set_moe_sharding_config( # Router gate: dense-family TP plan with Partial output grad. moe_cfg.router.gate.sharding_config = _router_gate_config(enable_ep=enable_ep) - # Shared experts: optional SharedExperts (SwiGLU FFN + optional gate). - # Gather x to Replicate ONCE at the module boundary so w1/w3/gate all share - # it (their per-linear input redistributions become no-ops). w2 (rowwise) - # keeps the output Partial; the Partial->sp_layout reduce happens once at - # the MoE boundary. + # Shared experts: SwiGLU FFN run in parallel with the routed experts. + # Gather x to Replicate ONCE at the module boundary so w1/w3 share it (their + # per-linear input redistributions become no-ops). w2 (rowwise) keeps the + # output Partial; the Partial->sp_layout reduce happens once at the MoE + # boundary. Model-specific shared-expert extras are sharded by that + # model's own sharding code. shared = moe_cfg.shared_experts if shared is not None: sp_layout = Shard(1) if enable_sp else Replicate() @@ -252,17 +253,6 @@ def set_moe_sharding_config( enable_ep=enable_ep, enable_sp=enable_sp ) - if shared.gate is not None: - # Gate output is Replicate, so `gate * ffn` is - # `Replicate * Partial = Partial` with no extra collective. - shared.gate.sharding_config = ShardingConfig( - state_shardings={ - "weight": dense_param_placement(tp=Replicate()), - "bias": dense_param_placement(tp=Replicate()), - }, - out_dst_shardings=dense_activation_placement(tp=Replicate()), - ) - # Routed experts: local_map converts DTensor inputs to local for # dispatch/compute/combine, then wraps local output as DTensor(Partial). # Routed experts: the three things that differ between EP and TP-only diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index ce540a5db4..d66071a641 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -27,7 +27,6 @@ make_ffn_config, make_moe_config, make_router_config, - make_shared_experts_config, ) from torchtitan.models.common.param_init import depth_scaled_std from torchtitan.models.utils import validate_converter_order @@ -243,7 +242,7 @@ def _build_dsv3_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), - shared_experts=make_shared_experts_config( + shared_experts=make_ffn_config( dim=dim, hidden_dim=moe_hidden_dim * num_shared_experts, w1_param_init=_LINEAR_INIT, diff --git a/torchtitan/models/qwen3_5/README.md b/torchtitan/models/qwen3_5/README.md index aed8135752..7c2b9169da 100644 --- a/torchtitan/models/qwen3_5/README.md +++ b/torchtitan/models/qwen3_5/README.md @@ -24,13 +24,7 @@ Note: the diagram shows each patch mapping to one vision token. In practice, the Install the additional dependencies: ```bash -pip install av torchvision -``` - -For GatedDeltaNet GPU efficiency (optional, pure-torch fallback available): - -```bash -pip install flash-linear-attention +pip install av torchvision flash-linear-attention ``` ## Model Variants diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py index 777581f72c..158364a48a 100644 --- a/torchtitan/models/qwen3_5/__init__.py +++ b/torchtitan/models/qwen3_5/__init__.py @@ -12,14 +12,13 @@ from torchtitan.components.optimizer import register_moe_load_balancing_hook -from torchtitan.models.common import Conv1d, Embedding, Linear, RoPE # noqa: F401 +from torchtitan.models.common import Conv1d, Embedding, Linear # noqa: F401 from torchtitan.models.common.config_utils import ( get_attention_config, make_experts_config, make_ffn_config, make_moe_config, make_router_config, - make_shared_experts_config, ) from torchtitan.models.common.nn_modules import LayerNorm from torchtitan.models.common.param_init import depth_scaled_std # noqa: F401 @@ -36,8 +35,10 @@ Qwen35Model, Qwen35TransformerBlock, RMSNormGated, + SharedExperts, ) from .parallelize import parallelize_qwen3_5, pipeline_qwen3_5 +from .rope import MRoPE from .state_dict_adapter import Qwen35StateDictAdapter from .vision_encoder import ( PatchMerger, @@ -115,6 +116,24 @@ def _offset_norm(dim: int) -> OffsetRMSNorm.Config: return OffsetRMSNorm.Config(dim=dim, eps=_EPS, param_init=_OFFSET_NORM_INIT) +def _shared_experts_config( + *, dim: int, hidden_dim: int, layer_id: int +) -> SharedExperts.Config: + """Build Qwen3.5's sigmoid-gated shared-expert config (SwiGLU FFN + gate).""" + ffn = make_ffn_config( + dim=dim, + hidden_dim=hidden_dim, + w1_param_init=_LINEAR_INIT, + w2w3_param_init=_depth_init(layer_id), + ) + return SharedExperts.Config( + w1=ffn.w1, + w2=ffn.w2, + w3=ffn.w3, + gate=Linear.Config(in_features=dim, out_features=1, param_init=_LINEAR_INIT), + ) + + def _qwen35_vision_encoder_config( *, dim: int, @@ -182,6 +201,7 @@ def _qwen35_attention_config( n_kv_heads: int, head_dim: int, rotary_dim: int, + rope: MRoPE.Config, attn_backend: str, layer_id: int, ) -> Qwen35Attention.Config: @@ -192,6 +212,7 @@ def _qwen35_attention_config( n_kv_heads=n_kv_heads, head_dim=head_dim, rotary_dim=rotary_dim, + rope=rope, wq=Linear.Config( in_features=dim, out_features=n_heads * head_dim * 2, @@ -288,6 +309,7 @@ def _build_qwen35_layers( n_kv_heads: int, head_dim: int, rotary_dim: int, + rope: MRoPE.Config, hidden_dim: int, n_key_heads: int, n_value_heads: int, @@ -311,6 +333,7 @@ def _build_qwen35_layers( n_kv_heads=n_kv_heads, head_dim=head_dim, rotary_dim=rotary_dim, + rope=rope, attn_backend=attn_backend, layer_id=layer_id, ) @@ -356,6 +379,7 @@ def _build_qwen35_moe_layers( n_kv_heads: int, head_dim: int, rotary_dim: int, + rope: MRoPE.Config, moe_hidden_dim: int, num_experts: int, top_k: int, @@ -384,6 +408,7 @@ def _build_qwen35_moe_layers( n_kv_heads=n_kv_heads, head_dim=head_dim, rotary_dim=rotary_dim, + rope=rope, attn_backend=attn_backend, layer_id=layer_id, ) @@ -428,12 +453,10 @@ def _build_qwen35_moe_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), - shared_experts=make_shared_experts_config( + shared_experts=_shared_experts_config( dim=dim, hidden_dim=shared_expert_hidden_dim, - w1_param_init=_LINEAR_INIT, - w2w3_param_init=_depth_init(layer_id), - gate_param_init=_LINEAR_INIT, + layer_id=layer_id, ), ), attention_norm=_offset_norm(dim), @@ -467,13 +490,13 @@ def _debugmodel(attn_backend: str) -> Qwen35Model.Config: out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=4096, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=4096, + theta=10_000_000.0, + mrope_section=[3, 3, 2], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -499,7 +522,6 @@ def _debugmodel(attn_backend: str) -> Qwen35Model.Config: out_hidden_size=256, num_position_embeddings=1024, ), - mrope_section=[3, 3, 2], ) @@ -528,13 +550,13 @@ def _debugmodel_moe( out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=4096, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_moe_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=4096, + theta=10_000_000.0, + mrope_section=[3, 3, 2], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -564,7 +586,6 @@ def _debugmodel_moe( out_hidden_size=256, num_position_embeddings=1024, ), - mrope_section=[3, 3, 2], ) @@ -595,13 +616,13 @@ def _0_8b(attn_backend: str) -> Qwen35Model.Config: out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -626,7 +647,6 @@ def _0_8b(attn_backend: str) -> Qwen35Model.Config: out_hidden_size=1024, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -657,13 +677,13 @@ def _2b(attn_backend: str) -> Qwen35Model.Config: out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -688,7 +708,6 @@ def _2b(attn_backend: str) -> Qwen35Model.Config: out_hidden_size=2048, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -718,13 +737,13 @@ def _4b(attn_backend: str) -> Qwen35Model.Config: out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -749,7 +768,6 @@ def _4b(attn_backend: str) -> Qwen35Model.Config: out_hidden_size=2560, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -775,13 +793,13 @@ def _9b(attn_backend: str) -> Qwen35Model.Config: out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -806,7 +824,6 @@ def _9b(attn_backend: str) -> Qwen35Model.Config: out_hidden_size=4096, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -832,13 +849,13 @@ def _27b(attn_backend: str) -> Qwen35Model.Config: out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -863,7 +880,6 @@ def _27b(attn_backend: str) -> Qwen35Model.Config: out_hidden_size=5120, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -892,13 +908,13 @@ def _35b_a3b( out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_moe_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -927,7 +943,6 @@ def _35b_a3b( out_hidden_size=2048, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -956,13 +971,13 @@ def _122b_a10b( out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_moe_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -991,7 +1006,6 @@ def _122b_a10b( out_hidden_size=3072, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) @@ -1020,13 +1034,13 @@ def _397b_a17b( out_features=vocab_size, param_init=_output_linear_init(dim), ), - rope=RoPE.Config( - dim=rotary_dim, - max_seq_len=262144, - theta=10_000_000.0, - backend="cos_sin", - ), layers=_build_qwen35_moe_layers( + rope=MRoPE.Config( + dim=rotary_dim, + max_seq_len=262144, + theta=10_000_000.0, + mrope_section=[11, 11, 10], + ), attn_backend=attn_backend, n_layers=n_layers, dim=dim, @@ -1055,7 +1069,6 @@ def _397b_a17b( out_hidden_size=4096, num_position_embeddings=2304, ), - mrope_section=[11, 11, 10], ) diff --git a/torchtitan/models/qwen3_5/config_registry.py b/torchtitan/models/qwen3_5/config_registry.py index 6fa2f35a17..d6d1518d4d 100644 --- a/torchtitan/models/qwen3_5/config_registry.py +++ b/torchtitan/models/qwen3_5/config_registry.py @@ -33,6 +33,7 @@ def _dataloader(dataset: str, **kwargs) -> MMDataLoader.Config: max_pixels=16777216, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), + build_mrope_positions=True, **kwargs, ) diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py index e751c25282..d43042516c 100644 --- a/torchtitan/models/qwen3_5/model.py +++ b/torchtitan/models/qwen3_5/model.py @@ -5,37 +5,31 @@ # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Literal import torch import torch.nn.functional as F + +from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule as _fla_chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule as _fla_fused_recurrent_gated_delta_rule, +) from torch import nn from torch.distributed.tensor import DTensor from torch.distributed.tensor.experimental import local_map -from torchtitan.models.common import Conv1d, Linear +from torchtitan.models.common import Conv1d, FeedForward, Linear from torchtitan.models.common.attention import AttentionMasksType, BaseAttention from torchtitan.models.common.decoder import Decoder -from torchtitan.models.common.rope import apply_rotary_emb_cos_sin from torchtitan.models.utils import get_moe_model_nparams_and_flops from torchtitan.protocols.module import Module +from .rope import MRoPE from .sharding import set_qwen35_sharding_config from .vision_encoder import Qwen35VisionEncoder -try: - from fla.ops.gated_delta_rule import ( - chunk_gated_delta_rule as _fla_chunk_gated_delta_rule, - fused_recurrent_gated_delta_rule as _fla_fused_recurrent_gated_delta_rule, - ) - - _HAS_FLA = True -except ImportError: - _HAS_FLA = False - - def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: """L2 norm using rsqrt(sum(x²) + eps), not x/max(norm, eps) like F.normalize, to match FLA kernel.""" return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) @@ -89,6 +83,27 @@ def _torch_native_gated_delta( return output.to(dtype) +class SharedExperts(FeedForward): + """Qwen3.5 shared expert: SwiGLU FFN with a per-token sigmoid gate. + + The output is ``sigmoid(gate(x)) * ffn(x)``. Inherits ``w1/w2/w3`` from + FeedForward so weight FQNs are unchanged. This gate is specific to + Qwen3.5; other models use a plain ``FeedForward`` shared expert. + """ + + @dataclass(kw_only=True, slots=True) + class Config(FeedForward.Config): + gate: Linear.Config + + def __init__(self, config: Config): + super().__init__(config) + self.gate = config.gate.build() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = super().forward(x) + return torch.sigmoid(self.gate(x)) * out + + class OffsetRMSNorm(Module): """RMSNorm with offset: ``(1 + weight) * norm(x)``. @@ -181,12 +196,6 @@ def forward( if self.backend == "torch_native": return _torch_native_gated_delta(q, k, v, g, beta) - if not _HAS_FLA: - raise RuntimeError( - f"Backend '{self.backend}' requires the `fla` package. " - "Install: pip install flash-linear-attention" - ) - if self.backend == "fla_chunked": result = _fla_chunk_gated_delta_rule( q, @@ -309,7 +318,10 @@ def _conv(x_local: torch.Tensor, w_local: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: bs, seqlen, _ = x.shape - # Split projections (not fused QKV) so each is ColwiseParallel for TP. + # Shapes: + # xq, xk: (bs, seqlen, n_key_heads * key_head_dim) + # xv, xz: (bs, seqlen, n_value_heads * value_head_dim) + # xa, xb: (bs, seqlen, n_value_heads) xq = self._causal_conv(self.in_proj_q(x), self.conv_q) xk = self._causal_conv(self.in_proj_k(x), self.conv_k) xv = self._causal_conv(self.in_proj_v(x), self.conv_v) @@ -321,10 +333,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: xk = xk.view(bs, seqlen, -1, self.key_head_dim) xv = xv.view(bs, seqlen, -1, self.value_head_dim) - g = -torch.exp(self.A_log.float()) * F.softplus( - xa.float() + self.dt_bias - ) # decay rate, always negative - beta = torch.sigmoid(xb) # update gate ∈ (0, 1) + # Gating signals, shape (bs, seqlen, n_value_heads): + # g: decay rate per head, always negative + # beta: update gate ∈ (0, 1) + g = -torch.exp(self.A_log.float()) * F.softplus(xa.float() + self.dt_bias) + beta = torch.sigmoid(xb) output = self.kernel(xq, xk, xv, g, beta) @@ -343,6 +356,10 @@ class Qwen35Attention(BaseAttention): - Partial RoPE: only first ``rotary_dim`` elements get RoPE - Output gating: ``attn_output * sigmoid(gate)`` before ``wo`` - QK norm uses OffsetRMSNorm + + Uses separate ``wq``/``wk``/``wv`` instead of the common fused ``qkv_linear`` + (so this subclasses ``BaseAttention``, not ``GQAttention``): the 2x-wide, + gated ``wq`` doesn't fit a fused QKV projection that TP-shards by head. """ @dataclass(kw_only=True, slots=True) @@ -351,6 +368,7 @@ class Config(BaseAttention.Config): n_kv_heads: int head_dim: int rotary_dim: int + rope: MRoPE.Config wq: Linear.Config wk: Linear.Config wv: Linear.Config @@ -373,6 +391,8 @@ def __init__(self, config: Config): self.wv = config.wv.build() self.wo = config.wo.build() + self.rope = config.rope.build() + self.q_norm = config.q_norm.build() self.k_norm = config.k_norm.build() @@ -383,7 +403,6 @@ def __init__(self, config: Config): def forward( self, x: torch.Tensor, - rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, positions: torch.Tensor | None = None, ) -> torch.Tensor: @@ -403,7 +422,7 @@ def forward( assert self.rotary_dim <= self.head_dim xq_rot, xq_pass = xq[..., : self.rotary_dim], xq[..., self.rotary_dim :] xk_rot, xk_pass = xk[..., : self.rotary_dim], xk[..., self.rotary_dim :] - xq_rot, xk_rot = apply_rotary_emb_cos_sin(xq_rot, xk_rot, rope_cache, positions) + xq_rot, xk_rot = self.rope(xq_rot, xk_rot, positions) xq = torch.cat([xq_rot, xq_pass], dim=-1) xk = torch.cat([xk_rot, xk_pass], dim=-1) @@ -463,13 +482,12 @@ def __init__(self, config: Config): def forward( self, x: torch.Tensor, - freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, positions: torch.Tensor | None = None, ) -> torch.Tensor: h = self.attention_norm(x) if self.full_attn: - h = self.attn(h, freqs_cis, attention_masks, positions) + h = self.attn(h, attention_masks, positions) else: h = self.attn(h) x = x + h @@ -494,25 +512,31 @@ class Qwen35Model(Decoder): - Output gating on full attention: ``attn_out * sigmoid(gate)`` - Partial RoPE: only first ``rotary_dim`` elements get positional encoding - OffsetRMSNorm: ``(1 + weight) * norm(x)`` with zero-init weight - - MRoPE: 3D position IDs (temporal, height, width) for vision tokens + - MRoPE: 3D (temporal/height/width) position IDs for multimodal batches; + text batches use the plain 1D positions - MoE variant: routed experts + shared expert with sigmoid gate + MRoPE positions (``mrope_positions``, shape ``(batch, seq, 3)``) are built by + the dataloader and forwarded to every pipeline stage, so RoPE stays consistent + across stages even though the raw vision inputs (``pixel_values``/``grid_thw``) + only reach the first stage. Text batches carry no ``mrope_positions`` and use + the 2D ``positions`` instead. + Forward pass flow:: - forward(tokens, pixel_values, grid_thw, ...) + forward(tokens, pixel_values, grid_thw, mrope_positions, ...) │ ├─ _prepare_multimodal_embeds │ ├─ tok_embeddings(tokens) → text embeddings │ ├─ _get_vision_embeds(pixel_values) → vision embeddings │ │ └─ vision_encoder(pixel_values) → merge patches - │ ├─ _compute_vision_positions → locate vision regions + │ ├─ _get_vision_positions → locate vision regions │ └─ _scatter_vision_embeds → scatter into text sequence │ - ├─ _compute_mrope_freqs → 3D position IDs → interleaved cos/sin - │ - └─ transformer layers (hybrid) + └─ transformer layers (hybrid), each given (mrope_positions or positions) └─ for each layer: ├─ full attention (every Nth): QK-norm → partial RoPE → SDPA → gate + │ (the layer's MRoPE builds the cos/sin cache from positions) └─ GatedDeltaNet (others): Conv1d → gated delta rule → gated norm """ @@ -520,10 +544,6 @@ class Qwen35Model(Decoder): class Config(Decoder.Config): vision_encoder: Qwen35VisionEncoder.Config - # MRoPE section sizes for interleaved multi-dimensional RoPE - # [temporal, height, width] - controls how position dimensions are interleaved - mrope_section: list[int] = field(default_factory=lambda: [24, 20, 20]) - def update_from_config( self, *, @@ -579,218 +599,9 @@ def __init__(self, config: Config): super().__init__(config) self.vision_encoder = config.vision_encoder.build() - - self.mrope_section = config.mrope_section self.spatial_merge_size = config.vision_encoder.spatial_merge_size - def _compute_mrope_freqs( - self, - tokens: torch.Tensor, - *, - grid_thw: torch.Tensor | None, - grid_thw_videos: torch.Tensor | None, - special_tokens: dict[str, int], - positions: torch.Tensor | None = None, - ) -> torch.Tensor: - """Build 3D position IDs and compute interleaved MRoPE cos/sin frequencies. - - Constructs (temporal, height, width) position IDs for each token, then - looks up cos/sin from the 1D RoPE table and overwrites H/W-assigned dims - with their own position lookups. - - Args: - tokens: (batch, seq_len) token IDs - grid_thw: (num_images, 3) grid dimensions for images - grid_thw_videos: (num_videos, 3) grid dimensions for videos - special_tokens: Special token definitions - positions: (batch, seq_len) per-token position IDs for packed - sequences. When provided, document boundaries are detected - where positions reset (positions[t] < positions[t-1]), and - pos_id_offset resets to 0 at each boundary - - Returns: - (batch, seq_len, 1, head_dim * 2) pre-computed MRoPE cos/sin - """ - # --- Build 3D position IDs --- - - # Expand each video [T, H, W] into T rows of [1, H, W] so that - # each frame is treated like an image in the MRoPE code below - # Temporal position comes from frame ordering in the sequence - if grid_thw_videos is not None: - grid_thw_videos = torch.repeat_interleave( - grid_thw_videos, grid_thw_videos[:, 0], dim=0 - ) - grid_thw_videos[:, 0] = 1 - - spatial_merge_size = self.spatial_merge_size - image_token_id = special_tokens["image_id"] - video_token_id = special_tokens["video_id"] - - batch_size, seq_len = tokens.shape - position_ids = torch.zeros( - 3, - batch_size, - seq_len, - dtype=tokens.dtype, - device=tokens.device, - ) - - # Precompute document boundaries and vision token positions across batch - if positions is not None: - resets = positions[:, 1:] < positions[:, :-1] # (batch, seq_len-1) - # Find the first token of each consecutive vision region (image or video) - # E.g. for [text, img, img, img, text, vid, vid] → positions [1, 5] - vision_mask = (tokens == image_token_id) | (tokens == video_token_id) - prev_vision = torch.cat( - [torch.zeros_like(vision_mask[:, :1]), vision_mask[:, :-1]], dim=1 - ) - batch_vision_starts = vision_mask & ~prev_vision # (batch, seq_len) - # Cache vision grid indices by shape to avoid redundant construction - grid_cache: dict[tuple[int, int, int], torch.Tensor] = {} - - image_index, video_index = 0, 0 - # Build MRoPE 3D position IDs per sample - # With sample packing, each sample may contain multiple documents - for sample_i in range(batch_size): - llm_pos_ids_list: list[torch.Tensor] = [] - - if positions is not None: - # Detect document boundaries within one packed sample - # pyrefly: ignore [unbound-name] - reset_indices = torch.where(resets[sample_i])[0] + 1 - doc_starts = [0] + reset_indices.tolist() - doc_ranges = [ - ( - doc_starts[d], - doc_starts[d + 1] if d + 1 < len(doc_starts) else seq_len, - ) - for d in range(len(doc_starts)) - ] - else: - doc_ranges = [(0, seq_len)] - - sample_tokens = tokens[sample_i] - sample_vision_starts = torch.where(batch_vision_starts[sample_i])[ - 0 - ].tolist() - vision_start_index = 0 - - for doc_start, doc_end in doc_ranges: - doc_pos_ids_list: list[torch.Tensor] = [] - - # Advance pointer to collect vision region starts in this document - doc_vision_starts: list[int] = [] - while ( - vision_start_index < len(sample_vision_starts) - and sample_vision_starts[vision_start_index] < doc_end - ): - doc_vision_starts.append(sample_vision_starts[vision_start_index]) - vision_start_index += 1 - - # Process [text tokens][vision tokens] pairs within this document - pair_cursor = doc_start - for vision_start in doc_vision_starts: - if sample_tokens[vision_start] == image_token_id: - # pyrefly: ignore [unsupported-operation] - t, h, w = grid_thw[image_index] - image_index += 1 - else: - # pyrefly: ignore [unsupported-operation] - t, h, w = grid_thw_videos[video_index] - video_index += 1 - - llm_grid_t, llm_grid_h, llm_grid_w = ( - int(t.item()), - int(h.item()) // spatial_merge_size, - int(w.item()) // spatial_merge_size, - ) - text_len = vision_start - pair_cursor - - # pos_id_offset may differ from pair_cursor due to compact - # spatial position IDs for vision regions - pos_id_offset = ( - doc_pos_ids_list[-1].max() + 1 - if len(doc_pos_ids_list) > 0 - else 0 - ) - # [text tokens] — sequential positions, identical on all 3 axes - doc_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset - ) - # [vision tokens] — 3D grid positions (T, H, W) - grid_key = (llm_grid_t, llm_grid_h, llm_grid_w) - if grid_key not in grid_cache: - hw = llm_grid_h * llm_grid_w - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, hw) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - grid_cache[grid_key] = torch.stack([t_index, h_index, w_index]) - doc_pos_ids_list.append( - grid_cache[grid_key] + text_len + pos_id_offset - ) - pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w - - # Trailing [text tokens] after the last [text tokens][vision tokens] pair - if pair_cursor < doc_end: - pos_id_offset = ( - doc_pos_ids_list[-1].max() + 1 - if len(doc_pos_ids_list) > 0 - else 0 - ) - text_len = doc_end - pair_cursor - doc_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset - ) - - llm_pos_ids_list.extend(doc_pos_ids_list) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[:, sample_i, :] = llm_positions.to(position_ids.device) - - # --- Compute interleaved MRoPE cos/sin from position IDs --- - # Convert to local — DTensor doesn't support fancy indexing with - # plain-tensor indices (cos_cache[t_pos], sin_cache[:, col][dim_pos]). - freqs_cis = self.freqs_cis - if isinstance(freqs_cis, DTensor): - freqs_cis = freqs_cis.to_local() - head_dim = freqs_cis.shape[-1] // 2 - cos_cache = freqs_cis[:, :head_dim] - sin_cache = freqs_cis[:, head_dim:] - - # Initialize with temporal positions, then overwrite H/W slices - t_pos = position_ids[0].long() - mrope_cos = cos_cache[t_pos] - mrope_sin = sin_cache[t_pos] - - # Overwrite H and W slices with their own position lookups - # Both halves of head_dim must be updated (head_dim = cat([freqs, freqs])) - half = head_dim // 2 - for dim, offset in enumerate((1, 2), start=1): # H, W - length = self.mrope_section[dim] * 3 - low = torch.arange(offset, length, 3, device=freqs_cis.device) - col_indices = torch.cat([low, low + half]) - dim_pos = position_ids[dim].long() - mrope_cos[..., col_indices] = cos_cache[:, col_indices][dim_pos] - mrope_sin[..., col_indices] = sin_cache[:, col_indices][dim_pos] - - return torch.cat([mrope_cos, mrope_sin], dim=-1).unsqueeze(2) - - def _compute_vision_positions( + def _get_vision_positions( self, tokens: torch.Tensor, num_tokens_per_item: torch.Tensor, @@ -908,7 +719,7 @@ def _prepare_multimodal_embeds( merged_embeds, num_tokens = self._get_vision_embeds( pixel_values, grid_thw=grid_thw ) - image_positions = self._compute_vision_positions( + image_positions = self._get_vision_positions( tokens, num_tokens, image_token_id ) if image_positions: @@ -922,7 +733,7 @@ def _prepare_multimodal_embeds( merged_embeds, num_tokens = self._get_vision_embeds( pixel_values_videos, grid_thw=grid_thw_videos ) - video_positions = self._compute_vision_positions( + video_positions = self._get_vision_positions( tokens, num_tokens, video_token_id ) if video_positions: @@ -944,6 +755,7 @@ def forward( # pyrefly: ignore [bad-override] grid_thw_videos: torch.Tensor | None = None, attention_masks: AttentionMasksType | None = None, positions: torch.Tensor | None = None, + mrope_positions: torch.Tensor | None = None, special_tokens: dict[str, int] | None = None, ): if self.tok_embeddings is not None: @@ -958,18 +770,11 @@ def forward( # pyrefly: ignore [bad-override] else: x = tokens - if grid_thw is not None or grid_thw_videos is not None: - freqs_cis = self._compute_mrope_freqs( - tokens, - grid_thw=grid_thw, - grid_thw_videos=grid_thw_videos, - special_tokens=special_tokens, # pyrefly: ignore [bad-argument-type] - positions=positions, - ) - else: - freqs_cis = self.freqs_cis + # 3D MRoPE positions for multimodal batches, else 2D text positions. + rope_positions = mrope_positions if mrope_positions is not None else positions + assert rope_positions is not None for layer in self.layers.values(): - x = layer(x, freqs_cis, attention_masks, positions) + x = layer(x, attention_masks, rope_positions) x = self.norm(x) if self.norm is not None else x if self._skip_lm_head: diff --git a/torchtitan/models/qwen3_5/rope.py b/torchtitan/models/qwen3_5/rope.py index 50368130b7..e2577c29b6 100644 --- a/torchtitan/models/qwen3_5/rope.py +++ b/torchtitan/models/qwen3_5/rope.py @@ -13,7 +13,14 @@ class MRoPE(CosSinRoPE): - """Multi-dimensional RoPE for Qwen3-VL temporal/height/width positions.""" + """Multi-dimensional RoPE for Qwen3.5 temporal/height/width positions. + + Standard per-layer RoPE: each full-attention layer owns an ``MRoPE`` and + applies it through ``RoPE.forward`` -> ``_reshape_cache`` -> ``apply_rotary_emb``. + The only override is ``_reshape_cache``: for 3D ``(batch, seq, 3)`` MRoPE + positions it builds an interleaved cos/sin cache; for 2D ``(batch, seq)`` text + positions it falls back to the plain ``CosSinRoPE`` per-token lookup. + """ @dataclass(kw_only=True, slots=True) class Config(CosSinRoPE.Config): @@ -31,44 +38,56 @@ def _reshape_cache( query: torch.Tensor, positions: torch.Tensor | None = None, ) -> torch.Tensor: - """Build a position-specific cache for 3D MRoPE position IDs.""" + """Build a query-broadcastable cos/sin cache. + + Dispatches on position rank: 3D ``(batch, seq, 3)`` MRoPE positions take + the interleaved scatter; everything else (2D text positions or ``None``) + falls back to the plain ``CosSinRoPE`` lookup. + """ if positions is not None and positions.ndim == 3: return self._compute_mrope_cache(positions) return super()._reshape_cache(query, positions) def _compute_mrope_cache(self, position_ids: torch.Tensor) -> torch.Tensor: + """Build the interleaved cos/sin cache for 3D MRoPE positions. + + Args: + position_ids: ``(batch, seq, 3)`` T/H/W positions. Plain, or a DTensor + under TP matching the rope ``cache`` buffer's Replicate placement. + + Returns: + ``(batch, seq, 1, dim * 2)`` cache, broadcastable to the + ``(batch, seq, n_heads, rotary_dim)`` query/key in ``apply_rotary_emb``. + + The scatter runs on plain local tensors. Under TP the ``cache`` buffer is a + Replicate DTensor, so it is unwrapped to local here and the result is + re-distributed with the buffer's placements, yielding a DTensor that + composes with the sharded query/key without any manual wrapping in the + attention forward. + """ cfg = self.config assert isinstance(cfg, MRoPE.Config) - if position_ids.shape[0] != 3: - raise ValueError( - f"MRoPE position IDs must have shape (3, batch, seq), " - f"got {tuple(position_ids.shape)}." - ) rope_cache = self.cache cache_dtensor = rope_cache if isinstance(rope_cache, DTensor) else None if cache_dtensor is not None: rope_cache = cache_dtensor.to_local() - - position_dtensor = position_ids if isinstance(position_ids, DTensor) else None - pos_local = ( - position_dtensor.to_local() - if position_dtensor is not None + pos = ( + position_ids.to_local() + if isinstance(position_ids, DTensor) else position_ids ) - pos_local = pos_local.to(device=rope_cache.device) + pos = pos.to(device=rope_cache.device) - _maybe_check_max_pos( - pos_local, - max_valid_pos=rope_cache.shape[0] - 1, - ) + _maybe_check_max_pos(pos, max_valid_pos=rope_cache.shape[0] - 1) head_dim = rope_cache.shape[-1] // 2 cos_cache = rope_cache[:, :head_dim] sin_cache = rope_cache[:, head_dim:] # Start from temporal positions for all dimensions, then overwrite the # height/width interleaved sections with their own position IDs. - t_pos = pos_local[0].long() + # ``pos`` is (batch, seq, 3); the last axis selects T/H/W. + t_pos = pos[..., 0].long() mrope_cos = cos_cache[t_pos] mrope_sin = sin_cache[t_pos] @@ -77,7 +96,7 @@ def _compute_mrope_cache(self, position_ids: torch.Tensor) -> torch.Tensor: length = cfg.mrope_section[dim] * 3 low = torch.arange(offset, length, 3, device=rope_cache.device) col_indices = torch.cat([low, low + half]) - dim_pos = pos_local[dim].long() + dim_pos = pos[..., dim].long() mrope_cos[..., col_indices] = cos_cache[:, col_indices][dim_pos] mrope_sin[..., col_indices] = sin_cache[:, col_indices][dim_pos] diff --git a/torchtitan/models/qwen3_5/sharding.py b/torchtitan/models/qwen3_5/sharding.py index a90fbbc88a..234b605c2c 100644 --- a/torchtitan/models/qwen3_5/sharding.py +++ b/torchtitan/models/qwen3_5/sharding.py @@ -9,10 +9,11 @@ Sets ``ShardingConfig`` on all sub-configs so that ``model.parallelize()`` applies TP via the Module protocol. Same pattern as ``qwen3/sharding.py``. -Full attention layers: TP on wq/wk/wv/wo with local_map for inner attention. +Full-attention layers: TP on wq/wk/wv/wo with local_map for inner attention; +each layer's MRoPE ``cache`` buffer is sharded Replicate. GatedDeltaNet layers: head-sharded TP on projections (ColwiseParallel) and -out_proj (RowwiseParallel). FLA kernel uses local_map for DTensor→local -conversion. Conv1d sharding is set on built modules. +out_proj (RowwiseParallel); the FLA kernel and depthwise Conv1d run on local +tensors via local_map. """ from typing import TYPE_CHECKING @@ -34,24 +35,44 @@ if TYPE_CHECKING: from torchtitan.models.qwen3_5.model import ( + GatedDeltaNet, Qwen35Attention, Qwen35Model, Qwen35TransformerBlock, + SharedExperts, ) + from torchtitan.models.qwen3_5.vision_encoder import Qwen35VisionEncoder -_REPLICATE_PARAM = dense_param_placement(tp=Replicate()) -_REPLICATE_STATE = ShardingConfig( - state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM} -) -_REPLICATE_ACT = dense_activation_placement(tp=Replicate()) - -# For norms/modules that receive and emit Replicate activations -_REPLICATE_NORM = ShardingConfig( - state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM}, - in_src_shardings={"input": _REPLICATE_ACT}, - in_dst_shardings={"input": _REPLICATE_ACT}, - out_dst_shardings=_REPLICATE_ACT, -) +def _replicate_norm() -> ShardingConfig: + """Replicate norm (weight/bias and activations) — used by the vision + encoder, which runs without sequence parallelism.""" + return ShardingConfig( + state_shardings={ + "weight": dense_param_placement(tp=Replicate()), + "bias": dense_param_placement(tp=Replicate()), + }, + in_src_shardings={"input": dense_activation_placement(tp=Replicate())}, + in_dst_shardings={"input": dense_activation_placement(tp=Replicate())}, + out_dst_shardings=dense_activation_placement(tp=Replicate()), + ) + + +def _qk_norm_sharding() -> ShardingConfig: + """Per-head QK-norm sharding: weight Replicate, activations Shard(2).""" + head_plc = dense_activation_placement(tp=Shard(2)) + return ShardingConfig( + state_shardings={"weight": dense_param_placement(tp=Replicate())}, + in_src_shardings={"input": head_plc}, + in_dst_shardings={"input": head_plc}, + out_dst_shardings=head_plc, + ) + + +def _conv_weight_sharding() -> ShardingConfig: + """Depthwise Conv1d weight sharded Shard(0) on out-channels (head-sharded).""" + return ShardingConfig( + state_shardings={"weight": dense_param_placement(tp=Shard(0))}, + ) _GROUPED_EXPERTS_PARAM_LAYOUT: dict[str, Placement] = { @@ -73,14 +94,15 @@ def set_qwen35_sharding_config( stays Replicate so vision scatter and MRoPE can access the full sequence. The model forward redistributes to Shard(1) before entering the layers. """ - # SP on norm, lm_head, and layers. freqs_cis stays Replicate (set by base). + # SP on norm, lm_head, and layers. Each full-attention layer owns its rope; + # its cache buffer is sharded Replicate in _set_full_attention_sharding. set_decoder_sharding_config(config, loss_parallel=loss_parallel, enable_sp=True) # Override tok_embeddings: output Replicate (not Shard(1)) for vision scatter config.tok_embeddings.sharding_config = ShardingConfig( state_shardings={"weight": dense_param_placement(tp=Shard(0))}, - in_src_shardings={"input": _REPLICATE_ACT}, - in_dst_shardings={"input": _REPLICATE_ACT}, - out_dst_shardings=_REPLICATE_ACT, + in_src_shardings={"input": dense_activation_placement(tp=Replicate())}, + in_dst_shardings={"input": dense_activation_placement(tp=Replicate())}, + out_dst_shardings=dense_activation_placement(tp=Replicate()), ) _set_vision_encoder_sharding(config.vision_encoder) for layer_cfg in config.layers: @@ -92,9 +114,8 @@ def _set_qwen35_layer_sharding( *, enable_ep: bool, ) -> None: - norm = norm_config(enable_sp=True) - layer_cfg.attention_norm.sharding_config = norm - layer_cfg.ffn_norm.sharding_config = norm + layer_cfg.attention_norm.sharding_config = norm_config(enable_sp=True) + layer_cfg.ffn_norm.sharding_config = norm_config(enable_sp=True) if layer_cfg.attention is not None: _set_full_attention_sharding(layer_cfg.attention) @@ -116,9 +137,34 @@ def _set_qwen35_layer_sharding( enable_sp=True, expert_param_layout=_GROUPED_EXPERTS_PARAM_LAYOUT, ) + _set_shared_expert_gate_sharding(layer_cfg.moe.shared_experts) + + +def _set_shared_expert_gate_sharding( + shared_experts: "SharedExperts.Config | None", +) -> None: + """Shard Qwen3.5's shared-expert sigmoid gate. + + The common MoE sharding handles the shared FFN (w1/w2/w3) and the + module-boundary gather that feeds the gate a Replicate ``x``. Here we only + add the gate: its weight is Replicate and its output is Replicate, so + ``sigmoid(gate(x)) * ffn(x)`` is ``Replicate * Partial = Partial`` with no + extra collective. ``getattr`` keeps this a no-op when the MoE has no shared + expert (``None``); Qwen3.5's shared expert always carries the gate. + """ + gate = getattr(shared_experts, "gate", None) + if gate is None: + return + gate.sharding_config = ShardingConfig( + state_shardings={ + "weight": dense_param_placement(tp=Replicate()), + "bias": dense_param_placement(tp=Replicate()), + }, + out_dst_shardings=dense_activation_placement(tp=Replicate()), + ) -def _set_vision_encoder_sharding(ve_cfg) -> None: +def _set_vision_encoder_sharding(ve_cfg: "Qwen35VisionEncoder.Config") -> None: """Sharding for the vision encoder. All activations flow as Replicate — no SP in the vision encoder. @@ -126,25 +172,28 @@ def _set_vision_encoder_sharding(ve_cfg) -> None: Norms are Replicate. pos_embed is Replicate via state_shardings. """ ve_cfg.sharding_config = ShardingConfig( - state_shardings={"pos_embed": _REPLICATE_PARAM}, + state_shardings={"pos_embed": dense_param_placement(tp=Replicate())}, ) # patch_embed receives plain pixel_values — wrap as DTensor(Replicate) ve_cfg.patch_embed_proj.sharding_config = ShardingConfig( - state_shardings={"weight": _REPLICATE_PARAM, "bias": _REPLICATE_PARAM}, - in_src_shardings={"input": _REPLICATE_ACT}, - in_dst_shardings={"input": _REPLICATE_ACT}, - out_dst_shardings=_REPLICATE_ACT, + state_shardings={ + "weight": dense_param_placement(tp=Replicate()), + "bias": dense_param_placement(tp=Replicate()), + }, + in_src_shardings={"input": dense_activation_placement(tp=Replicate())}, + in_dst_shardings={"input": dense_activation_placement(tp=Replicate())}, + out_dst_shardings=dense_activation_placement(tp=Replicate()), ) # Block sub-modules block = ve_cfg.block - block.norm1.sharding_config = _REPLICATE_NORM - block.norm2.sharding_config = _REPLICATE_NORM + block.norm1.sharding_config = _replicate_norm() + block.norm2.sharding_config = _replicate_norm() block.attn.sharding_config = ShardingConfig( - in_src_shardings={"rope_cache": _REPLICATE_ACT}, - in_dst_shardings={"rope_cache": _REPLICATE_ACT}, + in_src_shardings={"rope_cache": dense_activation_placement(tp=Replicate())}, + in_dst_shardings={"rope_cache": dense_activation_placement(tp=Replicate())}, ) block.attn.wq.sharding_config = colwise_config() block.attn.wk.sharding_config = colwise_config() @@ -157,7 +206,7 @@ def _set_vision_encoder_sharding(ve_cfg) -> None: # Merger sub-modules merger = ve_cfg.merger - merger.norm.sharding_config = _REPLICATE_NORM + merger.norm.sharding_config = _replicate_norm() merger.fc1.sharding_config = colwise_config() merger.fc2.sharding_config = rowwise_config(output_sp=False) @@ -167,34 +216,26 @@ def _set_full_attention_sharding( ) -> None: """TP sharding for Qwen35Attention (output gating + partial RoPE).""" attention_cfg.sharding_config = ShardingConfig( - in_src_shardings={ - "x": dense_activation_placement(tp=Shard(1)), - "rope_cache": dense_param_placement(tp=Replicate()), - }, - in_dst_shardings={ - "x": dense_activation_placement(tp=Replicate()), - "rope_cache": dense_param_placement(tp=Replicate()), - }, + in_src_shardings={"x": dense_activation_placement(tp=Shard(1))}, + in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, + ) + # The per-layer rope ``cache`` buffer is a Replicate DTensor; MRoPE builds the + # position-resolved cache from it (``positions`` stays a plain input). + attention_cfg.rope.sharding_config = ShardingConfig( + state_shardings={"cache": dense_param_placement(tp=Replicate())}, ) attention_cfg.wq.sharding_config = colwise_config() attention_cfg.wk.sharding_config = colwise_config() attention_cfg.wv.sharding_config = colwise_config() attention_cfg.wo.sharding_config = rowwise_config(output_sp=True) - _head_plc = dense_activation_placement(tp=Shard(2)) - qk_norm_sharding = ShardingConfig( - state_shardings={"weight": _REPLICATE_PARAM}, - in_src_shardings={"input": _head_plc}, - in_dst_shardings={"input": _head_plc}, - out_dst_shardings=_head_plc, - ) - attention_cfg.q_norm.sharding_config = qk_norm_sharding - attention_cfg.k_norm.sharding_config = qk_norm_sharding + attention_cfg.q_norm.sharding_config = _qk_norm_sharding() + attention_cfg.k_norm.sharding_config = _qk_norm_sharding() set_gqa_inner_attention_local_map(attention_cfg.inner_attention) -def _set_deltanet_sharding(deltanet_cfg) -> None: +def _set_deltanet_sharding(deltanet_cfg: "GatedDeltaNet.Config") -> None: """Sharding for GatedDeltaNet: head-sharded TP on projections. Input is allgathered (Shard(1)→Replicate) so that the recurrence @@ -218,12 +259,9 @@ def _set_deltanet_sharding(deltanet_cfg) -> None: getattr(deltanet_cfg, name).sharding_config = colwise_config() # Depthwise Conv1d weights: Shard(0) on out-channels (head-sharded). - _conv_shard = ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Shard(0))}, - ) - deltanet_cfg.conv_q.sharding_config = _conv_shard - deltanet_cfg.conv_k.sharding_config = _conv_shard - deltanet_cfg.conv_v.sharding_config = _conv_shard + deltanet_cfg.conv_q.sharding_config = _conv_weight_sharding() + deltanet_cfg.conv_k.sharding_config = _conv_weight_sharding() + deltanet_cfg.conv_v.sharding_config = _conv_weight_sharding() # RowwiseParallel on output projection (reduce-scatter to SP) deltanet_cfg.out_proj.sharding_config = rowwise_config(output_sp=True) @@ -231,7 +269,7 @@ def _set_deltanet_sharding(deltanet_cfg) -> None: # RMSNormGated: per-head norm, weight Replicate, activations Shard(2) _norm_plc = dense_activation_placement(tp=Shard(2)) deltanet_cfg.norm.sharding_config = ShardingConfig( - state_shardings={"weight": _REPLICATE_PARAM}, + state_shardings={"weight": dense_param_placement(tp=Replicate())}, in_src_shardings={"x": _norm_plc, "gate": _norm_plc}, in_dst_shardings={"x": _norm_plc, "gate": _norm_plc}, out_dst_shardings=_norm_plc, diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index 9246da2c0d..cab7a71f02 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -586,6 +586,7 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} positions = extra_inputs.pop("positions", None) + mrope_positions = extra_inputs.pop("mrope_positions", None) # positions and attention_masks are optional (Decoder.forward defaults # both to None). Build attention masks only for the masked backends @@ -615,6 +616,8 @@ def post_dataloading_process( ) extra_kwargs["positions"] = positions + if mrope_positions is not None: + extra_kwargs["mrope_positions"] = mrope_positions if self.parallel_dims.cp_enabled: inputs, labels, extra_kwargs = prepare_context_parallel_input( From 355fed86cee60482b688b8eae40eabdcaa21b005 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Tue, 9 Jun 2026 02:20:48 -0700 Subject: [PATCH 6/7] refactor spmd types --- .../numerical_tests_qwen3_5.py | 2 +- tests/integration_tests/models.py | 19 ----- tests/unit_tests/test_rope.py | 8 +- torchtitan/models/common/config_utils.py | 6 +- torchtitan/models/common/decoder.py | 1 - torchtitan/models/common/moe_sharding.py | 13 ++- torchtitan/models/qwen3_5/parallelize.py | 13 +-- torchtitan/models/qwen3_5/sharding.py | 83 ++++++++++--------- .../models/qwen3_5/state_dict_adapter.py | 3 + 9 files changed, 68 insertions(+), 80 deletions(-) diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py index e9d993e25f..d120461a40 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py @@ -152,7 +152,7 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224): ) # 3D MRoPE positions (1, S, 3); positions=None → single document. mrope_positions = MultiModalCollator._build_mrope_positions( - mrope_builder, + mrope_builder, # pyrefly: ignore [bad-argument-type] hf_in["input_ids"], grid_thw.unsqueeze(0), None, diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 7e615d7302..51f8a4365a 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -119,25 +119,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "qwen3_fsdp+tp+cp", ngpu=8, ), - # Integration Test Cases for Llama 4 - # TODO: re-enable compile after fixing - # https://github.com/pytorch/torchtitan/issues/2771 - OverrideDefinitions( - [ - [ - "--module llama4 --config llama4_debugmodel_ep", - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule Interleaved1F1B", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - "--parallelism.expert_parallel_degree 4", - # "--compile.enable", - ], - ], - "Llama 4 PP+FSDP+TP+EP+compile", - "llama4_pp+fsdp+tp+ep+compile", - ngpu=8, - ), # Integration Test Cases for Qwen3.5 OverrideDefinitions( [ diff --git a/tests/unit_tests/test_rope.py b/tests/unit_tests/test_rope.py index 5e6f0bc240..f97c02bec0 100644 --- a/tests/unit_tests/test_rope.py +++ b/tests/unit_tests/test_rope.py @@ -20,7 +20,7 @@ CosSinRoPE, RoPE, ) -from torchtitan.models.qwen3_vl.rope import MRoPE +from torchtitan.models.qwen3_5.rope import MRoPE class TestApplyRotaryEmbCosSin(unittest.TestCase): @@ -173,11 +173,11 @@ def test_forward_accepts_three_axis_positions(self): max_seq_len=8, mrope_section=[2, 1, 1], ).build() + # (batch, seq, 3): per-token [temporal, height, width] positions. position_ids = torch.tensor( [ - [[0, 1, 2], [3, 4, 5]], # temporal - [[1, 2, 3], [4, 5, 6]], # height - [[2, 3, 4], [5, 6, 7]], # width + [[0, 1, 2], [1, 2, 3], [2, 3, 4]], # batch 0 + [[3, 4, 5], [4, 5, 6], [5, 6, 7]], # batch 1 ] ) xq = torch.randn(bsz, seqlen, n_heads, head_dim) diff --git a/torchtitan/models/common/config_utils.py b/torchtitan/models/common/config_utils.py index d88c9d138d..a8f320ea77 100644 --- a/torchtitan/models/common/config_utils.py +++ b/torchtitan/models/common/config_utils.py @@ -22,11 +22,7 @@ VarlenAttention, ) from torchtitan.models.common.feed_forward import FeedForward -from torchtitan.models.common.moe import ( - GroupedExperts, - MoE, - TokenChoiceTopKRouter, -) +from torchtitan.models.common.moe import GroupedExperts, MoE, TokenChoiceTopKRouter from torchtitan.models.common.nn_modules import Linear, RMSNorm from torchtitan.models.common.rope import RoPE from torchtitan.models.common.token_dispatcher import ( diff --git a/torchtitan/models/common/decoder.py b/torchtitan/models/common/decoder.py index c27cdda35e..8a17d0fae2 100644 --- a/torchtitan/models/common/decoder.py +++ b/torchtitan/models/common/decoder.py @@ -79,7 +79,6 @@ class Config(BaseModel.Config): @property def max_seq_len(self) -> int: - # Llama4/iRoPE can have NoPE layers with ``rope=None``; use the # first layer that carries RoPE to expose the model context length. for layer_cfg in self.layers: attention_cfg = getattr(layer_cfg, "attention", None) diff --git a/torchtitan/models/common/moe_sharding.py b/torchtitan/models/common/moe_sharding.py index 67d219688e..9a335d8b5b 100644 --- a/torchtitan/models/common/moe_sharding.py +++ b/torchtitan/models/common/moe_sharding.py @@ -238,11 +238,16 @@ def set_moe_sharding_config( # model's own sharding code. shared = moe_cfg.shared_experts if shared is not None: - sp_layout = Shard(1) if enable_sp else Replicate() - shared_input_layout = Replicate() if not enable_ep else sp_layout + # Shared-expert input matches the MoE input: sequence-parallel under + # EP+SP (Shard(1) on seq, ordered via partition_spec), else Replicate. + shared_input = ( + dense_sequence_parallel_placement() + if enable_ep and enable_sp + else dense_activation_placement(tp=spmd.R) + ) shared.sharding_config = ShardingConfig( - in_src_shardings={"x": dense_activation_placement(tp=shared_input_layout)}, - in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, + in_src_shardings={"x": shared_input}, + in_dst_shardings={"x": dense_activation_placement(tp=spmd.R)}, ) shared.w1.sharding_config = _shared_expert_colwise_config( diff --git a/torchtitan/models/qwen3_5/parallelize.py b/torchtitan/models/qwen3_5/parallelize.py index 0e625acea4..9c39b01f1a 100644 --- a/torchtitan/models/qwen3_5/parallelize.py +++ b/torchtitan/models/qwen3_5/parallelize.py @@ -27,10 +27,11 @@ from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.compile import apply_compile -from torchtitan.distributed.fsdp import get_fsdp_reshard_after_forward_policy - +from torchtitan.distributed.fsdp import ( + apply_fsdp_to_decoder, + get_fsdp_reshard_after_forward_policy, +) from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.models.llama4.parallelize import apply_fsdp from torchtitan.tools.logging import logger @@ -77,7 +78,7 @@ def parallelize_qwen3_5( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallelism.full_dtensor: + if parallelism.spmd_backend == "full_dtensor": raise NotImplementedError("full_dtensor is not supported yet.") model_compile_enabled = ( @@ -147,8 +148,8 @@ def parallelize_qwen3_5( ) edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) - apply_fsdp( - model, + apply_fsdp_to_decoder( + model, # pyrefly: ignore [bad-argument-type] dp_mesh, param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce], diff --git a/torchtitan/models/qwen3_5/sharding.py b/torchtitan/models/qwen3_5/sharding.py index 234b605c2c..77db288696 100644 --- a/torchtitan/models/qwen3_5/sharding.py +++ b/torchtitan/models/qwen3_5/sharding.py @@ -18,12 +18,13 @@ from typing import TYPE_CHECKING -from torch.distributed.tensor import Placement, Replicate, Shard +import spmd_types as spmd from torchtitan.models.common.decoder_sharding import ( colwise_config, dense_activation_placement, dense_param_placement, + dense_sequence_parallel_placement, norm_config, rowwise_config, set_decoder_sharding_config, @@ -43,25 +44,26 @@ ) from torchtitan.models.qwen3_5.vision_encoder import Qwen35VisionEncoder + def _replicate_norm() -> ShardingConfig: """Replicate norm (weight/bias and activations) — used by the vision encoder, which runs without sequence parallelism.""" return ShardingConfig( state_shardings={ - "weight": dense_param_placement(tp=Replicate()), - "bias": dense_param_placement(tp=Replicate()), + "weight": dense_param_placement(tp=spmd.R), + "bias": dense_param_placement(tp=spmd.R), }, - in_src_shardings={"input": dense_activation_placement(tp=Replicate())}, - in_dst_shardings={"input": dense_activation_placement(tp=Replicate())}, - out_dst_shardings=dense_activation_placement(tp=Replicate()), + in_src_shardings={"input": dense_activation_placement(tp=spmd.R)}, + in_dst_shardings={"input": dense_activation_placement(tp=spmd.R)}, + out_dst_shardings=dense_activation_placement(tp=spmd.R), ) def _qk_norm_sharding() -> ShardingConfig: """Per-head QK-norm sharding: weight Replicate, activations Shard(2).""" - head_plc = dense_activation_placement(tp=Shard(2)) + head_plc = dense_activation_placement(tp=spmd.S(2)) return ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Replicate())}, + state_shardings={"weight": dense_param_placement(tp=spmd.R)}, in_src_shardings={"input": head_plc}, in_dst_shardings={"input": head_plc}, out_dst_shardings=head_plc, @@ -71,14 +73,14 @@ def _qk_norm_sharding() -> ShardingConfig: def _conv_weight_sharding() -> ShardingConfig: """Depthwise Conv1d weight sharded Shard(0) on out-channels (head-sharded).""" return ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Shard(0))}, + state_shardings={"weight": dense_param_placement(tp=spmd.S(0))}, ) -_GROUPED_EXPERTS_PARAM_LAYOUT: dict[str, Placement] = { - "w1_EFD": Shard(1), - "w2_EDF": Shard(2), - "w3_EFD": Shard(1), +_GROUPED_EXPERTS_PARAM_LAYOUT: dict[str, spmd.PerMeshAxisSpmdType] = { + "w1_EFD": spmd.S(1), + "w2_EDF": spmd.S(2), + "w3_EFD": spmd.S(1), } @@ -99,10 +101,10 @@ def set_qwen35_sharding_config( set_decoder_sharding_config(config, loss_parallel=loss_parallel, enable_sp=True) # Override tok_embeddings: output Replicate (not Shard(1)) for vision scatter config.tok_embeddings.sharding_config = ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Shard(0))}, - in_src_shardings={"input": dense_activation_placement(tp=Replicate())}, - in_dst_shardings={"input": dense_activation_placement(tp=Replicate())}, - out_dst_shardings=dense_activation_placement(tp=Replicate()), + state_shardings={"weight": dense_param_placement(tp=spmd.S(0))}, + in_src_shardings={"input": dense_activation_placement(tp=spmd.R)}, + in_dst_shardings={"input": dense_activation_placement(tp=spmd.R)}, + out_dst_shardings=dense_activation_placement(tp=spmd.R), ) _set_vision_encoder_sharding(config.vision_encoder) for layer_cfg in config.layers: @@ -126,7 +128,7 @@ def _set_qwen35_layer_sharding( if layer_cfg.feed_forward is not None: set_dense_ffn_sharding( layer_cfg.feed_forward, - attn_x_placement=Shard(1), + attn_x_layout=dense_sequence_parallel_placement(), enable_sp=True, ) @@ -137,6 +139,7 @@ def _set_qwen35_layer_sharding( enable_sp=True, expert_param_layout=_GROUPED_EXPERTS_PARAM_LAYOUT, ) + # pyrefly: ignore [missing-attribute] _set_shared_expert_gate_sharding(layer_cfg.moe.shared_experts) @@ -157,10 +160,10 @@ def _set_shared_expert_gate_sharding( return gate.sharding_config = ShardingConfig( state_shardings={ - "weight": dense_param_placement(tp=Replicate()), - "bias": dense_param_placement(tp=Replicate()), + "weight": dense_param_placement(tp=spmd.R), + "bias": dense_param_placement(tp=spmd.R), }, - out_dst_shardings=dense_activation_placement(tp=Replicate()), + out_dst_shardings=dense_activation_placement(tp=spmd.R), ) @@ -172,18 +175,18 @@ def _set_vision_encoder_sharding(ve_cfg: "Qwen35VisionEncoder.Config") -> None: Norms are Replicate. pos_embed is Replicate via state_shardings. """ ve_cfg.sharding_config = ShardingConfig( - state_shardings={"pos_embed": dense_param_placement(tp=Replicate())}, + state_shardings={"pos_embed": dense_param_placement(tp=spmd.R)}, ) # patch_embed receives plain pixel_values — wrap as DTensor(Replicate) ve_cfg.patch_embed_proj.sharding_config = ShardingConfig( state_shardings={ - "weight": dense_param_placement(tp=Replicate()), - "bias": dense_param_placement(tp=Replicate()), + "weight": dense_param_placement(tp=spmd.R), + "bias": dense_param_placement(tp=spmd.R), }, - in_src_shardings={"input": dense_activation_placement(tp=Replicate())}, - in_dst_shardings={"input": dense_activation_placement(tp=Replicate())}, - out_dst_shardings=dense_activation_placement(tp=Replicate()), + in_src_shardings={"input": dense_activation_placement(tp=spmd.R)}, + in_dst_shardings={"input": dense_activation_placement(tp=spmd.R)}, + out_dst_shardings=dense_activation_placement(tp=spmd.R), ) # Block sub-modules @@ -192,8 +195,8 @@ def _set_vision_encoder_sharding(ve_cfg: "Qwen35VisionEncoder.Config") -> None: block.norm2.sharding_config = _replicate_norm() block.attn.sharding_config = ShardingConfig( - in_src_shardings={"rope_cache": dense_activation_placement(tp=Replicate())}, - in_dst_shardings={"rope_cache": dense_activation_placement(tp=Replicate())}, + in_src_shardings={"rope_cache": dense_activation_placement(tp=spmd.R)}, + in_dst_shardings={"rope_cache": dense_activation_placement(tp=spmd.R)}, ) block.attn.wq.sharding_config = colwise_config() block.attn.wk.sharding_config = colwise_config() @@ -216,13 +219,13 @@ def _set_full_attention_sharding( ) -> None: """TP sharding for Qwen35Attention (output gating + partial RoPE).""" attention_cfg.sharding_config = ShardingConfig( - in_src_shardings={"x": dense_activation_placement(tp=Shard(1))}, - in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, + in_src_shardings={"x": dense_sequence_parallel_placement()}, + in_dst_shardings={"x": dense_activation_placement(tp=spmd.R)}, ) # The per-layer rope ``cache`` buffer is a Replicate DTensor; MRoPE builds the # position-resolved cache from it (``positions`` stays a plain input). attention_cfg.rope.sharding_config = ShardingConfig( - state_shardings={"cache": dense_param_placement(tp=Replicate())}, + state_shardings={"cache": dense_param_placement(tp=spmd.R)}, ) attention_cfg.wq.sharding_config = colwise_config() attention_cfg.wk.sharding_config = colwise_config() @@ -267,16 +270,16 @@ def _set_deltanet_sharding(deltanet_cfg: "GatedDeltaNet.Config") -> None: deltanet_cfg.out_proj.sharding_config = rowwise_config(output_sp=True) # RMSNormGated: per-head norm, weight Replicate, activations Shard(2) - _norm_plc = dense_activation_placement(tp=Shard(2)) + _norm_plc = dense_activation_placement(tp=spmd.S(2)) deltanet_cfg.norm.sharding_config = ShardingConfig( - state_shardings={"weight": dense_param_placement(tp=Replicate())}, + state_shardings={"weight": dense_param_placement(tp=spmd.R)}, in_src_shardings={"x": _norm_plc, "gate": _norm_plc}, in_dst_shardings={"x": _norm_plc, "gate": _norm_plc}, out_dst_shardings=_norm_plc, ) # GatedDeltaKernel: local_map converts DTensor q/k/v/g/beta to local - _kernel_plc = dense_activation_placement(tp=Shard(2)) + _kernel_plc = dense_activation_placement(tp=spmd.S(2)) deltanet_cfg.kernel.sharding_config = ShardingConfig( in_dst_shardings={ "q": _kernel_plc, @@ -293,10 +296,10 @@ def _set_deltanet_sharding(deltanet_cfg: "GatedDeltaNet.Config") -> None: deltanet_cfg.sharding_config = ShardingConfig( state_shardings={ - "A_log": dense_param_placement(tp=Shard(0)), - "dt_bias": dense_param_placement(tp=Shard(0)), + "A_log": dense_param_placement(tp=spmd.S(0)), + "dt_bias": dense_param_placement(tp=spmd.S(0)), }, - in_src_shardings={"x": dense_activation_placement(tp=Shard(1))}, - in_dst_shardings={"x": dense_activation_placement(tp=Replicate())}, - out_dst_shardings=dense_activation_placement(tp=Shard(1)), + in_src_shardings={"x": dense_sequence_parallel_placement()}, + in_dst_shardings={"x": dense_activation_placement(tp=spmd.R)}, + out_dst_shardings=dense_sequence_parallel_placement(), ) diff --git a/torchtitan/models/qwen3_5/state_dict_adapter.py b/torchtitan/models/qwen3_5/state_dict_adapter.py index 8f3c579dec..1bba2ce66f 100644 --- a/torchtitan/models/qwen3_5/state_dict_adapter.py +++ b/torchtitan/models/qwen3_5/state_dict_adapter.py @@ -190,6 +190,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_value = value # Linear weight (out, C*T*H*W) → Conv3d weight (out, C, T, H, W) if tt_key == "vision_encoder.patch_embed.weight": + # pyrefly: ignore [missing-attribute] encoder = self.model_config.vision_encoder hf_value = value.reshape( value.shape[0], @@ -286,6 +287,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key == "model.language_model.layers.{}.linear_attn.in_proj_qkv.weight" ): + # pyrefly: ignore [missing-attribute] dn = self.model_config.layers[int(idx)].delta_net kd = dn.in_proj_q.out_features vd = dn.in_proj_v.out_features @@ -300,6 +302,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key == "model.language_model.layers.{}.linear_attn.conv1d.weight" ): + # pyrefly: ignore [missing-attribute] dn = self.model_config.layers[int(idx)].delta_net kd = dn.in_proj_q.out_features vd = dn.in_proj_v.out_features From b917c7fb06f7107e6775785727a5e1c48dc9950d Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Wed, 10 Jun 2026 00:13:30 -0700 Subject: [PATCH 7/7] merge extra_inputs into extra_kwargs in trainer --- .../numerical_tests_qwen3_5_shard.py | 9 ++- torchtitan/components/validate.py | 57 +++++------------ torchtitan/experiments/forge/example_train.py | 23 +++---- .../graph_trainer/precompile_main.py | 2 - .../tests/test_bitwise_deterministic.py | 3 - .../experiments/graph_trainer/trainer.py | 12 +--- torchtitan/models/common/decoder.py | 39 +++++------- torchtitan/models/qwen3_5/__init__.py | 5 +- torchtitan/models/qwen3_5/model.py | 5 +- torchtitan/models/qwen3_5/parallelize.py | 11 ++-- torchtitan/models/qwen3_5/vision_encoder.py | 1 + torchtitan/models/utils.py | 16 +++-- torchtitan/trainer.py | 63 ++++++------------- 13 files changed, 86 insertions(+), 160 deletions(-) diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py index 62308692c8..a9da123221 100644 --- a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py +++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py @@ -20,6 +20,7 @@ import subprocess import sys import tempfile +from typing import cast import torch import torch.distributed as dist @@ -32,7 +33,7 @@ TrainingConfig, ) from torchtitan.distributed import ParallelDims -from torchtitan.models.qwen3_5 import qwen3_5_configs +from torchtitan.models.qwen3_5 import Qwen35Model, qwen3_5_configs from torchtitan.models.qwen3_5.parallelize import parallelize_qwen3_5 CONFIGS = [ @@ -116,13 +117,17 @@ def run_worker(args): tokens = torch.randint(0, 248320, (1, seq_len), device="cuda") dist.broadcast(tokens, src=0) - # Text-only inputs + # Text-only inputs: plain sequential positions. The flex backend requires a + # BlockMask, which the trainer normally builds in post_dataloading_process; + # build it here directly since we call the model outside the trainer. positions = torch.arange(seq_len, device="cuda").unsqueeze(0) + attention_masks = cast(Qwen35Model, model).get_attention_masks(positions=positions) with torch.no_grad(): output = model( tokens, positions=positions, + attention_masks=attention_masks, special_tokens={"image_id": 248056, "video_id": 248057}, ) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index ab81e6cd24..62db7ef4da 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -141,7 +141,7 @@ def post_dataloading_process( input_dict: dict[str, torch.Tensor], labels: torch.Tensor, model_parts: list[nn.Module], - ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: """ Post-processing hook after data loading and before model forward pass. @@ -157,27 +157,17 @@ def post_dataloading_process( model_parts: List of model parts for accessing model methods. Returns: - A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + A tuple of (inputs, labels, extra_kwargs) where: - inputs: Main input tensor extracted from input_dict["input"]. - labels: Target labels (potentially modified by CP sharding). - - extra_inputs: Dict of auxiliary input tensors (all keys except - "input" from input_dict). These are passed to the model forward - but are NOT forwarded across pipeline parallel stages. - - extra_kwargs: Dict of additional keyword arguments for model forward. - These ARE forwarded across pipeline parallel stages. Contains - attention_masks if flex attention is enabled. - - Note: - The distinction between extra_inputs and extra_kwargs is important for - pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, - while extra_inputs are only available to the first stage. + - extra_kwargs: Additional keyword arguments for the model forward + (e.g. positions, attention_masks), forwarded to every + pipeline-parallel stage. """ inputs = input_dict["input"] - extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # For arguments, like attention_masks, we have to put them in a separate - # dict as extra_inputs are not forwarded to other stages in PP, but - # extra_kwargs are. - extra_kwargs: dict[str, Any] = {} + extra_kwargs: dict[str, Any] = { + k: v for k, v in input_dict.items() if k != "input" + } # TODO: deduplicate with Trainer.post_dataloading_process which has # the same logic; extract a shared function to prevent further drift. @@ -185,26 +175,16 @@ def post_dataloading_process( # both RoPE and block_causal attention masking. model_config = getattr(model_parts[0], "config", None) - positions = extra_inputs.pop("positions", None) + positions = extra_kwargs.get("positions", None) # positions and attention_masks are optional (Decoder.forward defaults # both to None). Build masks only for the masked backends (Flex/Varlen), # which is where get_attention_masks is defined. A maskless backend (the # SDPA config used by the graph_trainer tests) still receives positions # for RoPE but no masks — it relies on is_causal instead. - mrope_positions = extra_inputs.pop("mrope_positions", None) - if isinstance(model_config, Decoder.Config): - attn_config = model_config.layers[0].attention - inner_attention = attn_config.inner_attention - - if attn_config.mask_type == "block_causal": - assert ( - positions is not None - ), "block_causal mask requires per-document positions from the dataloader" - else: - positions = torch.arange( - inputs.shape[1], dtype=torch.int32, device=inputs.device - ).repeat(inputs.shape[0], 1) - + if isinstance(model_config, Decoder.Config) and positions is not None: + inner_attention = getattr( + model_config.first_attention, "inner_attention", None + ) if isinstance( inner_attention, (FlexAttention.Config, VarlenAttention.Config) ): @@ -213,10 +193,6 @@ def post_dataloading_process( positions=positions, ) - extra_kwargs["positions"] = positions - if mrope_positions is not None: - extra_kwargs["mrope_positions"] = mrope_positions - if self.parallel_dims.cp_enabled: inputs, labels, extra_kwargs = prepare_context_parallel_input( inputs, @@ -232,7 +208,7 @@ def post_dataloading_process( self.parallel_dims, inputs, labels, extra_kwargs ) - return inputs, labels, extra_inputs, extra_kwargs + return inputs, labels, extra_kwargs @sl.log_trace_span("eval") @torch.no_grad() @@ -284,7 +260,7 @@ def validate( global_valid_tokens = float(local_valid_tokens.item()) # Process data (extract inputs, handle attention masks, CP sharding) - inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + inputs, labels, extra_kwargs = self.post_dataloading_process( input_dict, labels, model_parts ) @@ -300,7 +276,6 @@ def validate( if self.pp_has_first_stage: self.pp_schedule.eval( inputs, - **extra_inputs, **extra_kwargs, target=targets, losses=losses, @@ -323,7 +298,7 @@ def validate( else: with self.validation_context(): assert len(model_parts) == 1 - predictions = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + predictions = model_parts[0](inputs, **extra_kwargs) loss_sum = self.loss_fn(predictions, labels) accumulated_losses.append(loss_sum.detach() / global_valid_tokens) diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index dc706aeacd..b51727f2a4 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -166,15 +166,15 @@ def batch_generator( def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: inputs = input_dict["input"] - extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # For arguments, like attention_masks, we have to put them in a separate - # dict as extra_inputs are not forwarded to other stages in PP, but - # extra_kwargs are. - extra_kwargs: dict[str, Any] = {} + # Everything except the pipelined input is a model-forward kwarg, + # forwarded to all PP stages by the schedule. + extra_kwargs: dict[str, Any] = { + k: v for k, v in input_dict.items() if k != "input" + } - positions = extra_inputs.pop("positions", None) + positions = extra_kwargs.get("positions", None) try: # pyrefly: ignore [not-callable] @@ -194,7 +194,7 @@ def post_dataloading_process( self.config.parallelism.context_parallel_load_balancer, ) - return inputs, labels, extra_inputs, extra_kwargs + return inputs, labels, extra_kwargs def forward_backward_step( self, @@ -206,9 +206,7 @@ def forward_backward_step( model_parts = self.model_parts parallel_dims = self.parallel_dims - inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( - input_dict, labels - ) + inputs, labels, extra_kwargs = self.post_dataloading_process(input_dict, labels) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call @@ -219,7 +217,6 @@ def forward_backward_step( if self.pp_has_first_stage: self.pp_schedule.step( inputs, - **extra_inputs, **extra_kwargs, target=targets, losses=losses, @@ -244,7 +241,7 @@ def forward_backward_step( # Non-PP forward / backward with self.train_context(): assert len(model_parts) == 1 - pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + pred = model_parts[0](inputs, **extra_kwargs) # Compute loss sum (reduction='sum') loss_sum = self.loss_fn(pred, labels) diff --git a/torchtitan/experiments/graph_trainer/precompile_main.py b/torchtitan/experiments/graph_trainer/precompile_main.py index 8b12e81e2d..a0db7987c4 100644 --- a/torchtitan/experiments/graph_trainer/precompile_main.py +++ b/torchtitan/experiments/graph_trainer/precompile_main.py @@ -207,7 +207,6 @@ def _precompile_aot_fx_trace( * parallel_dims.cp ) dummy_global_valid_tokens = float(global_batch_size * seq_len) - extra_inputs: dict[str, torch.Tensor] = {} extra_kwargs: dict[str, Any] = {} if isinstance(model_config, Decoder.Config) and model_config.layers: @@ -257,7 +256,6 @@ def _precompile_aot_fx_trace( dummy_inputs, dummy_labels, dummy_global_valid_tokens, - extra_inputs, extra_kwargs, ) logger.info( diff --git a/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py b/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py index 2182529c41..087e28f136 100644 --- a/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py +++ b/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py @@ -210,7 +210,6 @@ def _run_steps_with_precompile( global_valid_tokens = torch.tensor( BATCH_SIZE * SEQ_LEN, dtype=torch.float, device="cuda" ) - extra_inputs: dict[str, torch.Tensor] = {} extra_kwargs: dict[str, object] = { "positions": self.positions, **self._get_extra_kwargs(model), @@ -222,7 +221,6 @@ def _run_steps_with_precompile( self.inputs, self.labels, global_valid_tokens, - extra_inputs, extra_kwargs, ) @@ -284,7 +282,6 @@ def _run_steps_with_precompile( self.inputs, self.labels, global_valid_tokens, - extra_inputs, extra_kwargs, ) loss = outputs[0] diff --git a/torchtitan/experiments/graph_trainer/trainer.py b/torchtitan/experiments/graph_trainer/trainer.py index 72c48fa469..8a45f8b6a0 100644 --- a/torchtitan/experiments/graph_trainer/trainer.py +++ b/torchtitan/experiments/graph_trainer/trainer.py @@ -70,8 +70,8 @@ def make_fwd_bwd_step(model, loss_fn): to thread its parameters/buffers as static graph inputs. """ - def fwd_bwd_step(inputs, labels, global_valid_tokens, extra_inputs, extra_kwargs): - pred = model(inputs, **extra_inputs, **extra_kwargs) + def fwd_bwd_step(inputs, labels, global_valid_tokens, extra_kwargs): + pred = model(inputs, **extra_kwargs) # The loss function is not a submodule of the model, so # annotate_module_fqns won't tag it. Annotate it here so that # downstream passes (bucketing, SAC, kernel annotations) can @@ -130,9 +130,7 @@ def forward_backward_step( assert len(self.model_parts) == 1 model = self.model_parts[0] - inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( - input_dict, labels - ) + inputs, labels, extra_kwargs = self.post_dataloading_process(input_dict, labels) # remove_duplicate=False to preserve duplicate parameter entries # from weight tying (e.g. shared embedding/output weights). params = [ @@ -146,7 +144,6 @@ def forward_backward_step( labels, global_valid_tokens, params, - extra_inputs, extra_kwargs, ) @@ -185,7 +182,6 @@ def _make_fx_forward_backward_step( labels: torch.Tensor, global_valid_tokens: float, params: list[torch.Tensor], - extra_inputs: dict[str, torch.Tensor], extra_kwargs: dict[str, Any], ) -> torch.Tensor: maybe_register_blockmask_pytree_node() @@ -199,7 +195,6 @@ def _make_fx_forward_backward_step( inputs, labels, global_valid_tokens, - extra_inputs, extra_kwargs, ) @@ -221,7 +216,6 @@ def _make_fx_forward_backward_step( inputs, labels, global_valid_tokens, - extra_inputs, extra_kwargs, ) loss = outputs[0] diff --git a/torchtitan/models/common/decoder.py b/torchtitan/models/common/decoder.py index 8a17d0fae2..b26860e620 100644 --- a/torchtitan/models/common/decoder.py +++ b/torchtitan/models/common/decoder.py @@ -78,17 +78,7 @@ class Config(BaseModel.Config): enable_weight_tying: bool = False @property - def max_seq_len(self) -> int: - # first layer that carries RoPE to expose the model context length. - for layer_cfg in self.layers: - attention_cfg = getattr(layer_cfg, "attention", None) - rope_cfg = getattr(attention_cfg, "rope", None) - if rope_cfg is not None: - return rope_cfg.max_seq_len - raise ValueError("Decoder config does not define RoPE max_seq_len.") - - @property - def first_attn_config(self) -> BaseAttention.Config | None: + def first_attention(self) -> BaseAttention.Config | None: """Attention config of the first layer that has one, else None. Hybrid models (linear + full attention) don't carry an attention @@ -105,6 +95,14 @@ def first_attn_config(self) -> BaseAttention.Config | None: None, ) + @property + def max_seq_len(self) -> int: + # The first full-attention layer's RoPE defines the context length. + rope_cfg = getattr(self.first_attention, "rope", None) + if rope_cfg is None: + raise ValueError("Decoder config does not define RoPE max_seq_len.") + return rope_cfg.max_seq_len + def update_from_config( self, *, @@ -139,12 +137,8 @@ def update_from_config( ) tp = parallelism.tensor_parallel_degree - if tp > 1: - attention = self.first_attn_config - if attention is None: - raise ValueError( - "No layer with attention config found for TP validation." - ) + attention = self.first_attention + if tp > 1 and attention is not None: n_heads = attention.n_heads n_kv_heads = getattr(attention, "n_kv_heads", None) or n_heads if n_heads % tp != 0: @@ -301,11 +295,12 @@ def _create_flex_attention_mask_for_document( def get_attention_masks( self, positions: torch.Tensor, - ) -> AttentionMasksType: - attn_config = self.config.first_attn_config - assert ( - attn_config is not None - ), "get_attention_masks requires an attention layer" + ) -> AttentionMasksType | None: + attn_config = self.config.first_attention + if attn_config is None: + # No full-attention layers (e.g. a pure linear-attention model, or a + # pipeline stage holding only linear-attention blocks) → no masks. + return None inner_attn = attn_config.inner_attention if isinstance(inner_attn, FlexAttention.Config): return self._create_flex_attention_mask_for_document(positions, attn_config) diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py index 158364a48a..1449dbfe9f 100644 --- a/torchtitan/models/qwen3_5/__init__.py +++ b/torchtitan/models/qwen3_5/__init__.py @@ -206,7 +206,7 @@ def _qwen35_attention_config( layer_id: int, ) -> Qwen35Attention.Config: """Build a fully-specified Qwen35Attention.Config.""" - inner_attention, mask_type = get_attention_config(attn_backend) + inner_attention = get_attention_config(attn_backend) return Qwen35Attention.Config( n_heads=n_heads, n_kv_heads=n_kv_heads, @@ -236,7 +236,6 @@ def _qwen35_attention_config( q_norm=_offset_norm(head_dim), k_norm=_offset_norm(head_dim), inner_attention=inner_attention, - mask_type=mask_type, ) @@ -1088,7 +1087,7 @@ def _397b_a17b( def model_registry( flavor: str, - attn_backend: str = "sdpa", + attn_backend: str = "flex", moe_comm_backend: str | None = None, converters: list[ModelConfigConverter.Config] | None = None, ) -> ModelSpec: diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py index d43042516c..37879019b0 100644 --- a/torchtitan/models/qwen3_5/model.py +++ b/torchtitan/models/qwen3_5/model.py @@ -376,7 +376,6 @@ class Config(BaseAttention.Config): q_norm: OffsetRMSNorm.Config k_norm: OffsetRMSNorm.Config inner_attention: Module.Config - mask_type: str = "causal" def __init__(self, config: Config): super().__init__() @@ -580,19 +579,17 @@ def update_from_config( def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, int]: - attn_cfg = self.first_attn_config + attn_cfg = self.first_attention # pyrefly: ignore [missing-attribute] n_heads = attn_cfg.n_heads # pyrefly: ignore [missing-attribute] head_dim = attn_cfg.head_dim - num_full_attn = sum(1 for l in self.layers if l.attention is not None) return get_moe_model_nparams_and_flops( self, model, n_heads, 2 * head_dim, seq_len, - num_full_attn=num_full_attn, ) def __init__(self, config: Config): diff --git a/torchtitan/models/qwen3_5/parallelize.py b/torchtitan/models/qwen3_5/parallelize.py index 9c39b01f1a..5b1de894e3 100644 --- a/torchtitan/models/qwen3_5/parallelize.py +++ b/torchtitan/models/qwen3_5/parallelize.py @@ -32,7 +32,6 @@ get_fsdp_reshard_after_forward_policy, ) from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.tools.logging import logger def _apply_fsdp_to_vision_encoder( @@ -160,11 +159,6 @@ def parallelize_qwen3_5( edp_mesh=edp_mesh, ) - logger.info("Applied fully_shard to the Qwen3.5 model") - - if training.enable_cpu_offload: - logger.info("Applied CPU Offloading to the Qwen3.5 model") - return model @@ -200,7 +194,10 @@ def pipeline_qwen3_5( fqn_per_part = _generate_llm_fqn_per_model_part( num_virtual_stages, num_layers, input_weight, output_weight ) - # Vision encoder lives on the first stage alongside tok_embeddings + # Vision encoder lives on the first stage alongside tok_embeddings. This + # adds load to stage 0 that the auto split doesn't model (input_weight + # only accounts for tok_embeddings); for a heavy vision encoder, bump + # parallelism.pipeline_parallel_first_stage_less_layers to rebalance. if hasattr(model, "vision_encoder") and model.vision_encoder is not None: fqn_per_part[0].insert(0, "vision_encoder") parallelism = dataclasses.replace( diff --git a/torchtitan/models/qwen3_5/vision_encoder.py b/torchtitan/models/qwen3_5/vision_encoder.py index 8e2bf38317..83ceea35ee 100644 --- a/torchtitan/models/qwen3_5/vision_encoder.py +++ b/torchtitan/models/qwen3_5/vision_encoder.py @@ -432,6 +432,7 @@ def __init__(self, config: Config): self.spatial_merge_size = config.spatial_merge_size self.spatial_merge_unit = config.spatial_merge_size**2 + # Patches are pre-extracted by the collator, so Linear replaces Conv3d (equivalent at full-patch kernel size). self.patch_embed = config.patch_embed_proj.build() # nn.Parameter (not nn.Embedding) because we interpolate the weight directly diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index abbeaa91b8..5df66cc5d4 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -460,8 +460,6 @@ def get_moe_model_nparams_and_flops( n_heads: int, head_dims: int, seq_len: int, - *, - num_full_attn: int | None = None, ) -> tuple[int, int]: """ Calculate nparams and nflops for MoE models. @@ -472,11 +470,6 @@ def get_moe_model_nparams_and_flops( n_heads: The number of attention heads. head_dims: The sum of qk and v head dimensions. seq_len: The sequence length in training configs. - num_full_attn: For hybrid models that mix full (O(L²)) - attention with linear (O(L)) attention, the number of layers using - softmax attention. Only these layers contribute the quadratic - attention FLOPs term. If None (default), all layers are assumed to - use full attention. Returns: Tuple of (nparams, num_flops_per_token): @@ -528,8 +521,13 @@ def get_moe_model_nparams_and_flops( nparams_for_matmul = nparams_dense + nparams_sparse_active else: nparams_for_matmul = nparams_dense - nparams_embedding + nparams_sparse_active - if num_full_attn is None: - num_full_attn = len(model_config.layers) + # Only full attention layers contribute the quadratic O(L²) FLOPs + # term. Hybrid models mix full attention with linear attention + # layers whose block leaves ``attention=None``; standard decoders carry + # full attention on every layer, so this counts all of them. + num_full_attn = sum( + 1 for l in model_config.layers if getattr(l, "attention", None) is not None + ) num_flops_per_token = ( 6 * nparams_for_matmul + 6 * num_full_attn * n_heads * head_dims * seq_len ) diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index cab7a71f02..ec09dc502a 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -542,7 +542,7 @@ def batch_generator( @sl.log_trace_span("post_dataloading_process") def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: """ Post-processing hook after data loading and before model forward pass. @@ -561,32 +561,21 @@ def post_dataloading_process( labels: Target labels for the batch. Returns: - A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + A tuple of (inputs, labels, extra_kwargs) where: - inputs: Main input tensor extracted from input_dict["input"]. - labels: Target labels (unchanged from input parameter). - - extra_inputs: Dict of auxiliary input tensors from input_dict - (excluding "input" and "positions"). These are passed to the - model forward but are NOT forwarded across pipeline parallel - stages. - - extra_kwargs: Dict of additional keyword arguments for model - forward (positions, attention_masks). These ARE forwarded - across all pipeline parallel stages. - - Note: - The distinction between extra_inputs and extra_kwargs is important for - pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, - while extra_inputs are only available to the first stage. Positions - always go into extra_kwargs so every stage can apply RoPE correctly. + - extra_kwargs: Additional keyword arguments for the model forward + (e.g. positions, attention_masks), forwarded to every + pipeline-parallel stage. """ inputs = input_dict["input"] - extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # extra_kwargs are forwarded to all PP stages; extra_inputs are only - # available to the first stage. Positions go into extra_kwargs so - # every stage can apply RoPE correctly. - extra_kwargs: dict[str, Any] = {} + # Everything else becomes a model-forward kwarg, forwarded to all PP + # stages by the schedule. positions is read here so we can build masks. + extra_kwargs: dict[str, Any] = { + k: v for k, v in input_dict.items() if k != "input" + } - positions = extra_inputs.pop("positions", None) - mrope_positions = extra_inputs.pop("mrope_positions", None) + positions = extra_kwargs.get("positions", None) # positions and attention_masks are optional (Decoder.forward defaults # both to None). Build attention masks only for the masked backends @@ -594,19 +583,10 @@ def post_dataloading_process( # maskless backend (e.g. the SDPA config used by the graph_trainer # tests) still receives positions for RoPE but no masks — it relies on # is_causal instead. - if isinstance(self.model_config, Decoder.Config): - attn_config = self.model_config.first_attn_config - inner_attention = getattr(attn_config, "inner_attention", None) - - if attn_config is not None and attn_config.mask_type == "block_causal": - assert ( - positions is not None - ), "block_causal mask requires per-document positions from the dataloader" - else: - positions = torch.arange( - inputs.shape[1], dtype=torch.int32, device=inputs.device - ).repeat(inputs.shape[0], 1) - + if isinstance(self.model_config, Decoder.Config) and positions is not None: + inner_attention = getattr( + self.model_config.first_attention, "inner_attention", None + ) if isinstance( inner_attention, (FlexAttention.Config, VarlenAttention.Config) ): @@ -615,10 +595,6 @@ def post_dataloading_process( positions=positions, ) - extra_kwargs["positions"] = positions - if mrope_positions is not None: - extra_kwargs["mrope_positions"] = mrope_positions - if self.parallel_dims.cp_enabled: inputs, labels, extra_kwargs = prepare_context_parallel_input( inputs, @@ -638,7 +614,7 @@ def post_dataloading_process( self.parallel_dims, inputs, labels, extra_kwargs ) - return inputs, labels, extra_inputs, extra_kwargs + return inputs, labels, extra_kwargs @sl.log_trace_span("fwd_bwd") def forward_backward_step( @@ -651,9 +627,7 @@ def forward_backward_step( model_parts = self.model_parts parallel_dims = self.parallel_dims - inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( - input_dict, labels - ) + inputs, labels, extra_kwargs = self.post_dataloading_process(input_dict, labels) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call @@ -665,7 +639,6 @@ def forward_backward_step( if self.pp_has_first_stage: self.pp_schedule.step( inputs, - **extra_inputs, **extra_kwargs, target=targets, losses=losses, @@ -693,7 +666,7 @@ def forward_backward_step( # Non-PP forward / backward assert len(model_parts) == 1 with self.train_context(): - pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + pred = model_parts[0](inputs, **extra_kwargs) # Under non-full_dtensor, labels stay as plain tensors. See # ``cross_entropy_loss`` for why pred must also be plain. # Remove once non-full_dtensor is no longer supported.