From 6eb37d4d6b152d052cac6d7ca31ee83b692c76de Mon Sep 17 00:00:00 2001 From: jglee-sqbits Date: Tue, 24 Mar 2026 08:37:05 +0000 Subject: [PATCH] [Models] Add QwenImage encoder (rebased onto main) --- .../qwen2_5vl/encoder/__init__.py | 17 ++ .../qwen2_5vl/encoder/layers/__init__.py | 16 ++ .../qwen2_5vl/encoder/layers/attention.py | 127 ++++++++++++ .../architectures/qwen2_5vl/encoder/model.py | 194 ++++++++++++++++++ .../qwen2_5vl/encoder/model_config.py | 87 ++++++++ .../qwen2_5vl/encoder/qwen25vl.py | 187 +++++++++++++++++ 6 files changed, 628 insertions(+) create mode 100644 max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py create mode 100644 max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/__init__.py create mode 100644 max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/attention.py create mode 100644 max/python/max/pipelines/architectures/qwen2_5vl/encoder/model.py create mode 100644 max/python/max/pipelines/architectures/qwen2_5vl/encoder/model_config.py create mode 100644 max/python/max/pipelines/architectures/qwen2_5vl/encoder/qwen25vl.py diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py new file mode 100644 index 00000000000..d18e229b147 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py @@ -0,0 +1,17 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import Qwen25VLEncoderModel +from .multimodal_encoder import Qwen25VLMultimodalEncoderModel + +__all__ = ["Qwen25VLEncoderModel", "Qwen25VLMultimodalEncoderModel"] diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/__init__.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/__init__.py new file mode 100644 index 00000000000..08ac3ec6638 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/__init__.py @@ -0,0 +1,16 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .attention import Qwen25VLEncoderAttention + +__all__ = ["Qwen25VLEncoderAttention"] diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/attention.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/attention.py new file mode 100644 index 00000000000..04bbd7761f3 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/attention.py @@ -0,0 +1,127 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Qwen2.5-VL encoder-only attention with bias support (module v2).""" + +from __future__ import annotations + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu +from max.nn.layer import Module +from max.nn.linear import Linear +from max.nn.rotary_embedding import RotaryEmbedding + + +class Qwen25VLEncoderAttention(Module): + """Encoder-only attention with bias for Qwen2.5-VL (module v2).""" + + def __init__( + self, + num_attention_heads: int, + num_key_value_heads: int, + hidden_size: int, + head_dim: int, + scale: float, + attention_bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + super().__init__() + self.n_heads = num_attention_heads + self.n_kv_heads = num_key_value_heads + self.head_dim = head_dim + self.scale = scale + + q_dim = head_dim * num_attention_heads + kv_dim = head_dim * num_key_value_heads + + self.q_proj = Linear( + hidden_size, + q_dim, + dtype=dtype, + device=device, + has_bias=attention_bias, + ) + self.k_proj = Linear( + hidden_size, + kv_dim, + dtype=dtype, + device=device, + has_bias=attention_bias, + ) + self.v_proj = Linear( + hidden_size, + kv_dim, + dtype=dtype, + device=device, + has_bias=attention_bias, + ) + self.o_proj = Linear( + q_dim, + hidden_size, + dtype=dtype, + device=device, + has_bias=False, + ) + + def _repeat_kv(self, x: TensorValue, n_rep: int) -> TensorValue: + if n_rep == 1: + return x + seq_len = x.shape[0] + n_kv_heads = x.shape[1] + head_dim = x.shape[2] + x = ops.unsqueeze(x, 2) + x = ops.broadcast_to(x, (seq_len, n_kv_heads, n_rep, head_dim)) + return ops.reshape(x, (seq_len, n_kv_heads * n_rep, head_dim)) + + def __call__( + self, + x: TensorValue, + rope: RotaryEmbedding, + ) -> TensorValue: + total_seq_len = x.shape[0] + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = ops.reshape(q, (total_seq_len, self.n_heads, self.head_dim)) + k = ops.reshape(k, (total_seq_len, self.n_kv_heads, self.head_dim)) + v = ops.reshape(v, (total_seq_len, self.n_kv_heads, self.head_dim)) + + q = ops.squeeze(rope(ops.unsqueeze(q, 0)), 0) + k = ops.squeeze(rope(ops.unsqueeze(k, 0)), 0) + + if self.n_kv_heads != self.n_heads: + n_rep = self.n_heads // self.n_kv_heads + k = self._repeat_kv(k, n_rep) + v = self._repeat_kv(v, n_rep) + + q = ops.unsqueeze(q, 0) + k = ops.unsqueeze(k, 0) + v = ops.unsqueeze(v, 0) + + attn_out = flash_attention_gpu( + q, + k, + v, + mask_variant=MHAMaskVariant.CAUSAL_MASK, + scale=self.scale, + ) + + attn_out = ops.squeeze(attn_out, 0) + attn_out = ops.reshape(attn_out, (total_seq_len, -1)) + return self.o_proj(attn_out) diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/model.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/model.py new file mode 100644 index 00000000000..f7b994d2110 --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/model.py @@ -0,0 +1,194 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Qwen2.5-VL encoder ComponentModel wrapper (module v2).""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from max.driver import Device +from max.dtype import DType +from max.engine import InferenceSession, Model +from max.graph import DeviceRef, Graph, TensorType +from max.graph.weights import Weights +from max.nn.embedding import Embedding +from max.nn.layer import Module +from max.pipelines.architectures.llama3.weight_adapters import ( + LLAMA_SAFETENSOR_MAPPING as QWEN_SAFETENSOR_MAP, +) +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.component_model import ComponentModel + +from .model_config import Qwen25VLTextEncoderConfig +from .qwen25vl import Qwen25VLTextEncoderTransformer + + +class _EmbedOnly(Module): + """Token embedding only (module v2).""" + + def __init__( + self, + vocab_size: int, + hidden_size: int, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + super().__init__() + self.embed_tokens = Embedding( + vocab_size, + hidden_size, + dtype=dtype, + device=device, + ) + + def __call__(self, tokens: Any) -> Any: + return self.embed_tokens(tokens) + + +class Qwen25VLEncoderModel(ComponentModel): + """Qwen2.5-VL language-side encoder ComponentModel wrapper (module v2).""" + + def __init__( + self, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession | None = None, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = Qwen25VLTextEncoderConfig.generate( + config, + encoding, + devices, + ) + self.session = session + self.load_model() + + def load_model(self) -> Callable[..., Any]: + embed_state: dict[str, Any] = {} + transform_state: dict[str, Any] = {} + + for key, value in self.weights.items(): + wd = value.data() + + # Normalize floating-point weights to bf16 + if wd.dtype.is_float() and not wd.dtype.is_float8(): + is_scale = key.endswith(".weight_scale") or key.endswith( + ".input_scale" + ) + if not is_scale: + wd = wd.astype(DType.bfloat16) + + # Key mapping + adapted_key = key + if adapted_key.startswith("model.language_model."): + adapted_key = adapted_key[len("model.language_model.") :] + else: + for before, after in QWEN_SAFETENSOR_MAP.items(): + adapted_key = adapted_key.replace(before, after) + + # Skip vision weights + if adapted_key.startswith("visual.") or adapted_key.startswith( + "vision_encoder." + ): + continue + + # Strip "model." prefix + adapted_key = adapted_key.removeprefix("model.") + + if adapted_key.startswith("embed_tokens."): + embed_state[adapted_key] = wd + elif ( + adapted_key.startswith("layers.") + or adapted_key.startswith("norm.") + or adapted_key.startswith("rope.") + ): + transform_state[adapted_key] = wd + + lc = self.config + device_ref = DeviceRef.from_device(self.devices[0]) + + # --- Compile embed_tokens --- + embed_model = _EmbedOnly( + lc.vocab_size, + lc.hidden_size, + dtype=lc.dtype, + device=device_ref, + ) + embed_model.load_state_dict( + embed_state, weight_alignment=1, strict=True + ) + embed_input_types = [ + TensorType(DType.int64, shape=["total_seq_len"], device=device_ref), + ] + with Graph("qwen_te_embed", input_types=embed_input_types) as g: + out = embed_model(*(v.tensor for v in g.inputs)) + g.output(out) + + session = self.session + if session is None: + session = InferenceSession(devices=self.devices) + + self._embed_model: Model = session.load( + g, + weights_registry=embed_model.state_dict(), + ) + + # --- Compile transformer layers + norm --- + transform_model = Qwen25VLTextEncoderTransformer(lc) + transform_model.load_state_dict( + transform_state, + weight_alignment=1, + strict=True, + ) + transform_input_types = [ + TensorType( + lc.dtype, + shape=["total_seq_len", lc.hidden_size], + device=device_ref, + ), + ] + with Graph("qwen_te_transform", input_types=transform_input_types) as g: + out = transform_model(*(v.tensor for v in g.inputs)) + g.output(out) + self._transform_model: Model = session.load( + g, + weights_registry=transform_model.state_dict(), + ) + + return self._embed_model + + def __call__(self, token_input: Any) -> tuple[Any]: + """Run text encoder: embed_tokens → transformer → normed output. + + Accepts both Buffer (v2) and experimental Tensor (v3 compat). + Returns a tuple wrapping the result in the same type as input. + """ + # Extract Buffer from _Tensor if needed + is_tensor = hasattr(token_input, "driver_tensor") + buf = token_input.driver_tensor if is_tensor else token_input + + embed_result = self._embed_model.execute(buf) + transform_result = self._transform_model.execute(embed_result[0]) + result_buf = transform_result[0] + + if is_tensor: + from max.experimental.tensor import Tensor as _Tensor + + return (_Tensor(storage=result_buf),) + + return (result_buf,) diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/model_config.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/model_config.py new file mode 100644 index 00000000000..b5ea03ae94c --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/model_config.py @@ -0,0 +1,87 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Configuration for Qwen2.5-VL text encoder used in QwenImage pipeline.""" + +from __future__ import annotations + +import math +from typing import Any + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from max.pipelines.lib.config.config_enums import supported_encoding_dtype +from pydantic import Field + +_HF_KEY_MAP = { + "max_position_embeddings": "max_seq_len", +} + + +class Qwen25VLTextEncoderConfigBase(MAXModelConfigBase): + """Base configuration for Qwen2.5-VL text encoder. + + Key differences from Qwen3: + - attention_bias: True (Qwen2.5 uses biased attention) + - Different default dimensions matching Qwen2.5-VL-7B-Instruct + """ + + hidden_size: int = 3584 + num_attention_heads: int = 28 + num_key_value_heads: int = 4 + num_hidden_layers: int = 28 + head_dim: int = 128 + vocab_size: int = 152064 + intermediate_size: int = 18944 + rope_theta: float = 1000000.0 + max_seq_len: int = 128000 + rms_norm_eps: float = 1e-6 + attention_bias: bool = True + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + @property + def attention_multiplier(self) -> float: + return math.sqrt(1.0 / self.head_dim) + + +class Qwen25VLTextEncoderConfig(Qwen25VLTextEncoderConfigBase): + @staticmethod + def generate( + config_dict: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + ) -> Qwen25VLTextEncoderConfigBase: + text_config = config_dict.get("text_config", config_dict) + + init_dict = {} + for key, value in text_config.items(): + mapped_key = _HF_KEY_MAP.get(key, key) + if mapped_key in Qwen25VLTextEncoderConfigBase.__annotations__: + init_dict[mapped_key] = value + + if "head_dim" not in init_dict: + hidden_size = init_dict.get("hidden_size", 3584) + num_attention_heads = init_dict.get("num_attention_heads", 28) + init_dict["head_dim"] = hidden_size // num_attention_heads + + init_dict.update( + { + "dtype": supported_encoding_dtype(encoding), + "device": DeviceRef.from_device(devices[0]), + } + ) + + return Qwen25VLTextEncoderConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/qwen2_5vl/encoder/qwen25vl.py b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/qwen25vl.py new file mode 100644 index 00000000000..f4ea8edc1ef --- /dev/null +++ b/max/python/max/pipelines/architectures/qwen2_5vl/encoder/qwen25vl.py @@ -0,0 +1,187 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Qwen2.5-VL text encoder transformer (module v2). + +Standalone transformer for text encoding in diffusion pipelines. +Returns the final normed hidden states. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from max.graph import TensorValue, ops +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm.rms_norm import RMSNorm +from max.nn.rotary_embedding import RotaryEmbedding + +from .layers import Qwen25VLEncoderAttention + +if TYPE_CHECKING: + from max.dtype import DType + from max.graph import DeviceRef + + from .model_config import Qwen25VLTextEncoderConfigBase + + +class Qwen25VLMLP(Module): + """Qwen2.5-VL MLP with SiLU gate activation (module v2).""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + super().__init__() + self.gate_proj = Linear( + hidden_size, + intermediate_size, + dtype=dtype, + device=device, + has_bias=False, + ) + self.up_proj = Linear( + hidden_size, + intermediate_size, + dtype=dtype, + device=device, + has_bias=False, + ) + self.down_proj = Linear( + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + has_bias=False, + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + gate = ops.silu(self.gate_proj(hidden_states)) + up = self.up_proj(hidden_states) + return self.down_proj(gate * up) + + +class Qwen25VLEncoderTransformerBlock(Module): + """Transformer block for Qwen2.5-VL encoder (module v2).""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + intermediate_size: int, + rms_norm_eps: float, + scale: float, + attention_bias: bool = True, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + super().__init__() + self.self_attn = Qwen25VLEncoderAttention( + num_attention_heads=num_heads, + num_key_value_heads=num_kv_heads, + hidden_size=hidden_size, + head_dim=head_dim, + scale=scale, + attention_bias=attention_bias, + dtype=dtype, + device=device, + ) + self.mlp = Qwen25VLMLP( + hidden_size, + intermediate_size, + dtype=dtype, + device=device, + ) + self.input_layernorm = RMSNorm( + hidden_size, dtype=dtype, eps=rms_norm_eps + ) + self.post_attention_layernorm = RMSNorm( + hidden_size, + dtype=dtype, + eps=rms_norm_eps, + ) + + def __call__( + self, + x: TensorValue, + rope: RotaryEmbedding, + ) -> TensorValue: + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, rope) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x + + +class Qwen25VLTextEncoderTransformer(Module): + """Qwen2.5-VL text encoder (module v2). + + Split into two sub-modules for separate compilation: + - embed_tokens: token embedding + - layers + norm: transformer blocks + final norm + """ + + def __init__(self, config: Qwen25VLTextEncoderConfigBase) -> None: + super().__init__() + dtype = config.dtype + device = config.device + + self.rope = RotaryEmbedding( + dim=config.hidden_size, + n_heads=config.num_attention_heads, + theta=config.rope_theta, + max_seq_len=config.max_seq_len, + head_dim=config.head_dim, + interleaved=False, + ) + + self.layers = LayerList( + [ + Qwen25VLEncoderTransformerBlock( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + intermediate_size=config.intermediate_size, + rms_norm_eps=config.rms_norm_eps, + scale=config.attention_multiplier, + attention_bias=config.attention_bias, + dtype=dtype, + device=device, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm( + config.hidden_size, dtype=dtype, eps=config.rms_norm_eps + ) + + def __call__(self, hidden_states: TensorValue) -> TensorValue: + """Run transformer layers + norm on pre-embedded hidden states.""" + h = hidden_states + for layer in self.layers: + h = layer(h, self.rope) + return self.norm(h)