diff --git a/max/python/max/interfaces/provider_options/modality/video.py b/max/python/max/interfaces/provider_options/modality/video.py index ef8a362a767..3a1911af880 100644 --- a/max/python/max/interfaces/provider_options/modality/video.py +++ b/max/python/max/interfaces/provider_options/modality/video.py @@ -66,3 +66,9 @@ class VideoProviderOptions(BaseModel): ), gt=0, ) + + guidance_scale_2: float | None = Field( + None, + description="Secondary guidance scale for boundary timestep switching.", + gt=0.0, + ) diff --git a/max/python/max/pipelines/architectures/wan/context.py b/max/python/max/pipelines/architectures/wan/context.py new file mode 100644 index 00000000000..dfdc291a276 --- /dev/null +++ b/max/python/max/pipelines/architectures/wan/context.py @@ -0,0 +1,35 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Wan-specific pixel generation context.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np +import numpy.typing as npt +from max.pipelines.core import PixelContext + + +@dataclass(kw_only=True) +class WanContext(PixelContext): + """Pixel generation context with Wan-specific video/MoE fields.""" + + guidance_scale_2: float | None = field(default=None) + """Secondary guidance scale for low-noise expert (MoE models).""" + + step_coefficients: npt.NDArray[np.float32] | None = field(default=None) + """Pre-computed scheduler step coefficients.""" + + boundary_timestep: float | None = field(default=None) + """Timestep threshold for switching between high/low noise experts.""" diff --git a/max/python/max/pipelines/architectures/wan/tokenizer.py b/max/python/max/pipelines/architectures/wan/tokenizer.py new file mode 100644 index 00000000000..0a790c22cb9 --- /dev/null +++ b/max/python/max/pipelines/architectures/wan/tokenizer.py @@ -0,0 +1,141 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Wan-specific pixel generation tokenizer.""" + +from __future__ import annotations + +import logging + +import numpy as np +import numpy.typing as npt +import PIL.Image +from max.interfaces.request import OpenResponsesRequest +from max.pipelines.lib.pixel_tokenizer import PixelGenerationTokenizer + +from .context import WanContext + +logger = logging.getLogger("max.pipelines") + + +class WanTokenizer(PixelGenerationTokenizer): + """Wan-specific tokenizer that produces WanContext with video/MoE fields.""" + + def _select_wan_flow_shift(self, height: int, width: int) -> float: + scheduler_cfg = ( + self.diffusers_config.get("components", {}) + .get("scheduler", {}) + .get("config_dict", {}) + ) + # Use explicit flow_shift from scheduler config if set (user override). + cfg_shift = scheduler_cfg.get("flow_shift") + if cfg_shift is not None and float(cfg_shift) != 1.0: + return float(cfg_shift) + # Default: interpolate based on pixel count. + # 480p (480*832 = 399 360) → 3.0, 720p (720*1280 = 921 600) → 5.0 + pixels = height * width + lo_px, hi_px = 399_360, 921_600 + lo_shift, hi_shift = 3.0, 5.0 + t = max(0.0, min(1.0, (pixels - lo_px) / (hi_px - lo_px))) + return lo_shift + t * (hi_shift - lo_shift) + + async def new_context( + self, + request: OpenResponsesRequest, + input_image: PIL.Image.Image | None = None, + ) -> WanContext: + base = await super().new_context(request, input_image=input_image) + + video_options = request.body.provider_options.video + + num_frames: int | None = ( + video_options.num_frames if video_options else None + ) + guidance_scale_2: float | None = ( + video_options.guidance_scale_2 if video_options else None + ) + + height = base.height + width = base.width + timesteps: npt.NDArray[np.float32] = base.timesteps + sigmas: npt.NDArray[np.float32] = base.sigmas + + if getattr(self._scheduler, "use_flow_sigmas", False): + self._scheduler.flow_shift = self._select_wan_flow_shift( + height, width + ) + latent_height = 2 * (int(height) // (self._vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self._vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + timesteps, sigmas = self._scheduler.retrieve_timesteps_and_sigmas( + image_seq_len, base.num_inference_steps + ) + + boundary_timestep: float | None = None + boundary_ratio = self.diffusers_config.get("boundary_ratio") + if boundary_ratio is not None: + boundary_timestep = float(boundary_ratio) * float( + getattr(self._scheduler, "num_train_timesteps", 1000) + ) + + step_coefficients: npt.NDArray[np.float32] | None = None + if hasattr(self._scheduler, "build_step_coefficients"): + step_coefficients = self._scheduler.build_step_coefficients() + + latents = base.latents + if num_frames is not None: + vae_scale_factor_temporal = 4 + latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + latent_height = 2 * (int(height) // (self._vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self._vae_scale_factor * 2)) + shape_5d = ( + base.num_images_per_prompt, + self._num_channels_latents, + latent_frames, + latent_height, + latent_width, + ) + latents = self._randn_tensor(shape_5d, request.body.seed) + + return WanContext( + request_id=base.request_id, + model_name=base.model_name, + tokens=base.tokens, + mask=base.mask, + tokens_2=base.tokens_2, + negative_tokens=base.negative_tokens, + negative_mask=base.negative_mask, + negative_tokens_2=base.negative_tokens_2, + explicit_negative_prompt=base.explicit_negative_prompt, + timesteps=timesteps, + sigmas=sigmas, + latents=latents, + latent_image_ids=base.latent_image_ids, + height=base.height, + width=base.width, + num_frames=num_frames, + guidance_scale=base.guidance_scale, + true_cfg_scale=base.true_cfg_scale, + guidance_scale_2=guidance_scale_2, + cfg_normalization=base.cfg_normalization, + cfg_truncation=base.cfg_truncation, + num_inference_steps=base.num_inference_steps, + num_warmup_steps=base.num_warmup_steps, + strength=base.strength, + boundary_timestep=boundary_timestep, + step_coefficients=step_coefficients, + num_images_per_prompt=base.num_images_per_prompt, + input_image=base.input_image, + output_format=base.output_format, + residual_threshold=base.residual_threshold, + status=base.status, + ) diff --git a/max/python/max/pipelines/core/context.py b/max/python/max/pipelines/core/context.py index db53bb34d46..ba0f399c1ab 100644 --- a/max/python/max/pipelines/core/context.py +++ b/max/python/max/pipelines/core/context.py @@ -758,6 +758,8 @@ class PixelContext: """Image encoding format for the output (e.g., 'jpeg', 'png', 'webp').""" residual_threshold: float | None = field(default=None) """Per-request residual threshold for FBCache. None uses pipeline default.""" + num_frames: int | None = field(default=None) + """Number of frames for video generation.""" status: GenerationStatus = field(default=GenerationStatus.ACTIVE) @property diff --git a/max/python/max/pipelines/lib/pixel_tokenizer.py b/max/python/max/pipelines/lib/pixel_tokenizer.py index b404c0a7af5..0d631042c1a 100644 --- a/max/python/max/pipelines/lib/pixel_tokenizer.py +++ b/max/python/max/pipelines/lib/pixel_tokenizer.py @@ -99,6 +99,8 @@ class PipelineClassName(str, Enum): FLUX2 = "Flux2Pipeline" FLUX2_KLEIN = "Flux2KleinPipeline" ZIMAGE = "ZImagePipeline" + WAN = "WanPipeline" + WAN_I2V = "WanImageToVideoPipeline" @classmethod def from_class_name(cls, class_name: str) -> PipelineClassName: @@ -239,7 +241,12 @@ def __init__( if self._pipeline_class_name == PipelineClassName.ZIMAGE: self._num_channels_latents = transformer_config["in_channels"] else: - self._num_channels_latents = transformer_config["in_channels"] // 4 + out_channels = transformer_config.get("out_channels") + self._num_channels_latents = ( + out_channels + if out_channels is not None + else transformer_config["in_channels"] // 4 + ) # Create scheduler from its component config. scheduler_config = models["scheduler"].huggingface_config @@ -902,9 +909,17 @@ async def new_context( " but may produce lower quality or unexpected results." ) + # Resolve negative_prompt: prefer video options for video pipelines. + video_options = request.body.provider_options.video + negative_prompt_resolved = ( + video_options.negative_prompt + if video_options and video_options.negative_prompt + else None + ) or image_options.negative_prompt + if ( image_options.true_cfg_scale > 1.0 - and image_options.negative_prompt is None + and negative_prompt_resolved is None ): logger.warning( f"true_cfg_scale={image_options.true_cfg_scale} is set, but no negative_prompt " @@ -928,7 +943,7 @@ async def new_context( else: do_true_cfg = ( image_options.true_cfg_scale > 1.0 - and image_options.negative_prompt is not None + and negative_prompt_resolved is not None ) # 1. Tokenize prompts @@ -953,7 +968,7 @@ async def new_context( ) = await self._generate_tokens_ids( prompt, image_options.secondary_prompt, - image_options.negative_prompt, + negative_prompt_resolved, image_options.secondary_negative_prompt, do_true_cfg or do_zimage_cfg, images=images_for_tokenization, @@ -992,17 +1007,29 @@ async def new_context( self._pipeline_class_name != PipelineClassName.ZIMAGE ), ) - height = image_options.height or preprocessed_image.height - width = image_options.width or preprocessed_image.width + height = ( + (video_options and video_options.height) + or image_options.height + or preprocessed_image.height + ) + width = ( + (video_options and video_options.width) + or image_options.width + or preprocessed_image.width + ) preprocessed_image_array = np.array( preprocessed_image, dtype=np.uint8 ).copy() else: height = ( - image_options.height or default_sample_size * vae_scale_factor + (video_options and video_options.height) + or image_options.height + or default_sample_size * vae_scale_factor ) width = ( - image_options.width or default_sample_size * vae_scale_factor + (video_options and video_options.width) + or image_options.width + or default_sample_size * vae_scale_factor ) # 3. Resolve image dimensions using cached static values @@ -1010,10 +1037,19 @@ async def new_context( latent_width = 2 * (int(width) // (self._vae_scale_factor * 2)) image_seq_len = (latent_height // 2) * (latent_width // 2) + video_steps = ( + video_options.steps + if video_options and video_options.steps is not None + else None + ) num_inference_steps = ( - image_options.steps - if "steps" in image_options.model_fields_set - else self._default_num_inference_steps + video_steps + if video_steps is not None + else ( + image_options.steps + if "steps" in image_options.model_fields_set + else self._default_num_inference_steps + ) ) sigma_min = ( 0.0 @@ -1092,6 +1128,7 @@ async def new_context( input_image=preprocessed_image_array, # Pass numpy array instead of PIL.Image output_format=image_options.output_format, residual_threshold=image_options.residual_threshold, + num_frames=video_options.num_frames if video_options else None, ) return context