From b387126742ec595cadc0200c66ad9ba3d2198082 Mon Sep 17 00:00:00 2001 From: jglee-sqbits Date: Wed, 1 Apr 2026 04:15:11 +0000 Subject: [PATCH] [MAX] Add Wan VAE and refactor autoencoder module ## Summary Add a Wan VAE (3D causal video autoencoder) and restructure the autoencoder module to separate Module V2 and V3 implementations. ## Description ### Wan VAE - Implements the Wan 3D causal VAE with temporal caching for chunked encode/decode - Encoder: processes video in temporal chunks (first frame + subsequent chunks) with cached convolution state to maintain temporal consistency - Decoder: same chunked approach with 3 specialized graphs (post-quant conv, first frame, subsequent frames) - Uses symbolic spatial dims for resolution flexibility - Adds 3D convolution support via cuDNN (`conv.mojo`) with depth-tiled execution for large volumes ### Autoencoder restructuring - Moves existing Module V3 (Flux) autoencoder files to `autoencoders_modulev3/` - The `autoencoders/` directory now contains Module V2 graph-based implementations (Wan VAE, Qwen Image VAE) - Updates `flux1_modulev3` and `flux2_modulev3` import paths accordingly This follows the same pattern as #6278 which established the V2/V3 split. ## Dependencies Should be merged **before** #6301 (transformer) and #6302 (pipeline-t2v), which import from `autoencoders`. ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code stack-info: PR: https://github.com/SqueezeBits/modular/pull/15, branch: jglee-sqbits/stack/3 --- max/kernels/src/nn/conv/conv.mojo | 714 ++++- .../architectures/autoencoders/__init__.py | 7 +- .../autoencoders/autoencoder_kl.py | 117 - .../autoencoders/autoencoder_kl_flux2.py | 280 -- .../autoencoders/autoencoder_kl_wan.py | 628 ++++ .../autoencoders/decode_step_flux2.py | 147 - .../autoencoders/layers/__init__.py | 17 - .../autoencoders/layers/attention.py | 145 - .../autoencoders/layers/downsampling.py | 165 - .../autoencoders/layers/resnet.py | 146 - .../autoencoders/layers/upsampling.py | 138 - .../architectures/autoencoders/model.py | 225 -- .../autoencoders/model_config.py | 77 + .../architectures/autoencoders/vae.py | 2683 ++++++++++++----- .../architectures/flux2/pipeline_flux2.py | 2 +- 15 files changed, 3406 insertions(+), 2085 deletions(-) delete mode 100644 max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_flux2.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_wan.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/decode_step_flux2.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/__init__.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/attention.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/downsampling.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/resnet.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py delete mode 100644 max/python/max/pipelines/architectures/autoencoders/model.py diff --git a/max/kernels/src/nn/conv/conv.mojo b/max/kernels/src/nn/conv/conv.mojo index 3288a76b6d3..c6e5dbae505 100644 --- a/max/kernels/src/nn/conv/conv.mojo +++ b/max/kernels/src/nn/conv/conv.mojo @@ -21,14 +21,17 @@ from std.sys.info import align_of, simd_width_of from _cudnn.cnn_infer import ( cudnnConvolutionForward, + cudnnConvolutionFwdAlgoPerfStruct, cudnnConvolutionMode_t, cudnnConvolutionStruct, cudnnCreateConvolutionDescriptor, cudnnDestroyConvolutionDescriptor, + cudnnFindConvolutionForwardAlgorithmEx, cudnnGetConvolutionForwardWorkspaceSize, cudnnSetConvolution2dDescriptor, cudnnSetConvolutionGroupCount, cudnnSetConvolutionMathType, + cudnnSetConvolutionNdDescriptor, cudnnGetConvolutionForwardAlgorithm_v7, cudnnConvolutionFwdAlgoPerf_t, ) @@ -45,8 +48,10 @@ from _cudnn.infer import ( cudnnFilterStruct, cudnnMathType_t, cudnnSetFilter4dDescriptor, + cudnnSetFilterNdDescriptor, cudnnSetStream, cudnnSetTensor4dDescriptor, + cudnnSetTensorNdDescriptorEx, cudnnStatus_t, cudnnTensorFormat_t, cudnnTensorStruct, @@ -4642,20 +4647,33 @@ def conv_gpu[ ) elif input_lt.rank == 5: - var grid_dim_x = ceildiv( - output_lt.dim[2]() * output_lt.dim[3](), block_size - ) # h * w / block size for 3d - ctx.enqueue_function[conv_gpu_3d, conv_gpu_3d]( - input_lt, - filter_lt, - output_lt, - stride, - dilation, - symmetric_padding, - num_groups, - grid_dim=(grid_dim_x, grid_dim_y, grid_dim_z), - block_dim=(block_size, block_size), - ) + + comptime if filter_is_fcrs: + conv3d_cudnn[input_type, filter_type, output_type]( + input_lt, + filter_lt, + output_lt, + rebind[IndexList[3]](stride), + rebind[IndexList[3]](dilation), + rebind[IndexList[3]](symmetric_padding), + num_groups, + ctx, + ) + else: + var grid_dim_x = ceildiv( + output_lt.dim[2]() * output_lt.dim[3](), block_size + ) # h * w / block size for 3d + ctx.enqueue_function[conv_gpu_3d, conv_gpu_3d]( + input_lt, + filter_lt, + output_lt, + stride, + dilation, + symmetric_padding, + num_groups, + grid_dim=(grid_dim_x, grid_dim_y, grid_dim_z), + block_dim=(block_size, block_size), + ) def conv3d_gpu_naive_ndhwc_qrscf[ @@ -4778,3 +4796,671 @@ def conv3d_gpu_naive_ndhwc_qrscf[ ), value.cast[output_type](), ) + + +# ===----------------------------------------------------------------------=== # +# GPU 3D Convolution using cuDNN (Nd APIs) # +# ===----------------------------------------------------------------------=== # + + +@fieldwise_init +struct _Conv3dAlgoCacheEntry(Copyable, Movable): + """Cached cuDNN algorithm selection result for a conv3d shape.""" + + var algo_value: Int8 + var workspace_size: Int + + def algo(self) -> cudnnConvolutionFwdAlgo_t: + return rebind[cudnnConvolutionFwdAlgo_t](self.algo_value) + + +def _conv3d_cudnn_depth_tiled[ + input_type: DType, + filter_type: DType, + output_type: DType, +]( + input: LayoutTensor[input_type, ...], + filter: LayoutTensor[filter_type, ...], + output: LayoutTensor[output_type, ...], + stride: IndexList[3], + dilation: IndexList[3], + padding: IndexList[3], + num_groups: Int, + ctx: DeviceContext, +) raises: + """Depth-tiled cuDNN 3D convolution for tensors exceeding INT32_MAX elements. + + Splits the computation along the depth dimension (dim[1] in NDHWC) into + tiles small enough for cuDNN's internal Int32 stride calculations. + Each tile uses a separate set of cuDNN descriptors. + """ + comptime INT32_MAX_VAL = 2147483647 + comptime FIND_WS_CAP = 256 * 1024 * 1024 + + var N = input.dim[0]() + var D_in = input.dim[1]() + var H = input.dim[2]() + var W = input.dim[3]() + var C = input.dim[4]() + + var K_d = filter.dim[2]() # kernel depth (Q in FCQRS) + var F_out = filter.dim[0]() # output channels + var D_out = output.dim[1]() + var H_out = output.dim[2]() + var W_out = output.dim[3]() + + var eff_k = (K_d - 1) * dilation[0] + 1 # effective kernel depth + + # Calculate max input depth per tile. + var per_frame_in = N * H * W * C + var max_d_in = INT32_MAX_VAL // per_frame_in + + # Also ensure output elements per tile fit in INT32. + var per_frame_out = N * H_out * W_out * F_out + var max_d_out = INT32_MAX_VAL // per_frame_out + # Output frames from max_d_in input frames: + var tile_d_out_from_in = (max_d_in + 2 * padding[0] - eff_k) // stride[ + 0 + ] + 1 + var tile_d_out = min(tile_d_out_from_in, max_d_out) + if tile_d_out < 1: + raise "conv3d: tensor too large even for single-frame tiling" + + # Input depth needed for tile_d_out output frames. + var tile_d_in = (tile_d_out - 1) * stride[0] + eff_k - 2 * padding[0] + + # Strides (in elements) along the depth dimension. + var in_d_stride = H * W * C # elements per depth frame + var out_d_stride = H_out * W_out * F_out + + var ptr_meta = _get_cudnn_meta(ctx) + + # Descriptor arrays (reused across tiles). + var input_dims = alloc[Int32](5) + var output_dims = alloc[Int32](5) + var filter_dims = alloc[Int32](5) + var pad_a = alloc[Int32](3) + var stride_a = alloc[Int32](3) + var dilation_a = alloc[Int32](3) + + # Filter dims (constant across tiles). + filter_dims[0] = Int32(filter.dim[0]()) + filter_dims[1] = Int32(filter.dim[1]()) + filter_dims[2] = Int32(filter.dim[2]()) + filter_dims[3] = Int32(filter.dim[3]()) + filter_dims[4] = Int32(filter.dim[4]()) + + check_cudnn_error( + cudnnSetFilterNdDescriptor( + ptr_meta[].ptr_filter_desc, + get_cudnn_dtype[filter_type](), + cudnnTensorFormat_t.CUDNN_TENSOR_NCHW, + Int16(5), + filter_dims.bitcast[NoneType](), + ) + ) + + # Convolution params (constant except padding for first tile). + stride_a[0] = Int32(stride[0]) + stride_a[1] = Int32(stride[1]) + stride_a[2] = Int32(stride[2]) + dilation_a[0] = Int32(dilation[0]) + dilation_a[1] = Int32(dilation[1]) + dilation_a[2] = Int32(dilation[2]) + + var alpha = Float32(1.0) + var beta = Float32(0.0) + + var d_out_start = 0 + while d_out_start < D_out: + var this_d_out = min(tile_d_out, D_out - d_out_start) + + # Determine input range for this output tile. + # First tile gets front padding, last tile gets back padding. + var d_in_start: Int + var this_d_in: Int + var tile_pad_front: Int + var tile_pad_back: Int + + if d_out_start == 0: + # First tile: include front padding. + tile_pad_front = padding[0] + d_in_start = 0 + this_d_in = ( + (this_d_out - 1) * stride[0] + eff_k - 2 * tile_pad_front + ) + # Adjust: no need for more input than available + if this_d_in > D_in: + this_d_in = D_in + tile_pad_back = 0 + else: + tile_pad_front = 0 + # For stride=1: input frame for output d is at d (with padding=0) + d_in_start = d_out_start * stride[0] - padding[0] + if d_in_start < 0: + tile_pad_front = -d_in_start + d_in_start = 0 + this_d_in = (this_d_out - 1) * stride[0] + eff_k - tile_pad_front + # Check if we need back padding + if d_in_start + this_d_in > D_in: + tile_pad_back = d_in_start + this_d_in - D_in + this_d_in = D_in - d_in_start + else: + tile_pad_back = 0 + + # --- Set up tile descriptors --- + # Input tile: [N, this_d_in, H, W, C] + input_dims[0] = Int32(N) + input_dims[1] = Int32(C) + input_dims[2] = Int32(this_d_in) + input_dims[3] = Int32(H) + input_dims[4] = Int32(W) + + check_cudnn_error( + cudnnSetTensorNdDescriptorEx( + ptr_meta[].ptr_input_desc, + cudnnTensorFormat_t.CUDNN_TENSOR_NHWC, + get_cudnn_dtype[input_type](), + Int16(5), + input_dims.bitcast[NoneType](), + ) + ) + + # Output tile: [N, this_d_out, H_out, W_out, F] + output_dims[0] = Int32(N) + output_dims[1] = Int32(F_out) + output_dims[2] = Int32(this_d_out) + output_dims[3] = Int32(H_out) + output_dims[4] = Int32(W_out) + + check_cudnn_error( + cudnnSetTensorNdDescriptorEx( + ptr_meta[].ptr_output_desc, + cudnnTensorFormat_t.CUDNN_TENSOR_NHWC, + get_cudnn_dtype[output_type](), + Int16(5), + output_dims.bitcast[NoneType](), + ) + ) + + # Convolution with tile-specific depth padding. + pad_a[0] = Int32(tile_pad_front) + pad_a[1] = Int32(padding[1]) + pad_a[2] = Int32(padding[2]) + + check_cudnn_error( + cudnnSetConvolutionNdDescriptor( + ptr_meta[].ptr_conv_desc, + Int16(3), + pad_a.bitcast[NoneType](), + stride_a.bitcast[NoneType](), + dilation_a.bitcast[NoneType](), + cudnnConvolutionMode_t.CUDNN_CROSS_CORRELATION, + cudnnDataType_t.CUDNN_DATA_FLOAT, + ) + ) + check_cudnn_error( + cudnnSetConvolutionGroupCount( + ptr_meta[].ptr_conv_desc, Int16(num_groups) + ) + ) + check_cudnn_error( + cudnnSetConvolutionMathType( + ptr_meta[].ptr_conv_desc, + cudnnMathType_t.CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION, + ) + ) + + # --- Algorithm selection (use GetWorkspaceSize for PRECOMP_GEMM) --- + var algo = ( + cudnnConvolutionFwdAlgo_t.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + ) + var ws_size: Int = 0 + var ws_st = cudnnGetConvolutionForwardWorkspaceSize( + ptr_meta[].ptr_handle, + ptr_meta[].ptr_input_desc, + ptr_meta[].ptr_filter_desc, + ptr_meta[].ptr_conv_desc, + ptr_meta[].ptr_output_desc, + algo, + UnsafePointer(to=ws_size), + ) + if ws_st != cudnnStatus_t.CUDNN_STATUS_SUCCESS or ws_size > FIND_WS_CAP: + # Fall back to IMPLICIT_GEMM (no workspace needed). + algo = ( + cudnnConvolutionFwdAlgo_t.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM + ) + ws_size = 0 + + # --- Execute tile --- + var workspace_buffer = ctx.enqueue_create_buffer[DType.uint8](ws_size) + + # Compute pointer offsets for input and output tiles. + var in_offset = d_in_start * in_d_stride + var out_offset = d_out_start * out_d_stride + var in_ptr = input.ptr + in_offset + var out_ptr = output.ptr + out_offset + + var fwd_status = cudnnConvolutionForward( + ptr_meta[].ptr_handle, + UnsafePointer(to=alpha).bitcast[NoneType](), + ptr_meta[].ptr_input_desc, + in_ptr.bitcast[NoneType](), + ptr_meta[].ptr_filter_desc, + filter.ptr.bitcast[NoneType](), + ptr_meta[].ptr_conv_desc, + algo, + workspace_buffer.unsafe_ptr().bitcast[NoneType](), + ws_size, + UnsafePointer(to=beta).bitcast[NoneType](), + ptr_meta[].ptr_output_desc, + out_ptr.bitcast[NoneType](), + ) + _ = workspace_buffer^ + + if fwd_status != cudnnStatus_t.CUDNN_STATUS_SUCCESS: + input_dims.free() + output_dims.free() + filter_dims.free() + pad_a.free() + stride_a.free() + dilation_a.free() + ctx.synchronize() + raise String("conv3d tiled forward failed: ", fwd_status) + + d_out_start += this_d_out + + # Clean up. + input_dims.free() + output_dims.free() + filter_dims.free() + pad_a.free() + stride_a.free() + dilation_a.free() + + +def _conv3d_cudnn[ + input_type: DType, + filter_type: DType, + output_type: DType, +]( + input: LayoutTensor[input_type, ...], + filter: LayoutTensor[filter_type, ...], + output: LayoutTensor[output_type, ...], + stride: IndexList[3], + dilation: IndexList[3], + padding: IndexList[3], + num_groups: Int, + ctx: DeviceContext, +) raises: + """cuDNN 3D convolution using Nd descriptor APIs. + + Expects: + - input: NDHWC layout [N, D, H, W, C] + - filter: FCQRS layout [F, C/groups, Q, R, S] + - output: NDHWC layout [N, D_out, H_out, W_out, F] + + Algorithm selection is cached per unique shape+params combination so that + the expensive FindEx search only runs once per shape. + + When the total number of elements exceeds INT32_MAX (~2.1B), cuDNN's + internal stride calculations overflow. In this case we tile along the + depth (D) dimension, processing each tile with a separate cuDNN call. + """ + comptime FIND_WS_CAP = 256 * 1024 * 1024 + comptime INT32_MAX_VAL = 2147483647 + + # --- Check if depth tiling is needed (INT32 stride overflow) --- + var total_in = ( + input.dim[0]() + * input.dim[1]() + * input.dim[2]() + * input.dim[3]() + * input.dim[4]() + ) + if total_in > INT32_MAX_VAL: + _conv3d_cudnn_depth_tiled( + input, + filter, + output, + stride, + dilation, + padding, + num_groups, + ctx, + ) + return + + var ptr_meta = _get_cudnn_meta(ctx) + + # --- Set up cuDNN descriptors (required every call — shared state) --- + # Input: NDHWC in memory, described as NHWC format with dims [N,C,D,H,W]. + var input_dims = alloc[Int32](5) + input_dims[0] = Int32(input.dim[0]()) # N + input_dims[1] = Int32(input.dim[4]()) # C + input_dims[2] = Int32(input.dim[1]()) # D + input_dims[3] = Int32(input.dim[2]()) # H + input_dims[4] = Int32(input.dim[3]()) # W + + check_cudnn_error( + cudnnSetTensorNdDescriptorEx( + ptr_meta[].ptr_input_desc, + cudnnTensorFormat_t.CUDNN_TENSOR_NHWC, + get_cudnn_dtype[input_type](), + Int16(5), + input_dims.bitcast[NoneType](), + ) + ) + + # Filter: FCQRS layout [F, C/groups, Q, R, S], described as NCHW format. + var filter_dims = alloc[Int32](5) + filter_dims[0] = Int32(filter.dim[0]()) # F (out_channels) + filter_dims[1] = Int32(filter.dim[1]()) # C (in_channels / groups) + filter_dims[2] = Int32(filter.dim[2]()) # Q (depth) + filter_dims[3] = Int32(filter.dim[3]()) # R (height) + filter_dims[4] = Int32(filter.dim[4]()) # S (width) + + check_cudnn_error( + cudnnSetFilterNdDescriptor( + ptr_meta[].ptr_filter_desc, + get_cudnn_dtype[filter_type](), + cudnnTensorFormat_t.CUDNN_TENSOR_NCHW, + Int16(5), + filter_dims.bitcast[NoneType](), + ) + ) + + # Convolution: 3 spatial dimensions. + var pad_a = alloc[Int32](3) + pad_a[0] = Int32(padding[0]) + pad_a[1] = Int32(padding[1]) + pad_a[2] = Int32(padding[2]) + + var stride_a = alloc[Int32](3) + stride_a[0] = Int32(stride[0]) + stride_a[1] = Int32(stride[1]) + stride_a[2] = Int32(stride[2]) + + var dilation_a = alloc[Int32](3) + dilation_a[0] = Int32(dilation[0]) + dilation_a[1] = Int32(dilation[1]) + dilation_a[2] = Int32(dilation[2]) + + check_cudnn_error( + cudnnSetConvolutionNdDescriptor( + ptr_meta[].ptr_conv_desc, + Int16(3), + pad_a.bitcast[NoneType](), + stride_a.bitcast[NoneType](), + dilation_a.bitcast[NoneType](), + cudnnConvolutionMode_t.CUDNN_CROSS_CORRELATION, + cudnnDataType_t.CUDNN_DATA_FLOAT, + ) + ) + + check_cudnn_error( + cudnnSetConvolutionGroupCount( + ptr_meta[].ptr_conv_desc, Int16(num_groups) + ) + ) + + # Output: NDHWC in memory, described as NHWC format with dims [N,C,D,H,W]. + var output_dims = alloc[Int32](5) + output_dims[0] = Int32(output.dim[0]()) # N + output_dims[1] = Int32(output.dim[4]()) # C (out_channels) + output_dims[2] = Int32(output.dim[1]()) # D_out + output_dims[3] = Int32(output.dim[2]()) # H_out + output_dims[4] = Int32(output.dim[3]()) # W_out + + check_cudnn_error( + cudnnSetTensorNdDescriptorEx( + ptr_meta[].ptr_output_desc, + cudnnTensorFormat_t.CUDNN_TENSOR_NHWC, + get_cudnn_dtype[output_type](), + Int16(5), + output_dims.bitcast[NoneType](), + ) + ) + + # Allow tensor-op math with automatic type conversion — required for + # bfloat16 3D convolutions on modern cuDNN (matches PR #5988 approach). + check_cudnn_error( + cudnnSetConvolutionMathType( + ptr_meta[].ptr_conv_desc, + cudnnMathType_t.CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION, + ) + ) + + # --- Algorithm selection (cached per shape) --- + var cache_key = String( + "CONV3D_ALGO_", + ctx.id(), + "_", + input.dim[0](), + "_", + input.dim[4](), + "_", + input.dim[1](), + "_", + input.dim[2](), + "_", + input.dim[3](), + "_F", + filter.dim[0](), + "_", + filter.dim[1](), + "_", + filter.dim[2](), + "_", + filter.dim[3](), + "_", + filter.dim[4](), + "_p", + padding[0], + "_", + padding[1], + "_", + padding[2], + "_s", + stride[0], + "_", + stride[1], + "_", + stride[2], + "_d", + dilation[0], + "_", + dilation[1], + "_", + dilation[2], + "_g", + num_groups, + ) + + var algo: cudnnConvolutionFwdAlgo_t + var workspace_size_var: Int + + if ptr_cached := _get_global_or_null(cache_key).bitcast[ + _Conv3dAlgoCacheEntry + ](): + # Cache hit — reuse previously selected algorithm. + algo = ptr_cached[].algo() + workspace_size_var = ptr_cached[].workspace_size + else: + # Cache miss — run FindEx to find the fastest algorithm. + var find_ws = ctx.enqueue_create_buffer[DType.uint8](FIND_WS_CAP) + + # CRITICAL: The Mojo cudnnConvolutionFwdAlgoPerfStruct uses Int8 for + # enum fields, but the C struct uses int (4 bytes). This causes a + # size mismatch: Mojo struct = ~32 bytes, C struct = 48 bytes. + # Allocating with the Mojo struct size would cause a buffer overflow + # when cuDNN writes 8 * 48 = 384 bytes. We allocate raw bytes with + # the correct C struct size and read fields at proper offsets. + comptime C_PERF_STRUCT_SIZE = 48 # sizeof(cudnnConvolutionFwdAlgoPerf_t) + comptime MAX_ALGOS = 8 + var perf_bytes = alloc[UInt8](MAX_ALGOS * C_PERF_STRUCT_SIZE) + + # returned_algo_count is int* in C (4 bytes), not Int16*. + # Use Int32 and bitcast the pointer. + var returned_count_i32 = Int32(0) + + var find_status = cudnnFindConvolutionForwardAlgorithmEx( + ptr_meta[].ptr_handle, + ptr_meta[].ptr_input_desc, + input.ptr.bitcast[NoneType](), + ptr_meta[].ptr_filter_desc, + filter.ptr.bitcast[NoneType](), + ptr_meta[].ptr_conv_desc, + ptr_meta[].ptr_output_desc, + output.ptr.bitcast[NoneType](), + Int16(MAX_ALGOS), + UnsafePointer(to=returned_count_i32).bitcast[Int16](), + perf_bytes.bitcast[cudnnConvolutionFwdAlgoPerfStruct](), + find_ws.unsafe_ptr().bitcast[NoneType](), + FIND_WS_CAP, + ) + _ = find_ws^ + + # Read the returned count (C int at offset 0 of returned_count_i32). + var returned_count = Int(returned_count_i32) + + # Pick the fastest successful algorithm within workspace cap. + # Read fields from raw bytes at correct C struct offsets: + # offset 0: algo (int32) + # offset 4: status (int32) + # offset 8: time (float32) + # offset 16: memory (size_t / int64) + algo = ( + cudnnConvolutionFwdAlgo_t.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM + ) + workspace_size_var = 0 + + var find_status_val = rebind[Int8](find_status) + if find_status_val == 0: # CUDNN_STATUS_SUCCESS + for i in range(returned_count): + var base = perf_bytes + i * C_PERF_STRUCT_SIZE + var algo_val = base.bitcast[Int32]()[] # offset 0 + var status_val = (base + 4).bitcast[Int32]()[] # offset 4 + var memory_val = (base + 16).bitcast[Int]()[] # offset 16 + if status_val == 0 and memory_val <= FIND_WS_CAP: + algo = rebind[cudnnConvolutionFwdAlgo_t](Int8(algo_val)) + workspace_size_var = memory_val + break + else: + print( + "conv3d FindEx FAILED: status=", + Int(find_status_val), + " input=[N=", + input.dim[0](), + " C=", + input.dim[4](), + " D=", + input.dim[1](), + " H=", + input.dim[2](), + " W=", + input.dim[3](), + "]", + ) + perf_bytes.free() + + # Fallback: if FindEx found nothing useful, try PRECOMP_GEMM via + # workspace size query (cheaper than FindEx). + if ( + algo + == cudnnConvolutionFwdAlgo_t.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM + and workspace_size_var == 0 + ): + var precomp = ( + cudnnConvolutionFwdAlgo_t.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + ) + var ws_size: Int = 0 + var ws_st = cudnnGetConvolutionForwardWorkspaceSize( + ptr_meta[].ptr_handle, + ptr_meta[].ptr_input_desc, + ptr_meta[].ptr_filter_desc, + ptr_meta[].ptr_conv_desc, + ptr_meta[].ptr_output_desc, + precomp, + UnsafePointer(to=ws_size), + ) + if ( + ws_st == cudnnStatus_t.CUDNN_STATUS_SUCCESS + and ws_size <= FIND_WS_CAP + ): + algo = precomp + workspace_size_var = ws_size + + # Store result in global cache. + var ptr_entry = alloc[_Conv3dAlgoCacheEntry](1) + ptr_entry.init_pointee_move( + _Conv3dAlgoCacheEntry( + algo_value=rebind[Int8](algo), + workspace_size=workspace_size_var, + ) + ) + external_call["KGEN_CompilerRT_InsertGlobal", NoneType]( + StringSlice(cache_key), + ptr_entry.bitcast[NoneType](), + ) + + # --- Execute convolution with cached/selected algorithm --- + var alpha = Float32(1.0) + var beta = Float32(0.0) + + var workspace_buffer = ctx.enqueue_create_buffer[DType.uint8]( + workspace_size_var + ) + var fwd_status = cudnnConvolutionForward( + ptr_meta[].ptr_handle, + UnsafePointer(to=alpha).bitcast[NoneType](), + ptr_meta[].ptr_input_desc, + input.ptr.bitcast[NoneType](), + ptr_meta[].ptr_filter_desc, + filter.ptr.bitcast[NoneType](), + ptr_meta[].ptr_conv_desc, + algo, + workspace_buffer.unsafe_ptr().bitcast[NoneType](), + workspace_size_var, + UnsafePointer(to=beta).bitcast[NoneType](), + ptr_meta[].ptr_output_desc, + output.ptr.bitcast[NoneType](), + ) + # Free workspace BEFORE sync to release the buffer back to the pool. + _ = workspace_buffer^ + + # Free temporary descriptor arrays. + input_dims.free() + filter_dims.free() + pad_a.free() + stride_a.free() + dilation_a.free() + output_dims.free() + + if fwd_status != cudnnStatus_t.CUDNN_STATUS_SUCCESS: + # Synchronize device to flush any pending GPU operations and free + # temporary cuDNN allocations, preventing VRAM accumulation. + print("conv3d FORWARD FAILED: ", fwd_status, " algo=", algo) + ctx.synchronize() + raise String("cudnnConvolutionForward failed: ", fwd_status) + + +def conv3d_cudnn[ + input_type: DType, + filter_type: DType, + output_type: DType, +]( + input: LayoutTensor[input_type, ...], + filter: LayoutTensor[filter_type, ...], + output: LayoutTensor[output_type, ...], + stride: IndexList[3], + dilation: IndexList[3], + padding: IndexList[3], + num_groups: Int, + ctx: DeviceContext, +) raises: + # Set `ctx`'s CUcontext as current to satisfy cudnn's stateful API. + with ctx.push_context() as ctx: + _conv3d_cudnn( + input, filter, output, stride, dilation, padding, num_groups, ctx + ) diff --git a/max/python/max/pipelines/architectures/autoencoders/__init__.py b/max/python/max/pipelines/architectures/autoencoders/__init__.py index 415932aa21e..81085487610 100644 --- a/max/python/max/pipelines/architectures/autoencoders/__init__.py +++ b/max/python/max/pipelines/architectures/autoencoders/__init__.py @@ -11,6 +11,9 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # -from .autoencoder_kl import AutoencoderKLModel -from .autoencoder_kl_flux2 import AutoencoderKLFlux2Model +from ..autoencoders_modulev3 import ( + AutoencoderKLFlux2Model, + AutoencoderKLModel, +) from .autoencoder_kl_qwen_image import AutoencoderKLQwenImageModel +from .autoencoder_kl_wan import AutoencoderKLWanModel diff --git a/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py b/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py deleted file mode 100644 index dfab42262d2..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py +++ /dev/null @@ -1,117 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -from typing import Any - -from max.driver import Device -from max.graph import TensorValue -from max.graph.weights import Weights -from max.nn.layer import Module -from max.pipelines.lib import SupportedEncoding - -from .model import BaseAutoencoderModel -from .model_config import AutoencoderKLConfig -from .vae import Decoder, Encoder - - -class AutoencoderKL(Module): - r"""A VAE model with KL loss for encoding images into latents and decoding latent representations into images.""" - - def __init__( - self, - config: AutoencoderKLConfig, - ) -> None: - """Initialize VAE AutoencoderKL model. - - Args: - config: Autoencoder configuration containing channel sizes, block - structure, normalization settings, and device/dtype information. - """ - super().__init__() - self.encoder = Encoder( - in_channels=config.in_channels, - out_channels=config.latent_channels, - down_block_types=tuple(config.down_block_types), - block_out_channels=tuple(config.block_out_channels), - layers_per_block=config.layers_per_block, - norm_num_groups=config.norm_num_groups, - act_fn=config.act_fn, - double_z=True, - mid_block_add_attention=config.mid_block_add_attention, - use_quant_conv=config.use_quant_conv, - device=config.device, - dtype=config.dtype, - ) - self.decoder = Decoder( - in_channels=config.latent_channels, - out_channels=config.out_channels, - up_block_types=tuple(config.up_block_types), - block_out_channels=tuple(config.block_out_channels), - layers_per_block=config.layers_per_block, - norm_num_groups=config.norm_num_groups, - act_fn=config.act_fn, - norm_type="group", - mid_block_add_attention=config.mid_block_add_attention, - use_post_quant_conv=config.use_post_quant_conv, - device=config.device, - dtype=config.dtype, - ) - - def __call__( - self, z: TensorValue, temb: TensorValue | None = None - ) -> TensorValue: - """Apply AutoencoderKL forward pass (decoding only). - - Args: - z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. - temb: Optional time embedding tensor. - - Returns: - Decoded image tensor of shape [N, C_out, H, W]. - """ - return self.decoder(z, temb) - - -class AutoencoderKLModel(BaseAutoencoderModel): - """ComponentModel wrapper for AutoencoderKL. - - This class provides the ComponentModel interface for AutoencoderKL, - handling configuration, weight loading, and model compilation. - """ - - def __init__( - self, - config: dict[str, Any], - encoding: SupportedEncoding, - devices: list[Device], - weights: Weights, - **kwargs: Any, - ) -> None: - """Initialize AutoencoderKLModel. - - Args: - config: Model configuration dictionary. - encoding: Supported encoding for the model. - devices: List of devices to use. - weights: Model weights. - **kwargs: Additional keyword arguments forwarded to ComponentModel. - """ - super().__init__( - config=config, - encoding=encoding, - devices=devices, - weights=weights, - config_class=AutoencoderKLConfig, - autoencoder_class=AutoencoderKL, - **kwargs, - ) diff --git a/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_flux2.py b/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_flux2.py deleted file mode 100644 index ddbe869a3b9..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_flux2.py +++ /dev/null @@ -1,280 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -from collections.abc import Callable -from types import SimpleNamespace -from typing import Any - -from max.driver import Buffer, Device -from max.graph import Graph, TensorType, TensorValue, ops -from max.graph.weights import Weights -from max.nn.layer import Module -from max.pipelines.lib import SupportedEncoding -from max.profiler import traced - -from .decode_step_flux2 import Flux2DecodeStep -from .model import BaseAutoencoderModel -from .model_config import AutoencoderKLFlux2Config -from .vae import Decoder, DiagonalGaussianDistribution, Encoder - - -class AutoencoderKLFlux2(Module): - r"""A VAE model with KL loss for encoding images into latents and decoding latent representations into images.""" - - def __init__( - self, - config: AutoencoderKLFlux2Config, - ) -> None: - """Initialize VAE AutoencoderKLFlux2 model. - - Args: - config: AutoencoderKLFlux2 configuration containing channel sizes, block - structure, normalization settings, BatchNorm parameters, and device/dtype information. - """ - super().__init__() - self.encoder = Encoder( - in_channels=config.in_channels, - out_channels=config.latent_channels, - down_block_types=tuple(config.down_block_types), - block_out_channels=tuple(config.block_out_channels), - layers_per_block=config.layers_per_block, - norm_num_groups=config.norm_num_groups, - act_fn=config.act_fn, - double_z=True, - mid_block_add_attention=config.mid_block_add_attention, - use_quant_conv=config.use_quant_conv, - device=config.device, - dtype=config.dtype, - ) - self.decoder = Decoder( - in_channels=config.latent_channels, - out_channels=config.out_channels, - up_block_types=tuple(config.up_block_types), - block_out_channels=tuple(config.block_out_channels), - layers_per_block=config.layers_per_block, - norm_num_groups=config.norm_num_groups, - act_fn=config.act_fn, - norm_type="group", - mid_block_add_attention=config.mid_block_add_attention, - use_post_quant_conv=config.use_post_quant_conv, - device=config.device, - dtype=config.dtype, - ) - - def __call__( - self, z: TensorValue, temb: TensorValue | None = None - ) -> TensorValue: - """Apply AutoencoderKLFlux2 forward pass (decoding only). - - Args: - z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. - temb: Optional time embedding tensor. - - Returns: - Decoded image tensor of shape [N, C_out, H, W]. - """ - return self.decoder(z, temb) - - -class AutoencoderKLFlux2Model(BaseAutoencoderModel): - """ComponentModel wrapper for AutoencoderKLFlux2. - - This class provides the ComponentModel interface for AutoencoderKLFlux2, - handling configuration, weight loading, model compilation, and BatchNorm - statistics for Flux2's latent patchification. - """ - - bn_running_mean: Buffer - bn_running_var: Buffer - - def __init__( - self, - config: dict[str, Any], - encoding: SupportedEncoding, - devices: list[Device], - weights: Weights, - **kwargs: Any, - ) -> None: - """Initialize AutoencoderKLFlux2Model. - - Args: - config: Model configuration dictionary. - encoding: Supported encoding for the model. - devices: List of devices to use. - weights: Model weights. - **kwargs: Additional keyword arguments forwarded to ComponentModel. - """ - super().__init__( - config=config, - encoding=encoding, - devices=devices, - weights=weights, - config_class=AutoencoderKLFlux2Config, - autoencoder_class=AutoencoderKLFlux2, - **kwargs, - ) - - @staticmethod - def _materialize(weight_data: Any) -> Buffer: - data = getattr(weight_data, "data", weight_data) - if isinstance(data, Buffer): - return data - return Buffer.from_dlpack(data) - - @staticmethod - def _extract_mean(moments: TensorValue) -> TensorValue: - return ops.chunk(moments, chunks=2, axis=1)[0] - - @traced(message="AutoencoderKLFlux2Model.load_model") - def load_model(self) -> Callable[..., Any]: - """Load encoder and BatchNorm statistics (skip standalone decoder). - - The standalone decoder compiled by the base class is not used in the - Flux2 pipeline, which decodes through build_fused_decode(). - - Returns: - Compiled encoder model callable. - """ - bn_stats: dict[str, Buffer] = {} - encoder_state_dict: dict[str, Any] = {} - target_dtype = self.config.dtype - - for key, value in self.weights.items(): - weight_data = value.data() - if weight_data.dtype != target_dtype: - if weight_data.dtype.is_float() and target_dtype.is_float(): - weight_data = weight_data.astype(target_dtype) - - if key in ("bn.running_mean", "bn.running_var"): - bn_stats[key] = self._materialize(weight_data).to( - self.devices[0] - ) - elif key.startswith("encoder."): - encoder_state_dict[key.removeprefix("encoder.")] = weight_data - elif key.startswith("quant_conv."): - encoder_state_dict[key] = weight_data - - bn_mean_data = bn_stats.get("bn.running_mean") - bn_var_data = bn_stats.get("bn.running_var") - if bn_mean_data is None or bn_var_data is None: - raise ValueError( - "BatchNorm statistics (running_mean, running_var) not loaded. " - "Make sure the model weights contain 'bn.running_mean' and 'bn.running_var'." - ) - - self.bn_running_mean = bn_mean_data - self.bn_running_var = bn_var_data - - autoencoder = AutoencoderKLFlux2(self.config) - self.encoder_model = self._compile_module( - autoencoder.encoder, - autoencoder.encoder.input_types(), - encoder_state_dict, - "autoencoder_kl_flux2_encoder_v2", - ) - - with Graph( - "autoencoder_kl_flux2_extract_mean", - input_types=( - TensorType( - self.config.dtype, - shape=[ - "batch", - self.config.latent_channels * 2, - "latent_height", - "latent_width", - ], - device=self.config.device, - ), - ), - ) as graph: - graph.output(self._extract_mean(graph.inputs[0].tensor)) - self._extract_mean_model = self.session.load(graph).execute - - return self.encoder_model - - @traced(message="AutoencoderKLFlux2Model.build_fused_decode") - def build_fused_decode( - self, device: Device, num_channels: int - ) -> Callable[..., Any]: - """Build a fused postprocess + VAE decode compiled graph. - - Combines BN denormalization, unpatchify, and VAE decoding into a single - compiled graph. - - Args: - device: Target device for the compiled graph. - num_channels: Number of latent channels after patchification. - - Returns: - Compiled callable taking packed latents and shape carriers. - """ - dtype = self.config.dtype - fused_state_dict: dict[str, Any] = {} - for key, value in self.weights.items(): - weight_data = value.data() - if weight_data.dtype != dtype: - if weight_data.dtype.is_float() and dtype.is_float(): - weight_data = weight_data.astype(dtype) - if key.startswith("decoder."): - fused_state_dict[f"decoder.{key.removeprefix('decoder.')}"] = ( - weight_data - ) - elif key.startswith("post_quant_conv."): - fused_state_dict[f"decoder.{key}"] = weight_data - - # The decode step receives BN tensors as runtime inputs, so do not load - # them as module weights. - autoencoder = AutoencoderKLFlux2(self.config) - decode_step = Flux2DecodeStep( - decoder=autoencoder.decoder, - batch_norm_eps=self.config.batch_norm_eps, - ) - compiled = self._compile_module( - decode_step, - decode_step.input_types(), - fused_state_dict, - "autoencoder_kl_flux2_decode_step_v2", - ) - - return lambda latents_bsc, h_carrier, w_carrier: self._unwrap_single( - compiled( - latents_bsc, - h_carrier, - w_carrier, - self.bn_running_mean, - self.bn_running_var, - ) - ) - - def encode( - self, sample: Buffer, return_dict: bool = True - ) -> dict[str, DiagonalGaussianDistribution] | DiagonalGaussianDistribution: - if self.encoder_model is None: - raise ValueError( - "Encoder not loaded. Check if encoder weights exist in the model." - ) - moments = self._unwrap_single(self.encoder_model(sample)) - mean = self._unwrap_single(self._extract_mean_model(moments)) - posterior = DiagonalGaussianDistribution(mean, moments) - if return_dict: - return {"latent_dist": posterior} - return posterior - - @property - def bn(self) -> SimpleNamespace: - """Access BatchNorm statistics in a diffusers-compatible shape.""" - return SimpleNamespace( - running_mean=self.bn_running_mean, - running_var=self.bn_running_var, - ) diff --git a/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_wan.py b/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_wan.py new file mode 100644 index 00000000000..45493402abe --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl_wan.py @@ -0,0 +1,628 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Wan VAE autoencoder -- slim ComponentModel with init-time graph compilation. + +All decoder/encoder graphs are compiled once at ``load_model()`` time with +symbolic spatial dims, so a single set of compiled models handles any +resolution without recompilation. + +Module classes live in ``vae.py``; this file only contains the +ComponentModel wrapper and numpy/buffer conversion helpers. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Callable +from typing import Any + +import numpy as np +from max.driver import CPU, Buffer, Device +from max.dtype import DType +from max.engine import InferenceSession, Model +from max.graph import DeviceRef, Graph, TensorType +from max.graph.buffer_utils import cast_dlpack_to +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.bfloat16_utils import float32_to_bfloat16_as_uint16 +from max.pipelines.lib.interfaces.component_model import ComponentModel +from max.profiler import Tracer + +from .model_config import AutoencoderKLWanConfig +from .vae import ( + WAN_DECODER_CACHE_SLOTS, + WAN_ENCODER_CHUNK_SIZE, + Decoder3dCached, + Encoder3dCached, + VAEDecoderFirstFrameCached, + VAEDecoderRestFrameCached, + VAEEncoderFirstChunk, + VAEEncoderRestChunk, + VAEPostQuantConv, + _use_nvidia_fcrs_conv3d, +) + +logger = logging.getLogger(__name__) + + +def _buffer_to_numpy_f32(buf: Buffer, cpu: CPU | None = None) -> np.ndarray: + """Convert a Buffer (possibly bf16) to f32 numpy on CPU.""" + cpu_buf = buf.to(cpu or CPU()) + if cpu_buf.dtype == DType.bfloat16: + u16 = np.from_dlpack( + cpu_buf.view(dtype=DType.uint16, shape=cpu_buf.shape) + ) + return (u16.astype(np.uint32) << 16).view(np.float32) + return np.from_dlpack(cpu_buf).astype(np.float32, copy=False) + + +def _numpy_f32_to_buffer( + arr: np.ndarray, target_dtype: DType, device: Device +) -> Buffer: + """Convert f32 numpy to Buffer on device with target dtype.""" + arr = np.ascontiguousarray(arr, dtype=np.float32) + if target_dtype == DType.bfloat16: + u16 = float32_to_bfloat16_as_uint16(arr) + return ( + Buffer.from_numpy(u16) + .to(device) + .view(dtype=DType.bfloat16, shape=arr.shape) + ) + return Buffer.from_numpy(arr).to(device) + + +class AutoencoderKLWanModel(ComponentModel): + """Wan VAE model using MAX-native 3D modules (decoder + optional encoder). + + All graphs are compiled once at ``load_model()`` time with symbolic spatial + dims, so a single set of compiled models handles any resolution. + """ + + def __init__( + self, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession | None = None, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = AutoencoderKLWanConfig.generate(config, encoding, devices) + self.config.dtype = DType.bfloat16 + + self.pqc_model: Model | None = None + self.first_frame_model: Model | None = None + self.rest_frame_model: Model | None = None + self.first_chunk_encoder: Model | None = None + self.rest_chunk_encoder: Model | None = None + + self._session = session or InferenceSession(devices=devices) + self._load_lock = threading.Lock() + + self.load_model() + + def load_model(self) -> Callable[[Buffer], Buffer]: + """Load weights, remap layouts, and compile all graphs.""" + with self._load_lock: + if self.pqc_model is not None: + return self.decode_4d + + decoder_state_dict: dict[str, Any] = {} + encoder_state_dict: dict[str, Any] = {} + has_encoder = False + target_dtype = self.config.dtype + + assert self.weights is not None + weights_obj: Any = self.weights + + for key, value in weights_obj.items(): + is_decoder = key.startswith("decoder.") or key.startswith( + "post_quant_conv." + ) + is_encoder = key.startswith("encoder.") or key.startswith( + "quant_conv." + ) + if not (is_decoder or is_encoder): + continue + + weight_data = value.data() + + # -- 5D conv weights: PyTorch FCQRS -> MAX native QRSCF -- + if key.endswith(".weight") and len(weight_data.shape) == 5: + use_native_layout = not _use_nvidia_fcrs_conv3d( + self.config.device + ) + if "time_conv" in key and ( + key.startswith("encoder.") + or key.startswith("quant_conv.") + ): + use_native_layout = True + if use_native_layout: + buf = ( + weight_data.to_buffer() + if hasattr(weight_data, "to_buffer") + else weight_data + ) + t_f32 = cast_dlpack_to( + buf, weight_data.dtype, DType.float32, CPU() + ) + weight_data = np.ascontiguousarray( + np.from_dlpack(t_f32).transpose(2, 3, 4, 1, 0) + ) + + # -- 4D conv weights -- + if key.endswith(".weight") and len(weight_data.shape) == 4: + is_resample_conv = "resample" in key + if not is_resample_conv: + buf = ( + weight_data.to_buffer() + if hasattr(weight_data, "to_buffer") + else weight_data + ) + t_f32 = cast_dlpack_to( + buf, weight_data.dtype, DType.float32, CPU() + ) + weight_data = np.ascontiguousarray( + np.from_dlpack(t_f32).transpose(2, 3, 1, 0) + ) + + if is_decoder: + decoder_state_dict[key] = weight_data + if is_encoder: + encoder_state_dict[key] = weight_data + has_encoder = True + + # Cast all weights to target dtype. + cpu_device = CPU() + for sd in (decoder_state_dict, encoder_state_dict): + for key in sd: + tensor = sd[key] + if hasattr(tensor, "to_buffer") and hasattr( + tensor, "dtype" + ): + src_dtype = tensor.dtype + if src_dtype == target_dtype: + continue + buf = tensor.to_buffer() + else: + src_dtype = DType.float32 + if src_dtype == target_dtype: + continue + buf = tensor + sd[key] = cast_dlpack_to( + buf, src_dtype, target_dtype, cpu_device + ) + + # Compile decoder graphs with symbolic dims. + self._compile_decoder_graphs(decoder_state_dict) + + # Compile encoder graphs (optional). + if has_encoder: + self._compile_encoder_graphs(encoder_state_dict) + + self.weights = None # type: ignore[assignment] + return self.decode_4d + + def _compile_decoder_graphs( + self, decoder_state_dict: dict[str, Any] + ) -> None: + """Compile PQC + first-frame + rest-frame decoder with symbolic dims.""" + cfg = self.config + dtype = cfg.dtype + dev = self.devices[0] + dev_ref = DeviceRef.from_device(dev) + + pqc_module = VAEPostQuantConv(cfg) + pqc_module.load_state_dict( + decoder_state_dict, weight_alignment=1, strict=False + ) + pqc_input_types = [ + TensorType(dtype, [1, cfg.z_dim, 1, "height", "width"], device=dev) + ] + with Graph("wan_vae_pqc", input_types=pqc_input_types) as pqc_graph: + out = pqc_module(pqc_graph.inputs[0].tensor) + pqc_graph.output(out) + self.pqc_model = self._session.load( + pqc_graph, weights_registry=pqc_module.state_dict() + ) + + first_module = VAEDecoderFirstFrameCached(cfg) + first_module.load_state_dict( + decoder_state_dict, weight_alignment=1, strict=False + ) + first_input_types = [ + TensorType(dtype, [1, cfg.z_dim, 1, "height", "width"], device=dev) + ] + with Graph( + "wan_vae_first_frame", input_types=first_input_types + ) as first_graph: + outputs = first_module(first_graph.inputs[0].tensor) + first_graph.output(*outputs) + self.first_frame_model = self._session.load( + first_graph, weights_registry=first_module.state_dict() + ) + + rest_module = VAEDecoderRestFrameCached(cfg) + rest_module.load_state_dict( + decoder_state_dict, weight_alignment=1, strict=False + ) + + # Build cache input types with level-specific symbolic dim names. + # Caches at the same decoder level share dim names so concat in + # forward_cached sees matching dims on non-concat axes. + decoder_for_shapes = Decoder3dCached( + dim=cfg.base_dim, + z_dim=cfg.z_dim, + dim_mult=tuple(cfg.dim_mult), + num_res_blocks=cfg.num_res_blocks, + temporal_upsample=tuple(reversed(cfg.temporal_downsample)), + out_channels=cfg.out_channels, + is_residual=cfg.is_residual, + dtype=dtype, + device=dev_ref, + ) + cache_shape_info = decoder_for_shapes.cache_shapes( + batch_size=1, latent_height=1, latent_width=1 + ) + + # Map each cache to its decoder level for dim naming. + cache_dim_names: list[tuple[str, str]] = [] + h_name, w_name = "height", "width" + level = 0 + + # conv_in cache + cache_dim_names.append((h_name, w_name)) + # mid_block: 2 resnets x 2 caches = 4 + for _ in range(4): + cache_dim_names.append((h_name, w_name)) + # up_blocks + for up_block in decoder_for_shapes.up_blocks: + for _ in up_block.resnets: + cache_dim_names.append((h_name, w_name)) + cache_dim_names.append((h_name, w_name)) + if up_block.upsamplers is not None: + if up_block._has_temporal_upsample: + cache_dim_names.append((h_name, w_name)) + level += 1 + h_name = f"h{level}" + w_name = f"w{level}" + # conv_out cache + cache_dim_names.append((h_name, w_name)) + + assert len(cache_dim_names) == WAN_DECODER_CACHE_SLOTS + + rest_input_types = [ + TensorType(dtype, [1, cfg.z_dim, 1, "height", "width"], device=dev) + ] + for i, shape in enumerate(cache_shape_info): + channels = shape[1] + cache_t = shape[2] + ch, cw = cache_dim_names[i] + rest_input_types.append( + TensorType(dtype, [1, channels, cache_t, ch, cw], device=dev) + ) + + with Graph( + "wan_vae_rest_frame", input_types=rest_input_types + ) as rest_graph: + z_input = rest_graph.inputs[0].tensor + cache_inputs = tuple(inp.tensor for inp in rest_graph.inputs[1:]) + outputs = rest_module(z_input, *cache_inputs) + rest_graph.output(*outputs) + self.rest_frame_model = self._session.load( + rest_graph, weights_registry=rest_module.state_dict() + ) + + def _compile_encoder_graphs( + self, encoder_state_dict: dict[str, Any] + ) -> None: + """Compile first-chunk + rest-chunk encoder with symbolic dims.""" + cfg = self.config + dtype = cfg.dtype + dev = self.devices[0] + dev_ref = DeviceRef.from_device(dev) + + first_module = VAEEncoderFirstChunk(cfg) + first_module.load_state_dict( + encoder_state_dict, weight_alignment=1, strict=False + ) + first_input_type = TensorType( + dtype, [1, 3, 1, "height", "width"], device=dev + ) + with Graph( + "wan_vae_enc_first", input_types=[first_input_type] + ) as first_graph: + outputs = first_module(first_graph.inputs[0].tensor) + first_graph.output(*outputs) + self.first_chunk_encoder = self._session.load( + first_graph, weights_registry=first_module.state_dict() + ) + + rest_module = VAEEncoderRestChunk(cfg) + rest_module.load_state_dict( + encoder_state_dict, weight_alignment=1, strict=False + ) + + encoder_for_shapes = Encoder3dCached( + dim=cfg.base_dim, + z_dim=cfg.z_dim, + in_channels=3, + dim_mult=cfg.dim_mult, + num_res_blocks=cfg.num_res_blocks, + temporal_downsample=cfg.temporal_downsample, + dtype=dtype, + device=dev_ref, + ) + # Encoder has no upsample so cache shapes use same dims throughout. + cache_shape_info = encoder_for_shapes.cache_shapes( + batch_size=1, height=None, width=None + ) + + rest_input_types = [ + TensorType( + dtype, + [1, 3, WAN_ENCODER_CHUNK_SIZE, "height", "width"], + device=dev, + ) + ] + for i, shape in enumerate(cache_shape_info): + channels = shape[1] + cache_t = shape[2] + assert channels is not None and cache_t is not None + rest_input_types.append( + TensorType( + dtype, + [1, channels, cache_t, f"eh{i}", f"ew{i}"], + device=dev, + ) + ) + + with Graph( + "wan_vae_enc_rest", input_types=rest_input_types + ) as rest_graph: + rest_inputs = [inp.tensor for inp in rest_graph.inputs] + outputs = rest_module(rest_inputs[0], *rest_inputs[1:]) + rest_graph.output(*outputs) + self.rest_chunk_encoder = self._session.load( + rest_graph, weights_registry=rest_module.state_dict() + ) + + def decode_5d(self, latents_5d: Buffer) -> Buffer: + """Decode 5D latents [B, C, T, H, W] frame-by-frame.""" + if self.pqc_model is None: + self.load_model() + pqc_model = self.pqc_model + first_frame_model = self.first_frame_model + rest_frame_model = self.rest_frame_model + assert pqc_model is not None + assert first_frame_model is not None + assert rest_frame_model is not None + + t_total = int(latents_5d.shape[2]) + if t_total <= 0: + raise ValueError("Expected non-empty temporal dimension for decode") + + cpu = CPU() + latents_np = _buffer_to_numpy_f32(latents_5d, cpu) + device = self.devices[0] + target_dtype = self.config.dtype + + decoded_frames: list[np.ndarray] = [] + caches: list[Buffer] | None = None + + with Tracer("wan_vae_decode"): + for t_idx in range(t_total): + z_t_np = np.ascontiguousarray( + latents_np[:, :, t_idx : t_idx + 1, :, :] + ) + z_t_buf = _numpy_f32_to_buffer(z_t_np, target_dtype, device) + + # Post-quant conv + pqc_outputs = pqc_model.execute(z_t_buf) + if len(pqc_outputs) != 1: + raise ValueError( + f"Expected 1 output from post_quant_conv, " + f"got {len(pqc_outputs)}" + ) + z_t_buf = pqc_outputs[0] + + if t_idx == 0: + outputs = first_frame_model.execute(z_t_buf) + else: + if caches is None: + raise ValueError( + "Cached framewise decoder expected caches " + "after first frame." + ) + outputs = rest_frame_model.execute(z_t_buf, *caches) + + if len(outputs) != 1 + WAN_DECODER_CACHE_SLOTS: + raise ValueError( + "Cached framewise decoder produced " + f"{len(outputs)} tensors; " + f"expected {1 + WAN_DECODER_CACHE_SLOTS}." + ) + + decoded_buf = outputs[0] + caches = list(outputs[1:]) + decoded_frames.append(_buffer_to_numpy_f32(decoded_buf, cpu)) + + stitched = np.ascontiguousarray(np.concatenate(decoded_frames, axis=2)) + return Buffer.from_numpy(stitched) + + def decode_4d(self, latents_4d: Buffer) -> Buffer: + """Decode 4D latents by adding and removing a temporal dim.""" + shape_5d = ( + int(latents_4d.shape[0]), + int(latents_4d.shape[1]), + 1, + int(latents_4d.shape[2]), + int(latents_4d.shape[3]), + ) + z5d = latents_4d.view(dtype=latents_4d.dtype, shape=shape_5d) + decoded_5d = self.decode_5d(z5d) + # Remove temporal dimension from decoded output. + cpu = CPU() + decoded_np = _buffer_to_numpy_f32(decoded_5d, cpu) + return Buffer.from_numpy( + np.ascontiguousarray(decoded_np[:, :, 0, :, :]) + ) + + def decode( + self, latents: Buffer, return_dict: bool = False + ) -> tuple[Buffer]: + del return_dict + if latents.rank == 5: + return (self.decode_5d(latents),) + return (self.decode_4d(latents),) + + def encode(self, video: Buffer) -> Buffer: + """Encode a video tensor [B, 3, T, H, W] to latent space. + + Uses chunked encoding matching diffusers: first frame processed + separately, then 4-frame chunks with temporal caching. + + Returns the mean of the diagonal Gaussian (argmax mode), + shape [B, z_dim, T_latent, H_latent, W_latent]. + """ + if self.first_chunk_encoder is None: + self.load_model() + first_chunk_encoder = self.first_chunk_encoder + rest_chunk_encoder = self.rest_chunk_encoder + if first_chunk_encoder is None or rest_chunk_encoder is None: + raise RuntimeError( + "VAE encoder weights not available. " + "Ensure the model checkpoint includes encoder weights." + ) + + video_np = _buffer_to_numpy_f32(video, CPU()) + target_dtype = self.config.dtype + device = self.devices[0] + cpu = CPU() + + t_total = video_np.shape[2] + latent_chunks: list[np.ndarray] = [] + caches: list[Buffer] | None = None + num_chunks = 1 + (t_total - 1) // WAN_ENCODER_CHUNK_SIZE + + with Tracer("wan_vae_encode"): + for i in range(num_chunks): + if i == 0: + chunk_np = np.ascontiguousarray(video_np[:, :, :1]) + else: + start = 1 + WAN_ENCODER_CHUNK_SIZE * (i - 1) + end = 1 + WAN_ENCODER_CHUNK_SIZE * i + chunk_np = np.ascontiguousarray(video_np[:, :, start:end]) + + chunk_buf = _numpy_f32_to_buffer(chunk_np, target_dtype, device) + + if i == 0: + outputs = first_chunk_encoder.execute(chunk_buf) + else: + assert caches is not None + outputs = rest_chunk_encoder.execute(chunk_buf, *caches) + + latent_chunks.append(_buffer_to_numpy_f32(outputs[0], cpu)) + caches = list(outputs[1:]) + + full_latent = np.ascontiguousarray( + np.concatenate(latent_chunks, axis=2) + ) + return _numpy_f32_to_buffer(full_latent, target_dtype, device) + + def encode_zero_padded_video_condition( + self, + first_frame: np.ndarray, + *, + batch_size: int, + num_frames: int, + ) -> Buffer: + """Encode a zero-padded I2V conditioning video without materializing it. + + The conditioning path only contains a real first frame; all later + frames are zeros. Stream those chunks directly into the cached encoder + so we avoid allocating the full ``[B, 3, T, H, W]`` input tensor. + """ + if num_frames <= 0: + raise ValueError("num_frames must be positive for I2V encoding.") + if first_frame.ndim != 4: + raise ValueError( + "Expected first_frame with shape [B, 3, H, W], " + f"got {first_frame.shape}." + ) + + image_f32 = np.ascontiguousarray(first_frame, dtype=np.float32) + if image_f32.shape[0] == 1 and batch_size > 1: + image_f32 = np.repeat(image_f32, batch_size, axis=0) + elif image_f32.shape[0] != batch_size: + raise ValueError( + "first_frame batch dimension must be 1 or match batch_size, " + f"got {image_f32.shape[0]} and {batch_size}." + ) + + chunks = [image_f32[:, :, np.newaxis, :, :]] + if num_frames > 1: + _, channels, height, width = image_f32.shape + zero_chunk = np.zeros( + (batch_size, channels, WAN_ENCODER_CHUNK_SIZE, height, width), + dtype=np.float32, + ) + remaining_frames = num_frames - 1 + while remaining_frames > 0: + chunk_len = min(WAN_ENCODER_CHUNK_SIZE, remaining_frames) + chunks.append(zero_chunk[:, :, :chunk_len]) + remaining_frames -= chunk_len + + return self._encode_chunk_sequence(chunks) + + def _encode_chunk_sequence(self, chunks: list[np.ndarray]) -> Buffer: + """Encode a pre-split Wan VAE chunk sequence.""" + if self.first_chunk_encoder is None: + self.load_model() + first_chunk_encoder = self.first_chunk_encoder + rest_chunk_encoder = self.rest_chunk_encoder + if first_chunk_encoder is None or rest_chunk_encoder is None: + raise RuntimeError( + "VAE encoder weights not available. " + "Ensure the model checkpoint includes encoder weights." + ) + + target_dtype = self.config.dtype + device = self.devices[0] + cpu = CPU() + + latent_chunks: list[np.ndarray] = [] + caches: list[Buffer] | None = None + with Tracer("wan_vae_encode"): + for i, chunk_np in enumerate(chunks): + chunk_buf = _numpy_f32_to_buffer(chunk_np, target_dtype, device) + if i == 0: + outputs = first_chunk_encoder.execute(chunk_buf) + else: + assert caches is not None + outputs = rest_chunk_encoder.execute(chunk_buf, *caches) + + latent_chunks.append(_buffer_to_numpy_f32(outputs[0], cpu)) + caches = list(outputs[1:]) + + full_latent = np.ascontiguousarray( + np.concatenate(latent_chunks, axis=2) + ) + return _numpy_f32_to_buffer(full_latent, target_dtype, device) + + def __call__(self, latents: Buffer) -> Buffer: + if latents.rank == 5: + return self.decode_5d(latents) + return self.decode_4d(latents) diff --git a/max/python/max/pipelines/architectures/autoencoders/decode_step_flux2.py b/max/python/max/pipelines/architectures/autoencoders/decode_step_flux2.py deleted file mode 100644 index 344ba2ba781..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/decode_step_flux2.py +++ /dev/null @@ -1,147 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -"""Fused decode-step module for the Flux2 pipeline. - -Combines Flux2-specific BN denorm + unpatchify with the VAE decoder forward -pass into a single compiled graph, eliminating the inter-graph boundary that -previously existed between _postprocess_latents and vae.decode(). -""" - -from max.dtype import DType -from max.graph import DeviceRef, TensorType, TensorValue, ops -from max.nn.layer import Module - -from .vae import Decoder - -Tensor = TensorValue -F = ops - - -class Flux2DecodeStep(Module): - """Fused postprocess-and-decode: packed latents -> decoded image. - - Combines Flux2-specific BN denorm + unpatchify with the VAE decoder - forward pass into a single compiled graph, eliminating the inter-graph - boundary that previously existed between postprocess_latents and - vae.decode(). - - Accepts packed latents in (B, S, C) shape where S = latent_h * latent_w. - Spatial dimensions are conveyed via two 1-D shape-carrier tensors whose - *lengths* encode latent_h and latent_w as symbolic graph Dims, so a single - compiled graph handles any spatial size without recompilation. - """ - - def __init__(self, decoder: Decoder, batch_norm_eps: float) -> None: - """Initialize Flux2DecodeStep. - - Args: - decoder: Raw (uncompiled) Decoder sub-module. - batch_norm_eps: Epsilon value for BatchNorm denormalization. - """ - super().__init__() - self.decoder = decoder - self.batch_norm_eps = batch_norm_eps - - def input_types(self) -> tuple[TensorType, ...]: - """Return input TensorTypes for compilation. - - Returns: - Tuple of TensorType objects corresponding to the forward() signature: - (latents_bsc, h_carrier, w_carrier, bn_mean, bn_var). - """ - num_channels = self.decoder.in_channels * 4 # e.g. 32*4 = 128 - dtype = self.decoder.dtype - device = self.decoder.device - assert dtype is not None, "Decoder dtype must be set before compilation" - assert device is not None, ( - "Decoder device must be set before compilation" - ) - return ( - TensorType( - dtype, shape=["batch", "seq", num_channels], device=device - ), - # Shape carriers: lengths encode latent_h / latent_w as symbolic dims. - # Content is never read; only the shapes matter. - TensorType( - DType.float32, shape=["latent_h"], device=DeviceRef.CPU() - ), - TensorType( - DType.float32, shape=["latent_w"], device=DeviceRef.CPU() - ), - TensorType(dtype, shape=[num_channels], device=device), - TensorType(dtype, shape=[num_channels], device=device), - ) - - def forward( - self, - latents_bsc: Tensor, - h_carrier: Tensor, - w_carrier: Tensor, - bn_mean: Tensor, - bn_var: Tensor, - ) -> Tensor: - """Run BN denorm + unpatchify + VAE decode in one fused graph. - - Args: - latents_bsc: Packed latents of shape (B, S, C) where S = latent_h * latent_w. - h_carrier: 1-D shape carrier of length latent_h (content unused). - w_carrier: 1-D shape carrier of length latent_w (content unused). - bn_mean: BatchNorm running mean of shape (C,). - bn_var: BatchNorm running variance of shape (C,). - - Returns: - Decoded image tensor of shape (B, H, W, C) after post-processing. - """ - batch = latents_bsc.shape[0] - c = latents_bsc.shape[2] - # Extract spatial dims from carrier shapes (symbolic Dims, not runtime values) - h = h_carrier.shape[0] - w = w_carrier.shape[0] - - # Assert seq == latent_h * latent_w so the reshape verifier accepts it, - # then reshape packed (B, S, C) -> spatial (B, H, W, C). - latents_bsc = F.rebind(latents_bsc, [batch, h * w, c]) - latents_bhwc = F.reshape(latents_bsc, (batch, h, w, c)) - - # Permute: (B, H, W, C) -> (B, C, H, W) - latents = F.permute(latents_bhwc, [0, 3, 1, 2]) - - # BN denormalization - bn_mean_r = F.reshape(bn_mean, (1, c, 1, 1)) - bn_var_r = F.reshape(bn_var, (1, c, 1, 1)) - bn_std = F.sqrt(bn_var_r + self.batch_norm_eps) - latents = latents * bn_std + bn_mean_r - - # Unpatchify: (B, C, H, W) -> (B, C//4, H*2, W*2) - latents = F.reshape(latents, (batch, c // 4, 2, 2, h, w)) - latents = F.permute(latents, [0, 1, 4, 2, 5, 3]) - latents = F.reshape(latents, (batch, c // 4, h * 2, w * 2)) - - decoded = self.decoder(latents, None) - decoded = F.permute(decoded, [0, 2, 3, 1]) - decoded = decoded * 0.5 + 0.5 - decoded = F.max(decoded, 0.0) - decoded = F.min(decoded, 1.0) - decoded = decoded * 255.0 - return F.transfer_to(F.cast(decoded, DType.uint8), DeviceRef.CPU()) - - def __call__( - self, - latents_bsc: Tensor, - h_carrier: Tensor, - w_carrier: Tensor, - bn_mean: Tensor, - bn_var: Tensor, - ) -> Tensor: - return self.forward(latents_bsc, h_carrier, w_carrier, bn_mean, bn_var) diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/__init__.py b/max/python/max/pipelines/architectures/autoencoders/layers/__init__.py deleted file mode 100644 index 548a3755f8a..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/layers/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -from .attention import VAEAttention -from .downsampling import Downsample2D -from .resnet import ResnetBlock2D -from .upsampling import Upsample2D diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/attention.py b/max/python/max/pipelines/architectures/autoencoders/layers/attention.py deleted file mode 100644 index 3b794f83e29..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/layers/attention.py +++ /dev/null @@ -1,145 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -import math - -from max.dtype import DType -from max.graph import DeviceRef, TensorValue, ops -from max.nn.layer import LayerList, Module -from max.nn.linear import Linear -from max.nn.norm import GroupNorm - - -class VAEAttention(Module): - """Spatial attention module for VAE models. - - This module performs self-attention on 2D spatial features by: - 1. Converting [N, C, H, W] to [N, H*W, C] sequence format - 2. Applying scaled dot-product attention (optimized for small sequences) - 3. Converting back to [N, C, H, W] format - - Note: Manual attention is used instead of flash-attention style kernels - because VAE attention typically has small sequence lengths (H*W) where - launch overhead outweighs benefits. - """ - - def __init__( - self, - query_dim: int, - heads: int, - dim_head: int, - num_groups: int = 32, - eps: float = 1e-6, - device: DeviceRef | None = None, - dtype: DType | None = None, - ) -> None: - """Initialize VAE attention module. - - Args: - query_dim: Dimension of query (number of channels). - heads: Number of attention heads. - dim_head: Dimension of each attention head. - num_groups: Number of groups for GroupNorm. - eps: Epsilon value for GroupNorm. - device: Device reference. - dtype: Data type. - """ - super().__init__() - if dtype is None: - raise ValueError("dtype must be set for VAEAttention") - if device is None: - raise ValueError("device must be set for VAEAttention") - - self.query_dim = query_dim - self.heads = heads - self.dim_head = dim_head - self.inner_dim = heads * dim_head - self.group_norm = GroupNorm( - num_groups=num_groups, - num_channels=query_dim, - eps=eps, - affine=True, - device=device, - ) - self.to_q = Linear( - in_dim=query_dim, - out_dim=self.inner_dim, - dtype=dtype, - device=device, - has_bias=True, - ) - self.to_k = Linear( - in_dim=query_dim, - out_dim=self.inner_dim, - dtype=dtype, - device=device, - has_bias=True, - ) - self.to_v = Linear( - in_dim=query_dim, - out_dim=self.inner_dim, - dtype=dtype, - device=device, - has_bias=True, - ) - self.to_out = LayerList( - [ - Linear( - in_dim=self.inner_dim, - out_dim=query_dim, - dtype=dtype, - device=device, - has_bias=True, - ) - ] - ) - self.scale = 1.0 / math.sqrt(dim_head) - - def __call__(self, x: TensorValue) -> TensorValue: - """Apply spatial attention to a 2D image tensor. - - Args: - x: Input tensor of shape [N, C, H, W]. - - Returns: - Output tensor of shape [N, C, H, W] with residual connection. - """ - residual = x - x = self.group_norm(x) - - n, c, h, w = x.shape - seq_len = h * w - x = ops.reshape(x, [n, c, seq_len]) - x = ops.permute(x, [0, 2, 1]) - - q = self.to_q(x) - k = self.to_k(x) - v = self.to_v(x) - - q = ops.reshape(q, [n, seq_len, self.heads, self.dim_head]) - q = ops.permute(q, [0, 2, 1, 3]) - k = ops.reshape(k, [n, seq_len, self.heads, self.dim_head]) - k = ops.permute(k, [0, 2, 1, 3]) - v = ops.reshape(v, [n, seq_len, self.heads, self.dim_head]) - v = ops.permute(v, [0, 2, 1, 3]) - - attn = (q @ ops.permute(k, [0, 1, 3, 2])) * self.scale - attn = ops.softmax(attn, axis=-1) - out = attn @ v - - out = ops.permute(out, [0, 2, 1, 3]) - out = ops.reshape(out, [n, seq_len, self.inner_dim]) - out = self.to_out[0](out) - out = ops.permute(out, [0, 2, 1]) - out = ops.reshape(out, [n, c, h, w]) - return residual + out diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/downsampling.py b/max/python/max/pipelines/architectures/autoencoders/layers/downsampling.py deleted file mode 100644 index 13552f0e94c..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/layers/downsampling.py +++ /dev/null @@ -1,165 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -"""Downsampling utilities for MAX framework.""" - -from max.dtype import DType -from max.graph import DeviceRef, TensorValue, ops -from max.graph.ops import avg_pool2d -from max.nn.conv import Conv2d -from max.nn.layer import Module -from max.nn.norm import LayerNorm, RMSNorm - - -class Downsample2D(Module): - """A 2D downsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - padding (`int`, default `1`): - padding for the convolution. - name (`str`, default `conv`): - name of the downsampling 2D layer. - kernel_size (`int`, default `3`): - kernel size for the convolution. - norm_type (`str`, optional): - normalization type. Supported: "ln_norm" (LayerNorm), "rms_norm" - (RMSNorm), or None. - eps (`float`, optional): - epsilon for normalization. Defaults to 1e-5 for LayerNorm, 1e-6 - for RMSNorm. - elementwise_affine (`bool`, optional): - elementwise affine for normalization. Only used for LayerNorm. - Defaults to True. - bias (`bool`, default `True`): - whether to use bias in the convolution. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - out_channels: int | None = None, - padding: int = 1, - name: str = "conv", - kernel_size: int = 3, - norm_type: str | None = None, - eps: float | None = None, - elementwise_affine: bool | None = None, - bias: bool = True, - device: DeviceRef | None = None, - dtype: DType | None = None, - ) -> None: - """Initialize 2D downsampling module. - - Args: - channels: Number of input channels. - use_conv: Whether to use convolution for downsampling. - out_channels: Number of output channels. If None, uses channels. - padding: Padding for the convolution. - name: Name for the convolution layer (unused, kept for compatibility). - kernel_size: Kernel size for the convolution. - norm_type: Normalization type ("ln_norm", "rms_norm", or None). - eps: Epsilon for normalization. Defaults to 1e-5 for LayerNorm, - 1e-6 for RMSNorm. - elementwise_affine: Elementwise affine for LayerNorm. - bias: Whether to use bias in the convolution. - device: Device reference for module placement. - dtype: Data type for module parameters. - """ - super().__init__() - stride = 2 - if dtype is None: - raise ValueError("dtype must be set for Downsample2D") - if device is None: - raise ValueError("device must be set for Downsample2D") - - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - self.name = name - - self.norm: LayerNorm | RMSNorm | None = None - if norm_type == "ln_norm": - self.norm = LayerNorm( - dims=channels, - devices=[device], - dtype=dtype, - eps=eps or 1e-5, - use_bias=( - True if elementwise_affine is None else elementwise_affine - ), - ) - elif norm_type == "rms_norm": - self.norm = RMSNorm( - dim=channels, - dtype=dtype, - eps=eps or 1e-6, - ) - elif norm_type is not None: - raise ValueError(f"unknown norm_type: {norm_type}") - - self.conv: Conv2d | None = None - if use_conv: - self.conv = Conv2d( - kernel_size=kernel_size, - in_channels=channels, - out_channels=self.out_channels, - dtype=dtype, - stride=stride, - padding=padding, - has_bias=bias, - device=device, - permute=True, - ) - elif channels != self.out_channels: - raise ValueError( - f"When use_conv=False, channels must equal out_channels. " - f"Got channels={channels}, out_channels={self.out_channels}" - ) - - def __call__(self, hidden_states: TensorValue) -> TensorValue: - """Apply 2D downsampling with optional convolution. - - Args: - hidden_states: Input tensor of shape [N, C, H, W]. - - Returns: - Downsampled tensor of shape [N, C_out, H//2, W//2]. - """ - if self.norm is not None: - hidden_states = ops.permute(hidden_states, [0, 2, 3, 1]) - hidden_states = self.norm(hidden_states) - hidden_states = ops.permute(hidden_states, [0, 3, 1, 2]) - - if self.use_conv and self.padding == 0: - hidden_states = ops.pad(hidden_states, [0, 0, 0, 0, 0, 1, 0, 1]) - - if self.use_conv: - assert self.conv is not None - return self.conv(hidden_states) - - hidden_states = ops.permute(hidden_states, [0, 2, 3, 1]) - hidden_states = avg_pool2d( - hidden_states, - kernel_size=(2, 2), - stride=2, - padding=0, - ) - return ops.permute(hidden_states, [0, 3, 1, 2]) diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/resnet.py b/max/python/max/pipelines/architectures/autoencoders/layers/resnet.py deleted file mode 100644 index 2c0e7ca9048..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/layers/resnet.py +++ /dev/null @@ -1,146 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -from max.dtype import DType -from max.graph import DeviceRef, TensorValue -from max.nn.activation import activation_function_from_name -from max.nn.conv import Conv2d -from max.nn.layer import Module -from max.nn.norm import GroupNorm - - -class ResnetBlock2D(Module): - """Residual block for 2D VAE decoder. - - This module implements a residual block with two convolutional layers, - group normalization, and optional shortcut connection. It supports - time embedding conditioning and configurable activation functions. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int | None, - groups: int, - groups_out: int, - eps: float = 1e-6, - non_linearity: str = "silu", - use_conv_shortcut: bool = False, - conv_shortcut_bias: bool = True, - device: DeviceRef | None = None, - dtype: DType | None = None, - ) -> None: - """Initialize ResnetBlock2D module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - temb_channels: Number of time embedding channels (None if not used). - groups: Number of groups for first GroupNorm. - groups_out: Number of groups for second GroupNorm. - eps: Epsilon value for GroupNorm layers. - non_linearity: Activation function name (e.g., "silu"). - use_conv_shortcut: Whether to use convolutional shortcut. - conv_shortcut_bias: Whether to use bias in shortcut convolution. - device: Device reference for module placement. - dtype: Data type for module parameters. - """ - super().__init__() - del temb_channels - if dtype is None: - raise ValueError("dtype must be set for ResnetBlock2D") - if device is None: - raise ValueError("device must be set for ResnetBlock2D") - - self.in_channels = in_channels - self.out_channels = out_channels - self.use_conv_shortcut = use_conv_shortcut - self.activation = activation_function_from_name(non_linearity) - self.norm1 = GroupNorm( - num_groups=groups, - num_channels=in_channels, - eps=eps, - affine=True, - device=device, - ) - self.conv1 = Conv2d( - kernel_size=3, - in_channels=in_channels, - out_channels=out_channels, - dtype=dtype, - stride=1, - padding=1, - dilation=1, - num_groups=1, - has_bias=True, - device=device, - permute=True, - ) - self.norm2 = GroupNorm( - num_groups=groups_out, - num_channels=out_channels, - eps=eps, - affine=True, - device=device, - ) - self.conv2 = Conv2d( - kernel_size=3, - in_channels=out_channels, - out_channels=out_channels, - dtype=dtype, - stride=1, - padding=1, - dilation=1, - num_groups=1, - has_bias=True, - device=device, - permute=True, - ) - self.conv_shortcut: Conv2d | None = None - if self.use_conv_shortcut or in_channels != out_channels: - self.conv_shortcut = Conv2d( - kernel_size=1, - in_channels=in_channels, - out_channels=out_channels, - dtype=dtype, - stride=1, - padding=0, - dilation=1, - num_groups=1, - has_bias=conv_shortcut_bias, - device=device, - permute=True, - ) - - def __call__( - self, x: TensorValue, temb: TensorValue | None = None - ) -> TensorValue: - """Apply ResnetBlock2D forward pass. - - Args: - x: Input tensor of shape [N, C, H, W]. - temb: Optional time embedding tensor (currently unused). - - Returns: - Output tensor of shape [N, C_out, H, W] with residual connection. - """ - del temb - shortcut = ( - self.conv_shortcut(x) if self.conv_shortcut is not None else x - ) - h = self.activation(self.norm1(x)) - h = self.conv1(h) - h = self.activation(self.norm2(h)) - h = self.conv2(h) - return h + shortcut diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py b/max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py deleted file mode 100644 index f89efab6b27..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py +++ /dev/null @@ -1,138 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -"""Upsampling utilities for MAX framework.""" - -from max.dtype import DType -from max.graph import DeviceRef, TensorValue, ops -from max.nn.conv import Conv2d -from max.nn.layer import Module - - -def interpolate_2d_nearest( - x: TensorValue, - scale_factor: int = 2, -) -> TensorValue: - """Upsamples a 2D tensor using nearest-neighbor interpolation. - - This is a workaround implementation because MAX framework resize does not - currently support NEAREST mode in this path. The workaround uses reshape - and broadcast operations to achieve nearest-neighbor upsampling by a factor - of 2. - - Note: - This workaround can be removed once native nearest-neighbor resize is - available in this path. - """ - - if x.rank != 4: - raise ValueError(f"Input tensor must have rank 4, got {x.rank}") - if scale_factor != 2: - raise NotImplementedError( - f"Only scale_factor=2 is currently supported, got {scale_factor}" - ) - - batch, channels, height, width = x.shape - x = ops.reshape(x, [batch, channels, height, 1, width, 1]) - ones = ops.broadcast_to( - ops.constant(1.0, dtype=x.dtype, device=x.device), - [1, 1, 1, scale_factor, 1, scale_factor], - ) - return ops.reshape( - x * ones, - [batch, channels, height * scale_factor, width * scale_factor], - ) - - -class Upsample2D(Module): - """2D upsampling module with optional convolution. - - This module performs 2D upsampling using nearest-neighbor interpolation - followed by an optional convolution layer. - """ - - def __init__( - self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, - out_channels: int | None = None, - name: str = "conv", - kernel_size: int | None = None, - padding: int = 1, - bias: bool = True, - interpolate: bool = True, - device: DeviceRef | None = None, - dtype: DType | None = None, - ) -> None: - """Initialize 2D upsampling module. - - Args: - channels: Number of input channels. - use_conv: Whether to apply a convolution after upsampling. - use_conv_transpose: Whether to use transposed convolution (not supported yet). - out_channels: Number of output channels. If None, uses channels. - name: Name for the convolution layer (unused, kept for compatibility). - kernel_size: Kernel size for the convolution. - padding: Padding for the convolution. - bias: Whether to use bias in the convolution. - interpolate: Whether to perform interpolation upsampling. - device: Device reference. - dtype: Data type. - """ - super().__init__() - if dtype is None: - raise ValueError("dtype must be set for Upsample2D") - if device is None: - raise ValueError("device must be set for Upsample2D") - if use_conv_transpose: - raise NotImplementedError( - "Upsample2D does not support use_conv_transpose=True yet." - ) - - self.channels = channels - self.out_channels = out_channels or channels - self.interpolate = interpolate - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.device = device - self.dtype = dtype - self.name = name - self.conv: Conv2d | None = None - if self.use_conv: - self.conv = Conv2d( - kernel_size=3 if kernel_size is None else kernel_size, - in_channels=self.channels, - out_channels=self.out_channels, - dtype=dtype, - stride=1, - padding=padding, - has_bias=bias, - device=device, - permute=True, - ) - - def __call__(self, x: TensorValue) -> TensorValue: - """Apply 2D upsampling with optional convolution. - - Args: - x: Input tensor of shape [N, C, H, W]. - - Returns: - Upsampled tensor, optionally convolved. - """ - if self.interpolate: - x = interpolate_2d_nearest(x, scale_factor=2) - if self.use_conv and self.conv is not None: - x = self.conv(x) - return x diff --git a/max/python/max/pipelines/architectures/autoencoders/model.py b/max/python/max/pipelines/architectures/autoencoders/model.py deleted file mode 100644 index 9f9714052fe..00000000000 --- a/max/python/max/pipelines/architectures/autoencoders/model.py +++ /dev/null @@ -1,225 +0,0 @@ -# ===----------------------------------------------------------------------=== # -# Copyright (c) 2026, Modular Inc. All rights reserved. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions: -# https://llvm.org/LICENSE.txt -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ===----------------------------------------------------------------------=== # - -from collections.abc import Callable, Mapping -from typing import Any - -from max.driver import Buffer, Device -from max.engine import InferenceSession, Model -from max.graph import Graph, TensorType -from max.graph.weights import WeightData, Weights -from max.nn.layer import Module -from max.pipelines.lib import SupportedEncoding -from max.pipelines.lib.interfaces.component_model import ComponentModel - -from .model_config import AutoencoderKLConfigBase -from .vae import DiagonalGaussianDistribution - - -class BaseAutoencoderModel(ComponentModel): - """Base class for autoencoder models with shared logic. - - This base class provides common functionality for loading and running - autoencoder decoders. Subclasses should specify the config and autoencoder - classes to use. - """ - - def __init__( - self, - config: dict[str, Any], - encoding: SupportedEncoding, - devices: list[Device], - weights: Weights, - config_class: type[AutoencoderKLConfigBase], - autoencoder_class: type, - **kwargs: Any, - ) -> None: - """Initialize base autoencoder model. - - Args: - config: Model configuration dictionary. - encoding: Supported encoding for the model. - devices: List of devices to use. - weights: Model weights. - config_class: Configuration class to use. - autoencoder_class: Autoencoder class to use. - **kwargs: Additional keyword arguments forwarded to ComponentModel. - """ - super().__init__(config, encoding, devices, weights, **kwargs) - self.config = config_class.generate(config, encoding, devices) # type: ignore[attr-defined] - self.autoencoder_class = autoencoder_class - self.session = InferenceSession(devices=[*devices]) - self.encoder_model: Callable[..., Any] | None = None - self.load_model() - - @staticmethod - def _unwrap_single(output: Any) -> Any: - if isinstance(output, (list, tuple)): - return output[0] - return output - - def _compile_module( - self, - module: Module, - input_types: tuple[TensorType, ...], - state_dict: Mapping[str, Buffer | WeightData | Any], - graph_name: str, - ) -> Callable[..., Any]: - normalized_state_dict = dict(state_dict) - for name, weight in module.raw_state_dict().items(): - if name not in normalized_state_dict: - continue - value = normalized_state_dict[name] - value_dtype = getattr(value, "dtype", None) - if value_dtype != weight.dtype: - if ( - value_dtype is not None - and value_dtype.is_float() - and weight.dtype.is_float() - and hasattr(value, "astype") - ): - normalized_state_dict[name] = value.astype(weight.dtype) - - module.load_state_dict( - normalized_state_dict, - weight_alignment=1, - strict=True, - ) - weights_registry = module.state_dict(auto_initialize=False) - - with Graph(graph_name, input_types=input_types) as graph: - output = module(*(value.tensor for value in graph.inputs)) - if isinstance(output, (list, tuple)): - graph.output(*output) - else: - graph.output(output) - - model: Model = self.session.load( - graph, weights_registry=weights_registry - ) - return model.execute - - def load_model(self) -> Callable[..., Any]: - """Load and compile decoder and encoder from full model weights. - - Splits weights by prefix (decoder/post_quant_conv vs encoder/quant_conv) - and compiles each subgraph. quant_conv is included in the encoder when - config.use_quant_conv is True. Encoder is compiled only when the model - has an encoder and encoder weights are present. - - Returns: - Compiled decoder model callable. - """ - decoder_state_dict = {} - encoder_state_dict = {} - target_dtype = self.config.dtype - - for key, value in self.weights.items(): - adapted_key = key - # Some checkpoints nest VAE params under a top-level module prefix. - # Normalize to raw autoencoder names before routing to encoder/decoder. - while adapted_key.startswith(("vae.", "model.")): - if adapted_key.startswith("vae."): - adapted_key = adapted_key.removeprefix("vae.") - continue - adapted_key = adapted_key.removeprefix("model.") - - weight_data = value.data() - if weight_data.dtype != target_dtype: - if weight_data.dtype.is_float() and target_dtype.is_float(): - weight_data = weight_data.astype(target_dtype) - # Non-float weights are left as-is and skipped for decoder/encoder - # state dicts if their prefixes do not match. - - if adapted_key.startswith("decoder."): - decoder_state_dict[adapted_key.removeprefix("decoder.")] = ( - weight_data - ) - elif adapted_key.startswith("post_quant_conv."): - decoder_state_dict[adapted_key] = weight_data - elif adapted_key.startswith("encoder."): - encoder_state_dict[adapted_key.removeprefix("encoder.")] = ( - weight_data - ) - elif adapted_key.startswith("quant_conv."): - encoder_state_dict[adapted_key] = weight_data - - autoencoder = self.autoencoder_class(self.config) - self.model = self._compile_module( - autoencoder.decoder, - autoencoder.decoder.input_types(), - decoder_state_dict, - type(autoencoder.decoder).__name__.lower(), - ) - if encoder_state_dict and hasattr(autoencoder, "encoder"): - self.encoder_model = self._compile_module( - autoencoder.encoder, - autoencoder.encoder.input_types(), - encoder_state_dict, - type(autoencoder.encoder).__name__.lower(), - ) - return self.model - - def encode( - self, sample: Buffer, return_dict: bool = True - ) -> dict[str, DiagonalGaussianDistribution] | DiagonalGaussianDistribution: - """Encode images to latent distribution using compiled encoder. - - Args: - sample: Input image tensor of shape [N, C_in, H, W]. - return_dict: If True, returns a dictionary with "latent_dist" key. - If False, returns DiagonalGaussianDistribution directly. - - Returns: - If return_dict=True: Dictionary with "latent_dist" key containing - DiagonalGaussianDistribution. - If return_dict=False: DiagonalGaussianDistribution directly. - - Raises: - ValueError: If encoder is not loaded. - """ - if self.encoder_model is None: - raise ValueError( - "Encoder not loaded. Check if encoder weights exist in the model." - ) - - moments = self._unwrap_single(self.encoder_model(sample)) - posterior = DiagonalGaussianDistribution(moments, moments) - if return_dict: - return {"latent_dist": posterior} - return posterior - - def decode(self, z: Buffer) -> Buffer: - """Decode latents to images using compiled decoder. - - Args: - z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. - - Returns: - Decoded image tensor. - """ - return self._unwrap_single(self.model(z)) - - def __call__(self, z: Buffer) -> Buffer: - """Call the decoder model to decode latents to images. - - This method provides a consistent interface with other ComponentModel - implementations. It is an alias for decode(). - - Args: - z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. - - Returns: - Decoded image tensor. - """ - return self.decode(z) diff --git a/max/python/max/pipelines/architectures/autoencoders/model_config.py b/max/python/max/pipelines/architectures/autoencoders/model_config.py index d91e439d250..5a37d2558a6 100644 --- a/max/python/max/pipelines/architectures/autoencoders/model_config.py +++ b/max/python/max/pipelines/architectures/autoencoders/model_config.py @@ -78,6 +78,83 @@ def generate( return AutoencoderKLConfig(**init_dict) +class AutoencoderKLWanConfigBase(MAXModelConfigBase): + # Defaults mirror Wan2.2 AutoencoderKLWan config. + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temporal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + latents_mean: tuple[float, ...] = ( + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ) + latents_std: tuple[float, ...] = ( + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ) + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class AutoencoderKLWanConfig(AutoencoderKLWanConfigBase): + @staticmethod + def generate( + config_dict: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + ) -> "AutoencoderKLWanConfig": + init_dict = { + key: value + for key, value in config_dict.items() + if key in AutoencoderKLWanConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": supported_encoding_dtype(encoding), + "device": DeviceRef.from_device(devices[0]), + } + ) + return AutoencoderKLWanConfig(**init_dict) + + class AutoencoderKLQwenImageConfigBase(MAXModelConfigBase): """Configuration for the QwenImage 3D causal VAE (Wan-2.1 based).""" diff --git a/max/python/max/pipelines/architectures/autoencoders/vae.py b/max/python/max/pipelines/architectures/autoencoders/vae.py index eeb4e1267ca..0a4955b6bd0 100644 --- a/max/python/max/pipelines/architectures/autoencoders/vae.py +++ b/max/python/max/pipelines/architectures/autoencoders/vae.py @@ -11,809 +11,2116 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # -from dataclasses import dataclass +from __future__ import annotations +from itertools import pairwise + +from max.driver import accelerator_api from max.dtype import DType -from max.graph import DeviceRef, TensorType, TensorValue -from max.nn.activation import activation_function_from_name -from max.nn.conv import Conv2d +from max.graph import DeviceRef, TensorValue, Weight, ops +from max.graph.type import FilterLayout from max.nn.layer import LayerList, Module -from max.nn.norm import GroupNorm -from .layers import Downsample2D, ResnetBlock2D, Upsample2D, VAEAttention +from .model_config import AutoencoderKLWanConfig + +CACHE_T = 2 +WAN_DECODER_CACHE_SLOTS = 32 +WAN_ENCODER_CHUNK_SIZE = 4 # Frames per encoder chunk (matching diffusers) + + +def _use_nvidia_fcrs_conv3d(device: DeviceRef | None) -> bool: + return ( + device is not None and device.is_gpu() and accelerator_api() == "cuda" + ) + + +def _zero_cache_for(x: TensorValue) -> TensorValue: + """Create a zero cache tensor shaped for a causal conv input.""" + shape: list[int | str] = [x.shape[0], x.shape[1], CACHE_T, x.shape[3], x.shape[4]] # type: ignore[list-item] + return ops.constant(0.0, dtype=x.dtype, device=x.device).broadcast_to(shape) + + +class RMSNorm(Module): + """RMS norm used by Wan VAE blocks.""" + + def __init__( + self, + dim: int, + channel_first: bool = True, + images: bool = False, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + self.channel_first = channel_first + + broadcastable_dims = (1, 1) if images else (1, 1, 1) + shape = [dim, *broadcastable_dims] if channel_first else [dim] + dev_ref = device if device is not None else DeviceRef.CPU() + self.gamma = Weight( + "gamma", + dtype or DType.float32, + shape, + dev_ref, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + axis = 1 if self.channel_first else x.rank - 1 + rms = ops.mean(x * x, axis=axis) + inv = ops.rsqrt(rms + 1e-12) + gamma = ops.transfer_to(self.gamma, x.device) + return x * inv * gamma + + +class CausalConv3d(Module): + """3D causal convolution for Wan VAE. + + Temporal causality is implemented via asymmetric padding: the front + (temporal) dimension is padded on the left only, which the conv3d + padding parameter supports directly. + + Input is permuted from NCDHW to NDHWC before conv, and back after. + On NVIDIA GPUs, weights stay in PyTorch FCQRS layout to use the cuDNN + 3D conv dispatch path. Other backends use MAX's native QRSCF layout. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + dtype: DType | None = None, + device: DeviceRef | None = None, + has_bias: bool = True, + prefer_nvidia_fcrs: bool = True, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + pad_t = pad_h = pad_w = padding + else: + pad_t, pad_h, pad_w = padding + + self.in_channels = in_channels + self.out_channels = out_channels + self._stride = stride + # Causal: pad only the front of the temporal axis (left=2*pad_t, right=0). + self._padding = (2 * pad_t, 0, pad_h, pad_h, pad_w, pad_w) + + dev_ref = device if device is not None else DeviceRef.CPU() + dt = dtype or DType.float32 + d, h, w = kernel_size + self._use_nvidia_fcrs = prefer_nvidia_fcrs and _use_nvidia_fcrs_conv3d( + dev_ref + ) + filter_shape = ( + [out_channels, in_channels, d, h, w] + if self._use_nvidia_fcrs + else [d, h, w, in_channels, out_channels] + ) + self.filter = Weight("weight", dt, filter_shape, dev_ref) + self._has_bias = has_bias + if has_bias: + self.bias = Weight("bias", dt, [out_channels], dev_ref) + + def __call__(self, x: TensorValue) -> TensorValue: + # NCDHW -> NDHWC + x_ndhwc = ops.permute(x, [0, 2, 3, 4, 1]) + out = ops.conv3d( + x_ndhwc, + self.filter, + stride=self._stride, + padding=self._padding, + filter_layout=( + FilterLayout.FCRS + if self._use_nvidia_fcrs + else FilterLayout.QRSCF + ), + ) + # NDHWC -> NCDHW + out = ops.permute(out, [0, 4, 1, 2, 3]) + if self._has_bias: + bias_5d = ops.reshape(self.bias, [1, self.out_channels, 1, 1, 1]) + out = out + bias_5d + return out + + +class CausalConv3dCached(Module): + """3D causal convolution with explicit cache tensor I/O. + + Handles temporal causal padding separately via concat/pad before + calling the conv, while spatial padding is handled by conv3d. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + dtype: DType | None = None, + device: DeviceRef | None = None, + has_bias: bool = True, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + pad_t = pad_h = pad_w = padding + else: + pad_t, pad_h, pad_w = padding + + self.in_channels = in_channels + self.out_channels = out_channels + self._stride = stride + # Temporal causal padding: left=2*pad_t, right=0 + self._temporal_pad_left = 2 * pad_t + # Let conv3d handle spatial padding. Temporal padding = 0 here. + self._padding = (0, 0, pad_h, pad_h, pad_w, pad_w) + + dev_ref = device if device is not None else DeviceRef.CPU() + dt = dtype or DType.float32 + d, h, w = kernel_size + self._use_nvidia_fcrs = _use_nvidia_fcrs_conv3d(dev_ref) + filter_shape = ( + [out_channels, in_channels, d, h, w] + if self._use_nvidia_fcrs + else [d, h, w, in_channels, out_channels] + ) + self.filter = Weight("weight", dt, filter_shape, dev_ref) + self._has_bias = has_bias + if has_bias: + self.bias = Weight("bias", dt, [out_channels], dev_ref) + + def _apply_temporal_pad(self, x: TensorValue, pad_left: int) -> TensorValue: + """Zero-pad the temporal dimension (axis=2) on the left only.""" + if pad_left <= 0: + return x + # ops.pad expects 2*rank values: [d0_before, d0_after, d1_before, d1_after, ...] + # For 5D [B, C, T, H, W]: pad only dim 2 (T) on the left. + pad_vals = [0, 0, 0, 0, pad_left, 0, 0, 0, 0, 0] + return ops.pad(x, pad_vals) + + def _forward_conv(self, x: TensorValue) -> TensorValue: + # NCDHW -> NDHWC + x_ndhwc = ops.permute(x, [0, 2, 3, 4, 1]) + out = ops.conv3d( + x_ndhwc, + self.filter, + stride=self._stride, + padding=self._padding, + filter_layout=( + FilterLayout.FCRS + if self._use_nvidia_fcrs + else FilterLayout.QRSCF + ), + ) + # NDHWC -> NCDHW + out = ops.permute(out, [0, 4, 1, 2, 3]) + if self._has_bias: + bias_5d = ops.reshape(self.bias, [1, self.out_channels, 1, 1, 1]) + out = out + bias_5d + return out + + def __call__(self, x: TensorValue) -> TensorValue: + x = self._apply_temporal_pad(x, self._temporal_pad_left) + return self._forward_conv(x) + + def forward_cached( + self, x: TensorValue, cache_in: TensorValue + ) -> tuple[TensorValue, TensorValue]: + # Rebind cache spatial dims to match x so concat sees matching dims. + cache_in = ops.rebind( + cache_in, + shape=[ + cache_in.shape[0], + cache_in.shape[1], + cache_in.shape[2], + x.shape[3], + x.shape[4], + ], + ) + x = ops.concat([cache_in, x], axis=2) + cache_out = x[:, :, -CACHE_T:, :, :] + effective_pad = max(self._temporal_pad_left - CACHE_T, 0) + x = self._apply_temporal_pad(x, effective_pad) + return self._forward_conv(x), cache_out + + +class Conv2dPermuted(Module): + """2D convolution with NCHW input and FCRS weights (permute=True equivalent). + + Input is permuted from NCHW to NHWC before conv, and back after. + Weights stay in FCRS (PyTorch) layout. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dtype: DType | None = None, + device: DeviceRef | None = None, + has_bias: bool = True, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(stride, int): + self._stride = (stride, stride) + else: + self._stride = stride + if isinstance(padding, int): + self._padding = (padding, padding, padding, padding) + else: + self._padding = padding + + dev_ref = device if device is not None else DeviceRef.CPU() + dt = dtype or DType.float32 + self.filter = Weight( + "weight", + dt, + [out_channels, in_channels, kernel_size, kernel_size], + dev_ref, + ) + self._has_bias = has_bias + if has_bias: + self.bias = Weight("bias", dt, [out_channels], dev_ref) + + def __call__(self, x: TensorValue) -> TensorValue: + # NCHW -> NHWC + x_nhwc = ops.permute(x, [0, 2, 3, 1]) + out = ops.conv2d( + x_nhwc, + self.filter, + stride=self._stride, + padding=self._padding, + filter_layout=FilterLayout.FCRS, + ) + # NHWC -> NCHW + out = ops.permute(out, [0, 3, 1, 2]) + if self._has_bias: + bias_4d = ops.reshape(self.bias, [1, self.out_channels, 1, 1]) + out = out + bias_4d + return out + + +class Conv2d(Module): + """2D convolution with NHWC input and RSCF weights (permute=False equivalent). + + Input is already in NHWC layout. Weights are in RSCF layout + [H, W, in_channels, out_channels]. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dtype: DType | None = None, + device: DeviceRef | None = None, + has_bias: bool = True, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(stride, int): + self._stride = (stride, stride) + else: + self._stride = stride + if isinstance(padding, int): + self._padding = (padding, padding, padding, padding) + else: + self._padding = padding + + dev_ref = device if device is not None else DeviceRef.CPU() + dt = dtype or DType.float32 + self.filter = Weight( + "weight", + dt, + [kernel_size, kernel_size, in_channels, out_channels], + dev_ref, + ) + self._has_bias = has_bias + if has_bias: + self.bias = Weight("bias", dt, [out_channels], dev_ref) + + def __call__(self, x: TensorValue) -> TensorValue: + out = ops.conv2d( + x, + self.filter, + stride=self._stride, + padding=self._padding, + filter_layout=FilterLayout.RSCF, + bias=self.bias if self._has_bias else None, + ) + return out + + +class ResidualBlock(Module): + """Residual block used in Wan VAE decoder.""" + + def __init__( + self, + in_dim: int, + out_dim: int, + dtype: DType | None = None, + device: DeviceRef | None = None, + prefer_nvidia_fcrs: bool = True, + ) -> None: + super().__init__() + self.norm1 = RMSNorm( + in_dim, + images=False, + dtype=dtype, + device=device, + ) + self.conv1 = CausalConv3d( + in_dim, + out_dim, + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + prefer_nvidia_fcrs=prefer_nvidia_fcrs, + ) + self.norm2 = RMSNorm( + out_dim, + images=False, + dtype=dtype, + device=device, + ) + self.conv2 = CausalConv3d( + out_dim, + out_dim, + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + prefer_nvidia_fcrs=prefer_nvidia_fcrs, + ) + self.conv_shortcut = ( + CausalConv3d( + in_dim, + out_dim, + 1, + padding=0, + dtype=dtype, + device=device, + has_bias=True, + prefer_nvidia_fcrs=prefer_nvidia_fcrs, + ) + if in_dim != out_dim + else None + ) + + def __call__(self, x: TensorValue) -> TensorValue: + residual = ( + self.conv_shortcut(x) if self.conv_shortcut is not None else x + ) + x = ops.silu(self.norm1(x)) + x = self.conv1(x) + x = ops.silu(self.norm2(x)) + x = self.conv2(x) + return x + residual + + +class AttentionBlock(Module): + """Per-frame windowed self-attention used in Wan decoder mid block. + + Uses window attention instead of full (H*W)^2 attention to avoid OOM + at high resolutions. The spatial dimensions are partitioned into + non-overlapping windows of size ws*ws, and attention is computed + independently per window. + + Memory: O(b*t * num_windows * ws^2 * ws^2) instead of O(b*t * (H*W)^2). + At 720p latent (90x160) with ws=8: ~158MB vs ~2.5GB+ per chunk. + """ + + _WINDOW_SIZE: int = 8 + + def __init__( + self, + dim: int, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + self.dim = dim + self.norm = RMSNorm( + dim, + images=True, + dtype=dtype, + device=device, + ) + self.to_qkv = Conv2d( + in_channels=dim, + out_channels=dim * 3, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + has_bias=True, + ) + self.proj = Conv2d( + in_channels=dim, + out_channels=dim, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + identity = x + b = x.shape[0] + t = x.shape[2] + h = x.shape[3] + w = x.shape[4] + c = self.dim + ws = self._WINDOW_SIZE + + # [b, c, t, h, w] -> [b*t, c, h, w] + x2d = ops.permute(x, [0, 2, 1, 3, 4]) + x2d = ops.reshape(x2d, [b * t, c, h, w]) + x2d = self.norm(x2d) + + x2d_nhwc = ops.permute(x2d, [0, 2, 3, 1]) # [bt, h, w, c] + qkv = self.to_qkv(x2d_nhwc) # [bt, h, w, 3c] + + # Pad H and W up to the next multiple of ws. + # Use concat with zero tensors — always applied (no Python branching + # on symbolic dims). If already aligned, pad dims are 0-sized. + h_p = ((h + ws - 1) // ws) * ws + w_p = ((w + ws - 1) // ws) * ws + pad_w = w_p - w + pad_h = h_p - h + zero_w = ops.constant( + 0.0, dtype=qkv.dtype, device=qkv.device + ).broadcast_to([b * t, h, pad_w, 3 * c]) + qkv = ops.concat([qkv, zero_w], axis=2) + zero_h = ops.constant( + 0.0, dtype=qkv.dtype, device=qkv.device + ).broadcast_to([b * t, pad_h, w_p, 3 * c]) + qkv = ops.concat([qkv, zero_h], axis=1) + + hws = h_p // ws + wws = w_p // ws + nwin = hws * wws + tok = ws * ws + + q = qkv[:, :, :, :c] + k = qkv[:, :, :, c : 2 * c] + v = qkv[:, :, :, 2 * c : 3 * c] + + def to_windows(y: TensorValue) -> TensorValue: + y = ops.reshape(y, [b * t, hws, ws, wws, ws, c]) + y = ops.permute(y, [0, 1, 3, 2, 4, 5]) + return ops.reshape(y, [b * t, nwin, tok, c]) + + q_w = to_windows(q) + k_w = to_windows(k) + v_w = to_windows(v) + + attn_scores = ops.matmul( + q_w * (float(c) ** -0.5), ops.permute(k_w, [0, 1, 3, 2]) + ) + attn = ops.softmax(attn_scores, axis=-1) + out = ops.matmul(attn, v_w) # [bt, nwin, tok, c] + + out = ops.reshape(out, [b * t, hws, wws, ws, ws, c]) + out = ops.permute(out, [0, 1, 3, 2, 4, 5]) + out = ops.reshape(out, [b * t, h_p, w_p, c]) + + # Slice back to original spatial dims (remove padding). + out = out[:, :h, :w, :] + + out = self.proj(out) # [bt, h, w, c] + out = ops.permute(out, [0, 3, 1, 2]) # [bt, c, h, w] + out = ops.reshape(out, [b, t, c, h, w]) + out = ops.permute(out, [0, 2, 1, 3, 4]) + return out + identity + + +class MidBlock(Module): + """Middle decoder block with residual-attention-residual.""" + + def __init__( + self, + dim: int, + dtype: DType | None = None, + device: DeviceRef | None = None, + prefer_nvidia_fcrs: bool = True, + ) -> None: + super().__init__() + self.resnets = LayerList( + [ + ResidualBlock( + dim, + dim, + dtype=dtype, + device=device, + prefer_nvidia_fcrs=prefer_nvidia_fcrs, + ), + ResidualBlock( + dim, + dim, + dtype=dtype, + device=device, + prefer_nvidia_fcrs=prefer_nvidia_fcrs, + ), + ] + ) + self.attentions = LayerList( + [AttentionBlock(dim, dtype=dtype, device=device)] + ) + + def __call__(self, x: TensorValue) -> TensorValue: + x = self.resnets[0](x) + x = self.attentions[0](x) + x = self.resnets[1](x) + return x + + +class Upsample2d(Module): + """Nearest-neighbor 2D upsample by factor 2.""" + + def __init__(self) -> None: + super().__init__() + + def __call__(self, x: TensorValue) -> TensorValue: + n = x.shape[0] + c = x.shape[1] + h = x.shape[2] + w = x.shape[3] + # Nearest-neighbor 2x upsample: [N,C,H,W] → [N,C,H*2,W*2] + x = ops.reshape(x, [n, c, h, 1, w, 1]) + x = ops.concat([x, x], axis=3) # [N, C, H, 2, W, 1] + x = ops.concat([x, x], axis=5) # [N, C, H, 2, W, 2] + return ops.reshape(x, [n, c, h * 2, w * 2]) + + +class Resample(Module): + """Wan decoder upsampling module.""" + + def __init__( + self, + dim: int, + mode: str, + upsample_out_dim: int | None = None, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + self._out_c = upsample_out_dim + + self.time_conv: CausalConv3d | None = None + self.resample = LayerList( + [ + Upsample2d(), + Conv2dPermuted( + in_channels=dim, + out_channels=upsample_out_dim, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ), + ] + ) + + if mode == "upsample3d": + self.time_conv = CausalConv3d( + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + stride=1, + padding=(1, 0, 0), + dtype=dtype, + device=device, + has_bias=True, + ) + elif mode != "upsample2d": + raise ValueError(f"Unsupported Resample mode: {mode}") + + def __call__(self, x: TensorValue) -> TensorValue: + b = x.shape[0] + t = x.shape[2] + h = x.shape[3] + w = x.shape[4] + + if self.mode == "upsample3d": + if self.time_conv is None: + raise ValueError("time_conv is required for upsample3d mode") + x = self.time_conv(x) + # x: [b, 2*dim, t, h, w] -> interleave temporal frames + x = ops.reshape(x, [b, 2, self.dim, t, h, w]) + x = ops.permute(x, [0, 2, 3, 1, 4, 5]) # [b, dim, t, 2, h, w] + t = t * 2 + x = ops.reshape(x, [b, self.dim, t, h, w]) + + # Per-frame 2D upsample + conv + x = ops.permute(x, [0, 2, 1, 3, 4]) # [b, t, c, h, w] + x = ops.reshape(x, [b * t, self.dim, h, w]) + x = self.resample[0](x) # Upsample2d: [b*t, dim, h*2, w*2] + # Conv2dPermuted handles NCHW->NHWC->conv->NCHW internally. + x = self.resample[1](x) # [b*t, out_c, h*2, w*2] + + x = ops.reshape(x, [b, t, self._out_c, h * 2, w * 2]) + x = ops.permute(x, [0, 2, 1, 3, 4]) # [b, out_c, t, h*2, w*2] + return x + + +class UpBlock(Module): + """Wan decoder up block composed of residual blocks and optional upsample.""" + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + upsample_mode: str | None, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + resnets: list[ResidualBlock] = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + ResidualBlock( + current_dim, + out_dim, + dtype=dtype, + device=device, + ) + ) + current_dim = out_dim + self.resnets = LayerList(resnets) + + self.upsamplers: LayerList | None = None + if upsample_mode is not None: + self.upsamplers = LayerList( + [ + Resample( + out_dim, + mode=upsample_mode, + upsample_out_dim=None, + dtype=dtype, + device=device, + ) + ] + ) + + def __call__(self, x: TensorValue) -> TensorValue: + for resnet in self.resnets: + x = resnet(x) + + if self.upsamplers is not None: + x = self.upsamplers[0](x) + + return x + + +class Decoder3d(Module): + """Wan 3D decoder module.""" + + def __init__( + self, + dim: int = 96, + z_dim: int = 16, + dim_mult: tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + temporal_upsample: tuple[bool, ...] = (False, True, True), + out_channels: int = 3, + is_residual: bool = False, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + del is_residual + + dims = [dim * u for u in [dim_mult[-1], *dim_mult[::-1]]] + + self.conv_in = CausalConv3d( + z_dim, + dims[0], + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ) + + self.mid_block = MidBlock(dims[0], dtype=dtype, device=device) + + up_blocks: list[UpBlock] = [] + final_out_dim = dims[-1] + for i, (in_dim, out_dim) in enumerate(pairwise(dims)): + if i > 0: + in_dim = in_dim // 2 + + up_flag = i != len(dim_mult) - 1 + upsample_mode: str | None = None + if up_flag and temporal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + + up_blocks.append( + UpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + upsample_mode=upsample_mode, + dtype=dtype, + device=device, + ) + ) + final_out_dim = out_dim + + self.up_blocks = LayerList(up_blocks) + + self.norm_out = RMSNorm( + final_out_dim, + images=False, + dtype=dtype, + device=device, + ) + self.conv_out = CausalConv3d( + final_out_dim, + out_channels, + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + x = self.conv_in(x) + x = self.mid_block(x) + + for up_block in self.up_blocks: + x = up_block(x) + + x = self.norm_out(x) + x = ops.silu(x) + x = self.conv_out(x) + return x + + +class ResidualBlockCached(Module): + """Wan residual block with explicit cache I/O for conv1/conv2.""" + + def __init__( + self, + in_dim: int, + out_dim: int, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + self.norm1 = RMSNorm( + in_dim, + images=False, + dtype=dtype, + device=device, + ) + self.conv1 = CausalConv3dCached( + in_dim, + out_dim, + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ) + self.norm2 = RMSNorm( + out_dim, + images=False, + dtype=dtype, + device=device, + ) + self.conv2 = CausalConv3dCached( + out_dim, + out_dim, + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ) + self.conv_shortcut = ( + CausalConv3d( + in_dim, + out_dim, + 1, + padding=0, + dtype=dtype, + device=device, + has_bias=True, + ) + if in_dim != out_dim + else None + ) + + def __call__( + self, + x: TensorValue, + cache1_in: TensorValue | None = None, + cache2_in: TensorValue | None = None, + ) -> tuple[TensorValue, TensorValue, TensorValue]: + residual = ( + self.conv_shortcut(x) if self.conv_shortcut is not None else x + ) + + x = ops.silu(self.norm1(x)) + if cache1_in is None: + cache1_in = _zero_cache_for(x) + x, cache1_out = self.conv1.forward_cached(x, cache1_in) + + x = ops.silu(self.norm2(x)) + if cache2_in is None: + cache2_in = _zero_cache_for(x) + x, cache2_out = self.conv2.forward_cached(x, cache2_in) + return x + residual, cache1_out, cache2_out + + +class MidBlockCached(Module): + """Middle decoder block with cache threading.""" + + def __init__( + self, + dim: int, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + self.resnets = LayerList( + [ + ResidualBlockCached(dim, dim, dtype=dtype, device=device), + ResidualBlockCached(dim, dim, dtype=dtype, device=device), + ] + ) + self.attentions = LayerList( + [AttentionBlock(dim, dtype=dtype, device=device)] + ) + + def __call__( + self, x: TensorValue, *cache_inputs: TensorValue + ) -> tuple[TensorValue, TensorValue, TensorValue, TensorValue, TensorValue]: + if len(cache_inputs) not in (0, 4): + raise ValueError( + f"MidBlockCached expected 0 or 4 cache tensors, got {len(cache_inputs)}" + ) + + cache1_in = cache_inputs[0] if len(cache_inputs) == 4 else None + cache2_in = cache_inputs[1] if len(cache_inputs) == 4 else None + x, cache1_out, cache2_out = self.resnets[0](x, cache1_in, cache2_in) + x = self.attentions[0](x) + + cache3_in = cache_inputs[2] if len(cache_inputs) == 4 else None + cache4_in = cache_inputs[3] if len(cache_inputs) == 4 else None + x, cache3_out, cache4_out = self.resnets[1](x, cache3_in, cache4_in) + return x, cache1_out, cache2_out, cache3_out, cache4_out + + +class ResampleCached(Module): + """Wan upsample3d module with explicit cache I/O.""" + + def __init__( + self, + dim: int, + mode: str, + upsample_out_dim: int | None = None, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + if mode != "upsample3d": + raise ValueError("ResampleCached only supports mode='upsample3d'") + + self.dim = dim + self.mode = mode + + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + self._out_c = upsample_out_dim + + self.time_conv = CausalConv3dCached( + in_channels=dim, + out_channels=dim * 2, + kernel_size=(3, 1, 1), + stride=1, + padding=(1, 0, 0), + dtype=dtype, + device=device, + has_bias=True, + ) + self.resample = LayerList( + [ + Upsample2d(), + Conv2dPermuted( + in_channels=dim, + out_channels=upsample_out_dim, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ), + ] + ) + + def __call__( + self, + x: TensorValue, + cache_in: TensorValue | None = None, + first_chunk: bool = False, + ) -> tuple[TensorValue, TensorValue]: + b = x.shape[0] + t = x.shape[2] + h = x.shape[3] + w = x.shape[4] + + if cache_in is None: + cache_in = _zero_cache_for(x) + + if first_chunk: + cache_out = cache_in + else: + x, cache_out = self.time_conv.forward_cached(x, cache_in) + x = ops.reshape(x, [b, 2, self.dim, t, h, w]) + x = ops.permute(x, [0, 2, 3, 1, 4, 5]) + t = t * 2 + x = ops.reshape(x, [b, self.dim, t, h, w]) + + x = ops.permute(x, [0, 2, 1, 3, 4]) + x = ops.reshape(x, [b * t, self.dim, h, w]) + x = self.resample[0](x) + x = self.resample[1](x) + x = ops.reshape(x, [b, t, self._out_c, h * 2, w * 2]) + x = ops.permute(x, [0, 2, 1, 3, 4]) + return x, cache_out + + +class UpBlockCached(Module): + """Wan decoder up block with explicit cache threading.""" + + cache_slots: int + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + upsample_mode: str | None, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + resnets: list[ResidualBlockCached] = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + ResidualBlockCached( + current_dim, + out_dim, + dtype=dtype, + device=device, + ) + ) + current_dim = out_dim + self.resnets = LayerList(resnets) + + self._has_temporal_upsample = upsample_mode == "upsample3d" + self.cache_slots = len(resnets) * 2 + ( + 1 if self._has_temporal_upsample else 0 + ) + + self.upsamplers: LayerList | None = None + if upsample_mode is not None: + if upsample_mode == "upsample3d": + upsampler: Module = ResampleCached( + out_dim, + mode=upsample_mode, + upsample_out_dim=None, + dtype=dtype, + device=device, + ) + elif upsample_mode == "upsample2d": + upsampler = Resample( + out_dim, + mode=upsample_mode, + upsample_out_dim=None, + dtype=dtype, + device=device, + ) + else: + raise ValueError( + f"Unsupported UpBlockCached upsample mode: {upsample_mode}" + ) + + self.upsamplers = LayerList([upsampler]) + + def __call__( + self, + x: TensorValue, + *cache_inputs: TensorValue, + first_chunk: bool = False, + ) -> tuple[TensorValue, ...]: + if len(cache_inputs) not in (0, self.cache_slots): + raise ValueError( + f"UpBlockCached expected 0 or {self.cache_slots} cache tensors, got {len(cache_inputs)}" + ) + + use_cache_inputs = len(cache_inputs) == self.cache_slots + cache_outputs: list[TensorValue] = [] + cache_idx = 0 + + for resnet in self.resnets: + cache1_in = cache_inputs[cache_idx] if use_cache_inputs else None + cache2_in = ( + cache_inputs[cache_idx + 1] if use_cache_inputs else None + ) + x, cache1_out, cache2_out = resnet(x, cache1_in, cache2_in) + cache_outputs.extend([cache1_out, cache2_out]) + cache_idx += 2 + + if self.upsamplers is not None: + upsampler = self.upsamplers[0] + if self._has_temporal_upsample: + cache_in = cache_inputs[cache_idx] if use_cache_inputs else None + if not isinstance(upsampler, ResampleCached): + raise TypeError( + "Expected ResampleCached for temporal upsample" + ) + x, cache_out = upsampler( + x, + cache_in, + first_chunk=first_chunk, + ) + cache_outputs.append(cache_out) + else: + x = upsampler(x) + + return (x, *cache_outputs) + + +class Decoder3dCached(Module): + """Wan 3D decoder with explicit cache tensor I/O.""" + + def __init__( + self, + dim: int = 96, + z_dim: int = 16, + dim_mult: tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + temporal_upsample: tuple[bool, ...] = (False, True, True), + out_channels: int = 3, + is_residual: bool = False, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + del is_residual + + dims = [dim * u for u in [dim_mult[-1], *dim_mult[::-1]]] + + self.conv_in = CausalConv3dCached( + z_dim, + dims[0], + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ) + + self.mid_block = MidBlockCached(dims[0], dtype=dtype, device=device) + + up_blocks: list[UpBlockCached] = [] + final_out_dim = dims[-1] + for i, (in_dim, out_dim) in enumerate(pairwise(dims)): + if i > 0: + in_dim = in_dim // 2 + + up_flag = i != len(dim_mult) - 1 + upsample_mode: str | None = None + if up_flag and temporal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + + up_blocks.append( + UpBlockCached( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + upsample_mode=upsample_mode, + dtype=dtype, + device=device, + ) + ) + final_out_dim = out_dim + self.up_blocks = LayerList(up_blocks) + self.norm_out = RMSNorm( + final_out_dim, + images=False, + dtype=dtype, + device=device, + ) + self.conv_out = CausalConv3dCached( + final_out_dim, + out_channels, + 3, + padding=1, + dtype=dtype, + device=device, + has_bias=True, + ) -class DownEncoderBlock2D(Module): - """Downsampling encoder block for 2D VAE. + def __call__( + self, + x: TensorValue, + *cache_inputs: TensorValue, + first_chunk: bool = False, + ) -> tuple[TensorValue, ...]: + if len(cache_inputs) not in (0, WAN_DECODER_CACHE_SLOTS): + raise ValueError( + "Decoder3dCached expected 0 or " + f"{WAN_DECODER_CACHE_SLOTS} cache tensors, got {len(cache_inputs)}" + ) - This module consists of multiple ResNet blocks followed by an optional - downsampling layer. It progressively decreases spatial resolution while - processing features through residual connections. - """ + use_cache_inputs = len(cache_inputs) == WAN_DECODER_CACHE_SLOTS + cache_outputs: list[TensorValue] = [] + cache_idx = 0 + + conv_in_cache = cache_inputs[cache_idx] if use_cache_inputs else None + if conv_in_cache is None: + conv_in_cache = _zero_cache_for(x) + x, cache_out = self.conv_in.forward_cached(x, conv_in_cache) + cache_outputs.append(cache_out) + cache_idx += 1 + + mid_cache_inputs: tuple[TensorValue, ...] = ( + tuple(cache_inputs[cache_idx : cache_idx + 4]) + if use_cache_inputs + else () + ) + mid_outputs = self.mid_block(x, *mid_cache_inputs) + x = mid_outputs[0] + cache_outputs.extend(mid_outputs[1:]) + cache_idx += 4 - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - device: DeviceRef | None = None, - dtype: DType | None = None, - ) -> None: - """Initialize DownEncoderBlock2D module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - dropout: Dropout rate (currently unused). - num_layers: Number of ResNet blocks in this encoder block. - resnet_eps: Epsilon value for ResNet GroupNorm layers. - resnet_time_scale_shift: Time embedding scale/shift mode (not used - in encoder, temb=None). - resnet_act_fn: Activation function for ResNet blocks. - resnet_groups: Number of groups for ResNet GroupNorm. - resnet_pre_norm: Whether to apply normalization before ResNet. - output_scale_factor: Scaling factor for output (currently unused). - add_downsample: Whether to add downsampling layer after ResNet - blocks. - downsample_padding: Padding for the downsampling layer. - device: Device reference for module placement. - dtype: Data type for module parameters. - """ - super().__init__() - del dropout, resnet_pre_norm, output_scale_factor - resnets_list = [] + for up_block in self.up_blocks: + block_cache_inputs: tuple[TensorValue, ...] = ( + tuple( + cache_inputs[cache_idx : cache_idx + up_block.cache_slots] + ) + if use_cache_inputs + else () + ) + block_outputs = up_block( + x, + *block_cache_inputs, + first_chunk=first_chunk, + ) + x = block_outputs[0] + cache_outputs.extend(block_outputs[1:]) + cache_idx += up_block.cache_slots + + x = self.norm_out(x) + x = ops.silu(x) + conv_out_cache = cache_inputs[cache_idx] if use_cache_inputs else None + if conv_out_cache is None: + conv_out_cache = _zero_cache_for(x) + x, cache_out = self.conv_out.forward_cached(x, conv_out_cache) + cache_outputs.append(cache_out) + + if len(cache_outputs) != WAN_DECODER_CACHE_SLOTS: + raise ValueError( + "Decoder3dCached produced " + f"{len(cache_outputs)} cache tensors, expected {WAN_DECODER_CACHE_SLOTS}" + ) + return (x, *cache_outputs) - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels + def cache_shapes( + self, + batch_size: int, + latent_height: int, + latent_width: int, + ) -> list[list[int]]: + h = latent_height + w = latent_width + shapes: list[list[int]] = [ + [batch_size, self.conv_in.in_channels, CACHE_T, h, w] + ] + + for resnet in self.mid_block.resnets: + shapes.append([batch_size, resnet.conv1.in_channels, CACHE_T, h, w]) + shapes.append([batch_size, resnet.conv2.in_channels, CACHE_T, h, w]) - if resnet_time_scale_shift == "spatial": - raise NotImplementedError( - "resnet_time_scale_shift='spatial' is not supported in Max encoder. " - "Encoder uses temb=None, so only 'default' is supported." + for up_block in self.up_blocks: + for resnet in up_block.resnets: + shapes.append( + [batch_size, resnet.conv1.in_channels, CACHE_T, h, w] + ) + shapes.append( + [batch_size, resnet.conv2.in_channels, CACHE_T, h, w] ) - resnet = ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - groups=resnet_groups, - groups_out=resnet_groups, - eps=resnet_eps, - non_linearity=resnet_act_fn, - use_conv_shortcut=False, - conv_shortcut_bias=True, - device=device, - dtype=dtype, + if up_block.upsamplers is not None: + if up_block._has_temporal_upsample: + upsampler = up_block.upsamplers[0] + if not isinstance(upsampler, ResampleCached): + raise TypeError( + "Expected ResampleCached for temporal upsample" + ) + shapes.append( + [ + batch_size, + upsampler.time_conv.in_channels, + CACHE_T, + h, + w, + ] + ) + h *= 2 + w *= 2 + + shapes.append([batch_size, self.conv_out.in_channels, CACHE_T, h, w]) + if len(shapes) != WAN_DECODER_CACHE_SLOTS: + raise ValueError( + f"Expected {WAN_DECODER_CACHE_SLOTS} cache shapes, got {len(shapes)}" ) - resnets_list.append(resnet) + return shapes - self.resnets = LayerList(resnets_list) - self.downsamplers: LayerList | None = None - if add_downsample: - downsampler = Downsample2D( - channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - kernel_size=3, - norm_type=None, - bias=True, - device=device, - dtype=dtype, - ) - self.downsamplers = LayerList([downsampler]) - def __call__(self, hidden_states: TensorValue) -> TensorValue: - """Apply DownEncoderBlock2D forward pass. +class VAEPostQuantConv(Module): + """Standalone post-quant conv graph (k=1, frame-independent).""" - Args: - hidden_states: Input tensor of shape [N, C_in, H, W]. + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self.post_quant_conv = CausalConv3d( + in_channels=config.z_dim, + out_channels=config.z_dim, + kernel_size=1, + padding=0, + dtype=config.dtype, + device=config.device, + has_bias=True, + ) - Returns: - Output tensor of shape [N, C_out, H//2, W//2] (if downsampling) or - [N, C_out, H, W] (if no downsampling). - """ - for resnet in self.resnets: - hidden_states = resnet(hidden_states, None) - if self.downsamplers is not None: - hidden_states = self.downsamplers[0](hidden_states) - return hidden_states + def __call__(self, z: TensorValue) -> TensorValue: + return self.post_quant_conv(z) -class UpDecoderBlock2D(Module): - """Upsampling decoder block for 2D VAE. +class VAEDecoderFirstFrameCached(Module): + """First-frame decoder graph returning pixels + initialized caches.""" - This module consists of multiple ResNet blocks followed by an optional - upsampling layer. It progressively increases spatial resolution while - processing features through residual connections. - """ + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self.decoder = Decoder3dCached( + dim=config.base_dim, + z_dim=config.z_dim, + dim_mult=tuple(config.dim_mult), + num_res_blocks=config.num_res_blocks, + temporal_upsample=tuple(reversed(config.temporal_downsample)), + out_channels=config.out_channels, + is_residual=config.is_residual, + dtype=config.dtype, + device=config.device, + ) - def __init__( - self, - in_channels: int, - out_channels: int, - resolution_idx: int | None = None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: int | None = None, - device: DeviceRef | None = None, - dtype: DType | None = None, - ) -> None: - """Initialize UpDecoderBlock2D module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - resolution_idx: Optional resolution index for tracking. - dropout: Dropout rate (currently unused). - num_layers: Number of ResNet blocks in this decoder block. - resnet_eps: Epsilon value for ResNet GroupNorm layers. - resnet_time_scale_shift: Time embedding scale/shift mode. - resnet_act_fn: Activation function for ResNet blocks. - resnet_groups: Number of groups for ResNet GroupNorm. - resnet_pre_norm: Whether to apply normalization before ResNet. - output_scale_factor: Scaling factor for output (currently unused). - add_upsample: Whether to add upsampling layer after ResNet blocks. - temb_channels: Number of time embedding channels (None if not used). - device: Device reference for module placement. - dtype: Data type for module parameters. - """ + def __call__(self, z: TensorValue) -> tuple[TensorValue, ...]: + outputs = self.decoder(z, first_chunk=True) + x = outputs[0] + x = ops.max(x, -1.0) + x = ops.min(x, 1.0) + return (x, *outputs[1:]) + + +class VAEDecoderRestFrameCached(Module): + """Per-frame decoder graph with cache feedback for frames 1..T-1.""" + + def __init__(self, config: AutoencoderKLWanConfig) -> None: super().__init__() - resnets_list = [] - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnet = ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - groups=resnet_groups, - groups_out=resnet_groups, - eps=resnet_eps, - non_linearity=resnet_act_fn, - use_conv_shortcut=False, - conv_shortcut_bias=True, - device=device, - dtype=dtype, - ) - resnets_list.append(resnet) - self.resnets = LayerList(resnets_list) - - if add_upsample: - upsampler = Upsample2D( - channels=out_channels, - use_conv=True, - out_channels=out_channels, - name="conv", - kernel_size=3, - padding=1, - bias=True, - interpolate=True, - device=device, - dtype=dtype, - ) - self.upsamplers: LayerList | None = LayerList([upsampler]) - else: - self.upsamplers = None + self.decoder = Decoder3dCached( + dim=config.base_dim, + z_dim=config.z_dim, + dim_mult=tuple(config.dim_mult), + num_res_blocks=config.num_res_blocks, + temporal_upsample=tuple(reversed(config.temporal_downsample)), + out_channels=config.out_channels, + is_residual=config.is_residual, + dtype=config.dtype, + device=config.device, + ) def __call__( - self, hidden_states: TensorValue, temb: TensorValue | None = None - ) -> TensorValue: - """Apply UpDecoderBlock2D forward pass. + self, z: TensorValue, *cache_inputs: TensorValue + ) -> tuple[TensorValue, ...]: + outputs = self.decoder(z, *cache_inputs, first_chunk=False) + x = outputs[0] + x = ops.max(x, -1.0) + x = ops.min(x, 1.0) + return (x, *outputs[1:]) - Args: - hidden_states: Input tensor of shape [N, C_in, H, W]. - temb: Optional time embedding tensor. - Returns: - Output tensor of shape [N, C_out, H*2, W*2] (if upsampling) or - [N, C_out, H, W] (if no upsampling). - """ - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - if self.upsamplers is not None: - hidden_states = self.upsamplers[0](hidden_states) - return hidden_states +class VAEDecoder(Module): + """Wan VAE decoder graph used by AutoencoderKLWanModel.""" + + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self._config = config + self.post_quant_conv = CausalConv3d( + in_channels=config.z_dim, + out_channels=config.z_dim, + kernel_size=1, + padding=0, + dtype=config.dtype, + device=config.device, + has_bias=True, + ) + self.decoder = Decoder3d( + dim=config.base_dim, + z_dim=config.z_dim, + dim_mult=tuple(config.dim_mult), + num_res_blocks=config.num_res_blocks, + temporal_upsample=tuple(reversed(config.temporal_downsample)), + out_channels=config.out_channels, + is_residual=config.is_residual, + dtype=config.dtype, + device=config.device, + ) + def __call__(self, z: TensorValue) -> TensorValue: + x = self.post_quant_conv(z) + x = self.decoder(x) + x = ops.max(x, -1.0) + x = ops.min(x, 1.0) + return x -class MidBlock2D(Module): - """Middle block for 2D VAE. - This module processes features at the middle of the VAE architecture, - applying ResNet blocks with optional spatial attention mechanisms. - It maintains spatial dimensions while processing features through - residual connections and self-attention. +class VAEDecoderFirstFrame(Module): + """Wan VAE decoder for the FIRST latent frame. + + Identical to VAEDecoder but ALL temporal upsamples are replaced + with spatial-only upsample2d (time_conv is omitted). This means + T=1 in -> T=1 out, matching the diffusers feat_cache behavior where + the first frame skips temporal upsampling. + """ + + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self._config = config + self.post_quant_conv = CausalConv3d( + in_channels=config.z_dim, + out_channels=config.z_dim, + kernel_size=1, + padding=0, + dtype=config.dtype, + device=config.device, + has_bias=True, + ) + # Force all temporal upsamples to spatial-only. + self.decoder = Decoder3d( + dim=config.base_dim, + z_dim=config.z_dim, + dim_mult=tuple(config.dim_mult), + num_res_blocks=config.num_res_blocks, + temporal_upsample=(False,) * len(config.temporal_downsample), + out_channels=config.out_channels, + is_residual=config.is_residual, + dtype=config.dtype, + device=config.device, + ) + + def __call__(self, z: TensorValue) -> TensorValue: + x = self.post_quant_conv(z) + x = self.decoder(x) + x = ops.max(x, -1.0) + x = ops.min(x, 1.0) + return x + + +class DownResample(Module): + """Wan encoder downsampling module. + + Matches diffusers Resample downsample modes: + - downsample2d: ZeroPad2d + Conv2d(stride=2) per frame + - downsample3d: same spatial + CausalConv3d(stride=(2,1,1)) temporal """ def __init__( self, - in_channels: int, - temb_channels: int | None, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - device: DeviceRef | None = None, + dim: int, + mode: str, dtype: DType | None = None, + device: DeviceRef | None = None, + prefer_nvidia_fcrs: bool = True, ) -> None: - """Initialize MidBlock2D module. - - Args: - in_channels: Number of input channels. - temb_channels: Number of time embedding channels (None if not used). - dropout: Dropout rate (currently unused). - num_layers: Number of ResNet/attention layer pairs. - resnet_eps: Epsilon value for ResNet GroupNorm layers. - resnet_time_scale_shift: Time embedding scale/shift mode. - resnet_act_fn: Activation function for ResNet blocks. - resnet_groups: Number of groups for ResNet GroupNorm. - resnet_pre_norm: Whether to apply normalization before ResNet. - add_attention: Whether to add attention layers between ResNet - blocks. - attention_head_dim: Dimension of each attention head. - output_scale_factor: Scaling factor for output (currently unused). - device: Device reference for module placement. - dtype: Data type for module parameters. - """ super().__init__() - - resnets_list = [] - attentions_list: list[VAEAttention | None] = [] - - resnet = ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - groups=resnet_groups, - groups_out=resnet_groups, - eps=resnet_eps, - non_linearity=resnet_act_fn, - use_conv_shortcut=False, - conv_shortcut_bias=True, - device=device, - dtype=dtype, + self.dim = dim + self.mode = mode + + # Spatial: ZeroPad2d(0,1,0,1) + Conv2d(stride=2, padding=0) + # Asymmetric padding: right=1, bottom=1 only. + # Use index [1] to match state_dict key "resample.1". + self.resample = LayerList( + [ + Upsample2d(), # Dummy at index 0 (no weights, not called) + Conv2dPermuted( + in_channels=dim, + out_channels=dim, + kernel_size=3, + stride=2, + padding=0, # We do manual asymmetric pad in __call__ + dtype=dtype, + device=device, + has_bias=True, + ), + ] ) - resnets_list.append(resnet) - for _i in range(num_layers): - if add_attention: - attn = VAEAttention( - query_dim=in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - num_groups=resnet_groups, - eps=resnet_eps, - device=device, - dtype=dtype, - ) - attentions_list.append(attn) - else: - attentions_list.append(None) - - resnet = ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - groups=resnet_groups, - groups_out=resnet_groups, - eps=resnet_eps, - non_linearity=resnet_act_fn, - use_conv_shortcut=False, - conv_shortcut_bias=True, - device=device, + self.time_conv: CausalConv3d | None = None + if mode == "downsample3d": + self.time_conv = CausalConv3d( + in_channels=dim, + out_channels=dim, + kernel_size=(3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0), dtype=dtype, + device=device, + has_bias=True, + # Encoder temporal downsample is the only conv pattern that + # currently reproduces cuDNN aborts in VAE encode. + prefer_nvidia_fcrs=False, ) - resnets_list.append(resnet) + elif mode != "downsample2d": + raise ValueError(f"Unsupported DownResample mode: {mode}") + + def __call__(self, x: TensorValue) -> TensorValue: + b = x.shape[0] + t = x.shape[2] + h = x.shape[3] + w = x.shape[4] + + if self.mode == "downsample3d": + if self.time_conv is None: + raise ValueError("time_conv is required for downsample3d mode") + # Temporal downsample via strided causal conv + x = self.time_conv(x) + t = x.shape[2] + + # Per-frame spatial downsample: ZeroPad2d(0,1,0,1) + Conv2d(stride=2) + x = ops.permute(x, [0, 2, 1, 3, 4]) # [b, t, c, h, w] + x = ops.reshape(x, [b * t, self.dim, h, w]) + # ZeroPad2d(left=0, right=1, top=0, bottom=1) on NCHW + # paddings format: [N_before, N_after, C_before, C_after, H_before, H_after, W_before, W_after] + x = ops.pad(x, [0, 0, 0, 0, 0, 1, 0, 1]) + x = self.resample[1](x) # Conv2d stride=2, padding=0 + new_h = (h + 1) // 2 + new_w = (w + 1) // 2 + # Rebind so the compiler sees conv output shape matches our computation. + x = ops.rebind(x, shape=[b * t, self.dim, new_h, new_w]) + x = ops.reshape(x, [b, t, self.dim, new_h, new_w]) + x = ops.permute(x, [0, 2, 1, 3, 4]) # [b, dim, t, h/2, w/2] + + return x + + +class DownResampleCached(Module): + """Encoder downsample with temporal cache for chunked encoding. + + Matches diffusers' Resample cache behavior for the encoder: + - downsample2d: spatial only, no temporal cache + - downsample3d first chunk: spatial downsample, skip time_conv, cache last frame + - downsample3d rest chunk: spatial downsample, prepend cached frame, apply time_conv + + Spatial downsample is done FIRST (matching diffusers order), then temporal. + """ - self.resnets = LayerList(resnets_list) + cache_slots: int - if attentions_list: - non_none_attentions = [ - attn for attn in attentions_list if attn is not None + def __init__( + self, + dim: int, + mode: str, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + self.dim = dim + self.mode = mode + self._has_temporal = mode == "downsample3d" + self.cache_slots = 1 if self._has_temporal else 0 + + self.resample = LayerList( + [ + Upsample2d(), # Dummy at index 0 (match weight naming) + Conv2dPermuted( + in_channels=dim, + out_channels=dim, + kernel_size=3, + stride=2, + padding=0, + dtype=dtype, + device=device, + has_bias=True, + ), ] - if non_none_attentions: - self.attentions: LayerList | None = LayerList( - non_none_attentions - ) - self.attention_indices = { - i - for i, attn in enumerate(attentions_list) - if attn is not None - } - else: - self.attentions = None - self.attention_indices = set() - else: - self.attentions = None - self.attention_indices = set() + ) + + self.time_conv: CausalConv3d | None = None + if self._has_temporal: + self.time_conv = CausalConv3d( + in_channels=dim, + out_channels=dim, + kernel_size=(3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0), + dtype=dtype, + device=device, + has_bias=True, + prefer_nvidia_fcrs=False, + ) + elif mode != "downsample2d": + raise ValueError(f"Unsupported DownResampleCached mode: {mode}") def __call__( - self, hidden_states: TensorValue, temb: TensorValue | None = None - ) -> TensorValue: - """Apply MidBlock2D forward pass. + self, + x: TensorValue, + *, + cache_in: TensorValue | None = None, + first_chunk: bool = False, + ) -> tuple[TensorValue, ...]: + b = x.shape[0] + t = x.shape[2] + h = x.shape[3] + w = x.shape[4] + + # Spatial downsample first (matching diffusers order) + x = ops.permute(x, [0, 2, 1, 3, 4]) # [b, t, c, h, w] + x = ops.reshape(x, [b * t, self.dim, h, w]) + x = ops.pad(x, [0, 0, 0, 0, 0, 1, 0, 1]) # ZeroPad2d(0,1,0,1) + x = self.resample[1](x) # Conv2d stride=2 + new_h = (h + 1) // 2 + new_w = (w + 1) // 2 + # Rebind so the compiler sees conv output shape matches our computation. + x = ops.rebind(x, shape=[b * t, self.dim, new_h, new_w]) + x = ops.reshape(x, [b, t, self.dim, new_h, new_w]) + x = ops.permute(x, [0, 2, 1, 3, 4]) # [b, c, t, h', w'] + + if self._has_temporal: + assert self.time_conv is not None + cache_out = x[:, :, -1:, :, :] # Last frame after spatial + if first_chunk: + # Skip time_conv, return spatial output + cache + return x, cache_out + else: + assert cache_in is not None + # Rebind cache spatial dims to match x after spatial downsample. + cache_in = ops.rebind( + cache_in, + shape=[ + cache_in.shape[0], + cache_in.shape[1], + cache_in.shape[2], + x.shape[3], + x.shape[4], + ], + ) + # Prepend cached last frame, apply time_conv + x_cat = ops.concat([cache_in, x], axis=2) + x = self.time_conv(x_cat) + return x, cache_out - Args: - hidden_states: Input tensor of shape [N, C, H, W]. - temb: Optional time embedding tensor. + return (x,) - Returns: - Output tensor of shape [N, C, H, W] with same spatial dimensions. - """ - hidden_states = self.resnets[0](hidden_states, temb) - attention_idx = 0 - for i in range(len(self.resnets) - 1): - if self.attentions is not None and i in self.attention_indices: - hidden_states = self.attentions[attention_idx](hidden_states) - attention_idx += 1 - hidden_states = self.resnets[i + 1](hidden_states, temb) - return hidden_states - - -@dataclass -class DecoderOutput: - r"""Output of decoding method. - - Args: - sample (`TensorValue` of shape `(batch_size, num_channels, height, width)`): - The decoded output sample from the last layer of the model. - """ - sample: TensorValue - commit_loss: TensorValue | None = None +class DownBlock(Module): + """Wan encoder down block (mirror of UpBlock).""" + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + downsample_mode: str | None, + dtype: DType | None = None, + device: DeviceRef | None = None, + ) -> None: + super().__init__() + resnets: list[ResidualBlock] = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + ResidualBlock( + current_dim, + out_dim, + dtype=dtype, + device=device, + ) + ) + current_dim = out_dim + self.resnets = LayerList(resnets) + + self.downsamplers: LayerList | None = None + if downsample_mode is not None: + self.downsamplers = LayerList( + [ + DownResample( + out_dim, + mode=downsample_mode, + dtype=dtype, + device=device, + ) + ] + ) + + def __call__(self, x: TensorValue) -> TensorValue: + for resnet in self.resnets: + x = resnet(x) + + if self.downsamplers is not None: + x = self.downsamplers[0](x) + return x -class Encoder(Module): - r"""The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. - This module progressively downsamples the input through multiple encoder blocks, - applies a middle block for feature processing, and outputs encoded latents. +class Encoder3d(Module): + """Wan 3D encoder module (mirror of Decoder3d). - Args: - in_channels: The number of input channels. - out_channels: The number of output channels. - down_block_types: The types of down blocks to use. Currently only supports "DownEncoderBlock2D". - block_out_channels: The number of output channels for each block. - layers_per_block: The number of layers per block. - norm_num_groups: The number of groups for normalization. - act_fn: The activation function to use (e.g., "silu"). - double_z: Whether to double the number of output channels for the last block. - mid_block_add_attention: Whether to add attention in the middle block. - device: Device reference for module placement. - dtype: Data type for module parameters. + Uses a flat ModuleList for down_blocks to match the diffusers + safetensors key naming (encoder.down_blocks.{i}.{conv1,norm1,...}). """ def __init__( self, + dim: int = 96, + z_dim: int = 16, in_channels: int = 3, - out_channels: int = 3, - down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",), - block_out_channels: tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - mid_block_add_attention: bool = True, - use_quant_conv: bool = False, - device: DeviceRef | None = None, + dim_mult: tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + temporal_downsample: tuple[bool, ...] = (False, True, True), dtype: DType | None = None, + device: DeviceRef | None = None, ) -> None: - """Initialize Encoder module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - down_block_types: Tuple of down block types (currently only "DownEncoderBlock2D"). - block_out_channels: Tuple of block output channels. - layers_per_block: Number of layers per block. - norm_num_groups: Number of groups for normalization. - act_fn: Activation function name (e.g., "silu"). - double_z: Whether to double output channels for the last block. - mid_block_add_attention: Whether to add attention in the middle block. - use_quant_conv: Whether to add 1x1 conv after conv_out (encoder output -> latent moments). - device: Device reference for module placement. - dtype: Data type for module parameters. - """ super().__init__() - if dtype is None: - raise ValueError("dtype must be set for Encoder") - if device is None: - raise ValueError("device must be set for Encoder") - self.layers_per_block = layers_per_block - self.in_channels = in_channels - self.device = device - self.dtype = dtype - self.activation = activation_function_from_name(act_fn) - self.conv_in = Conv2d( - kernel_size=3, - in_channels=in_channels, - out_channels=block_out_channels[0], - dtype=dtype, - stride=1, + + dims = [dim * u for u in [1, *list(dim_mult)]] + + self.conv_in = CausalConv3d( + in_channels, + dims[0], + 3, padding=1, - has_bias=True, + dtype=dtype, device=device, - permute=True, + has_bias=True, + prefer_nvidia_fcrs=False, ) - output_channel = block_out_channels[0] - down_blocks_list = [] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - if down_block_type != "DownEncoderBlock2D": - raise ValueError( - f"Unsupported down_block_type: {down_block_type}. " - "Currently only 'DownEncoderBlock2D' is supported." + # Flat ModuleList matching diffusers weight naming: + # down_blocks.{0,1} = ResidualBlock (first level, 2 blocks) + # down_blocks.2 = Resample (downsample) + # down_blocks.{3,4} = ResidualBlock (second level) + # down_blocks.5 = Resample ...etc + down_blocks: list[Module] = [] + for i, (in_dim, out_dim) in enumerate(pairwise(dims)): + for j in range(num_res_blocks): + down_blocks.append( + ResidualBlock( + in_dim if j == 0 else out_dim, + out_dim, + dtype=dtype, + device=device, + prefer_nvidia_fcrs=False, + ) + ) + down_flag = i != len(dim_mult) - 1 + if down_flag: + mode = ( + "downsample3d" if temporal_downsample[i] else "downsample2d" + ) + down_blocks.append( + DownResample( + out_dim, + mode=mode, + dtype=dtype, + device=device, + prefer_nvidia_fcrs=False, + ) ) - down_block = DownEncoderBlock2D( - in_channels=input_channel, - out_channels=output_channel, - dropout=0.0, - num_layers=self.layers_per_block, - resnet_eps=1e-6, - resnet_time_scale_shift="default", - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - output_scale_factor=1.0, - add_downsample=not is_final_block, - downsample_padding=0, - device=device, - dtype=dtype, - ) - down_blocks_list.append(down_block) - - self.down_blocks = LayerList(down_blocks_list) - - self.mid_block = MidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=None, - dropout=0.0, - num_layers=1, - resnet_eps=1e-6, - resnet_time_scale_shift="default", - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - add_attention=mid_block_add_attention, - attention_head_dim=block_out_channels[-1], - output_scale_factor=1.0, - device=device, + self.down_blocks = LayerList(down_blocks) + + final_dim = dims[-1] + self.mid_block = MidBlock( + final_dim, dtype=dtype, - ) - self.conv_norm_out = GroupNorm( - num_groups=norm_num_groups, - num_channels=block_out_channels[-1], - eps=1e-6, - affine=True, device=device, + prefer_nvidia_fcrs=False, ) - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = Conv2d( - kernel_size=3, - in_channels=block_out_channels[-1], - out_channels=conv_out_channels, + + self.norm_out = RMSNorm( + final_dim, + images=False, dtype=dtype, - stride=1, + device=device, + ) + # Output 2*z_dim for mean + logvar + self.conv_out = CausalConv3d( + final_dim, + z_dim * 2, + 3, padding=1, - has_bias=True, + dtype=dtype, device=device, - permute=True, - ) - self.quant_conv: Conv2d | None = None - if use_quant_conv: - self.quant_conv = Conv2d( - kernel_size=1, - in_channels=conv_out_channels, - out_channels=conv_out_channels, - dtype=dtype, - stride=1, - padding=0, - has_bias=True, - device=device, - permute=True, - ) - - def __call__(self, sample: TensorValue) -> TensorValue: - r"""The forward method of the `Encoder` class. + has_bias=True, + prefer_nvidia_fcrs=False, + ) - Args: - sample: Input tensor of shape [N, C_in, H, W]. + def __call__(self, x: TensorValue) -> TensorValue: + x = self.conv_in(x) - Returns: - Output tensor of shape [N, C_out, H_latent, W_latent] (downsampled). - """ - sample = self.conv_in(sample) for down_block in self.down_blocks: - sample = down_block(sample) - sample = self.mid_block(sample, None) - sample = self.conv_norm_out(sample) - sample = self.activation(sample) - sample = self.conv_out(sample) - if self.quant_conv is not None: - sample = self.quant_conv(sample) - return sample - - def input_types(self) -> tuple[TensorType, ...]: - """Define input tensor types for the encoder model. - - Returns: - Tuple of TensorType specifications for encoder input. - """ - return ( - TensorType( - self.dtype, - shape=[ - "batch_size", - self.in_channels, - "image_height", - "image_width", - ], - device=self.device, - ), - ) + x = down_block(x) + x = self.mid_block(x) + x = self.norm_out(x) + x = ops.silu(x) + x = self.conv_out(x) + return x -class Decoder(Module): - """VAE decoder for generating images from latent representations. - This decoder progressively upsamples latent features through multiple - decoder blocks, applying ResNet layers and attention mechanisms to - reconstruct high-resolution images from compressed latent codes. +class Encoder3dCached(Module): + """Chunked encoder with explicit cache I/O for temporal context. + + Uses a flat ModuleList for down_blocks (matching Encoder3d weight naming). + Each chunk processes either 1 frame (first) or CHUNK_SIZE frames (rest). + Temporal context is maintained via cache tensors passed between chunks. """ def __init__( self, + dim: int = 96, + z_dim: int = 16, in_channels: int = 3, - out_channels: int = 3, - up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), - block_out_channels: tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", - mid_block_add_attention: bool = True, - use_post_quant_conv: bool = True, - device: DeviceRef | None = None, + dim_mult: tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + temporal_downsample: tuple[bool, ...] = (False, True, True), dtype: DType | None = None, + device: DeviceRef | None = None, ) -> None: - """Initialize Decoder module. - - Args: - in_channels: Number of input channels (latent channels). - out_channels: Number of output channels (image channels). - up_block_types: Tuple of upsampling block types. - block_out_channels: Tuple of channel counts for each decoder block. - layers_per_block: Number of ResNet layers per decoder block. - norm_num_groups: Number of groups for GroupNorm layers. - act_fn: Activation function name (e.g., "silu"). - norm_type: Normalization type ("group" or "spatial"). - mid_block_add_attention: Whether to add attention in middle block. - use_post_quant_conv: Whether to use post-quantization convolution. - device: Device reference for module placement. - dtype: Data type for module parameters. - """ super().__init__() - if dtype is None: - raise ValueError("dtype must be set for Decoder") - if device is None: - raise ValueError("device must be set for Decoder") - - self.layers_per_block = layers_per_block - self.in_channels = in_channels - self.device = device - self.dtype = dtype - self.activation = activation_function_from_name(act_fn) - - self.post_quant_conv: Conv2d | None = None - if use_post_quant_conv: - self.post_quant_conv = Conv2d( - kernel_size=1, - in_channels=in_channels, - out_channels=in_channels, - dtype=dtype, - stride=1, - padding=0, - has_bias=True, - device=device, - permute=True, - ) - - self.conv_in = Conv2d( - kernel_size=3, - in_channels=in_channels, - out_channels=block_out_channels[-1], - dtype=dtype, - stride=1, + self._dim = dim + self._in_channels = in_channels + self._dim_mult = dim_mult + self._num_res_blocks = num_res_blocks + self._temporal_downsample = temporal_downsample + + dims = [dim * u for u in [1, *list(dim_mult)]] + + self.conv_in = CausalConv3dCached( + in_channels, + dims[0], + 3, padding=1, - has_bias=True, - device=device, - permute=True, - ) - temb_channels = in_channels if norm_type == "spatial" else None - self.mid_block = MidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=temb_channels, - num_layers=1, - resnet_eps=1e-6, - resnet_time_scale_shift=( - "default" if norm_type == "group" else norm_type - ), - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - add_attention=mid_block_add_attention, - attention_head_dim=block_out_channels[-1], - output_scale_factor=1.0, - device=device, dtype=dtype, + device=device, + has_bias=True, ) - up_blocks_list = [] - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - if up_block_type == "UpDecoderBlock2D": - up_block = UpDecoderBlock2D( - in_channels=prev_output_channel, - out_channels=output_channel, - resolution_idx=i, - dropout=0.0, - num_layers=self.layers_per_block + 1, - resnet_eps=1e-6, - resnet_time_scale_shift=norm_type, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_pre_norm=True, - output_scale_factor=1.0, - add_upsample=not is_final_block, - temb_channels=temb_channels, - device=device, + # Flat list matching diffusers weight naming + down_blocks: list[Module] = [] + self._block_cache_slots: list[int] = [] + for i, (in_dim, out_dim) in enumerate(pairwise(dims)): + for j in range(num_res_blocks): + down_blocks.append( + ResidualBlockCached( + in_dim if j == 0 else out_dim, + out_dim, + dtype=dtype, + device=device, + ) + ) + self._block_cache_slots.append(2) + down_flag = i != len(dim_mult) - 1 + if down_flag: + mode = ( + "downsample3d" if temporal_downsample[i] else "downsample2d" + ) + ds = DownResampleCached( + out_dim, + mode=mode, dtype=dtype, + device=device, ) - up_blocks_list.append(up_block) - else: - raise ValueError(f"Unsupported up_block_type: {up_block_type}") + down_blocks.append(ds) + self._block_cache_slots.append(ds.cache_slots) - prev_output_channel = output_channel + self.down_blocks = LayerList(down_blocks) - self.up_blocks = LayerList(up_blocks_list) + final_dim = dims[-1] + self.mid_block = MidBlockCached(final_dim, dtype=dtype, device=device) - if norm_type == "spatial": - raise NotImplementedError("SpatialNorm not implemented in MAX VAE") - else: - self.conv_norm_out = GroupNorm( - num_groups=norm_num_groups, - num_channels=block_out_channels[0], - eps=1e-6, - affine=True, - device=device, - ) - self.conv_out = Conv2d( - kernel_size=3, - in_channels=block_out_channels[0], - out_channels=out_channels, + self.norm_out = RMSNorm( + final_dim, + images=False, dtype=dtype, - stride=1, + device=device, + ) + self.conv_out = CausalConv3dCached( + final_dim, + z_dim * 2, + 3, padding=1, - has_bias=True, + dtype=dtype, device=device, - permute=True, + has_bias=True, ) - def __call__( - self, z: TensorValue, temb: TensorValue | None = None - ) -> TensorValue: - """Apply Decoder forward pass. + @property + def total_cache_slots(self) -> int: + return 1 + sum(self._block_cache_slots) + 4 + 1 - Args: - z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. - temb: Optional time embedding tensor. + def cache_shapes( + self, + batch_size: int, + height: int | None = None, + width: int | None = None, + ) -> list[list[int | None]]: + """Compute cache shapes for this encoder configuration. - Returns: - Decoded image tensor of shape [N, C_out, H, W] where H and W are - upsampled from H_latent and W_latent. + If height/width are None, those dimensions are dynamic. """ - sample = ( - self.post_quant_conv(z) if self.post_quant_conv is not None else z + dims = [self._dim * u for u in [1, *list(self._dim_mult)]] + h: int | None = height + w: int | None = width + shapes: list[list[int | None]] = [] + + # conv_in cache + shapes.append([batch_size, self._in_channels, CACHE_T, h, w]) + + for i, (in_dim, out_dim) in enumerate(pairwise(dims)): + for j in range(self._num_res_blocks): + block_in = in_dim if j == 0 else out_dim + shapes.append([batch_size, block_in, CACHE_T, h, w]) + shapes.append([batch_size, out_dim, CACHE_T, h, w]) + + down_flag = i != len(self._dim_mult) - 1 + if down_flag: + new_h = (h + 1) // 2 if h is not None else None + new_w = (w + 1) // 2 if w is not None else None + if self._temporal_downsample[i]: + shapes.append([batch_size, out_dim, 1, new_h, new_w]) + h, w = new_h, new_w + + final_dim = dims[-1] + for _ in range(4): + shapes.append([batch_size, final_dim, CACHE_T, h, w]) + + shapes.append([batch_size, final_dim, CACHE_T, h, w]) + + assert len(shapes) == self.total_cache_slots, ( + f"cache_shapes produced {len(shapes)}, " + f"expected {self.total_cache_slots}" ) - sample = self.conv_in(sample) - sample = self.mid_block(sample, temb) - for up_block in self.up_blocks: - sample = up_block(sample, temb) - sample = self.conv_norm_out(sample) - sample = self.activation(sample) - sample = self.conv_out(sample) - return sample + return shapes - def input_types(self) -> tuple[TensorType, ...]: - """Define input tensor types for the decoder model. + def __call__( + self, + x: TensorValue, + *cache_inputs: TensorValue, + first_chunk: bool = False, + ) -> tuple[TensorValue, ...]: + use_cache = len(cache_inputs) == self.total_cache_slots + if len(cache_inputs) not in (0, self.total_cache_slots): + raise ValueError( + f"Encoder3dCached expected 0 or {self.total_cache_slots} " + f"cache tensors, got {len(cache_inputs)}" + ) - Returns: - Tuple of TensorType specifications for decoder input. - """ - return ( - TensorType( - self.dtype, - shape=[ - "batch_size", - self.in_channels, - "latent_height", - "latent_width", - ], - device=self.device, - ), + cache_outputs: list[TensorValue] = [] + idx = 0 + + # conv_in + c_in = cache_inputs[idx] if use_cache else _zero_cache_for(x) + x, c_out = self.conv_in.forward_cached(x, c_in) + cache_outputs.append(c_out) + idx += 1 + + # down_blocks (flat list of ResidualBlockCached and DownResampleCached) + for block in self.down_blocks: + if isinstance(block, ResidualBlockCached): + c1 = cache_inputs[idx] if use_cache else None + c2 = cache_inputs[idx + 1] if use_cache else None + x, co1, co2 = block(x, c1, c2) + cache_outputs.extend([co1, co2]) + idx += 2 + elif isinstance(block, DownResampleCached): + if block._has_temporal: + c = cache_inputs[idx] if use_cache else None + x, co = block(x, cache_in=c, first_chunk=first_chunk) + cache_outputs.append(co) + idx += 1 + else: + (x,) = block(x) + + # mid_block + mid_caches: tuple[TensorValue, ...] = ( + tuple(cache_inputs[idx : idx + 4]) if use_cache else () ) + mid_out = self.mid_block(x, *mid_caches) + x = mid_out[0] + cache_outputs.extend(mid_out[1:]) + idx += 4 + + # norm + silu + conv_out + x = ops.silu(self.norm_out(x)) + c_in = cache_inputs[idx] if use_cache else _zero_cache_for(x) + x, c_out = self.conv_out.forward_cached(x, c_in) + cache_outputs.append(c_out) + + assert len(cache_outputs) == self.total_cache_slots, ( + f"Produced {len(cache_outputs)} caches, " + f"expected {self.total_cache_slots}" + ) + return (x, *cache_outputs) -class DiagonalGaussianDistribution: - r"""Represents a diagonal Gaussian distribution for VAE latent space. +class VAEEncoder(Module): + """Wrapper for VAE encoder graph compilation. - This wrapper intentionally stays lightweight for the Buffer-based VAE path. + Includes quant_conv (1x1 conv applied after encoder output). """ - def __init__( - self, - mean: object, - moments: object | None = None, - ) -> None: - """Initialize DiagonalGaussianDistribution. + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self.encoder = Encoder3d( + dim=config.base_dim, + z_dim=config.z_dim, + in_channels=3, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + temporal_downsample=config.temporal_downsample, + dtype=config.dtype, + device=config.device, + ) + z2 = config.z_dim * 2 + self.quant_conv = CausalConv3d( + z2, + z2, + 1, + padding=0, + dtype=config.dtype, + device=config.device, + has_bias=True, + prefer_nvidia_fcrs=False, + ) - Args: - mean: Mean tensor or mode tensor. - moments: Optional raw moments tensor containing additional VAE - distribution parameters. - """ - self.mean = mean - self.parameters = moments + def __call__(self, x: TensorValue) -> TensorValue: + h = self.encoder(x) + return self.quant_conv(h) - def sample(self, generator: object | None = None) -> object: - """Sample from the distribution using reparameterization trick. - Generates a random sample from the distribution by sampling from a - standard normal distribution and transforming it using the mean and - standard deviation. +class VAEEncoderFirstChunk(Module): + """First-chunk encoder graph: 1 frame in, mean latent + caches out.""" - Args: - generator: Random number generator (currently unused in Max, - kept for compatibility with diffusers API). + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self._z_dim = config.z_dim + self.encoder = Encoder3dCached( + dim=config.base_dim, + z_dim=config.z_dim, + in_channels=3, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + temporal_downsample=config.temporal_downsample, + dtype=config.dtype, + device=config.device, + ) + z2 = config.z_dim * 2 + self.quant_conv = CausalConv3d( + z2, + z2, + 1, + padding=0, + dtype=config.dtype, + device=config.device, + has_bias=True, + ) - Returns: - Sampled tensor of shape [N, C, H, W] with same shape as mean. - """ - del generator - return self.mean + def __call__(self, x: TensorValue) -> tuple[TensorValue, ...]: + outputs = self.encoder(x, first_chunk=True) + moments = self.quant_conv(outputs[0]) + # Extract mean in-graph to avoid GPU->CPU transfer of full moments + mean = moments[:, : self._z_dim, :, :, :] + return (mean, *outputs[1:]) - def mode(self) -> object: - """Return the mode (mean) of the distribution.""" - return self.mean + +class VAEEncoderRestChunk(Module): + """Rest-chunk encoder graph: CHUNK_SIZE frames + caches in, mean latent + caches out.""" + + def __init__(self, config: AutoencoderKLWanConfig) -> None: + super().__init__() + self._z_dim = config.z_dim + self.encoder = Encoder3dCached( + dim=config.base_dim, + z_dim=config.z_dim, + in_channels=3, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + temporal_downsample=config.temporal_downsample, + dtype=config.dtype, + device=config.device, + ) + z2 = config.z_dim * 2 + self.quant_conv = CausalConv3d( + z2, + z2, + 1, + padding=0, + dtype=config.dtype, + device=config.device, + has_bias=True, + ) + + def __call__( + self, x: TensorValue, *cache_inputs: TensorValue + ) -> tuple[TensorValue, ...]: + outputs = self.encoder(x, *cache_inputs, first_chunk=False) + moments = self.quant_conv(outputs[0]) + mean = moments[:, : self._z_dim, :, :, :] + return (mean, *outputs[1:]) diff --git a/max/python/max/pipelines/architectures/flux2/pipeline_flux2.py b/max/python/max/pipelines/architectures/flux2/pipeline_flux2.py index 5b00d37bd0a..da1ec11bb3d 100644 --- a/max/python/max/pipelines/architectures/flux2/pipeline_flux2.py +++ b/max/python/max/pipelines/architectures/flux2/pipeline_flux2.py @@ -35,7 +35,7 @@ from .model import Flux2TransformerModel if TYPE_CHECKING: - from ..autoencoders.vae import DiagonalGaussianDistribution + from ..autoencoders_modulev3.vae import DiagonalGaussianDistribution @dataclass(kw_only=True)