Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/requirements-vlm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ av
einops
pillow
torchvision
flash-linear-attention
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ testpaths = ["tests"]

[tool.pyrefly]
project-excludes = ["torchtitan/experiments", "**/tests/**"]
replace-imports-with-any = ["torchao.*", "torchft", "torchvision.*", "deep_ep.*", "jinja2.*"] # optional dependencies
replace-imports-with-any = ["torchao.*", "torchft", "torchvision.*", "deep_ep.*", "jinja2.*", "fla.*"] # optional dependencies
search-path = ["../pytorch"] # local built pytorch
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,42 @@
# LICENSE file in the root directory of this source tree.

"""
End-to-end numerical correctness test for Qwen3-VL checkpoint conversion.
End-to-end numerical correctness test for Qwen3.5 checkpoint conversion.

Compares HuggingFace and TorchTitan next-token logits on multimodal inputs
(random image + text prompt). Each pipeline uses its own image preprocessing
so the test validates the full path: pixels → vision encoder → decoder → logits.

Usage:
python -m scripts.checkpoint_conversion.numerical_tests_qwen3_vl \
--hf_model_path /path/to/Qwen3-VL-2B-Instruct \
--tt_checkpoint_path /path/to/qwen3_vl_2b_dcp

python -m scripts.checkpoint_conversion.numerical_tests_qwen3_vl \
--hf_model_path /path/to/Qwen3-VL-30B-A3B-Instruct \
--tt_checkpoint_path /path/to/qwen3_vl_30b_a3b_dcp \
--model_flavor 30B-A3B
python -m scripts.checkpoint_conversion.numerical_tests_qwen3_5 \
--hf_model_path hf_assets/Qwen/Qwen3.5-2B \
--tt_checkpoint_path outputs/Qwen/qwen3_5_2b_dcp \
--model_flavor 2B
"""

import argparse
import os
import types
from typing import Any

import einops as E
import torch
import torch._dynamo
import torch.distributed.checkpoint as dcp
import torch.nn.functional as F
from PIL import Image

torch._dynamo.config.disable = True

from torchtitan.components.checkpoint import ModelWrapper
from torchtitan.models.qwen3_vl import model_registry
from transformers import AutoProcessor
from torchtitan.hf_datasets.multimodal.mm_collator import MultiModalCollator
from torchtitan.hf_datasets.multimodal.utils.image import (
process_image,
vision_to_patches,
)
from torchtitan.models.common.attention import ScaledDotProductAttention
from torchtitan.models.qwen3_5 import model_registry, QWEN3_5_SPECIAL_TOKENS
from transformers import AutoModelForImageTextToText, AutoProcessor


# ============================================================
Expand All @@ -44,7 +50,7 @@


def kl_divergence(logits_a, logits_b):
"""KL(a || b) between two logit tensors."""
"""KL(softmax(b) || softmax(a)) — F.kl_div(log Q, P) computes KL(P || Q)."""
return F.kl_div(
F.log_softmax(logits_a, dim=-1),
F.softmax(logits_b, dim=-1),
Expand Down Expand Up @@ -77,25 +83,32 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):

Returns:
hf_inputs: list of dicts (processor output, ready for HF model)
tt_inputs: list of (input_ids, pixel_values, grid_thw)
tt_inputs: list of (input_ids, pixel_values, grid_thw, mrope_positions)
pixel_comparisons: list of per-sample pixel diff stats
special_tokens: {"image_id", "video_id"} from the tokenizer
"""
import einops as E
from PIL import Image

from torchtitan.hf_datasets.multimodal.utils.image import (
process_image,
vision_to_patches,
)

processor = AutoProcessor.from_pretrained(hf_model_path, trust_remote_code=True)
# Annotate as Any: AutoProcessor.from_pretrained is typed Optional, which
# trips the .apply_chat_template call on environments without transformers stubs.
processor: Any = AutoProcessor.from_pretrained(hf_model_path)

model_config = model_registry(model_flavor).model
encoder_config = model_config.vision_encoder # pyrefly: ignore[missing-attribute]
patch_size = encoder_config.patch_embed.patch_size
temporal_patch_size = encoder_config.patch_embed.temporal_patch_size
# pyrefly: ignore [missing-attribute]
encoder_config = model_config.vision_encoder
patch_size = encoder_config.patch_size
temporal_patch_size = encoder_config.temporal_patch_size
merge_size = encoder_config.spatial_merge_size

image_token_id = processor.tokenizer.convert_tokens_to_ids(
QWEN3_5_SPECIAL_TOKENS["image_token"]
)
video_token_id = processor.tokenizer.convert_tokens_to_ids(
QWEN3_5_SPECIAL_TOKENS["video_token"]
)
special_tokens = {"image_id": image_token_id, "video_id": video_token_id}
# Reuse the collator's MRoPE builder (only needs spatial_merge_size) so the
# 3D position IDs match the training path exactly.
mrope_builder = types.SimpleNamespace(spatial_merge_size=merge_size)

hf_inputs, tt_inputs, pixel_comparisons = [], [], []

for i in range(num_samples):
Expand All @@ -117,7 +130,7 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
],
}
]
hf_in = processor.apply_chat_template( # pyrefly: ignore[missing-attribute]
hf_in = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
Expand All @@ -137,8 +150,23 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
temporal_patch_size,
merge_size,
)
# 3D MRoPE positions (1, S, 3); positions=None → single document.
mrope_positions = MultiModalCollator._build_mrope_positions(
mrope_builder, # pyrefly: ignore [bad-argument-type]
hf_in["input_ids"],
grid_thw.unsqueeze(0),
None,
None,
image_token_id=image_token_id,
video_token_id=video_token_id,
)
tt_inputs.append(
(hf_in["input_ids"], patches.unsqueeze(0), grid_thw.unsqueeze(0))
(
hf_in["input_ids"],
patches.unsqueeze(0),
grid_thw.unsqueeze(0),
mrope_positions,
)
)

# --- Compare pixel values in image space ---
Expand All @@ -162,7 +190,7 @@ def build_inputs(hf_model_path, model_flavor, num_samples, image_size=224):
tt_img = E.rearrange(patches, pattern, **kwargs)
pixel_comparisons.append(_compare_images(hf_img[:1], tt_img[:1], i))

return hf_inputs, tt_inputs, pixel_comparisons
return hf_inputs, tt_inputs, pixel_comparisons, special_tokens


def _compare_images(hf_img, tt_img, sample_idx):
Expand Down Expand Up @@ -207,13 +235,11 @@ def print_pixel_comparisons(comparisons):
@torch.no_grad()
def run_hf(model_path, hf_inputs, device):
"""Run HF model, return last-token logits per sample."""
from transformers import AutoModelForImageTextToText

print(f"Loading HuggingFace model on {device} ...")
model = AutoModelForImageTextToText.from_pretrained(
model_path,
device_map=device,
torch_dtype=torch.float16,
dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
Expand All @@ -240,15 +266,15 @@ def run_hf(model_path, hf_inputs, device):


@torch.no_grad()
def run_tt(model_flavor, checkpoint_path, tt_inputs, device):
def run_tt(model_flavor, checkpoint_path, tt_inputs, special_tokens, device):
"""Run TT model, return last-token logits per sample."""
print(f"Loading TorchTitan model on {device} ...")

model_config = model_registry(model_flavor).model
with torch.device("meta"):
model = model_config.build()
model.to_empty(device="cpu")
model.init_weights(buffer_device=torch.device("cpu"))
model.init_states(buffer_device=torch.device("cpu"))
model.half()

state_dict = ModelWrapper(model)._get_state_dict()
Expand All @@ -258,11 +284,9 @@ def run_tt(model_flavor, checkpoint_path, tt_inputs, device):

# Replace FlexAttention with SDPA for single-process inference
# (unfused FlexAttention without torch.compile has poor fp16 numerics).
from torchtitan.models.common.attention import ScaledDotProductAttention

for layer in model.layers.values():
layer.attention.attn_backend = "sdpa"
layer.attention.inner_attention = ScaledDotProductAttention.Config().build()
if layer.full_attn:
layer.attn.inner_attention = ScaledDotProductAttention.Config().build()

class _BidirectionalSDPA(torch.nn.Module):
def forward(self, q, k, v, **kwargs):
Expand All @@ -278,20 +302,13 @@ def forward(self, q, k, v, **kwargs):

model.eval()

special_tokens = {
"image_id": 151655,
"video_id": 151656,
"vision_start_id": 151652,
"vision_end_id": 151653,
"pad_id": 151643,
}

outputs = []
for i, (tokens, pixel_values, grid_thw) in enumerate(tt_inputs):
for i, (tokens, pixel_values, grid_thw, mrope_positions) in enumerate(tt_inputs):
logits = model(
tokens.to(device),
pixel_values=pixel_values.half().to(device),
grid_thw=grid_thw.to(device),
mrope_positions=mrope_positions.to(device),
special_tokens=special_tokens,
)
outputs.append(logits[:, -1:, :].cpu())
Expand Down Expand Up @@ -319,12 +336,13 @@ def compare(hf_outputs, tt_outputs):
kl = kl_divergence(hf, tt).item()
top1, top5 = top_k_match(hf, tt)
cos = F.cosine_similarity(hf.flatten(), tt.flatten(), dim=0).item()
max_diff = (hf - tt).abs().max().item()
total_kl += kl
total_top1 += top1
total_top5 += top5
print(
f" Sample {i + 1:2d}: KL={kl:.4e} cos={cos:.6f} "
f"top1={top1:.0%} top5={top5:.0%}"
f"max_diff={max_diff:.4e} top1={top1:.0%} top5={top5:.0%}"
)

n = len(hf_outputs)
Expand All @@ -351,16 +369,11 @@ def compare(hf_outputs, tt_outputs):

def main():
parser = argparse.ArgumentParser(
description="End-to-end numerical correctness test for Qwen3-VL.",
description="End-to-end numerical correctness test for Qwen3.5.",
)
parser.add_argument("--hf_model_path", type=str, required=True)
parser.add_argument("--tt_checkpoint_path", type=str, required=True)
parser.add_argument(
"--model_flavor",
type=str,
default="2B",
choices=["2B", "8B", "30B-A3B", "235B-A22B"],
)
parser.add_argument("--model_flavor", type=str, default="4B")
parser.add_argument("--num_samples", type=int, default=10)
args = parser.parse_args()

Expand All @@ -373,7 +386,7 @@ def main():
print(f"Using {torch.cuda.get_device_name(0)}")

print(f"\nBuilding {args.num_samples} test samples ...")
hf_inputs, tt_inputs, pixel_comparisons = build_inputs(
hf_inputs, tt_inputs, pixel_comparisons, special_tokens = build_inputs(
args.hf_model_path,
args.model_flavor,
args.num_samples,
Expand All @@ -388,6 +401,7 @@ def main():
args.model_flavor,
args.tt_checkpoint_path,
tt_inputs,
special_tokens,
device,
)

Expand Down
Loading
Loading