Skip to content

Latest commit

 

History

History
47 lines (37 loc) · 1.87 KB

File metadata and controls

47 lines (37 loc) · 1.87 KB

Enhancing Pre-trained Diffusion Models with Reinforcement Learning and Adversarial Reward Functions

This is a course project for Computer Vision, second year course of Yao Class at Tsinghua University. This work is far below paper-level, but to prersent an interesting idea.

Contributors

Final Results

Final report and poster are in the root directory. Same as the ones we submitted to TA and shown in class.

Setup

conda env create -f environment.yaml
conda activate diffppogan

Dataset will be automatically downloaded when running the code. You may refer to Generative Zoo for more details on the dataset.

Training

We didn't provide the training script for the standard diffusion model. You can use the codes in Generative Zoo to train the standard diffusion model, or you can use the pre-trained model provided in pretrained/base.pt.

Please refer to scripts/train.sh for the training script. You can run the following command to start training:

export DATA_DIR=$PWD/data

python -m src.train.train \
    cfg=adv_schedule_r3 \
    cfg.gpu_id=0 \
    cfg.fid.real_image_path=$DATA_DIR/real/cifar10/imgs \
    cfg.ref_model_path=pretrained/base.pt \
    cfg.wandb.name=WANDB_NAME \

Evaluation

You can sample images from the pre-trained model using scripts/sample.sh the following command:

export DATA_DIR=$PWD/data

python -m src.sample.sample \
    cfg=adv_schedule_r3 \
    cfg.gpu_id=0 \
    cfg.ref_model_path=pretrained/base.pt \
    cfg.fid.real_image_path=DATA_DIR/real/cifar10/imgs \
    cfg.checkpoint=pretrained/best.pt

Run scripts/sample_fid.sh to sample images and use scripts/eval_fid.sh to evaluate the FID score.