From 23b1c559cf9d4d88d5de855d352674108a4ebe7a Mon Sep 17 00:00:00 2001 From: byungchul-sqzb Date: Wed, 1 Apr 2026 02:42:36 +0000 Subject: [PATCH] [Pipelines] Refactor and optimize Z-Image modulev3 pipeline - Fix autoencoder import and image postprocessing - Absorb eager F.mul negate into compiled scheduler_step for z_image - Add batched CFG for Z-Image modulev3 pipeline - Optimize fused decode, scheduler caching, and eager reduction - Apply RoPE micro-optimizations stack-info: PR: https://github.com/SqueezeBits/modular/pull/20, branch: byungchul-sqzb/stack/1 --- .../architectures/z_image_modulev3/arch.py | 4 +- .../z_image_modulev3/layers/attention.py | 62 ++- .../z_image_modulev3/layers/embeddings.py | 49 +- .../z_image_modulev3/pipeline_z_image.py | 507 ++++++++++-------- .../z_image_modulev3/weight_adapters.py | 1 + .../architectures/z_image_modulev3/z_image.py | 38 +- 6 files changed, 355 insertions(+), 306 deletions(-) diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/arch.py b/max/python/max/pipelines/architectures/z_image_modulev3/arch.py index f82078d4206..9b5f578cb93 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/arch.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/arch.py @@ -49,10 +49,10 @@ def initialize( name="ZImagePipeline", task=PipelineTask.PIXEL_GENERATION, default_encoding="bfloat16", - supported_encodings={"bfloat16"}, + supported_encodings={"bfloat16", "float32"}, example_repo_ids=[ "Tongyi-MAI/Z-Image", - "Zyphra/Z-Image", + "Tongyi-MAI/Z-Image-Turbo", ], pipeline_model=ZImagePipeline, # type: ignore[arg-type] context_type=PixelContext, diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/layers/attention.py b/max/python/max/pipelines/architectures/z_image_modulev3/layers/attention.py index 93a54e0f98d..f510b908b3d 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/layers/attention.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/layers/attention.py @@ -13,17 +13,49 @@ import math +from max.dtype import DType from max.experimental import functional as F from max.experimental.nn import Linear, Module from max.experimental.nn.norm import RMSNorm -from max.experimental.nn.sequential import ModuleList from max.experimental.tensor import Tensor from max.nn.attention.mask_config import MHAMaskVariant from max.nn.kernels import flash_attention_gpu as _flash_attention_gpu - -from ...flux2_modulev3.layers.embeddings import apply_rotary_emb +from max.nn.kernels import ( + rope_ragged_with_position_ids as _rope_ragged_with_position_ids, +) flash_attention_gpu = F.functional(_flash_attention_gpu) +rope_ragged_with_position_ids = F.functional(_rope_ragged_with_position_ids) + + +def _apply_zimage_qk_rope( + query: Tensor, + key: Tensor, + freqs_cis: Tensor, +) -> tuple[Tensor, Tensor]: + """Apply RoPE using precomputed interleaved [cos, sin] frequencies.""" + batch_size = query.shape[0] + seq_len = query.shape[1] + num_heads = query.shape[2] + head_dim = query.shape[3] + + query_ragged = F.reshape(query, [batch_size * seq_len, num_heads, head_dim]) + key_ragged = F.reshape(key, [batch_size * seq_len, num_heads, head_dim]) + + position_ids = F.arange(0, seq_len, dtype=DType.uint32, device=query.device) + position_ids = F.broadcast_to(position_ids[None, :], [batch_size, seq_len]) + position_ids = F.reshape(position_ids, [batch_size * seq_len]) + + query_out = rope_ragged_with_position_ids( + query_ragged, freqs_cis, position_ids, interleaved=True + ) + key_out = rope_ragged_with_position_ids( + key_ragged, freqs_cis, position_ids, interleaved=True + ) + return ( + F.reshape(query_out, [batch_size, seq_len, num_heads, head_dim]), + F.reshape(key_out, [batch_size, seq_len, num_heads, head_dim]), + ) class ZImageAttention(Module[..., Tensor]): @@ -45,13 +77,12 @@ def __init__( self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None - # Keep ModuleList naming for diffusers-compatible key loading. - self.to_out = ModuleList([Linear(dim, dim, bias=False)]) + self.to_out = Linear(dim, dim, bias=False) def forward( self, hidden_states: Tensor, - freqs_cis: tuple[Tensor, Tensor], + freqs_cis: Tensor, ) -> Tensor: batch_size = hidden_states.shape[0] seq_len = hidden_states.shape[1] @@ -73,22 +104,7 @@ def forward( if self.norm_k is not None: key = self.norm_k(key) - query = apply_rotary_emb( - query, - freqs_cis, - use_real=True, - use_real_unbind_dim=-1, - sequence_dim=1, - ) - key = apply_rotary_emb( - key, - freqs_cis, - use_real=True, - use_real_unbind_dim=-1, - sequence_dim=1, - ) - query = query.cast(value.dtype) - key = key.cast(value.dtype) + query, key = _apply_zimage_qk_rope(query, key, freqs_cis) out = flash_attention_gpu( query, @@ -99,4 +115,4 @@ def forward( ) out = F.reshape(out, [batch_size, seq_len, self.inner_dim]) - return self.to_out[0](out) + return self.to_out(out) diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/layers/embeddings.py b/max/python/max/pipelines/architectures/z_image_modulev3/layers/embeddings.py index f7a925cd18e..35392026ba5 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/layers/embeddings.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/layers/embeddings.py @@ -18,8 +18,6 @@ from max.experimental.nn import Linear, Module from max.experimental.tensor import Tensor -from ...flux2_modulev3.layers.embeddings import get_1d_rotary_pos_embed - class TimestepEmbedder(Module[[Tensor], Tensor]): def __init__( @@ -67,7 +65,21 @@ def forward(self, t: Tensor) -> Tensor: return t_emb -class RopeEmbedder(Module[[Tensor], tuple[Tensor, Tensor]]): +def _get_1d_rope_interleaved( + dim: int, + pos: Tensor, + theta: float = 10000.0, +) -> Tensor: + """Compute 1-D RoPE in [cos, sin] interleaved pair format.""" + half = dim // 2 + freq_exp = F.arange(0, half, dtype=DType.float32, device=pos.device) / half + freq = 1.0 / (theta**freq_exp) + freqs = F.outer(pos, freq) + paired = F.stack([F.cos(freqs), F.sin(freqs)], axis=2) + return F.reshape(paired, [freqs.shape[0], dim]) + + +class RopeEmbedder(Module[[Tensor], Tensor]): def __init__( self, theta: float = 256.0, @@ -76,28 +88,15 @@ def __init__( self.theta = theta self.axes_dims = axes_dims - def forward(self, ids: Tensor) -> tuple[Tensor, Tensor]: - if ids.rank != 2: - raise ValueError(f"Expected 2D ids tensor, got rank={ids.rank}") - - if int(ids.shape[-1]) != len(self.axes_dims): - raise ValueError( - "ids last dimension must match axes_dims length " - f"({len(self.axes_dims)}), got {ids.shape[-1]}" - ) - + def forward(self, ids: Tensor) -> Tensor: pos = ids.cast(DType.float32) - cos_out = [] - sin_out = [] + parts = [] for i in range(len(self.axes_dims)): - cos_i, sin_i = get_1d_rotary_pos_embed( - self.axes_dims[i], - pos[:, i], - theta=self.theta, - use_real=True, - repeat_interleave_real=True, + parts.append( + _get_1d_rope_interleaved( + self.axes_dims[i], + pos[:, i], + theta=self.theta, + ) ) - cos_out.append(cos_i) - sin_out.append(sin_i) - - return F.concat(cos_out, axis=-1), F.concat(sin_out, axis=-1) + return F.concat(parts, axis=-1) diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/pipeline_z_image.py b/max/python/max/pipelines/architectures/z_image_modulev3/pipeline_z_image.py index f3225a860c3..d1797514c36 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/pipeline_z_image.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/pipeline_z_image.py @@ -22,15 +22,16 @@ import hashlib from dataclasses import MISSING, dataclass, field, fields -from typing import Any, Literal +from typing import Any, Literal, cast import numpy as np import numpy.typing as npt -from max.driver import CPU, Buffer, Device +from max.driver import Buffer, Device from max.dtype import DType from max.experimental import functional as F +from max.experimental.nn import Module as ExpModule from max.experimental.tensor import Tensor -from max.graph import TensorType +from max.graph import DeviceRef, TensorType from max.interfaces import TokenBuffer from max.pipelines.core import PixelContext from max.pipelines.lib.interfaces import ( @@ -40,7 +41,7 @@ from max.pipelines.lib.interfaces.diffusion_pipeline import max_compile from max.profiler import Tracer, traced -from ..autoencoders import AutoencoderKLModel +from ..autoencoders_modulev3 import AutoencoderKLModel from ..qwen3_modulev3.text_encoder import Qwen3TextEncoderZImageModel from .model import ZImageTransformerModel @@ -254,9 +255,9 @@ def init_remaining_components(self) -> None: ) self.build_preprocess_latents() - self.build_prepare_scheduler() self.build_scheduler_step() self.build_decode_latents() + self.build_batched_cfg_ops() self._init_cache_state( dtype=self.transformer.config.dtype, @@ -268,10 +269,9 @@ def init_remaining_components(self) -> None: self._cached_img_ids: dict[str, Tensor] = {} self._cached_img_ids_base_np: dict[str, np.ndarray] = {} self._cached_shape_carriers: dict[int, Tensor] = {} - self._cached_timesteps_batched: dict[str, Tensor] = {} - self._cached_timesteps_host: dict[str, np.ndarray] = {} self._cached_prompt_token_tensors: dict[str, Tensor] = {} self._cached_prompt_padding: dict[str, Tensor] = {} + self._cached_step_scalar_tensors: dict[str, list[Tensor]] = {} @traced(message="ZImagePipeline.prepare_inputs") def prepare_inputs( @@ -444,8 +444,8 @@ def run_transformer( def build_preprocess_latents(self) -> None: device = self.transformer.devices[0] - self.__dict__["_pack_latents_from_6d"] = max_compile( - self._pack_latents_from_6d, + self.__dict__["_patchify_and_pack"] = max_compile( + self._patchify_and_pack, input_types=[ TensorType( DType.float32, @@ -453,27 +453,13 @@ def build_preprocess_latents(self) -> None: "batch", "channels", "height", - 2, "width", - 2, ], device=device, ), ], ) - def build_prepare_scheduler(self) -> None: - self.__dict__["prepare_scheduler"] = max_compile( - self.prepare_scheduler, - input_types=[ - TensorType( - DType.float32, - shape=["num_sigmas"], - device=self.transformer.devices[0], - ), - ], - ) - def build_scheduler_step(self) -> None: dtype = self.transformer.config.dtype device = self.transformer.devices[0] @@ -490,27 +476,100 @@ def build_scheduler_step(self) -> None: ], ) + @traced(message="ZImagePipeline.build_decode_latents") def build_decode_latents(self) -> None: + device = self.transformer.devices[0] + dtype = self.transformer.config.dtype + device_ref = DeviceRef.from_device(device) + + fused_weights: dict[str, Any] = {} + for key, value in self.vae.weights.items(): + adapted_key = key + while adapted_key.startswith(("vae.", "model.")): + if adapted_key.startswith("vae."): + adapted_key = adapted_key.removeprefix("vae.") + continue + adapted_key = adapted_key.removeprefix("model.") + + weight_data = value.data() + if weight_data.dtype != dtype: + if weight_data.dtype.is_float() and dtype.is_float(): + weight_data = weight_data.astype(dtype) + + if adapted_key.startswith("decoder."): + fused_weights[adapted_key] = weight_data + elif adapted_key.startswith("post_quant_conv."): + fused_weights[f"decoder.{adapted_key}"] = weight_data + + from ..autoencoders_modulev3.autoencoder_kl import AutoencoderKL + + with F.lazy(): + autoencoder = AutoencoderKL(self.vae.config) + fused = _PostprocessAndDecodeKL( + decoder=autoencoder.decoder, + scaling_factor=float(self.vae.config.scaling_factor), + shift_factor=float(self.vae.config.shift_factor or 0.0), + device=device_ref, + dtype=dtype, + ) + fused.to(device) + self._fused_decode = fused.compile( + *fused.input_types(), weights=fused_weights + ) + + def build_batched_cfg_ops(self) -> None: dtype = self.transformer.config.dtype device = self.transformer.devices[0] - self.__dict__["_postprocess_latents"] = max_compile( - self._postprocess_latents, + self.__dict__["duplicate_batch"] = max_compile( + self.duplicate_batch, input_types=[ TensorType( dtype, - shape=[ - "batch", - "half_h", - "half_w", - 2, - 2, - "ch_4", - ], + shape=["batch", "seq", "channels"], + device=device, + ), + ], + ) + self.__dict__["cfg_finalize_batched"] = max_compile( + self.cfg_finalize_batched, + input_types=[ + TensorType( + dtype, + shape=["double_batch", "seq", "channels"], device=device, ), + TensorType(DType.float32, shape=[], device=device), ], ) + @staticmethod + def duplicate_batch(x: Tensor) -> Tensor: + """Duplicate batch: [B, S, C] -> [2B, S, C] via broadcast.""" + batch = x.shape[0] + seq = x.shape[1] + channels = x.shape[2] + x = F.unsqueeze(x, axis=0) + x = F.broadcast_to(x, [2, batch, seq, channels]) + return F.reshape(x, [batch * 2, seq, channels]) + + @staticmethod + def cfg_finalize_batched( + noise_pred_cfg: Tensor, + guidance_scale: Tensor, + ) -> tuple[Tensor, Tensor]: + """Split [2B,S,C] into pos/neg, apply CFG, return (pos, result).""" + batch2 = noise_pred_cfg.shape[0] + batch = batch2 // 2 + seq = noise_pred_cfg.shape[1] + channels = noise_pred_cfg.shape[2] + pos = F.rebind(noise_pred_cfg[:batch], [batch, seq, channels]) + neg = F.rebind(noise_pred_cfg[batch:], [batch, seq, channels]) + input_dtype = pos.dtype + diff = pos - neg + scaled = guidance_scale * diff + result = (pos + scaled).cast(input_dtype) + return pos, result + @staticmethod def _pack_latents(latents: Tensor) -> Tensor: batch_size, num_channels, height, width = map(int, latents.shape) @@ -537,45 +596,34 @@ def _pack_latents(latents: Tensor) -> Tensor: return latents @staticmethod - def _pack_latents_from_6d(latents: Tensor) -> Tensor: + def _patchify_and_pack(latents: Tensor) -> Tensor: batch_size = latents.shape[0] num_channels = latents.shape[1] height = latents.shape[2] - width = latents.shape[4] - latents = F.permute(latents, (0, 2, 4, 3, 5, 1)) + width = latents.shape[3] + latents = F.rebind( + latents, + [batch_size, num_channels, (height // 2) * 2, (width // 2) * 2], + ) latents = F.reshape( latents, ( batch_size, - height * width, - num_channels * 4, + num_channels, + height // 2, + 2, + width // 2, + 2, ), ) - return latents - - @staticmethod - def _unpack_latents( - latents: Tensor, - height: int, - width: int, - vae_scale_factor: int, - ) -> Tensor: - batch_size = int(latents.shape[0]) - ch_size = int(latents.shape[2]) - - height = 2 * (height // (vae_scale_factor * 2)) - width = 2 * (width // (vae_scale_factor * 2)) - - h2 = height // 2 - w2 = width // 2 - latents = F.reshape( - latents, - (batch_size, h2, w2, 2, 2, ch_size // 4), - ) - latents = F.permute(latents, (0, 5, 1, 3, 2, 4)) + latents = F.permute(latents, (0, 2, 4, 3, 5, 1)) latents = F.reshape( latents, - (batch_size, ch_size // 4, height, width), + ( + batch_size, + (height // 2) * (width // 2), + num_channels * 4, + ), ) return latents @@ -749,48 +797,11 @@ def decode_latents( output_type: Literal["np", "latent", "pil"] = "np", ) -> Tensor | np.ndarray: """Decode packed latents into image output.""" - latent_h = int(h_carrier.shape[0]) * 2 - latent_w = int(w_carrier.shape[0]) * 2 if output_type == "latent": return latents - batch_size = int(latents.shape[0]) - ch_size = int(latents.shape[2]) - latents = F.reshape( - latents, - ( - batch_size, - latent_h // 2, - latent_w // 2, - 2, - 2, - ch_size // 4, - ), - ) - - latents = self._postprocess_latents(latents) - decoded: Tensor = self.vae.decode(latents) - return self._to_numpy(decoded) - - def _postprocess_latents(self, latents: Tensor) -> Tensor: - batch_size = latents.shape[0] - half_h = latents.shape[1] - half_w = latents.shape[2] - c_quarter = latents.shape[5] - - latents = F.permute(latents, (0, 5, 1, 3, 2, 4)) - latents = F.reshape( - latents, (batch_size, c_quarter, half_h * 2, half_w * 2) - ) - latents = (latents / float(self.vae.config.scaling_factor)) + float( - self.vae.config.shift_factor or 0.0 - ) - return latents - - @staticmethod - def _to_numpy(image: Tensor) -> np.ndarray: - cpu_image: Tensor = image.cast(DType.float32).to(CPU()) - return np.from_dlpack(cpu_image) + decoded = self._fused_decode(latents, h_carrier, w_carrier) + return np.from_dlpack(decoded) @staticmethod def _vector_norm_per_sample(x: Tensor) -> Tensor: @@ -836,7 +847,7 @@ def scheduler_step( ) -> Tensor: latents_dtype = latents.dtype latents = latents.cast(DType.float32) - latents = latents + dt * noise_pred + latents = latents - dt * noise_pred latents = latents.cast(latents_dtype) return latents @@ -858,19 +869,7 @@ def preprocess_latents(self, latents: Tensor, dtype: DType) -> Tensor: ) with Tracer("patchify_and_pack"): - batch, channels, height, width = map(int, latents.shape) - latents = F.reshape( - latents, - ( - batch, - channels, - height // 2, - 2, - width // 2, - 2, - ), - ) - latents = self._pack_latents_from_6d(latents) + latents = self._patchify_and_pack(latents) return latents.cast(dtype) @@ -941,73 +940,27 @@ def prepare_img2img_latents( latents = sigma * noise_latents + (1.0 - sigma) * image_latents return latents.cast(noise_latents.dtype) - def _prepare_timestep_broadcast( + def _get_cached_step_scalar_tensors( self, - timesteps: np.ndarray, + values: np.ndarray, + key_prefix: str, device: Device, - cache_key: str | None = None, - ) -> tuple[Tensor, np.ndarray]: - transformed_timesteps = (1.0 - timesteps).astype(np.float32, copy=False) - - if cache_key is None: - num_timesteps = int(transformed_timesteps.shape[0]) - first_t = ( - float(transformed_timesteps[0]) if num_timesteps > 0 else 0.0 - ) - last_t = ( - float(transformed_timesteps[-1]) if num_timesteps > 0 else 0.0 - ) - cache_key = ( - f"timesteps::{num_timesteps}::{first_t:.8f}::{last_t:.8f}" - ) - - if ( - cache_key in self._cached_timesteps_batched - and cache_key in self._cached_timesteps_host - ): - return ( - self._cached_timesteps_batched[cache_key], - self._cached_timesteps_host[cache_key], + ) -> list[Tensor]: + values = np.ascontiguousarray(values.astype(np.float32, copy=False)) + digest = hashlib.sha1(values.tobytes()).hexdigest() + key = f"{key_prefix}::{values.shape[0]}::{digest}::{device}" + if key in self._cached_step_scalar_tensors: + return self._cached_step_scalar_tensors[key] + tensors = [ + Tensor( + storage=Buffer.from_dlpack( + np.array([float(v)], dtype=np.float32) + ).to(device) ) - - transformed_timesteps = np.ascontiguousarray(transformed_timesteps) - - timesteps_tensor = Tensor( - storage=Buffer.from_dlpack(transformed_timesteps).to(device) - ) - self._cached_timesteps_batched[cache_key] = timesteps_tensor - self._cached_timesteps_host[cache_key] = transformed_timesteps - return timesteps_tensor, transformed_timesteps - - def _prepare_scheduler_inputs( - self, - model_inputs: ZImageModelInputs, - sigmas: Tensor, - device: Device, - ) -> tuple[Any, Any, np.ndarray]: - _, all_dts = self.prepare_scheduler(sigmas) - dts_seq: Any = all_dts - if hasattr(dts_seq, "driver_tensor"): - dts_seq = dts_seq.driver_tensor - - timesteps = model_inputs.timesteps - num_timesteps = timesteps.shape[0] - timesteps_key = ( - f"timesteps::{num_timesteps}::{model_inputs.height}x" - f"{model_inputs.width}::{int(model_inputs.input_image is not None)}::" - f"{model_inputs.num_inference_steps}::{model_inputs.strength:.4f}::" - f"{float(getattr(self, '_scheduler_shift', 1.0)):.4f}" - ) - timesteps_seq, transformed_timesteps = self._prepare_timestep_broadcast( - timesteps=timesteps, - device=device, - cache_key=timesteps_key, - ) - timesteps_seq_any: Any = timesteps_seq - if hasattr(timesteps_seq_any, "driver_tensor"): - timesteps_seq_any = timesteps_seq_any.driver_tensor - - return timesteps_seq_any, dts_seq, transformed_timesteps + for v in values + ] + self._cached_step_scalar_tensors[key] = tensors + return tensors @traced(message="ZImagePipeline.execute") def execute( # type: ignore[override] @@ -1033,11 +986,10 @@ def execute( # type: ignore[override] tokens=model_inputs.negative_tokens_tensor, num_images_per_prompt=model_inputs.num_images_per_prompt, ) - if not model_inputs.explicit_negative_prompt: - negative_prompt_embeds = self._align_prompt_seq_len( - negative_prompt_embeds, - int(prompt_embeds.shape[1]), - ) + negative_prompt_embeds = self._align_prompt_seq_len( + negative_prompt_embeds, + int(prompt_embeds.shape[1]), + ) dtype = prompt_embeds.dtype latents = model_inputs.latents_tensor @@ -1082,13 +1034,20 @@ def execute( # type: ignore[override] # 3) Prepare scheduler tensors. with Tracer("prepare_scheduler"): - timesteps_seq, dts_seq, transformed_timesteps = ( - self._prepare_scheduler_inputs( - model_inputs=model_inputs, - sigmas=sigmas, - device=device, - ) + transformed_timesteps = np.ascontiguousarray( + (1.0 - model_inputs.timesteps).astype(np.float32, copy=False) + ) + timestep_scalars = self._get_cached_step_scalar_tensors( + transformed_timesteps, + key_prefix=( + f"step_t::{model_inputs.height}x{model_inputs.width}::" + f"{model_inputs.num_inference_steps}" + ), + device=device, ) + sigmas_host = np.asarray(model_inputs.sigmas, dtype=np.float32) + dt_values = np.ascontiguousarray(sigmas_host[1:] - sigmas_host[:-1]) + dts_seq = Buffer.from_dlpack(dt_values).to(device) cfg_active: np.ndarray | None = None if model_inputs.do_cfg: @@ -1099,11 +1058,39 @@ def execute( # type: ignore[override] else: cfg_active = np.ones(num_timesteps, dtype=np.bool_) + # Prepare batched CFG inputs (pos + neg concat). + use_batched_cfg = bool(model_inputs.do_cfg) + cfg_prompt_embeds: Tensor | None = None + cfg_timesteps: list[Tensor] | None = None + guidance_scale_tensor: Tensor | None = None + if use_batched_cfg: + assert negative_prompt_embeds is not None + cfg_prompt_embeds = F.concat( + [prompt_embeds, negative_prompt_embeds], axis=0 + ) + cfg_timesteps = [ + Tensor( + storage=Buffer.from_dlpack( + np.full( + (2 * batch_size,), + float(transformed_timesteps[i]), + dtype=np.float32, + ) + ).to(device) + ) + for i in range(num_timesteps) + ] + guidance_scale_tensor = Tensor( + storage=Buffer.from_dlpack( + np.array(model_inputs.guidance_scale, dtype=np.float32) + ).to(device) + ) + # 4) Denoising loop. with Tracer("denoising_loop"): for i in range(num_timesteps): with Tracer(f"denoising_step_{i}"): - timestep = timesteps_seq[i : i + 1] + timestep = timestep_scalars[i] apply_cfg = bool( model_inputs.do_cfg and cfg_active is not None @@ -1113,58 +1100,44 @@ def execute( # type: ignore[override] model_inputs.guidance_scale if apply_cfg else 0.0 ) - with Tracer("transformer"): - noise_pred = self.run_denoising_step( - step=i, - cache_state=cache_pos, - device=device, - latents=latents, - prompt_embeds=prompt_embeds, - timestep=timestep, - img_ids=img_ids, - txt_ids=txt_ids, - ) - - if apply_cfg: - assert negative_prompt_embeds is not None - assert cache_neg is not None - neg_img_ids = img_ids - neg_txt_ids = txt_ids - if model_inputs.explicit_negative_prompt: - assert ( - model_inputs.negative_img_ids_tensor is not None - ) - assert ( - model_inputs.negative_txt_ids_tensor is not None - ) - neg_img_ids = model_inputs.negative_img_ids_tensor - neg_txt_ids = model_inputs.negative_txt_ids_tensor - with Tracer("cfg_transformer"): - neg_noise_pred = self.run_denoising_step( - step=i, - cache_state=cache_neg, - device=device, - latents=latents, - prompt_embeds=negative_prompt_embeds, - timestep=timestep, - img_ids=neg_img_ids, - txt_ids=neg_txt_ids, - ) - pos_noise_pred = noise_pred - noise_delta = F.sub(noise_pred, neg_noise_pred) - noise_pred = F.add( - pos_noise_pred, - F.mul(noise_delta, current_guidance_scale), + if apply_cfg and use_batched_cfg: + # Single transformer call with batch=2B. + assert cfg_prompt_embeds is not None + assert cfg_timesteps is not None + assert guidance_scale_tensor is not None + with Tracer("transformer"): + latents_cfg = self.duplicate_batch(latents) + noise_pred_cfg = self.run_transformer( + cache_pos, + latents=latents_cfg, + prompt_embeds=cfg_prompt_embeds, + timestep=cfg_timesteps[i], + img_ids=img_ids, + txt_ids=txt_ids, + )[0] + pos_noise_pred, noise_pred = self.cfg_finalize_batched( + noise_pred_cfg, guidance_scale_tensor ) noise_pred = self._apply_cfg_renormalization( pos_noise_pred, noise_pred, model_inputs.cfg_normalization, ) + else: + with Tracer("transformer"): + noise_pred = self.run_denoising_step( + step=i, + cache_state=cache_pos, + device=device, + latents=latents, + prompt_embeds=prompt_embeds, + timestep=timestep, + img_ids=img_ids, + txt_ids=txt_ids, + ) with Tracer("scheduler_step"): - noise_pred = F.mul(noise_pred, -1.0) - dt = dts_seq[i : i + 1] + dt = cast(Tensor, dts_seq[i : i + 1]) latents = self.scheduler_step(latents, noise_pred, dt) with Tracer("decode_outputs"): @@ -1176,3 +1149,69 @@ def execute( # type: ignore[override] ) return ZImagePipelineOutput(images=outputs) + + +class _PostprocessAndDecodeKL(ExpModule[..., Tensor]): + """Fused unpack + denorm + decode + uint8 for z-image latents.""" + + def __init__( + self, + decoder: Any, + scaling_factor: float, + shift_factor: float, + *, + device: DeviceRef, + dtype: DType, + ) -> None: + super().__init__() + self.decoder = decoder + self.scaling_factor = scaling_factor + self.shift_factor = shift_factor + self._device = device + self._dtype = dtype + + def forward( + self, + latents_bsc: Tensor, + h_carrier: Tensor, + w_carrier: Tensor, + ) -> Tensor: + batch = latents_bsc.shape[0] + c = latents_bsc.shape[2] + half_h = h_carrier.shape[0] + half_w = w_carrier.shape[0] + + latents_bsc = F.rebind(latents_bsc, [batch, half_h * half_w, c]) + latents = F.reshape(latents_bsc, (batch, half_h, half_w, c)) + latents = F.rebind(latents, [batch, half_h, half_w, (c // 4) * 4]) + latents = F.reshape(latents, (batch, half_h, half_w, 2, 2, c // 4)) + latents = F.permute(latents, (0, 5, 1, 3, 2, 4)) + latents = F.reshape(latents, (batch, c // 4, half_h * 2, half_w * 2)) + latents = (latents / self.scaling_factor) + self.shift_factor + + decoded = self.decoder(latents, None) + decoded = F.permute(decoded, (0, 2, 3, 1)) + decoded = decoded * 0.5 + 0.5 + decoded = F.max(decoded, 0.0) + decoded = F.min(decoded, 1.0) + decoded = decoded * 255.0 + return F.transfer_to(F.cast(decoded, DType.uint8), DeviceRef.CPU()) + + def input_types(self) -> tuple[TensorType, ...]: + return ( + TensorType( + self._dtype, + shape=["batch", "seq", "channels"], + device=self._device, + ), + TensorType( + DType.float32, + shape=["half_h"], + device=self._device, + ), + TensorType( + DType.float32, + shape=["half_w"], + device=self._device, + ), + ) diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/weight_adapters.py b/max/python/max/pipelines/architectures/z_image_modulev3/weight_adapters.py index 73080b7aada..5a45852d219 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/weight_adapters.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/weight_adapters.py @@ -46,6 +46,7 @@ def convert_z_image_transformer_state_dict( key = _replace_prefix(key, "cap_embedder.0.", "cap_norm.") key = _replace_prefix(key, "cap_embedder.1.", "cap_proj.") key = key.replace("adaLN_modulation.0.", "adaLN_modulation.") + key = key.replace(".to_out.0.", ".to_out.") key = _replace_prefix( key, "final_layer.adaLN_modulation.1.", diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/z_image.py b/max/python/max/pipelines/architectures/z_image_modulev3/z_image.py index dc74d661a6f..7518daf0583 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/z_image.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/z_image.py @@ -60,6 +60,7 @@ def __init__( self.layer_id = layer_id self.modulation = modulation + self.dim = dim self.attention = ZImageAttention( dim=dim, @@ -84,7 +85,7 @@ def __init__( def forward( self, x: Tensor, - freqs_cis: tuple[Tensor, Tensor], + freqs_cis: Tensor, adaln_input: Tensor | None = None, ) -> Tensor: if self.modulation: @@ -93,14 +94,12 @@ def forward( if self.adaLN_modulation is None: raise ValueError("adaLN_modulation is not initialized") - mod = self.adaLN_modulation(adaln_input) - mod = F.unsqueeze(mod, 1) - scale_msa, gate_msa, scale_mlp, gate_mlp = F.chunk(mod, 4, axis=2) - - gate_msa = F.tanh(gate_msa) - gate_mlp = F.tanh(gate_mlp) - scale_msa = 1.0 + scale_msa - scale_mlp = 1.0 + scale_mlp + mod = F.unsqueeze(self.adaLN_modulation(adaln_input), 1) + d = self.dim + scale_msa = 1.0 + mod[:, :, :d] + gate_msa = F.tanh(mod[:, :, d : 2 * d]) + scale_mlp = 1.0 + mod[:, :, 2 * d : 3 * d] + gate_mlp = F.tanh(mod[:, :, 3 * d :]) attn_out = self.attention( self.attention_norm1(x) * scale_msa, @@ -332,24 +331,19 @@ def _forward_preamble( timestep: Tensor, img_ids: Tensor, txt_ids: Tensor, - ) -> tuple[Tensor, Any, Tensor, tuple[Tensor, Tensor]]: + ) -> tuple[Tensor, Any, Tensor, Tensor]: """Embed inputs, run refiners, return unified seq before main ``layers[0]``.""" x = self.x_embedder(hidden_states) t_emb = self.t_embedder(timestep * self.t_scale).cast(x.dtype) cap = self.cap_proj(self.cap_norm(encoder_hidden_states)) - if txt_ids.rank == 3: - txt_ids = txt_ids[0] - if img_ids.rank == 3: - img_ids = img_ids[0] + img_seq_len = img_ids.shape[0] + unified_ids = F.concat([img_ids, txt_ids], axis=0) + unified_freqs = self.rope_embedder(unified_ids).cast(x.dtype) - txt_freqs = self.rope_embedder(txt_ids) - img_freqs = self.rope_embedder(img_ids) - unified_freqs = ( - F.concat([img_freqs[0], txt_freqs[0]], axis=0), - F.concat([img_freqs[1], txt_freqs[1]], axis=0), - ) + img_freqs = unified_freqs[:img_seq_len] + txt_freqs = unified_freqs[img_seq_len:] for layer in self.noise_refiner: x = layer(x, freqs_cis=img_freqs, adaln_input=t_emb) @@ -365,7 +359,7 @@ def _run_first_main_layer( self, unified0: Tensor, t_emb: Tensor, - unified_freqs: tuple[Tensor, Tensor], + unified_freqs: Tensor, ) -> Tensor: return self.layers[0]( unified0, @@ -379,7 +373,7 @@ def _run_remaining_after_first( *, img_len: Any, t_emb: Tensor, - freqs_cis: tuple[Tensor, Tensor], + freqs_cis: Tensor, ) -> Tensor: u = unified for i in range(1, len(self.layers)):