This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper Pyramidal Patchification Flow for Visual Generation (PPFlow).
Diffusion transformers (DiTs) adopt Patchify, mapping patch representations to token representations through linear projections, to adjust the number of tokens input to DiT blocks and thus the computation cost. Instead of a single patch size for all the timesteps, we introduce a Pyramidal Patchification Flow (PPFlow) approach: Large patch sizes are used for high noise timesteps and small patch sizes for low noise timesteps; Linear projections are learned for each patch size; and Unpatchify is accordingly modified. Unlike Pyramidal Flow, our approach operates over full latent representations other than pyramid representations, and adopts the normal denoising process without requiring the renoising trick. We demonstrate the effectiveness of our approach through two training manners. Training from scratch achieves
a
This repository contains:
- Pre-trained class-conditional PPFlow models trained on ImageNet 256x256
- A PPFlow training script using PyTorch DDP
First, download and set up the repo:
git clone https://github.com/fudan-generative-vision/PPFlow.git
cd PPFlowWe provide an environment.yml file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.
conda env create -f environment.yml
conda activate PPFlowPre-trained PPFlow checkpoints. You can sample from our pre-trained models with sample.py. Weights for our pre-trained model will be
automatically downloaded depending on the model you use. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
our 256x256 PPFlow-XL-2 model with default ODE setting, you can use:
python sample.py ODE --image-size 256 --seed 1Our pre-trained PPFlow models can be downloaded directly here as well:
| Model | Image Resolution | FID-50K | Inception Score |
|---|---|---|---|
| PPF-XL-2 | 256x256 | 1.99 | 271.62 |
| PPF-XL-3 | 256x256 | 2.23 | 286.67 |
Custom SiT checkpoints. If you've trained a new PPFlow model with train.py (see below), you can add the --ckpt
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
256x256 SiT-L/4 model with ODE sampler, run:
python sample.py ODE --model PPFlow_XL_2 --image-size 256 --ckpt /path/to/model.ptWe provide a training script for PPFlow in train.py. To launch PPF-XL-2 (256x256) training with N GPUs on
one node:
torchrun --nnodes=1 --nproc_per_node=N train.py --model PPF_XL_2 --data-path /path/to/imagenet/trainLogging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"Then in training command add the --wandb flag:
torchrun --nnodes=1 --nproc_per_node=N train.py --model PPF_XL_2 --data-path /path/to/imagenet/train --wandbInitialize from pretrained SiT model. To Initialize from pretrained SiT model:
torchrun --nnodes=1 --nproc_per_node=N train.py --model PPF_XL_2 --data-path /path/to/imagenet/train --initialize_from SiT-XL-2-256.ptResume training. To resume training from custom checkpoint:
torchrun --nnodes=1 --nproc_per_node=N train.py --model PPF_XL_2 --data-path /path/to/imagenet/train --ckpt /path/to/model.ptCaution. Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.
We include a sample_ddp.py script which samples a large number of images from a SiT model in parallel. This script
generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. For example, to sample 50K images from our pre-trained PPF_XL_2 model over N GPUs under default ODE sampler settings, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model PPF_XL_2 --num-fid-samples 50000This project is under the MIT license.
We would like to thank the contributors to the SiT and DiT respositories, for their open research and exploration. This code is built on SiT.
