Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions max/kernels/src/nn/conv/conv.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -5280,12 +5280,13 @@ def _conv3d_cudnn[
var algo: cudnnConvolutionFwdAlgo_t
var workspace_size_var: Int

if ptr_cached := _get_global_or_null(cache_key).bitcast[
_Conv3dAlgoCacheEntry
]():
if ptr_cached := _get_global_or_null(cache_key):
var cached = ptr_cached.unsafe_value().bitcast[
_Conv3dAlgoCacheEntry
]()
# Cache hit — reuse previously selected algorithm.
algo = ptr_cached[].algo()
workspace_size_var = ptr_cached[].workspace_size
algo = cached[].algo()
workspace_size_var = cached[].workspace_size
else:
# Cache miss — run FindEx to find the fastest algorithm.
var find_ws = ctx.enqueue_create_buffer[DType.uint8](FIND_WS_CAP)
Expand Down
6 changes: 6 additions & 0 deletions max/python/max/interfaces/provider_options/modality/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,9 @@ class VideoProviderOptions(BaseModel):
),
gt=0,
)

guidance_scale_2: float | None = Field(
None,
description="Secondary guidance scale for boundary timestep switching.",
gt=0.0,
)
3 changes: 3 additions & 0 deletions max/python/max/pipelines/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def register_all_models() -> None:
from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch
from .unified_eagle_llama3 import unified_eagle_llama3_arch
from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch
from .wan import wan_arch, wan_i2v_arch
from .z_image_modulev3 import z_image_arch

architectures = [
Expand Down Expand Up @@ -137,6 +138,8 @@ def register_all_models() -> None:
qwen3vl_moe_arch,
unified_eagle_llama3_arch,
unified_mtp_deepseekV3_arch,
wan_arch,
wan_i2v_arch,
z_image_arch,
]

Expand Down
25 changes: 25 additions & 0 deletions max/python/max/pipelines/architectures/wan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from .arch import wan_arch, wan_i2v_arch
from .model import WanTransformerModel
from .pipeline_wan import WanPipeline
from .pipeline_wan_i2v import WanI2VPipeline

__all__ = [
"WanI2VPipeline",
"WanPipeline",
"WanTransformerModel",
"wan_arch",
"wan_i2v_arch",
]
87 changes: 87 additions & 0 deletions max/python/max/pipelines/architectures/wan/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from __future__ import annotations

from dataclasses import dataclass

from max.graph.weights import WeightsFormat
from max.interfaces import PipelineTask
from max.pipelines.core import PixelContext
from max.pipelines.lib import (
PixelGenerationTokenizer,
SupportedArchitecture,
)
from max.pipelines.lib.config import MAXModelConfig, PipelineConfig
from max.pipelines.lib.interfaces import ArchConfig
from typing_extensions import Self

from .pipeline_wan import WanPipeline
from .pipeline_wan_i2v import WanI2VPipeline


@dataclass(kw_only=True)
class WanArchConfig(ArchConfig):
"""Pipeline-level config for Wan (implements ArchConfig; no KV cache)."""

pipeline_config: PipelineConfig

def get_max_seq_len(self) -> int:
# Tokenizer padding length — matches diffusers __call__ default.
return 512

@classmethod
def initialize(
cls,
pipeline_config: PipelineConfig,
model_config: MAXModelConfig | None = None,
) -> Self:
model_config = model_config or pipeline_config.model
if len(model_config.device_specs) != 1:
raise ValueError("Wan is only supported on a single device")
return cls(pipeline_config=pipeline_config)


wan_arch = SupportedArchitecture(
name="WanPipeline",
task=PipelineTask.PIXEL_GENERATION,
default_encoding="bfloat16",
supported_encodings={"bfloat16", "float32"},
example_repo_ids=[
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
"Wan-AI/Wan2.2-TI2V-5B-Diffusers",
"yetter-ai/Wan2.2-TI2V-5B-Turbo-Diffusers",
],
pipeline_model=WanPipeline, # type: ignore[arg-type]
context_type=PixelContext,
default_weights_format=WeightsFormat.safetensors,
tokenizer=PixelGenerationTokenizer,
config=WanArchConfig,
)

wan_i2v_arch = SupportedArchitecture(
name="WanImageToVideoPipeline",
task=PipelineTask.PIXEL_GENERATION,
default_encoding="bfloat16",
supported_encodings={"bfloat16", "float32"},
example_repo_ids=[
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
],
pipeline_model=WanI2VPipeline,
context_type=PixelContext,
default_weights_format=WeightsFormat.safetensors,
tokenizer=PixelGenerationTokenizer,
config=WanArchConfig,
)
Loading
Loading