FairGen is a fairness-aware diffusion framework for medical image synthesis and downstream diagnosis. It is designed for three medical imaging settings, including dermatology, brain MRI, and chest X-ray, and focuses on improving coverage for underrepresented demographic subgroups while preserving image quality and diagnostic utility.
This repository includes training, inference, and downstream evaluation code for FairGen and related baselines such as Vanilla Stable Diffusion, CBCB, and CBDM.
FairGen pipeline overview. Starting from imbalanced medical datasets, FairGen combines subgroup-aware data balancing, preference-aligned diffusion training, and downstream augmentation to improve fairness across sensitive attributes such as skin tone, age, and gender.
- Train diffusion backbones for skin, MRI, and chest X-ray synthesis.
- Align generation with physician preferences using DPO-based supervision.
- Generate balanced synthetic datasets for underrepresented subgroups.
- Train downstream diagnostic classifiers with real and synthetic data.
This section provides step-by-step instructions for training FairGen and related baselines for medical image synthesis.
Ensure you have the necessary dependencies installed. It is recommended to use a virtual environment (e.g., Conda).
If you have the requirements.txt file provided in this repository:
# Create and activate environment
conda create -n fairgen python=3.9
conda activate fairgen
# Install dependencies
pip install -r requirements.txtIf you prefer to install packages manually or requirements.txt is not available:
# Create and activate environment
conda create -n fairgen python=3.9
conda activate fairgen
# Install core dependencies (ensure CUDA compatibility)
pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/cu118
pip install diffusers["torch"] transformers accelerate datasets
pip install wandb umap-learn scikit-learnThe training script expects a standard ImageFolder structure or a HuggingFace dataset format.
dataset_root/
├── train/
│ ├── metadata.jsonl # Contains {"file_name": "img1.jpg", "text": "prompt..."}
│ ├── img1.jpg
│ ├── img2.jpg
│ └── ...
Required only for FairGen. You must prepare a JSONL file containing physician-annotated pairs.
File location: /path/to/dpo_folder/physician_preferences.jsonl
Format:
{"text": "Demented Age Above 75", "image_w": "path/to/winner.jpg", "image_l": "path/to/loser.jpg"}
{"text": "Skin lesion dark skin", "image_w": "path/to/winner.jpg", "image_l": "path/to/loser.jpg"}We provide unified training scripts that handle different modalities (Skin, MRI, Chest X-ray) via the --modality flag, located on diffusers/examples/text_to_image folder.
Use this command to train baseline models (e.g., CBCB) without DPO alignment.
Example: Training Skin Modality (CBCB)
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="/path/to/your/dataset_skin"
export OUTPUT_DIR="./checkpoints/sd_skin_cbcb"
accelerate launch --mixed_precision="fp16" /path/to/your/train_text_to_imagecbcb.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--modality="skin" \
--use_ema \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-5 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir=$OUTPUT_DIRExample: Training MRI Modality (CBCB)
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="/path/to/your/dataset_mri"
export OUTPUT_DIR="./checkpoints/sd_mri_cbcb"
accelerate launch --mixed_precision="fp16" /path/to/your/train_text_to_imagecbcb.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--modality="mri" \
--use_ema \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-5 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir=$OUTPUT_DIRFairGen utilizes a dual-stream training process:
- Regularization Stream: Maintains image fidelity using the original dataset.
- Alignment Stream: Optimizes for physician preference using DPO.
Key Flags:
--enable_dpo: Activates the DPO loss calculation.
--dpo_data_dir: Path to the folder containing physician_preferences.jsonl.
--beta_dpo: The λ parameter in Eq. 8 (Controls preference strength). Default is 0.5.
Example: Training Skin Modality (FairGen)
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
# Ideally, load a pre-trained baseline checkpoint to converge faster:
# export MODEL_NAME="./checkpoints/sd_skin_cbcb"
export TRAIN_DIR="/path/to/your/dataset_skin"
export DPO_DIR="/path/to/your/physician_preference_data"
export OUTPUT_DIR="./checkpoints/sd_skin_fairgen"
accelerate launch --mixed_precision="fp16" /path/to/your/train_text_to_image_FairGen.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--modality="skin" \
--enable_dpo \
--dpo_data_dir=$DPO_DIR \
--beta_dpo=0.5 \
--use_ema \
--resolution=512 \
--center_crop \
--random_flip \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-5 \
--max_grad_norm=1 \
--lr_scheduler="cosine" \
--lr_warmup_steps=0 \
--output_dir=$OUTPUT_DIRThe --modality flag automatically adjusts internal parameters (e.g., number of classes for balancing loss).
| Modality | Flag | Internal num_class
|
Key Attributes |
|---|---|---|---|
| Dermatology | --modality="skin" |
15 (3 tones |
Skin Tone, Disease Type |
| Brain MRI | --modality="mri" |
4 (2 ages |
Age Group, Dementia Status |
| Chest X-ray | --modality="chest" |
10 (2 genders |
Gender, Finding Type |
--beta_dpo(Lambda):- Range: 0.1 to 1.0.
- Increase (e.g., 1.0): If the generated images do not sufficiently reflect physician preferences (e.g., structural features are still generic).
- Decrease (e.g., 0.1): If the training becomes unstable or image quality degrades (artifacts appear).
--learning_rate:- For DPO fine-tuning, a lower learning rate (e.g.,
1e-5or5e-6) is often more stable than training from scratch.
- For DPO fine-tuning, a lower learning rate (e.g.,
This guide explains how to generate synthetic medical images using trained FairGen models (or baselines like CBCB, CBDM, Vanilla SD).
We provide a universal inference script src/inference.py.
Skin modality includes 15 subgroups (3 skin tones
export UNET_PATH="/ocean/projects/ccr200024p/zli27/sd_xray/sd_skin_model/DPO_fairgen_model/checkpoint-15000/unet"
export OUT_DIR="/ocean/projects/ccr200024p/zli27/sd_xray/output/DPO_sd_skin/fairgen"
python /path/to/your/inference.py \
--modality="skin" \
--model_path=$UNET_PATH \
--output_dir=$OUT_DIR \
--num_images_per_class=1000 \
--batch_size=4This guide outlines the process for training downstream diagnostic classifiers using datasets augmented by FairGen. We provide specialized scripts for three medical imaging modalities: Chest X-ray, Dermatology, and Brain MRI.
The downstream task involves training a Vision Transformer (ViT) classifier to diagnose diseases. To address class imbalance and demographic bias, our training pipeline incorporates:
- Dual-Label Dataset: Handles both disease labels (for classification) and sensitive attribute labels (for fairness evaluation).
- Stratified Splitting: Ensures train/val/test sets maintain demographic distributions.
- Reweighting Strategy: Dynamically adjusts sampling weights based on subgroup performance (Adaptive Inverse-Performance Reweighting).
The training scripts expect the data to be organized in an ImageFolder format, where subfolder names contain both demographic and disease information.
input_dataset_root/
├── downstream_skin/
│ ├── African_people_allergic_contact_dermatitis/
│ ├── African_people_basal_cell_carcinoma/
│ ├── African_people_lichen_planus/
│ ├── African_people_psoriasis/
│ ├── African_people_squamous_cell_carcinoma/
│ ├── Asian_people_allergic_contact_dermatitis/
│ ├── ... (other Asian subfolders)
│ ├── Caucasian_people_allergic_contact_dermatitis/
│ └── ... (other Caucasian subfolders)
├── downstream_xray/
│ ├── female_COVID19/
│ ├── female_Edema/
│ ├── female_Lung_Opacity/
│ ├── female_No_Finding/
│ ├── female_Pleural_Effusion/
│ ├── male_COVID19/
│ ├── male_Edema/
│ ├── male_Lung_Opacity/
│ ├── male_No_Finding/
│ └── male_Pleural_Effusion/
└── downstream_mri/
├── Demented_Age_Above_75/
├── Demented_Age_Below_75/
├── Nondemented_Age_Above_75/
└── Nondemented_Age_Below_75/
And you should also make sure your augmentation dataset directory structure shuold also remain same. You could sync it when you inference the generated diffusion model.
I will take chest xray for example, representing Dementia status (Demented vs. Nondemented).
Note: The MRI script uses a lower default learning rate (1e-5) for stability.
Command:
python src/downstream/classify_reweight_mri.py \
--data "/path/to/real_mri_data" \
--aug_data "/path/to/fairgen_mri_data" \
--lr 1e-5 \
--batchsize 64 \
--epochs 10 \
--best_worst