forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
[MAX] Add Qwen2.5-VL encoder for Qwen-Image #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jglee-sqbits
wants to merge
1
commit into
main
Choose a base branch
from
add/qwen-image/encoder
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
17 changes: 17 additions & 0 deletions
17
max/python/max/pipelines/architectures/qwen2_5vl/encoder/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from .model import Qwen25VLEncoderModel | ||
| from .multimodal_encoder import Qwen25VLMultimodalEncoderModel | ||
|
|
||
| __all__ = ["Qwen25VLEncoderModel", "Qwen25VLMultimodalEncoderModel"] |
16 changes: 16 additions & 0 deletions
16
max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from .attention import Qwen25VLEncoderAttention | ||
|
|
||
| __all__ = ["Qwen25VLEncoderAttention"] |
127 changes: 127 additions & 0 deletions
127
max/python/max/pipelines/architectures/qwen2_5vl/encoder/layers/attention.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| """Qwen2.5-VL encoder-only attention with bias support (module v2).""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from max.dtype import DType | ||
| from max.graph import DeviceRef, TensorValue, ops | ||
| from max.nn.attention.mask_config import MHAMaskVariant | ||
| from max.nn.kernels import flash_attention_gpu | ||
| from max.nn.layer import Module | ||
| from max.nn.linear import Linear | ||
| from max.nn.rotary_embedding import RotaryEmbedding | ||
|
|
||
|
|
||
| class Qwen25VLEncoderAttention(Module): | ||
| """Encoder-only attention with bias for Qwen2.5-VL (module v2).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_attention_heads: int, | ||
| num_key_value_heads: int, | ||
| hidden_size: int, | ||
| head_dim: int, | ||
| scale: float, | ||
| attention_bias: bool = True, | ||
| *, | ||
| dtype: DType, | ||
| device: DeviceRef, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.n_heads = num_attention_heads | ||
| self.n_kv_heads = num_key_value_heads | ||
| self.head_dim = head_dim | ||
| self.scale = scale | ||
|
|
||
| q_dim = head_dim * num_attention_heads | ||
| kv_dim = head_dim * num_key_value_heads | ||
|
|
||
| self.q_proj = Linear( | ||
| hidden_size, | ||
| q_dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=attention_bias, | ||
| ) | ||
| self.k_proj = Linear( | ||
| hidden_size, | ||
| kv_dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=attention_bias, | ||
| ) | ||
| self.v_proj = Linear( | ||
| hidden_size, | ||
| kv_dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=attention_bias, | ||
| ) | ||
| self.o_proj = Linear( | ||
| q_dim, | ||
| hidden_size, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=False, | ||
| ) | ||
|
|
||
| def _repeat_kv(self, x: TensorValue, n_rep: int) -> TensorValue: | ||
| if n_rep == 1: | ||
| return x | ||
| seq_len = x.shape[0] | ||
| n_kv_heads = x.shape[1] | ||
| head_dim = x.shape[2] | ||
| x = ops.unsqueeze(x, 2) | ||
| x = ops.broadcast_to(x, (seq_len, n_kv_heads, n_rep, head_dim)) | ||
| return ops.reshape(x, (seq_len, n_kv_heads * n_rep, head_dim)) | ||
|
|
||
| def __call__( | ||
| self, | ||
| x: TensorValue, | ||
| rope: RotaryEmbedding, | ||
| ) -> TensorValue: | ||
| total_seq_len = x.shape[0] | ||
|
|
||
| q = self.q_proj(x) | ||
| k = self.k_proj(x) | ||
| v = self.v_proj(x) | ||
|
|
||
| q = ops.reshape(q, (total_seq_len, self.n_heads, self.head_dim)) | ||
| k = ops.reshape(k, (total_seq_len, self.n_kv_heads, self.head_dim)) | ||
| v = ops.reshape(v, (total_seq_len, self.n_kv_heads, self.head_dim)) | ||
|
|
||
| q = ops.squeeze(rope(ops.unsqueeze(q, 0)), 0) | ||
| k = ops.squeeze(rope(ops.unsqueeze(k, 0)), 0) | ||
|
|
||
| if self.n_kv_heads != self.n_heads: | ||
| n_rep = self.n_heads // self.n_kv_heads | ||
| k = self._repeat_kv(k, n_rep) | ||
| v = self._repeat_kv(v, n_rep) | ||
|
|
||
| q = ops.unsqueeze(q, 0) | ||
| k = ops.unsqueeze(k, 0) | ||
| v = ops.unsqueeze(v, 0) | ||
|
|
||
| attn_out = flash_attention_gpu( | ||
| q, | ||
| k, | ||
| v, | ||
| mask_variant=MHAMaskVariant.CAUSAL_MASK, | ||
| scale=self.scale, | ||
| ) | ||
|
|
||
| attn_out = ops.squeeze(attn_out, 0) | ||
| attn_out = ops.reshape(attn_out, (total_seq_len, -1)) | ||
| return self.o_proj(attn_out) |
194 changes: 194 additions & 0 deletions
194
max/python/max/pipelines/architectures/qwen2_5vl/encoder/model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| """Qwen2.5-VL encoder ComponentModel wrapper (module v2).""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from typing import Any | ||
|
|
||
| from max.driver import Device | ||
| from max.dtype import DType | ||
| from max.engine import InferenceSession, Model | ||
| from max.graph import DeviceRef, Graph, TensorType | ||
| from max.graph.weights import Weights | ||
| from max.nn.embedding import Embedding | ||
| from max.nn.layer import Module | ||
| from max.pipelines.architectures.llama3.weight_adapters import ( | ||
| LLAMA_SAFETENSOR_MAPPING as QWEN_SAFETENSOR_MAP, | ||
| ) | ||
| from max.pipelines.lib import SupportedEncoding | ||
| from max.pipelines.lib.interfaces.component_model import ComponentModel | ||
|
|
||
| from .model_config import Qwen25VLTextEncoderConfig | ||
| from .qwen25vl import Qwen25VLTextEncoderTransformer | ||
|
|
||
|
|
||
| class _EmbedOnly(Module): | ||
| """Token embedding only (module v2).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocab_size: int, | ||
| hidden_size: int, | ||
| *, | ||
| dtype: DType, | ||
| device: DeviceRef, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.embed_tokens = Embedding( | ||
| vocab_size, | ||
| hidden_size, | ||
| dtype=dtype, | ||
| device=device, | ||
| ) | ||
|
|
||
| def __call__(self, tokens: Any) -> Any: | ||
| return self.embed_tokens(tokens) | ||
|
|
||
|
|
||
| class Qwen25VLEncoderModel(ComponentModel): | ||
| """Qwen2.5-VL language-side encoder ComponentModel wrapper (module v2).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: dict[str, Any], | ||
| encoding: SupportedEncoding, | ||
| devices: list[Device], | ||
| weights: Weights, | ||
| session: InferenceSession | None = None, | ||
| ) -> None: | ||
| super().__init__(config, encoding, devices, weights) | ||
| self.config = Qwen25VLTextEncoderConfig.generate( | ||
| config, | ||
| encoding, | ||
| devices, | ||
| ) | ||
| self.session = session | ||
| self.load_model() | ||
|
|
||
| def load_model(self) -> Callable[..., Any]: | ||
| embed_state: dict[str, Any] = {} | ||
| transform_state: dict[str, Any] = {} | ||
|
|
||
| for key, value in self.weights.items(): | ||
| wd = value.data() | ||
|
|
||
| # Normalize floating-point weights to bf16 | ||
| if wd.dtype.is_float() and not wd.dtype.is_float8(): | ||
| is_scale = key.endswith(".weight_scale") or key.endswith( | ||
| ".input_scale" | ||
| ) | ||
| if not is_scale: | ||
| wd = wd.astype(DType.bfloat16) | ||
|
|
||
| # Key mapping | ||
| adapted_key = key | ||
| if adapted_key.startswith("model.language_model."): | ||
| adapted_key = adapted_key[len("model.language_model.") :] | ||
| else: | ||
| for before, after in QWEN_SAFETENSOR_MAP.items(): | ||
| adapted_key = adapted_key.replace(before, after) | ||
|
|
||
| # Skip vision weights | ||
| if adapted_key.startswith("visual.") or adapted_key.startswith( | ||
| "vision_encoder." | ||
| ): | ||
| continue | ||
|
|
||
| # Strip "model." prefix | ||
| adapted_key = adapted_key.removeprefix("model.") | ||
|
|
||
| if adapted_key.startswith("embed_tokens."): | ||
| embed_state[adapted_key] = wd | ||
| elif ( | ||
| adapted_key.startswith("layers.") | ||
| or adapted_key.startswith("norm.") | ||
| or adapted_key.startswith("rope.") | ||
| ): | ||
| transform_state[adapted_key] = wd | ||
|
|
||
| lc = self.config | ||
| device_ref = DeviceRef.from_device(self.devices[0]) | ||
|
|
||
| # --- Compile embed_tokens --- | ||
| embed_model = _EmbedOnly( | ||
| lc.vocab_size, | ||
| lc.hidden_size, | ||
| dtype=lc.dtype, | ||
| device=device_ref, | ||
| ) | ||
| embed_model.load_state_dict( | ||
| embed_state, weight_alignment=1, strict=True | ||
| ) | ||
| embed_input_types = [ | ||
| TensorType(DType.int64, shape=["total_seq_len"], device=device_ref), | ||
| ] | ||
| with Graph("qwen_te_embed", input_types=embed_input_types) as g: | ||
| out = embed_model(*(v.tensor for v in g.inputs)) | ||
| g.output(out) | ||
|
|
||
| session = self.session | ||
| if session is None: | ||
| session = InferenceSession(devices=self.devices) | ||
|
|
||
| self._embed_model: Model = session.load( | ||
| g, | ||
| weights_registry=embed_model.state_dict(), | ||
| ) | ||
|
|
||
| # --- Compile transformer layers + norm --- | ||
| transform_model = Qwen25VLTextEncoderTransformer(lc) | ||
| transform_model.load_state_dict( | ||
| transform_state, | ||
| weight_alignment=1, | ||
| strict=True, | ||
| ) | ||
| transform_input_types = [ | ||
| TensorType( | ||
| lc.dtype, | ||
| shape=["total_seq_len", lc.hidden_size], | ||
| device=device_ref, | ||
| ), | ||
| ] | ||
| with Graph("qwen_te_transform", input_types=transform_input_types) as g: | ||
| out = transform_model(*(v.tensor for v in g.inputs)) | ||
| g.output(out) | ||
| self._transform_model: Model = session.load( | ||
| g, | ||
| weights_registry=transform_model.state_dict(), | ||
| ) | ||
|
|
||
| return self._embed_model | ||
|
|
||
| def __call__(self, token_input: Any) -> tuple[Any]: | ||
| """Run text encoder: embed_tokens → transformer → normed output. | ||
|
|
||
| Accepts both Buffer (v2) and experimental Tensor (v3 compat). | ||
| Returns a tuple wrapping the result in the same type as input. | ||
| """ | ||
| # Extract Buffer from _Tensor if needed | ||
| is_tensor = hasattr(token_input, "driver_tensor") | ||
| buf = token_input.driver_tensor if is_tensor else token_input | ||
|
|
||
| embed_result = self._embed_model.execute(buf) | ||
| transform_result = self._transform_model.execute(embed_result[0]) | ||
| result_buf = transform_result[0] | ||
|
|
||
| if is_tensor: | ||
| from max.experimental.tensor import Tensor as _Tensor | ||
|
|
||
| return (_Tensor(storage=result_buf),) | ||
|
|
||
| return (result_buf,) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reusing
LLAMA_SAFETENSOR_MAPPINGfor a Qwen model is risky. While there might be similarities in layer naming, differences between model architectures could lead to incorrect weight loading or hard-to-debug errors. For clarity and safety, it's better to define a specificQWEN_SAFETENSOR_MAPfor this architecture. If the mapping is indeed identical, adding a comment to clarify this would be beneficial for future maintenance.