From 3d40d44f78536359b8a40403584aa9625e363bb9 Mon Sep 17 00:00:00 2001 From: jglee-sqbits Date: Tue, 24 Mar 2026 08:37:05 +0000 Subject: [PATCH] [Models] Add QwenImage runtime (rebased onto main) --- .../architectures/qwen2_5vl/tokenizer.py | 1 - .../qwen_image/layers/qwen_image_attention.py | 657 ++++++++++++++++++ .../architectures/qwen_image/qwen_image.py | 256 +++++++ max/python/max/pipelines/core/context.py | 8 + .../max/pipelines/lib/pixel_tokenizer.py | 528 ++++++++++---- .../max/pipelines/lib/qwen_image_processor.py | 139 ++++ 6 files changed, 1434 insertions(+), 155 deletions(-) create mode 100644 max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py create mode 100644 max/python/max/pipelines/architectures/qwen_image/qwen_image.py create mode 100644 max/python/max/pipelines/lib/qwen_image_processor.py diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/tokenizer.py b/max/python/max/pipelines/architectures/qwen2_5vl/tokenizer.py index ad037fde376..95fcda6428e 100644 --- a/max/python/max/pipelines/architectures/qwen2_5vl/tokenizer.py +++ b/max/python/max/pipelines/architectures/qwen2_5vl/tokenizer.py @@ -95,7 +95,6 @@ def qwen2_5vl_image_preprocessing( grid_h = height // patch_size grid_w = width // patch_size - # Check if spatial merging is possible early if grid_h % merge_size != 0 or grid_w % merge_size != 0: raise ValueError( f"Spatial merging is not possible because grid_h {grid_h} % merge_size {merge_size} != 0 or grid_w {grid_w} % merge_size {merge_size} != 0" diff --git a/max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py b/max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py new file mode 100644 index 00000000000..9320fc564f8 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py @@ -0,0 +1,657 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage attention layers: dual-stream attention, FeedForward, and transformer block. + +Weight key naming follows HuggingFace diffusers conventions: +- Attention: attn.to_q, attn.to_k, attn.to_v, attn.to_out.0, attn.add_q_proj, etc. +- FeedForward: img_mlp.net.0.proj (SwiGLU), img_mlp.net.2 (output linear) +- Modulation: img_mod.1 (Linear after SiLU), txt_mod.1 +- Norms: img_norm1, img_norm2, txt_norm1, txt_norm2 (no affine, no weights) +""" + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + +from .embeddings import apply_rotary_emb +from .normalizations import LayerNormNoAffine + +# --------------------------------------------------------------------------- +# FeedForward (matches diffusers naming: net.0.proj, net.2) +# --------------------------------------------------------------------------- + + +class _QwenImageGELU(Module): + """GELU approximate activation with a Linear projection. + + Weight key: `proj.weight`, `proj.bias` + In the block: `img_mlp.net.0.proj.weight` + """ + + def __init__( + self, + dim_in: int, + dim_out: int, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + self.proj = Linear( + in_dim=dim_in, + out_dim=dim_out, + dtype=dtype, + device=device, + has_bias=bias, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + return ops.gelu(self.proj(x)) + + +class _QwenImageDropout(Module): + """No-op dropout for inference. Occupies index 1 in FeedForward.net.""" + + def __init__(self): + super().__init__() + + def __call__(self, x: TensorValue) -> TensorValue: + return x + + +class QwenImageFeedForward(Module): + """FeedForward matching diffusers key naming. + + Produces keys: + net.0.proj.weight, net.0.proj.bias (GELU approximate projection) + net.2.weight, net.2.bias (output linear) + """ + + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: float = 4.0, + inner_dim: int | None = None, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out or dim + + self.net: LayerList = LayerList( + [ + _QwenImageGELU( + dim, inner_dim, bias=bias, dtype=dtype, device=device + ), + _QwenImageDropout(), + Linear( + in_dim=inner_dim, + out_dim=dim_out, + dtype=dtype, + device=device, + has_bias=bias, + ), + ] + ) + + def __call__(self, x: TensorValue) -> TensorValue: + x = self.net[0](x) # GELU projection + # net[1] is dropout (no-op at inference) + x = self.net[2](x) # output linear + return x + + +# --------------------------------------------------------------------------- +# Attention (matches diffusers key naming: to_q, to_k, to_v, to_out.0, ...) +# --------------------------------------------------------------------------- + + +class QwenImageAttention(Module): + """Dual-stream attention for QwenImage transformer blocks. + + Key naming matches HuggingFace diffusers: + - to_q.weight/bias, to_k.weight/bias, to_v.weight/bias + - to_out.0.weight/bias (LayerList for correct .0. indexing) + - add_q_proj.weight/bias, add_k_proj.weight/bias, add_v_proj.weight/bias + - to_add_out.weight/bias + - norm_q.weight, norm_k.weight, norm_added_q.weight, norm_added_k.weight + """ + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + bias: bool = True, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int | None = None, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.scale = 1.0 / (self.head_dim**0.5) + out_dim = out_dim if out_dim is not None else query_dim + + self.to_q = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=bias, + ) + self.to_k = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=bias, + ) + self.to_v = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=bias, + ) + + self.norm_q = RMSNorm(dim_head, dtype=dtype, eps=eps) + self.norm_k = RMSNorm(dim_head, dtype=dtype, eps=eps) + + # Use LayerList so key becomes to_out.0.weight (not to_out_0.weight) + self.to_out: LayerList = LayerList( + [ + Linear( + in_dim=self.inner_dim, + out_dim=out_dim, + dtype=dtype, + device=device, + has_bias=out_bias, + ) + ] + ) + + self.norm_added_q: RMSNorm | None + self.norm_added_k: RMSNorm | None + self.add_q_proj: Linear | None + self.add_k_proj: Linear | None + self.add_v_proj: Linear | None + self.to_add_out: Linear | None + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, dtype=dtype, eps=eps) + self.norm_added_k = RMSNorm(dim_head, dtype=dtype, eps=eps) + self.add_q_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=added_proj_bias, + ) + self.add_k_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=added_proj_bias, + ) + self.add_v_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=added_proj_bias, + ) + self.to_add_out = Linear( + in_dim=self.inner_dim, + out_dim=query_dim, + dtype=dtype, + device=device, + has_bias=out_bias, + ) + else: + self.norm_added_q = None + self.norm_added_k = None + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + self.to_add_out = None + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue | None = None, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + ) -> TensorValue | tuple[TensorValue, TensorValue]: + batch_size = hidden_states.shape[0] + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + seq_len = query.shape[1] + + query = ops.reshape( + query, [batch_size, seq_len, self.heads, self.head_dim] + ) + key = ops.reshape(key, [batch_size, seq_len, self.heads, self.head_dim]) + value = ops.reshape( + value, [batch_size, seq_len, self.heads, self.head_dim] + ) + + query = self.norm_q(query) + key = self.norm_k(key) + + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + if ( + self.add_q_proj is None + or self.add_k_proj is None + or self.add_v_proj is None + ): + raise ValueError("Encoder projections not initialized") + encoder_query = self.add_q_proj(encoder_hidden_states) + encoder_key = self.add_k_proj(encoder_hidden_states) + encoder_value = self.add_v_proj(encoder_hidden_states) + encoder_seq_len = encoder_query.shape[1] + encoder_query = ops.reshape( + encoder_query, + [batch_size, encoder_seq_len, self.heads, self.head_dim], + ) + encoder_key = ops.reshape( + encoder_key, + [batch_size, encoder_seq_len, self.heads, self.head_dim], + ) + encoder_value = ops.reshape( + encoder_value, + [batch_size, encoder_seq_len, self.heads, self.head_dim], + ) + + if self.norm_added_q is None or self.norm_added_k is None: + raise ValueError("Encoder normalizations not initialized") + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = ops.concat([encoder_query, query], axis=1) + key = ops.concat([encoder_key, key], axis=1) + value = ops.concat([encoder_value, value], axis=1) + + original_dtype = query.dtype + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + if query.dtype != original_dtype: + query = ops.cast(query, original_dtype) + if key.dtype != original_dtype: + key = ops.cast(key, original_dtype) + + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=self.scale, + ) + + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + hidden_states = ops.reshape( + hidden_states, [batch_size, seq_len, self.inner_dim] + ) + if hidden_states.dtype != query.dtype: + hidden_states = ops.cast(hidden_states, query.dtype) + + if encoder_hidden_states is not None: + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_out = hidden_states[:, :encoder_seq_len, :] + hidden_out = hidden_states[:, encoder_seq_len:, :] + + hidden_out = self.to_out[0](hidden_out) + if self.to_add_out is None: + raise ValueError("Encoder output projection not initialized") + encoder_out = self.to_add_out(encoder_out) + + return hidden_out, encoder_out + else: + hidden_states = self.to_out[0](hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# Per-block Modulation (matches diffusers: img_mod.1.weight, txt_mod.1.weight) +# --------------------------------------------------------------------------- + + +class _SiLUPlaceholder(Module): + """Placeholder at index 0 in LayerList; SiLU has no learnable params.""" + + def __init__(self): + super().__init__() + + def __call__(self, x: TensorValue) -> TensorValue: + return ops.silu(x) + + +def _make_block_modulation( + dim: int, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, +) -> LayerList: + """Create per-block modulation as LayerList[SiLU_placeholder, Linear]. + + Produces weight keys: `{attr_name}.1.weight` and `{attr_name}.1.bias` + matching the diffusers convention img_mod.1.weight / txt_mod.1.weight. + """ + return LayerList( + [ + _SiLUPlaceholder(), + Linear( + in_dim=dim, + out_dim=dim * 6, + dtype=dtype, + device=device, + has_bias=bias, + ), + ] + ) + + +# --------------------------------------------------------------------------- +# Transformer Block (per-block img_mod, txt_mod, img_mlp, txt_mlp) +# --------------------------------------------------------------------------- + + +class QwenImageTransformerBlock(Module): + """Dual-stream transformer block with per-block modulation. + + Weight key structure per block: + img_mod.1.{weight,bias} + txt_mod.1.{weight,bias} + attn.to_q.{weight,bias}, attn.to_k.{weight,bias}, ... + img_mlp.net.0.proj.{weight,bias}, img_mlp.net.2.{weight,bias} + txt_mlp.net.0.proj.{weight,bias}, txt_mlp.net.2.{weight,bias} + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + eps: float = 1e-6, + bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + # Per-block modulation (img_mod, txt_mod) + self.img_mod: LayerList = _make_block_modulation( + dim, bias=bias, dtype=dtype, device=device + ) + self.txt_mod: LayerList = _make_block_modulation( + dim, bias=bias, dtype=dtype, device=device + ) + + # Norms (no affine → no weights in state_dict) + self.img_norm1 = LayerNormNoAffine(eps=eps) + self.img_norm2 = LayerNormNoAffine(eps=eps) + self.txt_norm1 = LayerNormNoAffine(eps=eps) + self.txt_norm2 = LayerNormNoAffine(eps=eps) + + # Dual-stream attention + self.attn = QwenImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=bias, + added_proj_bias=bias, + out_bias=bias, + eps=eps, + dtype=dtype, + device=device, + ) + + # Feedforward (img_mlp, txt_mlp) + self.img_mlp = QwenImageFeedForward( + dim=dim, + dim_out=dim, + mult=mlp_ratio, + bias=bias, + dtype=dtype, + device=device, + ) + self.txt_mlp = QwenImageFeedForward( + dim=dim, + dim_out=dim, + mult=mlp_ratio, + bias=bias, + dtype=dtype, + device=device, + ) + + def _apply_modulation( + self, + x: TensorValue, + shift: TensorValue, + scale: TensorValue, + ) -> TensorValue: + """Apply shift/scale modulation: (1 + scale) * x + shift.""" + return (1 + scale) * x + shift + + def _apply_split_modulation( + self, + x: TensorValue, + mod_real: TensorValue, + mod_zero: TensorValue, + condition_token_mask: TensorValue, + mod_idx: int, + ) -> TensorValue: + """Apply different modulation to noise vs condition tokens.""" + condition_token_mask = ops.broadcast_to(condition_token_mask, x.shape) + + # mod has 6 chunks: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp + # We need shift and scale at mod_idx and mod_idx+1 + real_chunks = ops.chunk(mod_real, 6, axis=-1) + zero_chunks = ops.chunk(mod_zero, 6, axis=-1) + shift_r, scale_r = real_chunks[mod_idx], real_chunks[mod_idx + 1] + shift_z, scale_z = zero_chunks[mod_idx], zero_chunks[mod_idx + 1] + + noise_modulated = (1 + scale_r) * x + shift_r + condition_modulated = (1 + scale_z) * x + shift_z + return ops.where( + condition_token_mask, + condition_modulated, + noise_modulated, + ) + + def _apply_split_gate( + self, + x: TensorValue, + gate_real: TensorValue, + gate_zero: TensorValue, + condition_token_mask: TensorValue, + ) -> TensorValue: + """Apply different gate to noise vs condition tokens.""" + condition_token_mask = ops.broadcast_to(condition_token_mask, x.shape) + gated_noise = x * gate_real + gated_condition = x * gate_zero + return ops.where( + condition_token_mask, + gated_condition, + gated_noise, + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + temb: TensorValue, + image_rotary_emb: tuple[TensorValue, TensorValue] | None = None, + temb_zero: TensorValue | None = None, + condition_token_mask: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue]: + # Compute per-block modulation params from temb + # Compute silu once and reuse for both modulation projections. + temb_activated = ops.silu(temb) + img_mod = self.img_mod[1](temb_activated) + txt_mod = self.txt_mod[1](temb_activated) + + if len(img_mod.shape) == 2: + img_mod = ops.unsqueeze(img_mod, 1) + txt_mod = ops.unsqueeze(txt_mod, 1) + + # zero_cond_t path: separate modulation for condition tokens + img_mod_zero: TensorValue | None = None + if temb_zero is not None: + temb_zero_activated = ops.silu(temb_zero) + img_mod_zero = self.img_mod[1](temb_zero_activated) + if len(img_mod_zero.shape) == 2: + img_mod_zero = ops.unsqueeze(img_mod_zero, 1) + + img_mod_chunks = ops.chunk(img_mod, 6, axis=-1) + shift_msa, scale_msa, gate_msa = ( + img_mod_chunks[0], + img_mod_chunks[1], + img_mod_chunks[2], + ) + shift_mlp, scale_mlp, gate_mlp = ( + img_mod_chunks[3], + img_mod_chunks[4], + img_mod_chunks[5], + ) + + txt_mod_chunks = ops.chunk(txt_mod, 6, axis=-1) + c_shift_msa, c_scale_msa, c_gate_msa = ( + txt_mod_chunks[0], + txt_mod_chunks[1], + txt_mod_chunks[2], + ) + c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + txt_mod_chunks[3], + txt_mod_chunks[4], + txt_mod_chunks[5], + ) + + # Image stream - Attention + norm_hidden_states = self.img_norm1(hidden_states) + if img_mod_zero is not None and condition_token_mask is not None: + norm_hidden_states = self._apply_split_modulation( + norm_hidden_states, + img_mod, + img_mod_zero, + condition_token_mask, + 0, + ) + else: + norm_hidden_states = ( + 1 + scale_msa + ) * norm_hidden_states + shift_msa + + # Text stream - Attention + norm_encoder_hidden_states = self.txt_norm1(encoder_hidden_states) + norm_encoder_hidden_states = ( + 1 + c_scale_msa + ) * norm_encoder_hidden_states + c_shift_msa + + # Dual-stream attention + attn_output, context_attn_output = self.attn( + norm_hidden_states, + norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + # Image stream - Apply gate and residual + if img_mod_zero is not None and condition_token_mask is not None: + img_mod_zero_chunks = ops.chunk(img_mod_zero, 6, axis=-1) + attn_output = self._apply_split_gate( + attn_output, + gate_msa, + img_mod_zero_chunks[2], + condition_token_mask, + ) + else: + attn_output = gate_msa * attn_output + hidden_states = hidden_states + attn_output + + # Image stream - Feedforward + norm_hidden_states = self.img_norm2(hidden_states) + if img_mod_zero is not None and condition_token_mask is not None: + norm_hidden_states = self._apply_split_modulation( + norm_hidden_states, + img_mod, + img_mod_zero, + condition_token_mask, + 3, + ) + else: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp) + shift_mlp + ) + + ff_output = self.img_mlp(norm_hidden_states) + if img_mod_zero is not None and condition_token_mask is not None: + ff_output = self._apply_split_gate( + ff_output, + gate_mlp, + img_mod_zero_chunks[5], + condition_token_mask, + ) + else: + ff_output = gate_mlp * ff_output + hidden_states = hidden_states + ff_output + + # Text stream - Apply gate and residual + context_attn_output = c_gate_msa * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + # Text stream - Feedforward + norm_encoder_hidden_states = self.txt_norm2(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + ) + + context_ff_output = self.txt_mlp(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp * context_ff_output + ) + + if encoder_hidden_states.dtype == DType.float16: + encoder_hidden_states = ops.max(encoder_hidden_states, -65504.0) + encoder_hidden_states = ops.min(encoder_hidden_states, 65504.0) + + return encoder_hidden_states, hidden_states diff --git a/max/python/max/pipelines/architectures/qwen_image/qwen_image.py b/max/python/max/pipelines/architectures/qwen_image/qwen_image.py new file mode 100644 index 00000000000..b20d4805c4e --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen_image/qwen_image.py @@ -0,0 +1,256 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""QwenImage Transformer 2D Model. + +A 20B parameter MMDiT model for text-to-image generation with 60 dual-stream +blocks, 3D RoPE, and timestep-only embeddings (no guidance embedding). + +Weight key naming matches HuggingFace diffusers: +- img_in.{weight,bias} (input projection for image latents) +- txt_in.{weight,bias} (input projection for text embeddings) +- time_text_embed.timestep_embedder.{linear_1,linear_2}.{weight,bias} +- txt_norm.weight (RMSNorm for text output) +- transformer_blocks.{i}.* (per-block: img_mod, txt_mod, attn, img_mlp, txt_mlp) +- norm_out.linear.{weight,bias} (AdaLayerNormContinuous) +- proj_out.{weight,bias} (output projection) +""" + +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + +from .layers.embeddings import ( + QwenImagePosEmbed, + QwenImageTimestepProjEmbeddings, +) +from .layers.normalizations import AdaLayerNormContinuous +from .layers.qwen_image_attention import QwenImageTransformerBlock +from .model_config import QwenImageConfigBase + + +class QwenImageTransformer2DModel(Module): + """QwenImage Transformer with 60 dual-stream blocks. + + Key differences from Flux2: + - No guidance embedding (timestep only) + - No single-stream blocks (all 60 are dual-stream) + - 3D RoPE with axes [16, 56, 56] (T, H, W) + - Per-block modulation (img_mod, txt_mod per block) + - inner_dim = 24 * 128 = 3072 + """ + + def __init__( + self, + config: QwenImageConfigBase, + ): + super().__init__() + patch_size = config.patch_size + in_channels = config.in_channels + out_channels = config.out_channels + num_layers = config.num_layers + attention_head_dim = config.attention_head_dim + num_attention_heads = config.num_attention_heads + joint_attention_dim = config.joint_attention_dim + axes_dims_rope = config.axes_dims_rope + rope_theta = config.rope_theta + device = config.device + dtype = config.dtype + eps = config.eps + + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + # 1. Positional embeddings (3D RoPE: T, H, W) + self.pos_embed = QwenImagePosEmbed( + theta=rope_theta, axes_dim=axes_dims_rope + ) + + # 2. Timestep embeddings (no guidance) — key: time_text_embed.* + self.time_text_embed = QwenImageTimestepProjEmbeddings( + in_channels=256, + embedding_dim=self.inner_dim, + bias=True, + dtype=dtype, + device=device, + ) + + # 3. Input embeddings — keys: img_in.*, txt_in.* + self.img_in = Linear( + in_dim=in_channels, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.txt_in = Linear( + in_dim=joint_attention_dim, + out_dim=self.inner_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + + # 4. Text input norm — key: txt_norm.weight + self.txt_norm = RMSNorm(joint_attention_dim, dtype=dtype, eps=eps) + + # 5. Dual-stream transformer blocks (all 60 are dual-stream) + self.transformer_blocks: LayerList = LayerList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_ratio=4.0, + eps=eps, + bias=True, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers — keys: norm_out.linear.*, proj_out.* + self.norm_out = AdaLayerNormContinuous( + embedding_dim=self.inner_dim, + conditioning_embedding_dim=self.inner_dim, + dtype=dtype, + device=device, + eps=eps, + bias=True, + ) + self.proj_out = Linear( + in_dim=self.inner_dim, + out_dim=patch_size * patch_size * self.out_channels, + dtype=dtype, + device=device, + has_bias=True, + ) + + # Store config for input_types + self.max_device = device + self.max_dtype = dtype + self.in_channels = in_channels + self.joint_attention_dim = joint_attention_dim + self.zero_cond_t = config.zero_cond_t + + def input_types(self) -> tuple[TensorType, ...]: + hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "image_seq_len", self.in_channels], + device=self.max_device, + ) + encoder_hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "text_seq_len", self.joint_attention_dim], + device=self.max_device, + ) + timestep_type = TensorType( + self.max_dtype, shape=["batch_size"], device=self.max_device + ) + # 3D position IDs: (T, H, W) + img_ids_type = TensorType( + DType.int64, + shape=["batch_size", "image_seq_len", 3], + device=self.max_device, + ) + txt_ids_type = TensorType( + DType.int64, + shape=["batch_size", "text_seq_len", 3], + device=self.max_device, + ) + + result = ( + hidden_states_type, + encoder_hidden_states_type, + timestep_type, + img_ids_type, + txt_ids_type, + ) + + return result + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + timestep: TensorValue, + img_ids: TensorValue, + txt_ids: TensorValue, + ) -> tuple[TensorValue]: + """Forward pass through QwenImage Transformer. + + Args: + hidden_states: Image latents [B, img_seq, in_channels]. + encoder_hidden_states: Text embeddings [B, txt_len, joint_attention_dim]. + timestep: Denoising timestep [B] (scaled to [0, 1] range). + img_ids: Image position IDs [B, image_seq_len, 3] (T, H, W). + txt_ids: Text position IDs [B, text_seq_len, 3]. + + Returns: + Denoised output of shape [B, img_seq, patch_size^2 * out_channels]. + """ + # Handle batch dimension in ids + img_ids = img_ids[0] # [img_seq, 3] + txt_ids = txt_ids[0] # [txt_len, 3] + + # 1. Calculate timestep embedding + timestep_scaled = ops.cast(timestep * 1000.0, hidden_states.dtype) + temb = self.time_text_embed(timestep_scaled) + + # For zero_cond_t: compute temb for timestep=0 (condition tokens) + # and derive the condition-token mask from the image-token T + # coordinate so edit runs stay shape-dynamic. + temb_zero: TensorValue | None = None + condition_token_mask: TensorValue | None = None + if self.zero_cond_t: + zero_t = timestep_scaled * 0.0 + temb_zero = self.time_text_embed(zero_t) + token_types = img_ids[:, 0] + is_condition_token = ops.not_equal( + token_types, + ops.constant(0, DType.int64, device=token_types.device), + ) + condition_token_mask = ops.unsqueeze( + ops.unsqueeze(is_condition_token, 0), -1 + ) + + # 2. Input projection (txt_norm applied before txt_in projection) + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + # 3. Calculate RoPE embeddings + ids = ops.concat([txt_ids, img_ids], axis=0) + image_rotary_emb = self.pos_embed(ids) + + # 4. Dual-stream transformer blocks (all 60) + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + temb_zero=temb_zero, + condition_token_mask=condition_token_mask, + ) + + # 5. Output projection (image tokens only, discard text) + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return (output,) diff --git a/max/python/max/pipelines/core/context.py b/max/python/max/pipelines/core/context.py index 28cc3db4ce2..5166f9a6ed0 100644 --- a/max/python/max/pipelines/core/context.py +++ b/max/python/max/pipelines/core/context.py @@ -729,6 +729,14 @@ class PixelContext: num_images_per_prompt: int = field(default=1) input_image: npt.NDArray[np.uint8] | None = field(default=None) """Input image as numpy array (H, W, C) in uint8 format for image-to-image generation.""" + input_images: list[npt.NDArray[np.uint8]] | None = field(default=None) + """Optional list of input images for multi-image conditioning flows.""" + prompt_images: list[npt.NDArray[np.uint8]] | None = field(default=None) + """Optional prompt-encoder images for Qwen image edit flows.""" + vae_condition_images: list[npt.NDArray[np.uint8]] | None = field( + default=None + ) + """Optional VAE conditioning images for Qwen image edit flows.""" image: npt.NDArray[np.uint8] | None = field(default=None) """Decoded output image (H, W, C) uint8 [0, 255]. Set after generation completes.""" output_format: str = field(default="jpeg") diff --git a/max/python/max/pipelines/lib/pixel_tokenizer.py b/max/python/max/pipelines/lib/pixel_tokenizer.py index 07af959179e..04b8984f6c3 100644 --- a/max/python/max/pipelines/lib/pixel_tokenizer.py +++ b/max/python/max/pipelines/lib/pixel_tokenizer.py @@ -18,6 +18,7 @@ import asyncio import base64 import logging +import math import threading from collections.abc import Callable from enum import Enum @@ -41,6 +42,7 @@ from transformers import AutoTokenizer from .diffusion_schedulers import SchedulerFactory +from .qwen_image_processor import Qwen2_5VLPromptImageProcessor if TYPE_CHECKING: import PIL.Image @@ -48,6 +50,22 @@ logger = logging.getLogger("max.pipelines") +QWEN_EDIT_PROMPT_IMAGE_SIZE = 384 * 384 +QWEN_EDIT_VAE_IMAGE_SIZE = 1024 * 1024 +QWEN_EDIT_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, " + "texture, objects, background), then explain how the user's text " + "instruction should alter or modify the image. Generate a new image " + "that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWEN_EDIT_IMAGE_PROMPT_TEMPLATE = ( + "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" +) + async def run_with_default_executor( fn: Callable[..., Any], *args: Any, **kwargs: Any @@ -99,6 +117,18 @@ class PipelineClassName(str, Enum): FLUX2 = "Flux2Pipeline" FLUX2_KLEIN = "Flux2KleinPipeline" ZIMAGE = "ZImagePipeline" + QWENIMAGE = "QwenImagePipeline" + QWENIMAGE_EDIT = "QwenImageEditPipeline" + QWENIMAGE_EDIT_PLUS = "QwenImageEditPlusPipeline" + + @property + def is_qwen_image_family(self) -> bool: + """Returns whether the pipeline belongs to the Qwen image family.""" + return self in { + PipelineClassName.QWENIMAGE, + PipelineClassName.QWENIMAGE_EDIT, + PipelineClassName.QWENIMAGE_EDIT_PLUS, + } @classmethod def from_diffusers_config( @@ -110,6 +140,11 @@ def from_diffusers_config( raise KeyError( "diffusers_config is missing required key '_class_name'." ) + if raw in { + cls.QWENIMAGE_EDIT.value, + cls.QWENIMAGE_EDIT_PLUS.value, + }: + return cls.QWENIMAGE try: return cls(raw) except ValueError as e: @@ -137,6 +172,7 @@ class PixelGenerationTokenizer( max_length: Maximum sequence length for the primary tokenizer. secondary_max_length: Maximum sequence length for the secondary tokenizer, if used. trust_remote_code: Whether to trust remote code from the model. + context_validators: Optional list of validators to run on PixelContext. """ def __init__( @@ -151,6 +187,7 @@ def __init__( secondary_max_length: int | None = None, trust_remote_code: bool = False, default_num_inference_steps: int = 50, + context_validators: list[Callable[[PixelContext], None]] | None = None, **unused_kwargs, ) -> None: self.model_path = model_path @@ -205,6 +242,10 @@ def __init__( "- '--trust-remote-code' is needed but not set\n" ) from e + self._context_validators = ( + context_validators if context_validators else [] + ) + # Extract diffusers_config if not pipeline_config or not hasattr( pipeline_config.model, "diffusers_config" @@ -224,6 +265,30 @@ def __init__( self._pipeline_class_name = PipelineClassName.from_diffusers_config( self.diffusers_config ) + self._raw_pipeline_class_name = self.diffusers_config.get("_class_name") + self._is_qwen_image_edit_family = self._raw_pipeline_class_name in { + PipelineClassName.QWENIMAGE_EDIT.value, + PipelineClassName.QWENIMAGE_EDIT_PLUS.value, + } + self._qwen_edit_image_token_id: int | None = None + self._qwen_edit_image_processor: ( + Qwen2_5VLPromptImageProcessor | None + ) = None + if self._is_qwen_image_edit_family: + self._qwen_edit_image_token_id = ( + self.delegate.convert_tokens_to_ids("<|image_pad|>") + ) + text_encoder_config = ( + self.diffusers_config.get("components", {}) + .get("text_encoder", {}) + .get("config_dict", {}) + ) + vision_config = text_encoder_config.get("vision_config", {}) + self._qwen_edit_image_processor = Qwen2_5VLPromptImageProcessor( + patch_size=vision_config.get("patch_size", 14), + temporal_patch_size=vision_config.get("temporal_patch_size", 2), + merge_size=vision_config.get("spatial_merge_size", 2), + ) # Preserve tokenizer attention masks so downstream text encoders can # derive additive attention bias directly from tokenizer semantics. @@ -237,9 +302,12 @@ def __init__( # Compute static VAE scale factor block_out_channels = vae_config.get("block_out_channels", None) - self._vae_scale_factor = ( - 2 ** (len(block_out_channels) - 1) if block_out_channels else 8 - ) + if block_out_channels: + self._vae_scale_factor = 2 ** (len(block_out_channels) - 1) + elif self._pipeline_class_name.is_qwen_image_family: + self._vae_scale_factor = 8 + else: + self._vae_scale_factor = 8 # Store static model dimensions self._default_sample_size = 128 @@ -290,6 +358,21 @@ def _prepare_latent_image_ids( latent_image_ids[np.newaxis, :, :], (batch_size, 1, 1) ) return latent_image_ids + elif self._pipeline_class_name.is_qwen_image_family: + t_coords = np.zeros((height, width), dtype=np.int64) + h_centered = np.arange(height, dtype=np.int64) - ( + height - height // 2 + ) + w_centered = np.arange(width, dtype=np.int64) - (width - width // 2) + h_coords, w_coords = np.meshgrid( + h_centered, w_centered, indexing="ij" + ) + latent_image_ids = np.stack([t_coords, h_coords, w_coords], axis=-1) + latent_image_ids = latent_image_ids.reshape(-1, 3) + latent_image_ids = np.tile( + latent_image_ids[np.newaxis, :, :], (batch_size, 1, 1) + ) + return latent_image_ids.astype(np.float32, copy=False) else: latent_image_ids = np.zeros((height, width, 3)) latent_image_ids[..., 1] = ( @@ -344,16 +427,20 @@ def _resize_with_center_crop( def _preprocess_input_image( self, image: PIL.Image.Image | npt.NDArray[np.uint8], + target_height: int | None = None, + target_width: int | None = None, ) -> PIL.Image.Image: """Preprocess input image for image-to-image generation. - Matches diffusers FLUX2 behavior: + Matches the shared image-to-image behavior: - cap image area when needed - floor dimensions to multiples of vae_scale_factor * 2 - apply aspect-ratio preserving center-crop resize to the floored size Args: image: PIL Image or numpy array (uint8) to preprocess. + target_height: Target height for the image. If None, uses image's height. + target_width: Target width for the image. If None, uses image's width. Returns: Preprocessed PIL Image with adjusted dimensions. @@ -379,19 +466,141 @@ def _preprocess_input_image( ) image_width, image_height = image.size - image_width = max( - (image_width // multiple_of) * multiple_of, multiple_of + if target_height is None: + image_height = max( + (image_height // multiple_of) * multiple_of, multiple_of + ) + else: + image_height = max( + (target_height // multiple_of) * multiple_of, multiple_of + ) + + if target_width is None: + image_width = max( + (image_width // multiple_of) * multiple_of, multiple_of + ) + else: + image_width = max( + (target_width // multiple_of) * multiple_of, multiple_of + ) + + if image.size != (image_width, image_height): + if target_height is None and target_width is None: + image = self._resize_with_center_crop( + image, image_width, image_height + ) + else: + image = image.resize( + (image_width, image_height), + PIL.Image.Resampling.LANCZOS, + ) + + return image + + @staticmethod + def _resize_image_to_area( + image: npt.NDArray[np.uint8], target_area: int + ) -> npt.NDArray[np.uint8]: + width = math.sqrt(target_area * (image.shape[1] / image.shape[0])) + height = width / (image.shape[1] / image.shape[0]) + resized_width = round(width / 32) * 32 + resized_height = round(height / 32) * 32 + if (image.shape[1], image.shape[0]) == (resized_width, resized_height): + return image + + pil_image = PIL.Image.fromarray(image) + resized = pil_image.resize( + (resized_width, resized_height), PIL.Image.Resampling.LANCZOS ) - image_height = max( - (image_height // multiple_of) * multiple_of, multiple_of + return np.asarray(resized) + + def _prepare_qwen_edit_condition_images( + self, + input_images: list[npt.NDArray[np.uint8]] | None, + ) -> tuple[ + list[npt.NDArray[np.uint8]] | None, list[npt.NDArray[np.uint8]] | None + ]: + if not input_images: + return None, None + + prompt_images = [ + self._resize_image_to_area(image, QWEN_EDIT_PROMPT_IMAGE_SIZE) + for image in input_images + ] + vae_condition_images = [ + self._resize_image_to_area(image, QWEN_EDIT_VAE_IMAGE_SIZE) + for image in input_images + ] + return prompt_images, vae_condition_images + + def _prepare_qwen_edit_tokens( + self, + prompt: str, + prompt_images: list[npt.NDArray[np.uint8]] | None, + ) -> npt.NDArray[np.int64]: + if not prompt_images: + raise ValueError( + "prompt_images are required for qwen edit tokenization" + ) + if self._qwen_edit_image_processor is None: + raise ValueError("qwen edit image processor is not initialized") + if self._qwen_edit_image_token_id is None: + raise ValueError("qwen edit image token id is not initialized") + + from max.pipelines.architectures.qwen2_5vl.nn.qwen_vl_utils import ( + fetch_image, ) - if image.size != (image_width, image_height): - image = self._resize_with_center_crop( - image, image_width, image_height + processed_images = [ + fetch_image({"image": PIL.Image.fromarray(image).convert("RGB")}) + for image in prompt_images + ] + processed = self._qwen_edit_image_processor( + images=processed_images, + return_tensors="np", + ) + processed_dict = ( + processed[0] if isinstance(processed, tuple) else processed + ) + image_grid_thw = np.asarray( + processed_dict["image_grid_thw"], dtype=np.int64 + ) + + vision_prefix = "".join( + QWEN_EDIT_IMAGE_PROMPT_TEMPLATE.format(index + 1) + for index in range(len(prompt_images)) + ) + formatted_prompt = QWEN_EDIT_PROMPT_TEMPLATE.format( + vision_prefix + prompt + ) + encoded = self.delegate( + formatted_prompt, + padding=False, + return_tensors="np", + add_special_tokens=False, + ) + input_ids = encoded["input_ids"][0].astype(np.int64, copy=False) + + merge_len = self._qwen_edit_image_processor.merge_size**2 + image_token_indices = np.where( + input_ids == self._qwen_edit_image_token_id + )[0] + if len(image_token_indices) < len(image_grid_thw): + raise ValueError( + "not enough qwen edit image placeholder tokens were generated" ) - return image + for index, grid_thw in enumerate(reversed(image_grid_thw)): + token_index = image_token_indices[-(index + 1)] + t, h, w = grid_thw + num_image_tokens = int((t * h * w) // merge_len) + input_ids = np.insert( + input_ids, + token_index, + [self._qwen_edit_image_token_id] * (num_image_tokens - 1), + ) + + return input_ids.astype(np.int64, copy=False) def _prepare_latents( self, @@ -422,9 +631,7 @@ async def _generate_tokens_ids( npt.NDArray[np.int64], npt.NDArray[np.bool_], npt.NDArray[np.int64] | None, - npt.NDArray[np.bool_] | None, npt.NDArray[np.int64] | None, - npt.NDArray[np.bool_] | None, npt.NDArray[np.int64] | None, ]: """Tokenize prompt(s) with encoder model(s). @@ -438,36 +645,26 @@ async def _generate_tokens_ids( images: Optional list of images for image-to-image generation (Flux2 only). Returns: - Tuple of ( - token_ids, - attn_mask, - token_ids_2, - attn_mask_2, - negative_token_ids, - negative_attn_mask, - negative_token_ids_2, - ). + Tuple of (token_ids, attn_mask, token_ids_2, negative_token_ids, negative_token_ids_2). token_ids_2 and negative_token_ids_2 are None if no secondary tokenizer is configured. """ token_ids, attn_mask = await self.encode(prompt, images=images) token_ids_2: npt.NDArray[np.int64] | None = None - attn_mask_2: npt.NDArray[np.bool_] | None = None if self.delegate_2 is not None: - token_ids_2, attn_mask_2 = await self.encode( + token_ids_2, _attn_mask_2 = await self.encode( prompt_2 or prompt, use_secondary=True, ) negative_token_ids: npt.NDArray[np.int64] | None = None - negative_attn_mask: npt.NDArray[np.bool_] | None = None negative_token_ids_2: npt.NDArray[np.int64] | None = None if do_true_cfg: - negative_token_ids, negative_attn_mask = await self.encode( + negative_token_ids, _attn_mask_neg = await self.encode( negative_prompt or "" ) if self.delegate_2 is not None: - negative_token_ids_2, _negative_attn_mask_2 = await self.encode( + negative_token_ids_2, _attn_mask_neg_2 = await self.encode( negative_prompt_2 or negative_prompt or "", use_secondary=True, ) @@ -476,9 +673,7 @@ async def _generate_tokens_ids( token_ids, attn_mask, token_ids_2, - attn_mask_2, negative_token_ids, - negative_attn_mask, negative_token_ids_2, ) @@ -508,6 +703,9 @@ async def encode( tokenizer_output: Any + # Check if this is Flux2 pipeline (uses Mistral3Tokenizer with chat_template) + # Flux2 requires apply_chat_template for proper tokenization + def _encode_fn(prompt_str: str) -> Any: assert delegate is not None if self._pipeline_class_name == PipelineClassName.FLUX2: @@ -604,23 +802,25 @@ def _encode_fn(prompt_str: str) -> Any: add_special_tokens=add_special_tokens, return_attention_mask=True, ) - else: - # Validate prompt length before truncation. - # The tokenizer's truncation=True silently drops - # tokens beyond max_sequence_length; error early - # instead. - raw_ids = delegate.encode( - prompt_str, + elif self._pipeline_class_name.is_qwen_image_family: + # QwenImage wraps prompts in a Qwen2 chat template. + template = ( + "<|im_start|>system\n" + "Describe the image by detailing the color, shape, " + "size, texture, quantity, text, spatial relationships " + "of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + formatted = template.format(prompt_str) + return delegate( + formatted, + padding=False, + max_length=max_sequence_length, + truncation=True, add_special_tokens=add_special_tokens, ) - if max_sequence_length and len(raw_ids) > max_sequence_length: - raise ValueError( - f"Prompt is too long for this model's text" - f" encoder: {len(raw_ids)} tokens exceeds" - f" the maximum of {max_sequence_length}" - " tokens. Please shorten your prompt." - ) - + else: return delegate( prompt_str, padding="max_length", @@ -631,71 +831,38 @@ def _encode_fn(prompt_str: str) -> Any: tokenizer_output = await run_with_default_executor(_encode_fn, prompt) - # Extract input_ids and attention_mask. + # Extract input_ids and attention_mask if isinstance(tokenizer_output, dict): + # apply_chat_template returns a dict input_ids = tokenizer_output["input_ids"] attention_mask = tokenizer_output.get("attention_mask", None) + if attention_mask is None: + attention_mask = [1] * len(input_ids) + + # Extract real tokens only (using attention mask) for Flux2 + if self._pipeline_class_name == PipelineClassName.FLUX2: + # Filter to keep only real tokens (where mask == 1) + real_token_ids = [ + token_id + for token_id, mask in zip( + input_ids[0], attention_mask[0], strict=False + ) + if mask == 1 + ] + input_ids = [real_token_ids] + attention_mask = [[1] * len(real_token_ids)] else: + # Standard tokenizer output input_ids = tokenizer_output.input_ids attention_mask = tokenizer_output.attention_mask - input_ids_array = np.asarray(input_ids, dtype=np.int64) - if attention_mask is None: - attention_mask_array = np.ones_like(input_ids_array, dtype=np.bool_) - else: - attention_mask_array = np.asarray(attention_mask, dtype=np.bool_) - - # Tokenizers can return a batch dimension for a single prompt. - if input_ids_array.ndim == 2: - if input_ids_array.shape[0] != 1: - raise ValueError( - "Expected one prompt during tokenization, got " - f"batch size {input_ids_array.shape[0]}." - ) - input_ids_array = input_ids_array[0] - elif input_ids_array.ndim != 1: + if max_sequence_length and len(input_ids) > max_sequence_length: raise ValueError( - "Expected rank-1 or rank-2 input_ids, got " - f"shape {input_ids_array.shape}." + f"Input string is larger than tokenizer's max length ({len(input_ids)} > {max_sequence_length})." ) - if attention_mask_array.ndim == 2: - if attention_mask_array.shape[0] != 1: - raise ValueError( - "Expected one prompt attention_mask, got " - f"batch size {attention_mask_array.shape[0]}." - ) - attention_mask_array = attention_mask_array[0] - elif attention_mask_array.ndim != 1: - raise ValueError( - "Expected rank-1 or rank-2 attention_mask, got " - f"shape {attention_mask_array.shape}." - ) - - if attention_mask_array.shape[0] != input_ids_array.shape[0]: - raise ValueError( - "input_ids and attention_mask must have the same sequence " - f"length ({input_ids_array.shape[0]} != {attention_mask_array.shape[0]})." - ) - - # FLUX.2 uses compact token IDs; FLUX.2-Klein keeps full tokenizer output. - if self._pipeline_class_name == PipelineClassName.FLUX2: - input_ids_array = input_ids_array[attention_mask_array] - attention_mask_array = np.ones( - input_ids_array.shape[0], dtype=np.bool_ - ) - - if ( - max_sequence_length - and input_ids_array.shape[0] > max_sequence_length - ): - raise ValueError( - "Input string is larger than tokenizer's max length " - f"({input_ids_array.shape[0]} > {max_sequence_length})." - ) - - encoded_prompt = input_ids_array.astype(np.int64, copy=False) - attention_mask_array = attention_mask_array.astype(np.bool_, copy=False) + encoded_prompt = np.array(input_ids) + attention_mask_array = np.array(attention_mask).astype(np.bool_) return encoded_prompt, attention_mask_array @@ -784,46 +951,47 @@ def _retrieve_prompt(request: OpenResponsesRequest) -> str: ) @staticmethod - def _retrieve_image( + def _retrieve_images( request: OpenResponsesRequest, - ) -> PIL.Image.Image | None: - """Retrieve the input image from an OpenResponsesRequest. + ) -> list[PIL.Image.Image]: + """Retrieve all input images from an OpenResponsesRequest. - Extracts InputImageContent from the first message's content list and converts - the data URI to a PIL Image. + Extracts all InputImageContent items from the first message's content + list and converts data URIs to PIL Images. Args: - request: The OpenResponsesRequest to extract the image from. + request: The OpenResponsesRequest to extract images from. Returns: - PIL Image if found, None otherwise. + List of PIL Images (empty if none found). """ # Only check list inputs if not isinstance(request.body.input, list): - return None + return [] if not request.body.input: - return None + return [] first_message = request.body.input[0] # Only check list content if not isinstance(first_message.content, list): - return None + return [] - # Find first InputImageContent item + images: list[PIL.Image.Image] = [] for item in first_message.content: - if isinstance(item, InputImageContent): - # Parse data URI and convert to PIL Image - image_url = item.image_url + image_type = getattr(item, "type", None) + image_url = getattr(item, "image_url", None) + if ( + isinstance(item, InputImageContent) + or image_type == "input_image" + ) and isinstance(image_url, str): if image_url.startswith("data:"): - # Extract base64 data from data URI - # Format: data:image/png;base64, _, base64_data = image_url.split(",", 1) image_bytes = base64.b64decode(base64_data) - return PIL.Image.open(BytesIO(image_bytes)) + images.append(PIL.Image.open(BytesIO(image_bytes))) - return None + return images async def new_context( self, @@ -836,8 +1004,10 @@ async def new_context( if not prompt: raise ValueError("Prompt must be a non-empty string.") - # Extract input image from request content (takes precedence over input_image parameter) - input_image = self._retrieve_image(request) or input_image + # Extract input images from request content (takes precedence over input_image parameter) + input_images_list = self._retrieve_images(request) + if not input_images_list and input_image is not None: + input_images_list = [input_image] # Extract image provider options (always available via defaults) image_options = request.body.provider_options.image @@ -871,8 +1041,6 @@ async def new_context( is_distilled_klein = bool( self.diffusers_config.get("is_distilled", False) ) - # for non-distilled models, CFG is enabled - # whenever guidance_scale > 1.0; negative prompt defaults to "". do_true_cfg = ( image_options.guidance_scale > 1.0 and not is_distilled_klein ) @@ -881,25 +1049,26 @@ async def new_context( image_options.true_cfg_scale > 1.0 and image_options.negative_prompt is not None ) + import PIL.Image # 1. Tokenize prompts - # Convert input_image to list format for _generate_tokens_ids + # Convert input images to list format for _generate_tokens_ids images_for_tokenization: list[PIL.Image.Image] | None = None - if input_image is not None: - input_img: PIL.Image.Image - if isinstance(input_image, np.ndarray): - input_img = PIL.Image.fromarray(input_image.astype(np.uint8)) - else: - input_img = input_image - images_for_tokenization = [input_img] + if input_images_list: + images_for_tokenization = [] + for img in input_images_list: + if isinstance(img, np.ndarray): + images_for_tokenization.append( + PIL.Image.fromarray(img.astype(np.uint8)) + ) + else: + images_for_tokenization.append(img) ( token_ids, attn_mask, token_ids_2, - _attn_mask_2, negative_token_ids, - _negative_attn_mask, negative_token_ids_2, ) = await self._generate_tokens_ids( prompt, @@ -910,6 +1079,70 @@ async def new_context( images=images_for_tokenization, ) + default_sample_size = self._default_sample_size + vae_scale_factor = self._vae_scale_factor + + requested_height = image_options.height + requested_width = image_options.width + + height = ( + requested_height + if requested_height is not None + else default_sample_size * vae_scale_factor + ) + width = ( + requested_width + if requested_width is not None + else default_sample_size * vae_scale_factor + ) + + # 2. Preprocess input images if provided + preprocessed_image_arrays: list[npt.NDArray[np.uint8]] | None = None + if input_images_list: + if self._is_qwen_image_edit_family: + target_height = None + target_width = None + else: + target_height = ( + requested_height if requested_height is not None else None + ) + target_width = ( + requested_width if requested_width is not None else None + ) + preprocessed_image_arrays = [] + for img in input_images_list: + preprocessed_image = self._preprocess_input_image( + img, target_height, target_width + ) + if ( + not preprocessed_image_arrays + and not self._is_qwen_image_edit_family + ): + height = preprocessed_image.height + width = preprocessed_image.width + preprocessed_image_arrays.append( + np.array(preprocessed_image, dtype=np.uint8).copy() + ) + + prompt_images: list[npt.NDArray[np.uint8]] | None = None + vae_condition_images: list[npt.NDArray[np.uint8]] | None = None + if self._is_qwen_image_edit_family: + prompt_images, vae_condition_images = ( + self._prepare_qwen_edit_condition_images( + preprocessed_image_arrays + ) + ) + if prompt_images: + token_ids = self._prepare_qwen_edit_tokens( + prompt, prompt_images + ) + attn_mask = np.ones_like(token_ids, dtype=np.bool_) + if do_true_cfg and image_options.negative_prompt is not None: + negative_token_ids = self._prepare_qwen_edit_tokens( + image_options.negative_prompt, + prompt_images, + ) + token_buffer = TokenBuffer( array=token_ids.astype(np.int64, copy=False), ) @@ -929,26 +1162,6 @@ async def new_context( array=negative_token_ids_2.astype(np.int64, copy=False), ) - default_sample_size = self._default_sample_size - vae_scale_factor = self._vae_scale_factor - - # 2. Preprocess input image if provided - preprocessed_image_array = None - if input_image is not None: - preprocessed_image = self._preprocess_input_image(input_image) - height = image_options.height or preprocessed_image.height - width = image_options.width or preprocessed_image.width - preprocessed_image_array = np.array( - preprocessed_image, dtype=np.uint8 - ).copy() - else: - height = ( - image_options.height or default_sample_size * vae_scale_factor - ) - width = ( - image_options.width or default_sample_size * vae_scale_factor - ) - # 3. Resolve image dimensions using cached static values latent_height = 2 * (int(height) // (self._vae_scale_factor * 2)) latent_width = 2 * (int(width) // (self._vae_scale_factor * 2)) @@ -982,7 +1195,6 @@ async def new_context( mask=attn_mask, tokens_2=token_buffer_2, negative_tokens=negative_token_buffer, - negative_mask=_negative_attn_mask, negative_tokens_2=negative_token_buffer_2, timesteps=timesteps, sigmas=sigmas, @@ -996,8 +1208,16 @@ async def new_context( true_cfg_scale=image_options.true_cfg_scale, num_warmup_steps=num_warmup_steps, model_name=request.body.model, - input_image=preprocessed_image_array, # Pass numpy array instead of PIL.Image + input_image=preprocessed_image_arrays[0] + if preprocessed_image_arrays + else None, + input_images=preprocessed_image_arrays, + prompt_images=prompt_images, + vae_condition_images=vae_condition_images, output_format=image_options.output_format, ) + for validator in self._context_validators: + validator(context) + return context diff --git a/max/python/max/pipelines/lib/qwen_image_processor.py b/max/python/max/pipelines/lib/qwen_image_processor.py new file mode 100644 index 00000000000..711978e12a7 --- /dev/null +++ b/max/python/max/pipelines/lib/qwen_image_processor.py @@ -0,0 +1,139 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""MAX-native image preprocessing helpers for Qwen image pipelines.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import numpy.typing as npt +from max.pipelines.lib import float32_to_bfloat16_as_uint16 +from PIL import Image + +_IMAGENET_MEAN = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) +_IMAGENET_STD = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) +_NORM_SCALE = (1.0 / (255.0 * _IMAGENET_STD)).astype(np.float32) +_NORM_OFFSET = (-_IMAGENET_MEAN / _IMAGENET_STD).astype(np.float32) + + +def qwen2_5vl_prompt_image_preprocessing( + image: Image.Image, + *, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, +) -> tuple[npt.NDArray[np.uint16], tuple[int, int, int]]: + """Preprocess a prompt image for MAX-native Qwen image edit encoding.""" + if image.mode != "RGB": + image = image.convert("RGB") + + img_array = np.asarray(image, dtype=np.float32) + np.multiply(img_array, _NORM_SCALE, out=img_array) + np.add(img_array, _NORM_OFFSET, out=img_array) + + height, width = img_array.shape[:2] + grid_h = height // patch_size + grid_w = width // patch_size + + if grid_h % merge_size != 0 or grid_w % merge_size != 0: + raise ValueError( + f"Spatial merging is not possible because grid_h {grid_h} % merge_size {merge_size} != 0 or grid_w {grid_w} % merge_size {merge_size} != 0" + ) + + patches = img_array[np.newaxis, ...] + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] % temporal_patch_size != 0: + repeats = np.repeat( + patches[-1][np.newaxis], + temporal_patch_size - (patches.shape[0] % temporal_patch_size), + axis=0, + ) + patches = np.concatenate([patches, repeats], axis=0) + + channel = patches.shape[1] + grid_t = patches.shape[0] // temporal_patch_size + patches = patches.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flattened_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + return float32_to_bfloat16_as_uint16(flattened_patches), ( + grid_t, + grid_h, + grid_w, + ) + + +class Qwen2_5VLPromptImageProcessor: + """Process prompt images for MAX-native Qwen image edit pipelines.""" + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + ) -> None: + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + + def __call__( + self, + images: list[Image.Image] | Image.Image, + return_tensors: str = "np", + **kwargs: Any, + ) -> tuple[dict[str, npt.NDArray[Any]], list[npt.NDArray[np.uint16]]]: + """Preprocess one or more prompt images and return patch tensors.""" + del return_tensors, kwargs + if isinstance(images, Image.Image): + images = [images] + + pixel_values_list: list[npt.NDArray[np.uint16]] = [] + image_grid_thw_list: list[tuple[int, int, int]] = [] + + for image in images: + pixel_values, image_grid_thw = qwen2_5vl_prompt_image_preprocessing( + image, + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + merge_size=self.merge_size, + ) + pixel_values_list.append(pixel_values) + image_grid_thw_list.append(image_grid_thw) + + return { + "concatenated_pixel_values": np.vstack(pixel_values_list), + "image_grid_thw": np.array(image_grid_thw_list, dtype=np.int32), + }, pixel_values_list + + def preprocess( + self, + images: list[Image.Image] | Image.Image, + return_tensors: str = "np", + **kwargs: Any, + ) -> tuple[dict[str, npt.NDArray[Any]], list[npt.NDArray[np.uint16]]]: + """Alias matching the HuggingFace image processor API.""" + return self(images, return_tensors=return_tensors, **kwargs)