forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
[MAX] Add UMT5 text encoder for Wan diffusion #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jglee-sqbits
wants to merge
1
commit into
jglee-sqbits/stack/1
Choose a base branch
from
jglee-sqbits/stack/2
base: jglee-sqbits/stack/1
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from .model import UMT5Model | ||
|
|
||
| __all__ = ["UMT5Model"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from typing import Any | ||
|
|
||
| from max.driver import Device | ||
| from max.dtype import DType | ||
| from max.engine import InferenceSession, Model | ||
| from max.graph import DeviceRef, Graph, TensorType | ||
| from max.graph.weights import WeightData, Weights | ||
| from max.pipelines.lib import SupportedEncoding | ||
| from max.pipelines.lib.interfaces.component_model import ComponentModel | ||
|
|
||
| from .model_config import UMT5Config, UMT5ConfigBase | ||
| from .umt5 import UMT5EncoderModel | ||
|
|
||
|
|
||
| def _prepare_state_dict( | ||
| weights: Weights, | ||
| target_dtype: DType | None = None, | ||
| ) -> dict[str, WeightData]: | ||
| """Convert Weights to a raw state dict, normalizing tied embedding keys. | ||
|
|
||
| HF UMT5 ties ``shared.weight`` and ``encoder.embed_tokens.weight``. | ||
| Our module owns the embedding as ``shared``, so we normalize to that key | ||
| and drop the alias to avoid strict-mode validation failures. | ||
|
|
||
| If ``target_dtype`` is provided, all weights are cast to that dtype | ||
| (e.g. float32 → bfloat16 for Wan 2.1 checkpoints). | ||
| """ | ||
| state_dict: dict[str, WeightData] = {} | ||
| for key, value in weights.items(): | ||
| wd = value.data() | ||
| if target_dtype is not None and wd.dtype != target_dtype: | ||
| wd = wd.astype(target_dtype) | ||
| state_dict[key] = wd | ||
|
|
||
| encoder_emb = state_dict.pop("encoder.embed_tokens.weight", None) | ||
| if "shared.weight" not in state_dict and encoder_emb is not None: | ||
| state_dict["shared.weight"] = encoder_emb | ||
|
|
||
| return state_dict | ||
|
|
||
|
|
||
| class UMT5Model(ComponentModel): | ||
| def __init__( | ||
| self, | ||
| config: dict[str, Any], | ||
| encoding: SupportedEncoding, | ||
| devices: list[Device], | ||
| weights: Weights, | ||
| session: InferenceSession | None = None, | ||
| ) -> None: | ||
| super().__init__(config, encoding, devices, weights) | ||
| self.session = session or InferenceSession(devices=devices) | ||
| self.config: UMT5ConfigBase = UMT5Config.generate( | ||
| config, | ||
| encoding, | ||
| devices, | ||
| ) | ||
| self.load_model() | ||
|
|
||
| def load_model(self) -> Model: | ||
| assert self.weights is not None, "Weights already freed" | ||
| # Force bfloat16 — some repos (Wan 2.1) declare float32 but | ||
| # should run in bfloat16 on GPU. Override both config and weights. | ||
| dtype = DType.bfloat16 | ||
| self.config.dtype = dtype | ||
| state_dict = _prepare_state_dict(self.weights, target_dtype=dtype) | ||
| dev = self.devices[0] | ||
| dev_ref = DeviceRef.from_device(dev) | ||
|
|
||
| # Build module and load weights | ||
| module = UMT5EncoderModel(self.config, dtype=dtype, device=dev_ref) | ||
| module.load_state_dict(state_dict, weight_alignment=1, strict=True) | ||
|
|
||
| # Build graph with symbolic sequence length | ||
| # attention_mask comes in as int64 from the pipeline | ||
| input_types = [ | ||
| TensorType(DType.int64, ["batch", "seq_len"], device=dev), | ||
| TensorType(DType.int64, ["batch", "seq_len"], device=dev), | ||
| ] | ||
| with Graph("umt5_encoder", input_types=input_types) as graph: | ||
| input_ids = graph.inputs[0].tensor | ||
| attention_mask = graph.inputs[1].tensor | ||
| out = module(input_ids, attention_mask) | ||
| graph.output(out) | ||
|
|
||
| self.model: Model = self.session.load( | ||
| graph, weights_registry=module.state_dict() | ||
| ) | ||
| # Free raw weights after compilation | ||
| self.weights = None # type: ignore[assignment] | ||
| return self.model | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| return self.model(*args, **kwargs) | ||
73 changes: 73 additions & 0 deletions
73
max/python/max/pipelines/architectures/umt5/model_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from typing import Any | ||
|
|
||
| from max.driver import Device | ||
| from max.dtype import DType | ||
| from max.graph import DeviceRef | ||
| from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding | ||
| from max.pipelines.lib.config.config_enums import supported_encoding_dtype | ||
| from pydantic import Field | ||
|
|
||
|
|
||
| class UMT5ConfigBase(MAXModelConfigBase): | ||
| vocab_size: int = 256384 | ||
| d_model: int = 4096 | ||
| d_kv: int = 64 | ||
| d_ff: int = 10240 | ||
| num_layers: int = 24 | ||
| num_decoder_layers: int | None = 24 | ||
| num_heads: int = 64 | ||
| relative_attention_num_buckets: int = 32 | ||
| relative_attention_max_distance: int = 128 | ||
| dropout_rate: float = 0.1 | ||
| layer_norm_epsilon: float = 1e-6 | ||
| initializer_factor: float = 1.0 | ||
| feed_forward_proj: str = "gated-gelu" | ||
| dense_act_fn: str | None = Field(default=None, exclude=True) | ||
| is_gated_act: bool = Field(default=False, exclude=True) | ||
| is_decoder: bool = Field(default=False, exclude=True) | ||
| is_encoder_decoder: bool = True | ||
| use_cache: bool = True | ||
| output_past: bool = True | ||
| pad_token_id: int = 0 | ||
| eos_token_id: int = 1 | ||
| decoder_start_token_id: int = 0 | ||
| classifier_dropout: float = 0.0 | ||
| scalable_attention: bool = True | ||
| tie_word_embeddings: bool = False | ||
| tokenizer_class: str = "T5Tokenizer" | ||
| device: DeviceRef = Field(default_factory=DeviceRef.GPU) | ||
| dtype: DType = DType.bfloat16 | ||
|
|
||
|
|
||
| class UMT5Config(UMT5ConfigBase): | ||
| @staticmethod | ||
| def generate( | ||
| config_dict: dict[str, Any], | ||
| encoding: SupportedEncoding, | ||
| devices: list[Device], | ||
| ) -> UMT5ConfigBase: | ||
| init_dict = { | ||
| key: value | ||
| for key, value in config_dict.items() | ||
| if key in UMT5ConfigBase.__annotations__ | ||
| } | ||
| init_dict.update( | ||
| { | ||
| "dtype": supported_encoding_dtype(encoding), | ||
| "device": DeviceRef.from_device(devices[0]), | ||
| } | ||
| ) | ||
| return UMT5ConfigBase(**init_dict) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding
DType.bfloat16here overrides theencodingandconfigsettings provided during initialization. While the comment explains this is a workaround for specific Wan 2.1 checkpoints, it limits the reusability of theUMT5Modelcomponent for other models that might require different precision (e.g.,float32orfloat16). This precision-forcing logic should ideally be handled during configuration resolution or guarded by a check to ensure it only applies when necessary.