diff --git a/max/python/max/pipelines/architectures/wan/embeddings.py b/max/python/max/pipelines/architectures/wan/embeddings.py new file mode 100644 index 00000000000..274a55afbf1 --- /dev/null +++ b/max/python/max/pipelines/architectures/wan/embeddings.py @@ -0,0 +1,180 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import math +from collections.abc import Callable +from typing import Any + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.activation import activation_function_from_name +from max.nn.layer import Module +from max.nn.linear import Linear + + +def get_timestep_embedding( + timesteps: TensorValue, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> TensorValue: + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * ops.range( + 0, half_dim, dtype=DType.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + emb = ops.exp(exponent) + timesteps_expanded = ops.cast(ops.unsqueeze(timesteps, 1), DType.float32) + emb_expanded = ops.unsqueeze(emb, 0) + emb = scale * timesteps_expanded * emb_expanded + emb = ops.concat([ops.sin(emb), ops.cos(emb)], axis=-1) + if flip_sin_to_cos: + emb = ops.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + if embedding_dim % 2 == 1: + emb = ops.pad(emb, (0, 0, 0, 1)) + return emb + + +def apply_rotary_emb( + x: TensorValue, + freqs_cis: tuple[TensorValue, TensorValue], + use_real: bool = True, + use_real_unbind_dim: int = -1, + sequence_dim: int = 2, +) -> TensorValue: + if not use_real: + raise NotImplementedError("Only use_real=True is supported") + + cos, sin = freqs_cis + if sequence_dim == 2: + cos = ops.unsqueeze(ops.unsqueeze(cos, 0), 0) + sin = ops.unsqueeze(ops.unsqueeze(sin, 0), 0) + elif sequence_dim == 1: + cos = ops.unsqueeze(ops.unsqueeze(cos, 0), 2) + sin = ops.unsqueeze(ops.unsqueeze(sin, 0), 2) + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + input_dtype = x.dtype + + if use_real_unbind_dim == -1: + x_shape: list[Any] = list(x.shape) + new_shape: list[Any] = x_shape[:-1] + [x_shape[-1] // 2, 2] + x_reshaped = ops.reshape(x, new_shape) + x_real = x_reshaped[..., 0] + x_imag = x_reshaped[..., 1] + x_rotated_stacked = ops.stack([-x_imag, x_real], axis=-1) + x_rotated = ops.reshape(x_rotated_stacked, x_shape) + elif use_real_unbind_dim == -2: + x_shape = list(x.shape) + new_shape = x_shape[:-1] + [2, x_shape[-1] // 2] + x_reshaped = ops.reshape(x, new_shape) + x_real = x_reshaped[..., 0, :] + x_imag = x_reshaped[..., 1, :] + x_rotated = ops.concat([-x_imag, x_real], axis=-1) + else: + raise ValueError( + f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2." + ) + + out = ops.cast(x, DType.float32) * ops.cast(cos, DType.float32) + ops.cast( + x_rotated, DType.float32 + ) * ops.cast(sin, DType.float32) + return ops.cast(out, input_dtype) + + +class Timesteps(Module): + def __init__( + self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: float = 1, + ): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def __call__(self, timesteps: TensorValue) -> TensorValue: + return get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + + +class TimestepEmbedding(Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ): + super().__init__() + self.linear_1 = Linear( + in_dim=in_channels, + out_dim=time_embed_dim, + dtype=dtype, + device=device, + has_bias=sample_proj_bias, + ) + self.cond_proj: Linear | None + if cond_proj_dim is not None: + self.cond_proj = Linear( + in_dim=cond_proj_dim, + out_dim=in_channels, + dtype=dtype, + device=device, + has_bias=False, + ) + else: + self.cond_proj = None + self.act = activation_function_from_name(act_fn) + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + self.linear_2 = Linear( + in_dim=time_embed_dim, + out_dim=time_embed_dim_out, + dtype=dtype, + device=device, + has_bias=sample_proj_bias, + ) + self.post_act: Callable[[TensorValue], TensorValue] | None + if post_act_fn is not None: + self.post_act = activation_function_from_name(post_act_fn) + else: + self.post_act = None + + def __call__(self, sample: TensorValue) -> TensorValue: + if self.cond_proj is not None: + sample = sample + self.cond_proj(sample) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample diff --git a/max/python/max/pipelines/architectures/wan/model.py b/max/python/max/pipelines/architectures/wan/model.py new file mode 100644 index 00000000000..6d3952ebb04 --- /dev/null +++ b/max/python/max/pipelines/architectures/wan/model.py @@ -0,0 +1,605 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import logging +import threading +from collections.abc import Callable +from functools import lru_cache +from pathlib import Path +from typing import Any + +import numpy as np +from max.driver import CPU, Buffer, Device +from max.dtype import DType +from max.engine import InferenceSession, Model +from max.graph import DeviceRef, Graph, TensorType +from max.graph.buffer_utils import cast_dlpack_to +from max.graph.weights import WeightData, Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.component_model import ComponentModel + +from .model_config import WanConfig +from .wan_transformer import ( + WanTransformerBlock, + WanTransformerPostProcess, + WanTransformerPreProcess, +) + +logger = logging.getLogger(__name__) + +# Weight key remapping from diffusers -> MAX module naming +_KEY_REMAP = [ + (".attn1.to_out.0.", ".attn1.to_out."), + (".attn2.to_out.0.", ".attn2.to_out."), + (".ffn.net.0.proj.", ".ffn.proj."), + (".ffn.net.2.", ".ffn.linear_out."), + # Image embedder GELU FFN: diffusers nested structure → flat Linear layers + ("image_embedder.ff.net.0.proj.", "image_embedder.ff_proj."), + ("image_embedder.ff.net.2.", "image_embedder.ff_out."), +] + +# Keys to skip (non-persistent buffers computed at runtime) +_SKIP_PREFIXES = ("rope.freqs_cos", "rope.freqs_sin") + + +def _remap_state_dict( + weights: Weights, + target_dtype: DType = DType.bfloat16, +) -> dict[str, Any]: + """Remap diffusers weight keys to MAX module naming, permute Conv3d, + and cast weights to target dtype. + + Some WAN checkpoints store weights as float32 (A14B), others as + bfloat16 (5B). We cast all to target_dtype to match the module + parameter declarations. + """ + state_dict: dict[str, Any] = {} + + # First pass: collect all weights with key remapping. + raw_dict: dict[str, Any] = {} + for key, value in weights.items(): + if any(key.startswith(prefix) for prefix in _SKIP_PREFIXES): + continue + + new_key = key + for old, new in _KEY_REMAP: + new_key = new_key.replace(old, new) + + tensor = value.data() + + # Conv3d weight permutation for patch_embedding + # Diffusers: [F, C, D, H, W] (PyTorch FCDHW) + # MAX Conv3d(permute=False): [D, H, W, C, F] (QRSCF) + if new_key == "patch_embedding.weight" and len(tensor.shape) == 5: + buf = tensor.to_buffer() if hasattr(tensor, "to_buffer") else tensor + t_f32 = cast_dlpack_to(buf, tensor.dtype, DType.float32, CPU()) + permuted: WeightData | np.ndarray = np.ascontiguousarray( + np.from_dlpack(t_f32).transpose(2, 3, 4, 1, 0) + ) + raw_dict[new_key] = permuted + else: + raw_dict[new_key] = tensor + + # Second pass: fuse attn2.to_k + attn2.to_v into attn2.to_kv + fused_keys: set[str] = set() + for key in list(raw_dict.keys()): + if ".attn2.to_k." in key: + k_key = key + v_key = key.replace(".attn2.to_k.", ".attn2.to_v.") + kv_key = key.replace(".attn2.to_k.", ".attn2.to_kv.") + if v_key in raw_dict: + k_data = raw_dict[k_key] + v_data = raw_dict[v_key] + k_buf = ( + k_data.to_buffer() + if hasattr(k_data, "to_buffer") + else k_data + ) + v_buf = ( + v_data.to_buffer() + if hasattr(v_data, "to_buffer") + else v_data + ) + k_f32 = cast_dlpack_to( + k_buf, k_data.dtype, DType.float32, CPU() + ) + v_f32 = cast_dlpack_to( + v_buf, v_data.dtype, DType.float32, CPU() + ) + k_np = np.from_dlpack(k_f32) + v_np = np.from_dlpack(v_f32) + kv_np = np.ascontiguousarray( + np.concatenate([k_np, v_np], axis=0) + ) + state_dict[kv_key] = kv_np + fused_keys.add(k_key) + fused_keys.add(v_key) + + for key, tensor in raw_dict.items(): + if key not in fused_keys: + state_dict[key] = tensor + + cpu_device = CPU() + for key in state_dict: + tensor = state_dict[key] + if isinstance(tensor, WeightData): + src_dtype = tensor.dtype + dlpack_obj = tensor.to_buffer() + else: + src_dtype = DType.float32 + dlpack_obj = tensor + state_dict[key] = cast_dlpack_to( + dlpack_obj, src_dtype, target_dtype, cpu_device + ) + + return state_dict + + +def _get_1d_rotary_pos_embed_np( + dim: int, + pos: np.ndarray, + theta: float = 10000.0, +) -> tuple[np.ndarray, np.ndarray]: + """Compute 1D rotary position embeddings (numpy, for eager RoPE).""" + freq_exponent = np.arange(0, dim, 2, dtype=np.float64) / dim + freqs = 1.0 / (theta**freq_exponent) + angles = np.outer(pos.astype(np.float64), freqs) + cos_emb = np.cos(angles).astype(np.float32) + sin_emb = np.sin(angles).astype(np.float32) + cos_emb = np.repeat(cos_emb, 2, axis=1) + sin_emb = np.repeat(sin_emb, 2, axis=1) + return cos_emb, sin_emb + + +@lru_cache(maxsize=8) +def _compute_wan_rope_cached( + num_frames: int, + height: int, + width: int, + patch_size: tuple[int, int, int], + head_dim: int, + theta: float = 10000.0, +) -> tuple[np.ndarray, np.ndarray]: + """Compute 3D RoPE cos/sin arrays for Wan transformer (cached by resolution).""" + p_t, p_h, p_w = patch_size + ppf = num_frames // p_t + pph = height // p_h + ppw = width // p_w + + d_h = (head_dim // 3 // 2) * 2 + d_w = d_h + d_t = head_dim - d_h - d_w + + cos_t, sin_t = _get_1d_rotary_pos_embed_np(d_t, np.arange(ppf), theta) + cos_h, sin_h = _get_1d_rotary_pos_embed_np(d_h, np.arange(pph), theta) + cos_w, sin_w = _get_1d_rotary_pos_embed_np(d_w, np.arange(ppw), theta) + + cos_t = np.broadcast_to(cos_t[:, None, None, :], (ppf, pph, ppw, d_t)) + sin_t = np.broadcast_to(sin_t[:, None, None, :], (ppf, pph, ppw, d_t)) + cos_h = np.broadcast_to(cos_h[None, :, None, :], (ppf, pph, ppw, d_h)) + sin_h = np.broadcast_to(sin_h[None, :, None, :], (ppf, pph, ppw, d_h)) + cos_w = np.broadcast_to(cos_w[None, None, :, :], (ppf, pph, ppw, d_w)) + sin_w = np.broadcast_to(sin_w[None, None, :, :], (ppf, pph, ppw, d_w)) + + rope_cos = np.concatenate([cos_t, cos_h, cos_w], axis=-1) + rope_sin = np.concatenate([sin_t, sin_h, sin_w], axis=-1) + + seq_len = ppf * pph * ppw + rope_cos = np.ascontiguousarray(rope_cos.reshape(seq_len, head_dim)) + rope_sin = np.ascontiguousarray(rope_sin.reshape(seq_len, head_dim)) + return rope_cos, rope_sin + + +class BlockLevelModel: + """Executes transformer forward pass as pre -> N blocks -> post. + + Each component is a separately compiled graph, so only one block's + workspace is live at any time. This keeps peak VRAM low. + """ + + def __init__( + self, + pre: Model, + blocks: list[Model], + post: Model, + ) -> None: + self.pre = pre + self.blocks = blocks + self.post = post + + def __call__( + self, + hidden_states: Buffer, + timestep: Buffer, + encoder_hidden_states: Buffer, + rope_cos: Buffer, + rope_sin: Buffer, + spatial_shape: Buffer, + ) -> Buffer: + pre_out = self.pre.execute( + hidden_states, timestep, encoder_hidden_states + ) + hs, temb, timestep_proj, text_emb = ( + pre_out[0], + pre_out[1], + pre_out[2], + pre_out[3], + ) + for block in self.blocks: + block_out = block.execute( + hs, text_emb, timestep_proj, rope_cos, rope_sin + ) + hs = block_out[0] + post_out = self.post.execute(hs, temb, spatial_shape) + return post_out[0] + + +class WanTransformerModel(ComponentModel): + """MAX-native Wan DiT interface with block-level compilation. + + Each block is compiled independently so only one block's workspace + is live at any time, keeping peak VRAM low. + """ + + def __init__( + self, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession | None = None, + eager_load: bool = True, + lora_path: str | Path | None = None, + lora_scale: float = 1.0, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = WanConfig.generate(config, encoding, devices) + self.config.dtype = DType.bfloat16 + self._state_dict: dict[str, Any] | None = None + self._lora_path = lora_path + self._lora_scale = lora_scale + self._lora_merged = False + self.model: BlockLevelModel | None = None + self._weight_registry_cache: dict[ + int, + tuple[dict[str, Any], list[dict[str, Any]], dict[str, Any]], + ] = {} + self.session = session or InferenceSession(devices=devices) + self._load_lock = threading.Lock() + if eager_load: + self.prepare_state_dict() + + def _ensure_state_dict(self) -> dict[str, Any]: + if self._state_dict is None: + if self.weights is None: + raise RuntimeError( + "WanTransformerModel weights are unavailable " + "while state_dict is not initialized." + ) + self._state_dict = _remap_state_dict( + self.weights, target_dtype=DType.bfloat16 + ) + self.weights = None # type: ignore[assignment] + + if self._lora_path and not self._lora_merged: + from .lora_utils import load_and_merge_lora + + self._state_dict = load_and_merge_lora( + self._state_dict, self._lora_path, self._lora_scale + ) + self._lora_merged = True + + return self._state_dict + + def prepare_state_dict(self) -> dict[str, Any]: + """Materialize the remapped state dict without compiling graphs.""" + with self._load_lock: + return self._ensure_state_dict() + + def _split_state_dict( + self, state_dict: dict[str, Any] + ) -> tuple[dict[str, Any], list[dict[str, Any]], dict[str, Any]]: + """Split flat state dict into pre/block/post weight groups.""" + pre_weights: dict[str, Any] = {} + post_weights: dict[str, Any] = {} + block_weights_list: list[dict[str, Any]] = [ + {} for _ in range(self.config.num_layers) + ] + + for key, value in state_dict.items(): + if key.startswith("patch_embedding.") or key.startswith( + "condition_embedder." + ): + pre_weights[key] = value + elif key.startswith("blocks."): + rest = key[len("blocks.") :] + dot = rest.index(".") + block_idx = int(rest[:dot]) + sub_key = rest[dot + 1 :] + block_weights_list[block_idx][sub_key] = value + else: + post_weights[key] = value + + return pre_weights, block_weights_list, post_weights + + def _build_weight_registries( + self, state_dict: dict[str, Any] + ) -> tuple[dict[str, Any], list[dict[str, Any]], dict[str, Any]]: + """Build module-level weight registries for pre/block/post.""" + dim = self.config.num_attention_heads * self.config.attention_head_dim + dtype = self.config.dtype + dev_ref = DeviceRef.from_device(self.config.device) + pre_weights, block_weights_list, post_weights = self._split_state_dict( + state_dict + ) + + pre_module = WanTransformerPreProcess( + self.config, dtype=dtype, device=dev_ref + ) + pre_module.load_state_dict(pre_weights, weight_alignment=1, strict=True) + + block_registries: list[dict[str, Any]] = [] + block_module = WanTransformerBlock( + dim=dim, + ffn_dim=self.config.ffn_dim, + num_heads=self.config.num_attention_heads, + head_dim=self.config.attention_head_dim, + text_dim=dim, + cross_attn_norm=self.config.cross_attn_norm, + eps=self.config.eps, + added_kv_proj_dim=self.config.added_kv_proj_dim, + dtype=dtype, + device=dev_ref, + ) + for block_weights in block_weights_list: + block_module.load_state_dict( + block_weights, weight_alignment=1, strict=True + ) + block_registries.append(block_module.state_dict()) + + post_module = WanTransformerPostProcess( + self.config, dtype=dtype, device=dev_ref + ) + post_module.load_state_dict( + post_weights, weight_alignment=1, strict=True + ) + + return ( + pre_module.state_dict(), + block_registries, + post_module.state_dict(), + ) + + def _get_cached_weight_registries( + self, state_dict: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], list[dict[str, Any]], dict[str, Any]]: + """Return weight registries, caching by state_dict identity.""" + target_state_dict = state_dict or self._ensure_state_dict() + cache_key = id(target_state_dict) + cached = self._weight_registry_cache.get(cache_key) + if cached is not None: + return cached + + registries = self._build_weight_registries(target_state_dict) + self._weight_registry_cache[cache_key] = registries + return registries + + def reload_model_weights( + self, state_dict: dict[str, Any] | None = None + ) -> None: + """Reload weights into already-compiled models for MoE weight switching.""" + with self._load_lock: + if self.model is None: + raise RuntimeError("Wan transformer model not compiled.") + + pre_registry, block_registries, post_registry = ( + self._get_cached_weight_registries(state_dict) + ) + + self.model.pre._load(pre_registry) + for compiled_block, block_registry in zip( + self.model.blocks, block_registries, strict=True + ): + compiled_block._load(block_registry) + self.model.post._load(post_registry) + + def load_model( # type: ignore[override] + self, + *, + seq_text_len: int, + seq_len: int, + batch_size: int = 1, + ) -> Callable[..., Any]: + """Compile the transformer as separate pre/block/post graphs. + + Block graphs are compiled with symbolic ``seq_len`` and concrete + ``batch_size`` / ``seq_text_len``. Pre/post graphs use symbolic + spatial dims. + """ + with self._load_lock: + if self.model is not None: + return self.__call__ + + state_dict = self._ensure_state_dict() + + dim = ( + self.config.num_attention_heads * self.config.attention_head_dim + ) + dtype = self.config.dtype + dev = self.config.device + dev_ref = DeviceRef.from_device(dev) + + pre_weights, block_weights_list, post_weights = ( + self._split_state_dict(state_dict) + ) + pre_input_types = [ + TensorType( + dtype, + [ + "batch", + self.config.in_channels, + "frames", + "height", + "width", + ], + device=dev, + ), + TensorType(DType.float32, ["batch"], device=dev), + TensorType( + dtype, + ["batch", seq_text_len, self.config.text_dim], + device=dev, + ), + ] + pre_module = WanTransformerPreProcess( + self.config, dtype=dtype, device=dev_ref + ) + pre_module.load_state_dict( + pre_weights, weight_alignment=1, strict=True + ) + with Graph("wan_pre", input_types=pre_input_types) as pre_graph: + outs = pre_module(*(v.tensor for v in pre_graph.inputs)) + pre_graph.output(*outs) + pre_model = self.session.load( + pre_graph, weights_registry=pre_module.state_dict() + ) + block_seq_len_dim: str = "seq_len" + block_input_types = [ + TensorType( + dtype, [batch_size, block_seq_len_dim, dim], device=dev + ), + TensorType(dtype, [batch_size, seq_text_len, dim], device=dev), + TensorType(dtype, [batch_size, 6, dim], device=dev), + TensorType( + DType.float32, + [block_seq_len_dim, self.config.attention_head_dim], + device=dev, + ), + TensorType( + DType.float32, + [block_seq_len_dim, self.config.attention_head_dim], + device=dev, + ), + ] + block_template = WanTransformerBlock( + dim=dim, + ffn_dim=self.config.ffn_dim, + num_heads=self.config.num_attention_heads, + head_dim=self.config.attention_head_dim, + text_dim=dim, + cross_attn_norm=self.config.cross_attn_norm, + eps=self.config.eps, + added_kv_proj_dim=self.config.added_kv_proj_dim, + dtype=dtype, + device=dev_ref, + ) + block_template.load_state_dict( + block_weights_list[0], weight_alignment=1, strict=True + ) + with Graph( + "wan_block", input_types=block_input_types + ) as block_graph: + block_out = block_template( + *(v.tensor for v in block_graph.inputs) + ) + block_graph.output(block_out) + + block_models: list[Model] = [ + self.session.load( + block_graph, + weights_registry=block_template.state_dict(), + ) + ] + for i in range(1, self.config.num_layers): + block_template.load_state_dict( + block_weights_list[i], + weight_alignment=1, + strict=True, + ) + block_models.append( + self.session.load( + block_graph, + weights_registry=block_template.state_dict(), + ) + ) + logger.info( + "Compiled block graph (batch=%d, seq_len=symbolic " + "default=%d, seq_text=%d, %d layers)", + batch_size, + seq_len, + seq_text_len, + len(block_models), + ) + post_input_types = [ + TensorType(dtype, ["batch", "seq_len", dim], device=dev), + TensorType(dtype, ["batch", dim], device=dev), + TensorType(DType.int8, ["ppf", "pph", "ppw"], device=dev), + ] + post_module = WanTransformerPostProcess( + self.config, dtype=dtype, device=dev_ref + ) + post_module.load_state_dict( + post_weights, weight_alignment=1, strict=True + ) + with Graph("wan_post", input_types=post_input_types) as post_graph: + post_out = post_module(*(v.tensor for v in post_graph.inputs)) + post_graph.output(post_out) + post_model = self.session.load( + post_graph, weights_registry=post_module.state_dict() + ) + self.model = BlockLevelModel(pre_model, block_models, post_model) + return self.__call__ + + def compute_rope( + self, + num_frames: int, + height: int, + width: int, + ) -> tuple[Buffer, Buffer]: + """Compute 3D RoPE cos/sin tensors and transfer to device.""" + rope_cos_np, rope_sin_np = _compute_wan_rope_cached( + num_frames, + height, + width, + self.config.patch_size, + self.config.attention_head_dim, + ) + device = self.devices[0] + return ( + Buffer.from_numpy(rope_cos_np).to(device), + Buffer.from_numpy(rope_sin_np).to(device), + ) + + def __call__( + self, + hidden_states: Buffer, + timestep: Buffer, + encoder_hidden_states: Buffer, + rope_cos: Buffer, + rope_sin: Buffer, + spatial_shape: Buffer, + ) -> Buffer: + if self.model is None: + raise RuntimeError( + "Wan transformer model not compiled. Call load_model() first." + ) + return self.model( + hidden_states, + timestep, + encoder_hidden_states, + rope_cos, + rope_sin, + spatial_shape, + ) diff --git a/max/python/max/pipelines/architectures/wan/model_config.py b/max/python/max/pipelines/architectures/wan/model_config.py new file mode 100644 index 00000000000..def60e6bf1f --- /dev/null +++ b/max/python/max/pipelines/architectures/wan/model_config.py @@ -0,0 +1,64 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import Any + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from max.pipelines.lib.config.config_enums import supported_encoding_dtype +from pydantic import Field + + +class WanConfigBase(MAXModelConfigBase): + # Defaults mirror diffusers WanTransformer3DModel constructor defaults. + patch_size: tuple[int, int, int] = (1, 2, 2) + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: str | None = "rms_norm_across_heads" + eps: float = 1e-6 + image_dim: int | None = None + added_kv_proj_dim: int | None = None + rope_max_seq_len: int = 1024 + pos_embed_seq_len: int | None = None + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class WanConfig(WanConfigBase): + @staticmethod + def generate( + config_dict: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + ) -> "WanConfig": + init_dict = { + key: value + for key, value in config_dict.items() + if key in WanConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": supported_encoding_dtype(encoding), + "device": DeviceRef.from_device(devices[0]), + } + ) + return WanConfig(**init_dict) diff --git a/max/python/max/pipelines/architectures/wan/wan_transformer.py b/max/python/max/pipelines/architectures/wan/wan_transformer.py new file mode 100644 index 00000000000..4b543405e22 --- /dev/null +++ b/max/python/max/pipelines/architectures/wan/wan_transformer.py @@ -0,0 +1,894 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from math import prod + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, Weight, ops +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear + +from .embeddings import ( + TimestepEmbedding, + Timesteps, + apply_rotary_emb, +) +from .model_config import WanConfigBase + + +class WanConv3d(Module): + """3D conv for WAN patch embedding (NDHWC/QRSCF layout).""" + + def __init__( + self, + kernel_size: tuple[int, int, int], + in_channels: int, + out_channels: int, + stride: tuple[int, int, int], + dtype: DType, + device: DeviceRef, + has_bias: bool = True, + ) -> None: + super().__init__() + d, h, w = kernel_size + self.filter = Weight( + "weight", dtype, [d, h, w, in_channels, out_channels], device + ) + self.bias = ( + Weight("bias", dtype, [out_channels], device) if has_bias else None + ) + self.stride = stride + + def __call__(self, x: TensorValue) -> TensorValue: + return ops.conv3d(x, self.filter, stride=self.stride, bias=self.bias) + + +class WanLayerNorm(Module): + """LayerNorm using decomposed ops for float32 numerical stability. + + The built-in ``layer_norm_gpu_block`` kernel hits + ``CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES`` for dim=5120, so we decompose + into basic ops (mean, rsqrt, multiply) that each launch small kernels. + """ + + def __init__( + self, + dim: int, + eps: float = 1e-5, + *, + elementwise_affine: bool = True, + use_bias: bool = True, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.dim = dim + self.eps = eps + self.has_weight = elementwise_affine + self.has_bias = elementwise_affine and use_bias + if elementwise_affine: + self.weight = Weight("weight", dtype, [dim], device) + if use_bias: + self.bias = Weight("bias", dtype, [dim], device) + + def __call__(self, x: TensorValue) -> TensorValue: + original_dtype = x.dtype + x = ops.cast(x, DType.float32) + mean = ops.mean(x, axis=-1) + x = x - mean + var = ops.mean(x * x, axis=-1) + x = x * ops.rsqrt(var + self.eps) + if self.has_weight: + x = x * ops.cast(self.weight, DType.float32) + if self.has_bias: + x = x + ops.cast(self.bias, DType.float32) + return ops.cast(x, original_dtype) + + +class WanRMSNorm(Module): + """RMSNorm using decomposed ops for float32 numerical stability. + + Same reason as WanLayerNorm: the built-in ``rms_norm`` custom kernel + may also hit resource limits for dim=5120. + """ + + def __init__( + self, + dim: int, + eps: float = 1e-6, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.weight = Weight("weight", dtype, [dim], device) + self.eps = eps + + def __call__(self, x: TensorValue) -> TensorValue: + original_dtype = x.dtype + x = ops.cast(x, DType.float32) + rms = ops.mean(x * x, axis=-1) + x = x * ops.rsqrt(rms + self.eps) + x = x * ops.cast(self.weight, DType.float32) + return ops.cast(x, original_dtype) + + +class WanTextProjection(Module): + def __init__( + self, + in_features: int, + hidden_size: int, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.linear_1 = Linear( + in_dim=in_features, + out_dim=hidden_size, + dtype=dtype, + device=device, + has_bias=True, + ) + self.linear_2 = Linear( + in_dim=hidden_size, + out_dim=hidden_size, + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__(self, caption: TensorValue) -> TensorValue: + hidden_states = self.linear_1(caption) + hidden_states = ops.gelu(hidden_states, approximate="tanh") + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class WanImageEmbedder(Module): + """Image embedding for Wan 2.1 I2V: LayerNorm → GEGLU FFN → LayerNorm. + + Matches diffusers' FeedForward(image_dim, dim, mult=1, activation_fn="gelu") + with pre/post norms. Weight keys:: + + image_embedder.norm1.{weight,bias} + image_embedder.ff.net.0.proj.{weight,bias} (GEGLU gate+value) + image_embedder.ff.net.2.{weight,bias} (output linear) + image_embedder.norm2.{weight,bias} + """ + + def __init__( + self, + image_dim: int, + out_dim: int, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + # Matches diffusers FeedForward(image_dim, out_dim, mult=1, activation_fn="gelu"): + # norm1(image_dim) → Linear(image_dim→image_dim) → GELU → + # Linear(image_dim→out_dim) → norm2(out_dim) + self.norm1 = WanLayerNorm( + image_dim, + elementwise_affine=True, + use_bias=True, + dtype=dtype, + device=device, + ) + self.ff_proj = Linear( + in_dim=image_dim, + out_dim=image_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.ff_out = Linear( + in_dim=image_dim, + out_dim=out_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.norm2 = WanLayerNorm( + out_dim, + elementwise_affine=True, + use_bias=True, + dtype=dtype, + device=device, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + x = self.norm1(x) + x = ops.gelu(self.ff_proj(x)) + x = self.ff_out(x) + return self.norm2(x) + + +class WanTimeTextImageEmbedding(Module): + def __init__( + self, + dim: int, + freq_dim: int, + text_dim: int, + num_layers: int, + *, + image_dim: int | None = None, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.timesteps_proj = Timesteps( + num_channels=freq_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + ) + self.time_embedder = TimestepEmbedding( + in_channels=freq_dim, + time_embed_dim=dim, + dtype=dtype, + device=device, + ) + # Projects SiLU(temb) to 6 modulation params per block + self.time_proj = Linear( + in_dim=dim, + out_dim=dim * 6, + dtype=dtype, + device=device, + has_bias=True, + ) + self.text_embedder = WanTextProjection( + in_features=text_dim, + hidden_size=dim, + dtype=dtype, + device=device, + ) + # Optional image embedder (Wan 2.1 I2V) + self.image_embedder: WanImageEmbedder | None = None + if image_dim is not None: + self.image_embedder = WanImageEmbedder( + image_dim=image_dim, + out_dim=dim, + dtype=dtype, + device=device, + ) + + def __call__( + self, timestep: TensorValue, encoder_hidden_states: TensorValue + ) -> tuple[TensorValue, TensorValue, TensorValue]: + # Sinusoidal timestep embedding (computed in float32 for precision). + # Cast to the model's working dtype (bf16) for the MLP, matching + # diffusers' behavior: float32 embedding → cast to weight dtype → MLP. + timesteps_emb = self.timesteps_proj(timestep) # [B, freq_dim] float32 + timesteps_emb = ops.cast( + timesteps_emb, encoder_hidden_states.dtype + ) # → bf16 + temb = self.time_embedder(timesteps_emb) # [B, dim] + + # Timestep projection for modulation: SiLU then linear + timestep_proj = self.time_proj(ops.silu(temb)) # [B, dim*6] + # Reshape to [B, 6, dim] for per-block modulation + timestep_proj = ops.reshape( + timestep_proj, + [timestep_proj.shape[0], 6, timestep_proj.shape[1] // 6], + ) + + # Text projection + text_emb = self.text_embedder(encoder_hidden_states) # [B, S, dim] + + return temb, timestep_proj, text_emb + + +class WanSelfAttention(Module): + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + eps: float, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = dim + + self.to_q = Linear( + in_dim=dim, out_dim=dim, dtype=dtype, device=device, has_bias=True + ) + self.to_k = Linear( + in_dim=dim, out_dim=dim, dtype=dtype, device=device, has_bias=True + ) + self.to_v = Linear( + in_dim=dim, out_dim=dim, dtype=dtype, device=device, has_bias=True + ) + self.norm_q = WanRMSNorm(dim, eps=eps, dtype=dtype, device=device) + self.norm_k = WanRMSNorm(dim, eps=eps, dtype=dtype, device=device) + self.to_out = Linear( + in_dim=dim, out_dim=dim, dtype=dtype, device=device, has_bias=True + ) + + def __call__( + self, + hidden_states: TensorValue, + rotary_emb: tuple[TensorValue, TensorValue], + ) -> TensorValue: + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # QK-norm applied across all heads (before reshape) + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape to multi-head: [B, S, D] -> [B, S, H, head_dim] + batch_size = query.shape[0] + seq_len = query.shape[1] + query = ops.reshape( + query, [batch_size, seq_len, self.num_heads, self.head_dim] + ) + key = ops.reshape( + key, [batch_size, seq_len, self.num_heads, self.head_dim] + ) + value = ops.reshape( + value, [batch_size, seq_len, self.num_heads, self.head_dim] + ) + + # Apply RoPE + original_dtype = query.dtype + query = apply_rotary_emb( + query, + rotary_emb, + use_real=True, + use_real_unbind_dim=-1, + sequence_dim=1, + ) + key = apply_rotary_emb( + key, + rotary_emb, + use_real=True, + use_real_unbind_dim=-1, + sequence_dim=1, + ) + query = ops.cast(query, original_dtype) + key = ops.cast(key, original_dtype) + + # Flash attention + scale = 1.0 / (self.head_dim**0.5) + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=scale, + ) + + # Reshape back: [B, S, H, head_dim] -> [B, S, D] + hidden_states = ops.reshape( + hidden_states, + [hidden_states.shape[0], hidden_states.shape[1], self.inner_dim], + ) + hidden_states = ops.cast(hidden_states, original_dtype) + + return self.to_out(hidden_states) + + +class WanCrossAttention(Module): + def __init__( + self, + dim: int, + text_dim: int, + num_heads: int, + head_dim: int, + eps: float, + *, + added_kv_proj_dim: int | None = None, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = dim + self._has_added_kv = added_kv_proj_dim is not None + + self.to_q = Linear( + in_dim=dim, out_dim=dim, dtype=dtype, device=device, has_bias=True + ) + # Fused K+V projection from text embeddings + self.to_kv = Linear( + in_dim=text_dim, + out_dim=dim * 2, + dtype=dtype, + device=device, + has_bias=True, + ) + self.norm_q = WanRMSNorm(dim, eps=eps, dtype=dtype, device=device) + self.norm_k = WanRMSNorm(dim, eps=eps, dtype=dtype, device=device) + self.to_out = Linear( + in_dim=dim, out_dim=dim, dtype=dtype, device=device, has_bias=True + ) + + # Optional added KV projections for image conditioning (Wan 2.1 I2V) + if added_kv_proj_dim is not None: + self.add_k_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.add_v_proj = Linear( + in_dim=added_kv_proj_dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.norm_added_q = WanRMSNorm( + dim, eps=eps, dtype=dtype, device=device + ) + self.norm_added_k = WanRMSNorm( + dim, eps=eps, dtype=dtype, device=device + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + image_embeds: TensorValue | None = None, + ) -> TensorValue: + query = self.to_q(hidden_states) + + # Fused KV from text - use explicit slicing instead of chunk + kv = self.to_kv(encoder_hidden_states) + key = kv[:, :, : self.inner_dim] + value = kv[:, :, self.inner_dim :] + + # QK-norm across all heads (before reshape) + query = self.norm_q(query) + key = self.norm_k(key) + + # Added image KV (Wan 2.1 I2V) + if self._has_added_kv and image_embeds is not None: + added_key = self.norm_added_k(self.add_k_proj(image_embeds)) + added_value = self.add_v_proj(image_embeds) + # Concatenate image KV with text KV along sequence dim + key = ops.concat([key, added_key], axis=1) + value = ops.concat([value, added_value], axis=1) + + # Reshape to multi-head + batch_size = query.shape[0] + q_seq_len = query.shape[1] + kv_seq_len = key.shape[1] + query = ops.reshape( + query, [batch_size, q_seq_len, self.num_heads, self.head_dim] + ) + key = ops.reshape( + key, [batch_size, kv_seq_len, self.num_heads, self.head_dim] + ) + value = ops.reshape( + value, [batch_size, kv_seq_len, self.num_heads, self.head_dim] + ) + + # Flash attention (no RoPE for cross-attention) + original_dtype = query.dtype + scale = 1.0 / (self.head_dim**0.5) + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=scale, + ) + + # Reshape back + hidden_states = ops.reshape( + hidden_states, + [hidden_states.shape[0], hidden_states.shape[1], self.inner_dim], + ) + hidden_states = ops.cast(hidden_states, original_dtype) + + return self.to_out(hidden_states) + + +class WanFeedForward(Module): + def __init__( + self, + dim: int, + ffn_dim: int, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + # WAN uses "gelu-approximate" (simple GELU), NOT GEGLU. + # ffn_dim is the direct projection output size (no 2x expansion). + self.proj = Linear( + in_dim=dim, + out_dim=ffn_dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.linear_out = Linear( + in_dim=ffn_dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + hidden = self.proj(x) + hidden = ops.gelu(hidden, approximate="tanh") + return self.linear_out(hidden) + + +class WanTransformerBlock(Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + head_dim: int, + text_dim: int, + cross_attn_norm: bool, + eps: float, + *, + added_kv_proj_dim: int | None = None, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.scale_shift_table = Weight( + "scale_shift_table", dtype, [1, 6, dim], device + ) + self.norm1 = WanLayerNorm( + dim, + eps=eps, + elementwise_affine=False, + dtype=dtype, + device=device, + ) + self.attn1 = WanSelfAttention( + dim, num_heads, head_dim, eps, dtype=dtype, device=device + ) + self.norm2 = WanLayerNorm( + dim, + eps=eps, + elementwise_affine=cross_attn_norm, + use_bias=cross_attn_norm, + dtype=dtype, + device=device, + ) + self.attn2 = WanCrossAttention( + dim, + text_dim, + num_heads, + head_dim, + eps, + added_kv_proj_dim=added_kv_proj_dim, + dtype=dtype, + device=device, + ) + self.norm3 = WanLayerNorm( + dim, + eps=eps, + elementwise_affine=False, + dtype=dtype, + device=device, + ) + self.ffn = WanFeedForward(dim, ffn_dim, dtype=dtype, device=device) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + timestep_proj: TensorValue, + rope_cos: TensorValue, + rope_sin: TensorValue, + image_embeds: TensorValue | None = None, + ) -> TensorValue: + rotary_emb = (rope_cos, rope_sin) + + # Modulation: scale_shift_table[1,6,D] + timestep_proj[B,6,D] + mod = self.scale_shift_table + timestep_proj # [B, 6, D] + + # Split into 6 modulation parameters + shift_sa, scale_sa, gate_sa = ( + mod[:, 0:1, :], + mod[:, 1:2, :], + mod[:, 2:3, :], + ) + shift_ff, scale_ff, gate_ff = ( + mod[:, 3:4, :], + mod[:, 4:5, :], + mod[:, 5:6, :], + ) + + # Self-attention + x = self.norm1(hidden_states) + x = x * (1 + scale_sa) + shift_sa + x = self.attn1(x, rotary_emb) + hidden_states = hidden_states + gate_sa * x + + # Cross-attention (with optional image KV for 2.1 I2V) + x = self.norm2(hidden_states) + x = self.attn2(x, encoder_hidden_states, image_embeds=image_embeds) + hidden_states = hidden_states + x + + # Feed-forward + x = self.norm3(hidden_states) + x = x * (1 + scale_ff) + shift_ff + x = self.ffn(x) + hidden_states = hidden_states + gate_ff * x + + return hidden_states + + +class WanTransformerPreProcess(Module): + """Patch embedding + condition embedding (compiled separately).""" + + def __init__( + self, + config: WanConfigBase, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + dim = config.num_attention_heads * config.attention_head_dim + self.inner_dim = dim + + self.patch_embedding = WanConv3d( + kernel_size=config.patch_size, + in_channels=config.in_channels, + out_channels=dim, + stride=config.patch_size, + dtype=dtype, + device=device, + ) + self.condition_embedder = WanTimeTextImageEmbedding( + dim=dim, + freq_dim=config.freq_dim, + text_dim=config.text_dim, + num_layers=config.num_layers, + image_dim=getattr(config, "image_dim", None), + dtype=dtype, + device=device, + ) + + def __call__( + self, + hidden_states: TensorValue, + timestep: TensorValue, + encoder_hidden_states: TensorValue, + ) -> tuple[TensorValue, TensorValue, TensorValue, TensorValue]: + batch_size = hidden_states.shape[0] + hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) + hs = self.patch_embedding(hs) + hs = ops.permute(hs, [0, 4, 1, 2, 3]) + seq_len = hs.shape[2] * hs.shape[3] * hs.shape[4] + hs = ops.reshape(hs, [batch_size, self.inner_dim, seq_len]) + hs = ops.permute(hs, [0, 2, 1]) + + temb, timestep_proj, text_emb = self.condition_embedder( + timestep, encoder_hidden_states + ) + return hs, temb, timestep_proj, text_emb + + +class WanTransformerPostProcess(Module): + """Output modulation + unpatchify (compiled separately).""" + + def __init__( + self, + config: WanConfigBase, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + dim = config.num_attention_heads * config.attention_head_dim + self.inner_dim = dim + self.out_channels = config.out_channels + self.patch_size = config.patch_size + + self.scale_shift_table = Weight( + "scale_shift_table", dtype, [1, 2, dim], device + ) + self.norm_out = WanLayerNorm( + dim, + eps=config.eps, + elementwise_affine=False, + dtype=dtype, + device=device, + ) + self.proj_out = Linear( + in_dim=dim, + out_dim=config.out_channels * prod(config.patch_size), + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__( + self, + hidden_states: TensorValue, + temb: TensorValue, + spatial_shape: TensorValue, + ) -> TensorValue: + batch_size = hidden_states.shape[0] + p_t, p_h, p_w = self.patch_size + ppf = spatial_shape.shape[0] + pph = spatial_shape.shape[1] + ppw = spatial_shape.shape[2] + + mod = self.scale_shift_table + ops.reshape( + temb, [batch_size, 1, self.inner_dim] + ) + shift = mod[:, :1, :] + scale = mod[:, 1:, :] + hs = self.norm_out(hidden_states) * (1.0 + scale) + shift + hs = self.proj_out(hs) + hs = ops.rebind( + hs, + shape=[ + batch_size, + ppf * pph * ppw, + self.out_channels * p_t * p_h * p_w, + ], + ) + + hs = ops.reshape( + hs, + [batch_size, ppf, pph, ppw, p_t, p_h, p_w, self.out_channels], + ) + hs = ops.permute(hs, [0, 7, 1, 4, 2, 5, 3, 6]) + hs = ops.reshape( + hs, + [batch_size, self.out_channels, ppf * p_t, pph * p_h, ppw * p_w], + ) + return ops.cast(hs, DType.bfloat16) + + +class WanTransformer3DModel(Module): + """Full transformer (for reference / single-graph compilation).""" + + def __init__( + self, + config: WanConfigBase, + *, + dtype: DType = DType.bfloat16, + device: DeviceRef = DeviceRef.CPU(), + ) -> None: + super().__init__() + self.config = config + dim = config.num_attention_heads * config.attention_head_dim + self.inner_dim = dim + self.num_heads = config.num_attention_heads + self.head_dim = config.attention_head_dim + self.out_channels = config.out_channels + self.patch_size = config.patch_size + + self.patch_embedding = WanConv3d( + kernel_size=config.patch_size, + in_channels=config.in_channels, + out_channels=dim, + stride=config.patch_size, + dtype=dtype, + device=device, + ) + self.condition_embedder = WanTimeTextImageEmbedding( + dim=dim, + freq_dim=config.freq_dim, + text_dim=config.text_dim, + num_layers=config.num_layers, + image_dim=getattr(config, "image_dim", None), + dtype=dtype, + device=device, + ) + self.blocks = LayerList( + [ + WanTransformerBlock( + dim=dim, + ffn_dim=config.ffn_dim, + num_heads=config.num_attention_heads, + head_dim=config.attention_head_dim, + text_dim=dim, + cross_attn_norm=config.cross_attn_norm, + eps=config.eps, + dtype=dtype, + device=device, + ) + for _ in range(config.num_layers) + ] + ) + self.scale_shift_table = Weight( + "scale_shift_table", dtype, [1, 2, dim], device + ) + self.norm_out = WanLayerNorm( + dim, + eps=config.eps, + elementwise_affine=False, + dtype=dtype, + device=device, + ) + self.proj_out = Linear( + in_dim=dim, + out_dim=config.out_channels * prod(config.patch_size), + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__( + self, + hidden_states: TensorValue, + timestep: TensorValue, + encoder_hidden_states: TensorValue, + rope_cos: TensorValue, + rope_sin: TensorValue, + ) -> TensorValue: + batch_size = hidden_states.shape[0] + orig_T = hidden_states.shape[2] + orig_H = hidden_states.shape[3] + orig_W = hidden_states.shape[4] + p_t, p_h, p_w = self.patch_size + ppf = orig_T // p_t + pph = orig_H // p_h + ppw = orig_W // p_w + + hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) + hs = self.patch_embedding(hs) + hs = ops.permute(hs, [0, 4, 1, 2, 3]) + hs = ops.reshape(hs, [batch_size, self.inner_dim, ppf * pph * ppw]) + hs = ops.permute(hs, [0, 2, 1]) + + temb, timestep_proj, text_emb = self.condition_embedder( + timestep, encoder_hidden_states + ) + + # Rebind RoPE to match the sequence length derived from spatial dims. + seq_len = ppf * pph * ppw + rope_cos = ops.rebind(rope_cos, shape=[seq_len, self.head_dim]) + rope_sin = ops.rebind(rope_sin, shape=[seq_len, self.head_dim]) + + for block in self.blocks: + hs = block(hs, text_emb, timestep_proj, rope_cos, rope_sin) + + mod = self.scale_shift_table + ops.reshape( + temb, [batch_size, 1, self.inner_dim] + ) + shift = mod[:, :1, :] + scale = mod[:, 1:, :] + hs = self.norm_out(hs) * (1.0 + scale) + shift + hs = self.proj_out(hs) + + hs = ops.reshape( + hs, + [batch_size, ppf, pph, ppw, p_t, p_h, p_w, self.out_channels], + ) + hs = ops.permute(hs, [0, 7, 1, 4, 2, 5, 3, 6]) + hs = ops.reshape( + hs, + [batch_size, self.out_channels, ppf * p_t, pph * p_h, ppw * p_w], + ) + return ops.cast(hs, self.config.dtype)