diff --git a/.ci/docker/requirements-vlm.txt b/.ci/docker/requirements-vlm.txt
index e82b2e33ba..9b1e557669 100644
--- a/.ci/docker/requirements-vlm.txt
+++ b/.ci/docker/requirements-vlm.txt
@@ -2,3 +2,4 @@ av
einops
pillow
torchvision
+flash-linear-attention
diff --git a/pyproject.toml b/pyproject.toml
index 21a2d28397..92bffa30e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,5 +65,5 @@ testpaths = ["tests"]
[tool.pyrefly]
project-excludes = ["torchtitan/experiments", "**/tests/**"]
-replace-imports-with-any = ["torchao.*", "torchft", "torchvision.*", "deep_ep.*", "jinja2.*"] # optional dependencies
+replace-imports-with-any = ["torchao.*", "torchft", "torchvision.*", "deep_ep.*", "jinja2.*", "fla.*"] # optional dependencies
search-path = ["../pytorch"] # local built pytorch
diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_vl.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py
similarity index 76%
rename from scripts/checkpoint_conversion/numerical_tests_qwen3_vl.py
rename to scripts/checkpoint_conversion/numerical_tests_qwen3_5.py
index 944363a318..d120461a40 100644
--- a/scripts/checkpoint_conversion/numerical_tests_qwen3_vl.py
+++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5.py
@@ -6,36 +6,42 @@
# LICENSE file in the root directory of this source tree.
"""
-End-to-end numerical correctness test for Qwen3-VL checkpoint conversion.
+End-to-end numerical correctness test for Qwen3.5 checkpoint conversion.
Compares HuggingFace and TorchTitan next-token logits on multimodal inputs
(random image + text prompt). Each pipeline uses its own image preprocessing
so the test validates the full path: pixels → vision encoder → decoder → logits.
Usage:
- python -m scripts.checkpoint_conversion.numerical_tests_qwen3_vl \
- --hf_model_path /path/to/Qwen3-VL-2B-Instruct \
- --tt_checkpoint_path /path/to/qwen3_vl_2b_dcp
-
- python -m scripts.checkpoint_conversion.numerical_tests_qwen3_vl \
- --hf_model_path /path/to/Qwen3-VL-30B-A3B-Instruct \
- --tt_checkpoint_path /path/to/qwen3_vl_30b_a3b_dcp \
- --model_flavor 30B-A3B
+ python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \
+ --hf_model_path hf_assets/Qwen/Qwen3.5-2B \
+ --tt_checkpoint_path outputs/Qwen/qwen3_5_2b_dcp \
+ --model_flavor 2B
"""
import argparse
import os
+import types
+from typing import Any
+import einops as E
import torch
import torch._dynamo
import torch.distributed.checkpoint as dcp
import torch.nn.functional as F
+from PIL import Image
torch._dynamo.config.disable = True
from torchtitan.components.checkpoint import ModelWrapper
-from torchtitan.models.qwen3_vl import model_registry
-from transformers import AutoProcessor
+from torchtitan.hf_datasets.multimodal.mm_collator import MultiModalCollator
+from torchtitan.hf_datasets.multimodal.utils.image import (
+ process_image,
+ vision_to_patches,
+)
+from torchtitan.models.common.attention import ScaledDotProductAttention
+from torchtitan.models.qwen3_5 import model_registry, QWEN3_5_SPECIAL_TOKENS
+from transformers import AutoModelForImageTextToText, AutoProcessor
# ============================================================
@@ -44,7 +50,7 @@
def kl_divergence(logits_a, logits_b):
- """KL(a || b) between two logit tensors."""
+ """KL(softmax(b) || softmax(a)) — F.kl_div(log Q, P) computes KL(P || Q)."""
return F.kl_div(
F.log_softmax(logits_a, dim=-1),
F.softmax(logits_b, dim=-1),
@@ -77,25 +83,32 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
Returns:
hf_inputs: list of dicts (processor output, ready for HF model)
- tt_inputs: list of (input_ids, pixel_values, grid_thw)
+ tt_inputs: list of (input_ids, pixel_values, grid_thw, mrope_positions)
pixel_comparisons: list of per-sample pixel diff stats
+ special_tokens: {"image_id", "video_id"} from the tokenizer
"""
- import einops as E
- from PIL import Image
-
- from torchtitan.hf_datasets.multimodal.utils.image import (
- process_image,
- vision_to_patches,
- )
-
- processor = AutoProcessor.from_pretrained(hf_model_path, trust_remote_code=True)
+ # Annotate as Any: AutoProcessor.from_pretrained is typed Optional, which
+ # trips the .apply_chat_template call on environments without transformers stubs.
+ processor: Any = AutoProcessor.from_pretrained(hf_model_path)
model_config = model_registry(model_flavor).model
- encoder_config = model_config.vision_encoder # pyrefly: ignore[missing-attribute]
- patch_size = encoder_config.patch_embed.patch_size
- temporal_patch_size = encoder_config.patch_embed.temporal_patch_size
+ # pyrefly: ignore [missing-attribute]
+ encoder_config = model_config.vision_encoder
+ patch_size = encoder_config.patch_size
+ temporal_patch_size = encoder_config.temporal_patch_size
merge_size = encoder_config.spatial_merge_size
+ image_token_id = processor.tokenizer.convert_tokens_to_ids(
+ QWEN3_5_SPECIAL_TOKENS["image_token"]
+ )
+ video_token_id = processor.tokenizer.convert_tokens_to_ids(
+ QWEN3_5_SPECIAL_TOKENS["video_token"]
+ )
+ special_tokens = {"image_id": image_token_id, "video_id": video_token_id}
+ # Reuse the collator's MRoPE builder (only needs spatial_merge_size) so the
+ # 3D position IDs match the training path exactly.
+ mrope_builder = types.SimpleNamespace(spatial_merge_size=merge_size)
+
hf_inputs, tt_inputs, pixel_comparisons = [], [], []
for i in range(num_samples):
@@ -117,7 +130,7 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
],
}
]
- hf_in = processor.apply_chat_template( # pyrefly: ignore[missing-attribute]
+ hf_in = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
@@ -137,8 +150,23 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
temporal_patch_size,
merge_size,
)
+ # 3D MRoPE positions (1, S, 3); positions=None → single document.
+ mrope_positions = MultiModalCollator._build_mrope_positions(
+ mrope_builder, # pyrefly: ignore [bad-argument-type]
+ hf_in["input_ids"],
+ grid_thw.unsqueeze(0),
+ None,
+ None,
+ image_token_id=image_token_id,
+ video_token_id=video_token_id,
+ )
tt_inputs.append(
- (hf_in["input_ids"], patches.unsqueeze(0), grid_thw.unsqueeze(0))
+ (
+ hf_in["input_ids"],
+ patches.unsqueeze(0),
+ grid_thw.unsqueeze(0),
+ mrope_positions,
+ )
)
# --- Compare pixel values in image space ---
@@ -162,7 +190,7 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
tt_img = E.rearrange(patches, pattern, **kwargs)
pixel_comparisons.append(_compare_images(hf_img[:1], tt_img[:1], i))
- return hf_inputs, tt_inputs, pixel_comparisons
+ return hf_inputs, tt_inputs, pixel_comparisons, special_tokens
def _compare_images(hf_img, tt_img, sample_idx):
@@ -207,13 +235,11 @@ def print_pixel_comparisons(comparisons):
@torch.no_grad()
def run_hf(model_path, hf_inputs, device):
"""Run HF model, return last-token logits per sample."""
- from transformers import AutoModelForImageTextToText
-
print(f"Loading HuggingFace model on {device} ...")
model = AutoModelForImageTextToText.from_pretrained(
model_path,
device_map=device,
- torch_dtype=torch.float16,
+ dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
@@ -240,7 +266,7 @@ def run_hf(model_path, hf_inputs, device):
@torch.no_grad()
-def run_tt(model_flavor, checkpoint_path, tt_inputs, device):
+def run_tt(model_flavor, checkpoint_path, tt_inputs, special_tokens, device):
"""Run TT model, return last-token logits per sample."""
print(f"Loading TorchTitan model on {device} ...")
@@ -248,7 +274,7 @@ def run_tt(model_flavor, checkpoint_path, tt_inputs, device):
with torch.device("meta"):
model = model_config.build()
model.to_empty(device="cpu")
- model.init_weights(buffer_device=torch.device("cpu"))
+ model.init_states(buffer_device=torch.device("cpu"))
model.half()
state_dict = ModelWrapper(model)._get_state_dict()
@@ -258,11 +284,9 @@ def run_tt(model_flavor, checkpoint_path, tt_inputs, device):
# Replace FlexAttention with SDPA for single-process inference
# (unfused FlexAttention without torch.compile has poor fp16 numerics).
- from torchtitan.models.common.attention import ScaledDotProductAttention
-
for layer in model.layers.values():
- layer.attention.attn_backend = "sdpa"
- layer.attention.inner_attention = ScaledDotProductAttention.Config().build()
+ if layer.full_attn:
+ layer.attn.inner_attention = ScaledDotProductAttention.Config().build()
class _BidirectionalSDPA(torch.nn.Module):
def forward(self, q, k, v, **kwargs):
@@ -278,20 +302,13 @@ def forward(self, q, k, v, **kwargs):
model.eval()
- special_tokens = {
- "image_id": 151655,
- "video_id": 151656,
- "vision_start_id": 151652,
- "vision_end_id": 151653,
- "pad_id": 151643,
- }
-
outputs = []
- for i, (tokens, pixel_values, grid_thw) in enumerate(tt_inputs):
+ for i, (tokens, pixel_values, grid_thw, mrope_positions) in enumerate(tt_inputs):
logits = model(
tokens.to(device),
pixel_values=pixel_values.half().to(device),
grid_thw=grid_thw.to(device),
+ mrope_positions=mrope_positions.to(device),
special_tokens=special_tokens,
)
outputs.append(logits[:, -1:, :].cpu())
@@ -319,12 +336,13 @@ def compare(hf_outputs, tt_outputs):
kl = kl_divergence(hf, tt).item()
top1, top5 = top_k_match(hf, tt)
cos = F.cosine_similarity(hf.flatten(), tt.flatten(), dim=0).item()
+ max_diff = (hf - tt).abs().max().item()
total_kl += kl
total_top1 += top1
total_top5 += top5
print(
f" Sample {i + 1:2d}: KL={kl:.4e} cos={cos:.6f} "
- f"top1={top1:.0%} top5={top5:.0%}"
+ f"max_diff={max_diff:.4e} top1={top1:.0%} top5={top5:.0%}"
)
n = len(hf_outputs)
@@ -351,16 +369,11 @@ def compare(hf_outputs, tt_outputs):
def main():
parser = argparse.ArgumentParser(
- description="End-to-end numerical correctness test for Qwen3-VL.",
+ description="End-to-end numerical correctness test for Qwen3.5.",
)
parser.add_argument("--hf_model_path", type=str, required=True)
parser.add_argument("--tt_checkpoint_path", type=str, required=True)
- parser.add_argument(
- "--model_flavor",
- type=str,
- default="2B",
- choices=["2B", "8B", "30B-A3B", "235B-A22B"],
- )
+ parser.add_argument("--model_flavor", type=str, default="4B")
parser.add_argument("--num_samples", type=int, default=10)
args = parser.parse_args()
@@ -373,7 +386,7 @@ def main():
print(f"Using {torch.cuda.get_device_name(0)}")
print(f"\nBuilding {args.num_samples} test samples ...")
- hf_inputs, tt_inputs, pixel_comparisons = build_inputs(
+ hf_inputs, tt_inputs, pixel_comparisons, special_tokens = build_inputs(
args.hf_model_path,
args.model_flavor,
args.num_samples,
@@ -388,6 +401,7 @@ def main():
args.model_flavor,
args.tt_checkpoint_path,
tt_inputs,
+ special_tokens,
device,
)
diff --git a/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py
new file mode 100644
index 0000000000..a9da123221
--- /dev/null
+++ b/scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py
@@ -0,0 +1,211 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Numerical comparison across parallelism configs for Qwen3.5.
+
+Feeds identical fake tokens across configs (no_parallel, FSDP, FSDP+EP,
+FSDP+EP+TP) and verifies logits match. Requires up to 8 GPUs.
+
+Usage:
+ python scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py
+"""
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+import tempfile
+from typing import cast
+
+import torch
+import torch.distributed as dist
+from torch.distributed.tensor import DTensor
+
+from torchtitan.config import (
+ ActivationCheckpointConfig,
+ CompileConfig,
+ ParallelismConfig,
+ TrainingConfig,
+)
+from torchtitan.distributed import ParallelDims
+from torchtitan.models.qwen3_5 import Qwen35Model, qwen3_5_configs
+from torchtitan.models.qwen3_5.parallelize import parallelize_qwen3_5
+
+CONFIGS = [
+ {"ngpu": 1, "tp": 1, "ep": 1, "label": "no_parallel"},
+ {"ngpu": 4, "tp": 1, "ep": 1, "label": "fsdp"},
+ {"ngpu": 8, "tp": 1, "ep": 4, "label": "fsdp+ep"},
+ {"ngpu": 8, "tp": 2, "ep": 2, "label": "fsdp+ep+tp"},
+]
+
+
+def run_worker(args):
+ """Worker entry point — called via torchrun."""
+ dist.init_process_group("nccl")
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ torch.cuda.set_device(rank)
+
+ dp_shard = world_size // args.tp
+
+ seed = 42
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ config = qwen3_5_configs["debugmodel_moe"](
+ attn_backend="flex",
+ moe_comm_backend="standard",
+ )
+
+ parallel_dims = ParallelDims(
+ dp_shard=dp_shard,
+ dp_replicate=1,
+ cp=1,
+ tp=args.tp,
+ pp=1,
+ ep=args.ep,
+ world_size=world_size,
+ )
+ parallel_dims.build_mesh()
+
+ parallelism = ParallelismConfig(
+ tensor_parallel_degree=args.tp,
+ data_parallel_shard_degree=dp_shard,
+ expert_parallel_degree=args.ep,
+ )
+ training = TrainingConfig(
+ local_batch_size=1,
+ seq_len=128,
+ steps=1,
+ mixed_precision_param="bfloat16",
+ mixed_precision_reduce="float32",
+ )
+
+ config.update_from_config(
+ config=type(
+ "C",
+ (),
+ {
+ "training": training,
+ "parallelism": parallelism,
+ "debug": type("D", (), {"moe_force_load_balance": False})(),
+ },
+ )(),
+ )
+
+ model = config.build()
+ model.to_empty(device="cuda")
+ model.init_weights(buffer_device=torch.device("cuda"))
+
+ model = parallelize_qwen3_5(
+ model,
+ parallel_dims=parallel_dims,
+ training=training,
+ parallelism=parallelism,
+ compile_config=CompileConfig(),
+ ac_config=ActivationCheckpointConfig(),
+ dump_folder="/tmp",
+ )
+
+ torch.manual_seed(seed)
+ seq_len = 128
+ tokens = torch.randint(0, 248320, (1, seq_len), device="cuda")
+ dist.broadcast(tokens, src=0)
+
+ # Text-only inputs: plain sequential positions. The flex backend requires a
+ # BlockMask, which the trainer normally builds in post_dataloading_process;
+ # build it here directly since we call the model outside the trainer.
+ positions = torch.arange(seq_len, device="cuda").unsqueeze(0)
+ attention_masks = cast(Qwen35Model, model).get_attention_masks(positions=positions)
+
+ with torch.no_grad():
+ output = model(
+ tokens,
+ positions=positions,
+ attention_masks=attention_masks,
+ special_tokens={"image_id": 248056, "video_id": 248057},
+ )
+
+ if isinstance(output, DTensor):
+ output = output.full_tensor()
+
+ logits = output[0, 0, :10].float().tolist()
+
+ if rank == 0:
+ with open(args.output, "w") as f:
+ json.dump(logits, f)
+
+ dist.destroy_process_group()
+
+
+def main():
+ """Orchestrator — launches torchrun for each config and compares."""
+ results = {}
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ for cfg in CONFIGS:
+ outfile = os.path.join(tmpdir, f"{cfg['label']}.json")
+ cmd = [
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ f"--nproc-per-node={cfg['ngpu']}",
+ __file__,
+ "--worker",
+ f"--tp={cfg['tp']}",
+ f"--ep={cfg['ep']}",
+ f"--output={outfile}",
+ ]
+ print(
+ f"Running {cfg['label']} (ngpu={cfg['ngpu']}, "
+ f"tp={cfg['tp']}, ep={cfg['ep']})..."
+ )
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ timeout=300,
+ )
+ if result.returncode != 0:
+ print(f" FAILED:\n{result.stderr[-500:]}")
+ return 1
+
+ with open(outfile) as f:
+ results[cfg["label"]] = json.load(f)
+ print(f" logits: {[f'{v:.6f}' for v in results[cfg['label']]]}")
+
+ print("\n--- Comparison ---")
+ baseline_label = CONFIGS[0]["label"]
+ baseline = results[baseline_label]
+ all_pass = True
+
+ for cfg in CONFIGS[1:]:
+ label = cfg["label"]
+ logits = results[label]
+ max_diff = max(abs(a - b) for a, b in zip(baseline, logits))
+ status = "PASS" if max_diff < 0.02 else "FAIL"
+ if status == "FAIL":
+ all_pass = False
+ print(f" {baseline_label} vs {label}: max_diff={max_diff:.6e} {status}")
+
+ print(f"\n{'ALL PASS' if all_pass else 'SOME FAILED'}")
+ return 0 if all_pass else 1
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--worker", action="store_true")
+ parser.add_argument("--tp", type=int)
+ parser.add_argument("--ep", type=int)
+ parser.add_argument("--output", type=str)
+ args = parser.parse_args()
+
+ if args.worker:
+ run_worker(args)
+ else:
+ sys.exit(main())
diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py
index 30e03fb0ec..51f8a4365a 100755
--- a/tests/integration_tests/models.py
+++ b/tests/integration_tests/models.py
@@ -119,18 +119,19 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
"qwen3_fsdp+tp+cp",
ngpu=8,
),
- # Integration Test Cases for Qwen3-VL
+ # Integration Test Cases for Qwen3.5
OverrideDefinitions(
[
[
- "--module qwen3_vl --config qwen3_vl_debugmodel_moe",
- "--parallelism.data_parallel_shard_degree 4",
+ "--module qwen3_5 --config qwen35_debugmodel_moe",
+ "--parallelism.data_parallel_shard_degree 2",
+ "--parallelism.pipeline_parallel_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.expert_parallel_degree 4",
],
],
- "Qwen3-VL MoE FSDP+TP+EP",
- "qwen3_vl_moe_fsdp+tp+ep",
+ "Qwen3.5 MoE FSDP+TP+EP+PP",
+ "qwen3_5_moe_fsdp+tp+ep+pp",
ngpu=8,
),
# Integration Test Cases for gpt-oss
diff --git a/tests/unit_tests/test_rope.py b/tests/unit_tests/test_rope.py
index 5e6f0bc240..f97c02bec0 100644
--- a/tests/unit_tests/test_rope.py
+++ b/tests/unit_tests/test_rope.py
@@ -20,7 +20,7 @@
CosSinRoPE,
RoPE,
)
-from torchtitan.models.qwen3_vl.rope import MRoPE
+from torchtitan.models.qwen3_5.rope import MRoPE
class TestApplyRotaryEmbCosSin(unittest.TestCase):
@@ -173,11 +173,11 @@ def test_forward_accepts_three_axis_positions(self):
max_seq_len=8,
mrope_section=[2, 1, 1],
).build()
+ # (batch, seq, 3): per-token [temporal, height, width] positions.
position_ids = torch.tensor(
[
- [[0, 1, 2], [3, 4, 5]], # temporal
- [[1, 2, 3], [4, 5, 6]], # height
- [[2, 3, 4], [5, 6, 7]], # width
+ [[0, 1, 2], [1, 2, 3], [2, 3, 4]], # batch 0
+ [[3, 4, 5], [4, 5, 6], [5, 6, 7]], # batch 1
]
)
xq = torch.randn(bsz, seqlen, n_heads, head_dim)
diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py
index 7190bb1a72..62db7ef4da 100644
--- a/torchtitan/components/validate.py
+++ b/torchtitan/components/validate.py
@@ -141,7 +141,7 @@ def post_dataloading_process(
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
model_parts: list[nn.Module],
- ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
"""
Post-processing hook after data loading and before model forward pass.
@@ -157,27 +157,17 @@ def post_dataloading_process(
model_parts: List of model parts for accessing model methods.
Returns:
- A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
+ A tuple of (inputs, labels, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (potentially modified by CP sharding).
- - extra_inputs: Dict of auxiliary input tensors (all keys except
- "input" from input_dict). These are passed to the model forward
- but are NOT forwarded across pipeline parallel stages.
- - extra_kwargs: Dict of additional keyword arguments for model forward.
- These ARE forwarded across pipeline parallel stages. Contains
- attention_masks if flex attention is enabled.
-
- Note:
- The distinction between extra_inputs and extra_kwargs is important for
- pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
- while extra_inputs are only available to the first stage.
+ - extra_kwargs: Additional keyword arguments for the model forward
+ (e.g. positions, attention_masks), forwarded to every
+ pipeline-parallel stage.
"""
inputs = input_dict["input"]
- extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
- # For arguments, like attention_masks, we have to put them in a separate
- # dict as extra_inputs are not forwarded to other stages in PP, but
- # extra_kwargs are.
- extra_kwargs: dict[str, Any] = {}
+ extra_kwargs: dict[str, Any] = {
+ k: v for k, v in input_dict.items() if k != "input"
+ }
# TODO: deduplicate with Trainer.post_dataloading_process which has
# the same logic; extract a shared function to prevent further drift.
@@ -185,14 +175,16 @@ def post_dataloading_process(
# both RoPE and block_causal attention masking.
model_config = getattr(model_parts[0], "config", None)
- positions = extra_inputs.pop("positions", None)
+ positions = extra_kwargs.get("positions", None)
# positions and attention_masks are optional (Decoder.forward defaults
# both to None). Build masks only for the masked backends (Flex/Varlen),
# which is where get_attention_masks is defined. A maskless backend (the
# SDPA config used by the graph_trainer tests) still receives positions
# for RoPE but no masks — it relies on is_causal instead.
if isinstance(model_config, Decoder.Config) and positions is not None:
- inner_attention = model_config.layers[0].attention.inner_attention
+ inner_attention = getattr(
+ model_config.first_attention, "inner_attention", None
+ )
if isinstance(
inner_attention, (FlexAttention.Config, VarlenAttention.Config)
):
@@ -201,8 +193,6 @@ def post_dataloading_process(
positions=positions,
)
- extra_kwargs["positions"] = positions
-
if self.parallel_dims.cp_enabled:
inputs, labels, extra_kwargs = prepare_context_parallel_input(
inputs,
@@ -218,7 +208,7 @@ def post_dataloading_process(
self.parallel_dims, inputs, labels, extra_kwargs
)
- return inputs, labels, extra_inputs, extra_kwargs
+ return inputs, labels, extra_kwargs
@sl.log_trace_span("eval")
@torch.no_grad()
@@ -270,7 +260,7 @@ def validate(
global_valid_tokens = float(local_valid_tokens.item())
# Process data (extract inputs, handle attention masks, CP sharding)
- inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
+ inputs, labels, extra_kwargs = self.post_dataloading_process(
input_dict, labels, model_parts
)
@@ -286,7 +276,6 @@ def validate(
if self.pp_has_first_stage:
self.pp_schedule.eval(
inputs,
- **extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
@@ -309,7 +298,7 @@ def validate(
else:
with self.validation_context():
assert len(model_parts) == 1
- predictions = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
+ predictions = model_parts[0](inputs, **extra_kwargs)
loss_sum = self.loss_fn(predictions, labels)
accumulated_losses.append(loss_sum.detach() / global_valid_tokens)
diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py
index dc706aeacd..b51727f2a4 100644
--- a/torchtitan/experiments/forge/example_train.py
+++ b/torchtitan/experiments/forge/example_train.py
@@ -166,15 +166,15 @@ def batch_generator(
def post_dataloading_process(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
- ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
inputs = input_dict["input"]
- extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
- # For arguments, like attention_masks, we have to put them in a separate
- # dict as extra_inputs are not forwarded to other stages in PP, but
- # extra_kwargs are.
- extra_kwargs: dict[str, Any] = {}
+ # Everything except the pipelined input is a model-forward kwarg,
+ # forwarded to all PP stages by the schedule.
+ extra_kwargs: dict[str, Any] = {
+ k: v for k, v in input_dict.items() if k != "input"
+ }
- positions = extra_inputs.pop("positions", None)
+ positions = extra_kwargs.get("positions", None)
try:
# pyrefly: ignore [not-callable]
@@ -194,7 +194,7 @@ def post_dataloading_process(
self.config.parallelism.context_parallel_load_balancer,
)
- return inputs, labels, extra_inputs, extra_kwargs
+ return inputs, labels, extra_kwargs
def forward_backward_step(
self,
@@ -206,9 +206,7 @@ def forward_backward_step(
model_parts = self.model_parts
parallel_dims = self.parallel_dims
- inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
- input_dict, labels
- )
+ inputs, labels, extra_kwargs = self.post_dataloading_process(input_dict, labels)
if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
@@ -219,7 +217,6 @@ def forward_backward_step(
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
- **extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
@@ -244,7 +241,7 @@ def forward_backward_step(
# Non-PP forward / backward
with self.train_context():
assert len(model_parts) == 1
- pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
+ pred = model_parts[0](inputs, **extra_kwargs)
# Compute loss sum (reduction='sum')
loss_sum = self.loss_fn(pred, labels)
diff --git a/torchtitan/experiments/graph_trainer/precompile_main.py b/torchtitan/experiments/graph_trainer/precompile_main.py
index 8b12e81e2d..a0db7987c4 100644
--- a/torchtitan/experiments/graph_trainer/precompile_main.py
+++ b/torchtitan/experiments/graph_trainer/precompile_main.py
@@ -207,7 +207,6 @@ def _precompile_aot_fx_trace(
* parallel_dims.cp
)
dummy_global_valid_tokens = float(global_batch_size * seq_len)
- extra_inputs: dict[str, torch.Tensor] = {}
extra_kwargs: dict[str, Any] = {}
if isinstance(model_config, Decoder.Config) and model_config.layers:
@@ -257,7 +256,6 @@ def _precompile_aot_fx_trace(
dummy_inputs,
dummy_labels,
dummy_global_valid_tokens,
- extra_inputs,
extra_kwargs,
)
logger.info(
diff --git a/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py b/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py
index 2182529c41..087e28f136 100644
--- a/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py
+++ b/torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py
@@ -210,7 +210,6 @@ def _run_steps_with_precompile(
global_valid_tokens = torch.tensor(
BATCH_SIZE * SEQ_LEN, dtype=torch.float, device="cuda"
)
- extra_inputs: dict[str, torch.Tensor] = {}
extra_kwargs: dict[str, object] = {
"positions": self.positions,
**self._get_extra_kwargs(model),
@@ -222,7 +221,6 @@ def _run_steps_with_precompile(
self.inputs,
self.labels,
global_valid_tokens,
- extra_inputs,
extra_kwargs,
)
@@ -284,7 +282,6 @@ def _run_steps_with_precompile(
self.inputs,
self.labels,
global_valid_tokens,
- extra_inputs,
extra_kwargs,
)
loss = outputs[0]
diff --git a/torchtitan/experiments/graph_trainer/trainer.py b/torchtitan/experiments/graph_trainer/trainer.py
index 72c48fa469..8a45f8b6a0 100644
--- a/torchtitan/experiments/graph_trainer/trainer.py
+++ b/torchtitan/experiments/graph_trainer/trainer.py
@@ -70,8 +70,8 @@ def make_fwd_bwd_step(model, loss_fn):
to thread its parameters/buffers as static graph inputs.
"""
- def fwd_bwd_step(inputs, labels, global_valid_tokens, extra_inputs, extra_kwargs):
- pred = model(inputs, **extra_inputs, **extra_kwargs)
+ def fwd_bwd_step(inputs, labels, global_valid_tokens, extra_kwargs):
+ pred = model(inputs, **extra_kwargs)
# The loss function is not a submodule of the model, so
# annotate_module_fqns won't tag it. Annotate it here so that
# downstream passes (bucketing, SAC, kernel annotations) can
@@ -130,9 +130,7 @@ def forward_backward_step(
assert len(self.model_parts) == 1
model = self.model_parts[0]
- inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
- input_dict, labels
- )
+ inputs, labels, extra_kwargs = self.post_dataloading_process(input_dict, labels)
# remove_duplicate=False to preserve duplicate parameter entries
# from weight tying (e.g. shared embedding/output weights).
params = [
@@ -146,7 +144,6 @@ def forward_backward_step(
labels,
global_valid_tokens,
params,
- extra_inputs,
extra_kwargs,
)
@@ -185,7 +182,6 @@ def _make_fx_forward_backward_step(
labels: torch.Tensor,
global_valid_tokens: float,
params: list[torch.Tensor],
- extra_inputs: dict[str, torch.Tensor],
extra_kwargs: dict[str, Any],
) -> torch.Tensor:
maybe_register_blockmask_pytree_node()
@@ -199,7 +195,6 @@ def _make_fx_forward_backward_step(
inputs,
labels,
global_valid_tokens,
- extra_inputs,
extra_kwargs,
)
@@ -221,7 +216,6 @@ def _make_fx_forward_backward_step(
inputs,
labels,
global_valid_tokens,
- extra_inputs,
extra_kwargs,
)
loss = outputs[0]
diff --git a/torchtitan/hf_datasets/multimodal/mm_collator.py b/torchtitan/hf_datasets/multimodal/mm_collator.py
index 97c65f470e..3350d89de8 100644
--- a/torchtitan/hf_datasets/multimodal/mm_collator.py
+++ b/torchtitan/hf_datasets/multimodal/mm_collator.py
@@ -34,6 +34,7 @@ class MultiModalCollator:
temporal_patch_size: int
spatial_merge_size: int
tokenizer: MultiModalTokenizer
+ build_mrope_positions: bool
def collate_images(
self, all_images: list[torch.Tensor]
@@ -133,6 +134,171 @@ def collate_text(
return input_ids[:, :-1], labels[:, 1:], positions[:, :-1]
+ def _build_mrope_positions(
+ self,
+ tokens: torch.Tensor,
+ grid_thw: torch.Tensor | None,
+ grid_thw_videos: torch.Tensor | None,
+ positions: torch.Tensor | None,
+ *,
+ image_token_id: int,
+ video_token_id: int,
+ ) -> torch.Tensor:
+ """Build 3D (temporal, height, width) MRoPE position IDs per token.
+
+ Returns ``(batch, seq_len, 3)`` — batch/seq leading (like the 2D
+ ``positions``) so pipeline-parallel microbatching chunks the batch dim
+ and context parallel can shard the seq dim, with the 3 T/H/W coords as
+ the last (feature) axis. Runs here on CPU data workers, off the GPU
+ training path.
+
+ Args:
+ tokens: (batch, seq_len) token IDs.
+ grid_thw: (num_images, 3) image grid dims, or None.
+ grid_thw_videos: (num_videos, 3) video grid dims, or None.
+ positions: (batch, seq_len) per-token positions; document
+ boundaries are detected where positions reset.
+ image_token_id: Placeholder token ID marking image positions.
+ video_token_id: Placeholder token ID marking video positions.
+
+ Returns:
+ (batch, seq_len, 3) MRoPE position IDs.
+ """
+ # Expand each video [T, H, W] into T rows of [1, H, W] so each frame is
+ # treated like an image; temporal position comes from frame ordering.
+ if grid_thw_videos is not None:
+ grid_thw_videos = torch.repeat_interleave(
+ grid_thw_videos, grid_thw_videos[:, 0], dim=0
+ )
+ grid_thw_videos[:, 0] = 1
+
+ spatial_merge_size = self.spatial_merge_size
+
+ batch_size, seq_len = tokens.shape
+ mrope_positions = torch.zeros(
+ batch_size, seq_len, 3, dtype=tokens.dtype, device=tokens.device
+ )
+
+ if positions is not None:
+ resets = positions[:, 1:] < positions[:, :-1] # (batch, seq_len-1)
+ # First token of each consecutive vision region (image or video).
+ vision_mask = (tokens == image_token_id) | (tokens == video_token_id)
+ prev_vision = torch.cat(
+ [torch.zeros_like(vision_mask[:, :1]), vision_mask[:, :-1]], dim=1
+ )
+ batch_vision_starts = vision_mask & ~prev_vision # (batch, seq_len)
+ grid_cache: dict[tuple[int, int, int], torch.Tensor] = {}
+
+ image_index, video_index = 0, 0
+ # With sample packing, each sample may contain multiple documents.
+ for sample_i in range(batch_size):
+ llm_pos_ids_list: list[torch.Tensor] = []
+
+ if positions is not None:
+ # pyrefly: ignore [unbound-name]
+ reset_indices = torch.where(resets[sample_i])[0] + 1
+ doc_starts = [0] + reset_indices.tolist()
+ doc_ranges = [
+ (
+ doc_starts[d],
+ doc_starts[d + 1] if d + 1 < len(doc_starts) else seq_len,
+ )
+ for d in range(len(doc_starts))
+ ]
+ else:
+ doc_ranges = [(0, seq_len)]
+
+ sample_tokens = tokens[sample_i]
+ sample_vision_starts = torch.where(batch_vision_starts[sample_i])[
+ 0
+ ].tolist()
+ vision_start_index = 0
+
+ for doc_start, doc_end in doc_ranges:
+ doc_pos_ids_list: list[torch.Tensor] = []
+
+ doc_vision_starts: list[int] = []
+ while (
+ vision_start_index < len(sample_vision_starts)
+ and sample_vision_starts[vision_start_index] < doc_end
+ ):
+ doc_vision_starts.append(sample_vision_starts[vision_start_index])
+ vision_start_index += 1
+
+ pair_cursor = doc_start
+ for vision_start in doc_vision_starts:
+ if sample_tokens[vision_start] == image_token_id:
+ # pyrefly: ignore [unsupported-operation]
+ t, h, w = grid_thw[image_index]
+ image_index += 1
+ else:
+ # pyrefly: ignore [unsupported-operation]
+ t, h, w = grid_thw_videos[video_index]
+ video_index += 1
+
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ int(t.item()),
+ int(h.item()) // spatial_merge_size,
+ int(w.item()) // spatial_merge_size,
+ )
+ text_len = vision_start - pair_cursor
+
+ pos_id_offset = (
+ doc_pos_ids_list[-1].max() + 1
+ if len(doc_pos_ids_list) > 0
+ else 0
+ )
+ # [text tokens] — sequential positions, identical on all 3 axes.
+ doc_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset
+ )
+ # [vision tokens] — 3D grid positions (T, H, W).
+ grid_key = (llm_grid_t, llm_grid_h, llm_grid_w)
+ if grid_key not in grid_cache:
+ hw = llm_grid_h * llm_grid_w
+ t_index = (
+ torch.arange(llm_grid_t)
+ .view(-1, 1)
+ .expand(-1, hw)
+ .flatten()
+ )
+ h_index = (
+ torch.arange(llm_grid_h)
+ .view(1, -1, 1)
+ .expand(llm_grid_t, -1, llm_grid_w)
+ .flatten()
+ )
+ w_index = (
+ torch.arange(llm_grid_w)
+ .view(1, 1, -1)
+ .expand(llm_grid_t, llm_grid_h, -1)
+ .flatten()
+ )
+ grid_cache[grid_key] = torch.stack([t_index, h_index, w_index])
+ doc_pos_ids_list.append(
+ grid_cache[grid_key] + text_len + pos_id_offset
+ )
+ pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w
+
+ # Trailing [text tokens] after the last text/vision pair.
+ if pair_cursor < doc_end:
+ pos_id_offset = (
+ doc_pos_ids_list[-1].max() + 1
+ if len(doc_pos_ids_list) > 0
+ else 0
+ )
+ text_len = doc_end - pair_cursor
+ doc_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset
+ )
+
+ llm_pos_ids_list.extend(doc_pos_ids_list)
+
+ # llm_pos_ids_list is (3, segment_len); concat -> (3, seq), then transpose
+ mrope_positions[sample_i] = torch.cat(llm_pos_ids_list, dim=1).T
+
+ return mrope_positions
+
def __call__(
self, batch: list[dict[str, Any]]
) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]:
@@ -186,5 +352,18 @@ def __call__(
},
}
+ if self.build_mrope_positions and (
+ grids is not None or video_grids is not None
+ ):
+ special_tokens = input_dict["special_tokens"]
+ input_dict["mrope_positions"] = self._build_mrope_positions(
+ input_ids,
+ grids,
+ video_grids,
+ positions,
+ image_token_id=special_tokens["image_id"],
+ video_token_id=special_tokens["video_id"],
+ )
+
# pyrefly: ignore [bad-return]
return input_dict, labels
diff --git a/torchtitan/hf_datasets/multimodal/mm_datasets.py b/torchtitan/hf_datasets/multimodal/mm_datasets.py
index b9a9dca862..71adda51a2 100644
--- a/torchtitan/hf_datasets/multimodal/mm_datasets.py
+++ b/torchtitan/hf_datasets/multimodal/mm_datasets.py
@@ -531,6 +531,11 @@ class Config(ParallelAwareDataloader.Config):
video_max_frames: int = 768
"""Maximum number of frames to sample from a video."""
+ # Other loading configs
+ build_mrope_positions: bool = False
+ """Build 3D MRoPE position IDs (``mrope_positions``) for models that use
+ multi-dimensional RoPE"""
+
def __init__(
self,
config: Config,
@@ -574,6 +579,7 @@ def __init__(
temporal_patch_size=config.temporal_patch_size,
spatial_merge_size=config.spatial_merge_size,
tokenizer=tokenizer,
+ build_mrope_positions=config.build_mrope_positions,
)
dataloader_kwargs = {
diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py
index 44f32f521f..da1b5c6224 100644
--- a/torchtitan/models/__init__.py
+++ b/torchtitan/models/__init__.py
@@ -5,5 +5,5 @@
# LICENSE file in the root directory of this source tree.
_supported_models = frozenset(
- ["deepseek_v3", "flux", "gpt_oss", "llama3", "qwen3", "qwen3_vl"]
+ ["deepseek_v3", "flux", "gpt_oss", "llama3", "qwen3", "qwen3_5"]
)
diff --git a/torchtitan/models/common/__init__.py b/torchtitan/models/common/__init__.py
index 52a1634726..2902247763 100644
--- a/torchtitan/models/common/__init__.py
+++ b/torchtitan/models/common/__init__.py
@@ -25,6 +25,7 @@
from .feed_forward import compute_ffn_hidden_dim, FeedForward
from .moe import MoE
from .nn_modules import (
+ Conv1d,
Conv2d,
Embedding,
GELU,
@@ -38,6 +39,7 @@
from .rope import ComplexRoPE, CosSinRoPE, RoPE
__all__ = [
+ "Conv1d",
"Conv2d",
"ComplexRoPE",
"CosSinRoPE",
diff --git a/torchtitan/models/common/decoder.py b/torchtitan/models/common/decoder.py
index b0ff896e84..b26860e620 100644
--- a/torchtitan/models/common/decoder.py
+++ b/torchtitan/models/common/decoder.py
@@ -77,9 +77,31 @@ class Config(BaseModel.Config):
# itself is handled by ``Decoder.__init__`` / ``Decoder.init_states``.
enable_weight_tying: bool = False
+ @property
+ def first_attention(self) -> BaseAttention.Config | None:
+ """Attention config of the first layer that has one, else None.
+
+ Hybrid models (linear + full attention) don't carry an attention
+ config on every layer, so callers needing attention metadata (TP
+ validation, FLOPs, mask type) look up the first full-attention
+ layer rather than assuming ``layers[0]``.
+ """
+ return next(
+ (
+ layer.attention
+ for layer in self.layers
+ if layer.attention is not None
+ ),
+ None,
+ )
+
@property
def max_seq_len(self) -> int:
- return self.layers[0].attention.rope.max_seq_len
+ # The first full-attention layer's RoPE defines the context length.
+ rope_cfg = getattr(self.first_attention, "rope", None)
+ if rope_cfg is None:
+ raise ValueError("Decoder config does not define RoPE max_seq_len.")
+ return rope_cfg.max_seq_len
def update_from_config(
self,
@@ -115,8 +137,8 @@ def update_from_config(
)
tp = parallelism.tensor_parallel_degree
- if tp > 1:
- attention = self.layers[0].attention
+ attention = self.first_attention
+ if tp > 1 and attention is not None:
n_heads = attention.n_heads
n_kv_heads = getattr(attention, "n_kv_heads", None) or n_heads
if n_heads % tp != 0:
@@ -131,7 +153,7 @@ def update_from_config(
)
for layer_cfg in self.layers:
- if hasattr(layer_cfg, "moe") and layer_cfg.moe is not None:
+ if layer_cfg.moe is not None:
from torchtitan.models.common.token_dispatcher import (
DeepEPTokenDispatcher,
HybridEPTokenDispatcher,
@@ -273,8 +295,12 @@ def _create_flex_attention_mask_for_document(
def get_attention_masks(
self,
positions: torch.Tensor,
- ) -> AttentionMasksType:
- attn_config = self.config.layers[0].attention
+ ) -> AttentionMasksType | None:
+ attn_config = self.config.first_attention
+ if attn_config is None:
+ # No full-attention layers (e.g. a pure linear-attention model, or a
+ # pipeline stage holding only linear-attention blocks) → no masks.
+ return None
inner_attn = attn_config.inner_attention
if isinstance(inner_attn, FlexAttention.Config):
return self._create_flex_attention_mask_for_document(positions, attn_config)
diff --git a/torchtitan/models/common/moe_sharding.py b/torchtitan/models/common/moe_sharding.py
index ebe0da6452..9a335d8b5b 100644
--- a/torchtitan/models/common/moe_sharding.py
+++ b/torchtitan/models/common/moe_sharding.py
@@ -230,14 +230,31 @@ def set_moe_sharding_config(
# Router gate: dense-family TP plan with Partial output grad.
moe_cfg.router.gate.sharding_config = _router_gate_config(enable_ep=enable_ep)
- # Shared experts: optional. Use Partial-flow variants so the
- # Partial->sp_layout reduce only happens once at the MoE boundary.
- if getattr(moe_cfg, "shared_experts", None) is not None:
- moe_cfg.shared_experts.w1.sharding_config = _shared_expert_colwise_config(
+ # Shared experts: SwiGLU FFN run in parallel with the routed experts.
+ # Gather x to Replicate ONCE at the module boundary so w1/w3 share it (their
+ # per-linear input redistributions become no-ops). w2 (rowwise) keeps the
+ # output Partial; the Partial->sp_layout reduce happens once at the MoE
+ # boundary. Model-specific shared-expert extras are sharded by that
+ # model's own sharding code.
+ shared = moe_cfg.shared_experts
+ if shared is not None:
+ # Shared-expert input matches the MoE input: sequence-parallel under
+ # EP+SP (Shard(1) on seq, ordered via partition_spec), else Replicate.
+ shared_input = (
+ dense_sequence_parallel_placement()
+ if enable_ep and enable_sp
+ else dense_activation_placement(tp=spmd.R)
+ )
+ shared.sharding_config = ShardingConfig(
+ in_src_shardings={"x": shared_input},
+ in_dst_shardings={"x": dense_activation_placement(tp=spmd.R)},
+ )
+
+ shared.w1.sharding_config = _shared_expert_colwise_config(
enable_ep=enable_ep, enable_sp=enable_sp
)
- moe_cfg.shared_experts.w2.sharding_config = _shared_expert_rowwise_config()
- moe_cfg.shared_experts.w3.sharding_config = _shared_expert_colwise_config(
+ shared.w2.sharding_config = _shared_expert_rowwise_config()
+ shared.w3.sharding_config = _shared_expert_colwise_config(
enable_ep=enable_ep, enable_sp=enable_sp
)
diff --git a/torchtitan/models/common/nn_modules.py b/torchtitan/models/common/nn_modules.py
index 10273967d7..db2127e54b 100644
--- a/torchtitan/models/common/nn_modules.py
+++ b/torchtitan/models/common/nn_modules.py
@@ -23,6 +23,33 @@
from torchtitan.protocols.module import Module
+class Conv1d(nn.Conv1d, Module):
+ """Configurable nn.Conv1d."""
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Module.Config):
+ in_channels: int
+ out_channels: int
+ kernel_size: int
+ stride: int = 1
+ padding: int = 0
+ groups: int = 1
+ # Matches the upstream ``nn.Conv1d`` default (differs from
+ # ``Linear.Config.bias``, which defaults to False).
+ bias: bool = True
+
+ def __init__(self, config: Config):
+ super().__init__(
+ config.in_channels,
+ config.out_channels,
+ config.kernel_size,
+ stride=config.stride,
+ padding=config.padding,
+ groups=config.groups,
+ bias=config.bias,
+ )
+
+
class Conv2d(nn.Conv2d, Module):
"""Configurable nn.Conv2d."""
@@ -162,6 +189,7 @@ def __init__(self, config: Config):
__all__ = [
+ "Conv1d",
"Conv2d",
"Embedding",
"GELU",
diff --git a/torchtitan/models/qwen3_5/README.md b/torchtitan/models/qwen3_5/README.md
new file mode 100644
index 0000000000..7c2b9169da
--- /dev/null
+++ b/torchtitan/models/qwen3_5/README.md
@@ -0,0 +1,82 @@
+# Qwen3.5: Multimodal Model with Hybrid Attention
+
+## Overview
+
+Qwen3.5 combines:
+- **Hybrid Decoder**: 75% GatedDeltaNet (linear attention) + 25% full attention with output gating and partial RoPE.
+- **Vision Encoder**: A Vision Transformer (ViT) with 2D RoPE and bilinear-interpolated learned position embeddings.
+- **Patch Merger**: Reduces vision sequence length by merging spatial patches (e.g., 2x2 patches -> 1 token).
+- **MRoPE**: Interleaves RoPE from temporal, height, and width position IDs in decoder layers.
+- **MoE variant**: Routed experts + shared expert with sigmoid gate.
+
+## Vision Scatter
+
+- `tok_embeddings` produces text token embeddings of shape `B×S`.
+- `vision_encoder` produces visual token embeddings of shape `N×L`.
+- Valid visual tokens (excluding padding) are scattered into the placeholder positions in the text sequence, as illustrated below (credit: [@lkhphuc](https://github.com/lkhphuc)).
+
+
+
+Note: the diagram shows each patch mapping to one vision token. In practice, the Patch Merger groups `merge_size²` patches into one token (e.g., `merge_size=2` → 4 patches per token), reducing the vision sequence length by `merge_size²`.
+
+## Prerequisites
+
+Install the additional dependencies:
+
+```bash
+pip install av torchvision flash-linear-attention
+```
+
+## Model Variants
+
+### Dense
+
+| Variant | LLM dim | Layers | Heads | KV Heads | Head Dim | ViT dim | ViT layers |
+|---------|---------|--------|-------|----------|----------|---------|------------|
+| debugmodel | 256 | 8 | 4 | 2 | 64 | 256 | 4 |
+| 0.8B | 1024 | 24 | 8 | 2 | 256 | 768 | 12 |
+| 2B | 2048 | 24 | 8 | 2 | 256 | 1024 | 24 |
+| 4B | 2560 | 32 | 16 | 4 | 256 | 1024 | 24 |
+| 9B | 4096 | 32 | 16 | 4 | 256 | 1152 | 27 |
+| 27B | 5120 | 64 | 24 | 4 | 256 | 1152 | 27 |
+
+### MoE
+
+| Variant | LLM dim | Layers | Experts | Top-k | Shared Expert |
+|---------|---------|--------|---------|-------|---------------|
+| debugmodel_moe | 256 | 4 | 8 | 2 | Yes |
+| 35B-A3B | 2048 | 40 | 256 | 8 | Yes |
+| 122B-A10B | 3072 | 48 | 256 | 8 | Yes |
+| 397B-A17B | 4096 | 60 | 512 | 10 | Yes |
+
+## Datasets
+
+| Dataset | Type | Format |
+|---------|------|--------|
+| `cc12m` | Image-text pairs | WebDataset (streaming) |
+| `cc12m-test` | Image-text pairs | Local WebDataset (for testing) |
+
+## Supported Parallelisms
+
+| Feature | Notes |
+|---------|-------|
+| FSDP / HSDP | Decoder sharded per-layer; vision encoder sharded as a single unit (one AllGather) |
+| Tensor Parallelism (TP) | With Sequence Parallel; head-sharded TP on GatedDeltaNet projections |
+| Expert Parallelism (EP) | For MoE variants |
+| Pipeline Parallel (PP) | Vision encoder assigned to first stage; 1F1B and Interleaved1F1B schedules |
+| Sample Packing | Configurable via `packing_buffer_size` in dataloader config |
+
+## Numerical Parity
+
+End-to-end KL divergence against HuggingFace Transformers (4B, multimodal inputs): **~3e-7** average, with **100% top-1 and top-5 match**.
+
+Parallelism correctness: bit-identical logits (max diff `0.0`) across no-parallel, FSDP, FSDP+EP, and FSDP+EP+TP configs.
+
+Test scripts:
+- `scripts/checkpoint_conversion/numerical_tests_qwen3_5.py` — HF vs TT comparison
+- `scripts/checkpoint_conversion/numerical_tests_qwen3_5_shard.py` — parallelism correctness
+
+## TODO
+
+- Add video dataset training configs
+- Add Context Parallel (CP) support
diff --git a/torchtitan/models/qwen3_5/__init__.py b/torchtitan/models/qwen3_5/__init__.py
new file mode 100644
index 0000000000..1449dbfe9f
--- /dev/null
+++ b/torchtitan/models/qwen3_5/__init__.py
@@ -0,0 +1,1111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections.abc import Callable
+from functools import partial
+from typing import Literal
+
+import torch.nn as nn
+
+from torchtitan.components.optimizer import register_moe_load_balancing_hook
+
+from torchtitan.models.common import Conv1d, Embedding, Linear # noqa: F401
+from torchtitan.models.common.config_utils import (
+ get_attention_config,
+ make_experts_config,
+ make_ffn_config,
+ make_moe_config,
+ make_router_config,
+)
+from torchtitan.models.common.nn_modules import LayerNorm
+from torchtitan.models.common.param_init import depth_scaled_std # noqa: F401
+from torchtitan.models.utils import validate_converter_order
+from torchtitan.protocols.model import ModelConfigConverter
+
+from torchtitan.protocols.model_spec import ModelSpec
+
+from .model import (
+ GatedDeltaKernel,
+ GatedDeltaNet,
+ OffsetRMSNorm,
+ Qwen35Attention,
+ Qwen35Model,
+ Qwen35TransformerBlock,
+ RMSNormGated,
+ SharedExperts,
+)
+from .parallelize import parallelize_qwen3_5, pipeline_qwen3_5
+from .rope import MRoPE
+from .state_dict_adapter import Qwen35StateDictAdapter
+from .vision_encoder import (
+ PatchMerger,
+ Qwen35VisionEncoder,
+ VisionAttention,
+ VisionMLP,
+ VisionRotaryEmbedding,
+ VisionTransformerBlock,
+)
+
+__all__ = [
+ "parallelize_qwen3_5",
+ "Qwen35Model",
+ "qwen3_5_configs",
+ "QWEN3_5_SPECIAL_TOKENS",
+]
+
+QWEN3_5_SPECIAL_TOKENS = {
+ "image_token": "<|image_pad|>",
+ "video_token": "<|video_pad|>",
+ "vision_start_token": "<|vision_start|>",
+ "vision_end_token": "<|vision_end|>",
+ "pad_token": "<|endoftext|>",
+}
+
+
+_LINEAR_INIT = {
+ "weight": partial(nn.init.trunc_normal_, std=0.02),
+ "bias": nn.init.zeros_,
+}
+_OFFSET_NORM_INIT = {"weight": nn.init.zeros_}
+_EMBEDDING_INIT = {"weight": partial(nn.init.normal_, std=1.0)}
+_POS_EMBED_INIT = {"pos_embed": partial(nn.init.trunc_normal_, mean=0.0, std=0.02)}
+
+_EPS = 1e-6
+
+
+def _output_linear_init(dim: int) -> dict[str, Callable]:
+ s = dim**-0.5
+ return {
+ "weight": partial(nn.init.trunc_normal_, std=s, a=-3 * s, b=3 * s),
+ "bias": nn.init.zeros_,
+ }
+
+
+def _depth_init(layer_id: int) -> dict[str, Callable]:
+ return {
+ "weight": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)),
+ "bias": nn.init.zeros_,
+ }
+
+
+def _depth_experts_init(layer_id: int) -> dict[str, Callable]:
+ return {
+ "w1_EFD": partial(nn.init.trunc_normal_, std=0.02),
+ "w2_EDF": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)),
+ "w3_EFD": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)),
+ }
+
+
+def _a_log_init(param: nn.Parameter) -> None:
+ param.data.uniform_(1e-6, 16.0).log_()
+
+
+def _linear(in_features: int, out_features: int) -> Linear.Config:
+ return Linear.Config(
+ in_features=in_features,
+ out_features=out_features,
+ bias=True,
+ param_init=_LINEAR_INIT,
+ )
+
+
+def _offset_norm(dim: int) -> OffsetRMSNorm.Config:
+ return OffsetRMSNorm.Config(dim=dim, eps=_EPS, param_init=_OFFSET_NORM_INIT)
+
+
+def _shared_experts_config(
+ *, dim: int, hidden_dim: int, layer_id: int
+) -> SharedExperts.Config:
+ """Build Qwen3.5's sigmoid-gated shared-expert config (SwiGLU FFN + gate)."""
+ ffn = make_ffn_config(
+ dim=dim,
+ hidden_dim=hidden_dim,
+ w1_param_init=_LINEAR_INIT,
+ w2w3_param_init=_depth_init(layer_id),
+ )
+ return SharedExperts.Config(
+ w1=ffn.w1,
+ w2=ffn.w2,
+ w3=ffn.w3,
+ gate=Linear.Config(in_features=dim, out_features=1, param_init=_LINEAR_INIT),
+ )
+
+
+def _qwen35_vision_encoder_config(
+ *,
+ dim: int,
+ ffn_dim: int,
+ num_layers: int,
+ num_heads: int,
+ patch_size: int,
+ temporal_patch_size: int,
+ spatial_merge_size: int,
+ out_hidden_size: int,
+ num_position_embeddings: int,
+ layer_norm_eps: float = 1e-6,
+ rope_theta: float = 10000.0,
+ in_channels: int = 3,
+) -> Qwen35VisionEncoder.Config:
+ """Build a fully-specified Qwen35VisionEncoder.Config."""
+ patch_dim = in_channels * temporal_patch_size * patch_size * patch_size
+ merged_hidden_size = dim * (spatial_merge_size**2)
+ head_dim = dim // num_heads
+ _norm = LayerNorm.Config(normalized_shape=dim, eps=layer_norm_eps)
+ return Qwen35VisionEncoder.Config(
+ dim=dim,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ in_channels=in_channels,
+ spatial_merge_size=spatial_merge_size,
+ num_position_embeddings=num_position_embeddings,
+ patch_embed_proj=_linear(patch_dim, dim),
+ block=VisionTransformerBlock.Config(
+ norm1=_norm,
+ norm2=_norm,
+ attn=VisionAttention.Config(
+ dim=dim,
+ num_heads=num_heads,
+ wq=_linear(dim, dim),
+ wk=_linear(dim, dim),
+ wv=_linear(dim, dim),
+ proj=_linear(dim, dim),
+ ),
+ mlp=VisionMLP.Config(
+ fc1=_linear(dim, ffn_dim),
+ fc2=_linear(ffn_dim, dim),
+ ),
+ ),
+ rotary_pos_emb=VisionRotaryEmbedding.Config(
+ dim=head_dim // 2, theta=rope_theta
+ ),
+ merger=PatchMerger.Config(
+ spatial_merge_size=spatial_merge_size,
+ merged_hidden_size=merged_hidden_size,
+ norm=LayerNorm.Config(normalized_shape=dim, eps=layer_norm_eps),
+ fc1=_linear(merged_hidden_size, merged_hidden_size),
+ fc2=_linear(merged_hidden_size, out_hidden_size),
+ ),
+ param_init=_POS_EMBED_INIT,
+ )
+
+
+def _qwen35_attention_config(
+ *,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ head_dim: int,
+ rotary_dim: int,
+ rope: MRoPE.Config,
+ attn_backend: str,
+ layer_id: int,
+) -> Qwen35Attention.Config:
+ """Build a fully-specified Qwen35Attention.Config."""
+ inner_attention = get_attention_config(attn_backend)
+ return Qwen35Attention.Config(
+ n_heads=n_heads,
+ n_kv_heads=n_kv_heads,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ rope=rope,
+ wq=Linear.Config(
+ in_features=dim,
+ out_features=n_heads * head_dim * 2,
+ param_init=_LINEAR_INIT,
+ ),
+ wk=Linear.Config(
+ in_features=dim,
+ out_features=n_kv_heads * head_dim,
+ param_init=_LINEAR_INIT,
+ ),
+ wv=Linear.Config(
+ in_features=dim,
+ out_features=n_kv_heads * head_dim,
+ param_init=_LINEAR_INIT,
+ ),
+ wo=Linear.Config(
+ in_features=n_heads * head_dim,
+ out_features=dim,
+ param_init=_depth_init(layer_id),
+ ),
+ q_norm=_offset_norm(head_dim),
+ k_norm=_offset_norm(head_dim),
+ inner_attention=inner_attention,
+ )
+
+
+def _qwen35_deltanet_config(
+ *,
+ dim: int,
+ n_key_heads: int,
+ n_value_heads: int,
+ key_head_dim: int,
+ value_head_dim: int,
+ layer_id: int,
+ conv_kernel_size: int = 4,
+ fla_backend: Literal[
+ "fla_chunked", "fla_fused_recurrent", "torch_native"
+ ] = "fla_chunked",
+) -> GatedDeltaNet.Config:
+ """Build a fully-specified GatedDeltaNet.Config."""
+ key_dim = n_key_heads * key_head_dim
+ value_dim = n_value_heads * value_head_dim
+
+ def _proj(in_f: int, out_f: int, init: dict) -> Linear.Config:
+ return Linear.Config(
+ in_features=in_f, out_features=out_f, bias=False, param_init=init
+ )
+
+ def _conv(channels: int) -> Conv1d.Config:
+ # Depthwise causal conv (groups == channels). Causal left-padding is
+ # applied in the forward, so padding=0 here.
+ return Conv1d.Config(
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=conv_kernel_size,
+ groups=channels,
+ padding=0,
+ bias=False,
+ )
+
+ return GatedDeltaNet.Config(
+ key_head_dim=key_head_dim,
+ value_head_dim=value_head_dim,
+ conv_kernel_size=conv_kernel_size,
+ in_proj_q=_proj(dim, key_dim, _LINEAR_INIT),
+ in_proj_k=_proj(dim, key_dim, _LINEAR_INIT),
+ in_proj_v=_proj(dim, value_dim, _LINEAR_INIT),
+ in_proj_z=_proj(dim, value_dim, _LINEAR_INIT),
+ in_proj_a=_proj(dim, n_value_heads, _LINEAR_INIT),
+ in_proj_b=_proj(dim, n_value_heads, _LINEAR_INIT),
+ conv_q=_conv(key_dim),
+ conv_k=_conv(key_dim),
+ conv_v=_conv(value_dim),
+ kernel=GatedDeltaKernel.Config(backend=fla_backend),
+ norm=RMSNormGated.Config(
+ dim=value_head_dim,
+ eps=1e-6,
+ param_init={"weight": nn.init.ones_},
+ ),
+ out_proj=_proj(value_dim, dim, _depth_init(layer_id)),
+ param_init={
+ "A_log": _a_log_init,
+ "dt_bias": nn.init.ones_,
+ },
+ )
+
+
+def _build_qwen35_layers(
+ *,
+ n_layers: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ head_dim: int,
+ rotary_dim: int,
+ rope: MRoPE.Config,
+ hidden_dim: int,
+ n_key_heads: int,
+ n_value_heads: int,
+ key_head_dim: int,
+ value_head_dim: int,
+ full_attention_interval: int = 4,
+ attn_backend: str,
+ fla_backend: Literal[
+ "fla_chunked", "fla_fused_recurrent", "torch_native"
+ ] = "fla_chunked",
+) -> list[Qwen35TransformerBlock.Config]:
+ """Build per-layer configs for dense Qwen3.5 models."""
+ layers = []
+ for layer_id in range(n_layers):
+ is_full = (layer_id + 1) % full_attention_interval == 0
+
+ attention = (
+ _qwen35_attention_config(
+ dim=dim,
+ n_heads=n_heads,
+ n_kv_heads=n_kv_heads,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ rope=rope,
+ attn_backend=attn_backend,
+ layer_id=layer_id,
+ )
+ if is_full
+ else None
+ )
+ deltanet = (
+ _qwen35_deltanet_config(
+ dim=dim,
+ n_key_heads=n_key_heads,
+ n_value_heads=n_value_heads,
+ key_head_dim=key_head_dim,
+ value_head_dim=value_head_dim,
+ layer_id=layer_id,
+ fla_backend=fla_backend,
+ )
+ if not is_full
+ else None
+ )
+
+ layers.append(
+ Qwen35TransformerBlock.Config(
+ attention=attention,
+ delta_net=deltanet,
+ feed_forward=make_ffn_config(
+ dim=dim,
+ hidden_dim=hidden_dim,
+ w1_param_init=_LINEAR_INIT,
+ w2w3_param_init=_depth_init(layer_id),
+ ),
+ attention_norm=_offset_norm(dim),
+ ffn_norm=_offset_norm(dim),
+ )
+ )
+ return layers
+
+
+def _build_qwen35_moe_layers(
+ *,
+ n_layers: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ head_dim: int,
+ rotary_dim: int,
+ rope: MRoPE.Config,
+ moe_hidden_dim: int,
+ num_experts: int,
+ top_k: int,
+ shared_expert_hidden_dim: int,
+ n_key_heads: int,
+ n_value_heads: int,
+ key_head_dim: int,
+ value_head_dim: int,
+ full_attention_interval: int = 4,
+ attn_backend: str,
+ fla_backend: Literal[
+ "fla_chunked", "fla_fused_recurrent", "torch_native"
+ ] = "fla_chunked",
+ moe_comm_backend: str = "standard",
+ non_blocking_capacity_factor: float | None = None,
+) -> list[Qwen35TransformerBlock.Config]:
+ """Build per-layer configs for MoE Qwen3.5 models with shared expert."""
+ layers = []
+ for layer_id in range(n_layers):
+ is_full = (layer_id + 1) % full_attention_interval == 0
+
+ attention = (
+ _qwen35_attention_config(
+ dim=dim,
+ n_heads=n_heads,
+ n_kv_heads=n_kv_heads,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ rope=rope,
+ attn_backend=attn_backend,
+ layer_id=layer_id,
+ )
+ if is_full
+ else None
+ )
+ deltanet = (
+ _qwen35_deltanet_config(
+ dim=dim,
+ n_key_heads=n_key_heads,
+ n_value_heads=n_value_heads,
+ key_head_dim=key_head_dim,
+ value_head_dim=value_head_dim,
+ layer_id=layer_id,
+ fla_backend=fla_backend,
+ )
+ if not is_full
+ else None
+ )
+
+ layers.append(
+ Qwen35TransformerBlock.Config(
+ attention=attention,
+ delta_net=deltanet,
+ moe=make_moe_config(
+ num_experts=num_experts,
+ router=make_router_config(
+ dim=dim,
+ num_experts=num_experts,
+ gate_param_init=_depth_init(layer_id),
+ top_k=top_k,
+ score_func="softmax",
+ route_norm=True,
+ ),
+ experts=make_experts_config(
+ dim=dim,
+ hidden_dim=moe_hidden_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ param_init=_depth_experts_init(layer_id),
+ score_before_experts=False,
+ comm_backend=moe_comm_backend,
+ non_blocking_capacity_factor=non_blocking_capacity_factor,
+ ),
+ shared_experts=_shared_experts_config(
+ dim=dim,
+ hidden_dim=shared_expert_hidden_dim,
+ layer_id=layer_id,
+ ),
+ ),
+ attention_norm=_offset_norm(dim),
+ ffn_norm=_offset_norm(dim),
+ )
+ )
+ return layers
+
+
+def _debugmodel(attn_backend: str) -> Qwen35Model.Config:
+ """Debug config for Qwen3.5 with vision encoder."""
+ dim = 256
+ head_dim = 64
+ rotary_dim = 16
+ n_layers = 8
+ vocab_size = 248320
+ # mrope_section sum must equal rotary_dim / 2 (8 for rotary_dim=16).
+ # Real models use [11, 11, 10] with rotary_dim=64.
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=4096,
+ theta=10_000_000.0,
+ mrope_section=[3, 3, 2],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=4,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ hidden_dim=512,
+ n_key_heads=2,
+ n_value_heads=4,
+ key_head_dim=64,
+ value_head_dim=64,
+ fla_backend="fla_chunked",
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=256,
+ ffn_dim=512,
+ num_layers=4,
+ num_heads=4,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=256,
+ num_position_embeddings=1024,
+ ),
+ )
+
+
+def _debugmodel_moe(
+ attn_backend: str,
+ moe_comm_backend: str = "standard",
+) -> Qwen35Model.Config:
+ """Debug MoE config for Qwen3.5 with shared expert."""
+ dim = 256
+ head_dim = 64
+ rotary_dim = 16
+ n_layers = 4
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_moe_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=4096,
+ theta=10_000_000.0,
+ mrope_section=[3, 3, 2],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=4,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ moe_hidden_dim=256,
+ num_experts=8,
+ top_k=2,
+ shared_expert_hidden_dim=256,
+ n_key_heads=2,
+ n_value_heads=4,
+ key_head_dim=64,
+ value_head_dim=64,
+ moe_comm_backend=moe_comm_backend,
+ fla_backend="fla_chunked",
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=256,
+ ffn_dim=512,
+ num_layers=2,
+ num_heads=4,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=256,
+ num_position_embeddings=1024,
+ ),
+ )
+
+
+def _0_8b(attn_backend: str) -> Qwen35Model.Config:
+ """Qwen3.5-0.8B dense config with vision encoder.
+
+ NOTE: HF config has tie_word_embeddings=true. Torchtitan doesn't support
+ tied embeddings yet, so we use a separate lm_head. Checkpoint conversion
+ must handle this.
+ """
+ dim = 1024
+ head_dim = 256
+ rotary_dim = 64 # partial_rotary_factor=0.25 → head_dim * 0.25
+ n_layers = 24
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=8,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ hidden_dim=3584,
+ n_key_heads=16,
+ n_value_heads=16,
+ key_head_dim=128,
+ value_head_dim=128,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=768,
+ ffn_dim=3072,
+ num_layers=12,
+ num_heads=12,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=1024,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _2b(attn_backend: str) -> Qwen35Model.Config:
+ """Qwen3.5-2B dense config with vision encoder.
+
+ NOTE: HF config has tie_word_embeddings=true. Torchtitan doesn't support
+ tied embeddings yet, so we use a separate lm_head. Checkpoint conversion
+ must handle this.
+ """
+ dim = 2048
+ head_dim = 256
+ rotary_dim = 64 # partial_rotary_factor=0.25
+ n_layers = 24
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=8,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ hidden_dim=6144,
+ n_key_heads=16,
+ n_value_heads=16,
+ key_head_dim=128,
+ value_head_dim=128,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1024,
+ ffn_dim=4096,
+ num_layers=24,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=2048,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _4b(attn_backend: str) -> Qwen35Model.Config:
+ """Qwen3.5-4B dense config with vision encoder.
+
+ NOTE: HF config has tie_word_embeddings=true. Torchtitan doesn't support
+ tied embeddings yet, so we use a separate lm_head.
+ """
+ dim = 2560
+ head_dim = 256
+ rotary_dim = 64
+ n_layers = 32
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=16,
+ n_kv_heads=4,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ hidden_dim=9216,
+ n_key_heads=16,
+ n_value_heads=32,
+ key_head_dim=128,
+ value_head_dim=128,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1024,
+ ffn_dim=4096,
+ num_layers=24,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=2560,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _9b(attn_backend: str) -> Qwen35Model.Config:
+ """Qwen3.5-9B dense config with vision encoder."""
+ dim = 4096
+ head_dim = 256
+ rotary_dim = 64
+ n_layers = 32
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=16,
+ n_kv_heads=4,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ hidden_dim=12288,
+ n_key_heads=16,
+ n_value_heads=32,
+ key_head_dim=128,
+ value_head_dim=128,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1152,
+ ffn_dim=4304,
+ num_layers=27,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=4096,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _27b(attn_backend: str) -> Qwen35Model.Config:
+ """Qwen3.5-27B dense config with vision encoder."""
+ dim = 5120
+ head_dim = 256
+ rotary_dim = 64
+ n_layers = 64
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=24,
+ n_kv_heads=4,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ hidden_dim=17408,
+ n_key_heads=16,
+ n_value_heads=48,
+ key_head_dim=128,
+ value_head_dim=128,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1152,
+ ffn_dim=4304,
+ num_layers=27,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=5120,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _35b_a3b(
+ attn_backend: str,
+ moe_comm_backend: str = "standard",
+) -> Qwen35Model.Config:
+ """Qwen3.5-35B-A3B MoE config with vision encoder."""
+ dim = 2048
+ head_dim = 256
+ rotary_dim = 64
+ n_layers = 40
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_moe_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=16,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ moe_hidden_dim=512,
+ num_experts=256,
+ top_k=8,
+ shared_expert_hidden_dim=512,
+ n_key_heads=16,
+ n_value_heads=32,
+ key_head_dim=128,
+ value_head_dim=128,
+ moe_comm_backend=moe_comm_backend,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1152,
+ ffn_dim=4304,
+ num_layers=27,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=2048,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _122b_a10b(
+ attn_backend: str,
+ moe_comm_backend: str = "standard",
+) -> Qwen35Model.Config:
+ """Qwen3.5-122B-A10B MoE config with vision encoder."""
+ dim = 3072
+ head_dim = 256
+ rotary_dim = 64
+ n_layers = 48
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_moe_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=32,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ moe_hidden_dim=1024,
+ num_experts=256,
+ top_k=8,
+ shared_expert_hidden_dim=1024,
+ n_key_heads=16,
+ n_value_heads=64,
+ key_head_dim=128,
+ value_head_dim=128,
+ moe_comm_backend=moe_comm_backend,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1152,
+ ffn_dim=4304,
+ num_layers=27,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=3072,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+def _397b_a17b(
+ attn_backend: str,
+ moe_comm_backend: str = "standard",
+) -> Qwen35Model.Config:
+ """Qwen3.5-397B-A17B MoE config with vision encoder."""
+ dim = 4096
+ head_dim = 256
+ rotary_dim = 64
+ n_layers = 60
+ vocab_size = 248320
+ return Qwen35Model.Config(
+ vocab_size=vocab_size,
+ dim=dim,
+ # pyrefly: ignore [bad-argument-type]
+ norm=_offset_norm(dim),
+ tok_embeddings=Embedding.Config(
+ num_embeddings=vocab_size,
+ embedding_dim=dim,
+ param_init=_EMBEDDING_INIT,
+ ),
+ lm_head=Linear.Config(
+ in_features=dim,
+ out_features=vocab_size,
+ param_init=_output_linear_init(dim),
+ ),
+ layers=_build_qwen35_moe_layers(
+ rope=MRoPE.Config(
+ dim=rotary_dim,
+ max_seq_len=262144,
+ theta=10_000_000.0,
+ mrope_section=[11, 11, 10],
+ ),
+ attn_backend=attn_backend,
+ n_layers=n_layers,
+ dim=dim,
+ n_heads=32,
+ n_kv_heads=2,
+ head_dim=head_dim,
+ rotary_dim=rotary_dim,
+ moe_hidden_dim=1024,
+ num_experts=512,
+ top_k=10,
+ shared_expert_hidden_dim=1024,
+ n_key_heads=16,
+ n_value_heads=64,
+ key_head_dim=128,
+ value_head_dim=128,
+ moe_comm_backend=moe_comm_backend,
+ ),
+ vision_encoder=_qwen35_vision_encoder_config(
+ dim=1152,
+ ffn_dim=4304,
+ num_layers=27,
+ num_heads=16,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ out_hidden_size=4096,
+ num_position_embeddings=2304,
+ ),
+ )
+
+
+qwen3_5_configs = {
+ "debugmodel": _debugmodel,
+ "debugmodel_moe": _debugmodel_moe,
+ "0.8B": _0_8b,
+ "2B": _2b,
+ "4B": _4b,
+ "9B": _9b,
+ "27B": _27b,
+ "35B-A3B": _35b_a3b,
+ "122B-A10B": _122b_a10b,
+ "397B-A17B": _397b_a17b,
+}
+
+
+def model_registry(
+ flavor: str,
+ attn_backend: str = "flex",
+ moe_comm_backend: str | None = None,
+ converters: list[ModelConfigConverter.Config] | None = None,
+) -> ModelSpec:
+ kwargs = dict(attn_backend=attn_backend)
+ if moe_comm_backend is not None:
+ kwargs["moe_comm_backend"] = moe_comm_backend
+ config = qwen3_5_configs[flavor](**kwargs)
+ if converters is not None:
+ validate_converter_order(converters)
+ for c in converters:
+ c.build().convert(config)
+
+ return ModelSpec(
+ name="qwen3_5",
+ flavor=flavor,
+ model=config,
+ parallelize_fn=parallelize_qwen3_5,
+ pipelining_fn=pipeline_qwen3_5,
+ post_optimizer_build_fn=register_moe_load_balancing_hook,
+ state_dict_adapter=Qwen35StateDictAdapter,
+ )
diff --git a/torchtitan/models/qwen3_5/config_registry.py b/torchtitan/models/qwen3_5/config_registry.py
new file mode 100644
index 0000000000..d6d1518d4d
--- /dev/null
+++ b/torchtitan/models/qwen3_5/config_registry.py
@@ -0,0 +1,322 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torchtitan.components.checkpoint import CheckpointManager
+from torchtitan.components.loss import ChunkedCELoss
+from torchtitan.components.lr_scheduler import LRSchedulersContainer
+from torchtitan.components.metrics import MetricsProcessor
+from torchtitan.components.optimizer import default_adamw
+from torchtitan.components.tokenizer import MultiModalTokenizer
+
+from torchtitan.config import (
+ ActivationCheckpointConfig,
+ ParallelismConfig,
+ TrainingConfig,
+)
+from torchtitan.hf_datasets.multimodal.mm_datasets import MMDataLoader
+from torchtitan.trainer import Trainer
+
+from . import model_registry, QWEN3_5_SPECIAL_TOKENS
+
+
+def _dataloader(dataset: str, **kwargs) -> MMDataLoader.Config:
+ return MMDataLoader.Config(
+ dataset=dataset,
+ max_images_per_batch=128,
+ patch_size=16,
+ temporal_patch_size=2,
+ spatial_merge_size=2,
+ min_pixels=65536,
+ max_pixels=16777216,
+ image_mean=(0.5, 0.5, 0.5),
+ image_std=(0.5, 0.5, 0.5),
+ build_mrope_positions=True,
+ **kwargs,
+ )
+
+
+def qwen35_debugmodel() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./tests/assets/tokenizer",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ metrics=MetricsProcessor.Config(log_freq=1),
+ model_spec=model_registry("debugmodel"),
+ dataloader=_dataloader("cc12m-test"),
+ optimizer=default_adamw(lr=5e-3),
+ lr_scheduler=LRSchedulersContainer.Config(
+ warmup_steps=2,
+ decay_ratio=0.8,
+ decay_type="linear",
+ min_lr_factor=0.0,
+ ),
+ training=TrainingConfig(
+ local_batch_size=1,
+ seq_len=512,
+ steps=10,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=10,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="selective",
+ ),
+ )
+
+
+def qwen35_debugmodel_moe() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./tests/assets/tokenizer",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ metrics=MetricsProcessor.Config(log_freq=1),
+ model_spec=model_registry("debugmodel_moe", moe_comm_backend="standard"),
+ dataloader=_dataloader("cc12m-test"),
+ optimizer=default_adamw(lr=5e-3),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=2),
+ training=TrainingConfig(
+ local_batch_size=2,
+ seq_len=512,
+ steps=10,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=2,
+ pipeline_parallel_degree=2,
+ expert_parallel_degree=4,
+ tensor_parallel_degree=2,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=10,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="selective",
+ ),
+ )
+
+
+def qwen35_0_8b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-0.8B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("0.8B"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-3),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="selective",
+ ),
+ )
+
+
+def qwen35_2b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-2B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("2B"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-3),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="selective",
+ ),
+ )
+
+
+def qwen35_4b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-4B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("4B"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-4),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="full",
+ ),
+ )
+
+
+def qwen35_9b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-9B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("9B"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-4),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ tensor_parallel_degree=2,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="full",
+ ),
+ )
+
+
+def qwen35_27b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-27B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("27B"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-4),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ tensor_parallel_degree=4,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="full",
+ ),
+ )
+
+
+def qwen35_35b_a3b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-35B-A3B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("35B-A3B", moe_comm_backend="standard"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-4),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ tensor_parallel_degree=2,
+ expert_parallel_degree=8,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="full",
+ ),
+ )
+
+
+def qwen35_122b_a10b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-122B-A10B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("122B-A10B", moe_comm_backend="standard"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-4),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ tensor_parallel_degree=4,
+ expert_parallel_degree=8,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="full",
+ ),
+ )
+
+
+def qwen35_397b_a17b() -> Trainer.Config:
+ return Trainer.Config(
+ loss=ChunkedCELoss.Config(),
+ hf_assets_path="./assets/hf/Qwen3.5-397B-A17B",
+ tokenizer=MultiModalTokenizer.Config(**QWEN3_5_SPECIAL_TOKENS),
+ model_spec=model_registry("397B-A17B", moe_comm_backend="standard"),
+ dataloader=_dataloader("cc12m"),
+ optimizer=default_adamw(lr=5e-4),
+ lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
+ training=TrainingConfig(
+ local_batch_size=4,
+ seq_len=4096,
+ steps=1000,
+ ),
+ parallelism=ParallelismConfig(
+ data_parallel_shard_degree=-1,
+ tensor_parallel_degree=8,
+ expert_parallel_degree=16,
+ ),
+ checkpoint=CheckpointManager.Config(
+ interval=500,
+ last_save_model_only=False,
+ ),
+ activation_checkpoint=ActivationCheckpointConfig(
+ mode="full",
+ ),
+ )
diff --git a/torchtitan/models/qwen3_5/model.py b/torchtitan/models/qwen3_5/model.py
new file mode 100644
index 0000000000..37879019b0
--- /dev/null
+++ b/torchtitan/models/qwen3_5/model.py
@@ -0,0 +1,779 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+import torch.nn.functional as F
+
+from fla.ops.gated_delta_rule import (
+ chunk_gated_delta_rule as _fla_chunk_gated_delta_rule,
+ fused_recurrent_gated_delta_rule as _fla_fused_recurrent_gated_delta_rule,
+)
+from torch import nn
+from torch.distributed.tensor import DTensor
+from torch.distributed.tensor.experimental import local_map
+
+from torchtitan.models.common import Conv1d, FeedForward, Linear
+from torchtitan.models.common.attention import AttentionMasksType, BaseAttention
+from torchtitan.models.common.decoder import Decoder
+from torchtitan.models.utils import get_moe_model_nparams_and_flops
+from torchtitan.protocols.module import Module
+
+from .rope import MRoPE
+from .sharding import set_qwen35_sharding_config
+from .vision_encoder import Qwen35VisionEncoder
+
+
+def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
+ """L2 norm using rsqrt(sum(x²) + eps), not x/max(norm, eps) like F.normalize, to match FLA kernel."""
+ return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
+
+
+def _torch_native_gated_delta(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+) -> torch.Tensor:
+ """Standalone math reference for the gated delta rule recurrence.
+
+ Sequential O(seqlen) loop — use FLA kernels for GPU efficiency.
+
+ Args:
+ q, k: (bs, seqlen, n_heads, key_head_dim)
+ v: (bs, seqlen, n_heads, value_head_dim)
+ g: (bs, seqlen, n_heads) — log-space decay, always negative
+ beta: (bs, seqlen, n_heads) — update gate ∈ (0, 1)
+
+ Returns:
+ output: (bs, seqlen, n_heads, value_head_dim)
+ """
+ B, L, H, D_k = q.shape
+ D_v = v.shape[-1]
+ dtype = q.dtype
+
+ # Upcast to float32 — recurrence accumulates over seqlen steps
+ q = _l2norm(q.float(), dim=-1) * (D_k**-0.5)
+ k = _l2norm(k.float(), dim=-1)
+ v, g, beta = v.float(), g.float(), beta.float()
+
+ output = torch.zeros(B, L, H, D_v, dtype=torch.float32, device=q.device)
+ state = torch.zeros(B, H, D_k, D_v, dtype=torch.float32, device=q.device)
+
+ for t in range(L):
+ q_t = q[:, t]
+ k_t = k[:, t]
+ v_t = v[:, t]
+ g_t = g[:, t].exp().unsqueeze(-1).unsqueeze(-1)
+ b_t = beta[:, t].unsqueeze(-1)
+
+ state = state * g_t
+ kv_mem = torch.einsum("bhkv,bhk->bhv", state, k_t)
+ delta = (v_t - kv_mem) * b_t
+ state = state + torch.einsum("bhk,bhv->bhkv", k_t, delta)
+ output[:, t] = torch.einsum("bhkv,bhk->bhv", state, q_t)
+
+ return output.to(dtype)
+
+
+class SharedExperts(FeedForward):
+ """Qwen3.5 shared expert: SwiGLU FFN with a per-token sigmoid gate.
+
+ The output is ``sigmoid(gate(x)) * ffn(x)``. Inherits ``w1/w2/w3`` from
+ FeedForward so weight FQNs are unchanged. This gate is specific to
+ Qwen3.5; other models use a plain ``FeedForward`` shared expert.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(FeedForward.Config):
+ gate: Linear.Config
+
+ def __init__(self, config: Config):
+ super().__init__(config)
+ self.gate = config.gate.build()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = super().forward(x)
+ return torch.sigmoid(self.gate(x)) * out
+
+
+class OffsetRMSNorm(Module):
+ """RMSNorm with offset: ``(1 + weight) * norm(x)``.
+
+ Weight is zero-initialized so the norm starts as identity-scaled.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Module.Config):
+ dim: int
+ eps: float = 1e-6
+
+ def __init__(self, config: Config):
+ super().__init__()
+ self.eps = config.eps
+ self.weight = nn.Parameter(torch.empty(config.dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Upcast to float32 for numerical stability in pow/rsqrt
+ input_dtype = x.dtype
+ x = x.float()
+ variance = x.pow(2).mean(-1, keepdim=True)
+ x = x * torch.rsqrt(variance + self.eps)
+ return ((1.0 + self.weight.float()) * x).to(input_dtype)
+
+
+class RMSNormGated(Module):
+ """Gated RMSNorm: ``silu(gate) * weight * norm(x)``.
+
+ Takes ``(x, gate)`` separately. Weight is ones-initialized.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Module.Config):
+ dim: int
+ eps: float = 1e-6
+
+ def __init__(self, config: Config):
+ super().__init__()
+ self.eps = config.eps
+ self.weight = nn.Parameter(torch.empty(config.dim))
+
+ def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
+ # Upcast to float32 for numerical stability in pow/rsqrt
+ input_dtype = x.dtype
+ x = x.float()
+ variance = x.pow(2).mean(-1, keepdim=True)
+ x = x * torch.rsqrt(variance + self.eps)
+ x = (self.weight.float() * x).to(input_dtype)
+ x = x * F.silu(gate.float())
+ return x.to(input_dtype)
+
+
+class GatedDeltaKernel(Module):
+ """Stateless dispatch to FLA kernel or pure-torch fallback.
+
+ Provides a module boundary for the sharding code to wrap forward with
+ DTensor→local conversion — same pattern as FlexAttention. Handles Q/K
+ head expansion for grouped linear attention internally so that
+ repeat_interleave runs on local tensors under TP.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Module.Config):
+ # "fla_chunked": parallel within chunks, fast for training (default)
+ # "fla_fused_recurrent": token-by-token, lower memory for long sequences
+ # "torch_native": pure-Python reference, for numerical testing only
+ backend: Literal[
+ "fla_chunked", "fla_fused_recurrent", "torch_native"
+ ] = "fla_chunked"
+
+ def __init__(self, config: Config):
+ super().__init__()
+ self.backend = config.backend
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ ) -> torch.Tensor:
+ # Expand Q/K heads to match V when n_value_heads > n_key_heads
+ if q.shape[2] != v.shape[2]:
+ assert v.shape[2] % q.shape[2] == 0
+ repeat = v.shape[2] // q.shape[2]
+ q = q.repeat_interleave(repeat, dim=2)
+ k = k.repeat_interleave(repeat, dim=2)
+
+ if self.backend == "torch_native":
+ return _torch_native_gated_delta(q, k, v, g, beta)
+
+ if self.backend == "fla_chunked":
+ result = _fla_chunk_gated_delta_rule(
+ q,
+ k,
+ v,
+ g,
+ beta,
+ use_qk_l2norm_in_kernel=True,
+ )
+ elif self.backend == "fla_fused_recurrent":
+ result = _fla_fused_recurrent_gated_delta_rule(
+ q,
+ k,
+ v,
+ g,
+ beta=beta,
+ use_qk_l2norm_in_kernel=True,
+ )
+ else:
+ raise ValueError(
+ f"Unknown fla_backend '{self.backend}'. "
+ "Valid: 'fla_chunked', 'fla_fused_recurrent', 'torch_native'."
+ )
+
+ # FLA kernels return (output, final_state); we only need output
+ return result[0]
+
+
+class GatedDeltaNet(Module):
+ """Gated DeltaNet linear attention.
+
+ Uses recurrent state + gated delta rule instead of softmax attention.
+ No RoPE, no attention masks, different head structure from standard
+ attention.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Module.Config):
+ key_head_dim: int
+ value_head_dim: int
+ conv_kernel_size: int = 4
+
+ # Sub-module configs
+ in_proj_q: Linear.Config
+ in_proj_k: Linear.Config
+ in_proj_v: Linear.Config
+ in_proj_z: Linear.Config
+ in_proj_a: Linear.Config
+ in_proj_b: Linear.Config
+ conv_q: Conv1d.Config
+ conv_k: Conv1d.Config
+ conv_v: Conv1d.Config
+ kernel: GatedDeltaKernel.Config
+ norm: RMSNormGated.Config
+ out_proj: Linear.Config
+
+ def __init__(self, config: Config):
+ super().__init__()
+ self.key_head_dim = config.key_head_dim
+ self.value_head_dim = config.value_head_dim
+ self.conv_kernel_size = config.conv_kernel_size
+
+ value_dim = config.in_proj_v.out_features
+
+ self.in_proj_q = config.in_proj_q.build()
+ self.in_proj_k = config.in_proj_k.build()
+ self.in_proj_v = config.in_proj_v.build()
+ self.in_proj_z = config.in_proj_z.build()
+ self.in_proj_a = config.in_proj_a.build()
+ self.in_proj_b = config.in_proj_b.build()
+
+ self.conv_q = config.conv_q.build()
+ self.conv_k = config.conv_k.build()
+ self.conv_v = config.conv_v.build()
+
+ n_value_heads = value_dim // config.value_head_dim
+ self.A_log = nn.Parameter(torch.empty(n_value_heads))
+ self.dt_bias = nn.Parameter(torch.empty(n_value_heads))
+
+ self.kernel = config.kernel.build()
+ self.norm = config.norm.build()
+ self.out_proj = config.out_proj.build()
+
+ def _causal_conv(self, x: torch.Tensor, conv: nn.Module) -> torch.Tensor:
+ x = F.pad(x.transpose(1, 2), [self.conv_kernel_size - 1, 0])
+ if isinstance(x, DTensor):
+ # TODO: Remove once the DTensor Conv1d dispatch fix for sharded
+ # groups lands in a released torch. local_map runs the conv on
+ # local shards (channel-sharded input + Shard(0) weight) and
+ # restores DTensor-ness, with explicit gradient placements.
+ x_plc = x.placements
+ w = conv.weight
+ w_plc = w.placements # pyrefly: ignore [missing-attribute]
+
+ def _conv(x_local: torch.Tensor, w_local: torch.Tensor) -> torch.Tensor:
+ # groups == local out-channels (depthwise, channel-sharded)
+ # pyrefly: ignore [no-matching-overload]
+ return F.conv1d(
+ x_local,
+ w_local,
+ None,
+ conv.stride,
+ conv.padding,
+ conv.dilation,
+ w_local.size(0),
+ )
+
+ conv_dt = local_map(
+ _conv,
+ out_placements=(x_plc,),
+ in_placements=(x_plc, w_plc),
+ in_grad_placements=(x_plc, w_plc),
+ device_mesh=x.device_mesh,
+ )
+ x = conv_dt(x, w) # pyrefly: ignore
+ else:
+ x = conv(x)
+ return F.silu(x).transpose(1, 2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ bs, seqlen, _ = x.shape
+
+ # Shapes:
+ # xq, xk: (bs, seqlen, n_key_heads * key_head_dim)
+ # xv, xz: (bs, seqlen, n_value_heads * value_head_dim)
+ # xa, xb: (bs, seqlen, n_value_heads)
+ xq = self._causal_conv(self.in_proj_q(x), self.conv_q)
+ xk = self._causal_conv(self.in_proj_k(x), self.conv_k)
+ xv = self._causal_conv(self.in_proj_v(x), self.conv_v)
+ xz = self.in_proj_z(x)
+ xa = self.in_proj_a(x)
+ xb = self.in_proj_b(x)
+
+ xq = xq.view(bs, seqlen, -1, self.key_head_dim)
+ xk = xk.view(bs, seqlen, -1, self.key_head_dim)
+ xv = xv.view(bs, seqlen, -1, self.value_head_dim)
+
+ # Gating signals, shape (bs, seqlen, n_value_heads):
+ # g: decay rate per head, always negative
+ # beta: update gate ∈ (0, 1)
+ g = -torch.exp(self.A_log.float()) * F.softplus(xa.float() + self.dt_bias)
+ beta = torch.sigmoid(xb)
+
+ output = self.kernel(xq, xk, xv, g, beta)
+
+ xz = xz.view(bs, seqlen, -1, self.value_head_dim)
+ output = self.norm(output, xz)
+
+ output = output.reshape(bs, seqlen, -1)
+ return self.out_proj(output)
+
+
+class Qwen35Attention(BaseAttention):
+ """Full attention with output gating and partial RoPE for Qwen3.5.
+
+ Differences from GQAttention:
+ - wq is 2x wider: produces both query and sigmoid gate
+ - Partial RoPE: only first ``rotary_dim`` elements get RoPE
+ - Output gating: ``attn_output * sigmoid(gate)`` before ``wo``
+ - QK norm uses OffsetRMSNorm
+
+ Uses separate ``wq``/``wk``/``wv`` instead of the common fused ``qkv_linear``
+ (so this subclasses ``BaseAttention``, not ``GQAttention``): the 2x-wide,
+ gated ``wq`` doesn't fit a fused QKV projection that TP-shards by head.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(BaseAttention.Config):
+ n_heads: int
+ n_kv_heads: int
+ head_dim: int
+ rotary_dim: int
+ rope: MRoPE.Config
+ wq: Linear.Config
+ wk: Linear.Config
+ wv: Linear.Config
+ wo: Linear.Config
+ q_norm: OffsetRMSNorm.Config
+ k_norm: OffsetRMSNorm.Config
+ inner_attention: Module.Config
+
+ def __init__(self, config: Config):
+ super().__init__()
+ self.n_heads = config.n_heads
+ self.n_kv_heads = config.n_kv_heads
+ self.head_dim = config.head_dim
+ self.rotary_dim = config.rotary_dim
+ self.enable_gqa = self.n_heads > self.n_kv_heads
+
+ self.wq = config.wq.build()
+ self.wk = config.wk.build()
+ self.wv = config.wv.build()
+ self.wo = config.wo.build()
+
+ self.rope = config.rope.build()
+
+ self.q_norm = config.q_norm.build()
+ self.k_norm = config.k_norm.build()
+
+ self.scaling = self.head_dim**-0.5
+
+ self.inner_attention = config.inner_attention.build()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_masks: AttentionMasksType | None,
+ positions: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ bs, seqlen, _ = x.shape
+
+ # wq is 2x wider: produces query + gate
+ xq_gate = self.wq(x).view(bs, seqlen, -1, self.head_dim * 2)
+ xq, gate = xq_gate.chunk(2, dim=-1)
+ xk = self.wk(x).view(bs, seqlen, -1, self.head_dim)
+ xv = self.wv(x).view(bs, seqlen, -1, self.head_dim)
+
+ # QK norm (before RoPE)
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+
+ # Partial RoPE: only first rotary_dim elements get positional encoding
+ assert self.rotary_dim <= self.head_dim
+ xq_rot, xq_pass = xq[..., : self.rotary_dim], xq[..., self.rotary_dim :]
+ xk_rot, xk_pass = xk[..., : self.rotary_dim], xk[..., self.rotary_dim :]
+ xq_rot, xk_rot = self.rope(xq_rot, xk_rot, positions)
+ xq = torch.cat([xq_rot, xq_pass], dim=-1)
+ xk = torch.cat([xk_rot, xk_pass], dim=-1)
+
+ output = self.inner_attention(
+ xq,
+ xk,
+ xv,
+ attention_masks=attention_masks,
+ scale=self.scaling,
+ enable_gqa=self.enable_gqa,
+ ).contiguous()
+
+ # Output gating
+ output = output * torch.sigmoid(gate)
+ output = output.view(bs, seqlen, -1)
+ return self.wo(output)
+
+
+class Qwen35TransformerBlock(Module):
+ """Hybrid transformer block for Qwen3.5.
+
+ Each layer uses either full attention (Qwen35Attention) or linear
+ attention (GatedDeltaNet), determined by which config is provided.
+ Both types share the same FFN/MoE structure.
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Module.Config):
+ attention: Qwen35Attention.Config | None = None
+ delta_net: GatedDeltaNet.Config | None = None
+ feed_forward: Module.Config | None = None
+ moe: Module.Config | None = None
+ attention_norm: OffsetRMSNorm.Config
+ ffn_norm: OffsetRMSNorm.Config
+
+ def __init__(self, config: Config):
+ super().__init__()
+ self.full_attn = config.attention is not None
+
+ if self.full_attn:
+ self.attn = config.attention.build() # pyrefly: ignore [missing-attribute]
+ else:
+ assert config.delta_net is not None
+ self.attn = config.delta_net.build()
+
+ self.moe_enabled = config.moe is not None
+ if self.moe_enabled:
+ # pyrefly: ignore [missing-attribute]
+ self.moe = config.moe.build()
+ else:
+ assert config.feed_forward is not None
+ self.feed_forward = config.feed_forward.build()
+
+ self.attention_norm = config.attention_norm.build()
+ self.ffn_norm = config.ffn_norm.build()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_masks: AttentionMasksType | None,
+ positions: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ h = self.attention_norm(x)
+ if self.full_attn:
+ h = self.attn(h, attention_masks, positions)
+ else:
+ h = self.attn(h)
+ x = x + h
+
+ h = self.ffn_norm(x)
+ if self.moe_enabled:
+ x = x + self.moe(h)
+ else:
+ x = x + self.feed_forward(h)
+ return x
+
+
+class Qwen35Model(Decoder):
+ """Qwen3.5: Multimodal model with hybrid attention.
+
+ Combines a hybrid decoder (GatedDeltaNet linear attention + full
+ attention with output gating and partial RoPE) with a Vision
+ Transformer encoder for multimodal understanding.
+
+ Key architectural features:
+ - Hybrid attention: 75% GatedDeltaNet (linear) + 25% full attention
+ - Output gating on full attention: ``attn_out * sigmoid(gate)``
+ - Partial RoPE: only first ``rotary_dim`` elements get positional encoding
+ - OffsetRMSNorm: ``(1 + weight) * norm(x)`` with zero-init weight
+ - MRoPE: 3D (temporal/height/width) position IDs for multimodal batches;
+ text batches use the plain 1D positions
+ - MoE variant: routed experts + shared expert with sigmoid gate
+
+ MRoPE positions (``mrope_positions``, shape ``(batch, seq, 3)``) are built by
+ the dataloader and forwarded to every pipeline stage, so RoPE stays consistent
+ across stages even though the raw vision inputs (``pixel_values``/``grid_thw``)
+ only reach the first stage. Text batches carry no ``mrope_positions`` and use
+ the 2D ``positions`` instead.
+
+ Forward pass flow::
+
+ forward(tokens, pixel_values, grid_thw, mrope_positions, ...)
+ │
+ ├─ _prepare_multimodal_embeds
+ │ ├─ tok_embeddings(tokens) → text embeddings
+ │ ├─ _get_vision_embeds(pixel_values) → vision embeddings
+ │ │ └─ vision_encoder(pixel_values) → merge patches
+ │ ├─ _get_vision_positions → locate vision regions
+ │ └─ _scatter_vision_embeds → scatter into text sequence
+ │
+ └─ transformer layers (hybrid), each given (mrope_positions or positions)
+ └─ for each layer:
+ ├─ full attention (every Nth): QK-norm → partial RoPE → SDPA → gate
+ │ (the layer's MRoPE builds the cos/sin cache from positions)
+ └─ GatedDeltaNet (others): Conv1d → gated delta rule → gated norm
+ """
+
+ @dataclass(kw_only=True, slots=True)
+ class Config(Decoder.Config):
+ vision_encoder: Qwen35VisionEncoder.Config
+
+ def update_from_config(
+ self,
+ *,
+ config,
+ **kwargs,
+ ) -> None:
+ Decoder.Config.update_from_config(self, config=config, **kwargs)
+ parallelism = config.parallelism
+
+ tp = parallelism.tensor_parallel_degree
+ if tp > 1:
+ dn_cfg = next(
+ (l.delta_net for l in self.layers if l.delta_net is not None),
+ None,
+ )
+ if dn_cfg is not None:
+ n_key_heads = dn_cfg.in_proj_q.out_features // dn_cfg.key_head_dim
+ n_value_heads = (
+ dn_cfg.in_proj_v.out_features // dn_cfg.value_head_dim
+ )
+ if n_key_heads % tp != 0 or n_value_heads % tp != 0:
+ raise ValueError(
+ f"tensor_parallel_degree ({tp}) must divide "
+ f"n_key_heads ({n_key_heads}) and "
+ f"n_value_heads ({n_value_heads})."
+ )
+
+ set_qwen35_sharding_config(
+ self,
+ loss_parallel=not parallelism.disable_loss_parallel,
+ enable_ep=parallelism.expert_parallel_degree > 1,
+ )
+
+ def get_nparams_and_flops(
+ self, model: nn.Module, seq_len: int
+ ) -> tuple[int, int]:
+ attn_cfg = self.first_attention
+ # pyrefly: ignore [missing-attribute]
+ n_heads = attn_cfg.n_heads
+ # pyrefly: ignore [missing-attribute]
+ head_dim = attn_cfg.head_dim
+ return get_moe_model_nparams_and_flops(
+ self,
+ model,
+ n_heads,
+ 2 * head_dim,
+ seq_len,
+ )
+
+ def __init__(self, config: Config):
+ super().__init__(config)
+
+ self.vision_encoder = config.vision_encoder.build()
+ self.spatial_merge_size = config.vision_encoder.spatial_merge_size
+
+ def _get_vision_positions(
+ self,
+ tokens: torch.Tensor,
+ num_tokens_per_item: torch.Tensor,
+ vision_token_id: int,
+ ) -> list[tuple[int, int, int, int]]:
+ """Compute (item_idx, sample_idx, vision_start, n_tokens) for each vision item.
+
+ Finds where each contiguous run of vision placeholder tokens starts
+ in the text sequence.
+
+ Args:
+ tokens: Token IDs (batch, seq_len)
+ num_tokens_per_item: (num_items,) actual tokens per vision item
+ vision_token_id: Placeholder token ID
+
+ Returns:
+ List of (item_idx, sample_idx, vision_start, n_tokens) tuples
+ """
+ vision_mask = tokens == vision_token_id
+ flat_mask = vision_mask.view(-1)
+ prev_mask = torch.cat(
+ [torch.zeros(1, dtype=torch.bool, device=flat_mask.device), flat_mask[:-1]]
+ )
+ region_starts = torch.where(flat_mask & ~prev_mask)[0]
+ seq_len = tokens.shape[1]
+
+ positions = []
+ for i in range(num_tokens_per_item.shape[0]):
+ start = int(region_starts[i].item())
+ n_tokens = int(num_tokens_per_item[i].item())
+ positions.append((i, start // seq_len, start % seq_len, n_tokens))
+ return positions
+
+ def _get_vision_embeds(
+ self,
+ pixel_values: torch.Tensor,
+ *,
+ grid_thw: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Run vision encoder and return padded embeddings with token counts.
+
+ Args:
+ pixel_values: Padded patches (num_items, max_num_patch, patch_dim)
+ grid_thw: Grid dimensions (num_items, 3) for [t, h, w]
+
+ Returns:
+ merged_embeds: (num_items, max_tokens, dim) padded vision embeddings
+ num_tokens_per_item: (num_items,) actual token count per item
+ """
+ pixel_values = pixel_values.to(self.vision_encoder.patch_embed.weight.dtype)
+ merged_embeds = self.vision_encoder(pixel_values, grid_thw=grid_thw)
+
+ merge_unit = self.vision_encoder.spatial_merge_unit
+ num_tokens_per_item = grid_thw.prod(-1) // merge_unit
+
+ return merged_embeds, num_tokens_per_item
+
+ def _scatter_vision_embeds(
+ self,
+ inputs_embeds: torch.Tensor,
+ *,
+ merged_embeds: torch.Tensor,
+ vision_positions: list[tuple[int, int, int, int]],
+ ) -> torch.Tensor:
+ """Scatter vision embeddings into text embeddings at placeholder positions.
+
+ Copies directly from the padded vision encoder output into the text
+ sequence.
+
+ Args:
+ inputs_embeds: Text embeddings (batch, seq_len, dim)
+ merged_embeds: Padded vision embeddings (num_items, max_tokens, dim)
+ vision_positions: List of (item_idx, sample_idx, vision_start, n_tokens)
+
+ Returns:
+ Updated embeddings
+ """
+ for item_idx, sample_idx, vision_start, n_tokens in vision_positions:
+ inputs_embeds[
+ sample_idx, vision_start : vision_start + n_tokens, :
+ ] = merged_embeds[item_idx, :n_tokens, :]
+ return inputs_embeds
+
+ def _prepare_multimodal_embeds(
+ self,
+ tokens: torch.Tensor,
+ *,
+ pixel_values: torch.Tensor | None,
+ pixel_values_videos: torch.Tensor | None,
+ grid_thw: torch.Tensor | None,
+ grid_thw_videos: torch.Tensor | None,
+ special_tokens: dict[str, int],
+ ) -> torch.Tensor:
+ """Embed tokens, run vision encoder, scatter vision into text.
+
+ Args:
+ tokens: Input token IDs (batch_size, seq_len)
+ pixel_values: Image patches or None
+ pixel_values_videos: Video patches or None
+ grid_thw: Grid dimensions for images or None
+ grid_thw_videos: Grid dimensions for videos or None
+ special_tokens: Special token definitions
+
+ Returns:
+ (batch, seq_len, dim) embeddings with vision tokens scattered in
+ """
+ image_token_id = special_tokens["image_id"]
+ video_token_id = special_tokens["video_id"]
+
+ inputs_embeds = (
+ self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
+ )
+
+ if pixel_values is not None and grid_thw is not None:
+ merged_embeds, num_tokens = self._get_vision_embeds(
+ pixel_values, grid_thw=grid_thw
+ )
+ image_positions = self._get_vision_positions(
+ tokens, num_tokens, image_token_id
+ )
+ if image_positions:
+ inputs_embeds = self._scatter_vision_embeds(
+ inputs_embeds,
+ merged_embeds=merged_embeds,
+ vision_positions=image_positions,
+ )
+
+ if pixel_values_videos is not None and grid_thw_videos is not None:
+ merged_embeds, num_tokens = self._get_vision_embeds(
+ pixel_values_videos, grid_thw=grid_thw_videos
+ )
+ video_positions = self._get_vision_positions(
+ tokens, num_tokens, video_token_id
+ )
+ if video_positions:
+ inputs_embeds = self._scatter_vision_embeds(
+ inputs_embeds,
+ merged_embeds=merged_embeds,
+ vision_positions=video_positions,
+ )
+
+ return inputs_embeds
+
+ def forward( # pyrefly: ignore [bad-override]
+ self,
+ tokens: torch.Tensor,
+ *,
+ pixel_values: torch.Tensor | None = None,
+ pixel_values_videos: torch.Tensor | None = None,
+ grid_thw: torch.Tensor | None = None,
+ grid_thw_videos: torch.Tensor | None = None,
+ attention_masks: AttentionMasksType | None = None,
+ positions: torch.Tensor | None = None,
+ mrope_positions: torch.Tensor | None = None,
+ special_tokens: dict[str, int] | None = None,
+ ):
+ if self.tok_embeddings is not None:
+ x = self._prepare_multimodal_embeds(
+ tokens,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ grid_thw=grid_thw,
+ grid_thw_videos=grid_thw_videos,
+ special_tokens=special_tokens, # pyrefly: ignore [bad-argument-type]
+ )
+ else:
+ x = tokens
+
+ # 3D MRoPE positions for multimodal batches, else 2D text positions.
+ rope_positions = mrope_positions if mrope_positions is not None else positions
+ assert rope_positions is not None
+ for layer in self.layers.values():
+ x = layer(x, attention_masks, rope_positions)
+
+ x = self.norm(x) if self.norm is not None else x
+ if self._skip_lm_head:
+ return x
+ return self.lm_head(x) if self.lm_head is not None else x
diff --git a/torchtitan/models/qwen3_5/parallelize.py b/torchtitan/models/qwen3_5/parallelize.py
new file mode 100644
index 0000000000..5b1de894e3
--- /dev/null
+++ b/torchtitan/models/qwen3_5/parallelize.py
@@ -0,0 +1,213 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Parallelization utilities for Qwen3.5.
+
+This module applies PT-D parallelisms and various training techniques
+(activation checkpointing, compile, FSDP) to the Qwen3.5 model.
+"""
+
+import torch
+import torch.nn as nn
+from torch.distributed._composable.fsdp import fully_shard
+from torch.distributed.fsdp import MixedPrecisionPolicy
+
+from torchtitan.config import (
+ ActivationCheckpointConfig,
+ CompileConfig,
+ ParallelismConfig,
+ TORCH_DTYPE_MAP,
+ TrainingConfig,
+)
+
+from torchtitan.distributed import ParallelDims
+from torchtitan.distributed.activation_checkpoint import apply_ac
+from torchtitan.distributed.compile import apply_compile
+from torchtitan.distributed.fsdp import (
+ apply_fsdp_to_decoder,
+ get_fsdp_reshard_after_forward_policy,
+)
+from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
+
+
+def _apply_fsdp_to_vision_encoder(
+ vision_encoder: nn.Module,
+ dp_mesh,
+ param_dtype: torch.dtype,
+ reduce_dtype: torch.dtype,
+ reshard_after_forward_policy: str = "default",
+ pp_enabled: bool = False,
+):
+ """FSDP the vision encoder as a single unit.
+
+ One AllGather for all vision params is more efficient than per-layer
+ sharding — the vision encoder is small relative to the decoder.
+ Must be called before apply_fsdp on the decoder.
+ """
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
+ reshard_after_forward = get_fsdp_reshard_after_forward_policy(
+ reshard_after_forward_policy, pp_enabled=pp_enabled
+ )
+ fully_shard(
+ vision_encoder,
+ mesh=dp_mesh,
+ mp_policy=mp_policy,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+
+def parallelize_qwen3_5(
+ model: nn.Module,
+ *,
+ parallel_dims: ParallelDims,
+ training: TrainingConfig,
+ parallelism: ParallelismConfig,
+ compile_config: CompileConfig,
+ ac_config: ActivationCheckpointConfig,
+ dump_folder: str,
+):
+ """
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
+ parallelism to the Qwen3.5 model.
+
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
+ the model must fit on GPU or CPU memory.
+ """
+ if parallelism.spmd_backend == "full_dtensor":
+ raise NotImplementedError("full_dtensor is not supported yet.")
+
+ model_compile_enabled = (
+ compile_config.enable and "model" in compile_config.components
+ )
+
+ if parallel_dims.cp_enabled:
+ raise NotImplementedError(
+ "Context Parallel is not yet supported for Qwen3.5. "
+ "GatedDeltaNet (75% of layers) requires full-sequence allgather, "
+ "and multimodal CP needs vision scatter before CP sharding."
+ )
+
+ if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
+ if parallelism.enable_async_tensor_parallel and not model_compile_enabled:
+ raise RuntimeError("Async TP requires torch.compile")
+
+ # pyrefly: ignore [not-callable]
+ model.parallelize(parallel_dims)
+
+ if parallel_dims.tp_enabled:
+ maybe_enable_async_tp(parallelism, compile_config, parallel_dims.get_mesh("tp"))
+
+ if ac_config.mode != "none":
+ apply_ac(
+ model,
+ ac_config,
+ model_compile_enabled=model_compile_enabled,
+ base_folder=dump_folder,
+ )
+ if model.vision_encoder is not None:
+ apply_ac(
+ # pyrefly: ignore [bad-argument-type]
+ model.vision_encoder,
+ ac_config,
+ model_compile_enabled=model_compile_enabled,
+ base_folder=dump_folder,
+ )
+
+ if model_compile_enabled:
+ apply_compile(model, compile_config)
+ if model.vision_encoder is not None:
+ # pyrefly: ignore [bad-argument-type]
+ apply_compile(model.vision_encoder, compile_config)
+
+ dp_mesh_names = (
+ ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
+ )
+ dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
+
+ if model.vision_encoder is not None:
+ _apply_fsdp_to_vision_encoder(
+ model.vision_encoder, # pyrefly: ignore [bad-argument-type]
+ dp_mesh,
+ param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param],
+ reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce],
+ reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward,
+ pp_enabled=parallel_dims.pp_enabled,
+ )
+
+ edp_mesh = None
+ if parallel_dims.ep_enabled:
+ edp_mesh_names = (
+ ["dp_replicate", "efsdp"]
+ if parallel_dims.dp_replicate_enabled
+ else ["efsdp"]
+ )
+ edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names)
+
+ apply_fsdp_to_decoder(
+ model, # pyrefly: ignore [bad-argument-type]
+ dp_mesh,
+ param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param],
+ reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce],
+ pp_enabled=parallel_dims.pp_enabled,
+ cpu_offload=training.enable_cpu_offload,
+ reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward,
+ ep_degree=parallel_dims.ep,
+ edp_mesh=edp_mesh,
+ )
+
+ return model
+
+
+def pipeline_qwen3_5(
+ model: nn.Module,
+ *,
+ parallel_dims: ParallelDims,
+ parallelism: ParallelismConfig,
+ model_config,
+ **kwargs,
+):
+ """PP wrapper that assigns vision_encoder to the first pipeline stage.
+
+ Delegates to ``pipeline_llm`` after injecting ``vision_encoder`` into
+ the first stage's FQN list (the auto-generated LLM split doesn't know
+ about vision encoder modules).
+ """
+ import dataclasses
+
+ from torchtitan.distributed.pipeline_parallel import (
+ _generate_llm_fqn_per_model_part,
+ _get_pipeline_metadata,
+ pipeline_llm,
+ )
+
+ if parallelism.module_fqns_per_model_part is None:
+ (
+ num_virtual_stages,
+ num_layers,
+ input_weight,
+ output_weight,
+ ) = _get_pipeline_metadata(parallel_dims, parallelism, model_config)
+ fqn_per_part = _generate_llm_fqn_per_model_part(
+ num_virtual_stages, num_layers, input_weight, output_weight
+ )
+ # Vision encoder lives on the first stage alongside tok_embeddings. This
+ # adds load to stage 0 that the auto split doesn't model (input_weight
+ # only accounts for tok_embeddings); for a heavy vision encoder, bump
+ # parallelism.pipeline_parallel_first_stage_less_layers to rebalance.
+ if hasattr(model, "vision_encoder") and model.vision_encoder is not None:
+ fqn_per_part[0].insert(0, "vision_encoder")
+ parallelism = dataclasses.replace(
+ parallelism, module_fqns_per_model_part=fqn_per_part
+ )
+
+ return pipeline_llm(
+ model,
+ parallel_dims=parallel_dims,
+ parallelism=parallelism,
+ model_config=model_config,
+ **kwargs,
+ )
diff --git a/torchtitan/models/qwen3_vl/requirements.txt b/torchtitan/models/qwen3_5/requirements.txt
similarity index 100%
rename from torchtitan/models/qwen3_vl/requirements.txt
rename to torchtitan/models/qwen3_5/requirements.txt
diff --git a/torchtitan/models/qwen3_vl/rope.py b/torchtitan/models/qwen3_5/rope.py
similarity index 57%
rename from torchtitan/models/qwen3_vl/rope.py
rename to torchtitan/models/qwen3_5/rope.py
index 50368130b7..e2577c29b6 100644
--- a/torchtitan/models/qwen3_vl/rope.py
+++ b/torchtitan/models/qwen3_5/rope.py
@@ -13,7 +13,14 @@
class MRoPE(CosSinRoPE):
- """Multi-dimensional RoPE for Qwen3-VL temporal/height/width positions."""
+ """Multi-dimensional RoPE for Qwen3.5 temporal/height/width positions.
+
+ Standard per-layer RoPE: each full-attention layer owns an ``MRoPE`` and
+ applies it through ``RoPE.forward`` -> ``_reshape_cache`` -> ``apply_rotary_emb``.
+ The only override is ``_reshape_cache``: for 3D ``(batch, seq, 3)`` MRoPE
+ positions it builds an interleaved cos/sin cache; for 2D ``(batch, seq)`` text
+ positions it falls back to the plain ``CosSinRoPE`` per-token lookup.
+ """
@dataclass(kw_only=True, slots=True)
class Config(CosSinRoPE.Config):
@@ -31,44 +38,56 @@ def _reshape_cache(
query: torch.Tensor,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
- """Build a position-specific cache for 3D MRoPE position IDs."""
+ """Build a query-broadcastable cos/sin cache.
+
+ Dispatches on position rank: 3D ``(batch, seq, 3)`` MRoPE positions take
+ the interleaved scatter; everything else (2D text positions or ``None``)
+ falls back to the plain ``CosSinRoPE`` lookup.
+ """
if positions is not None and positions.ndim == 3:
return self._compute_mrope_cache(positions)
return super()._reshape_cache(query, positions)
def _compute_mrope_cache(self, position_ids: torch.Tensor) -> torch.Tensor:
+ """Build the interleaved cos/sin cache for 3D MRoPE positions.
+
+ Args:
+ position_ids: ``(batch, seq, 3)`` T/H/W positions. Plain, or a DTensor
+ under TP matching the rope ``cache`` buffer's Replicate placement.
+
+ Returns:
+ ``(batch, seq, 1, dim * 2)`` cache, broadcastable to the
+ ``(batch, seq, n_heads, rotary_dim)`` query/key in ``apply_rotary_emb``.
+
+ The scatter runs on plain local tensors. Under TP the ``cache`` buffer is a
+ Replicate DTensor, so it is unwrapped to local here and the result is
+ re-distributed with the buffer's placements, yielding a DTensor that
+ composes with the sharded query/key without any manual wrapping in the
+ attention forward.
+ """
cfg = self.config
assert isinstance(cfg, MRoPE.Config)
- if position_ids.shape[0] != 3:
- raise ValueError(
- f"MRoPE position IDs must have shape (3, batch, seq), "
- f"got {tuple(position_ids.shape)}."
- )
rope_cache = self.cache
cache_dtensor = rope_cache if isinstance(rope_cache, DTensor) else None
if cache_dtensor is not None:
rope_cache = cache_dtensor.to_local()
-
- position_dtensor = position_ids if isinstance(position_ids, DTensor) else None
- pos_local = (
- position_dtensor.to_local()
- if position_dtensor is not None
+ pos = (
+ position_ids.to_local()
+ if isinstance(position_ids, DTensor)
else position_ids
)
- pos_local = pos_local.to(device=rope_cache.device)
+ pos = pos.to(device=rope_cache.device)
- _maybe_check_max_pos(
- pos_local,
- max_valid_pos=rope_cache.shape[0] - 1,
- )
+ _maybe_check_max_pos(pos, max_valid_pos=rope_cache.shape[0] - 1)
head_dim = rope_cache.shape[-1] // 2
cos_cache = rope_cache[:, :head_dim]
sin_cache = rope_cache[:, head_dim:]
# Start from temporal positions for all dimensions, then overwrite the
# height/width interleaved sections with their own position IDs.
- t_pos = pos_local[0].long()
+ # ``pos`` is (batch, seq, 3); the last axis selects T/H/W.
+ t_pos = pos[..., 0].long()
mrope_cos = cos_cache[t_pos]
mrope_sin = sin_cache[t_pos]
@@ -77,7 +96,7 @@ def _compute_mrope_cache(self, position_ids: torch.Tensor) -> torch.Tensor:
length = cfg.mrope_section[dim] * 3
low = torch.arange(offset, length, 3, device=rope_cache.device)
col_indices = torch.cat([low, low + half])
- dim_pos = pos_local[dim].long()
+ dim_pos = pos[..., dim].long()
mrope_cos[..., col_indices] = cos_cache[:, col_indices][dim_pos]
mrope_sin[..., col_indices] = sin_cache[:, col_indices][dim_pos]
diff --git a/torchtitan/models/qwen3_5/sharding.py b/torchtitan/models/qwen3_5/sharding.py
new file mode 100644
index 0000000000..77db288696
--- /dev/null
+++ b/torchtitan/models/qwen3_5/sharding.py
@@ -0,0 +1,305 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Sharding configs for Qwen3.5 hybrid attention model.
+
+Sets ``ShardingConfig`` on all sub-configs so that ``model.parallelize()``
+applies TP via the Module protocol. Same pattern as ``qwen3/sharding.py``.
+
+Full-attention layers: TP on wq/wk/wv/wo with local_map for inner attention;
+each layer's MRoPE ``cache`` buffer is sharded Replicate.
+GatedDeltaNet layers: head-sharded TP on projections (ColwiseParallel) and
+out_proj (RowwiseParallel); the FLA kernel and depthwise Conv1d run on local
+tensors via local_map.
+"""
+
+from typing import TYPE_CHECKING
+
+import spmd_types as spmd
+
+from torchtitan.models.common.decoder_sharding import (
+ colwise_config,
+ dense_activation_placement,
+ dense_param_placement,
+ dense_sequence_parallel_placement,
+ norm_config,
+ rowwise_config,
+ set_decoder_sharding_config,
+ set_dense_ffn_sharding,
+ set_gqa_inner_attention_local_map,
+)
+from torchtitan.models.common.moe_sharding import set_moe_sharding_config
+from torchtitan.protocols.sharding import LocalMapConfig, ShardingConfig
+
+if TYPE_CHECKING:
+ from torchtitan.models.qwen3_5.model import (
+ GatedDeltaNet,
+ Qwen35Attention,
+ Qwen35Model,
+ Qwen35TransformerBlock,
+ SharedExperts,
+ )
+ from torchtitan.models.qwen3_5.vision_encoder import Qwen35VisionEncoder
+
+
+def _replicate_norm() -> ShardingConfig:
+ """Replicate norm (weight/bias and activations) — used by the vision
+ encoder, which runs without sequence parallelism."""
+ return ShardingConfig(
+ state_shardings={
+ "weight": dense_param_placement(tp=spmd.R),
+ "bias": dense_param_placement(tp=spmd.R),
+ },
+ in_src_shardings={"input": dense_activation_placement(tp=spmd.R)},
+ in_dst_shardings={"input": dense_activation_placement(tp=spmd.R)},
+ out_dst_shardings=dense_activation_placement(tp=spmd.R),
+ )
+
+
+def _qk_norm_sharding() -> ShardingConfig:
+ """Per-head QK-norm sharding: weight Replicate, activations Shard(2)."""
+ head_plc = dense_activation_placement(tp=spmd.S(2))
+ return ShardingConfig(
+ state_shardings={"weight": dense_param_placement(tp=spmd.R)},
+ in_src_shardings={"input": head_plc},
+ in_dst_shardings={"input": head_plc},
+ out_dst_shardings=head_plc,
+ )
+
+
+def _conv_weight_sharding() -> ShardingConfig:
+ """Depthwise Conv1d weight sharded Shard(0) on out-channels (head-sharded)."""
+ return ShardingConfig(
+ state_shardings={"weight": dense_param_placement(tp=spmd.S(0))},
+ )
+
+
+_GROUPED_EXPERTS_PARAM_LAYOUT: dict[str, spmd.PerMeshAxisSpmdType] = {
+ "w1_EFD": spmd.S(1),
+ "w2_EDF": spmd.S(2),
+ "w3_EFD": spmd.S(1),
+}
+
+
+def set_qwen35_sharding_config(
+ config: "Qwen35Model.Config",
+ *,
+ loss_parallel: bool,
+ enable_ep: bool,
+) -> None:
+ """Fill ``sharding_config`` on all Qwen3.5 sub-configs.
+
+ Uses SP for decoder layers, norm, and lm_head. tok_embeddings output
+ stays Replicate so vision scatter and MRoPE can access the full sequence.
+ The model forward redistributes to Shard(1) before entering the layers.
+ """
+ # SP on norm, lm_head, and layers. Each full-attention layer owns its rope;
+ # its cache buffer is sharded Replicate in _set_full_attention_sharding.
+ set_decoder_sharding_config(config, loss_parallel=loss_parallel, enable_sp=True)
+ # Override tok_embeddings: output Replicate (not Shard(1)) for vision scatter
+ config.tok_embeddings.sharding_config = ShardingConfig(
+ state_shardings={"weight": dense_param_placement(tp=spmd.S(0))},
+ in_src_shardings={"input": dense_activation_placement(tp=spmd.R)},
+ in_dst_shardings={"input": dense_activation_placement(tp=spmd.R)},
+ out_dst_shardings=dense_activation_placement(tp=spmd.R),
+ )
+ _set_vision_encoder_sharding(config.vision_encoder)
+ for layer_cfg in config.layers:
+ _set_qwen35_layer_sharding(layer_cfg, enable_ep=enable_ep)
+
+
+def _set_qwen35_layer_sharding(
+ layer_cfg: "Qwen35TransformerBlock.Config",
+ *,
+ enable_ep: bool,
+) -> None:
+ layer_cfg.attention_norm.sharding_config = norm_config(enable_sp=True)
+ layer_cfg.ffn_norm.sharding_config = norm_config(enable_sp=True)
+
+ if layer_cfg.attention is not None:
+ _set_full_attention_sharding(layer_cfg.attention)
+ else:
+ assert layer_cfg.delta_net is not None
+ _set_deltanet_sharding(layer_cfg.delta_net)
+
+ if layer_cfg.feed_forward is not None:
+ set_dense_ffn_sharding(
+ layer_cfg.feed_forward,
+ attn_x_layout=dense_sequence_parallel_placement(),
+ enable_sp=True,
+ )
+
+ if layer_cfg.moe is not None:
+ set_moe_sharding_config(
+ layer_cfg.moe,
+ enable_ep=enable_ep,
+ enable_sp=True,
+ expert_param_layout=_GROUPED_EXPERTS_PARAM_LAYOUT,
+ )
+ # pyrefly: ignore [missing-attribute]
+ _set_shared_expert_gate_sharding(layer_cfg.moe.shared_experts)
+
+
+def _set_shared_expert_gate_sharding(
+ shared_experts: "SharedExperts.Config | None",
+) -> None:
+ """Shard Qwen3.5's shared-expert sigmoid gate.
+
+ The common MoE sharding handles the shared FFN (w1/w2/w3) and the
+ module-boundary gather that feeds the gate a Replicate ``x``. Here we only
+ add the gate: its weight is Replicate and its output is Replicate, so
+ ``sigmoid(gate(x)) * ffn(x)`` is ``Replicate * Partial = Partial`` with no
+ extra collective. ``getattr`` keeps this a no-op when the MoE has no shared
+ expert (``None``); Qwen3.5's shared expert always carries the gate.
+ """
+ gate = getattr(shared_experts, "gate", None)
+ if gate is None:
+ return
+ gate.sharding_config = ShardingConfig(
+ state_shardings={
+ "weight": dense_param_placement(tp=spmd.R),
+ "bias": dense_param_placement(tp=spmd.R),
+ },
+ out_dst_shardings=dense_activation_placement(tp=spmd.R),
+ )
+
+
+def _set_vision_encoder_sharding(ve_cfg: "Qwen35VisionEncoder.Config") -> None:
+ """Sharding for the vision encoder.
+
+ All activations flow as Replicate — no SP in the vision encoder.
+ Linear layers are ColwiseParallel/RowwiseParallel for memory savings.
+ Norms are Replicate. pos_embed is Replicate via state_shardings.
+ """
+ ve_cfg.sharding_config = ShardingConfig(
+ state_shardings={"pos_embed": dense_param_placement(tp=spmd.R)},
+ )
+
+ # patch_embed receives plain pixel_values — wrap as DTensor(Replicate)
+ ve_cfg.patch_embed_proj.sharding_config = ShardingConfig(
+ state_shardings={
+ "weight": dense_param_placement(tp=spmd.R),
+ "bias": dense_param_placement(tp=spmd.R),
+ },
+ in_src_shardings={"input": dense_activation_placement(tp=spmd.R)},
+ in_dst_shardings={"input": dense_activation_placement(tp=spmd.R)},
+ out_dst_shardings=dense_activation_placement(tp=spmd.R),
+ )
+
+ # Block sub-modules
+ block = ve_cfg.block
+ block.norm1.sharding_config = _replicate_norm()
+ block.norm2.sharding_config = _replicate_norm()
+
+ block.attn.sharding_config = ShardingConfig(
+ in_src_shardings={"rope_cache": dense_activation_placement(tp=spmd.R)},
+ in_dst_shardings={"rope_cache": dense_activation_placement(tp=spmd.R)},
+ )
+ block.attn.wq.sharding_config = colwise_config()
+ block.attn.wk.sharding_config = colwise_config()
+ block.attn.wv.sharding_config = colwise_config()
+ block.attn.proj.sharding_config = rowwise_config(output_sp=False)
+ set_gqa_inner_attention_local_map(block.attn.inner_attention)
+
+ block.mlp.fc1.sharding_config = colwise_config()
+ block.mlp.fc2.sharding_config = rowwise_config(output_sp=False)
+
+ # Merger sub-modules
+ merger = ve_cfg.merger
+ merger.norm.sharding_config = _replicate_norm()
+ merger.fc1.sharding_config = colwise_config()
+ merger.fc2.sharding_config = rowwise_config(output_sp=False)
+
+
+def _set_full_attention_sharding(
+ attention_cfg: "Qwen35Attention.Config",
+) -> None:
+ """TP sharding for Qwen35Attention (output gating + partial RoPE)."""
+ attention_cfg.sharding_config = ShardingConfig(
+ in_src_shardings={"x": dense_sequence_parallel_placement()},
+ in_dst_shardings={"x": dense_activation_placement(tp=spmd.R)},
+ )
+ # The per-layer rope ``cache`` buffer is a Replicate DTensor; MRoPE builds the
+ # position-resolved cache from it (``positions`` stays a plain input).
+ attention_cfg.rope.sharding_config = ShardingConfig(
+ state_shardings={"cache": dense_param_placement(tp=spmd.R)},
+ )
+ attention_cfg.wq.sharding_config = colwise_config()
+ attention_cfg.wk.sharding_config = colwise_config()
+ attention_cfg.wv.sharding_config = colwise_config()
+ attention_cfg.wo.sharding_config = rowwise_config(output_sp=True)
+
+ attention_cfg.q_norm.sharding_config = _qk_norm_sharding()
+ attention_cfg.k_norm.sharding_config = _qk_norm_sharding()
+
+ set_gqa_inner_attention_local_map(attention_cfg.inner_attention)
+
+
+def _set_deltanet_sharding(deltanet_cfg: "GatedDeltaNet.Config") -> None:
+ """Sharding for GatedDeltaNet: head-sharded TP on projections.
+
+ Input is allgathered (Shard(1)→Replicate) so that the recurrence
+ sees the full sequence. Projections are ColwiseParallel (head-sharded
+ output). The FLA kernel runs on local tensors via local_map.
+ out_proj is RowwiseParallel (reduce-scatter back to Shard(1)).
+
+ A_log and dt_bias are per-head parameters, Shard(0) on TP.
+ Conv1d weights are Shard(0) (out-channels); the DTensor->local conversion
+ for the depthwise conv is handled in the model's ``_causal_conv``.
+ """
+ # ColwiseParallel on all input projections
+ for name in (
+ "in_proj_q",
+ "in_proj_k",
+ "in_proj_v",
+ "in_proj_z",
+ "in_proj_a",
+ "in_proj_b",
+ ):
+ getattr(deltanet_cfg, name).sharding_config = colwise_config()
+
+ # Depthwise Conv1d weights: Shard(0) on out-channels (head-sharded).
+ deltanet_cfg.conv_q.sharding_config = _conv_weight_sharding()
+ deltanet_cfg.conv_k.sharding_config = _conv_weight_sharding()
+ deltanet_cfg.conv_v.sharding_config = _conv_weight_sharding()
+
+ # RowwiseParallel on output projection (reduce-scatter to SP)
+ deltanet_cfg.out_proj.sharding_config = rowwise_config(output_sp=True)
+
+ # RMSNormGated: per-head norm, weight Replicate, activations Shard(2)
+ _norm_plc = dense_activation_placement(tp=spmd.S(2))
+ deltanet_cfg.norm.sharding_config = ShardingConfig(
+ state_shardings={"weight": dense_param_placement(tp=spmd.R)},
+ in_src_shardings={"x": _norm_plc, "gate": _norm_plc},
+ in_dst_shardings={"x": _norm_plc, "gate": _norm_plc},
+ out_dst_shardings=_norm_plc,
+ )
+
+ # GatedDeltaKernel: local_map converts DTensor q/k/v/g/beta to local
+ _kernel_plc = dense_activation_placement(tp=spmd.S(2))
+ deltanet_cfg.kernel.sharding_config = ShardingConfig(
+ in_dst_shardings={
+ "q": _kernel_plc,
+ "k": _kernel_plc,
+ "v": _kernel_plc,
+ "g": _kernel_plc,
+ "beta": _kernel_plc,
+ },
+ out_src_shardings=_kernel_plc,
+ local_map=LocalMapConfig(
+ in_grad_placements=(_kernel_plc,) * 5,
+ ),
+ )
+
+ deltanet_cfg.sharding_config = ShardingConfig(
+ state_shardings={
+ "A_log": dense_param_placement(tp=spmd.S(0)),
+ "dt_bias": dense_param_placement(tp=spmd.S(0)),
+ },
+ in_src_shardings={"x": dense_sequence_parallel_placement()},
+ in_dst_shardings={"x": dense_activation_placement(tp=spmd.R)},
+ out_dst_shardings=dense_sequence_parallel_placement(),
+ )
diff --git a/torchtitan/models/qwen3_5/state_dict_adapter.py b/torchtitan/models/qwen3_5/state_dict_adapter.py
new file mode 100644
index 0000000000..1bba2ce66f
--- /dev/null
+++ b/torchtitan/models/qwen3_5/state_dict_adapter.py
@@ -0,0 +1,346 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+State dict adapter for Qwen3.5.
+
+Converts between HuggingFace Qwen3.5 checkpoint format and torchtitan format.
+
+MoE expert weights require two transformations:
+- **Transpose**: HF and TT use transposed layouts for grouped 3D expert weights.
+ E.g. HF down_proj [E, hidden, dim] <-> TT w2 [E, dim, hidden].
+- **Fuse/split gate_up_proj**: HF fuses gate_proj and up_proj into a single
+ gate_up_proj [E, dim, 2*hidden_dim]. TT stores them separately as
+ w1 [E, hidden_dim, dim] and w3 [E, hidden_dim, dim].
+
+Other notable conversions:
+- Conv3d patch embedding (HF) <-> Linear (TT) via weight reshape
+- Vision block naming: HF `blocks` <-> TT `layers`
+- Vision QKV: HF fused qkv <-> TT separate wq/wk/wv
+- GatedDeltaNet QKV: HF fused in_proj_qkv <-> TT separate in_proj_q/k/v
+- GatedDeltaNet Conv1d: HF fused conv1d <-> TT separate conv_q/k/v
+"""
+
+import re
+from typing import Any
+
+import torch
+
+from torchtitan.protocols.state_dict_adapter import StateDictAdapter
+
+from .model import Qwen35Model
+
+
+class Qwen35StateDictAdapter(StateDictAdapter):
+ def __init__(self, model_config: Qwen35Model.Config, hf_assets_path: str | None):
+ super().__init__(model_config, hf_assets_path)
+ self.model_config = model_config
+
+ self.from_hf_map = {
+ # ===== Language Model =====
+ "model.language_model.embed_tokens.weight": "tok_embeddings.weight",
+ # Full attention layers (self_attn.*)
+ "model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.wq.weight",
+ "model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.wk.weight",
+ "model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.wv.weight",
+ "model.language_model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.wo.weight",
+ "model.language_model.layers.{}.self_attn.q_norm.weight": "layers.{}.attn.q_norm.weight",
+ "model.language_model.layers.{}.self_attn.k_norm.weight": "layers.{}.attn.k_norm.weight",
+ "model.language_model.layers.{}.self_attn.rotary_emb.inv_freq": None,
+ # GatedDeltaNet layers (linear_attn.*)
+ # QKV and Conv1d: HF fused → TT separate (handled in to_hf/from_hf)
+ "model.language_model.layers.{}.linear_attn.in_proj_qkv.weight": None,
+ "model.language_model.layers.{}.linear_attn.conv1d.weight": None,
+ "model.language_model.layers.{}.linear_attn.in_proj_z.weight": "layers.{}.attn.in_proj_z.weight",
+ "model.language_model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attn.in_proj_a.weight",
+ "model.language_model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attn.in_proj_b.weight",
+ "model.language_model.layers.{}.linear_attn.A_log": "layers.{}.attn.A_log",
+ "model.language_model.layers.{}.linear_attn.dt_bias": "layers.{}.attn.dt_bias",
+ "model.language_model.layers.{}.linear_attn.norm.weight": "layers.{}.attn.norm.weight",
+ "model.language_model.layers.{}.linear_attn.out_proj.weight": "layers.{}.attn.out_proj.weight",
+ # Non-MoE MLP
+ "model.language_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
+ "model.language_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
+ "model.language_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
+ # Layer norms
+ "model.language_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
+ "model.language_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
+ # MoE (grouped 3D format, handled specially in to_hf/from_hf)
+ "model.language_model.layers.{}.mlp.experts.down_proj": "layers.{}.moe.experts.w2",
+ "model.language_model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight",
+ # MoE shared expert
+ "model.language_model.layers.{}.mlp.shared_expert.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight",
+ "model.language_model.layers.{}.mlp.shared_expert.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight",
+ "model.language_model.layers.{}.mlp.shared_expert.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight",
+ "model.language_model.layers.{}.mlp.shared_expert_gate.weight": "layers.{}.moe.shared_experts.gate.weight",
+ # Final norm and output
+ "model.language_model.norm.weight": "norm.weight",
+ "lm_head.weight": "lm_head.weight",
+ # ===== Vision Encoder =====
+ # Patch embedding (Conv3d in HF, Linear in TT — weight reshape needed)
+ "model.visual.patch_embed.proj.weight": "vision_encoder.patch_embed.weight",
+ "model.visual.patch_embed.proj.bias": "vision_encoder.patch_embed.bias",
+ # Position embeddings
+ "model.visual.pos_embed.weight": "vision_encoder.pos_embed",
+ # Vision transformer blocks (HF: blocks, TT: layers)
+ "model.visual.blocks.{}.norm1.weight": "vision_encoder.layers.{}.norm1.weight",
+ "model.visual.blocks.{}.norm1.bias": "vision_encoder.layers.{}.norm1.bias",
+ "model.visual.blocks.{}.norm2.weight": "vision_encoder.layers.{}.norm2.weight",
+ "model.visual.blocks.{}.norm2.bias": "vision_encoder.layers.{}.norm2.bias",
+ # Vision QKV: HF fused → TT separate (handled in to_hf/from_hf)
+ "model.visual.blocks.{}.attn.qkv.weight": None,
+ "model.visual.blocks.{}.attn.qkv.bias": None,
+ "model.visual.blocks.{}.attn.proj.weight": "vision_encoder.layers.{}.attn.proj.weight",
+ "model.visual.blocks.{}.attn.proj.bias": "vision_encoder.layers.{}.attn.proj.bias",
+ "model.visual.blocks.{}.mlp.linear_fc1.weight": "vision_encoder.layers.{}.mlp.linear_fc1.weight",
+ "model.visual.blocks.{}.mlp.linear_fc1.bias": "vision_encoder.layers.{}.mlp.linear_fc1.bias",
+ "model.visual.blocks.{}.mlp.linear_fc2.weight": "vision_encoder.layers.{}.mlp.linear_fc2.weight",
+ "model.visual.blocks.{}.mlp.linear_fc2.bias": "vision_encoder.layers.{}.mlp.linear_fc2.bias",
+ # Merger
+ "model.visual.merger.norm.weight": "vision_encoder.merger.norm.weight",
+ "model.visual.merger.norm.bias": "vision_encoder.merger.norm.bias",
+ "model.visual.merger.linear_fc1.weight": "vision_encoder.merger.linear_fc1.weight",
+ "model.visual.merger.linear_fc1.bias": "vision_encoder.merger.linear_fc1.bias",
+ "model.visual.merger.linear_fc2.weight": "vision_encoder.merger.linear_fc2.weight",
+ "model.visual.merger.linear_fc2.bias": "vision_encoder.merger.linear_fc2.bias",
+ }
+
+ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
+ """Convert torchtitan state dict to HuggingFace Qwen3.5 format."""
+ to_hf_map = {v: k for k, v in self.from_hf_map.items() if v is not None}
+ hf_state_dict = {}
+
+ moe_w1_by_layer: dict[str, Any] = {}
+ moe_w3_by_layer: dict[str, Any] = {}
+ vision_qkv_by_layer: dict[str, dict[str, Any]] = {}
+ deltanet_qkv_by_layer: dict[str, dict[str, Any]] = {}
+
+ for tt_key, value in state_dict.items():
+ if "moe.experts" in tt_key:
+ tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1)
+ # pyrefly: ignore [missing-attribute]
+ layer_num = re.search(r"\d+", tt_key).group(0)
+
+ if tt_abstract_key == "layers.{}.moe.experts.w1":
+ moe_w1_by_layer[layer_num] = value
+ continue
+ elif tt_abstract_key == "layers.{}.moe.experts.w3":
+ moe_w3_by_layer[layer_num] = value
+ continue
+ elif tt_abstract_key == "layers.{}.moe.experts.w2":
+ hf_key = (
+ f"model.language_model.layers.{layer_num}.mlp.experts.down_proj"
+ )
+ hf_state_dict[hf_key] = value.transpose(-2, -1)
+ continue
+
+ if tt_abstract_key not in to_hf_map:
+ continue
+ hf_state_dict[to_hf_map[tt_abstract_key].format(layer_num)] = value
+
+ elif re.search(r"\.\d+\.", tt_key):
+ tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1)
+ # pyrefly: ignore [missing-attribute]
+ layer_num = re.search(r"\d+", tt_key).group(0)
+
+ # Collect deltanet q/k/v projections and conv weights for fusing
+ if tt_abstract_key in (
+ "layers.{}.attn.in_proj_q.weight",
+ "layers.{}.attn.in_proj_k.weight",
+ "layers.{}.attn.in_proj_v.weight",
+ "layers.{}.attn.conv_q.weight",
+ "layers.{}.attn.conv_k.weight",
+ "layers.{}.attn.conv_v.weight",
+ ):
+ if layer_num not in deltanet_qkv_by_layer:
+ deltanet_qkv_by_layer[layer_num] = {}
+ short_key = tt_abstract_key.split("attn.")[-1].replace("{}", "")
+ deltanet_qkv_by_layer[layer_num][short_key] = value
+ continue
+
+ # Collect vision wq/wk/wv for fusing into qkv
+ if tt_abstract_key in (
+ "vision_encoder.layers.{}.attn.wq.weight",
+ "vision_encoder.layers.{}.attn.wq.bias",
+ "vision_encoder.layers.{}.attn.wk.weight",
+ "vision_encoder.layers.{}.attn.wk.bias",
+ "vision_encoder.layers.{}.attn.wv.weight",
+ "vision_encoder.layers.{}.attn.wv.bias",
+ ):
+ if layer_num not in vision_qkv_by_layer:
+ vision_qkv_by_layer[layer_num] = {}
+ short_key = tt_abstract_key.split("attn.")[-1].replace("{}", "")
+ vision_qkv_by_layer[layer_num][short_key] = value
+ continue
+
+ if tt_abstract_key not in to_hf_map:
+ continue
+ hf_state_dict[to_hf_map[tt_abstract_key].format(layer_num)] = value
+
+ else:
+ if tt_key not in to_hf_map:
+ continue
+ if tt_key == "lm_head.weight" and getattr(
+ self.model_config, "enable_weight_tying", False
+ ):
+ continue
+ hf_value = value
+ # Linear weight (out, C*T*H*W) → Conv3d weight (out, C, T, H, W)
+ if tt_key == "vision_encoder.patch_embed.weight":
+ # pyrefly: ignore [missing-attribute]
+ encoder = self.model_config.vision_encoder
+ hf_value = value.reshape(
+ value.shape[0],
+ encoder.in_channels,
+ encoder.temporal_patch_size,
+ encoder.patch_size,
+ encoder.patch_size,
+ )
+ hf_state_dict[to_hf_map[tt_key]] = hf_value
+
+ # Fuse MoE w1 (gate) + w3 (up) → gate_up_proj
+ for layer_num in moe_w1_by_layer:
+ w1 = moe_w1_by_layer[layer_num].transpose(-2, -1)
+ w3 = moe_w3_by_layer[layer_num].transpose(-2, -1)
+ hf_state_dict[
+ f"model.language_model.layers.{layer_num}.mlp.experts.gate_up_proj"
+ ] = torch.cat([w1, w3], dim=-1)
+
+ # Fuse vision wq/wk/wv → qkv
+ for layer_num, parts in vision_qkv_by_layer.items():
+ for suffix in ("weight", "bias"):
+ q = parts.get(f"wq.{suffix}")
+ k = parts.get(f"wk.{suffix}")
+ v = parts.get(f"wv.{suffix}")
+ if q is not None and k is not None and v is not None:
+ hf_state_dict[
+ f"model.visual.blocks.{layer_num}.attn.qkv.{suffix}"
+ ] = torch.cat([q, k, v], dim=0)
+
+ # Fuse deltanet in_proj_q/k/v → in_proj_qkv, conv_q/k/v → conv1d
+ for layer_num, parts in deltanet_qkv_by_layer.items():
+ q = parts.get("in_proj_q.weight")
+ k = parts.get("in_proj_k.weight")
+ v = parts.get("in_proj_v.weight")
+ if q is not None and k is not None and v is not None:
+ hf_state_dict[
+ f"model.language_model.layers.{layer_num}.linear_attn.in_proj_qkv.weight"
+ ] = torch.cat([q, k, v], dim=0)
+ cq = parts.get("conv_q.weight")
+ ck = parts.get("conv_k.weight")
+ cv = parts.get("conv_v.weight")
+ if cq is not None and ck is not None and cv is not None:
+ hf_state_dict[
+ f"model.language_model.layers.{layer_num}.linear_attn.conv1d.weight"
+ ] = torch.cat([cq, ck, cv], dim=0)
+
+ return hf_state_dict
+
+ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
+ """Convert HuggingFace Qwen3.5 state dict to torchtitan format."""
+ tt_state_dict = {}
+
+ # HF ties lm_head with embed_tokens — copy if missing
+ if "lm_head.weight" not in hf_state_dict:
+ embed_key = "model.language_model.embed_tokens.weight"
+ if embed_key not in hf_state_dict:
+ raise ValueError(
+ f"HF checkpoint missing both 'lm_head.weight' and '{embed_key}'"
+ )
+ hf_state_dict["lm_head.weight"] = hf_state_dict[embed_key]
+
+ for hf_key, value in hf_state_dict.items():
+ if re.search(r"\.\d+\.", hf_key):
+ hf_abstract_key = re.sub(r"(\d+)", "{}", hf_key, count=1)
+ # pyrefly: ignore [missing-attribute]
+ idx = re.search(r"\d+", hf_key).group(0)
+
+ # MoE gate_up_proj → split into w1 + w3 and transpose
+ if (
+ hf_abstract_key
+ == "model.language_model.layers.{}.mlp.experts.gate_up_proj"
+ ):
+ w1_hf, w3_hf = value.chunk(2, dim=-1)
+ tt_state_dict[f"layers.{idx}.moe.experts.w1"] = w1_hf.transpose(
+ -2, -1
+ )
+ tt_state_dict[f"layers.{idx}.moe.experts.w3"] = w3_hf.transpose(
+ -2, -1
+ )
+ continue
+
+ # MoE down_proj → transpose
+ if (
+ hf_abstract_key
+ == "model.language_model.layers.{}.mlp.experts.down_proj"
+ ):
+ tt_state_dict[f"layers.{idx}.moe.experts.w2"] = value.transpose(
+ -2, -1
+ )
+ continue
+
+ # GatedDeltaNet fused in_proj_qkv → split into q/k/v
+ if (
+ hf_abstract_key
+ == "model.language_model.layers.{}.linear_attn.in_proj_qkv.weight"
+ ):
+ # pyrefly: ignore [missing-attribute]
+ dn = self.model_config.layers[int(idx)].delta_net
+ kd = dn.in_proj_q.out_features
+ vd = dn.in_proj_v.out_features
+ q, k, v = value.split([kd, kd, vd], dim=0)
+ tt_state_dict[f"layers.{idx}.attn.in_proj_q.weight"] = q
+ tt_state_dict[f"layers.{idx}.attn.in_proj_k.weight"] = k
+ tt_state_dict[f"layers.{idx}.attn.in_proj_v.weight"] = v
+ continue
+
+ # GatedDeltaNet fused conv1d → split into conv_q/k/v
+ if (
+ hf_abstract_key
+ == "model.language_model.layers.{}.linear_attn.conv1d.weight"
+ ):
+ # pyrefly: ignore [missing-attribute]
+ dn = self.model_config.layers[int(idx)].delta_net
+ kd = dn.in_proj_q.out_features
+ vd = dn.in_proj_v.out_features
+ cq, ck, cv = value.split([kd, kd, vd], dim=0)
+ tt_state_dict[f"layers.{idx}.attn.conv_q.weight"] = cq
+ tt_state_dict[f"layers.{idx}.attn.conv_k.weight"] = ck
+ tt_state_dict[f"layers.{idx}.attn.conv_v.weight"] = cv
+ continue
+
+ # Vision fused QKV → split into wq/wk/wv
+ if hf_abstract_key in (
+ "model.visual.blocks.{}.attn.qkv.weight",
+ "model.visual.blocks.{}.attn.qkv.bias",
+ ):
+ suffix = "weight" if "weight" in hf_abstract_key else "bias"
+ q, k, v = value.chunk(3, dim=0)
+ tt_state_dict[f"vision_encoder.layers.{idx}.attn.wq.{suffix}"] = q
+ tt_state_dict[f"vision_encoder.layers.{idx}.attn.wk.{suffix}"] = k
+ tt_state_dict[f"vision_encoder.layers.{idx}.attn.wv.{suffix}"] = v
+ continue
+
+ if hf_abstract_key not in self.from_hf_map:
+ continue
+ tt_key = self.from_hf_map[hf_abstract_key]
+ if tt_key is None:
+ continue
+ tt_state_dict[tt_key.format(idx)] = value
+
+ else:
+ if hf_key not in self.from_hf_map:
+ continue
+ tt_key = self.from_hf_map[hf_key]
+ if tt_key is None:
+ continue
+ tt_value = value
+ # Conv3d weight (out, C, T, H, W) → Linear weight (out, C*T*H*W)
+ if hf_key == "model.visual.patch_embed.proj.weight":
+ tt_value = value.reshape(value.shape[0], -1)
+ tt_state_dict[tt_key] = tt_value
+
+ return tt_state_dict
diff --git a/torchtitan/models/qwen3_vl/vision_encoder.py b/torchtitan/models/qwen3_5/vision_encoder.py
similarity index 68%
rename from torchtitan/models/qwen3_vl/vision_encoder.py
rename to torchtitan/models/qwen3_5/vision_encoder.py
index be59176387..83ceea35ee 100644
--- a/torchtitan/models/qwen3_vl/vision_encoder.py
+++ b/torchtitan/models/qwen3_5/vision_encoder.py
@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+from collections.abc import Callable
from dataclasses import dataclass, field
import torch
@@ -15,19 +16,18 @@
from torchtitan.models.common.attention import FlexAttention
from torchtitan.models.common.nn_modules import GELU, LayerNorm
from torchtitan.models.common.rope import CosSinRoPE
-from torchtitan.protocols.module import Module, ModuleDict, ModuleList
+from torchtitan.protocols.module import Module, ModuleDict
_compiled_create_block_mask = torch.compile(create_block_mask)
-def get_vision_block_mask_mod(num_patch: torch.Tensor, max_num_patch: int):
+def get_vision_block_mask_mod(num_patch: torch.Tensor) -> Callable:
"""Create a mask modifier for block-diagonal attention.
Each image only attends to its own patches.
Args:
num_patch: (num_vision,) actual number of patches per visual item
- max_num_patch: Maximum number of patches (padded length)
"""
def mask_mod(b, h, q_idx, kv_idx):
@@ -61,7 +61,7 @@ def _compute_learned_pos_embeds(
dim: Hidden dimension
Returns:
- learned_pos: (num_vision, max_num_patch, dim) interpolated position embeddings
+ pos_embeds: (num_vision, max_num_patch, dim) interpolated position embeddings
"""
num_vision = grid_thw.shape[0]
dtype = learned_pos_embed.dtype
@@ -90,7 +90,7 @@ def _compute_learned_pos_embeds(
for (h, w), indices in hw_to_indices.items():
pos_hw = F.interpolate(
pos_grid,
- size=(h, w),
+ size=[h, w],
mode="bilinear",
align_corners=True,
)
@@ -208,7 +208,7 @@ def _compute_2d_rope_cache(
else:
rope_embeds[i, :seq_len] = rope_2d
- # Compute cos/sin in model dtype (HF uses .float() here)
+ # Compute cos/sin in float32 for numerical precision
rope_embeds = torch.cat((rope_embeds, rope_embeds), dim=-1) # (N, L, head_dim)
rope_cache = torch.cat([rope_embeds.cos(), rope_embeds.sin()], dim=-1).unsqueeze(
2
@@ -217,45 +217,6 @@ def _compute_2d_rope_cache(
return rope_cache
-class PatchEmbed(Module):
- """Patch Embedding using Linear projection.
-
- Since patches are already extracted by the collator, we use Linear instead of Conv3d.
- This is mathematically equivalent when Conv3d kernel_size equals input size:
- - Conv3d: (B, C, T, H, W) with kernel=C*(T,H,W) and dim kernels → (B, dim, 1, 1, 1)
- - Linear: (B, C*T*H*W) → (B, dim)
- Same weighted sum, but Linear uses efficient batched matrix multiplication.
- """
-
- @dataclass(kw_only=True, slots=True)
- class Config(Module.Config):
- patch_size: int
- temporal_patch_size: int
- in_channels: int
- proj: Linear.Config
-
- def __init__(self, config: Config):
- super().__init__()
- # Kept on the Module so the state-dict adapter can reshape
- # Linear weights into 5D Conv3d weights.
- self.patch_size = config.patch_size
- self.temporal_patch_size = config.temporal_patch_size
- self.in_channels = config.in_channels
-
- self.proj = config.proj.build()
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """Project patches to embeddings.
-
- Args:
- hidden_states: (batch, max_num_patch, patch_dim)
-
- Returns:
- (batch, max_num_patch, embed_dim)
- """
- return self.proj(hidden_states)
-
-
class VisionRotaryEmbedding(Module):
"""2D Rotary Position Embedding for Vision Transformer."""
@@ -286,41 +247,35 @@ def _init_self_buffers(self, *, buffer_device: torch.device | None = None) -> No
)
def forward(self, seqlen: int) -> torch.Tensor:
- """Compute rotary embeddings for a sequence."""
+ """Compute rotary frequency table for positions up to seqlen."""
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
- freqs = torch.outer(seq, self.inv_freq)
- return freqs
+ return torch.outer(seq, self.inv_freq)
class PatchMerger(Module):
"""Merge spatial patches to reduce sequence length.
- ``use_postshuffle_norm``: If True, apply LayerNorm after spatial reshape
- (norm dim = hidden_size * spatial_merge_size^2). If False, apply
- before reshape (norm dim = hidden_size). DeepStack mergers use
- postshuffle norm; the main merger uses pre-shuffle norm. The caller
- must set ``norm.normalized_shape`` to match the chosen mode.
+ Applies LayerNorm before spatial reshape, then projects through a
+ two-layer MLP (fc1 → GELU → fc2).
"""
@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
- hidden_size: int
spatial_merge_size: int
- fc1: Linear.Config
- fc2: Linear.Config
+ merged_hidden_size: int
norm: LayerNorm.Config
+ fc1: Linear.Config
act_fn: GELU.Config = field(
default_factory=lambda: GELU.Config(approximate="tanh")
)
- use_postshuffle_norm: bool = False
+ fc2: Linear.Config
def __init__(self, config: Config):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
- self.merged_hidden_size = config.hidden_size * (config.spatial_merge_size**2)
- self.use_postshuffle_norm = config.use_postshuffle_norm
+ self.merged_hidden_size = config.merged_hidden_size
self.norm = config.norm.build()
self.linear_fc1 = config.fc1.build()
@@ -337,20 +292,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(batch, seq_len // spatial_merge_size^2, out_hidden_size)
"""
batch_size, seq_len, _ = x.shape
- if self.use_postshuffle_norm:
- x = x.view(
- batch_size,
- seq_len // (self.spatial_merge_size**2),
- self.merged_hidden_size,
- )
- x = self.norm(x)
- else:
- x = self.norm(x)
- x = x.view(
- batch_size,
- seq_len // (self.spatial_merge_size**2),
- self.merged_hidden_size,
- )
+ x = self.norm(x)
+ x = x.view(
+ batch_size,
+ seq_len // (self.spatial_merge_size**2),
+ self.merged_hidden_size,
+ )
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return x
@@ -361,53 +308,43 @@ class VisionAttention(Module):
@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
dim: int
- n_heads: int
- qkv: Linear.Config
+ num_heads: int
+ wq: Linear.Config
+ wk: Linear.Config
+ wv: Linear.Config
proj: Linear.Config
- flex_attention: FlexAttention.Config = field(
- default_factory=lambda: FlexAttention.Config()
- )
+ inner_attention: Module.Config = field(default_factory=FlexAttention.Config)
def __init__(self, config: Config):
super().__init__()
self.dim = config.dim
- self.num_heads = config.n_heads
+ self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
- self.qkv = config.qkv.build()
+ self.wq = config.wq.build()
+ self.wk = config.wk.build()
+ self.wv = config.wv.build()
self.proj = config.proj.build()
- self.flex_attention = config.flex_attention.build()
+ self.flex_attention = config.inner_attention.build()
def forward(
self,
- hidden_states: torch.Tensor,
+ x: torch.Tensor,
*,
rope_cache: torch.Tensor,
attention_mask: BlockMask,
) -> torch.Tensor:
- """Apply multi-head attention with 2D RoPE.
+ bs, seqlen, _ = x.shape
- Args:
- hidden_states: (num_vision, max_num_patch, dim)
- rope_cache: (num_vision, max_num_patch, 1, head_dim*2) precomputed cos/sin
- attention_mask: BlockMask for attention
+ xq = self.wq(x).view(bs, seqlen, -1, self.head_dim)
+ xk = self.wk(x).view(bs, seqlen, -1, self.head_dim)
+ xv = self.wv(x).view(bs, seqlen, -1, self.head_dim)
- Returns:
- (num_vision, max_num_patch, dim)
- """
- num_vision, max_num_patch, _ = hidden_states.shape
+ xq, xk = CosSinRoPE.apply_rotary_emb(xq, xk, rope_cache)
- qkv = self.qkv(hidden_states).reshape(
- num_vision, max_num_patch, 3, -1, self.head_dim
- )
- # Each: (num_vision, max_num_patch, n_heads, head_dim)
- q, k, v = qkv.permute(2, 0, 1, 3, 4).unbind(0)
- # Vision RoPE cache is already position-specific from grid_thw, so
- # apply the prepared cache directly without a RoPE module instance.
- q, k = CosSinRoPE.apply_rotary_emb(q, k, rope_cache)
- attn_output = self.flex_attention(q, k, v, attention_masks=attention_mask)
- attn_output = attn_output.reshape(num_vision, max_num_patch, -1)
- return self.proj(attn_output)
+ output = self.flex_attention(xq, xk, xv, attention_masks=attention_mask)
+ output = output.reshape(bs, seqlen, -1)
+ return self.proj(output)
class VisionMLP(Module):
@@ -427,8 +364,8 @@ def __init__(self, config: Config):
self.linear_fc2 = config.fc2.build()
self.act_fn = config.act_fn.build()
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear_fc2(self.act_fn(self.linear_fc1(x)))
class VisionTransformerBlock(Module):
@@ -436,10 +373,10 @@ class VisionTransformerBlock(Module):
@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
- attn: VisionAttention.Config
- mlp: VisionMLP.Config
norm1: LayerNorm.Config
norm2: LayerNorm.Config
+ attn: VisionAttention.Config
+ mlp: VisionMLP.Config
def __init__(self, config: Config):
super().__init__()
@@ -450,51 +387,44 @@ def __init__(self, config: Config):
def forward(
self,
- hidden_states: torch.Tensor,
+ x: torch.Tensor,
*,
rope_cache: torch.Tensor,
attention_mask: BlockMask,
) -> torch.Tensor:
- hidden_states = hidden_states + self.attn(
- self.norm1(hidden_states),
- rope_cache=rope_cache,
- attention_mask=attention_mask,
+ x = x + self.attn(
+ self.norm1(x), rope_cache=rope_cache, attention_mask=attention_mask
)
- hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
- return hidden_states
+ x = x + self.mlp(self.norm2(x))
+ return x
-class Qwen3VLVisionEncoder(Module):
- """Qwen3-VL Vision Encoder with FlexAttention.
+class Qwen35VisionEncoder(Module):
+ """Qwen3.5 Vision Encoder with FlexAttention.
Uses padded batches (N, L, D) format for efficient processing.
"""
@dataclass(kw_only=True, slots=True)
class Config(Module.Config):
- """Configuration for Qwen3-VL Vision Encoder (ViT).
-
- ``num_position_embeddings`` and ``dim`` stay on this Config because
- ``pos_embed`` is an ``nn.Parameter`` owned directly by the encoder
- (not factored into a sub-Module). ``n_heads`` is kept to derive
- ``head_dim`` inside ``compute_position_embeddings``.
- ``spatial_merge_size`` and ``deepstack_visual_indices`` are read by
- ``Qwen3VLModel``.
- """
+ """Configuration for Qwen3.5 Vision Encoder (ViT)."""
- dim: int
- n_heads: int
- spatial_merge_size: int
- num_position_embeddings: int
+ dim: int = 1280
+ num_layers: int = 32
+ num_heads: int = 16
+
+ patch_size: int = 16
+ temporal_patch_size: int = 2
+ in_channels: int = 3
+ spatial_merge_size: int = 2
- # DeepStack: layer indices for extracting intermediate visual features
- deepstack_visual_indices: list[int]
+ num_position_embeddings: int = 4096
- patch_embed: PatchEmbed.Config
+ # Sub-module configs
+ patch_embed_proj: Linear.Config
+ block: VisionTransformerBlock.Config
rotary_pos_emb: VisionRotaryEmbedding.Config
- layers: list[VisionTransformerBlock.Config]
merger: PatchMerger.Config
- deepstack_mergers: list[PatchMerger.Config]
def __init__(self, config: Config):
super().__init__()
@@ -502,7 +432,8 @@ def __init__(self, config: Config):
self.spatial_merge_size = config.spatial_merge_size
self.spatial_merge_unit = config.spatial_merge_size**2
- self.patch_embed = config.patch_embed.build()
+ # Patches are pre-extracted by the collator, so Linear replaces Conv3d (equivalent at full-patch kernel size).
+ self.patch_embed = config.patch_embed_proj.build()
# nn.Parameter (not nn.Embedding) because we interpolate the weight directly
self.num_position_embeddings = config.num_position_embeddings
@@ -512,21 +443,14 @@ def __init__(self, config: Config):
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
self.rotary_pos_emb = config.rotary_pos_emb.build()
- # Cached RoPE freq table — recomputed only when max_hw grows
self._cached_freq_table: torch.Tensor | None = None
self.layers = ModuleDict(
- {str(idx): layer.build() for idx, layer in enumerate(config.layers)}
+ {str(idx): config.block.build() for idx in range(config.num_layers)}
)
self.merger = config.merger.build()
- # DeepStack mergers for intermediate layers
- self.deepstack_visual_indices = config.deepstack_visual_indices
- self.deepstack_merger_list = ModuleList(
- [cfg.build() for cfg in config.deepstack_mergers]
- )
-
def compute_position_embeddings(
self, grid_thw: torch.Tensor, max_num_patch: int
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -537,7 +461,7 @@ def compute_position_embeddings(
- ``_compute_2d_rope_cache``: 2D RoPE cache
Args:
- grid_thw: (num_vision, 3) with pixel patch counts [t, h, w] per visual item
+ grid_thw: (num_vision, 3) with patch counts [t, h, w] per visual item
max_num_patch: Maximum number of patches (for padding)
Returns:
@@ -545,7 +469,7 @@ def compute_position_embeddings(
rope_cache: (num_vision, max_num_patch, 1, head_dim*2) RoPE cache for
VisionAttention
"""
- head_dim = self.config.dim // self.config.n_heads
+ head_dim = self.config.dim // self.config.num_heads
# Get RoPE freq table, reusing cache when possible
max_hw = int(grid_thw[:, 1:].max().item())
@@ -576,7 +500,7 @@ def forward(
pixel_values: torch.Tensor,
*,
grid_thw: torch.Tensor,
- ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ ) -> torch.Tensor:
"""Forward pass of the vision encoder.
Processes both images and videos — each visual item is a batch of
@@ -588,40 +512,28 @@ def forward(
Returns:
merged_hidden_states: (num_vision, max_merged_num_patch, out_hidden_size)
- deepstack_features: List of features from intermediate layers
"""
num_vision, max_num_patch, _ = pixel_values.shape
num_patch = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).to(torch.long)
- hidden_states = self.patch_embed(pixel_values)
+ x = self.patch_embed(pixel_values) # (num_vision, max_num_patch, dim)
learned_pos, rope_cache = self.compute_position_embeddings(
grid_thw, max_num_patch
)
- hidden_states = hidden_states + learned_pos
+ x = x + learned_pos
- mask_mod = get_vision_block_mask_mod(num_patch, max_num_patch)
+ mask_mod = get_vision_block_mask_mod(num_patch)
attention_mask = _compiled_create_block_mask(
mask_mod,
num_vision,
None,
max_num_patch,
max_num_patch,
- device=hidden_states.device,
+ device=x.device,
)
- deepstack_features = []
-
- for layer_idx, layer in self.layers.items():
- hidden_states = layer(
- hidden_states,
- rope_cache=rope_cache,
- attention_mask=attention_mask,
- )
- if int(layer_idx) in self.deepstack_visual_indices:
- idx = self.deepstack_visual_indices.index(int(layer_idx))
- deepstack_feature = self.deepstack_merger_list[idx](hidden_states)
- deepstack_features.append(deepstack_feature)
- merged_hidden_states = self.merger(hidden_states)
+ for layer in self.layers.values():
+ x = layer(x, rope_cache=rope_cache, attention_mask=attention_mask)
- return merged_hidden_states, deepstack_features
+ return self.merger(x)
diff --git a/torchtitan/models/qwen3_vl/README.md b/torchtitan/models/qwen3_vl/README.md
deleted file mode 100644
index 6f5b54afa7..0000000000
--- a/torchtitan/models/qwen3_vl/README.md
+++ /dev/null
@@ -1,66 +0,0 @@
-# Qwen3-VL: Vision-Language Model
-
-## Overview
-
-Qwen3-VL combines:
-- **Qwen3 LLM**: The base language model with QK-norm and RoPE.
-- **Vision Encoder**: A Vision Transformer (ViT) that supports native resolution images (no fixed square crops) with 2D RoPE and bilinear-interpolated position embeddings.
-- **Patch Merger**: Reduces vision sequence length by merging spatial patches (e.g., 2×2 patches → 1 token).
-- **DeepStack**: Adds intermediate ViT features to vision positions in early decoder layers.
-- **MRoPE**: Interleaves RoPE from temporal, height, and width position IDs in decoder layers.
-
-## Vision Scatter
-
-- `tok_embeddings` produces text token embeddings of shape `B×S`.
-- `vision_encoder` produces visual token embeddings of shape `N×L`.
-- Valid visual tokens (excluding padding) are scattered into the placeholder positions in the text sequence, as illustrated below (credit: [@lkhphuc](https://github.com/lkhphuc)).
-
-
-
-Note: the diagram shows each patch mapping to one vision token. In practice, the Patch Merger groups `merge_size²` patches into one token (e.g., `merge_size=2` → 4 patches per token), reducing the vision sequence length by `merge_size²`.
-
-## Prerequisites
-
-Install the additional dependencies required by Qwen3-VL:
-
-```bash
-pip install av torchvision
-```
-
-## Model Variants
-
-| Variant | LLM dim | Layers | ViT dim | ViT layers | Patch size | MoE |
-|---------|---------|--------|---------|------------|------------|-----|
-| debugmodel | 256 | 4 | 256 | 4 | 16 | No |
-| debugmodel_moe | 256 | 1 | 256 | 4 | 16 | Yes (8 experts) |
-| 2B | 2048 | 28 | 1024 | 24 | 16 | No |
-| 8B | 4096 | 36 | 1152 | 27 | 16 | No |
-| 30B-A3B | 2048 | 48 | 1152 | 27 | 16 | Yes (128 experts) |
-| 235B-A22B | 4096 | 94 | 1152 | 27 | 16 | Yes (128 experts) |
-
-## Datasets
-
-| Dataset | Type | Format |
-|---------|------|--------|
-| `cc12m` | Image-text pairs | WebDataset (streaming) |
-| `cc12m-test` | Image-text pairs | Local WebDataset (for testing) |
-| `obelics` | Interleaved image-text | HuggingFace (streaming) |
-
-## Supported Features
-
-| Feature | Notes |
-|---------|-------|
-| FSDP / HSDP | Both vision encoder and decoder are individually sharded |
-| Tensor Parallelism (TP) | Applied to both vision encoder and decoder (without SequenceParallel due to vision scatter and DeepStack) |
-| Expert Parallelism (EP) | For MoE variants (e.g., 30B-A3B) |
-| Sample Packing | Configurable via `packing_buffer_size` in dataloader config |
-
-## Numerical Parity
-
-End-to-end KL divergence against HuggingFace Transformers (2B, 10 random samples): **~5e-8 to ~5e-5** per sample, with **100% top-1 and top-5 match**. Test scripts are in `scripts/checkpoint_conversion/numerical_tests_qwen3_vl_*.py`.
-
-## TODO
-
-- Add Pipeline Parallelism support
-- Add Context Parallel support
-- Add video dataset training configs
diff --git a/torchtitan/models/qwen3_vl/__init__.py b/torchtitan/models/qwen3_vl/__init__.py
deleted file mode 100644
index 75b3ef29ee..0000000000
--- a/torchtitan/models/qwen3_vl/__init__.py
+++ /dev/null
@@ -1,649 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from collections.abc import Callable
-from functools import partial
-
-import torch.nn as nn
-
-from torchtitan.models.common import Embedding, Linear, RoPE, TransformerBlock
-from torchtitan.models.common.config_utils import (
- get_attention_config,
- make_experts_config,
- make_ffn_config,
- make_gqa_config,
- make_moe_config,
- make_router_config,
-)
-from torchtitan.models.common.nn_modules import LayerNorm, RMSNorm
-from torchtitan.models.common.param_init import depth_scaled_std, skip_param_init
-from torchtitan.models.qwen3.model import Qwen3TransformerBlock
-from torchtitan.models.utils import validate_converter_order
-from torchtitan.protocols.model import ModelConfigConverter
-from torchtitan.protocols.model_spec import ModelSpec
-from .model import Qwen3VLModel
-from .parallelize import parallelize_qwen3_vl
-from .rope import MRoPE
-from .state_dict_adapter import Qwen3VLStateDictAdapter
-from .vision_encoder import (
- PatchEmbed,
- PatchMerger,
- Qwen3VLVisionEncoder,
- VisionAttention,
- VisionMLP,
- VisionRotaryEmbedding,
- VisionTransformerBlock,
-)
-
-__all__ = [
- "parallelize_qwen3_vl",
- "Qwen3VLModel",
- "qwen3_vl_configs",
- "QWEN3_VL_SPECIAL_TOKENS",
-]
-
-QWEN3_VL_SPECIAL_TOKENS = {
- "image_token": "<|image_pad|>",
- "video_token": "<|video_pad|>",
- "vision_start_token": "<|vision_start|>",
- "vision_end_token": "<|vision_end|>",
- "pad_token": "<|endoftext|>",
-}
-
-
-_LINEAR_INIT = {
- "weight": partial(nn.init.trunc_normal_, std=0.02),
- "bias": nn.init.zeros_,
-}
-_NORM_INIT = {"weight": nn.init.ones_}
-_EMBEDDING_INIT = {"weight": partial(nn.init.normal_, std=1.0)}
-_EMBEDDING_SKIP_INIT = {"weight": skip_param_init}
-_POS_EMBED_INIT = {"pos_embed": partial(nn.init.trunc_normal_, mean=0.0, std=0.02)}
-
-_EPS = 1e-6
-
-
-def _output_linear_init(dim: int) -> dict[str, Callable]:
- s = dim**-0.5
- return {
- "weight": partial(nn.init.trunc_normal_, std=s, a=-3 * s, b=3 * s),
- "bias": nn.init.zeros_,
- }
-
-
-def _depth_init(layer_id: int) -> dict[str, Callable]:
- return {
- "weight": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)),
- "bias": nn.init.zeros_,
- }
-
-
-def _depth_experts_init(layer_id: int) -> dict[str, Callable]:
- return {
- "w1_EFD": partial(nn.init.trunc_normal_, std=0.02),
- "w2_EDF": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)),
- "w3_EFD": partial(nn.init.trunc_normal_, std=depth_scaled_std(0.02, layer_id)),
- }
-
-
-def _vl_linear(in_features: int, out_features: int) -> Linear.Config:
- return Linear.Config(
- in_features=in_features,
- out_features=out_features,
- bias=True,
- param_init=_LINEAR_INIT,
- )
-
-
-def _qwen3_vl_norm(dim: int) -> RMSNorm.Config:
- return RMSNorm.Config(normalized_shape=dim, eps=_EPS, param_init=_NORM_INIT)
-
-
-def _vl_layer_norm(normalized_shape: int) -> LayerNorm.Config:
- return LayerNorm.Config(normalized_shape=normalized_shape, eps=_EPS)
-
-
-def _build_qwen3_vl_vision_block(
- *, dim: int, ffn_dim: int, n_heads: int
-) -> VisionTransformerBlock.Config:
- return VisionTransformerBlock.Config(
- attn=VisionAttention.Config(
- dim=dim,
- n_heads=n_heads,
- qkv=_vl_linear(dim, dim * 3),
- proj=_vl_linear(dim, dim),
- ),
- mlp=VisionMLP.Config(
- fc1=_vl_linear(dim, ffn_dim),
- fc2=_vl_linear(ffn_dim, dim),
- ),
- norm1=_vl_layer_norm(dim),
- norm2=_vl_layer_norm(dim),
- )
-
-
-def _build_qwen3_vl_merger(
- *,
- dim: int,
- out_hidden_size: int,
- spatial_merge_size: int,
- use_postshuffle_norm: bool,
-) -> PatchMerger.Config:
- merged_hidden_size = dim * (spatial_merge_size**2)
- norm_dim = merged_hidden_size if use_postshuffle_norm else dim
- return PatchMerger.Config(
- hidden_size=dim,
- spatial_merge_size=spatial_merge_size,
- fc1=_vl_linear(merged_hidden_size, merged_hidden_size),
- fc2=_vl_linear(merged_hidden_size, out_hidden_size),
- norm=_vl_layer_norm(norm_dim),
- use_postshuffle_norm=use_postshuffle_norm,
- )
-
-
-def _vl_vision_encoder_config(
- *,
- dim: int,
- ffn_dim: int,
- n_layers: int,
- n_heads: int,
- patch_size: int,
- temporal_patch_size: int,
- spatial_merge_size: int,
- out_hidden_size: int,
- num_position_embeddings: int,
- deepstack_visual_indices: list[int],
- in_channels: int = 3,
-) -> Qwen3VLVisionEncoder.Config:
- """Build a fully-specified Qwen3VLVisionEncoder.Config."""
- patch_dim = in_channels * temporal_patch_size * patch_size * patch_size
- head_dim = dim // n_heads
- return Qwen3VLVisionEncoder.Config(
- dim=dim,
- n_heads=n_heads,
- spatial_merge_size=spatial_merge_size,
- num_position_embeddings=num_position_embeddings,
- deepstack_visual_indices=deepstack_visual_indices,
- patch_embed=PatchEmbed.Config(
- patch_size=patch_size,
- temporal_patch_size=temporal_patch_size,
- in_channels=in_channels,
- proj=_vl_linear(patch_dim, dim),
- ),
- rotary_pos_emb=VisionRotaryEmbedding.Config(dim=head_dim // 2),
- layers=[
- _build_qwen3_vl_vision_block(dim=dim, ffn_dim=ffn_dim, n_heads=n_heads)
- for _ in range(n_layers)
- ],
- merger=_build_qwen3_vl_merger(
- dim=dim,
- out_hidden_size=out_hidden_size,
- spatial_merge_size=spatial_merge_size,
- use_postshuffle_norm=False,
- ),
- deepstack_mergers=[
- _build_qwen3_vl_merger(
- dim=dim,
- out_hidden_size=out_hidden_size,
- spatial_merge_size=spatial_merge_size,
- use_postshuffle_norm=True,
- )
- for _ in range(len(deepstack_visual_indices))
- ],
- param_init=_POS_EMBED_INIT,
- )
-
-
-def _build_qwen3_vl_layers(
- *,
- n_layers: int,
- dim: int,
- n_heads: int,
- n_kv_heads: int,
- head_dim: int,
- hidden_dim: int,
- attn_backend: str,
- rope: RoPE.Config,
-) -> list[TransformerBlock.Config]:
- """Build per-layer configs for dense Qwen3-VL models with depth-scaled inits."""
- inner_attention = get_attention_config(attn_backend)
- layers = []
- for layer_id in range(n_layers):
- layers.append(
- Qwen3TransformerBlock.Config(
- attention_norm=_qwen3_vl_norm(dim),
- ffn_norm=_qwen3_vl_norm(dim),
- attention=make_gqa_config(
- dim=dim,
- n_heads=n_heads,
- n_kv_heads=n_kv_heads,
- head_dim=head_dim,
- wqkv_param_init=_LINEAR_INIT,
- wo_param_init=_depth_init(layer_id),
- inner_attention=inner_attention,
- rope=rope,
- qk_norm=_qwen3_vl_norm(head_dim),
- ),
- feed_forward=make_ffn_config(
- dim=dim,
- hidden_dim=hidden_dim,
- w1_param_init=_LINEAR_INIT,
- w2w3_param_init=_depth_init(layer_id),
- ),
- )
- )
- return layers
-
-
-def _build_qwen3_vl_moe_layers(
- *,
- n_layers: int,
- dim: int,
- n_heads: int,
- n_kv_heads: int,
- head_dim: int,
- moe_hidden_dim: int,
- num_experts: int,
- top_k: int,
- attn_backend: str,
- moe_comm_backend: str,
- non_blocking_capacity_factor: float | None = None,
- rope: RoPE.Config,
-) -> list[TransformerBlock.Config]:
- """Build per-layer configs for MoE Qwen3-VL models with depth-scaled inits."""
- inner_attention = get_attention_config(attn_backend)
- layers = []
- for layer_id in range(n_layers):
- layers.append(
- Qwen3TransformerBlock.Config(
- attention_norm=_qwen3_vl_norm(dim),
- ffn_norm=_qwen3_vl_norm(dim),
- attention=make_gqa_config(
- dim=dim,
- n_heads=n_heads,
- n_kv_heads=n_kv_heads,
- head_dim=head_dim,
- wqkv_param_init=_LINEAR_INIT,
- wo_param_init=_depth_init(layer_id),
- inner_attention=inner_attention,
- rope=rope,
- qk_norm=_qwen3_vl_norm(head_dim),
- ),
- moe=make_moe_config(
- num_experts=num_experts,
- router=make_router_config(
- dim=dim,
- num_experts=num_experts,
- gate_param_init=_depth_init(layer_id),
- top_k=top_k,
- score_func="softmax",
- route_norm=True,
- ),
- experts=make_experts_config(
- dim=dim,
- hidden_dim=moe_hidden_dim,
- num_experts=num_experts,
- top_k=top_k,
- param_init=_depth_experts_init(layer_id),
- score_before_experts=False,
- comm_backend=moe_comm_backend,
- non_blocking_capacity_factor=non_blocking_capacity_factor,
- ),
- ),
- )
- )
- return layers
-
-
-def _debugmodel(attn_backend: str) -> Qwen3VLModel.Config:
- dim = 256
- head_dim = 64
- n_layers = 4
- vocab_size = 151936
- return Qwen3VLModel.Config(
- vocab_size=vocab_size,
- dim=dim,
- norm=_qwen3_vl_norm(dim),
- tok_embeddings=Embedding.Config(
- num_embeddings=vocab_size,
- embedding_dim=dim,
- param_init=_EMBEDDING_INIT,
- ),
- lm_head=Linear.Config(
- in_features=dim,
- out_features=vocab_size,
- param_init=_output_linear_init(dim),
- ),
- layers=_build_qwen3_vl_layers(
- attn_backend=attn_backend,
- n_layers=n_layers,
- dim=dim,
- n_heads=4,
- n_kv_heads=2,
- head_dim=head_dim,
- hidden_dim=512,
- rope=MRoPE.Config(
- dim=head_dim,
- max_seq_len=4096,
- theta=1000000.0,
- mrope_section=[8, 8, 8],
- ),
- ),
- vision_encoder=_vl_vision_encoder_config(
- dim=256,
- ffn_dim=512,
- n_layers=4,
- n_heads=4,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- out_hidden_size=256,
- num_position_embeddings=1024,
- deepstack_visual_indices=[1, 2, 3],
- ),
- )
-
-
-def _debugmodel_moe(
- attn_backend: str,
- moe_comm_backend: str = "standard",
-) -> Qwen3VLModel.Config:
- dim = 256
- head_dim = 64
- n_layers = 1
- vocab_size = 151936
- return Qwen3VLModel.Config(
- vocab_size=vocab_size,
- dim=dim,
- norm=_qwen3_vl_norm(dim),
- tok_embeddings=Embedding.Config(
- num_embeddings=vocab_size,
- embedding_dim=dim,
- param_init=_EMBEDDING_INIT,
- ),
- lm_head=Linear.Config(
- in_features=dim,
- out_features=vocab_size,
- param_init=_output_linear_init(dim),
- ),
- layers=_build_qwen3_vl_moe_layers(
- attn_backend=attn_backend,
- n_layers=n_layers,
- dim=dim,
- n_heads=4,
- n_kv_heads=2,
- head_dim=head_dim,
- moe_hidden_dim=768,
- num_experts=8,
- top_k=4,
- moe_comm_backend=moe_comm_backend,
- rope=MRoPE.Config(
- dim=head_dim,
- max_seq_len=4096,
- theta=1000000.0,
- mrope_section=[8, 8, 8],
- ),
- ),
- vision_encoder=_vl_vision_encoder_config(
- dim=256,
- ffn_dim=512,
- n_layers=4,
- n_heads=4,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- out_hidden_size=256,
- num_position_embeddings=1024,
- deepstack_visual_indices=[1, 2, 3],
- ),
- )
-
-
-def _2b(attn_backend: str) -> Qwen3VLModel.Config:
- dim = 2048
- head_dim = 128
- n_layers = 28
- vocab_size = 151936
- return Qwen3VLModel.Config(
- vocab_size=vocab_size,
- dim=dim,
- norm=_qwen3_vl_norm(dim),
- enable_weight_tying=True,
- tok_embeddings=Embedding.Config(
- num_embeddings=vocab_size,
- embedding_dim=dim,
- param_init=_EMBEDDING_SKIP_INIT,
- ),
- lm_head=Linear.Config(
- in_features=dim,
- out_features=vocab_size,
- param_init=_output_linear_init(dim),
- ),
- layers=_build_qwen3_vl_layers(
- attn_backend=attn_backend,
- n_layers=n_layers,
- dim=dim,
- n_heads=16,
- n_kv_heads=8,
- head_dim=head_dim,
- hidden_dim=6144,
- rope=MRoPE.Config(
- dim=head_dim,
- max_seq_len=32768,
- theta=5000000.0,
- mrope_section=[24, 20, 20],
- ),
- ),
- vision_encoder=_vl_vision_encoder_config(
- dim=1024,
- ffn_dim=4096,
- n_layers=24,
- n_heads=16,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- out_hidden_size=2048,
- num_position_embeddings=2304,
- deepstack_visual_indices=[5, 11, 17],
- ),
- )
-
-
-def _8b(attn_backend: str) -> Qwen3VLModel.Config:
- dim = 4096
- head_dim = 128
- n_layers = 36
- vocab_size = 151936
- return Qwen3VLModel.Config(
- vocab_size=vocab_size,
- dim=dim,
- norm=_qwen3_vl_norm(dim),
- tok_embeddings=Embedding.Config(
- num_embeddings=vocab_size,
- embedding_dim=dim,
- param_init=_EMBEDDING_INIT,
- ),
- lm_head=Linear.Config(
- in_features=dim,
- out_features=vocab_size,
- param_init=_output_linear_init(dim),
- ),
- layers=_build_qwen3_vl_layers(
- attn_backend=attn_backend,
- n_layers=n_layers,
- dim=dim,
- n_heads=32,
- n_kv_heads=8,
- head_dim=head_dim,
- hidden_dim=12288,
- rope=MRoPE.Config(
- dim=head_dim,
- max_seq_len=32768,
- theta=5000000.0,
- mrope_section=[24, 20, 20],
- ),
- ),
- vision_encoder=_vl_vision_encoder_config(
- dim=1152,
- ffn_dim=4304,
- n_layers=27,
- n_heads=16,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- out_hidden_size=4096,
- num_position_embeddings=2304,
- deepstack_visual_indices=[8, 16, 24],
- ),
- )
-
-
-# Qwen3-VL MoE models
-
-
-def _30b_a3b(
- attn_backend: str,
- moe_comm_backend: str = "standard",
-) -> Qwen3VLModel.Config:
- dim = 2048
- head_dim = 128
- n_layers = 48
- vocab_size = 151936
- return Qwen3VLModel.Config(
- vocab_size=vocab_size,
- dim=dim,
- norm=_qwen3_vl_norm(dim),
- tok_embeddings=Embedding.Config(
- num_embeddings=vocab_size,
- embedding_dim=dim,
- param_init=_EMBEDDING_INIT,
- ),
- lm_head=Linear.Config(
- in_features=dim,
- out_features=vocab_size,
- param_init=_output_linear_init(dim),
- ),
- layers=_build_qwen3_vl_moe_layers(
- attn_backend=attn_backend,
- n_layers=n_layers,
- dim=dim,
- n_heads=32,
- n_kv_heads=4,
- head_dim=head_dim,
- moe_hidden_dim=768,
- num_experts=128,
- top_k=8,
- moe_comm_backend=moe_comm_backend,
- rope=MRoPE.Config(
- dim=head_dim,
- max_seq_len=32768,
- theta=5000000.0,
- mrope_section=[24, 20, 20],
- ),
- ),
- vision_encoder=_vl_vision_encoder_config(
- dim=1152,
- ffn_dim=4304,
- n_layers=27,
- n_heads=16,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- out_hidden_size=2048,
- num_position_embeddings=2304,
- deepstack_visual_indices=[8, 16, 24],
- ),
- )
-
-
-def _235b_a22b(
- attn_backend: str,
- moe_comm_backend: str = "standard",
-) -> Qwen3VLModel.Config:
- dim = 4096
- head_dim = 128
- n_layers = 94
- vocab_size = 151936
- return Qwen3VLModel.Config(
- vocab_size=vocab_size,
- dim=dim,
- norm=_qwen3_vl_norm(dim),
- tok_embeddings=Embedding.Config(
- num_embeddings=vocab_size,
- embedding_dim=dim,
- param_init=_EMBEDDING_INIT,
- ),
- lm_head=Linear.Config(
- in_features=dim,
- out_features=vocab_size,
- param_init=_output_linear_init(dim),
- ),
- layers=_build_qwen3_vl_moe_layers(
- attn_backend=attn_backend,
- n_layers=n_layers,
- dim=dim,
- n_heads=64,
- n_kv_heads=4,
- head_dim=head_dim,
- moe_hidden_dim=1536,
- num_experts=128,
- top_k=8,
- moe_comm_backend=moe_comm_backend,
- rope=MRoPE.Config(
- dim=head_dim,
- max_seq_len=32768,
- theta=5000000.0,
- mrope_section=[24, 20, 20],
- ),
- ),
- vision_encoder=_vl_vision_encoder_config(
- dim=1152,
- ffn_dim=4304,
- n_layers=27,
- n_heads=16,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- out_hidden_size=4096,
- num_position_embeddings=2304,
- deepstack_visual_indices=[8, 16, 24],
- ),
- )
-
-
-qwen3_vl_configs = {
- "debugmodel": _debugmodel,
- "debugmodel_moe": _debugmodel_moe,
- "2B": _2b,
- "8B": _8b,
- "30B-A3B": _30b_a3b,
- "235B-A22B": _235b_a22b,
-}
-
-
-def model_registry(
- flavor: str,
- attn_backend: str = "flex",
- moe_comm_backend: str | None = None,
- converters: list[ModelConfigConverter.Config] | None = None,
-) -> ModelSpec:
- kwargs = dict(attn_backend=attn_backend)
- if moe_comm_backend is not None:
- kwargs["moe_comm_backend"] = moe_comm_backend
- config = qwen3_vl_configs[flavor](**kwargs)
- if converters is not None:
- validate_converter_order(converters)
- for c in converters:
- c.build().convert(config)
- return ModelSpec(
- name="qwen3_vl",
- flavor=flavor,
- model=config,
- parallelize_fn=parallelize_qwen3_vl,
- pipelining_fn=None,
- post_optimizer_build_fn=None,
- state_dict_adapter=Qwen3VLStateDictAdapter,
- )
diff --git a/torchtitan/models/qwen3_vl/config_registry.py b/torchtitan/models/qwen3_vl/config_registry.py
deleted file mode 100644
index a2a7fd59ff..0000000000
--- a/torchtitan/models/qwen3_vl/config_registry.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from torchtitan.components.checkpoint import CheckpointManager
-from torchtitan.components.loss import ChunkedCELoss
-from torchtitan.components.lr_scheduler import LRSchedulersContainer
-from torchtitan.components.metrics import MetricsProcessor
-from torchtitan.components.optimizer import default_adamw
-from torchtitan.components.tokenizer import MultiModalTokenizer
-
-from torchtitan.config import (
- ActivationCheckpointConfig,
- ParallelismConfig,
- TrainingConfig,
-)
-from torchtitan.hf_datasets.multimodal.mm_datasets import MMDataLoader
-from torchtitan.trainer import Trainer
-
-from . import model_registry, QWEN3_VL_SPECIAL_TOKENS
-
-
-def _qwen3_vl_dataloader(dataset: str, **kwargs) -> MMDataLoader.Config:
- return MMDataLoader.Config(
- dataset=dataset,
- max_images_per_batch=128,
- patch_size=16,
- temporal_patch_size=2,
- spatial_merge_size=2,
- min_pixels=65536,
- max_pixels=16777216,
- image_mean=(0.5, 0.5, 0.5),
- image_std=(0.5, 0.5, 0.5),
- **kwargs,
- )
-
-
-def qwen3_vl_debugmodel() -> Trainer.Config:
- return Trainer.Config(
- loss=ChunkedCELoss.Config(),
- hf_assets_path="./tests/assets/tokenizer",
- tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS),
- metrics=MetricsProcessor.Config(log_freq=1),
- model_spec=model_registry("debugmodel"),
- dataloader=_qwen3_vl_dataloader("cc12m-test"),
- optimizer=default_adamw(lr=8e-4),
- lr_scheduler=LRSchedulersContainer.Config(
- warmup_steps=2,
- decay_ratio=0.8,
- decay_type="linear",
- min_lr_factor=0.0,
- ),
- training=TrainingConfig(
- local_batch_size=1,
- seq_len=512,
- steps=10,
- ),
- checkpoint=CheckpointManager.Config(
- interval=10,
- last_save_model_only=False,
- ),
- activation_checkpoint=ActivationCheckpointConfig(
- mode="selective",
- ),
- )
-
-
-def qwen3_vl_debugmodel_moe() -> Trainer.Config:
- return Trainer.Config(
- loss=ChunkedCELoss.Config(),
- hf_assets_path="./tests/assets/tokenizer",
- tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS),
- metrics=MetricsProcessor.Config(log_freq=1),
- model_spec=model_registry("debugmodel_moe"),
- dataloader=_qwen3_vl_dataloader("cc12m-test"),
- optimizer=default_adamw(lr=3e-3),
- lr_scheduler=LRSchedulersContainer.Config(warmup_steps=2),
- training=TrainingConfig(
- local_batch_size=1,
- seq_len=512,
- steps=10,
- ),
- parallelism=ParallelismConfig(
- data_parallel_shard_degree=4,
- expert_parallel_degree=4,
- tensor_parallel_degree=2,
- ),
- checkpoint=CheckpointManager.Config(
- interval=10,
- last_save_model_only=False,
- ),
- activation_checkpoint=ActivationCheckpointConfig(
- mode="selective",
- ),
- )
-
-
-def qwen3_vl_2b() -> Trainer.Config:
- return Trainer.Config(
- loss=ChunkedCELoss.Config(),
- hf_assets_path="./assets/hf/Qwen3-VL-2B-Instruct",
- tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS),
- model_spec=model_registry("2B"),
- dataloader=_qwen3_vl_dataloader("cc12m"),
- optimizer=default_adamw(lr=8e-4),
- lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
- training=TrainingConfig(
- local_batch_size=8,
- seq_len=4096,
- steps=1000,
- ),
- parallelism=ParallelismConfig(
- data_parallel_shard_degree=-1,
- tensor_parallel_degree=1,
- ),
- checkpoint=CheckpointManager.Config(
- enable=False,
- interval=50,
- last_save_model_only=False,
- export_dtype="float16",
- ),
- activation_checkpoint=ActivationCheckpointConfig(
- mode="full",
- ),
- )
-
-
-def qwen3_vl_8b() -> Trainer.Config:
- return Trainer.Config(
- loss=ChunkedCELoss.Config(),
- hf_assets_path="./assets/hf/Qwen3-VL-8B-Instruct",
- tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS),
- model_spec=model_registry("8B"),
- dataloader=_qwen3_vl_dataloader("cc12m"),
- optimizer=default_adamw(lr=8e-4),
- lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
- training=TrainingConfig(
- local_batch_size=4,
- seq_len=4096,
- steps=1000,
- ),
- parallelism=ParallelismConfig(
- data_parallel_shard_degree=-1,
- tensor_parallel_degree=1,
- ),
- checkpoint=CheckpointManager.Config(
- enable=False,
- interval=50,
- last_save_model_only=False,
- export_dtype="float16",
- ),
- activation_checkpoint=ActivationCheckpointConfig(
- mode="full",
- ),
- )
-
-
-def qwen3_vl_30b_a3b() -> Trainer.Config:
- return Trainer.Config(
- loss=ChunkedCELoss.Config(),
- hf_assets_path="./assets/hf/Qwen3-VL-30B-A3B-Instruct",
- tokenizer=MultiModalTokenizer.Config(**QWEN3_VL_SPECIAL_TOKENS),
- model_spec=model_registry("30B-A3B"),
- dataloader=_qwen3_vl_dataloader("cc12m"),
- optimizer=default_adamw(lr=3e-4),
- lr_scheduler=LRSchedulersContainer.Config(warmup_steps=20),
- training=TrainingConfig(
- local_batch_size=4,
- seq_len=4096,
- steps=1000,
- ),
- parallelism=ParallelismConfig(
- data_parallel_shard_degree=-1,
- tensor_parallel_degree=1,
- expert_parallel_degree=8,
- ),
- checkpoint=CheckpointManager.Config(
- enable=False,
- interval=500,
- last_save_model_only=False,
- export_dtype="float16",
- ),
- activation_checkpoint=ActivationCheckpointConfig(
- mode="full",
- ),
- )
diff --git a/torchtitan/models/qwen3_vl/model.py b/torchtitan/models/qwen3_vl/model.py
deleted file mode 100644
index ce5e5e0374..0000000000
--- a/torchtitan/models/qwen3_vl/model.py
+++ /dev/null
@@ -1,573 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-from dataclasses import dataclass
-
-import torch
-from torch import nn
-
-from torchtitan.models.common.attention import AttentionMasksType, GQAttention
-from torchtitan.models.common.decoder import Decoder
-from torchtitan.models.qwen3.model import Qwen3Model
-from torchtitan.models.utils import get_moe_model_nparams_and_flops
-
-from .vision_encoder import Qwen3VLVisionEncoder
-
-
-class Qwen3VLModel(Qwen3Model):
- """Qwen3-VL: A Vision-Language Model based on Qwen3.
-
- Combines the Qwen3 language model with a Vision Transformer encoder
- for multimodal understanding of images and videos.
-
- Key features:
- - DeepStack: Vision embeddings from intermediate ViT layers are added to
- early LLM hidden states for better multimodal understanding
- - MRoPE: Multi-dimensional RoPE with interleaved temporal/height/width
- position encoding for vision tokens
-
- Forward pass flow::
-
- forward(tokens, pixel_values, grid_thw, ...)
- │
- ├─ _prepare_multimodal_embeds
- │ ├─ tok_embeddings(tokens) → text embeddings
- │ ├─ _get_vision_embeds(pixel_values) → padded vision embeddings + deepstack features
- │ │ └─ vision_encoder(pixel_values) → per-layer features, merge patches
- │ ├─ _compute_vision_positions → locate vision regions in text sequence
- │ └─ _scatter_vision_embeds → copy vision into text at placeholder positions
- │
- ├─ _compute_mrope_position_ids → build 3D MRoPE position IDs
- │
- └─ transformer layers
- └─ for each layer:
- ├─ layer(hidden_states, masks, positions)
- └─ _deepstack_process: add intermediate ViT features at vision positions (early layers only)
- """
-
- @dataclass(kw_only=True, slots=True)
- class Config(Qwen3Model.Config):
- vision_encoder: Qwen3VLVisionEncoder.Config
-
- def update_from_config(
- self,
- *,
- config,
- **kwargs,
- ) -> None:
- Decoder.Config.update_from_config(self, config=config, **kwargs)
- parallelism = config.parallelism
-
- from torchtitan.models.qwen3_vl.sharding import set_qwen3_vl_sharding_config
-
- set_qwen3_vl_sharding_config(
- self,
- loss_parallel=not parallelism.disable_loss_parallel,
- enable_ep=parallelism.expert_parallel_degree > 1,
- )
-
- def get_nparams_and_flops(
- self, model: nn.Module, seq_len: int
- ) -> tuple[int, int]:
- assert isinstance(self.layers[0].attention, GQAttention.Config)
- assert self.layers[0].attention.head_dim is not None
- return get_moe_model_nparams_and_flops(
- self,
- model,
- self.layers[0].attention.n_heads,
- 2 * self.layers[0].attention.head_dim,
- seq_len,
- )
-
- def __init__(self, config: Config):
- super().__init__(config)
-
- self.vision_encoder = config.vision_encoder.build()
-
- self.spatial_merge_size = config.vision_encoder.spatial_merge_size
-
- # Number of early LLM layers that receive DeepStack visual features
- self.num_deepstack_layers = len(config.vision_encoder.deepstack_visual_indices)
-
- def _compute_mrope_position_ids(
- self,
- tokens: torch.Tensor,
- *,
- grid_thw: torch.Tensor | None,
- grid_thw_videos: torch.Tensor | None,
- special_tokens: dict[str, int],
- positions: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Build 3D position IDs for Qwen3-VL MRoPE.
-
- Constructs temporal/height/width position IDs for each token. The RoPE
- module consumes these IDs and computes the interleaved cos/sin cache.
- Text tokens use the same position value across all three axes. Vision
- tokens use temporal, height, and width positions derived from their
- patch grid.
-
- Args:
- tokens: (batch, seq_len) token IDs
- grid_thw: (num_images, 3) grid dimensions for images
- grid_thw_videos: (num_videos, 3) grid dimensions for videos
- special_tokens: Special token definitions
- positions: (batch, seq_len) per-token position IDs for packed
- sequences. When provided, document boundaries are detected
- where positions reset (positions[t] < positions[t-1]), and
- pos_id_offset resets to 0 at each boundary
-
- Returns:
- (3, batch, seq_len) MRoPE position IDs
- """
- # --- Build 3D position IDs ---
-
- # Expand each video [T, H, W] into T rows of [1, H, W] so that
- # each frame is treated like an image in the MRoPE code below
- # Temporal position comes from frame ordering in the sequence
- if grid_thw_videos is not None:
- grid_thw_videos = torch.repeat_interleave(
- grid_thw_videos, grid_thw_videos[:, 0], dim=0
- )
- grid_thw_videos[:, 0] = 1
-
- spatial_merge_size = self.spatial_merge_size
- image_token_id = special_tokens["image_id"]
- video_token_id = special_tokens["video_id"]
-
- batch_size, seq_len = tokens.shape
- position_ids = torch.zeros(
- 3,
- batch_size,
- seq_len,
- dtype=tokens.dtype,
- device=tokens.device,
- )
-
- # Precompute document boundaries and vision token positions across batch
- if positions is not None:
- resets = positions[:, 1:] < positions[:, :-1] # (batch, seq_len-1)
- # Find the first token of each consecutive vision region (image or video)
- # E.g. for [text, img, img, img, text, vid, vid] → positions [1, 5]
- vision_mask = (tokens == image_token_id) | (tokens == video_token_id)
- prev_vision = torch.cat(
- [torch.zeros_like(vision_mask[:, :1]), vision_mask[:, :-1]], dim=1
- )
- batch_vision_starts = vision_mask & ~prev_vision # (batch, seq_len)
- # Cache vision grid indices by shape to avoid redundant construction
- grid_cache: dict[tuple[int, int, int], torch.Tensor] = {}
-
- image_index, video_index = 0, 0
- # Build MRoPE 3D position IDs per sample
- # With sample packing, each sample may contain multiple documents
- for sample_i in range(batch_size):
- llm_pos_ids_list: list[torch.Tensor] = []
-
- if positions is not None:
- # Detect document boundaries within one packed sample
- # pyrefly: ignore [unbound-name]
- reset_indices = torch.where(resets[sample_i])[0] + 1
- doc_starts = [0] + reset_indices.tolist()
- doc_ranges = [
- (
- doc_starts[d],
- doc_starts[d + 1] if d + 1 < len(doc_starts) else seq_len,
- )
- for d in range(len(doc_starts))
- ]
- else:
- doc_ranges = [(0, seq_len)]
-
- sample_tokens = tokens[sample_i]
- sample_vision_starts = torch.where(batch_vision_starts[sample_i])[
- 0
- ].tolist()
- vision_start_index = 0
-
- for doc_start, doc_end in doc_ranges:
- doc_pos_ids_list: list[torch.Tensor] = []
-
- # Advance pointer to collect vision region starts in this document
- doc_vision_starts: list[int] = []
- while (
- vision_start_index < len(sample_vision_starts)
- and sample_vision_starts[vision_start_index] < doc_end
- ):
- doc_vision_starts.append(sample_vision_starts[vision_start_index])
- vision_start_index += 1
-
- # Process [text tokens][vision tokens] pairs within this document
- pair_cursor = doc_start
- for vision_start in doc_vision_starts:
- if sample_tokens[vision_start] == image_token_id:
- # pyrefly: ignore [unsupported-operation]
- t, h, w = grid_thw[image_index]
- image_index += 1
- else:
- # pyrefly: ignore [unsupported-operation]
- t, h, w = grid_thw_videos[video_index]
- video_index += 1
-
- llm_grid_t, llm_grid_h, llm_grid_w = (
- t.item(),
- h.item() // spatial_merge_size,
- w.item() // spatial_merge_size,
- )
- text_len = vision_start - pair_cursor
-
- # pos_id_offset may differ from pair_cursor due to compact
- # spatial position IDs for vision regions
- pos_id_offset = (
- doc_pos_ids_list[-1].max() + 1
- if len(doc_pos_ids_list) > 0
- else 0
- )
- # [text tokens] — sequential positions, identical on all 3 axes
- doc_pos_ids_list.append(
- torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset
- )
- # [vision tokens] — 3D grid positions (T, H, W)
- grid_key = (llm_grid_t, llm_grid_h, llm_grid_w)
- if grid_key not in grid_cache:
- t_index = (
- torch.arange(llm_grid_t)
- .view(-1, 1)
- # pyrefly: ignore [no-matching-overload]
- .expand(-1, llm_grid_h * llm_grid_w)
- .flatten()
- )
- h_index = (
- torch.arange(llm_grid_h)
- .view(1, -1, 1)
- # pyrefly: ignore [no-matching-overload]
- .expand(llm_grid_t, -1, llm_grid_w)
- .flatten()
- )
- w_index = (
- torch.arange(llm_grid_w)
- .view(1, 1, -1)
- # pyrefly: ignore [no-matching-overload]
- .expand(llm_grid_t, llm_grid_h, -1)
- .flatten()
- )
- # pyrefly: ignore [unsupported-operation]
- grid_cache[grid_key] = torch.stack([t_index, h_index, w_index])
- doc_pos_ids_list.append(
- # pyrefly: ignore [bad-index]
- grid_cache[grid_key]
- + text_len
- + pos_id_offset
- )
- pair_cursor = vision_start + llm_grid_t * llm_grid_h * llm_grid_w
-
- # Trailing [text tokens] after the last [text tokens][vision tokens] pair
- if pair_cursor < doc_end:
- pos_id_offset = (
- doc_pos_ids_list[-1].max() + 1
- if len(doc_pos_ids_list) > 0
- else 0
- )
- text_len = doc_end - pair_cursor
- doc_pos_ids_list.append(
- torch.arange(text_len).view(1, -1).expand(3, -1) + pos_id_offset
- )
-
- llm_pos_ids_list.extend(doc_pos_ids_list)
-
- llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
- position_ids[:, sample_i, :] = llm_positions.to(position_ids.device)
-
- return position_ids
-
- def _compute_vision_positions(
- self,
- tokens: torch.Tensor,
- num_tokens_per_item: torch.Tensor,
- vision_token_id: int,
- ) -> list[tuple[int, int, int, int]]:
- """Compute (item_idx, sample_idx, vision_start, n_tokens) for each vision item.
-
- Finds where each contiguous run of vision placeholder tokens starts
- in the text sequence.
-
- Args:
- tokens: Token IDs (batch, seq_len)
- num_tokens_per_item: (num_items,) actual tokens per vision item
- vision_token_id: Placeholder token ID
-
- Returns:
- List of (item_idx, sample_idx, vision_start, n_tokens) tuples
- """
- vision_mask = tokens == vision_token_id
- flat_mask = vision_mask.view(-1)
- prev_mask = torch.cat(
- [torch.zeros(1, dtype=torch.bool, device=flat_mask.device), flat_mask[:-1]]
- )
- region_starts = torch.where(flat_mask & ~prev_mask)[0]
- seq_len = tokens.shape[1]
-
- positions = []
- for i in range(num_tokens_per_item.shape[0]):
- start = int(region_starts[i].item())
- n_tokens = int(num_tokens_per_item[i].item())
- positions.append((i, start // seq_len, start % seq_len, n_tokens))
- return positions
-
- def _get_vision_embeds(
- self,
- pixel_values: torch.Tensor,
- *,
- grid_thw: torch.Tensor,
- ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
- """Run vision encoder and return padded embeddings with token counts.
-
- Works for both images and videos — the ViT processes them identically.
-
- Args:
- pixel_values: Padded patches (num_items, max_num_patch, patch_dim)
- grid_thw: Grid dimensions (num_items, 3) for [t, h, w]
-
- Returns:
- merged_embeds: (num_items, max_tokens, dim) padded vision embeddings
- deepstack_features: List of (num_items, max_tokens, dim) per layer
- num_tokens_per_item: (num_items,) actual token count per item
- """
- pixel_values = pixel_values.to(
- self.vision_encoder.patch_embed.proj.weight.dtype
- )
- merged_embeds, deepstack_features = self.vision_encoder(
- pixel_values, grid_thw=grid_thw
- )
-
- merge_unit = self.vision_encoder.spatial_merge_unit
- num_tokens_per_item = grid_thw.prod(-1) // merge_unit
-
- return merged_embeds, deepstack_features, num_tokens_per_item
-
- def _scatter_vision_embeds(
- self,
- inputs_embeds: torch.Tensor,
- *,
- merged_embeds: torch.Tensor,
- vision_positions: list[tuple[int, int, int, int]],
- ) -> torch.Tensor:
- """Scatter vision embeddings into text embeddings at placeholder positions.
-
- Copies directly from the padded vision encoder output into the text
- sequence.
-
- Args:
- inputs_embeds: Text embeddings (batch, seq_len, dim)
- merged_embeds: Padded vision embeddings (num_items, max_tokens, dim)
- vision_positions: List of (item_idx, sample_idx, vision_start, n_tokens)
-
- Returns:
- Updated embeddings
- """
- for item_idx, sample_idx, vision_start, n_tokens in vision_positions:
- inputs_embeds[
- sample_idx, vision_start : vision_start + n_tokens, :
- ] = merged_embeds[item_idx, :n_tokens, :]
- return inputs_embeds
-
- def _deepstack_process(
- self,
- hidden_states: torch.Tensor,
- *,
- vision_positions: list[tuple[int, int, int, int]],
- deepstack_embeds: torch.Tensor,
- ) -> torch.Tensor:
- """Add vision embeddings to hidden states at vision token positions.
-
- Args:
- hidden_states: LLM hidden states (batch, seq_len, dim)
- vision_positions: List of (item_idx, sample_idx, vision_start, n_tokens)
- deepstack_embeds: Padded vision embeddings (num_items, max_tokens, dim)
-
- Returns:
- Updated hidden states
- """
- for item_idx, sample_idx, vision_start, n_tokens in vision_positions:
- hidden_states[
- sample_idx, vision_start : vision_start + n_tokens, :
- ] += deepstack_embeds[item_idx, :n_tokens, :]
- return hidden_states
-
- def _prepare_multimodal_embeds(
- self,
- tokens: torch.Tensor,
- *,
- pixel_values: torch.Tensor | None,
- pixel_values_videos: torch.Tensor | None,
- grid_thw: torch.Tensor | None,
- grid_thw_videos: torch.Tensor | None,
- special_tokens: dict[str, int],
- ) -> tuple[
- torch.Tensor,
- list[tuple[int, int, int, int]],
- list[tuple[int, int, int, int]],
- list[torch.Tensor] | None,
- list[torch.Tensor] | None,
- ]:
- """Embed tokens, run vision encoder, scatter vision into text, prepare DeepStack.
-
- Args:
- tokens: Input token IDs (batch_size, seq_len)
- pixel_values: Image patches or None
- pixel_values_videos: Video patches or None
- grid_thw: Grid dimensions for images or None
- grid_thw_videos: Grid dimensions for videos or None
- special_tokens: Special token definitions
-
- Returns:
- inputs_embeds: (batch, seq_len, dim) with vision tokens scattered in
- image_positions: List of (item_idx, sample_idx, vision_start, n_tokens)
- video_positions: List of (item_idx, sample_idx, vision_start, n_tokens)
- deepstack_image_features: List of (num_items, max_tokens, dim) or None
- deepstack_video_features: List of (num_items, max_tokens, dim) or None
- """
- image_token_id = special_tokens["image_id"]
- video_token_id = special_tokens["video_id"]
-
- inputs_embeds = (
- self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
- )
-
- # Process image inputs
- image_positions: list[tuple[int, int, int, int]] = []
- deepstack_image_features = None
- if pixel_values is not None and grid_thw is not None:
- merged_embeds, deepstack_features, num_tokens = self._get_vision_embeds(
- pixel_values, grid_thw=grid_thw
- )
- image_positions = self._compute_vision_positions(
- tokens, num_tokens, image_token_id
- )
- if image_positions:
- inputs_embeds = self._scatter_vision_embeds(
- inputs_embeds,
- merged_embeds=merged_embeds,
- vision_positions=image_positions,
- )
- deepstack_image_features = deepstack_features
-
- # Process video inputs
- video_positions: list[tuple[int, int, int, int]] = []
- deepstack_video_features = None
- if pixel_values_videos is not None and grid_thw_videos is not None:
- merged_embeds, deepstack_features, num_tokens = self._get_vision_embeds(
- pixel_values_videos, grid_thw=grid_thw_videos
- )
- video_positions = self._compute_vision_positions(
- tokens, num_tokens, video_token_id
- )
- if video_positions:
- inputs_embeds = self._scatter_vision_embeds(
- inputs_embeds,
- merged_embeds=merged_embeds,
- vision_positions=video_positions,
- )
- deepstack_video_features = deepstack_features
-
- return (
- inputs_embeds,
- image_positions,
- video_positions,
- deepstack_image_features,
- deepstack_video_features,
- )
-
- def forward( # pyrefly: ignore [bad-override]
- self,
- tokens: torch.Tensor,
- *,
- pixel_values: torch.Tensor | None = None,
- pixel_values_videos: torch.Tensor | None = None,
- grid_thw: torch.Tensor | None = None,
- grid_thw_videos: torch.Tensor | None = None,
- attention_masks: AttentionMasksType | None = None,
- positions: torch.Tensor | None = None,
- special_tokens: dict[str, int],
- ):
- """Forward pass of Qwen3-VL.
-
- Args:
- tokens: Input token IDs (batch_size, seq_len)
- pixel_values: Flattened image patches (num_images, max_num_patches, patch_dim)
- pixel_values_videos: Flattened video patches (num_videos, max_num_patches, patch_dim)
- grid_thw: Grid dimensions for images (num_images, 3)
- grid_thw_videos: Grid dimensions for videos (num_videos, 3)
- attention_masks: Attention masks for block_causal / flex attention
- positions: Per-token position IDs (batch_size, seq_len) for packed sequences.
- Each document's positions reset to 0. None means sequential positions.
- special_tokens: Special token definitions
-
- Returns:
- Output logits (batch_size, seq_len, vocab_size)
- """
- (
- inputs_embeds,
- image_positions,
- video_positions,
- deepstack_image_features,
- deepstack_video_features,
- ) = self._prepare_multimodal_embeds(
- tokens,
- pixel_values=pixel_values,
- pixel_values_videos=pixel_values_videos,
- grid_thw=grid_thw,
- grid_thw_videos=grid_thw_videos,
- special_tokens=special_tokens,
- )
-
- # Compute MRoPE position IDs when vision inputs are present.
- if grid_thw is not None or grid_thw_videos is not None:
- positions = self._compute_mrope_position_ids(
- tokens,
- grid_thw=grid_thw,
- grid_thw_videos=grid_thw_videos,
- special_tokens=special_tokens,
- positions=positions,
- )
-
- # Apply transformer layers with DeepStack
- hidden_states = inputs_embeds
- for layer_idx, layer in self.layers.items():
- hidden_states = layer(hidden_states, attention_masks, positions)
-
- # Apply DeepStack: add visual features to early layer hidden states
- layer_idx_int = int(layer_idx)
- if layer_idx_int < self.num_deepstack_layers:
- if (
- deepstack_image_features is not None
- and image_positions
- and layer_idx_int < len(deepstack_image_features)
- ):
- hidden_states = self._deepstack_process(
- hidden_states,
- vision_positions=image_positions,
- deepstack_embeds=deepstack_image_features[layer_idx_int],
- )
- if (
- deepstack_video_features is not None
- and video_positions
- and layer_idx_int < len(deepstack_video_features)
- ):
- hidden_states = self._deepstack_process(
- hidden_states,
- vision_positions=video_positions,
- deepstack_embeds=deepstack_video_features[layer_idx_int],
- )
-
- hidden_states = (
- self.norm(hidden_states) if self.norm is not None else hidden_states
- )
- if self._skip_lm_head:
- return hidden_states
- output = (
- self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
- )
- return output
diff --git a/torchtitan/models/qwen3_vl/parallelize.py b/torchtitan/models/qwen3_vl/parallelize.py
deleted file mode 100644
index 006d0ea82d..0000000000
--- a/torchtitan/models/qwen3_vl/parallelize.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-Parallelization utilities for Qwen3-VL.
-
-This module applies PT-D parallelisms and various training techniques
-(activation checkpointing, compile, FSDP) to the Qwen3-VL model.
-"""
-
-import torch
-import torch.nn as nn
-
-from torch.distributed.device_mesh import DeviceMesh
-from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
-from torch.distributed.tensor import distribute_tensor, Replicate
-from torch.distributed.tensor.parallel import (
- ColwiseParallel,
- parallelize_module,
- RowwiseParallel,
-)
-
-from torchtitan.config import (
- ActivationCheckpointConfig,
- CompileConfig,
- ParallelismConfig,
- TORCH_DTYPE_MAP,
- TrainingConfig,
-)
-
-from torchtitan.distributed import ParallelDims
-from torchtitan.distributed.activation_checkpoint import apply_ac
-from torchtitan.distributed.compile import apply_compile
-from torchtitan.distributed.fsdp import (
- apply_fsdp_to_decoder,
- get_fsdp_reshard_after_forward_policy,
-)
-from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp, NoParallel
-from torchtitan.models.qwen3_vl.model import Qwen3VLModel
-from torchtitan.tools.logging import logger
-
-
-def _apply_tp_to_vision_encoder(
- vision_encoder: nn.Module,
- tp_mesh: DeviceMesh,
-):
- """Apply tensor parallelism to the vision encoder.
-
- Hidden states flow as DTensor(Replicate) throughout — all ranks hold the
- full hidden_states. Only the linear layers (qkv, proj, fc1, fc2) are
- sharded via ColwiseParallel/RowwiseParallel to save memory. Norms operate
- on Replicate DTensors directly.
- """
- # NoParallel on patch_embed distributes its params as Replicate DTensors
- # on tp_mesh for FSDP mesh consistency. Its input hook wraps plain
- # pixel_values as DTensor(Replicate), and the output stays as DTensor
- # (Replicate) to flow through the rest of the vision encoder.
- parallelize_module(vision_encoder, tp_mesh, {"patch_embed": NoParallel()})
-
- # pos_embed is an nn.Parameter (not a submodule), so it can't be targeted
- # by parallelize_module's plan dict. We distribute it as Replicate DTensor
- # on tp_mesh for FSDP mesh consistency.
- vision_encoder.pos_embed = nn.Parameter(
- # pyrefly: ignore [bad-argument-type]
- distribute_tensor(vision_encoder.pos_embed.data, tp_mesh, [Replicate()]),
- requires_grad=vision_encoder.pos_embed.requires_grad,
- )
-
- # TP plan for each vision transformer block.
- # hidden_states flows through as DTensor (Replicate) so residual adds work.
- # NoParallel on norms sets their params as Replicate DTensors on tp_mesh
- # (for consistent (fsdp, tp) mesh after FSDP).
- # RowwiseParallel uses use_local_output=False to return DTensor (Replicate)
- # so residual connections (hidden_states + attn/mlp output) stay in DTensor
- # space. ColwiseParallel uses use_local_output=True (default) to return
- # local shards for the internal attention/MLP computation.
- layer_plan = {
- "norm1": NoParallel(),
- "norm2": NoParallel(),
- "attn.qkv": ColwiseParallel(), # needs plain tensor for reshape after qkv
- "attn.proj": RowwiseParallel(use_local_output=False),
- "mlp.linear_fc1": ColwiseParallel(use_local_output=False),
- "mlp.linear_fc2": RowwiseParallel(use_local_output=False),
- }
-
- # pyrefly: ignore [not-callable]
- for transformer_block in vision_encoder.layers.values():
- # pyrefly: ignore [bad-argument-type]
- parallelize_module(transformer_block, tp_mesh, layer_plan)
-
- # TP plan for patch mergers (main + deepstack).
- # Mergers output DTensor(Replicate) — the model passes padded embeddings
- # directly to vision scatter and DeepStack.
- merger_plan = {
- "norm": NoParallel(),
- "linear_fc1": ColwiseParallel(use_local_output=False),
- "linear_fc2": RowwiseParallel(use_local_output=False),
- }
-
- # pyrefly: ignore [bad-argument-type]
- parallelize_module(vision_encoder.merger, tp_mesh, merger_plan)
- # pyrefly: ignore [not-iterable]
- for merger in vision_encoder.deepstack_merger_list:
- # pyrefly: ignore [bad-argument-type]
- parallelize_module(merger, tp_mesh, merger_plan)
-
- logger.info("Applied Tensor Parallelism to the vision encoder")
-
-
-def _apply_fsdp_to_vision_encoder(
- vision_encoder: nn.Module,
- dp_mesh: DeviceMesh,
- param_dtype: torch.dtype,
- reduce_dtype: torch.dtype,
- reshard_after_forward_policy: str = "default",
- pp_enabled: bool = False,
-):
- """
- Apply FSDP to the vision encoder as a single unit.
-
- Wraps the entire vision encoder with one fully_shard call so all parameters
- are gathered in a single AllGather. The vision encoder's compute is small
- relative to the decoder, so per-layer sharding would launch many small
- AllGather kernels whose total overhead exceeds a single AllGather followed
- by computing all layers in one shot — even without overlap.
-
- Must be called before apply_fsdp on the decoder so the vision encoder is
- already sharded when the final fully_shard(model) is applied.
- """
- mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
- fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
- reshard_after_forward = get_fsdp_reshard_after_forward_policy(
- reshard_after_forward_policy, pp_enabled=pp_enabled
- )
-
- fully_shard(
- vision_encoder,
- **fsdp_config,
- reshard_after_forward=reshard_after_forward,
- )
-
-
-def parallelize_qwen3_vl(
- model: Qwen3VLModel,
- *,
- parallel_dims: ParallelDims,
- training: TrainingConfig,
- parallelism: ParallelismConfig,
- compile_config: CompileConfig,
- ac_config: ActivationCheckpointConfig,
- dump_folder: str,
-):
- """
- Apply tensor parallelism, activation checkpointing, torch.compile, and data
- parallelism to the Qwen3-VL model.
-
- NOTE: The passed-in model preferably should be on meta device. Otherwise,
- the model must fit on GPU or CPU memory.
- """
- if parallelism.spmd_backend == "full_dtensor":
- raise NotImplementedError("full_dtensor is not supported yet.")
-
- model_compile_enabled = (
- compile_config.enable and "model" in compile_config.components
- )
-
- if parallel_dims.cp_enabled:
- raise NotImplementedError("Context Parallel is not yet supported for Qwen3-VL.")
-
- if parallel_dims.tp_enabled:
- # TODO(@fegin): Apply TP to vision encoder (still uses parallelize_module path)
- if model.vision_encoder is not None:
- _apply_tp_to_vision_encoder(
- model.vision_encoder, parallel_dims.get_mesh("tp")
- )
-
- if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
- model.parallelize(parallel_dims)
-
- if parallel_dims.tp_enabled:
- maybe_enable_async_tp(parallelism, compile_config, parallel_dims.get_mesh("tp"))
-
- # Apply activation checkpointing
- if ac_config.mode != "none":
- apply_ac(
- model,
- ac_config,
- model_compile_enabled=model_compile_enabled,
- base_folder=dump_folder,
- )
- if model.vision_encoder is not None:
- apply_ac(
- model.vision_encoder,
- ac_config,
- model_compile_enabled=model_compile_enabled,
- base_folder=dump_folder,
- )
-
- # Apply torch.compile after AC wrapping and before FSDP
- if model_compile_enabled:
- apply_compile(model, compile_config)
- if model.vision_encoder is not None:
- apply_compile(model.vision_encoder, compile_config)
-
- # Apply FSDP / HSDP unconditionally (fully_shard handles dp_shard=1)
- dp_mesh_names = (
- ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
- )
- dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
-
- # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
- edp_mesh = None
- if parallel_dims.ep_enabled:
- edp_mesh_names = (
- ["dp_replicate", "efsdp"]
- if parallel_dims.dp_replicate_enabled
- else ["efsdp"]
- )
- edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names)
-
- # FSDP the vision encoder as a single unit (see _apply_fsdp_to_vision_encoder)
- if model.vision_encoder is not None:
- _apply_fsdp_to_vision_encoder(
- model.vision_encoder,
- dp_mesh,
- param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param],
- reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce],
- reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward,
- pp_enabled=parallel_dims.pp_enabled,
- )
-
- # FSDP the decoder with MoE-aware sharding
- apply_fsdp_to_decoder(
- model,
- dp_mesh,
- param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param],
- reduce_dtype=TORCH_DTYPE_MAP[training.mixed_precision_reduce],
- pp_enabled=parallel_dims.pp_enabled,
- cpu_offload=training.enable_cpu_offload,
- reshard_after_forward_policy=parallelism.fsdp_reshard_after_forward,
- ep_degree=parallel_dims.ep,
- edp_mesh=edp_mesh,
- )
-
- return model
diff --git a/torchtitan/models/qwen3_vl/sharding.py b/torchtitan/models/qwen3_vl/sharding.py
deleted file mode 100644
index d93f105de9..0000000000
--- a/torchtitan/models/qwen3_vl/sharding.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import TYPE_CHECKING
-
-from torchtitan.models.qwen3.sharding import set_qwen3_sharding_config
-
-if TYPE_CHECKING:
- from torchtitan.models.qwen3_vl.model import Qwen3VLModel
-
-
-def set_qwen3_vl_sharding_config(
- config: "Qwen3VLModel.Config",
- *,
- loss_parallel: bool,
- enable_ep: bool,
-) -> None:
- """Fill ``sharding_config`` on all Qwen3-VL sub-configs.
-
- Delegates to ``set_qwen3_sharding_config`` with ``enable_sp=False``
- because Qwen3-VL keeps hidden states as ``Replicate`` (not
- ``Shard(1)``) — no SequenceParallel due to vision scatter and
- DeepStack needing full-sequence access.
- """
- set_qwen3_sharding_config(
- config,
- loss_parallel=loss_parallel,
- enable_sp=False,
- enable_ep=enable_ep,
- )
diff --git a/torchtitan/models/qwen3_vl/state_dict_adapter.py b/torchtitan/models/qwen3_vl/state_dict_adapter.py
deleted file mode 100644
index be5c097f23..0000000000
--- a/torchtitan/models/qwen3_vl/state_dict_adapter.py
+++ /dev/null
@@ -1,341 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""
-State dict adapter for Qwen3-VL.
-
-Converts between HuggingFace Qwen3-VL checkpoint format and torchtitan format.
-
-MoE expert weights require two transformations:
-- **Transpose**: HF and TT use transposed layouts for grouped 3D expert weights.
- E.g. HF down_proj [E, hidden, dim] <-> TT w2 [E, dim, hidden].
-- **Fuse/split gate_up_proj**: HF fuses gate_proj and up_proj into a single
- gate_up_proj [E, dim, 2*hidden_dim]. TT stores them separately as
- w1 [E, hidden_dim, dim] and w3 [E, hidden_dim, dim].
-
-Other notable conversions:
-- Conv3d patch embedding (HF) <-> Linear (TT) via weight reshape
-- Vision block naming: HF `blocks` <-> TT `layers`
-"""
-
-import re
-from typing import Any
-
-import torch
-
-from torchtitan.models.common.attention import FusedQKVLinear
-from torchtitan.protocols.state_dict_adapter import StateDictAdapter
-from .model import Qwen3VLModel
-from .rope import MRoPE
-
-
-class Qwen3VLStateDictAdapter(StateDictAdapter):
- def __init__(self, model_config: Qwen3VLModel.Config, hf_assets_path: str | None):
- super().__init__(model_config, hf_assets_path)
- self.model_config = model_config
- self.fuse_qkv = isinstance(
- model_config.layers[0].attention.qkv_linear, FusedQKVLinear.Config
- )
-
- qkv_map: dict[str, str | None]
- if self.fuse_qkv:
- qkv_map = {
- "model.language_model.layers.{}.self_attn.q_proj.weight": None,
- "model.language_model.layers.{}.self_attn.k_proj.weight": None,
- "model.language_model.layers.{}.self_attn.v_proj.weight": None,
- }
- else:
- qkv_map = {
- "model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.qkv_linear.wq.weight",
- "model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.qkv_linear.wk.weight",
- "model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.qkv_linear.wv.weight",
- }
-
- self.from_hf_map = {
- # ===== Language Model =====
- "model.language_model.embed_tokens.weight": "tok_embeddings.weight",
- # Attention
- **qkv_map,
- "model.language_model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
- "model.language_model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
- "model.language_model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
- "model.language_model.layers.{}.self_attn.rotary_emb.inv_freq": None,
- # Non-MoE MLP
- "model.language_model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
- "model.language_model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
- "model.language_model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
- # Layer norms
- "model.language_model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
- "model.language_model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
- # MoE (grouped 3D format, handled specially in to_hf/from_hf)
- # gate_up_proj is fused gate+up, mapped to w1+w3 via custom logic
- "model.language_model.layers.{}.mlp.experts.down_proj": "layers.{}.moe.experts.w2",
- "model.language_model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight",
- # Final norm and output
- "model.language_model.norm.weight": "norm.weight",
- "lm_head.weight": "lm_head.weight",
- # ===== Vision Encoder =====
- # Patch embedding (Conv3d in HF, Linear in TT - weight reshape needed)
- "model.visual.patch_embed.proj.weight": "vision_encoder.patch_embed.proj.weight",
- "model.visual.patch_embed.proj.bias": "vision_encoder.patch_embed.proj.bias",
- # Position embeddings
- "model.visual.pos_embed.weight": "vision_encoder.pos_embed",
- # Vision transformer blocks (HF: blocks, TT: layers)
- "model.visual.blocks.{}.norm1.weight": "vision_encoder.layers.{}.norm1.weight",
- "model.visual.blocks.{}.norm1.bias": "vision_encoder.layers.{}.norm1.bias",
- "model.visual.blocks.{}.norm2.weight": "vision_encoder.layers.{}.norm2.weight",
- "model.visual.blocks.{}.norm2.bias": "vision_encoder.layers.{}.norm2.bias",
- "model.visual.blocks.{}.attn.qkv.weight": "vision_encoder.layers.{}.attn.qkv.weight",
- "model.visual.blocks.{}.attn.qkv.bias": "vision_encoder.layers.{}.attn.qkv.bias",
- "model.visual.blocks.{}.attn.proj.weight": "vision_encoder.layers.{}.attn.proj.weight",
- "model.visual.blocks.{}.attn.proj.bias": "vision_encoder.layers.{}.attn.proj.bias",
- "model.visual.blocks.{}.mlp.linear_fc1.weight": "vision_encoder.layers.{}.mlp.linear_fc1.weight",
- "model.visual.blocks.{}.mlp.linear_fc1.bias": "vision_encoder.layers.{}.mlp.linear_fc1.bias",
- "model.visual.blocks.{}.mlp.linear_fc2.weight": "vision_encoder.layers.{}.mlp.linear_fc2.weight",
- "model.visual.blocks.{}.mlp.linear_fc2.bias": "vision_encoder.layers.{}.mlp.linear_fc2.bias",
- # Merger (maps vision dim to LLM dim)
- "model.visual.merger.norm.weight": "vision_encoder.merger.norm.weight",
- "model.visual.merger.norm.bias": "vision_encoder.merger.norm.bias",
- "model.visual.merger.linear_fc1.weight": "vision_encoder.merger.linear_fc1.weight",
- "model.visual.merger.linear_fc1.bias": "vision_encoder.merger.linear_fc1.bias",
- "model.visual.merger.linear_fc2.weight": "vision_encoder.merger.linear_fc2.weight",
- "model.visual.merger.linear_fc2.bias": "vision_encoder.merger.linear_fc2.bias",
- # DeepStack mergers
- "model.visual.deepstack_merger_list.{}.norm.weight": "vision_encoder.deepstack_merger_list.{}.norm.weight",
- "model.visual.deepstack_merger_list.{}.norm.bias": "vision_encoder.deepstack_merger_list.{}.norm.bias",
- "model.visual.deepstack_merger_list.{}.linear_fc1.weight": "vision_encoder.deepstack_merger_list.{}.linear_fc1.weight",
- "model.visual.deepstack_merger_list.{}.linear_fc1.bias": "vision_encoder.deepstack_merger_list.{}.linear_fc1.bias",
- "model.visual.deepstack_merger_list.{}.linear_fc2.weight": "vision_encoder.deepstack_merger_list.{}.linear_fc2.weight",
- "model.visual.deepstack_merger_list.{}.linear_fc2.bias": "vision_encoder.deepstack_merger_list.{}.linear_fc2.bias",
- }
-
- def _get_attention_dims(self) -> tuple[int, int, int]:
- """Return (n_heads, n_kv_heads, head_dim) from model config."""
- # pyrefly: ignore [missing-attribute]
- attn = self.model_config.layers[0].attention
- n_heads = attn.n_heads
- n_kv_heads = attn.n_kv_heads if attn.n_kv_heads is not None else n_heads
- head_dim = (
- attn.head_dim
- if attn.head_dim is not None
- else self.model_config.dim // n_heads # pyrefly: ignore [missing-attribute]
- )
- return n_heads, n_kv_heads, head_dim
-
- def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
- """Convert torchtitan state dict to HuggingFace Qwen3-VL format."""
-
- if self.fuse_qkv:
- to_hf_map = {v: k for k, v in self.from_hf_map.items() if v is not None}
- n_heads, n_kv_heads, head_dim = self._get_attention_dims()
- else:
- to_hf_map = {v: k for k, v in self.from_hf_map.items() if v is not None}
- hf_state_dict = {}
-
- # Collect MoE w1/w3 per layer to fuse into gate_up_proj
- moe_w1_by_layer: dict[str, Any] = {}
- moe_w3_by_layer: dict[str, Any] = {}
-
- for tt_key, value in state_dict.items():
- if "moe.experts" in tt_key:
- tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1)
- # pyrefly: ignore [missing-attribute]
- layer_num = re.search(r"\d+", tt_key).group(0)
-
- # Collect w1/w3 for fusing into gate_up_proj
- if tt_abstract_key == "layers.{}.moe.experts.w1":
- moe_w1_by_layer[layer_num] = value
- continue
- elif tt_abstract_key == "layers.{}.moe.experts.w3":
- moe_w3_by_layer[layer_num] = value
- continue
-
- # Handle down_proj: TT w2 [E, dim, hidden] -> HF [E, hidden, dim]
- if tt_abstract_key == "layers.{}.moe.experts.w2":
- hf_key = (
- f"model.language_model.layers.{layer_num}.mlp.experts.down_proj"
- )
- hf_state_dict[hf_key] = value.transpose(-2, -1)
- continue
-
- if tt_abstract_key not in to_hf_map:
- continue
- hf_key = to_hf_map[tt_abstract_key].format(layer_num)
- hf_state_dict[hf_key] = value
- # Indexed key: contains a layer/block/merger index (e.g. ".0.")
- elif re.search(r"\.\d+\.", tt_key):
- tt_abstract_key = re.sub(r"(\d+)", "{}", tt_key, count=1)
- # pyrefly: ignore [missing-attribute]
- layer_num = re.search(r"\d+", tt_key).group(0)
-
- # Handle fused QKV: split wqkv into separate q/k/v projections
- if (
- self.fuse_qkv
- and tt_abstract_key == "layers.{}.attention.qkv_linear.wqkv.weight"
- ):
- wq, wk, wv = self.fused_to_separate_qkv(
- value,
- n_heads, # pyrefly: ignore [unbound-name]
- n_kv_heads, # pyrefly: ignore [unbound-name]
- head_dim, # pyrefly: ignore [unbound-name]
- )
- hf_state_dict[
- f"model.language_model.layers.{layer_num}.self_attn.q_proj.weight"
- ] = wq
- hf_state_dict[
- f"model.language_model.layers.{layer_num}.self_attn.k_proj.weight"
- ] = wk
- hf_state_dict[
- f"model.language_model.layers.{layer_num}.self_attn.v_proj.weight"
- ] = wv
- continue
-
- if tt_abstract_key not in to_hf_map:
- continue
- hf_key = to_hf_map[tt_abstract_key].format(layer_num)
- hf_state_dict[hf_key] = value
-
- else:
- if tt_key not in to_hf_map:
- continue
- if (
- tt_key == "lm_head.weight"
- and self.model_config.enable_weight_tying # pyrefly: ignore [missing-attribute]
- ):
- continue
- hf_key = to_hf_map[tt_key]
- hf_value = value
- # Linear weight (out, C*T*H*W) -> Conv3d weight (out, C, T, H, W)
- # Plain reshape since both use channel-first (c pt ph pw) layout.
- if tt_key == "vision_encoder.patch_embed.proj.weight":
- # pyrefly: ignore [missing-attribute]
- patch_embed = self.model_config.vision_encoder.patch_embed
- hf_value = value.reshape(
- value.shape[0],
- patch_embed.in_channels,
- patch_embed.temporal_patch_size,
- patch_embed.patch_size,
- patch_embed.patch_size,
- )
- hf_state_dict[hf_key] = hf_value
-
- # Fuse w1 (gate) and w3 (up) into gate_up_proj per layer
- # TT w1/w3: [E, hidden_dim, dim] -> transpose to [E, dim, hidden_dim] -> cat on last dim
- for layer_num in moe_w1_by_layer:
- w1 = moe_w1_by_layer[layer_num].transpose(-2, -1) # [E, dim, hidden_dim]
- w3 = moe_w3_by_layer[layer_num].transpose(-2, -1) # [E, dim, hidden_dim]
- gate_up = torch.cat([w1, w3], dim=-1) # [E, dim, 2*hidden_dim]
- hf_key = f"model.language_model.layers.{layer_num}.mlp.experts.gate_up_proj"
- hf_state_dict[hf_key] = gate_up
-
- return hf_state_dict
-
- def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
- """Convert HuggingFace Qwen3-VL state dict to torchtitan format."""
- self._validate_hf_rope_config(MRoPE.Config)
- tt_state_dict = {}
- # Collect Q/K/V per layer for fusing (only used when fuse_qkv=True)
- pending_qkv: dict[str, dict[str, torch.Tensor]] = {}
-
- if self.fuse_qkv:
- n_heads, n_kv_heads, head_dim = self._get_attention_dims()
-
- # HF Qwen3-VL ties lm_head.weight with embed_tokens.weight,
- # so lm_head.weight may not be stored in the checkpoint.
- if "lm_head.weight" not in hf_state_dict:
- if "model.language_model.embed_tokens.weight" not in hf_state_dict:
- raise ValueError(
- "HF checkpoint missing both 'lm_head.weight' and "
- "'model.language_model.embed_tokens.weight'"
- )
- hf_state_dict["lm_head.weight"] = hf_state_dict[
- "model.language_model.embed_tokens.weight"
- ]
-
- for hf_key, value in hf_state_dict.items():
- # Indexed key: contains a layer/block/merger index (e.g. ".0.")
- if re.search(r"\.\d+\.", hf_key):
- hf_abstract_key = re.sub(r"(\d+)", "{}", hf_key, count=1)
- # pyrefly: ignore [missing-attribute]
- idx = re.search(r"\d+", hf_key).group(0)
-
- # Handle fused QKV: collect q/k/v and fuse when all 3 are ready
- if self.fuse_qkv and hf_abstract_key in (
- "model.language_model.layers.{}.self_attn.q_proj.weight",
- "model.language_model.layers.{}.self_attn.k_proj.weight",
- "model.language_model.layers.{}.self_attn.v_proj.weight",
- ):
- if idx not in pending_qkv:
- pending_qkv[idx] = {}
- proj = hf_abstract_key.split(".")[-2] # q_proj, k_proj, v_proj
- pending_qkv[idx][proj] = value
- if len(pending_qkv[idx]) == 3:
- fused = self.separate_to_fused_qkv(
- pending_qkv[idx]["q_proj"],
- pending_qkv[idx]["k_proj"],
- pending_qkv[idx]["v_proj"],
- n_heads, # pyrefly: ignore [unbound-name]
- n_kv_heads, # pyrefly: ignore [unbound-name]
- head_dim, # pyrefly: ignore [unbound-name]
- )
- tt_state_dict[
- f"layers.{idx}.attention.qkv_linear.wqkv.weight"
- ] = fused
- del pending_qkv[idx]
- continue
-
- # Handle fused gate_up_proj: split and transpose
- # HF gate_up_proj: [E, dim, 2*hidden_dim] -> split -> transpose each to [E, hidden_dim, dim]
- if (
- hf_abstract_key
- == "model.language_model.layers.{}.mlp.experts.gate_up_proj"
- ):
- w1_hf, w3_hf = value.chunk(2, dim=-1) # each [E, dim, hidden_dim]
- tt_state_dict[f"layers.{idx}.moe.experts.w1"] = w1_hf.transpose(
- -2, -1
- )
- tt_state_dict[f"layers.{idx}.moe.experts.w3"] = w3_hf.transpose(
- -2, -1
- )
- continue
-
- # Handle down_proj transpose: HF [E, hidden, dim] -> TT w2 [E, dim, hidden]
- if (
- hf_abstract_key
- == "model.language_model.layers.{}.mlp.experts.down_proj"
- ):
- tt_state_dict[f"layers.{idx}.moe.experts.w2"] = value.transpose(
- -2, -1
- )
- continue
-
- if hf_abstract_key not in self.from_hf_map:
- continue
- tt_key = self.from_hf_map[hf_abstract_key]
- if tt_key is None:
- continue
- tt_key = tt_key.format(idx)
- tt_state_dict[tt_key] = value
-
- else:
- if hf_key not in self.from_hf_map:
- continue
- tt_key = self.from_hf_map[hf_key]
- if tt_key is None:
- continue
- tt_value = value
- # Conv3d weight (out, C, T, H, W) -> Linear weight (out, C*T*H*W)
- # Plain reshape since both use channel-first (c pt ph pw) layout.
- if hf_key == "model.visual.patch_embed.proj.weight":
- tt_value = value.reshape(value.shape[0], -1)
- tt_state_dict[tt_key] = tt_value
-
- if self.fuse_qkv and pending_qkv:
- raise ValueError(
- f"Incomplete Q/K/V projections for layers: {list(pending_qkv.keys())}"
- )
-
- return tt_state_dict
diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py
index 96b9ec0b00..5df66cc5d4 100644
--- a/torchtitan/models/utils.py
+++ b/torchtitan/models/utils.py
@@ -521,9 +521,15 @@ def get_moe_model_nparams_and_flops(
nparams_for_matmul = nparams_dense + nparams_sparse_active
else:
nparams_for_matmul = nparams_dense - nparams_embedding + nparams_sparse_active
+ # Only full attention layers contribute the quadratic O(L²) FLOPs
+ # term. Hybrid models mix full attention with linear attention
+ # layers whose block leaves ``attention=None``; standard decoders carry
+ # full attention on every layer, so this counts all of them.
+ num_full_attn = sum(
+ 1 for l in model_config.layers if getattr(l, "attention", None) is not None
+ )
num_flops_per_token = (
- 6 * nparams_for_matmul
- + 6 * len(model_config.layers) * n_heads * head_dims * seq_len
+ 6 * nparams_for_matmul + 6 * num_full_attn * n_heads * head_dims * seq_len
)
return nparams, num_flops_per_token
diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py
index d4209351b4..ec09dc502a 100644
--- a/torchtitan/trainer.py
+++ b/torchtitan/trainer.py
@@ -542,7 +542,7 @@ def batch_generator(
@sl.log_trace_span("post_dataloading_process")
def post_dataloading_process(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
- ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
"""
Post-processing hook after data loading and before model forward pass.
@@ -561,31 +561,21 @@ def post_dataloading_process(
labels: Target labels for the batch.
Returns:
- A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
+ A tuple of (inputs, labels, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (unchanged from input parameter).
- - extra_inputs: Dict of auxiliary input tensors from input_dict
- (excluding "input" and "positions"). These are passed to the
- model forward but are NOT forwarded across pipeline parallel
- stages.
- - extra_kwargs: Dict of additional keyword arguments for model
- forward (positions, attention_masks). These ARE forwarded
- across all pipeline parallel stages.
-
- Note:
- The distinction between extra_inputs and extra_kwargs is important for
- pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
- while extra_inputs are only available to the first stage. Positions
- always go into extra_kwargs so every stage can apply RoPE correctly.
+ - extra_kwargs: Additional keyword arguments for the model forward
+ (e.g. positions, attention_masks), forwarded to every
+ pipeline-parallel stage.
"""
inputs = input_dict["input"]
- extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
- # extra_kwargs are forwarded to all PP stages; extra_inputs are only
- # available to the first stage. Positions go into extra_kwargs so
- # every stage can apply RoPE correctly.
- extra_kwargs: dict[str, Any] = {}
+ # Everything else becomes a model-forward kwarg, forwarded to all PP
+ # stages by the schedule. positions is read here so we can build masks.
+ extra_kwargs: dict[str, Any] = {
+ k: v for k, v in input_dict.items() if k != "input"
+ }
- positions = extra_inputs.pop("positions", None)
+ positions = extra_kwargs.get("positions", None)
# positions and attention_masks are optional (Decoder.forward defaults
# both to None). Build attention masks only for the masked backends
@@ -594,7 +584,9 @@ def post_dataloading_process(
# tests) still receives positions for RoPE but no masks — it relies on
# is_causal instead.
if isinstance(self.model_config, Decoder.Config) and positions is not None:
- inner_attention = self.model_config.layers[0].attention.inner_attention
+ inner_attention = getattr(
+ self.model_config.first_attention, "inner_attention", None
+ )
if isinstance(
inner_attention, (FlexAttention.Config, VarlenAttention.Config)
):
@@ -603,8 +595,6 @@ def post_dataloading_process(
positions=positions,
)
- extra_kwargs["positions"] = positions
-
if self.parallel_dims.cp_enabled:
inputs, labels, extra_kwargs = prepare_context_parallel_input(
inputs,
@@ -624,7 +614,7 @@ def post_dataloading_process(
self.parallel_dims, inputs, labels, extra_kwargs
)
- return inputs, labels, extra_inputs, extra_kwargs
+ return inputs, labels, extra_kwargs
@sl.log_trace_span("fwd_bwd")
def forward_backward_step(
@@ -637,9 +627,7 @@ def forward_backward_step(
model_parts = self.model_parts
parallel_dims = self.parallel_dims
- inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
- input_dict, labels
- )
+ inputs, labels, extra_kwargs = self.post_dataloading_process(input_dict, labels)
if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
@@ -651,7 +639,6 @@ def forward_backward_step(
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
- **extra_inputs,
**extra_kwargs,
target=targets,
losses=losses,
@@ -679,7 +666,7 @@ def forward_backward_step(
# Non-PP forward / backward
assert len(model_parts) == 1
with self.train_context():
- pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
+ pred = model_parts[0](inputs, **extra_kwargs)
# Under non-full_dtensor, labels stay as plain tensors. See
# ``cross_entropy_loss`` for why pred must also be plain.
# Remove once non-full_dtensor is no longer supported.