Skip to content

grad_norm=NaN During NFT Training on Flux1.d-dev #134

@rlustc

Description

@rlustc

DiffusionNFT 训练中频繁出现 grad_norm=NaN 这导致我的训练完全不可行 具体的训练yaml文件如下

Environment Configuration

launcher: "accelerate" # Options: accelerate
config_file: config/accelerate_configs/fsdp_full_shard.yaml # Path to distributed config file (optional)
num_processes: 8 # Number of processes to launch (overrides config file)
main_process_port: 29500
mixed_precision: "bf16" # Options: no, fp16, bf16

Data Configuration

data:
dataset_dir: "dataset/pickscore" # Path to dataset folder
preprocessing_batch_size: 8 # Batch size for preprocessing
dataloader_num_workers: 16 # Number of workers for DataLoader
force_reprocess: false # Force reprocessing of the dataset
cache_dir: "~/.cache/flow_factory/datasets" # Cache directory for preprocessed datasets
max_dataset_size: 1024 # Limit the maximum number of samples in the dataset
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous

Model Configuration

model:
finetune_type: 'full' # Options: full, lora
target_modules: "default" # Options: all, default, or list of module names like ["to_k", "to_q", "to_v", "to_out.0"]
model_name_or_path: "/data/aigc/liangyzh_intern/Lirui/Flux1.0-dev" # HuggingFace model ID or local path
model_type: "flux1"
resume_path: null # Path to load previous checkpoint/lora adapter
resume_type: null # Options: lora, full, state. Null to auto-detect based on finetune_type

log:
run_name: null # Run name (auto: {model_type}{finetune_type}{trainer_type}_{timestamp})
project: "Flow-Factory" # Project name for logging
logging_backend: "tensorboard" # Options: wandb, swanlab, none
save_dir: "saves/" # Directory to save model checkpoints and logs
save_freq: 40 # Save frequency in epochs (0 to disable)
save_model_only: true # Save only the model weights (not optimizer, scheduler, etc.)

Training Configuration

train:

Trainer settings

trainer_type: 'nft'
advantage_aggregation: 'sum' # Options: 'sum', 'gdpo'
nft_beta: 1

Old Policy settings

off_policy: true # Whether to use ema parameters for sampling off-policy data.
ema_decay_schedule: "piecewise_linear" # Decay schedule for EMA. Options: ['constant', 'power', 'linear', 'piecewise_linear', 'cosine', 'warmup_cosine']
flat_steps: 0
ramp_rate: 0.001
ema_decay: 0.5 # EMA decay rate (0 to disable)
ema_update_interval: 1 # EMA update interval (in epochs)
ema_device: "cpu" # Device to store EMA model (options: cpu, cuda)

Training Timestep distribution

num_train_timesteps: 2 # Set null to all steps
time_sampling_strategy: discrete # Options: uniform, logit_normal, discrete, discrete_with_init, discrete_wo_init
time_shift: 3.0
timestep_range: 0.7 # Select fraction of timesteps to train on

KL div

kl_type: 'v-based'
kl_beta: 0 # KL divergence beta, 0 to disable
ref_param_device: 'cpu' # Options: cpu, cuda

Clipping

adv_clip_range: 5.0 # Advantage clipping range

Sampling

resolution: 384 # Can be int or [height, width]
num_inference_steps: 8 # Number of timesteps
guidance_scale: 3.5 # Guidance scale for sampling

Batch and sampling

per_device_batch_size: 1 # Batch size per device
group_size: 16 # Group size for GRPO sampling
global_std: false # Use global std for advantage normalization
unique_sample_num_per_epoch: 48 # Unique samples per group
gradient_step_per_epoch: 1 # Gradient steps per epoch. The first step is on-policy, the rest are off-policy.
gradient_accumulation_steps: auto # Options: auto, or positive integer. When set, gradient_step_per_epoch is ignored.

Optimization

learning_rate: 1.0e-5 # Initial learning rate
adam_weight_decay: 1.0e-4 # AdamW weight decay
adam_betas: [0.9, 0.999] # AdamW betas
adam_epsilon: 1.0e-8 # AdamW epsilon
max_grad_norm: 1.0 # Max gradient norm for clipping

Gradient checkpointing

enable_gradient_checkpointing: true # Enable gradient checkpointing to save memory with extra compute

Seed

seed: 42 # Random seed

Scheduler Configuration

scheduler:
dynamics_type: "ODE" # Options: Flow-SDE, Dance-SDE, CPS, ODE

Evaluation settings

eval:
resolution: 1024 # Evaluation resolution
per_device_batch_size: 1 # Eval batch size
guidance_scale: 3.5 # Guidance scale for sampling
num_inference_steps: 28 # Number of eval timesteps
eval_freq: 20 # Eval frequency in epochs (0 to disable)
seed: 42 # Eval seed (defaults to training seed)

Reward Model Configuration

rewards:

  • name: "hps"
    reward_model: "HPSv2"
    hps_ckpt_path: "/data/aigc/liangyzh_intern/zqni/DanceGRPO-main/HPSv2/ckpt_all/HPS_v2.1_compressed.pt"
    clip_pretrained_path: "/data/aigc/liangyzh_intern/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    hps_version: "v2.1"
    batch_size: 16
    dtype: bfloat16
    device: "cuda"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions