Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .agents/knowledge/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ All three registries map string keys → lazy import paths. Resolution: registry
| `dpo` | `DPOTrainer` | Decoupled | `BaseTrainer` |
| `nft` | `DiffusionNFTTrainer` | Decoupled | `BaseTrainer` |
| `awm` | `AWMTrainer` | Decoupled | `BaseTrainer` |
| `crd` | `CRDTrainer` | Decoupled | `BaseTrainer` |

**Model Adapters** (`models/registry.py`):
| Key | Class | Task |
Expand Down
128 changes: 128 additions & 0 deletions examples/crd/lora/sd3_5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# CRD (Centered Reward Distillation) with SD3.5 Medium + OCR
# Reference: https://arxiv.org/abs/2603.14128
# Environment Configuration
launcher: "accelerate" # Options: accelerate
config_file: config/accelerate_configs/multi_gpu.yaml # Path to distributed config file (optional)
num_processes: 8 # Number of processes to launch (overrides config file)
main_process_port: 29500
mixed_precision: "bf16" # Options: no, fp16, bf16

# Data Configuration
data:
dataset_dir: "dataset/ocr" # Path to dataset folder
preprocessing_batch_size: 8 # Batch size for preprocessing
dataloader_num_workers: 16 # Number of workers for DataLoader
force_reprocess: false # Force reprocessing of the dataset
cache_dir: "~/.cache/flow_factory/datasets" # Cache directory for preprocessed datasets
max_dataset_size: 1024 # Limit the maximum number of samples in the dataset
sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous

# Model Configuration
model:
finetune_type: 'lora' # Options: full, lora
lora_rank : 32
lora_alpha : 64
target_modules: "default" # Options: all, default, or list of module names like ["to_k", "to_q", "to_v", "to_out.0"]
model_name_or_path: "stabilityai/stable-diffusion-3.5-medium"
model_type: "sd3-5"
resume_path: null # Path to load previous checkpoint/lora adapter
resume_type: null # Options: lora, full, state. Null to auto-detect based on `finetune_type`
# attn_backend: '_flash_3_hub' # Attention backend for training.

log:
run_name: null # Run name (auto: {model_type}_{finetune_type}_{timestamp})
project: "Flow-Factory" # Project name for logging
logging_backend: "wandb" # Options: wandb, swanlab, none
save_dir: "saves/" # Directory to save model checkpoints and logs
save_freq: 20 # Save frequency in epochs (0 to disable)
save_model_only: true # Save only the model weights (not optimizer, scheduler, etc.)

# Training Configuration
train:
# Trainer settings
trainer_type: 'crd'
advantage_aggregation: 'sum'

# CRD-specific settings
crd_beta: 1.0 # Beta scaling for reward matching loss
crd_loss_type: 'mse' # Options: mse, bce
use_old_for_loss: true # Use old model snapshot for implicit reward (core CRD feature)
adaptive_logp: true # Adaptive weighting of implicit reward terms
weight_temp: -1.0 # Softmax temperature for centering (-1 = uniform / inf temp)

# Old model and sampling model decay schedules
# Format: "start_step-start_value-slope-end_value" or int preset (see crd.py _DECAY_PRESETS)
old_model_decay: "0-0.25-0.005-0.999" # Old model: starts at 0.25, ramps to 0.999 (paper Table 3 OCR)
sampling_model_decay: "75-0.0-0.0075-0.999" # Sampling model: delayed start at step 75 (paper Table 3 OCR)

# Training timestep distribution
# num_train_timesteps: 0 # 0 = auto from num_inference_steps * timestep_range
num_train_timesteps: 20 # Set null to all steps
time_sampling_strategy: discrete # Options: uniform, logit_normal, discrete, discrete_with_init, discrete_wo_init
time_shift: 3.0
timestep_range: 0.99 # Original CRD default (top 99% of denoising steps)

# KL regularization
kl_type: 'v-based'
kl_beta: 0.1
kl_cfg: 4.5
reward_adaptive_kl: true # Scale KL by reward signal
ref_param_device: 'cuda'

# Clipping
adv_clip_range: 5.0 # Advantage clipping range

# Sampling Settings
resolution: 512 # Can be int or [height, width]
num_inference_steps: 10 # Number of timesteps of rollout
guidance_scale: 1.0 # Guidance scale for sampling

# Batch and sampling
per_device_batch_size: 8 # Batch size per device. For image-to-image task, this will always fallback to 1.
group_size: 24 # Group size K=24 (paper Table 3)
global_std: true # Use global std for advantage normalization
unique_sample_num_per_epoch: 48 # 48 groups per batch (paper Table 3)
gradient_step_per_epoch: 2 # Optimizer steps per batch = 2 (paper Table 3)
gradient_accumulation_steps: auto # Options: auto, or positive integer. When set, `gradient_step_per_epoch` is ignored.

# Optimization
learning_rate: 3.0e-4 # Initial learning rate
adam_weight_decay: 1.0e-4 # AdamW weight decay
adam_betas: [0.9, 0.999] # AdamW betas
adam_epsilon: 1.0e-8 # AdamW epsilon
max_grad_norm: 1.0 # Max gradient norm for clipping

# Gradient Checkpointing
enable_gradient_checkpointing: false # Enable gradient checkpointing to save memory with extra compute

# Seed
seed: 42 # Random seed

# Scheduler Configuration
scheduler:
dynamics_type: "ODE" # Options: Flow-SDE, Dance-SDE, CPS, ODE

# Evaluation settings
eval:
resolution: 512 # Evaluation resolution
per_device_batch_size: 4 # Eval batch size
guidance_scale: 1.0 # Guidance scale for sampling
num_inference_steps: 40 # Number of eval timesteps
eval_freq: 20 # Eval frequency in epochs (0 to disable)
seed: 42 # Eval seed (defaults to training seed)

# Reward Model Configuration
rewards:
- name: "ocr"
reward_model: "OCR"
weight: 1 # Weight of this reward model
batch_size: 16
device: "cuda"
dtype: bfloat16

eval_rewards:
- name: "ocr"
reward_model: "OCR"
batch_size: 32
device: "cuda"
dtype: bfloat16
63 changes: 62 additions & 1 deletion guidance/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

- [AWM: Advantage Weighted Matching](#awm-advantage-weighted-matching)

- [CRD: Centered Reward Distillation](#crd-centered-reward-distillation)

- [References](#references)

## Overview
Expand Down Expand Up @@ -252,6 +254,64 @@ Here $\varepsilon$ is a small constant for numerical stability and $p$ denotes `
> **Tip**: `ghuber` with a small power (e.g., `0.25`) provides a good balance between robustness and gradient signal strength. `Uniform` is the simplest baseline and works well when reward signals are clean and low-variance.


## CRD: Centered Reward Distillation

This algorithm is introduced in [[10]](#ref10). **Centered Reward Distillation (CRD)** is a forward-process RL method that matches implicit model rewards (estimated from prediction error in velocity space) with centered external rewards. The key insight is that the unknown prompt-dependent normalizer cancels under *within-prompt centering*, yielding a well-posed reward-matching objective.

CRD maintains two named parameter snapshots alongside the current model:
- **Old model** (`_crd_old`): used to estimate implicit rewards via prediction error difference.
- **Sampling model** (`_crd_sampling`): used for off-policy rollout generation, blended toward the current model over time.

To use this algorithm, set:

```yaml
train:
trainer_type: 'crd'
```

### Key Hyperparameters

```yaml
train:
trainer_type: 'crd'

# CRD loss
crd_beta: 1.0 # Scaling factor for reward-matching loss
crd_loss_type: 'mse' # Options: mse, bce
use_old_for_loss: true # Use old model snapshot for implicit reward (recommended)
adaptive_logp: true # Adaptive per-sample weighting of implicit reward terms
weight_temp: -1.0 # Softmax temperature τ for centering (-1 = uniform/τ→∞)

# Model snapshot decay schedules
# Format: "start_step-start_value-slope-end_value" or int preset key
old_model_decay: "0-0.25-0.005-0.999" # Paper (OCR): min(0.25 + 0.005t, 0.999)
sampling_model_decay: "75-0.0-0.0075-0.999" # Paper (OCR): delayed start at step 75

# KL regularization anchored to CFG-guided pretrained reference
kl_beta: 0.1 # KL coefficient
kl_cfg: 4.5 # CFG scale for teacher reference model
reward_adaptive_kl: true # Scale KL by reward to accelerate early learning
ref_param_device: 'cuda'

# Timestep sampling
timestep_range: 0.99 # Top 99% of denoising steps (original CRD default)
num_train_timesteps: 20
time_sampling_strategy: discrete
time_shift: 3.0

# Advantage clipping
adv_clip_range: 5.0
```

### Centering Modes (`weight_temp`)

| `weight_temp` | Mode | Description |
|---|---|---|
| `< 0` | Uniform (τ→∞) | Simple mean centering; recommended default |
| `== 0` | Hard selection | Positive pool (adv > 0) vs negative pool (adv < 0) |
| `> 0` | Softmax temperature | Dual-direction: `softmax(adv/τ)` and `softmax(-adv/τ)` |

Comment on lines +306 to +313

## References

* <a name="ref1"></a>[1] [**Flow-GRPO:** Training Flow Matching Models via Online RL](https://arxiv.org/abs/2505.05470)
Expand All @@ -262,4 +322,5 @@ Here $\varepsilon$ is a small constant for numerical stability and $p$ denotes `
* <a name="ref6"></a>[6] [**PaCo-RL**: Advancing Reinforcement Learning for Consistent Image Generation with Pairwise Reward Modeling](https://arxiv.org/abs/2512.04784)
* <a name="ref7"></a>[7] [**DiffusionNFT**: Online Diffusion Reinforcement with Forward Process](https://arxiv.org/abs/2509.16117)
* <a name="ref8"></a>[8] [**<u>C</u>oefficients-<u>P</u>reserving <u>S</u>ampling** for Reinforcement Learning with Flow Matching](https://arxiv.org/abs/2509.05952)
* <a name="ref9"></a>[9] [**<u>A</u>dvantage <u>W</u>eighted <u>M</u>atching**: Aligning RL with Pretraining in Diffusion Models](https://arxiv.org/abs/2509.25050)
* <a name="ref9"></a>[9] [**<u>A</u>dvantage <u>W</u>eighted <u>M</u>atching**: Aligning RL with Pretraining in Diffusion Models](https://arxiv.org/abs/2509.25050)
* <a name="ref10"></a>[10] [**CRD**: Diffusion Reinforcement Learning via Centered Reward Distillation](https://arxiv.org/abs/2603.14128)
2 changes: 2 additions & 0 deletions src/flow_factory/hparams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
NFTTrainingArguments,
AWMTrainingArguments,
DPOTrainingArguments,
CRDTrainingArguments,
get_training_args_class,
)
from .reward_args import RewardArguments, MultiRewardArguments
Expand All @@ -41,6 +42,7 @@
"NFTTrainingArguments",
"AWMTrainingArguments",
"DPOTrainingArguments",
"CRDTrainingArguments",
"get_training_args_class",
"RewardArguments",
"MultiRewardArguments",
Expand Down
123 changes: 123 additions & 0 deletions src/flow_factory/hparams/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,128 @@ def get_num_train_timesteps(self, args: Any) -> int:
return self.num_train_timesteps


@dataclass
class CRDTrainingArguments(TrainingArguments):
r"""Training arguments for Centered Reward Distillation (CRD).

Reference:
Diffusion Reinforcement Learning via Centered Reward Distillation
https://arxiv.org/abs/2603.14128
"""

# Group-wise advantage normalization
global_std: bool = field(
default=True,
metadata={"help": "Whether to use global std for advantage normalization."},
)
advantage_aggregation: Literal['sum', 'gdpo'] = field(
default='gdpo',
metadata={"help": "Method to aggregate advantages within each group. Options: ['sum', 'gdpo']."},
)

# CRD core
crd_beta: float = field(
default=1.0,
metadata={"help": "Beta scaling for CRD reward matching loss. Controls implicit vs external reward balance."},
)
crd_loss_type: Literal['mse', 'bce'] = field(
default='mse',
metadata={"help": "Loss type for CRD reward distillation. 'mse': squared error, 'bce': binary cross-entropy."},
)
use_old_for_loss: bool = field(
default=True,
metadata={"help": "Use 'old' model snapshot (instead of ref) for implicit reward estimation."},
)
adaptive_logp: bool = field(
default=True,
metadata={"help": "Adaptively weight implicit reward terms by prediction error magnitude."},
)
weight_temp: float = field(
default=-1.0,
metadata={"help": "Temperature for softmax weighting of advantages in CRD. Negative means uniform (inf temp)."},
)
# Decay schedules for model snapshots
old_model_decay: str = field(
default="0-0.25-0.005-0.999",
metadata={"help": "Decay schedule for old model blending: 'start_step-start_value-slope-end_value' or preset name."},
)
sampling_model_decay: Union[str, int] = field(
default="75-0.0-0.0075-0.999",
metadata={"help": "Decay schedule for sampling model blending. Same format as old_model_decay, or int preset."},
)

# Clipping / KL
adv_clip_range: tuple[float, float] = field(
default=(-5.0, 5.0),
metadata={"help": "Clipping range for advantages."},
)
kl_type: Literal['v-based'] = field(
default='v-based',
metadata={"help": "Type of KL divergence. CRD uses 'v-based' (velocity space)."},
)
kl_beta: float = field(
default=0.1,
metadata={"help": "KL penalty beta for regularization against the reference model."},
)
kl_cfg: float = field(
default=4.5,
metadata={
"help": (
"CFG scale for the teacher (reference) model during KL computation. "
"If > 1.0, the reference forward pass uses classifier-free guidance: "
"``noise_pred = uncond + kl_cfg * (cond - uncond)``. "
"Set to 1.0 (default) to disable CFG on the teacher."
)
},
)
reward_adaptive_kl: bool = field(
default=True,
metadata={"help": "Dynamically adjust KL strength based on reward signal."},
)
ref_param_device: Literal["cpu", "cuda"] = field(
default="cuda",
metadata={"help": "Device to store reference model parameters."},
)

# Timestep control
num_train_timesteps: int = field(
default=0,
metadata={"help": "Number of training timesteps. 0 = auto from num_inference_steps * timestep_range."},
)
time_sampling_strategy: Literal['uniform', 'logit_normal', 'discrete', 'discrete_with_init', 'discrete_wo_init'] = field(
default='discrete',
metadata={"help": "Time sampling strategy for training."},
)
time_shift: float = field(
default=3.0,
metadata={"help": "Time shift for logit normal time sampling."},
)
timestep_range: Union[float, Tuple[float, float]] = field(
default=0.99,
metadata={
"help": "Fraction range along denoise axis 1000→0. Default 0.99 matches original CRD's timestep_fraction."
},
)

def __post_init__(self):
super().__post_init__()
self.timestep_range = _standardize_timestep_range(self.timestep_range)
if not self.num_train_timesteps or self.num_train_timesteps <= 0:
self.num_train_timesteps = max(1, int(
self.num_inference_steps * (self.timestep_range[1] - self.timestep_range[0])
))
self.adv_clip_range = _standardize_clip_range(self.adv_clip_range, 'adv_clip_range')

@property
def requires_ref_model(self) -> bool:
"""CRD always needs a reference model for KL and implicit reward."""
return True

def get_num_train_timesteps(self, args: Any) -> int:
assert self.num_train_timesteps is not None
return self.num_train_timesteps


# ============================================================================
# Training Arguments Registry
# ============================================================================
Expand All @@ -691,6 +813,7 @@ def get_num_train_timesteps(self, args: Any) -> int:
'nft': NFTTrainingArguments,
'awm': AWMTrainingArguments,
'dpo': DPOTrainingArguments,
'crd': CRDTrainingArguments,
}


Expand Down
Loading
Loading