From dbc1f3cc1ea74be3caecccbe13c248ef44381417 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Fri, 8 May 2026 00:26:21 +0530 Subject: [PATCH] feat: add KV caching support for Wan and VACE models, fix pyconfig getattr and fix jax tree map --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configs/base_wan_1_3b.yml | 1 + src/maxdiffusion/configs/base_wan_27b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 4 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 4 +- src/maxdiffusion/generate_wan.py | 6 +- src/maxdiffusion/models/attention_flax.py | 237 +++++++++++++---- src/maxdiffusion/models/embeddings_flax.py | 9 +- .../models/modeling_flax_utils.py | 2 +- .../wan/transformers/transformer_wan.py | 240 ++++++++++++++++-- .../wan/transformers/transformer_wan_vace.py | 83 +++++- .../pipelines/wan/wan_pipeline.py | 103 ++++++-- .../pipelines/wan/wan_pipeline_2_1.py | 71 +++++- .../pipelines/wan/wan_pipeline_2_2.py | 198 ++++++++++++--- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 62 ++++- .../pipelines/wan/wan_pipeline_i2v_2p2.py | 194 ++++++++++++-- .../pipelines/wan/wan_vace_pipeline_2_1.py | 15 ++ src/maxdiffusion/pyconfig.py | 2 +- src/maxdiffusion/tests/wan_kv_cache_test.py | 217 ++++++++++++++++ 19 files changed, 1281 insertions(+), 170 deletions(-) create mode 100644 src/maxdiffusion/tests/wan_kv_cache_test.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index c2c83c9f7..f432928aa 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -355,6 +355,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False +use_kv_cache: False use_magcache: False magcache_thresh: 0.12 magcache_K: 2 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 1fd384eb1..0e0552656 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -301,6 +301,7 @@ flow_shift: 3.0 # Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) use_cfg_cache: False +use_kv_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 1ce67a3cf..bf29fa867 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -331,7 +331,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False - +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass # when predicted output change (based on accumulated latent/timestep drift) is small use_sen_cache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 214cf5ce4..ca2d239ab 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -302,7 +302,7 @@ profiler_steps: 10 enable_jax_named_scopes: False # Generation parameters -prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. They are raising their left arm for a thumbs up. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True @@ -318,7 +318,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False - +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False use_magcache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index d2eb451d4..90799524c 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -303,7 +303,7 @@ profiler_steps: 10 enable_jax_named_scopes: False # Generation parameters -prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. They are raising their left arm for a thumbs up. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True @@ -330,7 +330,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False - +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index c0f71c84a..78a58c2d6 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -104,6 +104,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): magcache_thresh=config.magcache_thresh, magcache_K=config.magcache_K, retention_ratio=config.retention_ratio, + use_kv_cache=config.use_kv_cache, ) elif model_key == WAN2_2: return pipeline( @@ -118,6 +119,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, use_sen_cache=config.use_sen_cache, + use_kv_cache=config.use_kv_cache, ) else: raise ValueError(f"Unsupported model_name for I2V in config: {model_key}") @@ -136,6 +138,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): magcache_thresh=config.magcache_thresh, magcache_K=config.magcache_K, retention_ratio=config.retention_ratio, + use_kv_cache=config.use_kv_cache, ) elif model_key == WAN2_2: return pipeline( @@ -149,9 +152,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, use_sen_cache=config.use_sen_cache, + use_kv_cache=config.use_kv_cache, ) else: - raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}") + raise ValueError(f"Unsupported model_name for T2V in config: {model_key}") def inference_generate_video(config, pipeline, filename_prefix=""): diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index ae938b541..6ff4b4fed 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,7 +15,7 @@ import contextlib import functools import math -from typing import Optional, Callable, Tuple +from typing import Optional, Callable, Tuple, Dict import flax.linen as nn from flax import nnx import jax @@ -30,6 +30,8 @@ from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base from einops import rearrange from .. import common_types, max_logging +from maxdiffusion.tpu_utils import get_tpu_type, TpuType + from ..kernels import custom_splash_attention as custom_splash from . import quantizations @@ -677,7 +679,13 @@ def wrap_ulysses_attention(query, key, value): # Restore the original layout expected by the rest of the model: # head-sharded / full-sequence -> sequence-sharded / full-heads. - attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + attention_output = jax.lax.all_to_all( + attention_output, + axis_name=axis_name, + split_axis=2, + concat_axis=1, + tiled=True, + ) return attention_output devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) @@ -739,7 +747,11 @@ def _apply_attention_dot( query_chunk_size = int(flatten_latent_dim) hidden_states = jax_memory_efficient_attention( - query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + query_states, + key_states, + value_states, + query_chunk_size=query_chunk_size, + key_chunk_size=4096 * 4, ) hidden_states = hidden_states.transpose(1, 0, 2) @@ -1040,7 +1052,12 @@ def chunk_scanner(chunk_idx): def jax_memory_efficient_attention( - query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 + query, + key, + value, + precision=jax.lax.Precision.HIGHEST, + query_chunk_size: int = 1024, + key_chunk_size: int = 4096, ): r""" Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 @@ -1072,11 +1089,20 @@ def chunk_scanner(chunk_idx, _): return ( chunk_idx + query_chunk_size, # unused ignore it - _query_chunk_attention(query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size), + _query_chunk_attention( + query=query_chunk, + key=key, + value=value, + precision=precision, + key_chunk_size=key_chunk_size, + ), ) _, res = jax.lax.scan( - f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter + f=chunk_scanner, + init=0, + xs=None, + length=math.ceil(num_q / query_chunk_size), # start counter # stop counter ) return jnp.concatenate(res, axis=-3) # fuse the chunked result back @@ -1357,6 +1383,8 @@ def __init__( attention_kernel = "tokamax_flash" # do not use ring attention for cross attention self.added_kv_proj_dim = added_kv_proj_dim # New for I2V self.image_seq_len = image_seq_len # New for I2V + tpu_type = get_tpu_type() + self.alignment = 256 if tpu_type in [TpuType.TPU_V6_LITE, TpuType.TPU_7X] else 128 self.attention_op = NNXAttentionOp( mesh=mesh, @@ -1547,6 +1575,7 @@ def __call__( encoder_attention_mask: Optional[jax.Array] = None, deterministic: bool = True, rngs: nnx.Rngs = None, + cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, ) -> jax.Array: axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) @@ -1566,16 +1595,22 @@ def __call__( if not is_i2v_cross_attention: with jax.named_scope("query_proj"): query_proj = self.query(hidden_states) - with jax.named_scope("key_proj"): - key_proj = self.key(encoder_hidden_states) - with jax.named_scope("value_proj"): - value_proj = self.value(encoder_hidden_states) if self.qk_norm: with self.conditional_named_scope("attn_q_norm"): query_proj = self.norm_q(query_proj) - with self.conditional_named_scope("attn_k_norm"): - key_proj = self.norm_k(key_proj) + + if not is_self_attention and cached_kv is not None and "text" in cached_kv: + key_proj, value_proj = cached_kv["text"] + else: + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) if rotary_emb is not None: with self.conditional_named_scope("attn_rope"): @@ -1591,7 +1626,10 @@ def __call__( with jax.named_scope("apply_attention"): attn_output = self.attention_op.apply_attention( - query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask + query_proj, + key_proj, + value_proj, + attention_mask=encoder_attention_mask, ) else: @@ -1599,19 +1637,15 @@ def __call__( with self.conditional_named_scope("proj_query"): query_proj_raw = self.query(hidden_states) - # Image embeddings are padded to multiples of 128 for TPU flash attention + # Image embeddings are padded to multiples of 128 (v5p and below) or 256 (v6e and above) for TPU flash attention # Calculate the padded length to correctly split image and text embeddings if self.added_kv_proj_dim is not None: - alignment = 128 + alignment = self.alignment if self.image_seq_len is not None: image_seq_len_actual = self.image_seq_len else: image_seq_len_actual = 257 padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384 - - if encoder_attention_mask is None: - padded_img_len = image_seq_len_actual - encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :] encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :] @@ -1635,22 +1669,28 @@ def __call__( query_proj_text = query_proj_raw # Text K/V - with self.conditional_named_scope("proj_key"): - key_proj_text = self.key(encoder_hidden_states_text) - if self.qk_norm: - with self.conditional_named_scope("attn_k_norm"): - key_proj_text = self.norm_k(key_proj_text) - with self.conditional_named_scope("proj_value"): - value_proj_text = self.value(encoder_hidden_states_text) + if cached_kv is not None and "text" in cached_kv: + key_proj_text, value_proj_text = cached_kv["text"] + else: + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) # Image K/V (only if image embeddings are present) if encoder_hidden_states_img is not None: - with self.conditional_named_scope("add_proj_k"): - key_proj_img = self.add_k_proj(encoder_hidden_states_img) - with self.conditional_named_scope("norm_add_k"): - key_proj_img = self.norm_added_k(key_proj_img) - with self.conditional_named_scope("add_proj_v"): - value_proj_img = self.add_v_proj(encoder_hidden_states_img) + if cached_kv is not None and "image" in cached_kv: + key_proj_img, value_proj_img = cached_kv["image"] + else: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) query_proj_img = query_proj_raw # Check norm_added_k too # Checkpointing @@ -1667,7 +1707,10 @@ def __call__( with self.conditional_named_scope("cross_attn_img_apply"): # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens attn_output_img = self.attention_op.apply_attention( - query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img + query_proj_img, + key_proj_img, + value_proj_img, + attention_mask=encoder_attention_mask_img, ) attn_output = attn_output_text + attn_output_img @@ -1689,6 +1732,64 @@ def __call__( hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states + def compute_kv( + self, + encoder_hidden_states: jax.Array, + encoder_attention_mask: Optional[jax.Array] = None, + ) -> Dict[str, Tuple[jax.Array, jax.Array]]: + is_i2v_cross_attention = self.added_kv_proj_dim is not None + + if not is_i2v_cross_attention: + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + + return {"text": (key_proj, value_proj)} + else: + # Image embeddings are padded to multiples of 128 (v5p and below) or 256 (v6e and above) for TPU flash attention + alignment = self.alignment + if self.image_seq_len is not None: + image_seq_len_actual = self.image_seq_len + else: + image_seq_len_actual = 257 + padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment + + if encoder_attention_mask is None: + padded_img_len = image_seq_len_actual + + encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :] + encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :] + + # Text K/V + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) + + # Image K/V (only if image embeddings are present) + if encoder_hidden_states_img is not None: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) + + return { + "text": (key_proj_text, value_proj_text), + "image": (key_proj_img, value_proj_img), + } + else: + return {"text": (key_proj_text, value_proj_text)} + class FlaxFluxAttention(nn.Module): query_dim: int @@ -1801,7 +1902,13 @@ def setup(self): param_dtype=self.weights_dtype, ) - def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): + def __call__( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + image_rotary_emb=None, + ): qkv_proj = self.qkv(hidden_states) B, L = hidden_states.shape[:2] H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 @@ -1973,7 +2080,13 @@ def setup(self): ) self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context=None, deterministic=True, cross_attention_kwargs=None): + def __call__( + self, + hidden_states, + context=None, + deterministic=True, + cross_attention_kwargs=None, + ): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) @@ -2077,7 +2190,11 @@ def setup(self): quant=self.quant, ) self.ff = FlaxFeedForward( - dim=self.dim, dropout=self.dropout, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision + dim=self.dim, + dropout=self.dropout, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, ) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) @@ -2089,11 +2206,16 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k residual = hidden_states if self.only_cross_attention: hidden_states = self.attn1( - self.norm1(hidden_states), context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + self.norm1(hidden_states), + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) else: hidden_states = self.attn1( - self.norm1(hidden_states), deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + self.norm1(hidden_states), + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) hidden_states = hidden_states + residual @@ -2101,7 +2223,10 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k # cross attention residual = hidden_states hidden_states = self.attn2( - self.norm2(hidden_states), context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + self.norm2(hidden_states), + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) hidden_states = hidden_states + residual @@ -2172,7 +2297,12 @@ class FlaxTransformer2DModel(nn.Module): quant: Quant = (None,) def setup(self): - self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) + self.norm = nn.GroupNorm( + num_groups=self.norm_num_groups, + epsilon=1e-5, + dtype=self.dtype, + param_dtype=self.weights_dtype, + ) conv_kernel_init = nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out") @@ -2255,7 +2385,10 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k for transformer_block in self.transformer_blocks: hidden_states = transformer_block( - hidden_states, context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + hidden_states, + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) if self.use_linear_projection: @@ -2298,8 +2431,19 @@ class FlaxFeedForward(nn.Module): def setup(self): # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer - self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype, self.weights_dtype, precision=self.precision) - self.net_2 = nn.Dense(self.dim, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.net_0 = FlaxGEGLU( + self.dim, + self.dropout, + self.dtype, + self.weights_dtype, + precision=self.precision, + ) + self.net_2 = nn.Dense( + self.dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states, deterministic=deterministic) @@ -2329,7 +2473,12 @@ class FlaxGEGLU(nn.Module): def setup(self): inner_dim = self.dim * 4 - self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.proj = nn.Dense( + inner_dim * 2, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, deterministic=True): diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 36bea9ea3..4b83ae4eb 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -21,6 +21,7 @@ from .modeling_flax_utils import get_activation from ..models.attention_flax import NNXSimpleFeedForward from ..models.normalization_flax import FP32LayerNorm +from maxdiffusion.tpu_utils import get_tpu_type, TpuType def get_sinusoidal_embeddings( @@ -260,7 +261,7 @@ def __init__( weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, - alignment: int = 128, + alignment: Optional[int] = None, flash_min_seq_length: int = 4096, ): self.norm1 = FP32LayerNorm(rngs=rngs, dim=in_features, elementwise_affine=True, eps=1e-6) @@ -275,7 +276,11 @@ def __init__( precision=precision, ) self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6) - self.alignment = alignment + if alignment is None: + tpu_type = get_tpu_type() + self.alignment = 256 if tpu_type in [TpuType.TPU_V6_LITE, TpuType.TPU_7X] else 128 + else: + self.alignment = alignment self.flash_min_seq_length = flash_min_seq_length if pos_embed_seq_len is not None: self.pos_embed = nnx.Param(jnp.zeros((1, pos_embed_seq_len, in_features), dtype=dtype)) diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index d346eef2f..935c6b392 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -93,7 +93,7 @@ def conditional_cast(param): return param if mask is None: - return jax.tree_map(conditional_cast, params) + return jax.tree.map(conditional_cast, params) flat_params = flatten_dict(params) flat_mask, _ = jax.tree_flatten(mask) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 7d721773e..f5057f50d 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -62,7 +62,13 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): class WanRotaryPosEmbed(nnx.Module): - def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len @@ -152,18 +158,63 @@ def __init__( ) def __call__( - self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None + self, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + skip_embeddings: bool = False, ): timestep = self.timesteps_proj(timestep) temb = self.time_embedder(timestep) with jax.named_scope("time_proj"): timestep_proj = self.time_proj(self.act_fn(temb)) - encoder_hidden_states = self.text_embedder(encoder_hidden_states) - encoder_attention_mask = None - if encoder_hidden_states_image is not None: - encoder_hidden_states_image, encoder_attention_mask = self.image_embedder(encoder_hidden_states_image) - return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask + if not skip_embeddings: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + encoder_attention_mask = None + if encoder_hidden_states_image is not None: + ( + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.image_embedder(encoder_hidden_states_image) + else: + encoder_attention_mask = None + if ( + encoder_hidden_states_image is not None + and encoder_hidden_states_image.shape[-1] != encoder_hidden_states.shape[-1] + ): + img_dim = encoder_hidden_states_image.shape[-1] + text_dim = encoder_hidden_states.shape[-1] + if img_dim < text_dim: + pad_shape = ( + encoder_hidden_states_image.shape[0], + encoder_hidden_states_image.shape[1], + text_dim - img_dim, + ) + encoder_hidden_states_image = jnp.concatenate( + [ + encoder_hidden_states_image, + jnp.zeros(pad_shape, dtype=encoder_hidden_states_image.dtype), + ], + axis=-1, + ) + else: + pad_shape = ( + encoder_hidden_states.shape[0], + encoder_hidden_states.shape[1], + img_dim - text_dim, + ) + encoder_hidden_states = jnp.concatenate( + [encoder_hidden_states, jnp.zeros(pad_shape, dtype=encoder_hidden_states.dtype)], axis=-1 + ) + + return ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) class ApproximateGELU(nnx.Module): @@ -232,7 +283,13 @@ def __init__( self.act_fn = nnx.data(None) if activation_fn == "gelu-approximate": self.act_fn = ApproximateGELU( - rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision + rngs=rngs, + dim_in=dim, + dim_out=inner_dim, + bias=bias, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) else: raise NotImplementedError(f"{activation_fn} is not implemented.") @@ -259,7 +316,12 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() - def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: + def __call__( + self, + hidden_states: jax.Array, + deterministic: bool = True, + rngs: nnx.Rngs = None, + ) -> jax.Array: hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) hidden_states = checkpoint_name(hidden_states, "ffn_activation") if self.drop_out.rate > 0: @@ -381,6 +443,7 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, encoder_attention_mask: Optional[jax.Array] = None, + cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, ): with self.conditional_named_scope("transformer_block"): # Support both global [B, 6, dim] and per-token [B, seq_len, 6, dim] temb. @@ -396,7 +459,14 @@ def __call__( c_scale_msa = parts[4].squeeze(2) c_gate_msa = parts[5].squeeze(2) else: # Global: [B, 6, dim] - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + ( + shift_msa, + scale_msa, + gate_msa, + c_shift_msa, + c_scale_msa, + c_gate_msa, + ) = jnp.split( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1, @@ -435,6 +505,7 @@ def __call__( deterministic=deterministic, rngs=rngs, encoder_attention_mask=encoder_attention_mask, + cached_kv=cached_kv, ) with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output @@ -453,6 +524,13 @@ def __call__( ) return hidden_states + def compute_kv( + self, + encoder_hidden_states: jax.Array, + encoder_attention_mask: Optional[jax.Array] = None, + ) -> Dict[str, Tuple[jax.Array, jax.Array]]: + return self.attn2.compute_kv(encoder_hidden_states, encoder_attention_mask) + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -533,7 +611,11 @@ def __init__( # 3. Transformer blocks @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}) + @nnx.vmap( + in_axes=0, + out_axes=0, + transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}, + ) def init_block(rngs): return WanTransformerBlock( rngs=rngs, @@ -609,6 +691,61 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def compute_kv_cache( + self, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + timestep: Optional[jax.Array] = None, + ) -> Tuple[Dict[str, Tuple[jax.Array, jax.Array]], Optional[jax.Array]]: + if timestep is None: + batch_size = encoder_hidden_states.shape[0] + timestep = jnp.zeros((batch_size,), dtype=jnp.int32) + + with self.conditional_named_scope("condition_embedder"): + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + text_mask = jnp.ones( + ( + encoder_hidden_states.shape[0], + encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1], + ), + dtype=jnp.int32, + ) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + + if self.scan_layers: + + @nnx.vmap( + in_axes=(0, None, None), + out_axes=0, + transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}, + ) + def _compute_kv(block, enc_states, enc_mask): + return block.compute_kv(enc_states, enc_mask) + + kv_cache = _compute_kv(self.blocks, encoder_hidden_states, encoder_attention_mask) + else: + kv_cache_list = [] + for block in self.blocks: + kv_cache_list.append(block.compute_kv(encoder_hidden_states, encoder_attention_mask)) + keys = kv_cache_list[0].keys() + kv_cache = {} + for k in keys: + k_list = [d[k][0] for d in kv_cache_list] + v_list = [d[k][1] for d in kv_cache_list] + kv_cache[k] = (jnp.stack(k_list, axis=0), jnp.stack(v_list, axis=0)) + + return kv_cache, encoder_attention_mask + @jax.named_scope("WanModel") def __call__( self, @@ -623,6 +760,9 @@ def __call__( skip_blocks: Optional[jax.Array] = None, cached_residual: Optional[jax.Array] = None, return_residual: bool = False, + kv_cache: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, + rotary_emb: Optional[jax.Array] = None, + encoder_attention_mask: Optional[jax.Array] = None, ) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]: hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, _, num_frames, height, width = hidden_states.shape @@ -633,7 +773,8 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) with self.conditional_named_scope("rotary_embedding"): - rotary_emb = self.rope(hidden_states) + if rotary_emb is None: + rotary_emb = self.rope(hidden_states) with self.conditional_named_scope("patch_embedding"): hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -652,34 +793,56 @@ def __call__( timestep_proj = self.condition_embedder.time_proj(self.condition_embedder.act_fn(temb)) # [B, sl, dim*6] timestep_proj = timestep_proj.reshape(bt, sl, 6, -1) # [B, sl, 6, dim] # Text processing - encoder_hidden_states = self.condition_embedder.text_embedder(encoder_hidden_states) - encoder_hidden_states_image = None - encoder_attention_mask = None + if kv_cache is None: + encoder_hidden_states_out = self.condition_embedder.text_embedder(encoder_hidden_states) + else: + encoder_hidden_states_out = encoder_hidden_states + encoder_hidden_states_image_out = None + encoder_attention_mask_out = None else: ( temb, timestep_proj, + encoder_hidden_states_out, + encoder_hidden_states_image_out, + encoder_attention_mask_out, + ) = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, - encoder_attention_mask, - ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + skip_embeddings=(kv_cache is not None), + ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) - if encoder_hidden_states_image is not None: - encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) - if encoder_attention_mask is not None: + if encoder_attention_mask is None: + encoder_attention_mask = encoder_attention_mask_out + + if encoder_hidden_states_image_out is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image_out, encoder_hidden_states_out], axis=1) + if kv_cache is None and encoder_attention_mask is not None: text_mask = jnp.ones( - (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), + ( + encoder_hidden_states.shape[0], + encoder_hidden_states.shape[1] - encoder_hidden_states_image_out.shape[1], + ), dtype=jnp.int32, ) encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) + else: + encoder_hidden_states = encoder_hidden_states_out.astype(hidden_states.dtype) def _run_all_blocks(h): if self.scan_layers: - def scan_fn(carry, block): + def scan_fn(carry, block_input): hidden_states_carry, rngs_carry = carry + if kv_cache is not None: + block, layer_kv_cache = block_input + else: + block = block_input + layer_kv_cache = None + hidden_states = block( hidden_states_carry, encoder_hidden_states, @@ -688,27 +851,40 @@ def scan_fn(carry, block): deterministic, rngs_carry, encoder_attention_mask, + cached_kv=layer_kv_cache, ) new_carry = (hidden_states, rngs_carry) return new_carry, None rematted_block_forward = self.gradient_checkpoint.apply( - scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + scan_fn, + self.names_which_can_be_saved, + self.names_which_can_be_offloaded, + prevent_cse=not self.scan_layers, ) initial_carry = (h, rngs) + + if kv_cache is not None: + scan_input = (self.blocks, kv_cache) + else: + scan_input = self.blocks + final_carry, _ = nnx.scan( rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - )(initial_carry, self.blocks) + )(initial_carry, scan_input) h_out, _ = final_carry else: h_out = h - for block in self.blocks: + for i, block in enumerate(self.blocks): + layer_kv_cache = None + if kv_cache is not None: + layer_kv_cache = jax.tree.map(lambda x: x[i], kv_cache) - def layer_forward(hidden_states): + def layer_forward(hidden_states, l_kv): return block( hidden_states, encoder_hidden_states, @@ -717,6 +893,7 @@ def layer_forward(hidden_states): deterministic, rngs, encoder_attention_mask=encoder_attention_mask, + cached_kv=l_kv, ) rematted_layer_forward = self.gradient_checkpoint.apply( @@ -725,7 +902,7 @@ def layer_forward(hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - h_out = rematted_layer_forward(h_out) + h_out = rematted_layer_forward(h_out, layer_kv_cache) return h_out hidden_states_before_blocks = hidden_states @@ -752,7 +929,14 @@ def layer_forward(hidden_states): hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, ) hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index f3548602b..fcb9151f8 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -199,6 +199,9 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def compute_kv(self, encoder_hidden_states: jax.Array, encoder_attention_mask: Optional[jax.Array] = None): + return self.attn2.compute_kv(encoder_hidden_states, encoder_attention_mask) + def __call__( self, *, @@ -207,6 +210,8 @@ def __call__( control_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, + kv_cache: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, + encoder_attention_mask: Optional[jax.Array] = None, deterministic: bool = True, rngs: nnx.Rngs | None = None, ) -> Tuple[jax.Array, jax.Array]: @@ -253,6 +258,8 @@ def __call__( attn_output = self.attn2( hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, + cached_kv=kv_cache, + encoder_attention_mask=encoder_attention_mask, deterministic=deterministic, rngs=rngs, ) @@ -467,6 +474,66 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def compute_kv_cache( + self, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + timestep: Optional[jax.Array] = None, + ) -> Tuple[Tuple[Dict[str, Tuple[jax.Array, jax.Array]], Dict[str, Tuple[jax.Array, jax.Array]]], Optional[jax.Array]]: + if timestep is None: + batch_size = encoder_hidden_states.shape[0] + timestep = jnp.zeros((batch_size,), dtype=jnp.int32) + + with self.conditional_named_scope("condition_embedder"): + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + text_mask = jnp.ones( + ( + encoder_hidden_states.shape[0], + encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1], + ), + dtype=jnp.int32, + ) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + + if self.scan_layers: + raise NotImplementedError("scan_layers is not supported yet") + else: + # VACE blocks + vace_kv_cache_list = [] + for block in self.vace_blocks: + vace_kv_cache_list.append(block.compute_kv(encoder_hidden_states, encoder_attention_mask)) + vace_kv_cache = {} + if vace_kv_cache_list: + keys = vace_kv_cache_list[0].keys() + for k in keys: + k_list = [d[k][0] for d in vace_kv_cache_list] + v_list = [d[k][1] for d in vace_kv_cache_list] + vace_kv_cache[k] = (jnp.stack(k_list, axis=0), jnp.stack(v_list, axis=0)) + + # Main blocks + main_kv_cache_list = [] + for block in self.blocks: + main_kv_cache_list.append(block.compute_kv(encoder_hidden_states, encoder_attention_mask)) + main_kv_cache = {} + if main_kv_cache_list: + keys = main_kv_cache_list[0].keys() + for k in keys: + k_list = [d[k][0] for d in main_kv_cache_list] + v_list = [d[k][1] for d in main_kv_cache_list] + main_kv_cache[k] = (jnp.stack(k_list, axis=0), jnp.stack(v_list, axis=0)) + + return (vace_kv_cache, main_kv_cache), encoder_attention_mask + @jax.named_scope("WanVACEModel") def __call__( self, @@ -478,6 +545,8 @@ def __call__( encoder_hidden_states_image: Optional[jax.Array] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache: Optional[Tuple[Dict[str, Tuple[jax.Array, jax.Array]], Dict[str, Tuple[jax.Array, jax.Array]]]] = None, + encoder_attention_mask: Optional[jax.Array] = None, deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: @@ -524,7 +593,7 @@ def __call__( encoder_hidden_states_image, _, ) = self.condition_embedder( # We will need to mask out the text embedding. - timestep, encoder_hidden_states, encoder_hidden_states_image + timestep, encoder_hidden_states, encoder_hidden_states_image, skip_embeddings=(kv_cache is not None) ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) @@ -534,9 +603,14 @@ def __call__( if self.scan_layers: raise NotImplementedError("scan_layers is not supported yet") else: + vace_kv_cache, main_kv_cache = kv_cache if kv_cache is not None else (None, None) + # Prepare VACE hints control_hidden_states_list = [] for i, vace_block in enumerate(self.vace_blocks): + layer_kv_cache = None + if vace_kv_cache is not None: + layer_kv_cache = jax.tree.map(lambda x: x[i], vace_kv_cache) def layer_forward(hidden_states, control_hidden_states, rngs): return vace_block( @@ -545,6 +619,8 @@ def layer_forward(hidden_states, control_hidden_states, rngs): control_hidden_states=control_hidden_states, temb=timestep_proj, rotary_emb=rotary_emb, + kv_cache=layer_kv_cache, + encoder_attention_mask=encoder_attention_mask, deterministic=deterministic, rngs=rngs, ) @@ -561,6 +637,9 @@ def layer_forward(hidden_states, control_hidden_states, rngs): control_hidden_states_list = control_hidden_states_list[::-1] for i, block in enumerate(self.blocks): + layer_kv_cache = None + if main_kv_cache is not None: + layer_kv_cache = jax.tree.map(lambda x: x[i], main_kv_cache) def layer_forward_vace(hidden_states, rngs): return block( @@ -570,6 +649,8 @@ def layer_forward_vace(hidden_states, rngs): rotary_emb, deterministic, rngs, + encoder_attention_mask=encoder_attention_mask, + cached_kv=layer_kv_cache, ) rematted_layer_forward = self.gradient_checkpoint.apply( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 608f7282d..17c211b34 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -173,7 +173,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): ) params = jax.tree_util.tree_map_with_path( - lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params + lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), + params, ) for path, val in flax.traverse_util.flatten_dict(params).items(): if restored_checkpoint: @@ -291,7 +292,9 @@ def load_image_encoder(cls, config: HyperParameters): image_processor = CLIPImageProcessor.from_pretrained(config.pretrained_model_name_or_path, subfolder="image_processor") try: image_encoder = FlaxCLIPVisionModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 + config.pretrained_model_name_or_path, + subfolder="image_encoder", + dtype=jnp.float32, ) except Exception as e: max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") @@ -300,7 +303,12 @@ def load_image_encoder(cls, config: HyperParameters): @classmethod def load_vae( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, vae_logical_axis_rules: tuple = None + cls, + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + vae_logical_axis_rules: tuple = None, ): def create_model(rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( @@ -403,7 +411,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: return None @classmethod - def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): + def quantize_transformer( + cls, + config: HyperParameters, + model: WanModel, + pipeline: "WanPipeline", + mesh: Mesh, + ): """Quantizes the transformer model.""" q_rules = cls.get_qt_provider(config) if not q_rules: @@ -484,7 +498,8 @@ def _get_t5_prompt_embeds( prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], + dim=0, ) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -503,8 +518,11 @@ def encode_prompt( prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) + if prompt is not None: + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] // num_videos_per_prompt if negative_prompt is None: negative_prompt = [""] * batch_size @@ -587,12 +605,26 @@ def prepare_latents_i2v_base( if last_image is None: video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 + [ + image, + jnp.zeros( + (image.shape[0], image.shape[1], num_frames - 1, height, width), + dtype=image.dtype, + ), + ], + axis=2, ) else: last_image = last_image[:, :, jnp.newaxis, :, :] video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], + [ + image, + jnp.zeros( + (image.shape[0], image.shape[1], num_frames - 2, height, width), + dtype=image.dtype, + ), + last_image, + ], axis=2, ) @@ -679,7 +711,11 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): with vae_mesh: wan_vae, vae_cache = cls.load_vae( - devices_array=devices_array, mesh=vae_mesh, rngs=rngs, config=config, vae_logical_axis_rules=vae_logical_axis_rules + devices_array=devices_array, + mesh=vae_mesh, + rngs=rngs, + config=config, + vae_logical_axis_rules=vae_logical_axis_rules, ) components = { @@ -703,7 +739,10 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): components["text_encoder"] = cls.load_text_encoder(config=config) components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) if i2v and config.model_name == "wan2.1": - components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) + ( + components["image_processor"], + components["image_encoder"], + ) = cls.load_image_encoder(config) return components @abstractmethod @@ -803,7 +842,7 @@ def _prepare_model_inputs( if prompt is not None and isinstance(prompt, str): prompt = [prompt] - batch_size = len(prompt) + batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt with jax.named_scope("Encode-Prompt"): prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -836,10 +875,18 @@ def _prepare_model_inputs( negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, ) - return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames + return ( + latents, + prompt_embeds, + negative_prompt_embeds, + scheduler_state, + num_frames, + ) @abstractmethod def __call__(self, **kwargs): @@ -847,7 +894,15 @@ def __call__(self, **kwargs): pass -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale", "return_residual", "skip_blocks")) +@partial( + jax.jit, + static_argnames=( + "do_classifier_free_guidance", + "guidance_scale", + "return_residual", + "skip_blocks", + ), +) def transformer_forward_pass( graphdef, sharded_state, @@ -861,6 +916,9 @@ def transformer_forward_pass( skip_blocks=None, cached_residual=None, return_residual=False, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) outputs = wan_transformer( @@ -871,6 +929,9 @@ def transformer_forward_pass( skip_blocks=skip_blocks, cached_residual=cached_residual, return_residual=return_residual, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if return_residual: @@ -901,6 +962,9 @@ def transformer_forward_pass_full_cfg( prompt_embeds_combined: jnp.array, guidance_scale: float, encoder_hidden_states_image=None, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, ): """Full CFG forward pass. @@ -919,6 +983,9 @@ def transformer_forward_pass_full_cfg( skip_blocks=False, cached_residual=None, return_residual=False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_cond = noise_pred[:bsz] noise_uncond = noise_pred[bsz:] @@ -940,6 +1007,9 @@ def transformer_forward_pass_cfg_cache( w1: float = 1.0, w2: float = 1.0, encoder_hidden_states_image=None, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, ): """CFG-Cache forward pass with FFT frequency-domain compensation. @@ -965,6 +1035,9 @@ def transformer_forward_pass_cfg_cache( timestep=timestep_cond, encoder_hidden_states=prompt_cond_embeds, encoder_hidden_states_image=encoder_hidden_states_image, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) # FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W] diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index e0a2f05e6..355ba6ae6 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -73,7 +73,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) return pipeline @@ -100,6 +106,7 @@ def __call__( magcache_thresh: Optional[float] = None, magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, + use_kv_cache: bool = False, ): config = getattr(self, "config", None) if magcache_thresh is None: @@ -114,7 +121,6 @@ def __call__( f"use_cfg_cache=True requires guidance_scale > 1.0 (got {guidance_scale}). " "CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0." ) - trace = {} t_cond_start = time.perf_counter() @@ -152,6 +158,7 @@ def __call__( height=height, mag_ratios_base=getattr(config, "mag_ratios_base", None), config=self.config, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -196,6 +203,7 @@ def run_inference_2_1( height: int = 480, mag_ratios_base: Optional[List[float]] = None, config=None, + use_kv_cache: bool = False, ): """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. @@ -267,6 +275,26 @@ def run_inference_2_1( cached_noise_cond = None cached_noise_uncond = None + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(( + latents.shape[0], + latents.shape[2], + latents.shape[3], + latents.shape[4], + latents.shape[1], + )) + rotary_emb = transformer_obj.rope(dummy_hidden_states) + + kv_cache = None + encoder_attention_mask = None + + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache( + prompt_embeds_combined if do_cfg else prompt_cond_embeds + ) + if use_magcache and do_cfg: magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) accumulated_state = magcache_init[:6] @@ -276,7 +304,11 @@ def run_inference_2_1( first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -299,6 +331,9 @@ def scan_body(carry, t): timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: timestep = jnp.broadcast_to(t, bsz) @@ -311,6 +346,9 @@ def scan_body(carry, t): prompt_cond_embeds, do_classifier_free_guidance=False, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) new_latents, new_scheduler_state = scheduler.step( @@ -338,7 +376,12 @@ def scan_body(carry, t): timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz) skip_blocks, accumulated_state = magcache_step( - step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + skip_warmup, ) noise_pred, latents, residual_x_cur = transformer_forward_pass( @@ -353,6 +396,9 @@ def scan_body(carry, t): skip_blocks=bool(skip_blocks), cached_residual=cached_residual, return_residual=True, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if not skip_blocks: @@ -364,6 +410,8 @@ def scan_body(carry, t): if is_cache_step: w1, w2 = step_w1w2[step] timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree.map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, sharded_state, @@ -376,12 +424,19 @@ def scan_body(carry, t): guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) elif do_cfg: latents_doubled = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + ( + noise_pred, + cached_noise_cond, + cached_noise_uncond, + ) = transformer_forward_pass_full_cfg( graphdef, sharded_state, rest_of_state, @@ -389,6 +444,9 @@ def scan_body(carry, t): timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: @@ -402,6 +460,9 @@ def scan_body(carry, t): prompt_cond_embeds, do_classifier_free_guidance=False, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 77331d66d..2c294a124 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -25,6 +25,7 @@ import numpy as np from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler from ... import max_utils +from maxdiffusion import max_logging class WanPipeline2_2(WanPipeline): @@ -89,7 +90,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init( config, restored_checkpoint, vae_only, load_transformer ) @@ -116,6 +123,7 @@ def __call__( vae_only: bool = False, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_kv_cache: bool = False, ): if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") @@ -137,7 +145,13 @@ def __call__( trace = {} t_cond_start = time.perf_counter() - latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( + ( + latents, + prompt_embeds, + negative_prompt_embeds, + scheduler_state, + num_frames, + ) = self._prepare_model_inputs( prompt, negative_prompt, height, @@ -171,6 +185,7 @@ def __call__( use_cfg_cache=use_cfg_cache, use_sen_cache=use_sen_cache, height=height, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -220,6 +235,7 @@ def run_inference_2_2( use_sen_cache: bool = False, height: int = 480, config=None, + use_kv_cache: bool = False, ): """Denoising loop for WAN 2.2 T2V with optional caching acceleration. @@ -239,6 +255,33 @@ def run_inference_2_2( do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) + + low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(( + latents.shape[0], + latents.shape[2], + latents.shape[3], + latents.shape[4], + latents.shape[1], + )) + rotary_emb = low_transformer.rope(dummy_hidden_states) + + kv_cache_low = None + encoder_attention_mask_low = None + kv_cache_high = None + encoder_attention_mask_high = None + + if use_kv_cache: + kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache(prompt_embeds_combined) + + high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -262,8 +305,6 @@ def run_inference_2_2( # uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches. num_train_timesteps = float(scheduler.config.num_train_timesteps) - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # SenCache state ref_noise_pred = None # y^r: cached denoiser output ref_latent = None # x^r: latent at last cache refresh @@ -279,11 +320,23 @@ def run_inference_2_2( # Select transformer and guidance scale if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low # Force full compute: warmup, first 30%, last 10%, or transformer boundary is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] @@ -302,6 +355,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) ref_noise_pred = noise_pred ref_latent = latents @@ -338,6 +394,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) ref_noise_pred = noise_pred ref_latent = latents @@ -348,7 +407,7 @@ def run_inference_2_2( latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - print( + max_logging.log( f"[SenCache] Cached {cache_count}/{num_inference_steps} steps " f"({100*cache_count/num_inference_steps:.1f}% cache ratio)" ) @@ -374,7 +433,6 @@ def run_inference_2_2( # Pre-split embeds once prompt_cond_embeds = prompt_embeds - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) # Determine the first low-noise step (boundary transition). # In Wan 2.2 the boundary IS the structural→detail transition, so @@ -420,16 +478,30 @@ def run_inference_2_2( # Select transformer and guidance scale based on precomputed schedule if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── w1, w2 = step_w1w2[step] timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree.map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, state, @@ -442,12 +514,19 @@ def run_inference_2_2( guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── latents_doubled = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + ( + noise_pred, + cached_noise_cond, + cached_noise_uncond, + ) = transformer_forward_pass_full_cfg( graphdef, state, rest, @@ -455,6 +534,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() @@ -467,40 +549,64 @@ def run_inference_2_2( timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] - prompt_embeds_combined = ( - jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds - ) - first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False - def high_noise_branch(ops): - model_latents_in, timestep_in = ops + def high_noise_branch(operands): + ( + latents_input, + ts_input, + pe_input, + kv_cache_high, + _, + r_emb, + mask_high, + _, + ) = operands return transformer_forward_pass( high_noise_graphdef, high_noise_state, high_noise_rest, - model_latents_in, - timestep_in, - prompt_embeds_combined, - do_classifier_free_guidance, - guidance_scale_high, + latents_input, + ts_input, + pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale_high, + kv_cache=kv_cache_high, + rotary_emb=r_emb, + encoder_attention_mask=mask_high, ) - def low_noise_branch(ops): - model_latents_in, timestep_in = ops + def low_noise_branch(operands): + ( + latents_input, + ts_input, + pe_input, + _, + kv_cache_low, + r_emb, + _, + mask_low, + ) = operands return transformer_forward_pass( low_noise_graphdef, low_noise_state, low_noise_rest, - model_latents_in, - timestep_in, - prompt_embeds_combined, - do_classifier_free_guidance, - guidance_scale_low, + latents_input, + ts_input, + pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale_low, + kv_cache=kv_cache_low, + rotary_emb=r_emb, + encoder_attention_mask=mask_low, ) if scan_diffusion_loop: @@ -519,7 +625,21 @@ def scan_body(carry, t): timestep = jnp.broadcast_to(t, model_latents.shape[0]) use_high_noise = jnp.greater_equal(t, boundary) - noise_pred, latents_out = jax.lax.cond(use_high_noise, high_noise_branch, low_noise_branch, (model_latents, timestep)) + noise_pred, latents_out = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + ( + model_latents, + timestep, + prompt_embeds_combined, + kv_cache_high, + kv_cache_low, + rotary_emb, + encoder_attention_mask_high, + encoder_attention_mask_low, + ), + ) new_latents, new_scheduler_state = scheduler.step( current_scheduler_state, noise_pred, t, latents_out, return_dict=False @@ -543,11 +663,19 @@ def scan_body(carry, t): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if do_classifier_free_guidance: latents_doubled = jnp.concatenate([latents] * 2) @@ -560,6 +688,9 @@ def scan_body(carry, t): timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: timestep = jnp.broadcast_to(t, bsz) @@ -572,6 +703,9 @@ def scan_body(carry, t): prompt_embeds, do_classifier_free_guidance, guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0abe4fa5b..aa4bbba27 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -77,7 +77,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) return pipeline @@ -113,7 +119,13 @@ def prepare_latents( latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial - shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + shape = ( + batch_size, + num_latent_frames, + latent_height, + latent_width, + num_channels_latents, + ) if latents is None: latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) @@ -129,7 +141,12 @@ def prepare_latents( first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) mask_lat_size = mask_lat_size.reshape( - batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + batch_size, + 1, + num_latent_frames, + self.vae_scale_factor_temporal, + latent_height, + latent_width, ) mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) @@ -158,6 +175,7 @@ def __call__( magcache_thresh: Optional[float] = None, magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, + use_kv_cache: bool = False, ): config = getattr(self, "config", None) if magcache_thresh is None: @@ -180,7 +198,6 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) - trace = {} t_cond_start = time.perf_counter() @@ -231,7 +248,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): trace["conditioning"] = time.perf_counter() - t_cond_start scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, ) graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) @@ -262,6 +281,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): height=height, mag_ratios_base=self.config.mag_ratios_base_720p if height >= 720 else self.config.mag_ratios_base_480p, config=self.config, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -311,6 +331,7 @@ def run_inference_2_1_i2v( height: int = 480, mag_ratios_base: Optional[List[float]] = None, config=None, + use_kv_cache: bool = False, ): do_cfg = guidance_scale > 1.0 @@ -330,9 +351,25 @@ def run_inference_2_1_i2v( image_embeds_combined = image_embeds condition_combined = condition + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(latents.shape) + rotary_emb = transformer_obj.rope(dummy_hidden_states) + + kv_cache = None + encoder_attention_mask = None + + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds_combined, image_embeds_combined) + first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -365,6 +402,9 @@ def scan_body(carry, t): skip_blocks=None, cached_residual=None, return_residual=False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred, _ = outputs @@ -393,7 +433,12 @@ def scan_body(carry, t): skip_blocks = False if use_magcache and do_cfg: skip_blocks, accumulated_state = magcache_step( - step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + skip_warmup, ) latents_input = latents @@ -417,6 +462,9 @@ def scan_body(carry, t): skip_blocks=bool(skip_blocks) if use_magcache and do_cfg else None, cached_residual=cached_residual if use_magcache and do_cfg else None, return_residual=True if use_magcache and do_cfg else False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if use_magcache and do_cfg: noise_pred, _, residual_x_cur = outputs diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index f466ec574..1ba54f2eb 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -95,7 +95,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, _, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) return pipeline @@ -126,7 +132,13 @@ def prepare_latents( latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial - shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + shape = ( + batch_size, + num_latent_frames, + latent_height, + latent_width, + num_channels_latents, + ) if latents is None: latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) @@ -144,7 +156,12 @@ def prepare_latents( first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) mask_lat_size = mask_lat_size.reshape( - batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + batch_size, + 1, + num_latent_frames, + self.vae_scale_factor_temporal, + latent_height, + latent_width, ) mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) @@ -172,6 +189,7 @@ def __call__( rng: Optional[jax.Array] = None, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_kv_cache: bool = False, ): if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") @@ -202,7 +220,6 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) - trace = {} t_cond_start = time.perf_counter() @@ -256,7 +273,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): trace["conditioning"] = time.perf_counter() - t_cond_start scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, ) low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) @@ -288,6 +307,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): use_sen_cache=use_sen_cache, height=height, config=self.config, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -344,10 +364,42 @@ def run_inference_2_2_i2v( use_sen_cache: bool = False, height: int = 480, config=None, + use_kv_cache: bool = False, ): do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) + if image_embeds is not None: + image_embeds_combined = ( + jnp.concatenate([image_embeds, image_embeds], axis=0) if do_classifier_free_guidance else image_embeds + ) + else: + image_embeds_combined = None + + low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(latents.shape) + rotary_emb = low_transformer.rope(dummy_hidden_states) + + kv_cache_low = None + encoder_attention_mask_low = None + kv_cache_high = None + encoder_attention_mask_high = None + + if use_kv_cache: + kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache( + prompt_embeds_combined, image_embeds_combined + ) + + high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache( + prompt_embeds_combined, image_embeds_combined + ) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -365,11 +417,6 @@ def run_inference_2_2_i2v( nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) num_train_timesteps = float(scheduler.config.num_train_timesteps) - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - if image_embeds is not None: - image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) - else: - image_embeds_combined = None condition_doubled = jnp.concatenate([condition] * 2) # SenCache state @@ -386,11 +433,23 @@ def run_inference_2_2_i2v( t_float = float(timesteps_np[step]) / num_train_timesteps if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] force_compute = ( @@ -411,6 +470,9 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred @@ -447,6 +509,9 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred @@ -458,7 +523,7 @@ def run_inference_2_2_i2v( latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - print( + max_logging.log( f"[SenCache] Cached {cache_count}/{num_inference_steps} steps " f"({100*cache_count/num_inference_steps:.1f}% cache ratio)" ) @@ -483,14 +548,11 @@ def run_inference_2_2_i2v( # Pre-split embeds prompt_cond_embeds = prompt_embeds - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if image_embeds is not None: image_embeds_cond = image_embeds - image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) else: image_embeds_cond = None - image_embeds_combined = None # Keep condition in both single and doubled forms condition_cond = condition @@ -534,11 +596,23 @@ def run_inference_2_2_i2v( is_cache_step = step_is_cache[step] if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── @@ -547,6 +621,8 @@ def run_inference_2_2_i2v( latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree.map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, state, @@ -560,6 +636,9 @@ def run_inference_2_2_i2v( w1=jnp.float32(w1), w2=jnp.float32(w2), encoder_hidden_states_image=image_embeds_cond, + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── @@ -567,7 +646,11 @@ def run_inference_2_2_i2v( latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + ( + noise_pred, + cached_noise_cond, + cached_noise_uncond, + ) = transformer_forward_pass_full_cfg( graphdef, state, rest, @@ -576,6 +659,9 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC @@ -584,7 +670,17 @@ def run_inference_2_2_i2v( # ── Original non-cache path ── def high_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands + ( + latents_input, + ts_input, + pe_input, + ie_input, + kv_cache_high, + _, + r_emb, + mask_high, + _, + ) = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( high_noise_graphdef, @@ -596,11 +692,24 @@ def high_noise_branch(operands): do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high, encoder_hidden_states_image=ie_input, + kv_cache=kv_cache_high, + rotary_emb=r_emb, + encoder_attention_mask=mask_high, ) return noise_pred, latents_out def low_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands + ( + latents_input, + ts_input, + pe_input, + ie_input, + _, + kv_cache_low, + r_emb, + _, + mask_low, + ) = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( low_noise_graphdef, @@ -612,19 +721,22 @@ def low_noise_branch(operands): do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low, encoder_hidden_states_image=ie_input, + kv_cache=kv_cache_low, + rotary_emb=r_emb, + encoder_attention_mask=mask_low, ) return noise_pred, latents_out if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder - if image_embeds is not None: - image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) condition = jnp.concatenate([condition] * 2) first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -644,7 +756,20 @@ def scan_body(carry, t): use_high_noise = jnp.greater_equal(t, boundary) noise_pred, _ = jax.lax.cond( - use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + use_high_noise, + high_noise_branch, + low_noise_branch, + ( + latent_model_input, + timestep, + prompt_embeds_combined, + image_embeds_combined, + kv_cache_high, + kv_cache_low, + rotary_emb, + encoder_attention_mask_high, + encoder_attention_mask_low, + ), ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) new_latents, new_scheduler_state = scheduler.step( @@ -675,7 +800,20 @@ def scan_body(carry, t): use_high_noise = jnp.greater_equal(t, boundary) noise_pred, _ = jax.lax.cond( - use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + use_high_noise, + high_noise_branch, + low_noise_branch, + ( + latent_model_input, + timestep, + prompt_embeds_combined, + image_embeds_combined, + kv_cache_high, + kv_cache_low, + rotary_emb, + encoder_attention_mask_high, + encoder_attention_mask_low, + ), ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index ac721189c..b12ae4142 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -499,6 +499,7 @@ def __call__( prompt_embeds: jax.Array | None = None, negative_prompt_embeds: jax.Array | None = None, vae_only: bool = False, + use_kv_cache: bool = False, ): """Runs the VACE model for the given inputs. @@ -639,6 +640,7 @@ def __call__( num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, + use_kv_cache=use_kv_cache, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -742,6 +744,8 @@ def transformer_forward_pass( control_hidden_states_scale: jax.Array, do_classifier_free_guidance: bool, guidance_scale: float, + kv_cache=None, + encoder_attention_mask=None, ): """Performs a forward pass on the transformer.""" wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) @@ -751,6 +755,8 @@ def transformer_forward_pass( encoder_hidden_states=prompt_embeds, control_hidden_states=control_hidden_states, control_hidden_states_scale=control_hidden_states_scale, + kv_cache=kv_cache, + encoder_attention_mask=encoder_attention_mask, ) if do_classifier_free_guidance: bsz = latents.shape[0] // 2 @@ -775,12 +781,19 @@ def run_inference( scheduler_state, control_hidden_states, control_hidden_states_scale, + use_kv_cache: bool = False, ): do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) control_hidden_states = jnp.concatenate([control_hidden_states] * 2) + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + kv_cache = None + encoder_attention_mask = None + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds) + for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if do_classifier_free_guidance: @@ -798,6 +811,8 @@ def run_inference( control_hidden_states_scale=control_hidden_states_scale, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, + kv_cache=kv_cache, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 16ed86357..67b053360 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -313,7 +313,7 @@ def __init__(self): def __getattr__(self, attr): if attr not in _config.keys: - raise ValueError(f"Requested key {attr}, not in config") + raise AttributeError(f"Requested key {attr}, not in config") return _config.keys[attr] def __setattr__(self, attr, value): diff --git a/src/maxdiffusion/tests/wan_kv_cache_test.py b/src/maxdiffusion/tests/wan_kv_cache_test.py new file mode 100644 index 000000000..38589db12 --- /dev/null +++ b/src/maxdiffusion/tests/wan_kv_cache_test.py @@ -0,0 +1,217 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import unittest +from unittest.mock import MagicMock, patch + +import flax +import jax + +import jax.numpy as jnp +import numpy as np + +from maxdiffusion import pyconfig +from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 +from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class WanKvCacheTest(unittest.TestCase): + + def setUp(self): + # Initialize pyconfig with base_wan_1_3b.yml and overrides some parameters for speed + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_1_3b.yml"), + "pretrained_model_name_or_path=Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "num_inference_steps=2", # Reduced steps for speed + "height=240", # Reduced resolution for speed (divisible by 16) + "width=416", # Reduced resolution for speed (divisible by 16) + "num_frames=9", # Reduced num_frames for speed + "attention=flash", + "scan_layers=False", + "jit_initializers=False", + "skip_jax_distributed_system=True", + ], + unittest=True, + ) + self.config = pyconfig.config + + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanModel.load_config") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.AutoencoderKLWan.load_config") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.load_wan_transformer") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.load_wan_vae") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_tokenizer") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_text_encoder") + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_scheduler") + def test_wan_2_1_kv_cache( + self, + mock_load_scheduler_fn, + mock_load_text_encoder_fn, + mock_load_tokenizer_fn, + mock_load_wan_vae_fn, + mock_load_wan_transformer_fn, + mock_vae_load_config_fn, + mock_transformer_load_config_fn, + ): + # Mock transformer config + def mock_transformer_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs): + config_dict = { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "image_dim": None, + "in_channels": 16, + "num_attention_heads": 12, + "num_layers": 2, + "out_channels": 16, + "patch_size": [1, 2, 2], + "pos_embed_seq_len": None, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 1024, + "text_dim": 4096, + } + if return_unused_kwargs: + return config_dict, kwargs + return config_dict + + mock_transformer_load_config_fn.side_effect = mock_transformer_load_config + + # Mock VAE config + def mock_vae_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs): + config_dict = { + "attn_scales": [], + "base_dim": 96, + "dim_mult": [1, 2, 4, 4], + "dropout": 0.0, + "latents_mean": [0.0] * 16, + "latents_std": [1.0] * 16, + "num_res_blocks": 2, + "temperal_downsample": [False, True, True], + "z_dim": 16, + } + if return_unused_kwargs: + return config_dict, kwargs + return config_dict + + mock_vae_load_config_fn.side_effect = mock_vae_load_config + + # Mock weight loaders + def mock_load_wan_transformer(pretrained_model_name_or_path, eval_shapes, *args, **kwargs): + cpu = jax.local_devices(backend="cpu")[0] + flat_shapes = flax.traverse_util.flatten_dict(eval_shapes) + flat_params = {} + key = jax.random.key(42) + for k, shape_struct in flat_shapes.items(): + dtype = shape_struct.dtype + shape = shape_struct.shape + key, subkey = jax.random.split(key) + val = jax.random.normal(subkey, shape, dtype=dtype) + flat_params[k] = jax.device_put(val, device=cpu) + return flax.traverse_util.unflatten_dict(flat_params) + + mock_load_wan_transformer_fn.side_effect = mock_load_wan_transformer + + def mock_load_wan_vae(pretrained_model_name_or_path, eval_shapes, *args, **kwargs): + cpu = jax.local_devices(backend="cpu")[0] + flat_shapes = flax.traverse_util.flatten_dict(eval_shapes) + flat_params = {} + key = jax.random.key(42) + for k, shape_struct in flat_shapes.items(): + dtype = shape_struct.dtype + shape = shape_struct.shape + key, subkey = jax.random.split(key) + val = jax.random.normal(subkey, shape, dtype=dtype) + flat_params[k] = jax.device_put(val, device=cpu) + return flax.traverse_util.unflatten_dict(flat_params) + + mock_load_wan_vae_fn.side_effect = mock_load_wan_vae + + # Mock scheduler + def mock_load_scheduler(config): + scheduler = FlaxUniPCMultistepScheduler.from_config({ + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "flow_shift": config.flow_shift, + "num_train_timesteps": 1000, + "prediction_type": "flow_prediction", + "timestep_spacing": "linspace", + "use_flow_sigmas": True, + }) + state = scheduler.create_state() + return scheduler, state + + mock_load_scheduler_fn.side_effect = mock_load_scheduler + + mock_load_tokenizer_fn.return_value = MagicMock() + mock_load_text_encoder_fn.return_value = MagicMock() + + pipeline = WanPipeline2_1.from_pretrained(self.config) + + batch_size = 1 + height = self.config.height + width = self.config.width + num_frames = self.config.num_frames + + prompt_embeds = jnp.zeros((batch_size, 512, 4096), dtype=self.config.weights_dtype) + negative_prompt_embeds = jnp.zeros((batch_size, 512, 4096), dtype=self.config.weights_dtype) + + # Run without cache + video_no_cache, _ = pipeline( + prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt=None, + negative_prompt_embeds=negative_prompt_embeds, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=self.config.num_inference_steps, + use_kv_cache=False, + ) + + # Run with cache + video_with_cache, _ = pipeline( + prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt=None, + negative_prompt_embeds=negative_prompt_embeds, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=self.config.num_inference_steps, + use_kv_cache=True, + ) + + self.assertEqual(len(video_no_cache), batch_size) + self.assertEqual(video_no_cache[0].shape, (num_frames, height, width, 3)) + + self.assertEqual(len(video_with_cache), batch_size) + self.assertEqual(video_with_cache[0].shape, (num_frames, height, width, 3)) + + # Compare outputs + np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=0.7) + + +if __name__ == "__main__": + unittest.main()