Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions max/python/max/interfaces/provider_options/modality/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,9 @@ class VideoProviderOptions(BaseModel):
),
gt=0,
)

guidance_scale_2: float | None = Field(
None,
description="Secondary guidance scale for boundary timestep switching.",
gt=0.0,
)
35 changes: 35 additions & 0 deletions max/python/max/pipelines/architectures/wan/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Wan-specific pixel generation context."""

from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np
import numpy.typing as npt
from max.pipelines.core import PixelContext


@dataclass(kw_only=True)
class WanContext(PixelContext):
"""Pixel generation context with Wan-specific video/MoE fields."""

guidance_scale_2: float | None = field(default=None)
"""Secondary guidance scale for low-noise expert (MoE models)."""

step_coefficients: npt.NDArray[np.float32] | None = field(default=None)
"""Pre-computed scheduler step coefficients."""

boundary_timestep: float | None = field(default=None)
"""Timestep threshold for switching between high/low noise experts."""
141 changes: 141 additions & 0 deletions max/python/max/pipelines/architectures/wan/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Wan-specific pixel generation tokenizer."""

from __future__ import annotations

import logging

import numpy as np
import numpy.typing as npt
import PIL.Image
from max.interfaces.request import OpenResponsesRequest
from max.pipelines.lib.pixel_tokenizer import PixelGenerationTokenizer

from .context import WanContext

logger = logging.getLogger("max.pipelines")


class WanTokenizer(PixelGenerationTokenizer):
"""Wan-specific tokenizer that produces WanContext with video/MoE fields."""

def _select_wan_flow_shift(self, height: int, width: int) -> float:
scheduler_cfg = (
self.diffusers_config.get("components", {})
.get("scheduler", {})
.get("config_dict", {})
)
# Use explicit flow_shift from scheduler config if set (user override).
cfg_shift = scheduler_cfg.get("flow_shift")
if cfg_shift is not None and float(cfg_shift) != 1.0:
return float(cfg_shift)
# Default: interpolate based on pixel count.
# 480p (480*832 = 399 360) → 3.0, 720p (720*1280 = 921 600) → 5.0
pixels = height * width
lo_px, hi_px = 399_360, 921_600
lo_shift, hi_shift = 3.0, 5.0
t = max(0.0, min(1.0, (pixels - lo_px) / (hi_px - lo_px)))
return lo_shift + t * (hi_shift - lo_shift)

async def new_context(
self,
request: OpenResponsesRequest,
input_image: PIL.Image.Image | None = None,
) -> WanContext:
base = await super().new_context(request, input_image=input_image)

video_options = request.body.provider_options.video

num_frames: int | None = (
video_options.num_frames if video_options else None
)
guidance_scale_2: float | None = (
video_options.guidance_scale_2 if video_options else None
)

height = base.height
width = base.width
timesteps: npt.NDArray[np.float32] = base.timesteps
sigmas: npt.NDArray[np.float32] = base.sigmas

if getattr(self._scheduler, "use_flow_sigmas", False):
self._scheduler.flow_shift = self._select_wan_flow_shift(
height, width
)
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
timesteps, sigmas = self._scheduler.retrieve_timesteps_and_sigmas(
image_seq_len, base.num_inference_steps
)

boundary_timestep: float | None = None
boundary_ratio = self.diffusers_config.get("boundary_ratio")
if boundary_ratio is not None:
boundary_timestep = float(boundary_ratio) * float(
getattr(self._scheduler, "num_train_timesteps", 1000)
)

step_coefficients: npt.NDArray[np.float32] | None = None
if hasattr(self._scheduler, "build_step_coefficients"):
step_coefficients = self._scheduler.build_step_coefficients()

latents = base.latents
if num_frames is not None:
vae_scale_factor_temporal = 4
latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
shape_5d = (
base.num_images_per_prompt,
self._num_channels_latents,
latent_frames,
latent_height,
latent_width,
)
latents = self._randn_tensor(shape_5d, request.body.seed)
Comment on lines +56 to +107
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in new_context for calculating flow shift, sigmas, and reshaping latents is currently duplicated from the base class PixelGenerationTokenizer.new_context. Since super().new_context() is called at the beginning, it already populates base with these values. To properly extract the logic as intended, the Wan-specific branches should be removed from the base class. Additionally, latent_height and latent_width are calculated twice within this method (lines 113-114 and 135-136); they should be calculated once and reused.


return WanContext(
request_id=base.request_id,
model_name=base.model_name,
tokens=base.tokens,
mask=base.mask,
tokens_2=base.tokens_2,
negative_tokens=base.negative_tokens,
negative_mask=base.negative_mask,
negative_tokens_2=base.negative_tokens_2,
explicit_negative_prompt=base.explicit_negative_prompt,
timesteps=timesteps,
sigmas=sigmas,
latents=latents,
latent_image_ids=base.latent_image_ids,
height=base.height,
width=base.width,
num_frames=num_frames,
guidance_scale=base.guidance_scale,
true_cfg_scale=base.true_cfg_scale,
guidance_scale_2=guidance_scale_2,
cfg_normalization=base.cfg_normalization,
cfg_truncation=base.cfg_truncation,
num_inference_steps=base.num_inference_steps,
num_warmup_steps=base.num_warmup_steps,
strength=base.strength,
boundary_timestep=boundary_timestep,
step_coefficients=step_coefficients,
num_images_per_prompt=base.num_images_per_prompt,
input_image=base.input_image,
output_format=base.output_format,
residual_threshold=base.residual_threshold,
status=base.status,
)
2 changes: 2 additions & 0 deletions max/python/max/pipelines/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,8 @@ class PixelContext:
"""Image encoding format for the output (e.g., 'jpeg', 'png', 'webp')."""
residual_threshold: float | None = field(default=None)
"""Per-request residual threshold for FBCache. None uses pipeline default."""
num_frames: int | None = field(default=None)
"""Number of frames for video generation."""
status: GenerationStatus = field(default=GenerationStatus.ACTIVE)

@property
Expand Down
59 changes: 48 additions & 11 deletions max/python/max/pipelines/lib/pixel_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class PipelineClassName(str, Enum):
FLUX2 = "Flux2Pipeline"
FLUX2_KLEIN = "Flux2KleinPipeline"
ZIMAGE = "ZImagePipeline"
WAN = "WanPipeline"
WAN_I2V = "WanImageToVideoPipeline"

@classmethod
def from_class_name(cls, class_name: str) -> PipelineClassName:
Expand Down Expand Up @@ -239,7 +241,12 @@ def __init__(
if self._pipeline_class_name == PipelineClassName.ZIMAGE:
self._num_channels_latents = transformer_config["in_channels"]
else:
self._num_channels_latents = transformer_config["in_channels"] // 4
out_channels = transformer_config.get("out_channels")
self._num_channels_latents = (
out_channels
if out_channels is not None
else transformer_config["in_channels"] // 4
)

# Create scheduler from its component config.
scheduler_config = models["scheduler"].huggingface_config
Expand Down Expand Up @@ -902,9 +909,17 @@ async def new_context(
" but may produce lower quality or unexpected results."
)

# Resolve negative_prompt: prefer video options for video pipelines.
video_options = request.body.provider_options.video
negative_prompt_resolved = (
video_options.negative_prompt
if video_options and video_options.negative_prompt
else None
) or image_options.negative_prompt

if (
image_options.true_cfg_scale > 1.0
and image_options.negative_prompt is None
and negative_prompt_resolved is None
):
logger.warning(
f"true_cfg_scale={image_options.true_cfg_scale} is set, but no negative_prompt "
Expand All @@ -928,7 +943,7 @@ async def new_context(
else:
do_true_cfg = (
image_options.true_cfg_scale > 1.0
and image_options.negative_prompt is not None
and negative_prompt_resolved is not None
)

# 1. Tokenize prompts
Expand All @@ -953,7 +968,7 @@ async def new_context(
) = await self._generate_tokens_ids(
prompt,
image_options.secondary_prompt,
image_options.negative_prompt,
negative_prompt_resolved,
image_options.secondary_negative_prompt,
do_true_cfg or do_zimage_cfg,
images=images_for_tokenization,
Expand Down Expand Up @@ -992,28 +1007,49 @@ async def new_context(
self._pipeline_class_name != PipelineClassName.ZIMAGE
),
)
height = image_options.height or preprocessed_image.height
width = image_options.width or preprocessed_image.width
height = (
(video_options and video_options.height)
or image_options.height
or preprocessed_image.height
)
width = (
(video_options and video_options.width)
or image_options.width
or preprocessed_image.width
)
preprocessed_image_array = np.array(
preprocessed_image, dtype=np.uint8
).copy()
else:
height = (
image_options.height or default_sample_size * vae_scale_factor
(video_options and video_options.height)
or image_options.height
or default_sample_size * vae_scale_factor
)
width = (
image_options.width or default_sample_size * vae_scale_factor
(video_options and video_options.width)
or image_options.width
or default_sample_size * vae_scale_factor
)

# 3. Resolve image dimensions using cached static values
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)

video_steps = (
video_options.steps
if video_options and video_options.steps is not None
else None
)
num_inference_steps = (
image_options.steps
if "steps" in image_options.model_fields_set
else self._default_num_inference_steps
video_steps
if video_steps is not None
else (
image_options.steps
if "steps" in image_options.model_fields_set
else self._default_num_inference_steps
)
)
sigma_min = (
0.0
Expand Down Expand Up @@ -1092,6 +1128,7 @@ async def new_context(
input_image=preprocessed_image_array, # Pass numpy array instead of PIL.Image
output_format=image_options.output_format,
residual_threshold=image_options.residual_threshold,
num_frames=video_options.num_frames if video_options else None,
)

return context
Loading