Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions max/python/max/pipelines/architectures/umt5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from .model import UMT5Model

__all__ = ["UMT5Model"]
107 changes: 107 additions & 0 deletions max/python/max/pipelines/architectures/umt5/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from typing import Any

from max.driver import Device
from max.dtype import DType
from max.engine import InferenceSession, Model
from max.graph import DeviceRef, Graph, TensorType
from max.graph.weights import WeightData, Weights
from max.pipelines.lib import SupportedEncoding
from max.pipelines.lib.interfaces.component_model import ComponentModel

from .model_config import UMT5Config, UMT5ConfigBase
from .umt5 import UMT5EncoderModel


def _prepare_state_dict(
weights: Weights,
target_dtype: DType | None = None,
) -> dict[str, WeightData]:
"""Convert Weights to a raw state dict, normalizing tied embedding keys.

HF UMT5 ties ``shared.weight`` and ``encoder.embed_tokens.weight``.
Our module owns the embedding as ``shared``, so we normalize to that key
and drop the alias to avoid strict-mode validation failures.

If ``target_dtype`` is provided, all weights are cast to that dtype
(e.g. float32 → bfloat16 for Wan 2.1 checkpoints).
"""
state_dict: dict[str, WeightData] = {}
for key, value in weights.items():
wd = value.data()
if target_dtype is not None and wd.dtype != target_dtype:
wd = wd.astype(target_dtype)
state_dict[key] = wd

encoder_emb = state_dict.pop("encoder.embed_tokens.weight", None)
if "shared.weight" not in state_dict and encoder_emb is not None:
state_dict["shared.weight"] = encoder_emb

return state_dict


class UMT5Model(ComponentModel):
def __init__(
self,
config: dict[str, Any],
encoding: SupportedEncoding,
devices: list[Device],
weights: Weights,
session: InferenceSession | None = None,
) -> None:
super().__init__(config, encoding, devices, weights)
self.session = session or InferenceSession(devices=devices)
self.config: UMT5ConfigBase = UMT5Config.generate(
config,
encoding,
devices,
)
self.load_model()

def load_model(self) -> Model:
assert self.weights is not None, "Weights already freed"
# Force bfloat16 — some repos (Wan 2.1) declare float32 but
# should run in bfloat16 on GPU. Override both config and weights.
dtype = DType.bfloat16
self.config.dtype = dtype
Comment on lines +77 to +78

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding DType.bfloat16 here overrides the encoding and config settings provided during initialization. While the comment explains this is a workaround for specific Wan 2.1 checkpoints, it limits the reusability of the UMT5Model component for other models that might require different precision (e.g., float32 or float16). This precision-forcing logic should ideally be handled during configuration resolution or guarded by a check to ensure it only applies when necessary.

state_dict = _prepare_state_dict(self.weights, target_dtype=dtype)
dev = self.devices[0]
dev_ref = DeviceRef.from_device(dev)

# Build module and load weights
module = UMT5EncoderModel(self.config, dtype=dtype, device=dev_ref)
module.load_state_dict(state_dict, weight_alignment=1, strict=True)

# Build graph with symbolic sequence length
# attention_mask comes in as int64 from the pipeline
input_types = [
TensorType(DType.int64, ["batch", "seq_len"], device=dev),
TensorType(DType.int64, ["batch", "seq_len"], device=dev),
]
with Graph("umt5_encoder", input_types=input_types) as graph:
input_ids = graph.inputs[0].tensor
attention_mask = graph.inputs[1].tensor
out = module(input_ids, attention_mask)
graph.output(out)

self.model: Model = self.session.load(
graph, weights_registry=module.state_dict()
)
# Free raw weights after compilation
self.weights = None # type: ignore[assignment]
return self.model

def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
73 changes: 73 additions & 0 deletions max/python/max/pipelines/architectures/umt5/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from typing import Any

from max.driver import Device
from max.dtype import DType
from max.graph import DeviceRef
from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding
from max.pipelines.lib.config.config_enums import supported_encoding_dtype
from pydantic import Field


class UMT5ConfigBase(MAXModelConfigBase):
vocab_size: int = 256384
d_model: int = 4096
d_kv: int = 64
d_ff: int = 10240
num_layers: int = 24
num_decoder_layers: int | None = 24
num_heads: int = 64
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
dropout_rate: float = 0.1
layer_norm_epsilon: float = 1e-6
initializer_factor: float = 1.0
feed_forward_proj: str = "gated-gelu"
dense_act_fn: str | None = Field(default=None, exclude=True)
is_gated_act: bool = Field(default=False, exclude=True)
is_decoder: bool = Field(default=False, exclude=True)
is_encoder_decoder: bool = True
use_cache: bool = True
output_past: bool = True
pad_token_id: int = 0
eos_token_id: int = 1
decoder_start_token_id: int = 0
classifier_dropout: float = 0.0
scalable_attention: bool = True
tie_word_embeddings: bool = False
tokenizer_class: str = "T5Tokenizer"
device: DeviceRef = Field(default_factory=DeviceRef.GPU)
dtype: DType = DType.bfloat16


class UMT5Config(UMT5ConfigBase):
@staticmethod
def generate(
config_dict: dict[str, Any],
encoding: SupportedEncoding,
devices: list[Device],
) -> UMT5ConfigBase:
init_dict = {
key: value
for key, value in config_dict.items()
if key in UMT5ConfigBase.__annotations__
}
init_dict.update(
{
"dtype": supported_encoding_dtype(encoding),
"device": DeviceRef.from_device(devices[0]),
}
)
return UMT5ConfigBase(**init_dict)
Loading
Loading