diff --git a/README.md b/README.md index ed40abb..4df2c38 100755 --- a/README.md +++ b/README.md @@ -89,6 +89,77 @@ torchrun --nnodes 1 --nproc_per_node=8 --master_port 17154 \ Set `--nproc_per_node` to the number of GPUs you use. Logs and checkpoints go under `experiments//` (the `name` field in the YAML). +## 🤗 Using with Hugging Face `diffusers` + +The `nvidia/AnyFlow-*-Diffusers` checkpoints can be loaded through the standard `diffusers` API: + +```python +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", + torch_dtype=torch.bfloat16, +).to("cuda") + +video = pipe( + prompt="A red panda eating bamboo in a forest, cinematic lighting", + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "anyflow_t2v.mp4", fps=16) +``` + +For the FAR variant (T2V / I2V / V2V via the `video` or pre-encoded `video_latents` kwarg): + +```python +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", + torch_dtype=torch.bfloat16, +).to("cuda") + +video = pipe( + prompt="A red panda eating bamboo in a forest, cinematic lighting", + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "anyflow_far_t2v.mp4", fps=16) +``` + +For image-to-video, pass a single-frame video tensor of shape `(B, T, C, H, W)` in `[0, 1]` via the `video` kwarg: + +```python +import numpy as np +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video, load_image + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", + torch_dtype=torch.bfloat16, +).to("cuda") + +first_frame = load_image("path/to/first_frame.png").resize((832, 480)) +arr = np.asarray(first_frame).astype("float32") / 255.0 +context = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).unsqueeze(1).to("cuda") # (1, 1, 3, 480, 832) + +video = pipe( + prompt="a cat walks across a sunlit lawn", + video=context, + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "anyflow_far_i2v.mp4", fps=16) +``` + +For video-to-video continuation, use a multi-frame `video` tensor (with `T = 4n + 1` frames). To skip VAE encoding when the conditioning latents are already on disk, pass them as `video_latents=` instead. The same checkpoints also work with the `demo.py` and training entry points in this repository. See the [diffusers AnyFlow docs](https://huggingface.co/docs/diffusers/api/pipelines/anyflow) for the full reference. + + ## 📊 Evaluation Evaluation uses **`mode: eval`** configs under `options/test/anyflow/`. diff --git a/far/models/transformer_far_wan_model.py b/far/models/transformer_far_wan_model.py index 1be2ede..121979f 100644 --- a/far/models/transformer_far_wan_model.py +++ b/far/models/transformer_far_wan_model.py @@ -584,6 +584,16 @@ def forward( return hidden_states +# Bind this class under the AnyFlow* names that `model_index.json` resolves via +# `getattr(diffusers, ...)`. Idempotent: if the diffusers AnyFlow classes are +# already importable, the existing bindings win. +def _register_diffusers_aliases(cls): + import diffusers as _diffusers + for name in ('AnyFlowTransformer3DModel', 'AnyFlowFARTransformer3DModel'): + if not hasattr(_diffusers, name): + setattr(_diffusers, name, cls) + + @MODEL_REGISTRY.register() class FAR_Wan_Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" @@ -695,6 +705,47 @@ def __init__( if init_flowmap_model: self.setup_flowmap_model(gate_value=self.config.gate_value, deltatime_type=self.config.deltatime_type) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """Load checkpoints whose `transformer/config.json` omits `init_*_model`. + + When the config does not specify which submodules to build, derive the + flags from `_class_name`: + + AnyFlowTransformer3DModel -> flow-map embedder + AnyFlowFARTransformer3DModel -> flow-map embedder + FAR patch embedding + + default `chunk_partition` for 81-frame inference + + Configs that already set these fields are passed through unchanged; user + kwargs always win. + """ + load_kwargs = { + k: kwargs[k] + for k in ( + 'subfolder', 'cache_dir', 'force_download', 'proxies', + 'local_files_only', 'token', 'revision', 'variant' + ) + if k in kwargs + } + try: + config_dict = cls.load_config(pretrained_model_name_or_path, **load_kwargs) + except Exception: + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + if ( + 'init_flowmap_model' not in config_dict + and 'init_flowmap_model' not in kwargs + ): + cls_name = config_dict.get('_class_name', '') or '' + is_far = 'FAR' in cls_name + kwargs.setdefault('init_flowmap_model', True) + kwargs.setdefault('init_far_model', is_far) + # The pipeline in this repository reads `chunk_partition` from the + # transformer config; fall back to the 81-frame schedule when absent. + if is_far and 'chunk_partition' not in config_dict and 'chunk_partition' not in kwargs: + kwargs.setdefault('chunk_partition', [1, 3, 3, 3, 3, 3, 3, 2]) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + def setup_flowmap_model(self, gate_value=0, deltatime_type='r'): inner_dim = self.config.num_attention_heads * self.config.attention_head_dim @@ -1214,3 +1265,7 @@ def _forward_bidirection( return (output,) return Transformer2DModelOutput(sample=output) + + +# See _register_diffusers_aliases above. +_register_diffusers_aliases(FAR_Wan_Transformer3DModel) diff --git a/far/pipelines/pipeline_far_wan_anyflow.py b/far/pipelines/pipeline_far_wan_anyflow.py index e2f152a..2b82815 100644 --- a/far/pipelines/pipeline_far_wan_anyflow.py +++ b/far/pipelines/pipeline_far_wan_anyflow.py @@ -112,6 +112,39 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """Load checkpoints whose `model_index.json` references the diffusers AnyFlow + class names. + + Pre-instantiates the transformer and scheduler with the classes defined in + this repository and passes them as kwargs, so `DiffusionPipeline.from_pretrained` + skips its module class lookup for those entries. text_encoder / tokenizer / vae + still load normally. + """ + load_kwargs = { + k: kwargs[k] + for k in ( + 'cache_dir', 'force_download', 'proxies', 'local_files_only', + 'token', 'revision', 'variant' + ) + if k in kwargs + } + if 'transformer' not in kwargs: + kwargs['transformer'] = FAR_Wan_Transformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder='transformer', + torch_dtype=kwargs.get('torch_dtype'), + **load_kwargs, + ) + if 'scheduler' not in kwargs: + kwargs['scheduler'] = FlowMapDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path, + subfolder='scheduler', + **load_kwargs, + ) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/far/pipelines/pipeline_wan_anyflow.py b/far/pipelines/pipeline_wan_anyflow.py index ab84098..5803c63 100644 --- a/far/pipelines/pipeline_wan_anyflow.py +++ b/far/pipelines/pipeline_wan_anyflow.py @@ -111,6 +111,39 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """Load checkpoints whose `model_index.json` references the diffusers AnyFlow + class names. + + Pre-instantiates the transformer and scheduler with the classes defined in + this repository and passes them as kwargs, so `DiffusionPipeline.from_pretrained` + skips its module class lookup for those entries. text_encoder / tokenizer / vae + still load normally. + """ + load_kwargs = { + k: kwargs[k] + for k in ( + 'cache_dir', 'force_download', 'proxies', 'local_files_only', + 'token', 'revision', 'variant' + ) + if k in kwargs + } + if 'transformer' not in kwargs: + kwargs['transformer'] = FAR_Wan_Transformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder='transformer', + torch_dtype=kwargs.get('torch_dtype'), + **load_kwargs, + ) + if 'scheduler' not in kwargs: + kwargs['scheduler'] = FlowMapDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path, + subfolder='scheduler', + **load_kwargs, + ) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/far/schedulers/scheduling_flowmap_euler_discrete.py b/far/schedulers/scheduling_flowmap_euler_discrete.py index 8d18f60..9290980 100755 --- a/far/schedulers/scheduling_flowmap_euler_discrete.py +++ b/far/schedulers/scheduling_flowmap_euler_discrete.py @@ -104,3 +104,15 @@ def step( r_timestep = r_timestep.view(*r_timestep.shape, *([1] * (model_output.ndim - r_timestep.ndim))) prev_sample = sample - (timestep - r_timestep) * model_output return prev_sample.to(model_output.dtype) + + +# Expose this scheduler under the name used by the diffusers AnyFlow pipeline. +FlowMapEulerDiscreteScheduler = FlowMapDiscreteScheduler + +# Bind the same alias on the `diffusers` package so it can be resolved via +# `getattr`. Idempotent: if diffusers already provides this class, the existing +# binding wins. +import diffusers as _diffusers # noqa: E402 + +if not hasattr(_diffusers, 'FlowMapEulerDiscreteScheduler'): + _diffusers.FlowMapEulerDiscreteScheduler = FlowMapEulerDiscreteScheduler diff --git a/scripts/convert_model/convert_anyflow_to_diffusers.py b/scripts/convert_model/convert_anyflow_to_diffusers.py index 15860e2..a1a1f76 100644 --- a/scripts/convert_model/convert_anyflow_to_diffusers.py +++ b/scripts/convert_model/convert_anyflow_to_diffusers.py @@ -14,25 +14,94 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Convert an AnyFlow training checkpoint into a diffusers ``save_pretrained`` directory. + +The AnyFlow training loop in this repository emits ``.pt`` files containing an ``ema`` key whose +value is a flat transformer state dict (see ``far/main.py`` ``save_checkpoint``). This script: + +1. Loads the matching base Wan2.1 pipeline from the Hub (provides VAE, tokenizer, text encoder). +2. Constructs an ``AnyFlowTransformer3DModel`` (bidirectional) or ``AnyFlowFARTransformer3DModel`` + (FAR causal) **from the diffusers library** with the right config flags. +3. Loads the ``ema`` weights into the transformer. +4. Wraps everything in an ``AnyFlowPipeline`` / ``AnyFlowFARPipeline`` (also from diffusers). +5. Calls ``pipeline.save_pretrained(output_dir)``. + +Unlike earlier revisions of this script, the resulting directory is the **canonical diffusers +layout** — ``model_index.json`` references ``AnyFlowPipeline`` / ``AnyFlowFARPipeline`` / +``AnyFlowTransformer3DModel`` / ``AnyFlowFARTransformer3DModel`` / +``FlowMapEulerDiscreteScheduler``, all importable from ``diffusers`` directly — so the output +loads via ``AnyFlowPipeline.from_pretrained(...)`` without any compat shim. + +Requires diffusers ≥ 0.36 (with AnyFlow merged via huggingface/diffusers#13745). + +CLI: + python -m scripts.convert_model.convert_anyflow_to_diffusers \\ + model_type=AnyFlow-Wan2.1-T2V-14B-Diffusers \\ + model_path=experiments/pretrained_models/AnyFlow_Demo/anyflow_v1.0/anyflow-wan-14b.pt \\ + model_save_dir=experiments/pretrained_models/AnyFlow-Diffusers-V1.0 +""" + +import logging import os from dataclasses import dataclass -from omegaconf import MISSING, OmegaConf - -import decord import torch -from diffusers.utils import export_to_video -from PIL import Image -from torchvision import transforms - -from far.models.transformer_far_wan_model import FAR_Wan_Transformer3DModel -from far.pipelines.pipeline_far_wan_anyflow import FARWanAnyFlowPipeline -from far.pipelines.pipeline_wan_anyflow import WanAnyFlowPipeline -from far.schedulers.scheduling_flowmap_euler_discrete import FlowMapDiscreteScheduler -from far.utils.video_util import select_frame_indices -from far.utils.vis_util import draw_rectangle +from omegaconf import MISSING, OmegaConf -decord.bridge.set_bridge('torch') +try: + from diffusers import ( + AnyFlowFARPipeline, + AnyFlowFARTransformer3DModel, + AnyFlowPipeline, + AnyFlowTransformer3DModel, + FlowMapEulerDiscreteScheduler, + ) +except ImportError as exc: + raise ImportError( + 'This conversion script requires the diffusers AnyFlow classes ' + '(huggingface/diffusers#13745, available in diffusers ≥ 0.36). ' + 'Upgrade with `pip install -U diffusers`.' + ) from exc + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') + + +# Per-variant configuration. `base_model` is fetched from the Hub to source the matching VAE / +# tokenizer / text encoder; the transformer is rebuilt locally and the AnyFlow weights are loaded. +VARIANTS = { + 'AnyFlow-FAR-Wan2.1-1.3B-Diffusers': { + 'base_model': 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers', + 'transformer_cls': AnyFlowFARTransformer3DModel, + 'transformer_kwargs': { + 'full_chunk_limit': 3, + 'compressed_patch_size': [1, 4, 4], + }, + 'pipeline_cls': AnyFlowFARPipeline, + }, + 'AnyFlow-FAR-Wan2.1-14B-Diffusers': { + 'base_model': 'Wan-AI/Wan2.1-T2V-14B-Diffusers', + 'transformer_cls': AnyFlowFARTransformer3DModel, + 'transformer_kwargs': { + 'full_chunk_limit': 3, + 'compressed_patch_size': [1, 4, 4], + }, + 'pipeline_cls': AnyFlowFARPipeline, + }, + 'AnyFlow-Wan2.1-T2V-1.3B-Diffusers': { + 'base_model': 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers', + 'transformer_cls': AnyFlowTransformer3DModel, + 'transformer_kwargs': {}, + 'pipeline_cls': AnyFlowPipeline, + }, + 'AnyFlow-Wan2.1-T2V-14B-Diffusers': { + 'base_model': 'Wan-AI/Wan2.1-T2V-14B-Diffusers', + 'transformer_cls': AnyFlowTransformer3DModel, + 'transformer_kwargs': {}, + 'pipeline_cls': AnyFlowPipeline, + }, +} @dataclass @@ -44,72 +113,72 @@ class ConvertAnyflowToDiffusersConfig: model_type: str = 'AnyFlow-FAR-Wan2.1-1.3B-Diffusers' # Path to the AnyFlow .pt checkpoint (expects an `ema` entry in the dict). model_path: str = MISSING - # Output directory for `pipeline.save_pretrained`. + # Output directory; the variant name is appended automatically (matches the released layout). model_save_dir: str = MISSING - - -def build_causal_pipeline(model_type, model_path): - - far_config = { - 'full_chunk_limit': 3, - 'chunk_partition': [1, 3, 3, 3, 3, 3, 3, 2], - 'compressed_patch_size': [1, 4, 4] - } - - if model_type == 'AnyFlow-FAR-Wan2.1-1.3B-Diffusers': - base_model_name = 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers' - elif model_type == 'AnyFlow-FAR-Wan2.1-14B-Diffusers': - base_model_name = 'Wan-AI/Wan2.1-T2V-14B-Diffusers' - else: - raise NotImplementedError - - transformer = FAR_Wan_Transformer3DModel.from_pretrained( - base_model_name, - chunk_partition=far_config['chunk_partition'], - full_chunk_limit=far_config['full_chunk_limit'], - compressed_patch_size=far_config['compressed_patch_size'], - subfolder='transformer' + # Scheduler hyperparameters. Defaults match the released AnyFlow distillation recipe; override + # when converting a checkpoint trained with a different schedule (e.g. a higher-resolution run + # that re-tuned `shift`). Reading from the training yaml is also fine — just pass the value. + shift: float = 5.0 + num_train_timesteps: int = 1000 + # Which state dict to read from the .pt. `ema` is the released setting; `model_state_dict_g` + # is the raw (non-EMA) generator weights, useful for ablation runs. + source: str = 'ema' + + +def build_pipeline( + model_type: str, + model_path: str, + shift: float = 5.0, + num_train_timesteps: int = 1000, + source: str = 'ema', +): + if model_type not in VARIANTS: + raise ValueError(f'Unknown model_type {model_type!r}. Choices: {list(VARIANTS)}.') + if source not in {'ema', 'model_state_dict_g'}: + raise ValueError(f"Unknown source {source!r}. Choices: 'ema', 'model_state_dict_g'.") + spec = VARIANTS[model_type] + + # Construct the diffusers transformer with the variant's per-variant config (FAR adds + # compressed_patch_size + full_chunk_limit). gate_value / deltatime_type match the + # released AnyFlow distillation recipe. + transformer = spec['transformer_cls'].from_pretrained( + spec['base_model'], + subfolder='transformer', + gate_value=0.25, + deltatime_type='r', + **spec['transformer_kwargs'], ) - transformer.setup_far_model() - transformer.setup_flowmap_model(gate_value=0.25, deltatime_type="r") - transformer.register_to_config(init_far_model=True, init_flowmap_model=True, deltatime_type='r', gate_value=0.25) - - # load model - state_dict = torch.load(model_path)['ema'] - transformer.load_state_dict(state_dict) - transformer = transformer.to('cuda', dtype=torch.bfloat16) - scheduler = FlowMapDiscreteScheduler(shift=5, num_train_timesteps=1000) + # AnyFlow training checkpoints store the EMA state alongside a bit of Python metadata, so + # `weights_only=False` is required for the unpickle. Only run this script on checkpoints you + # trust. `strict=False` accommodates the EMA bookkeeping keys; tensor keys are bit-exact + # compatible between FAR_Wan_Transformer3DModel and the diffusers AnyFlow classes (verified + # against the released NVlabs checkpoints). + raw = torch.load(model_path, map_location='cpu', weights_only=False) + if source not in raw: + raise KeyError( + f"Checkpoint at {model_path!r} has no key {source!r}. " + f"Available top-level keys: {list(raw.keys())[:8]}." + ) + state_dict = raw[source] + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + if unexpected: + head = ', '.join(unexpected[:5]) + ('...' if len(unexpected) > 5 else '') + logger.warning('Unexpected keys in state dict (ignored): %s', head) + if missing: + head = ', '.join(missing[:5]) + ('...' if len(missing) > 5 else '') + logger.warning('Missing keys not loaded from state dict: %s', head) - pipeline = FARWanAnyFlowPipeline.from_pretrained(base_model_name, transformer=transformer, scheduler=scheduler) - pipeline.to('cuda') - - return pipeline - - -def build_bidirectional_pipeline(model_type='AnyFlow-Wan2.1-T2V-1.3B-Diffusers', model_path=None): - - if model_type == 'AnyFlow-Wan2.1-T2V-1.3B-Diffusers': - base_model_name = 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers' - elif model_type == 'AnyFlow-Wan2.1-T2V-14B-Diffusers': - base_model_name = 'Wan-AI/Wan2.1-T2V-14B-Diffusers' - else: - raise NotImplementedError - - transformer = FAR_Wan_Transformer3DModel.from_pretrained(base_model_name, subfolder='transformer') - transformer.setup_flowmap_model(gate_value=0.25, deltatime_type='r') - transformer.register_to_config(init_flowmap_model=True, deltatime_type='r', gate_value=0.25) - - # load model - state_dict = torch.load(model_path)['ema'] - transformer.load_state_dict(state_dict) transformer = transformer.to('cuda', dtype=torch.bfloat16) - scheduler = FlowMapDiscreteScheduler(shift=5, num_train_timesteps=1000) + scheduler = FlowMapEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps, shift=shift) - pipeline = WanAnyFlowPipeline.from_pretrained(base_model_name, transformer=transformer, scheduler=scheduler) + pipeline = spec['pipeline_cls'].from_pretrained( + spec['base_model'], + transformer=transformer, + scheduler=scheduler, + ) pipeline.to('cuda') - return pipeline @@ -119,29 +188,18 @@ def build_bidirectional_pipeline(model_type='AnyFlow-Wan2.1-T2V-1.3B-Diffusers', OmegaConf.from_cli(), ) - cfg.model_save_dir = os.path.join(cfg.model_save_dir, cfg.model_type) - - os.makedirs(cfg.model_save_dir, exist_ok=True) - - if 'FAR' in cfg.model_type: - pipeline = build_causal_pipeline(cfg.model_type, cfg.model_path) - else: - pipeline = build_bidirectional_pipeline(cfg.model_type, model_path=cfg.model_path) - - pipeline.save_pretrained(cfg.model_save_dir) + save_dir = os.path.join(cfg.model_save_dir, cfg.model_type) + os.makedirs(save_dir, exist_ok=True) -""" -Convert AnyFlow checkpoint weights into a Diffusers pipeline on disk. - -CLI variables: - model_type — Causal: AnyFlow-FAR-Wan2.1-1.3B-Diffusers, AnyFlow-FAR-Wan2.1-14B-Diffusers; - Bidirectional: AnyFlow-Wan2.1-T2V-1.3B-Diffusers, AnyFlow-Wan2.1-T2V-14B-Diffusers. - model_path — Input .pt checkpoint (lora-merged state_dict and contains `ema`). - model_save_dir — Output directory for `pipeline.save_pretrained` (required). - -Example: -python -m scripts.convert_model.convert_anyflow_to_diffusers \ - model_type=AnyFlow-Wan2.1-T2V-14B-Diffusers \ - model_path=experiments/pretrained_models/AnyFlow_Demo/anyflow_v1.0/anyflow-wan-14b.pt \ - model_save_dir=experiments/pretrained_models/AnyFlow-Diffusers-V1.0/ -""" \ No newline at end of file + pipeline = build_pipeline( + cfg.model_type, + cfg.model_path, + shift=cfg.shift, + num_train_timesteps=cfg.num_train_timesteps, + source=cfg.source, + ) + pipeline.save_pretrained(save_dir) + logger.info( + 'Saved %s pipeline to %s (scheduler: shift=%s num_train_timesteps=%s, source=%s)', + cfg.model_type, save_dir, cfg.shift, cfg.num_train_timesteps, cfg.source, + )