Skip to content

[MAX] Add Wan VAE and refactor autoencoder module#15

Draft
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/2from
jglee-sqbits/stack/3
Draft

[MAX] Add Wan VAE and refactor autoencoder module#15
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/2from
jglee-sqbits/stack/3

Conversation

@jglee-sqbits
Copy link
Copy Markdown
Collaborator

@jglee-sqbits jglee-sqbits commented Apr 1, 2026

Stacked PRs:


[MAX] Add Wan VAE and refactor autoencoder module

Summary

Add a Wan VAE (3D causal video autoencoder) and restructure the autoencoder module to separate Module V2 and V3 implementations.

Description

Wan VAE

  • Implements the Wan 3D causal VAE with temporal caching for chunked encode/decode
  • Encoder: processes video in temporal chunks (first frame + subsequent chunks) with cached convolution state to maintain temporal consistency
  • Decoder: same chunked approach with 3 specialized graphs (post-quant conv, first frame, subsequent frames)
  • Uses symbolic spatial dims for resolution flexibility
  • Adds 3D convolution support via cuDNN (conv.mojo) with depth-tiled execution for large volumes

Autoencoder restructuring

  • Moves existing Module V3 (Flux) autoencoder files to autoencoders_modulev3/
  • The autoencoders/ directory now contains Module V2 graph-based implementations (Wan VAE, Qwen Image VAE)
  • Updates flux1_modulev3 and flux2_modulev3 import paths accordingly

This follows the same pattern as modular#6278 which established the V2/V3 split.

Dependencies

Should be merged before modular#6301 (transformer) and modular#6302 (pipeline-t2v), which import from autoencoders.

Checklist

  • PR is small and focused
  • I ran ./bazelw run format to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

## Summary

Add a Wan VAE (3D causal video autoencoder) and restructure the autoencoder module to separate Module V2 and V3 implementations.

## Description

### Wan VAE
- Implements the Wan 3D causal VAE with temporal caching for chunked encode/decode
- Encoder: processes video in temporal chunks (first frame + subsequent chunks) with cached convolution state to maintain temporal consistency
- Decoder: same chunked approach with 3 specialized graphs (post-quant conv, first frame, subsequent frames)
- Uses symbolic spatial dims for resolution flexibility
- Adds 3D convolution support via cuDNN (`conv.mojo`) with depth-tiled execution for large volumes

### Autoencoder restructuring
- Moves existing Module V3 (Flux) autoencoder files to `autoencoders_modulev3/`
- The `autoencoders/` directory now contains Module V2 graph-based implementations (Wan VAE, Qwen Image VAE)
- Updates `flux1_modulev3` and `flux2_modulev3` import paths accordingly

This follows the same pattern as modular#6278 which established the V2/V3 split.

## Dependencies

Should be merged **before** modular#6301 (transformer) and modular#6302 (pipeline-t2v), which import from `autoencoders`.

## Checklist

- [x] PR is small and focused
- [x] I ran `./bazelw run format` to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

stack-info: PR: #15, branch: jglee-sqbits/stack/3
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for 3D convolutions using cuDNN and adds a new Wan VAE autoencoder implementation. The changes include new cuDNN descriptor APIs, depth-tiled convolution for large tensors, and a new architecture module for Wan. My feedback highlights that the hardcoded total_cache_slots calculation in vae.py is fragile and should be dynamic, and that the comptime if in conv.mojo is redundant and can be simplified to a standard if statement.

"""Apply Decoder forward pass.
@property
def total_cache_slots(self) -> int:
return 1 + sum(self._block_cache_slots) + 4 + 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The total_cache_slots property calculation is hardcoded and fragile. If the number of blocks or cache slots changes in the future, this will lead to silent errors. Consider calculating this dynamically based on the actual number of modules in down_blocks and mid_block.

Comment on lines +4651 to +4661
comptime if filter_is_fcrs:
conv3d_cudnn[input_type, filter_type, output_type](
input_lt,
filter_lt,
output_lt,
rebind[IndexList[3]](stride),
rebind[IndexList[3]](dilation),
rebind[IndexList[3]](symmetric_padding),
num_groups,
ctx,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of comptime if inside the conv_gpu function for algorithm selection is redundant here, as filter_is_fcrs is already a compile-time constant. This can be simplified to a standard if statement for better readability without impacting performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant