From 621af712aedfb1ee6ffcff2d76c032705978e98e Mon Sep 17 00:00:00 2001 From: jglee-sqbits Date: Tue, 24 Mar 2026 08:37:27 +0000 Subject: [PATCH] [Models] Add QwenImage pipeline (rebased onto main) --- .../diffusion/simple_offline_generation.py | 113 ++- .../max/pipelines/architectures/__init__.py | 5 + .../architectures/qwen_image/__init__.py | 17 + .../architectures/qwen_image/arch.py | 61 ++ .../qwen_image/layers/__init__.py | 25 + .../qwen_image/layers/embeddings.py | 274 ++++++ .../qwen_image/layers/normalizations.py | 83 ++ .../qwen_image/layers/qwen_image_attention.py | 641 ++++++++++++++ .../architectures/qwen_image/model.py | 88 ++ .../architectures/qwen_image/model_config.py | 59 ++ .../qwen_image/pipeline_qwen_image.py | 828 ++++++++++++++++++ .../architectures/qwen_image/qwen_image.py | 252 ++++++ .../qwen_image/weight_adapters.py | 31 + .../architectures/qwen_image/BUILD.bazel | 50 ++ .../architectures/qwen_image/conftest.py | 191 ++++ .../qwen_image/test_attention.py | 192 ++++ .../qwen_image/test_scheduler_parity.py | 181 ++++ .../qwen_image/test_text_encoder_parity.py | 220 +++++ .../qwen_image/testdata/BUILD.bazel | 16 + .../qwen_image/testdata/config.json | 12 + 20 files changed, 3312 insertions(+), 27 deletions(-) create mode 100644 max/python/max/pipelines/architectures/qwen_image/__init__.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/arch.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/layers/__init__.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/layers/embeddings.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/layers/normalizations.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/model.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/model_config.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/pipeline_qwen_image.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/qwen_image.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/weight_adapters.py create mode 100644 max/tests/integration/architectures/qwen_image/BUILD.bazel create mode 100644 max/tests/integration/architectures/qwen_image/conftest.py create mode 100644 max/tests/integration/architectures/qwen_image/test_attention.py create mode 100644 max/tests/integration/architectures/qwen_image/test_scheduler_parity.py create mode 100644 max/tests/integration/architectures/qwen_image/test_text_encoder_parity.py create mode 100644 max/tests/integration/architectures/qwen_image/testdata/BUILD.bazel create mode 100644 max/tests/integration/architectures/qwen_image/testdata/config.json diff --git a/max/examples/diffusion/simple_offline_generation.py b/max/examples/diffusion/simple_offline_generation.py index 4f983c43799..969220b500b 100644 --- a/max/examples/diffusion/simple_offline_generation.py +++ b/max/examples/diffusion/simple_offline_generation.py @@ -83,6 +83,17 @@ "Flux2Pipeline_ModuleV3", "Flux2KleinPipeline_ModuleV3", } +QWEN_IMAGE_ARCH_NAMES = { + "QwenImagePipeline", + "QwenImageEditPipeline", + "QwenImageEditPlusPipeline", +} +QWEN_IMAGE_EDIT_ARCH_NAMES = { + "QwenImageEditPipeline", + "QwenImageEditPlusPipeline", +} +QWEN_DEFAULT_GUIDANCE_SCALE = 1.0 +QWEN_DEFAULT_TRUE_CFG_SCALE = 4.0 def parse_args(argv: list[str] | None = None) -> argparse.Namespace: @@ -158,8 +169,21 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument( "--guidance-scale", type=float, - default=3.5, - help="Guidance scale for classifier-free guidance. Set to 1.0 to disable CFG.", + default=None, + help=( + "Guidance scale for classifier-free guidance. " + "If omitted, defaults to 1.0 for QwenImage family and 3.5 otherwise." + ), + ) + parser.add_argument( + "--true-cfg-scale", + type=float, + default=None, + help=( + "True classifier-free guidance scale. " + "If omitted, defaults to 4.0 for QwenImage family when negative prompt is provided, " + "and 1.0 otherwise." + ), ) parser.add_argument( "--seed", @@ -188,8 +212,9 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument( "--input-image", type=str, + action="append", default=None, - help="Input image for image-to-image generation.", + help="Input image for image-to-image generation. Can be specified multiple times.", ) parser.add_argument( "--profile-timings", @@ -263,16 +288,19 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: assert args.num_inference_steps > 0, ( "num-inference-steps must be a positive integer." ) - assert args.guidance_scale > 0.0, "guidance-scale must be positive." - if args.residual_threshold is not None: + if args.guidance_scale is not None: + assert args.guidance_scale > 0.0, "guidance-scale must be positive." + if args.true_cfg_scale is not None: + assert args.true_cfg_scale > 0.0, "true-cfg-scale must be positive." + if hasattr(args, 'residual_threshold') and args.residual_threshold is not None: assert args.residual_threshold >= 0.0, ( "residual-threshold must be non-negative." ) - if args.taylorseer_cache_interval is not None: + if hasattr(args, 'taylorseer_cache_interval') and args.taylorseer_cache_interval is not None: assert args.taylorseer_cache_interval >= 1, ( "taylorseer-cache-interval must be >= 1." ) - if args.taylorseer_warmup_steps is not None: + if hasattr(args, 'taylorseer_warmup_steps') and args.taylorseer_warmup_steps is not None: assert args.taylorseer_warmup_steps >= 1, ( "taylorseer-warmup-steps must be >= 1." ) @@ -381,6 +409,8 @@ async def generate_image(args: argparse.Namespace) -> None: ) if arch.name in _FLUX2_ARCH_NAMES: max_length = 512 + elif arch.name in QWEN_IMAGE_ARCH_NAMES: + max_length = 512 print(f"Using max length: {max_length} for tokenizer") if ( @@ -433,27 +463,50 @@ async def generate_image(args: argparse.Namespace) -> None: print(f"Generating image for prompt: '{args.prompt}'") # Step 4: Create an OpenResponsesRequest - # Load input image if provided and convert to data URI - input_image_data_uri = load_image_as_data_uri(args.input_image) + # Load input images if provided and convert to data URIs + input_image_data_uris: list[str] = [] + if args.input_image: + for img_path in args.input_image: + uri = load_image_as_data_uri(img_path) + if uri is not None: + input_image_data_uris.append(uri) + + is_qwen_image_family = arch.name in QWEN_IMAGE_ARCH_NAMES + guidance_scale = args.guidance_scale + if guidance_scale is None: + guidance_scale = ( + QWEN_DEFAULT_GUIDANCE_SCALE if is_qwen_image_family else 3.5 + ) + + true_cfg_scale = args.true_cfg_scale + if true_cfg_scale is None: + if is_qwen_image_family and args.negative_prompt is not None: + true_cfg_scale = QWEN_DEFAULT_TRUE_CFG_SCALE + else: + true_cfg_scale = 1.0 - # Create request with structured message if image is provided - if input_image_data_uri: + # Create request with structured message if images are provided + if input_image_data_uris: # Image-to-image: Use structured message with InputImageContent + InputTextContent + image_content_items: list[InputImageContent | InputTextContent] = [ + InputImageContent( + type="input_image", + image_url=uri, + ) + for uri in input_image_data_uris + ] + image_content_items.append( + InputTextContent( + type="input_text", + text=args.prompt, + ) + ) body = OpenResponsesRequestBody( model=args.model, input=[ UserMessage( role="user", - content=[ - InputImageContent( - type="input_image", - image_url=input_image_data_uri, - ), - InputTextContent( - type="input_text", - text=args.prompt, - ), - ], + content=image_content_items, ) ], seed=args.seed, @@ -463,7 +516,8 @@ async def generate_image(args: argparse.Namespace) -> None: height=args.height, width=args.width, steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, ) ), ) @@ -479,7 +533,8 @@ async def generate_image(args: argparse.Namespace) -> None: height=args.height, width=args.width, steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, ) ), ) @@ -487,7 +542,8 @@ async def generate_image(args: argparse.Namespace) -> None: request = OpenResponsesRequest(request_id=RequestID(), body=body) print( - f"Parameters: steps={args.num_inference_steps}, guidance={args.guidance_scale}" + "Parameters: " + f"steps={args.num_inference_steps}, guidance={guidance_scale}, true_cfg={true_cfg_scale}" ) # Step 5: Create a PixelContext object from the request @@ -545,16 +601,19 @@ async def generate_image(args: argparse.Namespace) -> None: height=args.height, width=args.width, steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, ) ), ) request_warmup = OpenResponsesRequest( request_id=RequestID(), body=body_warmup ) - input_image = Image.open(args.input_image) if args.input_image else None + warmup_image = ( + Image.open(args.input_image[0]) if args.input_image else None + ) context_warmup = await tokenizer.new_context( - request_warmup, input_image=input_image + request_warmup, input_image=warmup_image ) inputs_warmup = PixelGenerationInputs[PixelContext]( batch={context_warmup.request_id: context_warmup} diff --git a/max/python/max/pipelines/architectures/__init__.py b/max/python/max/pipelines/architectures/__init__.py index fe7e2771c56..61c0a49b417 100644 --- a/max/python/max/pipelines/architectures/__init__.py +++ b/max/python/max/pipelines/architectures/__init__.py @@ -79,6 +79,8 @@ def register_all_models() -> None: from .qwen3_embedding import qwen3_embedding_arch from .qwen3_embedding_modulev3 import qwen3_embedding_modulev3_arch from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch + from .qwen_image import qwen_image_arch + from .qwen_image_edit import qwen_image_edit_arch, qwen_image_edit_plus_arch from .unified_eagle_llama3 import unified_eagle_llama3_arch from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch @@ -131,6 +133,9 @@ def register_all_models() -> None: qwen3_embedding_modulev3_arch, qwen3vl_arch, qwen3vl_moe_arch, + qwen_image_arch, + qwen_image_edit_arch, + qwen_image_edit_plus_arch, unified_eagle_llama3_arch, unified_mtp_deepseekV3_arch, ] diff --git a/max/python/max/pipelines/architectures/qwen_image/__init__.py b/max/python/max/pipelines/architectures/qwen_image/__init__.py new file mode 100644 index 00000000000..d56d2ba2d67 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/__init__.py @@ -0,0 +1,17 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .arch import qwen_image_arch +from .model import QwenImageTransformerModel + +__all__ = ["QwenImageTransformerModel", "qwen_image_arch"] diff --git a/max/python/max/pipelines/architectures/qwen_image/arch.py b/max/python/max/pipelines/architectures/qwen_image/arch.py new file mode 100644 index 00000000000..8e4779f101d --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/arch.py @@ -0,0 +1,61 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from dataclasses import dataclass + +from max.graph.weights import WeightsFormat +from max.interfaces import PipelineTask +from max.pipelines.core import PixelContext +from max.pipelines.lib import ( + PixelGenerationTokenizer, + SupportedArchitecture, +) +from max.pipelines.lib.config import PipelineConfig +from max.pipelines.lib.interfaces import ArchConfig +from typing_extensions import Self + +from .pipeline_qwen_image import QwenImagePipeline + + +@dataclass(kw_only=True) +class QwenImageArchConfig(ArchConfig): + """Pipeline-level config for QwenImage (implements ArchConfig; no KV cache).""" + + pipeline_config: PipelineConfig + + def get_max_seq_len(self) -> int: + return 0 # Not used for pixel generation. + + @classmethod + def initialize(cls, pipeline_config: PipelineConfig) -> Self: + if len(pipeline_config.model.device_specs) != 1: + raise ValueError("QwenImage is only supported on a single device") + return cls(pipeline_config=pipeline_config) + + +qwen_image_arch = SupportedArchitecture( + name="QwenImagePipeline", + task=PipelineTask.PIXEL_GENERATION, + default_encoding="bfloat16", + supported_encodings={"bfloat16": []}, + example_repo_ids=[ + "Qwen/Qwen-Image-2512", + ], + pipeline_model=QwenImagePipeline, # type: ignore[arg-type] + context_type=PixelContext, + default_weights_format=WeightsFormat.safetensors, + tokenizer=PixelGenerationTokenizer, + config=QwenImageArchConfig, +) diff --git a/max/python/max/pipelines/architectures/qwen_image/layers/__init__.py b/max/python/max/pipelines/architectures/qwen_image/layers/__init__.py new file mode 100644 index 00000000000..7e25fdb6f46 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/layers/__init__.py @@ -0,0 +1,25 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .embeddings import QwenImagePosEmbed, QwenImageTimestepProjEmbeddings +from .qwen_image_attention import ( + QwenImageFeedForward, + QwenImageTransformerBlock, +) + +__all__ = [ + "QwenImageFeedForward", + "QwenImagePosEmbed", + "QwenImageTimestepProjEmbeddings", + "QwenImageTransformerBlock", +] diff --git a/max/python/max/pipelines/architectures/qwen_image/layers/embeddings.py b/max/python/max/pipelines/architectures/qwen_image/layers/embeddings.py new file mode 100644 index 00000000000..f137a135d2c --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/layers/embeddings.py @@ -0,0 +1,274 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Embeddings for QwenImage transformer: timestep projection and 3D RoPE.""" + +import math + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.layer import Module +from max.nn.linear import Linear + + +def get_timestep_embedding( + timesteps: TensorValue, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> TensorValue: + """Create sinusoidal timestep embeddings.""" + half_dim = embedding_dim // 2 + + exponent = -math.log(max_period) * ops.range( + 0, half_dim, dtype=DType.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = ops.exp(exponent) + timesteps_f32 = ops.cast(timesteps, DType.float32) + emb = ops.outer(timesteps_f32, emb) * scale + emb = ops.concat([ops.sin(emb), ops.cos(emb)], axis=-1) + + if flip_sin_to_cos: + emb = ops.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + + if embedding_dim % 2 == 1: + emb = ops.pad(emb, [0, 0, 0, 1]) + + return emb + + +def apply_rotary_emb( + x: TensorValue, + freqs_cis: tuple[TensorValue, TensorValue], + sequence_dim: int = 1, +) -> TensorValue: + """Apply rotary embeddings to input tensor (complex-multiply path). + + Matches diffusers' ``use_real=False`` path: + view x as complex pairs, multiply by ``cos + i·sin``, flatten back. + + Because MAX graph has no complex dtype we expand the multiplication + manually: ``(x_re + i·x_im)(cos + i·sin)`` + = ``(x_re·cos - x_im·sin) + i·(x_re·sin + x_im·cos)`` + + Args: + x: Input tensor [B, S, H, D] (sequence_dim=1). + freqs_cis: ``(cos, sin)`` each of shape ``[S, D//2]``. + sequence_dim: Dimension index for sequence length (1 or 2). + """ + cos, sin = freqs_cis # [S, D//2] + + # Broadcast freqs to match x layout + if sequence_dim == 2: + # x: [B, H, S, D] → cos/sin: [1, 1, S, D//2] + cos = ops.unsqueeze(ops.unsqueeze(cos, 0), 0) + sin = ops.unsqueeze(ops.unsqueeze(sin, 0), 0) + elif sequence_dim == 1: + # x: [B, S, H, D] → cos/sin: [1, S, 1, D//2] + cos = ops.unsqueeze(ops.unsqueeze(cos, 0), 2) + sin = ops.unsqueeze(ops.unsqueeze(sin, 0), 2) + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + input_dtype = x.dtype + x_shape = list(x.shape) + + # Split last dim into (D//2, 2) pairs — real and imaginary parts + x_pairs = ops.reshape(x, x_shape[:-1] + [x_shape[-1] // 2, 2]) + x_re = x_pairs[..., 0] # [B, S, H, D//2] + x_im = x_pairs[..., 1] # [B, S, H, D//2] + + # Complex multiply in float32 + x_re = ops.cast(x_re, DType.float32) + x_im = ops.cast(x_im, DType.float32) + cos = ops.cast(cos, DType.float32) + sin = ops.cast(sin, DType.float32) + + out_re = x_re * cos - x_im * sin + out_im = x_re * sin + x_im * cos + + # Interleave back: [B, S, H, D//2, 2] → [B, S, H, D] + out = ops.stack([out_re, out_im], axis=-1) + out = ops.reshape(out, x_shape) + return ops.cast(out, input_dtype) + + +def get_1d_rotary_pos_embed( + dim: int, + pos: TensorValue, + theta: float = 10000.0, +) -> tuple[TensorValue, TensorValue]: + """Precompute rotary position embeddings for one axis. + + Returns ``(cos, sin)`` each of shape ``[S, dim // 2]``, matching the + complex-multiply convention used by diffusers. + """ + if dim % 2 != 0: + raise ValueError(f"dim must be even, got {dim}") + freq_exponent = ( + ops.range( + 0, + dim, + 2, + dtype=DType.float32, + device=pos.device, + ) + / dim + ) + freq = 1.0 / (theta**freq_exponent) + freqs = ops.outer(pos, freq) # [S, dim // 2] + return ops.cos(freqs), ops.sin(freqs) + + +class Timesteps(Module): + def __init__( + self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: float = 1.0, + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def __call__(self, timesteps: TensorValue) -> TensorValue: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class TimestepEmbedding(Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + sample_proj_bias: bool = True, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.GPU(), + ) -> None: + super().__init__() + self.linear_1 = Linear( + in_dim=in_channels, + out_dim=time_embed_dim, + dtype=dtype, + device=device, + has_bias=sample_proj_bias, + ) + self.linear_2 = Linear( + in_dim=time_embed_dim, + out_dim=time_embed_dim, + dtype=dtype, + device=device, + has_bias=sample_proj_bias, + ) + + def __call__(self, sample: TensorValue) -> TensorValue: + sample = self.linear_1(sample) + sample = ops.silu(sample) + sample = self.linear_2(sample) + return sample + + +class QwenImageTimestepProjEmbeddings(Module): + """Timestep-only projection embeddings (no guidance embedding). + + Unlike Flux2 which combines timestep + guidance, QwenImage only uses timestep + since guidance_embeds=False. + """ + + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 3072, + bias: bool = False, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.GPU(), + ): + super().__init__() + self.time_proj = Timesteps( + num_channels=in_channels, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=embedding_dim, + sample_proj_bias=bias, + dtype=dtype, + device=device, + ) + + def __call__(self, timestep: TensorValue) -> TensorValue: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + ops.cast(timesteps_proj, timestep.dtype) + ) + return timesteps_emb + + +class QwenImagePosEmbed(Module): + """3D Rotary Position Embeddings for QwenImage. + + Uses axes_dims_rope = (16, 56, 56) for (T, H, W) dimensions, + compared to Flux2's 4D (32, 32, 32, 32). + """ + + theta: int + axes_dim: tuple[int, ...] + + def __init__(self, theta: int, axes_dim: tuple[int, ...]): + super().__init__() + self.theta = theta + self.axes_dim = tuple(axes_dim) + + def __call__(self, ids: TensorValue) -> tuple[TensorValue, TensorValue]: + """Compute rotary position embeddings from position IDs. + + Args: + ids: Position IDs of shape [S, len(axes_dim)] (3D: T, H, W). + + Returns: + Tuple of (cos, sin) tensors of shape [S, sum(axes_dim)//2]. + """ + cos_out = [] + sin_out = [] + + pos = ops.cast(ids, DType.float32) + + for i in range(len(self.axes_dim)): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[..., i], + theta=self.theta, + ) + cos_out.append(cos) + sin_out.append(sin) + + freqs_cos = ops.concat(cos_out, axis=-1) + freqs_sin = ops.concat(sin_out, axis=-1) + + return freqs_cos, freqs_sin diff --git a/max/python/max/pipelines/architectures/qwen_image/layers/normalizations.py b/max/python/max/pipelines/architectures/qwen_image/layers/normalizations.py new file mode 100644 index 00000000000..ffb12af1c49 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/layers/normalizations.py @@ -0,0 +1,83 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Normalization layers for QwenImage transformer (module v2).""" + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.layer import Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + + +class LayerNormNoAffine(Module): + """LayerNorm over the last dimension without learned affine parameters.""" + + def __init__(self, eps: float = 1e-5) -> None: + super().__init__() + self.eps = eps + + def __call__(self, x: TensorValue) -> TensorValue: + dim = x.shape[-1] + gamma = ops.broadcast_to( + ops.constant(1.0, x.dtype, device=x.device), + shape=[dim], + ) + beta = ops.broadcast_to( + ops.constant(0.0, x.dtype, device=x.device), + shape=[dim], + ) + return ops.layer_norm( + x, + gamma=gamma, + beta=beta, + epsilon=self.eps, + ) + + +class AdaLayerNormContinuous(Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + *, + dtype: DType, + device: DeviceRef, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + ) -> None: + super().__init__() + self.linear = Linear( + in_dim=conditioning_embedding_dim, + out_dim=embedding_dim * 2, + dtype=dtype, + device=device, + has_bias=bias, + ) + if norm_type == "layer_norm": + self.norm: Module = LayerNormNoAffine(eps=eps) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, dtype=dtype, eps=eps) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def __call__( + self, + x: TensorValue, + conditioning_embedding: TensorValue, + ) -> TensorValue: + emb = self.linear(ops.cast(ops.silu(conditioning_embedding), x.dtype)) + width = x.shape[-1] + scale, shift = ops.split(emb, [width, width], axis=1) + return self.norm(x) * (1 + scale[:, None, :]) + shift[:, None, :] diff --git a/max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py b/max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py new file mode 100644 index 00000000000..41f8a966374 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py @@ -0,0 +1,641 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage attention layers: dual-stream attention, FeedForward, and transformer block. + +Weight key naming follows HuggingFace diffusers conventions: +- Attention: attn.to_q, attn.to_k, attn.to_v, attn.to_out.0, attn.add_q_proj, etc. +- FeedForward: img_mlp.net.0.proj (SwiGLU), img_mlp.net.2 (output linear) +- Modulation: img_mod.1 (Linear after SiLU), txt_mod.1 +- Norms: img_norm1, img_norm2, txt_norm1, txt_norm2 (no affine, no weights) +""" + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + +from .embeddings import apply_rotary_emb +from .normalizations import LayerNormNoAffine + +# --------------------------------------------------------------------------- +# FeedForward (matches diffusers naming: net.0.proj, net.2) +# --------------------------------------------------------------------------- + + +class _QwenImageGELU(Module): + """GELU approximate activation with a Linear projection. + + Weight key: `proj.weight`, `proj.bias` + In the block: `img_mlp.net.0.proj.weight` + """ + + def __init__( + self, + dim_in: int, + dim_out: int, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + self.proj = Linear( + in_dim=dim_in, + out_dim=dim_out, + dtype=dtype, + device=device, + has_bias=bias, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + return ops.gelu(self.proj(x)) + + +class _QwenImageDropout(Module): + """No-op dropout for inference. Occupies index 1 in FeedForward.net.""" + + def __init__(self): + super().__init__() + + def __call__(self, x: TensorValue) -> TensorValue: + return x + + +class QwenImageFeedForward(Module): + """FeedForward matching diffusers key naming. + + Produces keys: + net.0.proj.weight, net.0.proj.bias (GELU approximate projection) + net.2.weight, net.2.bias (output linear) + """ + + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 4.0, + inner_dim: int | None = None, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + self.net: LayerList = LayerList( + [ + _QwenImageGELU( + dim, inner_dim, bias=bias, dtype=dtype, device=device + ), + _QwenImageDropout(), + Linear( + in_dim=inner_dim, + out_dim=dim_out, + dtype=dtype, + device=device, + has_bias=bias, + ), + ] + ) + + def __call__(self, x: TensorValue) -> TensorValue: + x = self.net[0](x) # GELU projection + # net[1] is dropout (no-op at inference) + x = self.net[2](x) # output linear + return x + + +# --------------------------------------------------------------------------- +# Attention (matches diffusers key naming: to_q, to_k, to_v, to_out.0, ...) +# --------------------------------------------------------------------------- + + +class QwenImageAttention(Module): + """Dual-stream attention for QwenImage transformer blocks. + + Key naming matches HuggingFace diffusers: + - to_q.weight/bias, to_k.weight/bias, to_v.weight/bias + - to_out.0.weight/bias (LayerList for correct .0. indexing) + - add_q_proj.weight/bias, add_k_proj.weight/bias, add_v_proj.weight/bias + - to_add_out.weight/bias + - norm_q.weight, norm_k.weight, norm_added_q.weight, norm_added_k.weight + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + bias: bool = True, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int | None = None, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.scale = 1.0 / (self.head_dim**0.5) + out_dim = out_dim if out_dim is not None else query_dim + + self.to_q = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=bias, + ) + self.to_k = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=bias, + ) + self.to_v = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=bias, + ) + + self.norm_q = RMSNorm(dim_head, dtype=dtype, eps=eps) + self.norm_k = RMSNorm(dim_head, dtype=dtype, eps=eps) + + # Use LayerList so key becomes to_out.0.weight (not to_out_0.weight) + self.to_out: LayerList = LayerList( + [ + Linear( + in_dim=self.inner_dim, + out_dim=out_dim, + dtype=dtype, + device=device, + has_bias=out_bias, + ) + ] + ) + + self.norm_added_q: RMSNorm | None + self.norm_added_k: RMSNorm | None + self.add_q_proj: Linear | None + self.add_k_proj: Linear | None + self.add_v_proj: Linear | None + self.to_add_out: Linear | None + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, dtype=dtype, eps=eps) + self.norm_added_k = RMSNorm(dim_head, dtype=dtype, eps=eps) + self.add_q_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=added_proj_bias, + ) + self.add_k_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=added_proj_bias, + ) + self.add_v_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=added_proj_bias, + ) + self.to_add_out = Linear( + in_dim=self.inner_dim, + out_dim=query_dim, + dtype=dtype, + device=device, + has_bias=out_bias, + ) + else: + self.norm_added_q = None + self.norm_added_k = None + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + self.to_add_out = None + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue | None = None, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + ) -> TensorValue | tuple[TensorValue, TensorValue]: + batch_size = hidden_states.shape[0] + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + seq_len = query.shape[1] + + query = ops.reshape( + query, [batch_size, seq_len, self.heads, self.head_dim] + ) + key = ops.reshape(key, [batch_size, seq_len, self.heads, self.head_dim]) + value = ops.reshape( + value, [batch_size, seq_len, self.heads, self.head_dim] + ) + + query = self.norm_q(query) + key = self.norm_k(key) + + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + if ( + self.add_q_proj is None + or self.add_k_proj is None + or self.add_v_proj is None + ): + raise ValueError("Encoder projections not initialized") + encoder_query = self.add_q_proj(encoder_hidden_states) + encoder_key = self.add_k_proj(encoder_hidden_states) + encoder_value = self.add_v_proj(encoder_hidden_states) + encoder_seq_len = encoder_query.shape[1] + encoder_query = ops.reshape( + encoder_query, + [batch_size, encoder_seq_len, self.heads, self.head_dim], + ) + encoder_key = ops.reshape( + encoder_key, + [batch_size, encoder_seq_len, self.heads, self.head_dim], + ) + encoder_value = ops.reshape( + encoder_value, + [batch_size, encoder_seq_len, self.heads, self.head_dim], + ) + + if self.norm_added_q is None or self.norm_added_k is None: + raise ValueError("Encoder normalizations not initialized") + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = ops.concat([encoder_query, query], axis=1) + key = ops.concat([encoder_key, key], axis=1) + value = ops.concat([encoder_value, value], axis=1) + + original_dtype = query.dtype + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + if query.dtype != original_dtype: + query = ops.cast(query, original_dtype) + if key.dtype != original_dtype: + key = ops.cast(key, original_dtype) + + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=self.scale, + ) + + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + hidden_states = ops.reshape( + hidden_states, [batch_size, seq_len, self.inner_dim] + ) + if hidden_states.dtype != query.dtype: + hidden_states = ops.cast(hidden_states, query.dtype) + + if encoder_hidden_states is not None: + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_out = hidden_states[:, :encoder_seq_len, :] + hidden_out = hidden_states[:, encoder_seq_len:, :] + + hidden_out = self.to_out[0](hidden_out) + if self.to_add_out is None: + raise ValueError("Encoder output projection not initialized") + encoder_out = self.to_add_out(encoder_out) + + return hidden_out, encoder_out + else: + hidden_states = self.to_out[0](hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# Per-block Modulation (matches diffusers: img_mod.1.weight, txt_mod.1.weight) +# --------------------------------------------------------------------------- + + +class _SiLUPlaceholder(Module): + """Placeholder at index 0 in LayerList; SiLU has no learnable params.""" + + def __init__(self): + super().__init__() + + def __call__(self, x: TensorValue) -> TensorValue: + return ops.silu(x) + + +def _make_block_modulation( + dim: int, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, +) -> LayerList: + """Create per-block modulation as LayerList[SiLU_placeholder, Linear]. + + Produces weight keys: `{attr_name}.1.weight` and `{attr_name}.1.bias` + matching the diffusers convention img_mod.1.weight / txt_mod.1.weight. + """ + return LayerList( + [ + _SiLUPlaceholder(), + Linear( + in_dim=dim, + out_dim=dim * 6, + dtype=dtype, + device=device, + has_bias=bias, + ), + ] + ) + + +# --------------------------------------------------------------------------- +# Transformer Block (per-block img_mod, txt_mod, img_mlp, txt_mlp) +# --------------------------------------------------------------------------- + + +class QwenImageTransformerBlock(Module): + """Dual-stream transformer block with per-block modulation. + + Weight key structure per block: + img_mod.1.{weight,bias} + txt_mod.1.{weight,bias} + attn.to_q.{weight,bias}, attn.to_k.{weight,bias}, ... + img_mlp.net.0.proj.{weight,bias}, img_mlp.net.2.{weight,bias} + txt_mlp.net.0.proj.{weight,bias}, txt_mlp.net.2.{weight,bias} + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + eps: float = 1e-6, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + # Per-block modulation (img_mod, txt_mod) + self.img_mod: LayerList = _make_block_modulation( + dim, bias=bias, dtype=dtype, device=device + ) + self.txt_mod: LayerList = _make_block_modulation( + dim, bias=bias, dtype=dtype, device=device + ) + + # Norms (no affine → no weights in state_dict) + self.img_norm1 = LayerNormNoAffine(eps=eps) + self.img_norm2 = LayerNormNoAffine(eps=eps) + self.txt_norm1 = LayerNormNoAffine(eps=eps) + self.txt_norm2 = LayerNormNoAffine(eps=eps) + + # Dual-stream attention + self.attn = QwenImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + dtype=dtype, + device=device, + ) + + # Feedforward (img_mlp, txt_mlp) + self.img_mlp = QwenImageFeedForward( + dim=dim, + dim_out=dim, + mult=mlp_ratio, + bias=bias, + dtype=dtype, + device=device, + ) + self.txt_mlp = QwenImageFeedForward( + dim=dim, + dim_out=dim, + mult=mlp_ratio, + bias=bias, + dtype=dtype, + device=device, + ) + + def _apply_modulation( + self, + x: TensorValue, + shift: TensorValue, + scale: TensorValue, + ) -> TensorValue: + """Apply shift/scale modulation: (1 + scale) * x + shift.""" + return (1 + scale) * x + shift + + def _apply_split_modulation( + self, + x: TensorValue, + mod_real: TensorValue, + mod_zero: TensorValue, + num_noise: int, + mod_idx: int, + ) -> TensorValue: + """Apply different modulation to noise vs condition tokens. + + Splits x along seq dim, applies mod_real to noise tokens and + mod_zero to condition tokens, then concatenates back. + Avoids broadcasting [B,1,D] to [B,seq,D]. + """ + # mod has 6 chunks: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp + # We need shift and scale at mod_idx and mod_idx+1 + real_chunks = ops.chunk(mod_real, 6, axis=-1) + zero_chunks = ops.chunk(mod_zero, 6, axis=-1) + shift_r, scale_r = real_chunks[mod_idx], real_chunks[mod_idx + 1] + shift_z, scale_z = zero_chunks[mod_idx], zero_chunks[mod_idx + 1] + + x_noise = x[:, :num_noise, :] + x_cond = x[:, num_noise:, :] + + x_noise = (1 + scale_r) * x_noise + shift_r + x_cond = (1 + scale_z) * x_cond + shift_z + + return ops.concat([x_noise, x_cond], axis=1) + + def _apply_split_gate( + self, + x: TensorValue, + gate_real: TensorValue, + gate_zero: TensorValue, + num_noise: int, + ) -> TensorValue: + """Apply different gate to noise vs condition tokens.""" + x_noise = x[:, :num_noise, :] * gate_real + x_cond = x[:, num_noise:, :] * gate_zero + return ops.concat([x_noise, x_cond], axis=1) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + temb: TensorValue, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + temb_zero: TensorValue | None = None, + num_noise_tokens: int | None = None, + ) -> tuple[TensorValue, TensorValue]: + # Compute per-block modulation params from temb + # Compute silu once and reuse for both modulation projections. + temb_activated = ops.silu(temb) + img_mod = self.img_mod[1](temb_activated) + txt_mod = self.txt_mod[1](temb_activated) + + if len(img_mod.shape) == 2: + img_mod = ops.unsqueeze(img_mod, 1) + txt_mod = ops.unsqueeze(txt_mod, 1) + + # zero_cond_t path: separate modulation for condition tokens + img_mod_zero: TensorValue | None = None + if temb_zero is not None: + temb_zero_activated = ops.silu(temb_zero) + img_mod_zero = self.img_mod[1](temb_zero_activated) + if len(img_mod_zero.shape) == 2: + img_mod_zero = ops.unsqueeze(img_mod_zero, 1) + + img_mod_chunks = ops.chunk(img_mod, 6, axis=-1) + shift_msa, scale_msa, gate_msa = ( + img_mod_chunks[0], + img_mod_chunks[1], + img_mod_chunks[2], + ) + shift_mlp, scale_mlp, gate_mlp = ( + img_mod_chunks[3], + img_mod_chunks[4], + img_mod_chunks[5], + ) + + txt_mod_chunks = ops.chunk(txt_mod, 6, axis=-1) + c_shift_msa, c_scale_msa, c_gate_msa = ( + txt_mod_chunks[0], + txt_mod_chunks[1], + txt_mod_chunks[2], + ) + c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + txt_mod_chunks[3], + txt_mod_chunks[4], + txt_mod_chunks[5], + ) + + # Image stream - Attention + norm_hidden_states = self.img_norm1(hidden_states) + if img_mod_zero is not None and num_noise_tokens is not None: + norm_hidden_states = self._apply_split_modulation( + norm_hidden_states, img_mod, img_mod_zero, num_noise_tokens, 0 + ) + else: + norm_hidden_states = ( + 1 + scale_msa + ) * norm_hidden_states + shift_msa + + # Text stream - Attention + norm_encoder_hidden_states = self.txt_norm1(encoder_hidden_states) + norm_encoder_hidden_states = ( + 1 + c_scale_msa + ) * norm_encoder_hidden_states + c_shift_msa + + # Dual-stream attention + attn_output, context_attn_output = self.attn( + norm_hidden_states, + norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # Image stream - Apply gate and residual + if img_mod_zero is not None and num_noise_tokens is not None: + img_mod_zero_chunks = ops.chunk(img_mod_zero, 6, axis=-1) + attn_output = self._apply_split_gate( + attn_output, gate_msa, img_mod_zero_chunks[2], num_noise_tokens + ) + else: + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + # Image stream - Feedforward + norm_hidden_states = self.img_norm2(hidden_states) + if img_mod_zero is not None and num_noise_tokens is not None: + norm_hidden_states = self._apply_split_modulation( + norm_hidden_states, img_mod, img_mod_zero, num_noise_tokens, 3 + ) + else: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp) + shift_mlp + ) + + ff_output = self.img_mlp(norm_hidden_states) + if img_mod_zero is not None and num_noise_tokens is not None: + ff_output = self._apply_split_gate( + ff_output, gate_mlp, img_mod_zero_chunks[5], num_noise_tokens + ) + else: + ff_output = gate_mlp * ff_output + hidden_states = hidden_states + ff_output + + # Text stream - Apply gate and residual + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + # Text stream - Feedforward + norm_encoder_hidden_states = self.txt_norm2(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + ) + + context_ff_output = self.txt_mlp(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp * context_ff_output + ) + + if encoder_hidden_states.dtype == DType.float16: + encoder_hidden_states = ops.max(encoder_hidden_states, -65504.0) + encoder_hidden_states = ops.min(encoder_hidden_states, 65504.0) + + return encoder_hidden_states, hidden_states diff --git a/max/python/max/pipelines/architectures/qwen_image/model.py b/max/python/max/pipelines/architectures/qwen_image/model.py new file mode 100644 index 00000000000..8501ba161ee --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/model.py @@ -0,0 +1,88 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from collections.abc import Callable +from typing import Any + +from max.driver import Buffer, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.component_model import ComponentModel + +from .model_config import QwenImageConfig +from .qwen_image import QwenImageTransformer2DModel + + +class QwenImageTransformerModel(ComponentModel): + def __init__( + self, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.session = session + self.config = QwenImageConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Callable[..., Any]: + state_dict = {key: value.data() for key, value in self.weights.items()} + + nn_model = QwenImageTransformer2DModel(self.config) + nn_model.load_state_dict(state_dict, weight_alignment=1, strict=True) + self.state_dict = nn_model.state_dict() + + with Graph( + "qwen_image_transformer", + input_types=nn_model.input_types(), + ) as graph: + outputs = nn_model(*(value.tensor for value in graph.inputs)) + if isinstance(outputs, tuple): + graph.output(*outputs) + else: + graph.output(outputs) + + self.model: Model = self.session.load( + graph, + weights_registry=self.state_dict, + ) + return self.model.execute + + def __call__( + self, + hidden_states: Buffer, + encoder_hidden_states: Buffer, + timestep: Buffer, + img_ids: Buffer, + txt_ids: Buffer, + ) -> Any: + return self.model.execute( + hidden_states, + encoder_hidden_states, + timestep, + img_ids, + txt_ids, + ) diff --git a/max/python/max/pipelines/architectures/qwen_image/model_config.py b/max/python/max/pipelines/architectures/qwen_image/model_config.py new file mode 100644 index 00000000000..fbf71097102 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/model_config.py @@ -0,0 +1,59 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import Any + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from max.pipelines.lib.config.config_enums import supported_encoding_dtype +from pydantic import Field + + +class QwenImageConfigBase(MAXModelConfigBase): + patch_size: int = 2 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 60 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 3584 + guidance_embeds: bool = False + axes_dims_rope: tuple[int, ...] = (16, 56, 56) + rope_theta: int = 10000 + zero_cond_t: bool = False + eps: float = 1e-6 + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class QwenImageConfig(QwenImageConfigBase): + @staticmethod + def generate( + config_dict: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + ) -> QwenImageConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in QwenImageConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": supported_encoding_dtype(encoding), + "device": DeviceRef.from_device(devices[0]), + } + ) + return QwenImageConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/qwen_image/pipeline_qwen_image.py b/max/python/max/pipelines/architectures/qwen_image/pipeline_qwen_image.py new file mode 100644 index 00000000000..f335da8c871 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/pipeline_qwen_image.py @@ -0,0 +1,828 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage diffusion pipeline. + +Key differences from Flux2Pipeline: +- True CFG with two forward passes (positive + negative prompts) +- No guidance embedding (timestep only, not timestep+guidance) +- Latent normalization via latents_mean/latents_std instead of BatchNorm +- Text encoder returns last hidden state (not multiple layers) +- 3D position IDs (T, H, W) instead of 4D (T, H, W, L) +""" + +from dataclasses import dataclass, field +from typing import Any, Literal + +import numpy as np +import numpy.typing as npt +from max.driver import CPU, Buffer, Device +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.graph.ops import shape_to_tensor +from max.interfaces import TokenBuffer +from max.pipelines.core import PixelContext +from max.pipelines.lib.bfloat16_utils import float32_to_bfloat16_as_uint16 +from max.pipelines.lib.interfaces import DiffusionPipeline, PixelModelInputs +from max.pipelines.lib.interfaces.diffusion_pipeline import max_compile +from max.profiler import Tracer, traced + +from ..autoencoders.autoencoder_kl_qwen_image import AutoencoderKLQwenImageModel +from ..qwen2_5vl.encoder import Qwen25VLEncoderModel +from .model import QwenImageTransformerModel + + +@dataclass(kw_only=True) +class QwenImageModelInputs(PixelModelInputs): + """QwenImage-specific PixelModelInputs. + + QwenImage is not guidance-distilled — use ``--true-cfg-scale`` + (not ``--guidance-scale``) to control classifier-free guidance. + """ + + width: int = 1024 + height: int = 1024 + true_cfg_scale: float = 4.0 + num_inference_steps: int = 50 + num_images_per_prompt: int = 1 + + +@dataclass +class QwenImagePipelineOutput: + """Container for QwenImage pipeline results.""" + + images: np.ndarray | list + + +@dataclass +class QwenImageCache: + """Runtime cache for reusable Qwen image buffers.""" + + sigmas: dict[str, Buffer] = field(default_factory=dict) + text_ids: dict[str, Buffer] = field(default_factory=dict) + shape_carriers: dict[int, Buffer] = field(default_factory=dict) + cfg_scales: dict[float, Buffer] = field(default_factory=dict) + latent_image_ids: dict[tuple[int, int, int], Buffer] = field( + default_factory=dict + ) + prompt_tokens: dict[tuple[int, ...], Buffer] = field(default_factory=dict) + + +class QwenImagePipeline(DiffusionPipeline): + """Diffusion pipeline for QwenImage text-to-image generation. + + Wires together: + - Qwen2.5-VL text encoder + - QwenImage transformer denoiser (60 dual-stream blocks) + - QwenImage 3D VAE (with latents_mean/std normalization) + """ + + vae: AutoencoderKLQwenImageModel + text_encoder: Qwen25VLEncoderModel + transformer: QwenImageTransformerModel + + components = { + "vae": AutoencoderKLQwenImageModel, + "text_encoder": Qwen25VLEncoderModel, + "transformer": QwenImageTransformerModel, + } + + def init_remaining_components(self) -> None: + """Initialize derived attributes that depend on loaded components.""" + # QwenImage VAE uses dim_mult [1,2,4,4] with 3 downsample stages + # Spatial scale factor = 2^3 = 8 + self.vae_scale_factor = 8 + + self._compile_runtime_helpers() + self._compile_cfg_fastpath_helpers() + self.cache: QwenImageCache = QwenImageCache() + + def prepare_inputs(self, context: PixelContext) -> QwenImageModelInputs: # type: ignore[override] + """Convert a PixelContext into QwenImageModelInputs.""" + return QwenImageModelInputs.from_context(context) + + def _compile_runtime_helpers(self) -> None: + """Compile the core runtime helper graphs used by QwenImage.""" + device = self.transformer.devices[0] + self.cached_patchify_and_pack = max_compile( + self._patchify_and_pack, + input_types=[ + TensorType( + DType.float32, + shape=["batch", "channels", "height", 2, "width", 2], + device=device, + ), + ], + ) + + self.cached_prepare_scheduler = max_compile( + self.prepare_scheduler, + input_types=[ + TensorType( + DType.float32, + shape=["num_sigmas"], + device=device, + ), + ], + ) + + dtype = self.transformer.config.dtype + packed_channels = self.transformer.config.in_channels + self.cached_scheduler_step = max_compile( + self.scheduler_step, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, + shape=["batch", "pred_seq", "channels"], + device=device, + ), + TensorType(DType.float32, shape=[1], device=device), + ], + ) + + z_dim = 16 # VAE latent channels + self.cached_postprocess_latents = max_compile( + self._postprocess_latents, + input_types=[ + TensorType( + dtype, + shape=["batch", "height", "width", packed_channels], + device=device, + ), + TensorType(dtype, shape=[z_dim], device=device), + TensorType(dtype, shape=[z_dim], device=device), + ], + ) + + self.cached_cfg_blend = max_compile( + self._cfg_blend, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", "channels"], + device=device, + ), + TensorType( + dtype, + shape=["batch", "seq", "channels"], + device=device, + ), + TensorType(DType.float32, shape=[1], device=device), + ], + ) + + self.cached_reshape_latents = max_compile( + self._reshape_latents, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", packed_channels], + device=device, + ), + TensorType(DType.float32, shape=["packed_h"], device=CPU()), + TensorType(DType.float32, shape=["packed_w"], device=CPU()), + ], + ) + + hidden_size = self.text_encoder.config.hidden_size + text_dtype = self.text_encoder.config.dtype + text_device = self.text_encoder.devices[0] + self.cached_trim_prompt_embeddings = max_compile( + self._trim_prompt_embeddings, + input_types=[ + TensorType( + text_dtype, + shape=["seq", hidden_size], + device=text_device, + ) + ], + ) + + def duplicate_batch(value: TensorValue) -> TensorValue: + return ops.concat([value, value], axis=0) + + self.cached_duplicate_prompt_embeddings = max_compile( + duplicate_batch, + input_types=[ + TensorType( + text_dtype, + shape=[1, "trimmed_seq_len", hidden_size], + device=text_device, + ) + ], + ) + + def _compile_cfg_fastpath_helpers(self) -> None: + """Compile the small helper graphs used by the CFG fast path.""" + + def duplicate_batch(value: TensorValue) -> TensorValue: + return ops.concat([value, value], axis=0) + + def concat_batch_pair( + first_value: TensorValue, + second_value: TensorValue, + ) -> TensorValue: + return ops.concat([first_value, second_value], axis=0) + + def split_cfg_predictions( + batched_predictions: TensorValue, + ) -> tuple[TensorValue, TensorValue]: + positive_prediction = ops.slice_tensor( + batched_predictions, + [slice(0, 1), slice(None), slice(None)], + ) + negative_prediction = ops.slice_tensor( + batched_predictions, + [slice(1, 2), slice(None), slice(None)], + ) + return positive_prediction, negative_prediction + + text_dtype = self.text_encoder.config.dtype + text_device = self.text_encoder.devices[0] + hidden_size = self.text_encoder.config.hidden_size + self.cached_concat_cfg_prompt_embeddings = max_compile( + concat_batch_pair, + input_types=[ + TensorType( + text_dtype, + shape=[1, "trimmed_seq_len", hidden_size], + device=text_device, + ), + TensorType( + text_dtype, + shape=[1, "trimmed_seq_len", hidden_size], + device=text_device, + ), + ], + ) + + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + packed_channels = self.transformer.config.in_channels + self.cached_duplicate_cfg_latents = max_compile( + duplicate_batch, + input_types=[ + TensorType( + dtype, + shape=[1, "seq", packed_channels], + device=device, + ) + ], + ) + + self.cached_concat_cfg_ids = max_compile( + concat_batch_pair, + input_types=[ + TensorType( + DType.int64, + shape=[1, "seq", 3], + device=device, + ), + TensorType( + DType.int64, + shape=[1, "seq", 3], + device=device, + ), + ], + ) + + self.cached_duplicate_cfg_timesteps = max_compile( + duplicate_batch, + input_types=[TensorType(dtype, shape=[1], device=device)], + ) + + self.cached_duplicate_cfg_ids = max_compile( + duplicate_batch, + input_types=[ + TensorType( + DType.int64, + shape=[1, "seq", 3], + device=device, + ) + ], + ) + + self.cached_split_cfg_predictions = max_compile( + split_cfg_predictions, + input_types=[ + TensorType( + dtype, + shape=[2, "seq", packed_channels], + device=device, + ) + ], + ) + + def _cfg_blend( + self, + cond_pred: TensorValue, + uncond_pred: TensorValue, + cfg_scale: TensorValue, + ) -> TensorValue: + scale = ops.cast(cfg_scale, cond_pred.dtype) + return uncond_pred + scale * (cond_pred - uncond_pred) + + # Number of chat template prefix tokens to drop from encoder output. + # Matches diffusers' prompt_template_encode_start_idx = 34. + PROMPT_TEMPLATE_DROP_IDX = 34 + + def _trim_prompt_embeddings( + self, hidden_states: TensorValue + ) -> TensorValue: + trimmed = ops.slice_tensor( + hidden_states, + [slice(self.PROMPT_TEMPLATE_DROP_IDX, None), slice(None)], + ) + return ops.unsqueeze(trimmed, 0) + + def prepare_prompt_embeddings( + self, + tokens: TokenBuffer, + num_images_per_prompt: int = 1, + ) -> Buffer: + """Create prompt embeddings from tokens. + + QwenImage uses the last hidden state from the text encoder (layer -1). + The tokens include a chat template prefix (~34 tokens) that must be + dropped from the encoder output to match diffusers' behavior. + """ + device = self.text_encoder.devices[0] + text_input_ids_np = np.asarray(tokens.array).flatten() + token_key = tuple(int(token) for token in text_input_ids_np.tolist()) + if token_key not in self.cache.prompt_tokens: + self.cache.prompt_tokens[token_key] = Buffer.from_dlpack( + np.ascontiguousarray(text_input_ids_np) + ).to(device) + token_buf = self.cache.prompt_tokens[token_key] + + hidden_states_all = self.text_encoder(token_buf) + hs_buf = hidden_states_all[-1] + + trimmed = self.cached_trim_prompt_embeddings(hs_buf) + if num_images_per_prompt == 1: + return trimmed + if num_images_per_prompt == 2: + return self.cached_duplicate_prompt_embeddings(trimmed) + + hs_cpu = hs_buf.to(CPU()) + if self.text_encoder.config.dtype == DType.bfloat16: + hs_u16 = np.from_dlpack( + hs_cpu.view(dtype=DType.uint16, shape=hs_cpu.shape) + ) + hs_np = (hs_u16.astype(np.uint32) << 16).view(np.float32) + else: + hs_np = np.from_dlpack(hs_cpu).astype(np.float32) + + hs_np = hs_np[self.PROMPT_TEMPLATE_DROP_IDX :] + hs_np = hs_np[np.newaxis, :, :] + + if num_images_per_prompt != 1: + hs_np = np.broadcast_to( + hs_np, + (num_images_per_prompt, hs_np.shape[1], hs_np.shape[2]), + ).copy() + + if self.text_encoder.config.dtype == DType.bfloat16: + result_u16 = float32_to_bfloat16_as_uint16( + np.ascontiguousarray(hs_np) + ) + buf = Buffer.from_numpy(result_u16).to(device) + return buf.view(dtype=DType.bfloat16, shape=hs_np.shape) + + return Buffer.from_numpy(np.ascontiguousarray(hs_np)).to(device) + + @staticmethod + def _prepare_text_ids( + batch_size: int, + seq_len: int, + device: Device, + max_vid_index: int = 0, + ) -> Buffer: + """Create 3D text position IDs in (T, H, W) format. + + QwenImage text tokens use positions [max_vid_index, max_vid_index+1, ...] + for all 3 axes (matching diffusers scale_rope=True convention). + """ + tok_positions = np.arange(seq_len, dtype=np.int64) + max_vid_index + coords = np.stack( + [tok_positions, tok_positions, tok_positions], axis=-1 + ) + text_ids = np.broadcast_to( + coords[np.newaxis, :, :], + (batch_size, coords.shape[0], coords.shape[1]), + ).copy() + return Buffer.from_dlpack(text_ids).to(device) + + def _reshape_latents( + self, + latents_bsc: TensorValue, + h_carrier: TensorValue, + w_carrier: TensorValue, + ) -> TensorValue: + batch = latents_bsc.shape[0] + h = h_carrier.shape[0] + w = w_carrier.shape[0] + channels = latents_bsc.shape[2] + latents_bsc = ops.rebind(latents_bsc, [batch, h * w, channels]) + return ops.reshape(latents_bsc, (batch, h, w, channels)) + + def _get_shape_carriers( + self, h_latent: int, w_latent: int + ) -> tuple[Buffer, Buffer]: + for n in (h_latent, w_latent): + if n not in self.cache.shape_carriers: + self.cache.shape_carriers[n] = Buffer.from_dlpack( + np.empty(n, dtype=np.float32) + ) + return ( + self.cache.shape_carriers[h_latent], + self.cache.shape_carriers[w_latent], + ) + + def decode_latents( + self, + latents: Buffer, + height: int, + width: int, + output_type: Literal["np", "latent"] = "np", + ) -> np.ndarray | Buffer: + """Decode packed latents into an image array.""" + if output_type == "latent": + return latents + + h_latent = height // (self.vae_scale_factor * 2) + w_latent = width // (self.vae_scale_factor * 2) + + latents_mean = self.vae.latents_mean_tensor + latents_std = self.vae.latents_std_tensor + if latents_mean is None or latents_std is None: + raise ValueError("VAE latents_mean/latents_std not loaded.") + + h_carrier, w_carrier = self._get_shape_carriers(h_latent, w_latent) + latents_bhwc = self.cached_reshape_latents( + latents, h_carrier, w_carrier + ) + + latents_decoded = self.cached_postprocess_latents( + latents_bhwc, latents_mean, latents_std + ) + + decoded = self.vae.decode(latents_decoded) + return self._image_to_flat_hwc(self._to_numpy(decoded)) + + def _postprocess_latents( + self, + latents_bhwc: TensorValue, + latents_mean: TensorValue, + latents_std: TensorValue, + ) -> TensorValue: + """Unpatchify and denormalize latents for VAE decoding.""" + batch = latents_bhwc.shape[0] + h = latents_bhwc.shape[1] + w = latents_bhwc.shape[2] + c = latents_bhwc.shape[3] + z_dim = c // 4 # 16 + + # Permute (B, H, W, C) -> (B, C, H, W) + latents = ops.permute(latents_bhwc, (0, 3, 1, 2)) + + # Unpatchify first: (B, C, H, W) -> (B, z_dim, H*2, W*2) + latents = ops.reshape(latents, (batch, z_dim, 2, 2, h, w)) + latents = ops.permute(latents, (0, 1, 4, 2, 5, 3)) + latents = ops.reshape(latents, (batch, z_dim, h * 2, w * 2)) + + # Then denormalize using latents_mean/std (shape [z_dim]) + mean_r = ops.reshape(latents_mean, (1, z_dim, 1, 1)) + std_r = ops.reshape(latents_std, (1, z_dim, 1, 1)) + latents = latents * std_r + mean_r + + return latents + + def _to_numpy(self, image: Any) -> np.ndarray: + cpu_image = image.to(CPU()) + try: + return np.from_dlpack(cpu_image).astype(np.float32) + except (RuntimeError, TypeError): + # bfloat16 not supported by numpy, cast via v1 Tensor + from max.experimental.tensor import Tensor as _Tensor + + if isinstance(cpu_image, _Tensor): + return np.from_dlpack(cpu_image.cast(DType.float32)).astype( + np.float32 + ) + # Buffer bfloat16: wrap in v1 Tensor to cast + t = _Tensor(storage=cpu_image) + return np.from_dlpack(t.cast(DType.float32)).astype(np.float32) + + @staticmethod + def _image_to_flat_hwc(image: np.ndarray) -> np.ndarray: + img = np.asarray(image) + while img.ndim > 3: + img = img.squeeze(0) + if img.ndim == 3 and img.shape[0] == 3: + img = np.transpose(img, (1, 2, 0)) + return img.astype(np.float32, copy=False) + + def preprocess_latents( + self, + latents: npt.NDArray[np.float32], + latent_image_ids: npt.NDArray[np.float32], + ) -> tuple[Buffer, Buffer]: + latents_np = np.asarray(latents) + batch = latents_np.shape[0] + c = latents_np.shape[1] + h = latents_np.shape[2] + w = latents_np.shape[3] + latents_6d = latents_np.reshape(batch, c, h // 2, 2, w // 2, 2) + latents_6d_buf = Buffer.from_dlpack( + np.ascontiguousarray(latents_6d) + ).to(self.transformer.devices[0]) + latents_packed = self.cached_patchify_and_pack(latents_6d_buf) + + latent_image_ids_int64 = np.asarray(latent_image_ids, dtype=np.int64) + ids_key = ( + int(latent_image_ids_int64.shape[0]), + int(latent_image_ids_int64.shape[1]), + int(latent_image_ids_int64.shape[2]), + ) + if ids_key not in self.cache.latent_image_ids: + self.cache.latent_image_ids[ids_key] = Buffer.from_dlpack( + latent_image_ids_int64 + ).to(self.transformer.devices[0]) + latent_image_ids_buf = self.cache.latent_image_ids[ids_key] + return latents_packed, latent_image_ids_buf + + def _patchify_and_pack(self, latents: TensorValue) -> TensorValue: + """Patchify (B,C,H,W)->(B,C*4,H//2,W//2) then pack to (B,H//2*W//2,C*4).""" + latents = ops.cast(latents, self.transformer.config.dtype) + batch = latents.shape[0] + c = latents.shape[1] + h2 = latents.shape[2] + w2 = latents.shape[4] + + latents = ops.permute(latents, (0, 1, 3, 5, 2, 4)) + latents = ops.reshape(latents, (batch, c * 4, h2, w2)) + + c4 = c * 4 + latents = ops.reshape(latents, (batch, c4, h2 * w2)) + latents = ops.permute(latents, (0, 2, 1)) + + return latents + + def scheduler_step( + self, + latents: TensorValue, + noise_pred: TensorValue, + dt: TensorValue, + ) -> TensorValue: + """Apply a single Euler update step.""" + num_noise_tokens = shape_to_tensor([latents.shape[1]]) + latents_sliced = ops.slice_tensor( + latents, + [ + slice(None), + (slice(0, num_noise_tokens), "num_tokens"), + slice(None), + ], + ) + noise_pred_sliced = ops.slice_tensor( + noise_pred, + [ + slice(None), + (slice(0, num_noise_tokens), "num_tokens"), + slice(None), + ], + ) + latents_dtype = latents_sliced.dtype + latents_sliced = ops.cast(latents_sliced, DType.float32) + latents_sliced = latents_sliced + dt * noise_pred_sliced + return ops.cast(latents_sliced, latents_dtype) + + def prepare_scheduler( + self, sigmas: TensorValue + ) -> tuple[TensorValue, TensorValue]: + """Precompute timesteps and dt values from sigmas.""" + sigmas_curr = ops.slice_tensor(sigmas, [slice(0, -1)]) + sigmas_next = ops.slice_tensor(sigmas, [slice(1, None)]) + all_dt = sigmas_next - sigmas_curr + all_timesteps = ops.cast(sigmas_curr, self.transformer.config.dtype) + return all_timesteps, all_dt + + @traced + def execute( # type: ignore[override] + self, + model_inputs: QwenImageModelInputs, + output_type: Literal["np", "latent"] = "np", + ) -> QwenImagePipelineOutput: + """Run the QwenImage denoising loop and decode outputs. + + Supports true classifier-free guidance with separate positive and + negative prompt forward passes. + """ + # Phase 1: prompt and latent preparation. + prompt_embeds = self.prepare_prompt_embeddings( + tokens=model_inputs.tokens, + num_images_per_prompt=model_inputs.num_images_per_prompt, + ) + batch_size = int(prompt_embeds.shape[0]) + device = self.transformer.devices[0] + + latents, latent_image_ids = self.preprocess_latents( + model_inputs.latents, model_inputs.latent_image_ids + ) + + h_latent = model_inputs.height // (self.vae_scale_factor * 2) + w_latent = model_inputs.width // (self.vae_scale_factor * 2) + max_vid_index = max(h_latent // 2, w_latent // 2) + text_ids_key = ( + f"{batch_size}_{int(prompt_embeds.shape[1])}_{max_vid_index}" + ) + if text_ids_key not in self.cache.text_ids: + self.cache.text_ids[text_ids_key] = self._prepare_text_ids( + batch_size=batch_size, + seq_len=int(prompt_embeds.shape[1]), + device=device, + max_vid_index=max_vid_index, + ) + text_ids = self.cache.text_ids[text_ids_key] + + # Phase 2: CFG setup. + do_true_cfg = ( + model_inputs.true_cfg_scale > 1.0 + and model_inputs.negative_tokens is not None + ) + negative_prompt_embeds: Buffer | None = None + negative_text_ids: Buffer | None = None + cfg_scale_buf: Buffer | None = None + batched_prompt_embeds: Buffer | None = None + batched_text_ids: Buffer | None = None + batched_latent_image_ids: Buffer | None = None + if do_true_cfg and model_inputs.negative_tokens is not None: + negative_prompt_embeds = self.prepare_prompt_embeddings( + tokens=model_inputs.negative_tokens, + num_images_per_prompt=model_inputs.num_images_per_prompt, + ) + negative_text_ids_key = f"{batch_size}_{int(negative_prompt_embeds.shape[1])}_{max_vid_index}" + if negative_text_ids_key not in self.cache.text_ids: + self.cache.text_ids[negative_text_ids_key] = ( + self._prepare_text_ids( + batch_size=batch_size, + seq_len=int(negative_prompt_embeds.shape[1]), + device=device, + max_vid_index=max_vid_index, + ) + ) + negative_text_ids = self.cache.text_ids[negative_text_ids_key] + + cfg_scale = float(model_inputs.true_cfg_scale) + if cfg_scale not in self.cache.cfg_scales: + self.cache.cfg_scales[cfg_scale] = Buffer.from_dlpack( + np.array([cfg_scale], dtype=np.float32) + ).to(device) + cfg_scale_buf = self.cache.cfg_scales[cfg_scale] + + if ( + batch_size == 1 + and prompt_embeds.shape[1] == negative_prompt_embeds.shape[1] + ): + batched_prompt_embeds = ( + self.cached_concat_cfg_prompt_embeddings( + prompt_embeds, + negative_prompt_embeds, + ) + ) + batched_text_ids = self.cached_concat_cfg_ids( + text_ids, + negative_text_ids, + ) + batched_latent_image_ids = self.cached_duplicate_cfg_ids( + latent_image_ids + ) + + # Phase 3: scheduler setup. + sigmas_key = ( + f"{model_inputs.num_inference_steps}_{int(latents.shape[1])}" + ) + if sigmas_key not in self.cache.sigmas: + self.cache.sigmas[sigmas_key] = Buffer.from_dlpack( + model_inputs.sigmas + ).to(device) + sigmas = self.cache.sigmas[sigmas_key] + + with Tracer("prepare_scheduler"): + all_timesteps, all_dts = self.cached_prepare_scheduler(sigmas) + timesteps_seq = all_timesteps.driver_tensor + dts_seq = all_dts.driver_tensor + + # Phase 4: denoising loop. + with Tracer("denoising_loop"): + for i in range(model_inputs.num_inference_steps): + with Tracer(f"denoising_step_{i}"): + timestep = timesteps_seq[i : i + 1] + dt = dts_seq[i : i + 1] + + if ( + batched_prompt_embeds is not None + and batched_text_ids is not None + and batched_latent_image_ids is not None + and cfg_scale_buf is not None + ): + with Tracer("transformer_cfg"): + batched_predictions = self.transformer( + self.cached_duplicate_cfg_latents(latents), + batched_prompt_embeds, + self.cached_duplicate_cfg_timesteps(timestep), + batched_latent_image_ids, + batched_text_ids, + )[0] + positive_prediction, negative_prediction = ( + self.cached_split_cfg_predictions( + batched_predictions + ) + ) + with Tracer("cfg_blend"): + noise_pred = self.cached_cfg_blend( + positive_prediction, + negative_prediction, + cfg_scale_buf, + ) + else: + with Tracer("transformer_pos"): + noise_pred = self.transformer( + latents, + prompt_embeds, + timestep, + latent_image_ids, + text_ids, + )[0] + + if ( + do_true_cfg + and negative_prompt_embeds is not None + and negative_text_ids is not None + and cfg_scale_buf is not None + ): + with Tracer("transformer_neg"): + negative_prediction = self.transformer( + latents, + negative_prompt_embeds, + timestep, + latent_image_ids, + negative_text_ids, + )[0] + with Tracer("cfg_blend"): + noise_pred = self.cached_cfg_blend( + noise_pred, + negative_prediction, + cfg_scale_buf, + ) + + with Tracer("scheduler_step"): + latents = self.cached_scheduler_step( + latents, + noise_pred, + dt, + ) + + # Phase 5: decode outputs. + image_list: list[np.ndarray | Buffer] = [] + if batch_size == 1: + image_list.append( + self.decode_latents( + latents, + model_inputs.height, + model_inputs.width, + output_type=output_type, + ) + ) + else: + latents_np = np.from_dlpack(latents.to(CPU())).astype(np.float32) + for batch_index in range(batch_size): + batch_latents = Buffer.from_dlpack( + np.ascontiguousarray( + latents_np[batch_index : batch_index + 1] + ) + ).to(device) + image_list.append( + self.decode_latents( + batch_latents, + model_inputs.height, + model_inputs.width, + output_type=output_type, + ) + ) + + return QwenImagePipelineOutput(images=image_list) diff --git a/max/python/max/pipelines/architectures/qwen_image/qwen_image.py b/max/python/max/pipelines/architectures/qwen_image/qwen_image.py new file mode 100644 index 00000000000..878b137fc08 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/qwen_image.py @@ -0,0 +1,252 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage Transformer 2D Model. + +A 20B parameter MMDiT model for text-to-image generation with 60 dual-stream +blocks, 3D RoPE, and timestep-only embeddings (no guidance embedding). + +Weight key naming matches HuggingFace diffusers: +- img_in.{weight,bias} (input projection for image latents) +- txt_in.{weight,bias} (input projection for text embeddings) +- time_text_embed.timestep_embedder.{linear_1,linear_2}.{weight,bias} +- txt_norm.weight (RMSNorm for text output) +- transformer_blocks.{i}.* (per-block: img_mod, txt_mod, attn, img_mlp, txt_mlp) +- norm_out.linear.{weight,bias} (AdaLayerNormContinuous) +- proj_out.{weight,bias} (output projection) +""" + +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + +from .layers.embeddings import ( + QwenImagePosEmbed, + QwenImageTimestepProjEmbeddings, +) +from .layers.normalizations import AdaLayerNormContinuous +from .layers.qwen_image_attention import QwenImageTransformerBlock +from .model_config import QwenImageConfigBase + + +class QwenImageTransformer2DModel(Module): + """QwenImage Transformer with 60 dual-stream blocks. + + Key differences from Flux2: + - No guidance embedding (timestep only) + - No single-stream blocks (all 60 are dual-stream) + - 3D RoPE with axes [16, 56, 56] (T, H, W) + - Per-block modulation (img_mod, txt_mod per block) + - inner_dim = 24 * 128 = 3072 + """ + + def __init__( + self, + config: QwenImageConfigBase, + ): + super().__init__() + patch_size = config.patch_size + in_channels = config.in_channels + out_channels = config.out_channels + num_layers = config.num_layers + attention_head_dim = config.attention_head_dim + num_attention_heads = config.num_attention_heads + joint_attention_dim = config.joint_attention_dim + axes_dims_rope = config.axes_dims_rope + rope_theta = config.rope_theta + device = config.device + dtype = config.dtype + eps = config.eps + + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Positional embeddings (3D RoPE: T, H, W) + self.pos_embed = QwenImagePosEmbed( + theta=rope_theta, axes_dim=axes_dims_rope + ) + + # 2. Timestep embeddings (no guidance) — key: time_text_embed.* + self.time_text_embed = QwenImageTimestepProjEmbeddings( + in_channels=256, + embedding_dim=self.inner_dim, + bias=True, + dtype=dtype, + device=device, + ) + + # 3. Input embeddings — keys: img_in.*, txt_in.* + self.img_in = Linear( + in_dim=in_channels, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.txt_in = Linear( + in_dim=joint_attention_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + + # 4. Text input norm — key: txt_norm.weight + self.txt_norm = RMSNorm(joint_attention_dim, dtype=dtype, eps=eps) + + # 5. Dual-stream transformer blocks (all 60 are dual-stream) + self.transformer_blocks: LayerList = LayerList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=4.0, + eps=eps, + bias=True, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers — keys: norm_out.linear.*, proj_out.* + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=self.inner_dim, + dtype=dtype, + device=device, + eps=eps, + bias=True, + ) + self.proj_out = Linear( + in_dim=self.inner_dim, + out_dim=patch_size * patch_size * self.out_channels, + dtype=dtype, + device=device, + has_bias=True, + ) + + # Store config for input_types + self.max_device = device + self.max_dtype = dtype + self.in_channels = in_channels + self.joint_attention_dim = joint_attention_dim + self.zero_cond_t = config.zero_cond_t + # Set before graph build for zero_cond_t split modulation + self.num_noise_tokens: int | None = None + + def input_types(self) -> tuple[TensorType, ...]: + hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "image_seq_len", self.in_channels], + device=self.max_device, + ) + encoder_hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "text_seq_len", self.joint_attention_dim], + device=self.max_device, + ) + timestep_type = TensorType( + self.max_dtype, shape=["batch_size"], device=self.max_device + ) + # 3D position IDs: (T, H, W) + img_ids_type = TensorType( + DType.int64, + shape=["batch_size", "image_seq_len", 3], + device=self.max_device, + ) + txt_ids_type = TensorType( + DType.int64, + shape=["batch_size", "text_seq_len", 3], + device=self.max_device, + ) + + result = ( + hidden_states_type, + encoder_hidden_states_type, + timestep_type, + img_ids_type, + txt_ids_type, + ) + + return result + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + timestep: TensorValue, + img_ids: TensorValue, + txt_ids: TensorValue, + ) -> tuple[TensorValue]: + """Forward pass through QwenImage Transformer. + + Args: + hidden_states: Image latents [B, img_seq, in_channels]. + encoder_hidden_states: Text embeddings [B, txt_len, joint_attention_dim]. + timestep: Denoising timestep [B] (scaled to [0, 1] range). + img_ids: Image position IDs [B, image_seq_len, 3] (T, H, W). + txt_ids: Text position IDs [B, text_seq_len, 3]. + num_noise_tokens: [1] scalar — number of noise tokens in the + image sequence. Condition tokens (positions >= this value) + receive timestep=0 modulation. Only for zero_cond_t=True. + + Returns: + Denoised output of shape [B, img_seq, patch_size^2 * out_channels]. + """ + # Handle batch dimension in ids + img_ids = img_ids[0] # [img_seq, 3] + txt_ids = txt_ids[0] # [txt_len, 3] + + # 1. Calculate timestep embedding + timestep_scaled = ops.cast(timestep * 1000.0, hidden_states.dtype) + temb = self.time_text_embed(timestep_scaled) + + # For zero_cond_t: compute temb for timestep=0 (condition tokens) + temb_zero: TensorValue | None = None + num_noise: int | None = None + if self.zero_cond_t and self.num_noise_tokens is not None: + zero_t = timestep_scaled * 0.0 + temb_zero = self.time_text_embed(zero_t) + num_noise = self.num_noise_tokens + + # 2. Input projection (txt_norm applied before txt_in projection) + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + # 3. Calculate RoPE embeddings + ids = ops.concat([txt_ids, img_ids], axis=0) + image_rotary_emb = self.pos_embed(ids) + + # 4. Dual-stream transformer blocks (all 60) + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + temb_zero=temb_zero, + num_noise_tokens=num_noise, + ) + + # 5. Output projection (image tokens only, discard text) + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return (output,) diff --git a/max/python/max/pipelines/architectures/qwen_image/weight_adapters.py b/max/python/max/pipelines/architectures/qwen_image/weight_adapters.py new file mode 100644 index 00000000000..2bc38b9ad69 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/weight_adapters.py @@ -0,0 +1,31 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Weight key remapping for QwenImage transformer. + +QwenImage HuggingFace weight keys follow the same pattern as Flux2 since +both use diffusers naming conventions. The keys map directly. +""" + +# The QwenImage transformer weights from HuggingFace use the same naming +# convention as the MAX implementation, so no remapping is needed. +# Weight keys like: +# transformer_blocks.0.attn.to_q.weight +# img_in.weight +# txt_in.weight +# norm_out.linear.weight +# proj_out.weight +# map directly to our Module attribute names. +# +# The only adaptation needed is in the ComponentModel.load_model() method +# which strips component prefixes during weight loading. diff --git a/max/tests/integration/architectures/qwen_image/BUILD.bazel b/max/tests/integration/architectures/qwen_image/BUILD.bazel new file mode 100644 index 00000000000..2bf4e7272ac --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/BUILD.bazel @@ -0,0 +1,50 @@ +# QwenImage parity tests: MAX vs diffusers. + +load( + "//bazel:api.bzl", + "modular_py_test", + "requirement", +) + +package(default_visibility = [ + "//:__pkg__", + "//SDK/integration-test:__subpackages__", + "//max/tests:__subpackages__", + "//oss/modular/max/tests:__subpackages__", +]) + +modular_py_test( + name = "qwen_image", + size = "large", + srcs = glob(["**/test_*.py"]) + ["conftest.py"], + data = [ + "//max/tests/integration/architectures/qwen_image/testdata", + ], + env = { + "PIPELINES_TESTDATA": "max/tests/integration/architectures/qwen_image/testdata", + "MODULAR_TORCH_MEMORY_PERCENT": "0.6", + }, + exec_properties = { + "test.resources:gpu-memory": "4", + }, + gpu_constraints = ["//:has_gpu"] + select({ + "//:apple_gpu": ["@platforms//:incompatible"], + "//conditions:default": [], + }), + tags = [ + "gpu", + "no-sandbox", + ], + deps = [ + "//max/python/max:tensor", + "//max/python/max/driver", + "//max/python/max/dtype", + "//max/python/max/graph", + "//max/python/max/pipelines/architectures", + "//max/python/max/pipelines/lib", + requirement("diffusers"), + requirement("numpy"), + requirement("torch"), + requirement("transformers"), + ], +) diff --git a/max/tests/integration/architectures/qwen_image/conftest.py b/max/tests/integration/architectures/qwen_image/conftest.py new file mode 100644 index 00000000000..9a97be3ccf8 --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/conftest.py @@ -0,0 +1,191 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Fixtures for QwenImage parity tests: config, input tensors, dummy weights.""" + +import json +import os +from pathlib import Path +from typing import Any + +import pytest +import torch + + +@pytest.fixture +def qwen_config() -> dict[str, Any]: + """Load QwenImage configuration from testdata.""" + path = os.environ["PIPELINES_TESTDATA"] + config_path = Path(path) / "config.json" + with open(config_path) as file: + return json.load(file) + + +@pytest.fixture +def hidden_states(qwen_config: dict[str, Any]) -> torch.Tensor: + """Random image latent hidden states. + + Shape: (batch_size, img_seq_len, inner_dim) + inner_dim = num_attention_heads * attention_head_dim = 24 * 128 = 3072 + """ + torch.manual_seed(42) + inner_dim = ( + qwen_config["num_attention_heads"] * qwen_config["attention_head_dim"] + ) + # Use 256 image tokens (small 128x128 latent, 64x64 patches) + return torch.randn(1, 256, inner_dim).to(torch.bfloat16).to("cuda") + + +@pytest.fixture +def encoder_hidden_states(qwen_config: dict[str, Any]) -> torch.Tensor: + """Random text encoder hidden states. + + Shape: (batch_size, txt_seq_len, inner_dim) + Note: in the block, text is already projected to inner_dim (3072), + not joint_attention_dim (3584). + """ + torch.manual_seed(43) + inner_dim = ( + qwen_config["num_attention_heads"] * qwen_config["attention_head_dim"] + ) + return torch.randn(1, 64, inner_dim).to(torch.bfloat16).to("cuda") + + +@pytest.fixture +def temb(qwen_config: dict[str, Any]) -> torch.Tensor: + """Random timestep embedding. + + Shape: (batch_size, inner_dim) + """ + torch.manual_seed(44) + inner_dim = ( + qwen_config["num_attention_heads"] * qwen_config["attention_head_dim"] + ) + return torch.randn(1, inner_dim).to(torch.bfloat16).to("cuda") + + +def _compute_rope_freqs( + seq_len: int, head_dim: int, theta: int = 10000 +) -> torch.Tensor: + """Compute complex-valued RoPE frequencies (same math as diffusers). + + Returns: complex tensor of shape [seq_len, head_dim // 2] + """ + pos_index = torch.arange(seq_len, dtype=torch.float32) + half_dim = head_dim // 2 + freq_base = 1.0 / torch.pow( + theta, + torch.arange(0, half_dim * 2, 2, dtype=torch.float32) / (half_dim * 2), + ) + freqs = torch.outer(pos_index, freq_base) + return torch.polar(torch.ones_like(freqs), freqs) + + +@pytest.fixture +def image_rotary_emb_diffusers( + qwen_config: dict[str, Any], + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """RoPE in diffusers format: (img_freqs, txt_freqs) as complex tensors. + + img_freqs: [img_seq_len, head_dim // 2] complex + txt_freqs: [txt_seq_len, head_dim // 2] complex + """ + head_dim = qwen_config["attention_head_dim"] + img_seq_len = hidden_states.shape[1] + txt_seq_len = encoder_hidden_states.shape[1] + + img_freqs = _compute_rope_freqs(img_seq_len, head_dim).to("cuda") + txt_freqs = _compute_rope_freqs(txt_seq_len, head_dim).to("cuda") + return (img_freqs, txt_freqs) + + +@pytest.fixture +def image_rotary_emb_max( + image_rotary_emb_diffusers: tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + """RoPE in MAX format: (cos, sin) as real tensors. + + Concatenated as [txt, img] to match MAX's concat order. + cos, sin: [txt_seq_len + img_seq_len, head_dim] float32 + """ + img_freqs, txt_freqs = image_rotary_emb_diffusers + # MAX concatenates text first, then image + full_freqs = torch.cat([txt_freqs, img_freqs], dim=0) + # Convert complex -> real: repeat_interleave to expand D//2 -> D + cos = full_freqs.real.repeat_interleave(2, dim=-1).float().to("cuda") + sin = full_freqs.imag.repeat_interleave(2, dim=-1).float().to("cuda") + return (cos, sin) + + +@pytest.fixture +def block_weights(qwen_config: dict[str, Any]) -> dict[str, torch.Tensor]: + """Random weights for a single QwenImageTransformerBlock. + + Uses realistic weight statistics (std, mean) from actual model weights. + """ + inner_dim = ( + qwen_config["num_attention_heads"] * qwen_config["attention_head_dim"] + ) + head_dim = qwen_config["attention_head_dim"] + mlp_hidden_dim = int(inner_dim * 4.0) + + # Format: {weight_name: (shape, std, mean)} + WEIGHT_STATS: dict[str, tuple[tuple[int, ...], float, float]] = { + # Per-block modulation + "img_mod.1.weight": ((6 * inner_dim, inner_dim), 0.02, 0.0), + "img_mod.1.bias": ((6 * inner_dim,), 0.01, 0.0), + "txt_mod.1.weight": ((6 * inner_dim, inner_dim), 0.02, 0.0), + "txt_mod.1.bias": ((6 * inner_dim,), 0.01, 0.0), + # Attention - main stream + "attn.to_q.weight": ((inner_dim, inner_dim), 0.032, 0.0), + "attn.to_q.bias": ((inner_dim,), 0.053, 0.0), + "attn.to_k.weight": ((inner_dim, inner_dim), 0.031, 0.0), + "attn.to_k.bias": ((inner_dim,), 0.065, 0.0), + "attn.to_v.weight": ((inner_dim, inner_dim), 0.023, 0.0), + "attn.to_v.bias": ((inner_dim,), 0.004, 0.0), + "attn.to_out.0.weight": ((inner_dim, inner_dim), 0.030, 0.0), + "attn.to_out.0.bias": ((inner_dim,), 0.020, 0.0), + "attn.norm_q.weight": ((head_dim,), 0.30, 0.86), + "attn.norm_k.weight": ((head_dim,), 0.21, 0.80), + # Attention - encoder stream + "attn.add_q_proj.weight": ((inner_dim, inner_dim), 0.036, 0.0), + "attn.add_q_proj.bias": ((inner_dim,), 0.041, 0.0), + "attn.add_k_proj.weight": ((inner_dim, inner_dim), 0.036, 0.0), + "attn.add_k_proj.bias": ((inner_dim,), 0.061, 0.0), + "attn.add_v_proj.weight": ((inner_dim, inner_dim), 0.027, 0.0), + "attn.add_v_proj.bias": ((inner_dim,), 0.028, 0.0), + "attn.to_add_out.weight": ((inner_dim, inner_dim), 0.035, 0.0), + "attn.to_add_out.bias": ((inner_dim,), 0.020, 0.0), + "attn.norm_added_q.weight": ((head_dim,), 0.076, 0.69), + "attn.norm_added_k.weight": ((head_dim,), 0.17, 0.74), + # Image MLP + "img_mlp.net.0.proj.weight": ((mlp_hidden_dim, inner_dim), 0.02, 0.0), + "img_mlp.net.0.proj.bias": ((mlp_hidden_dim,), 0.01, 0.0), + "img_mlp.net.2.weight": ((inner_dim, mlp_hidden_dim), 0.02, 0.0), + "img_mlp.net.2.bias": ((inner_dim,), 0.01, 0.0), + # Text MLP + "txt_mlp.net.0.proj.weight": ((mlp_hidden_dim, inner_dim), 0.02, 0.0), + "txt_mlp.net.0.proj.bias": ((mlp_hidden_dim,), 0.01, 0.0), + "txt_mlp.net.2.weight": ((inner_dim, mlp_hidden_dim), 0.02, 0.0), + "txt_mlp.net.2.bias": ((inner_dim,), 0.01, 0.0), + } + + torch.manual_seed(100) + weights = {} + for key, (shape, std, mean) in WEIGHT_STATS.items(): + weights[key] = ( + torch.randn(shape, dtype=torch.bfloat16).to("cuda") * std + mean + ) + return weights diff --git a/max/tests/integration/architectures/qwen_image/test_attention.py b/max/tests/integration/architectures/qwen_image/test_attention.py new file mode 100644 index 00000000000..64210a1eee5 --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/test_attention.py @@ -0,0 +1,192 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Test that MAX QwenImageTransformerBlock matches diffusers output.""" + +from typing import Any + +import torch +from diffusers.models.transformers.transformer_qwenimage import ( + QwenImageTransformerBlock as DiffusersBlock, +) +from max.driver import Accelerator +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import TensorType +from max.pipelines.architectures.qwen_image.layers.qwen_image_attention import ( + QwenImageTransformerBlock as MaxBlock, +) +from torch.utils.dlpack import from_dlpack + + +class QwenImageBlockWrapper(MaxBlock): + """Wrapper to flatten tuple inputs for MAX compiler.""" + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + def forward( # type: ignore[override] + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + temb: Tensor, + rotary_cos: Tensor, + rotary_sin: Tensor, + ) -> tuple[Tensor, Tensor]: + return super().forward( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb=(rotary_cos, rotary_sin), + ) + + +@torch.no_grad() +def generate_torch_outputs( + qwen_config: dict[str, Any], + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + block_weights: dict[str, torch.Tensor], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + """Run diffusers QwenImageTransformerBlock and return outputs.""" + inner_dim = ( + qwen_config["num_attention_heads"] * qwen_config["attention_head_dim"] + ) + + layer = ( + DiffusersBlock( + dim=inner_dim, + num_attention_heads=qwen_config["num_attention_heads"], + attention_head_dim=qwen_config["attention_head_dim"], + qk_norm="rms_norm", + eps=qwen_config["eps"], + ) + .to(torch.bfloat16) + .to("cuda") + ) + layer.load_state_dict(block_weights) + + txt_out, img_out = layer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + return txt_out, img_out + + +def generate_max_outputs( + qwen_config: dict[str, Any], + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + block_weights: dict[str, torch.Tensor], + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + """Run MAX QwenImageTransformerBlock and return outputs.""" + device_ref = Accelerator() + inner_dim = ( + qwen_config["num_attention_heads"] * qwen_config["attention_head_dim"] + ) + + with F.lazy(): + block = QwenImageBlockWrapper( + dim=inner_dim, + num_attention_heads=qwen_config["num_attention_heads"], + attention_head_dim=qwen_config["attention_head_dim"], + mlp_ratio=4.0, + eps=qwen_config["eps"], + bias=True, + ) + block.to(device_ref) + + batch_size, img_seq_len, _ = hidden_states.shape + txt_seq_len = encoder_hidden_states.shape[1] + total_seq_len = txt_seq_len + img_seq_len + head_dim = qwen_config["attention_head_dim"] + + cos, sin = image_rotary_emb + + compiled = block.compile( + TensorType( + DType.bfloat16, [batch_size, img_seq_len, inner_dim], device_ref + ), + TensorType( + DType.bfloat16, [batch_size, txt_seq_len, inner_dim], device_ref + ), + TensorType(DType.bfloat16, [batch_size, inner_dim], device_ref), + TensorType(DType.float32, list(cos.shape), device_ref), + TensorType(DType.float32, list(sin.shape), device_ref), + weights=block_weights, + ) + + result = compiled( + Tensor.from_dlpack(hidden_states), + Tensor.from_dlpack(encoder_hidden_states), + Tensor.from_dlpack(temb), + Tensor.from_dlpack(cos), + Tensor.from_dlpack(sin), + ) + return result[0], result[1] + + +def test_qwen_image_block( + qwen_config: dict[str, Any], + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + block_weights: dict[str, torch.Tensor], + image_rotary_emb_diffusers: tuple[torch.Tensor, torch.Tensor], + image_rotary_emb_max: tuple[torch.Tensor, torch.Tensor], +) -> None: + """Test that MAX QwenImageTransformerBlock matches diffusers output.""" + torch_txt_out, torch_img_out = generate_torch_outputs( + qwen_config, + hidden_states, + encoder_hidden_states, + temb, + block_weights, + image_rotary_emb_diffusers, + ) + + max_txt_out, max_img_out = generate_max_outputs( + qwen_config, + hidden_states, + encoder_hidden_states, + temb, + block_weights, + image_rotary_emb_max, + ) + + max_txt_torch = from_dlpack(max_txt_out).to(torch.bfloat16) + max_img_torch = from_dlpack(max_img_out).to(torch.bfloat16) + + # Image stream output + torch.testing.assert_close( + torch_img_out.to(torch.bfloat16), + max_img_torch, + rtol=2 * torch.finfo(torch.bfloat16).eps, + atol=16 * torch.finfo(torch.bfloat16).eps, + ) + + # Text stream output + torch.testing.assert_close( + torch_txt_out.to(torch.bfloat16), + max_txt_torch, + rtol=2 * torch.finfo(torch.bfloat16).eps, + atol=16 * torch.finfo(torch.bfloat16).eps, + ) diff --git a/max/tests/integration/architectures/qwen_image/test_scheduler_parity.py b/max/tests/integration/architectures/qwen_image/test_scheduler_parity.py new file mode 100644 index 00000000000..5300e014ddf --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/test_scheduler_parity.py @@ -0,0 +1,181 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Test that MAX scheduler sigma schedule matches diffusers logic exactly. + +Computes the reference sigma schedule using the same math as diffusers +(without loading the full 20B model) and compares against MAX scheduler. +""" + +import numpy as np +from max.pipelines.lib.diffusion_schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) + +PATCH_SIZE = 2 +VAE_SCALE_FACTOR = 8 + +# QwenImage scheduler_config.json values +QWEN_BASE_IMAGE_SEQ_LEN = 256 +QWEN_MAX_IMAGE_SEQ_LEN = 8192 +QWEN_BASE_SHIFT = 0.5 +QWEN_MAX_SHIFT = 0.9 +QWEN_USE_DYNAMIC_SHIFTING = True +QWEN_SHIFT_TERMINAL = 0.02 + + +def _compute_reference_sigmas( + height: int, + width: int, + num_inference_steps: int, +) -> np.ndarray: + """Compute sigma schedule using the same math as diffusers. + + This replicates FlowMatchEulerDiscreteScheduler.set_timesteps() from + diffusers, including dynamic shifting and stretch_shift_to_terminal. + """ + latent_h = height // VAE_SCALE_FACTOR + latent_w = width // VAE_SCALE_FACTOR + image_seq_len = (latent_h // PATCH_SIZE) * (latent_w // PATCH_SIZE) + + # Base sigmas: linearly spaced from 1.0 to 1/N + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) + + # Dynamic shifting: compute mu from linear interpolation + slope = (QWEN_MAX_SHIFT - QWEN_BASE_SHIFT) / ( + QWEN_MAX_IMAGE_SEQ_LEN - QWEN_BASE_IMAGE_SEQ_LEN + ) + mu = slope * image_seq_len + ( + QWEN_BASE_SHIFT - slope * QWEN_BASE_IMAGE_SEQ_LEN + ) + + # Exponential time shift: sigma(t) = exp(mu) / (exp(mu) + (1/t - 1)) + t_safe = np.clip(sigmas.astype(np.float64), 1e-7, 1.0) + sigmas = (np.exp(mu) / (np.exp(mu) + (1.0 / t_safe - 1.0))).astype( + np.float32 + ) + + # Terminal stretching: stretch so last sigma = shift_terminal + shift_terminal = QWEN_SHIFT_TERMINAL + if shift_terminal is not None and shift_terminal > 0: + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1.0 - shift_terminal) + sigmas = (1.0 - (one_minus_z / scale_factor)).astype(np.float32) + + sigmas = np.append(sigmas, np.float32(0.0)) + return sigmas + + +def _get_max_sigmas( + height: int, + width: int, + num_inference_steps: int, +) -> np.ndarray: + """Get sigma schedule from MAX scheduler.""" + scheduler = FlowMatchEulerDiscreteScheduler( + base_image_seq_len=QWEN_BASE_IMAGE_SEQ_LEN, + max_image_seq_len=QWEN_MAX_IMAGE_SEQ_LEN, + base_shift=QWEN_BASE_SHIFT, + max_shift=QWEN_MAX_SHIFT, + use_dynamic_shifting=QWEN_USE_DYNAMIC_SHIFTING, + shift_terminal=QWEN_SHIFT_TERMINAL, + ) + image_seq_len = (height // (VAE_SCALE_FACTOR * PATCH_SIZE)) * ( + width // (VAE_SCALE_FACTOR * PATCH_SIZE) + ) + _, sigmas = scheduler.retrieve_timesteps_and_sigmas( + image_seq_len=image_seq_len, + num_inference_steps=num_inference_steps, + ) + return sigmas + + +def test_sigma_schedule_matches_reference() -> None: + """Verify MAX sigma schedule matches reference diffusers math (fp32).""" + height, width, steps = 1024, 1024, 20 + + ref_sigmas = _compute_reference_sigmas(height, width, steps) + max_sigmas = _get_max_sigmas(height, width, steps) + + assert ref_sigmas.shape == max_sigmas.shape, ( + f"Shape mismatch: ref={ref_sigmas.shape} vs MAX={max_sigmas.shape}" + ) + np.testing.assert_allclose( + max_sigmas, + ref_sigmas, + atol=1e-6, + rtol=1e-5, + err_msg="Sigma schedules differ between MAX and reference", + ) + + +def test_sigma_schedule_50_steps() -> None: + """Test sigma schedule with 50 steps at various resolutions.""" + for height, width in [(512, 512), (1024, 1024), (768, 1024)]: + ref = _compute_reference_sigmas(height, width, 50) + max_ = _get_max_sigmas(height, width, 50) + np.testing.assert_allclose( + max_, + ref, + atol=1e-6, + rtol=1e-5, + err_msg=f"Mismatch at {height}x{width}, 50 steps", + ) + + +def test_shift_terminal_effect() -> None: + """Verify shift_terminal stretches the last sigma correctly.""" + shift_terminal = 0.02 + scheduler = FlowMatchEulerDiscreteScheduler( + base_image_seq_len=256, + max_image_seq_len=8192, + base_shift=0.5, + max_shift=0.9, + use_dynamic_shifting=True, + shift_terminal=shift_terminal, + ) + _, sigmas = scheduler.retrieve_timesteps_and_sigmas( + image_seq_len=4096, + num_inference_steps=50, + ) + # Last non-zero sigma should equal shift_terminal + assert abs(sigmas[-2] - shift_terminal) < 1e-5, ( + f"Last non-zero sigma {sigmas[-2]} != shift_terminal {shift_terminal}" + ) + assert sigmas[-1] == 0.0 + + +def test_no_shift_terminal_preserves_behavior() -> None: + """Without shift_terminal, scheduler should behave as before.""" + scheduler_with = FlowMatchEulerDiscreteScheduler( + base_image_seq_len=256, + max_image_seq_len=4096, + base_shift=0.5, + max_shift=1.15, + use_dynamic_shifting=True, + shift_terminal=None, + ) + scheduler_without = FlowMatchEulerDiscreteScheduler( + base_image_seq_len=256, + max_image_seq_len=4096, + base_shift=0.5, + max_shift=1.15, + use_dynamic_shifting=True, + ) + _, sigmas_with = scheduler_with.retrieve_timesteps_and_sigmas( + image_seq_len=4096, num_inference_steps=50 + ) + _, sigmas_without = scheduler_without.retrieve_timesteps_and_sigmas( + image_seq_len=4096, num_inference_steps=50 + ) + np.testing.assert_array_equal(sigmas_with, sigmas_without) diff --git a/max/tests/integration/architectures/qwen_image/test_text_encoder_parity.py b/max/tests/integration/architectures/qwen_image/test_text_encoder_parity.py new file mode 100644 index 00000000000..470c9888ebf --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/test_text_encoder_parity.py @@ -0,0 +1,220 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Test that MAX Qwen2.5-VL text encoder matches HuggingFace Qwen2Model output. + +Uses a tiny Qwen2 config (2 layers, small dims) with random weights to verify +the forward pass matches exactly. This catches: +- Missing final RMSNorm +- Wrong RoPE interleaving convention +- Weight loading mismatches +- Attention scale differences +""" + +from __future__ import annotations + +import numpy as np +import torch +from max.driver import Accelerator +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import DeviceRef, TensorType +from max.pipelines.architectures.qwen2_5vl.encoder.model_config import ( + Qwen25VLTextEncoderConfigBase, +) +from max.pipelines.architectures.qwen2_5vl.encoder.qwen25vl import ( + Qwen25VLTextEncoderTransformer, +) +from torch.utils.dlpack import from_dlpack +from transformers import Qwen2Config, Qwen2Model + +# Small config for fast testing (2 layers, 256 hidden, 4 heads) +_HIDDEN_SIZE = 256 +_NUM_HEADS = 4 +_NUM_KV_HEADS = 2 +_NUM_LAYERS = 2 +_INTERMEDIATE_SIZE = 512 +_VOCAB_SIZE = 1024 +_RMS_NORM_EPS = 1e-6 +_ROPE_THETA = 1000000.0 +_MAX_SEQ_LEN = 512 +_HEAD_DIM = 64 + +_SEQ_LEN = 16 +_SEED = 42 + + +def _hf_config() -> Qwen2Config: + return Qwen2Config( + vocab_size=_VOCAB_SIZE, + hidden_size=_HIDDEN_SIZE, + intermediate_size=_INTERMEDIATE_SIZE, + num_hidden_layers=_NUM_LAYERS, + num_attention_heads=_NUM_HEADS, + num_key_value_heads=_NUM_KV_HEADS, + max_position_embeddings=_MAX_SEQ_LEN, + rms_norm_eps=_RMS_NORM_EPS, + rope_theta=_ROPE_THETA, + use_sliding_window=False, + ) + + +def _max_config() -> Qwen25VLTextEncoderConfigBase: + return Qwen25VLTextEncoderConfigBase( + hidden_size=_HIDDEN_SIZE, + num_attention_heads=_NUM_HEADS, + num_key_value_heads=_NUM_KV_HEADS, + num_hidden_layers=_NUM_LAYERS, + intermediate_size=_INTERMEDIATE_SIZE, + vocab_size=_VOCAB_SIZE, + rms_norm_eps=_RMS_NORM_EPS, + rope_theta=_ROPE_THETA, + max_seq_len=_MAX_SEQ_LEN, + head_dim=_HEAD_DIM, + device=DeviceRef.GPU(), + ) + + +def _build_hf_model() -> tuple[Qwen2Model, dict[str, torch.Tensor]]: + """Create HF Qwen2Model with random weights and return model + state_dict.""" + torch.manual_seed(_SEED) + model = Qwen2Model(_hf_config()).to(dtype=torch.bfloat16, device="cuda") + model.eval() + state_dict = dict(model.state_dict()) + return model, state_dict + + +def _build_max_model( + hf_state_dict: dict[str, torch.Tensor], +) -> object: + """Create MAX text encoder, compile with HF weights.""" + device_ref = Accelerator() + + with F.lazy(): + model = Qwen25VLTextEncoderTransformer(_max_config()) + model.to(device_ref) + + compiled = model.compile( + TensorType(DType.int64, shape=[_SEQ_LEN], device=device_ref), + weights=hf_state_dict, + ) + return compiled + + +@torch.no_grad() +def _run_hf(model: Qwen2Model, token_ids: torch.Tensor) -> torch.Tensor: + """Run HF model and return the last hidden state (after final norm).""" + out = model(token_ids, output_hidden_states=True) + # hidden_states[-1] is the output after the final RMSNorm + return out.hidden_states[-1] + + +def _run_max(compiled_model: object, token_ids_np: np.ndarray) -> np.ndarray: + """Run MAX model and return the last hidden state.""" + input_tensor = Tensor.from_dlpack( + torch.tensor(token_ids_np, dtype=torch.int64, device="cuda") + ) + result = compiled_model(input_tensor) # type: ignore[operator] + last_hs = result[-1] + return np.from_dlpack(from_dlpack(last_hs).float().cpu()) + + +def _random_tokens(seed_offset: int = 0) -> tuple[np.ndarray, torch.Tensor]: + """Generate random token IDs as both numpy and torch tensors.""" + rng = np.random.RandomState(_SEED + seed_offset) + token_ids_np = rng.randint(0, _VOCAB_SIZE, size=(_SEQ_LEN,)).astype( + np.int64 + ) + token_ids_torch = torch.tensor( + token_ids_np, dtype=torch.long, device="cuda" + ).unsqueeze(0) + return token_ids_np, token_ids_torch + + +def test_text_encoder_matches_hf() -> None: + """Verify MAX text encoder output matches HF Qwen2Model for random weights.""" + hf_model, hf_state_dict = _build_hf_model() + max_model = _build_max_model(hf_state_dict) + + token_ids_np, token_ids_torch = _random_tokens(seed_offset=1) + + hf_np = _run_hf(hf_model, token_ids_torch)[0].float().cpu().numpy() + max_np = _run_max(max_model, token_ids_np) + + assert hf_np.shape == max_np.shape, ( + f"Shape mismatch: HF={hf_np.shape} vs MAX={max_np.shape}" + ) + + # Per-token cosine similarity + for i in range(hf_np.shape[0]): + cos = float( + np.dot(hf_np[i], max_np[i]) + / (np.linalg.norm(hf_np[i]) * np.linalg.norm(max_np[i]) + 1e-10) + ) + assert cos > 0.99, f"Token {i}: cosine similarity {cos:.6f} < 0.99" + + # Global cosine similarity + cos_global = float( + np.dot(hf_np.flatten(), max_np.flatten()) + / (np.linalg.norm(hf_np.flatten()) * np.linalg.norm(max_np.flatten())) + ) + assert cos_global > 0.99, ( + f"Global cosine similarity {cos_global:.6f} < 0.99" + ) + + # Norm ratio should be close to 1.0 + hf_norms = np.linalg.norm(hf_np, axis=-1) + max_norms = np.linalg.norm(max_np, axis=-1) + norm_ratios = max_norms / (hf_norms + 1e-10) + assert np.all(norm_ratios > 0.9) and np.all(norm_ratios < 1.1), ( + f"Norm ratio out of [0.9, 1.1]: " + f"min={norm_ratios.min():.4f}, max={norm_ratios.max():.4f}" + ) + + +def test_text_encoder_norm_range() -> None: + """Verify output norms match; catches a missing final RMSNorm (~2x off).""" + hf_model, hf_state_dict = _build_hf_model() + max_model = _build_max_model(hf_state_dict) + + token_ids_np, token_ids_torch = _random_tokens(seed_offset=2) + + hf_norms = ( + _run_hf(hf_model, token_ids_torch)[0].float().norm(dim=-1).cpu().numpy() + ) + max_norms = np.linalg.norm(_run_max(max_model, token_ids_np), axis=-1) + + ratio = float(max_norms.mean()) / float(hf_norms.mean()) + assert 0.8 < ratio < 1.2, ( + f"Mean norm ratio {ratio:.4f} outside [0.8, 1.2]. " + f"HF={float(hf_norms.mean()):.1f}, MAX={float(max_norms.mean()):.1f}. " + f"Missing final RMSNorm?" + ) + + +def test_text_encoder_weight_count() -> None: + """Verify all HF Qwen2Model parameters have a match in the MAX module.""" + with F.lazy(): + model = Qwen25VLTextEncoderTransformer(_max_config()) + + max_param_names = set(name for name, _ in model.parameters) + + hf_model = Qwen2Model(_hf_config()) + hf_keys = set(hf_model.state_dict().keys()) + + missing = hf_keys - max_param_names + assert len(missing) == 0, ( + f"MAX model is missing {len(missing)} HF weight keys: " + f"{sorted(missing)[:10]}" + ) diff --git a/max/tests/integration/architectures/qwen_image/testdata/BUILD.bazel b/max/tests/integration/architectures/qwen_image/testdata/BUILD.bazel new file mode 100644 index 00000000000..bc81c0b5d5c --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/testdata/BUILD.bazel @@ -0,0 +1,16 @@ +package(default_visibility = [ + "//:__pkg__", + "//SDK/integration-test:__subpackages__", + "//max/tests:__subpackages__", + "//oss/modular/max/tests:__subpackages__", +]) + +filegroup( + name = "testdata", + testonly = True, + srcs = glob( + [ + "*.json", + ], + ), +) diff --git a/max/tests/integration/architectures/qwen_image/testdata/config.json b/max/tests/integration/architectures/qwen_image/testdata/config.json new file mode 100644 index 00000000000..52fc1fa0a68 --- /dev/null +++ b/max/tests/integration/architectures/qwen_image/testdata/config.json @@ -0,0 +1,12 @@ +{ + "patch_size": 2, + "in_channels": 64, + "num_layers": 60, + "attention_head_dim": 128, + "num_attention_heads": 24, + "joint_attention_dim": 3584, + "guidance_embeds": false, + "axes_dims_rope": [16, 56, 56], + "rope_theta": 10000, + "eps": 1e-6 +}