From 396dd6e28a6f1a43d3767c950b79976e89f1a072 Mon Sep 17 00:00:00 2001 From: yuanzhi-zhu Date: Fri, 10 Apr 2026 12:12:41 +0200 Subject: [PATCH] [trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm Implements Centered Reward Distillation (arXiv:2603.14128) as a new decoupled RL trainer for flow-matching models. Key changes: - `trainers/crd.py`: Full CRDTrainer implementation with old/sampling model snapshots, dual-direction centering loss, adaptive KL, and per-step velocity-space implicit reward estimation - `hparams/training_args.py`: CRDTrainingArguments with paper-aligned defaults (decay schedules, kl_beta=0.1, kl_cfg=4.5) - `trainers/registry.py`: Register 'crd' key - `hparams/__init__.py`: Export CRDTrainingArguments - `examples/crd/lora/sd3_5.yaml`: SD3.5 + OCR example config matching paper Table 3 hyperparameters (K=24, 2 grad steps, timestep_range=0.99) - `guidance/algorithms.md`: CRD section with hyperparameter reference and centering modes table - `.agents/knowledge/architecture.md`: Add CRD to trainer registry table Co-Authored-By: Claude Sonnet 4.6 --- .agents/knowledge/architecture.md | 1 + examples/crd/lora/sd3_5.yaml | 128 ++++ guidance/algorithms.md | 63 +- src/flow_factory/hparams/__init__.py | 2 + src/flow_factory/hparams/training_args.py | 123 ++++ src/flow_factory/trainers/crd.py | 848 ++++++++++++++++++++++ src/flow_factory/trainers/registry.py | 1 + 7 files changed, 1165 insertions(+), 1 deletion(-) create mode 100644 examples/crd/lora/sd3_5.yaml create mode 100644 src/flow_factory/trainers/crd.py diff --git a/.agents/knowledge/architecture.md b/.agents/knowledge/architecture.md index a9bfacef..c69e3000 100644 --- a/.agents/knowledge/architecture.md +++ b/.agents/knowledge/architecture.md @@ -100,6 +100,7 @@ All three registries map string keys → lazy import paths. Resolution: registry | `dpo` | `DPOTrainer` | Decoupled | `BaseTrainer` | | `nft` | `DiffusionNFTTrainer` | Decoupled | `BaseTrainer` | | `awm` | `AWMTrainer` | Decoupled | `BaseTrainer` | +| `crd` | `CRDTrainer` | Decoupled | `BaseTrainer` | **Model Adapters** (`models/registry.py`): | Key | Class | Task | diff --git a/examples/crd/lora/sd3_5.yaml b/examples/crd/lora/sd3_5.yaml new file mode 100644 index 00000000..cea6a766 --- /dev/null +++ b/examples/crd/lora/sd3_5.yaml @@ -0,0 +1,128 @@ +# CRD (Centered Reward Distillation) with SD3.5 Medium + OCR +# Reference: https://arxiv.org/abs/2603.14128 +# Environment Configuration +launcher: "accelerate" # Options: accelerate +config_file: config/accelerate_configs/multi_gpu.yaml # Path to distributed config file (optional) +num_processes: 8 # Number of processes to launch (overrides config file) +main_process_port: 29500 +mixed_precision: "bf16" # Options: no, fp16, bf16 + +# Data Configuration +data: + dataset_dir: "dataset/ocr" # Path to dataset folder + preprocessing_batch_size: 8 # Batch size for preprocessing + dataloader_num_workers: 16 # Number of workers for DataLoader + force_reprocess: false # Force reprocessing of the dataset + cache_dir: "~/.cache/flow_factory/datasets" # Cache directory for preprocessed datasets + max_dataset_size: 1024 # Limit the maximum number of samples in the dataset + sampler_type: "auto" # Options: auto, distributed_k_repeat, group_contiguous + +# Model Configuration +model: + finetune_type: 'lora' # Options: full, lora + lora_rank : 32 + lora_alpha : 64 + target_modules: "default" # Options: all, default, or list of module names like ["to_k", "to_q", "to_v", "to_out.0"] + model_name_or_path: "stabilityai/stable-diffusion-3.5-medium" + model_type: "sd3-5" + resume_path: null # Path to load previous checkpoint/lora adapter + resume_type: null # Options: lora, full, state. Null to auto-detect based on `finetune_type` + # attn_backend: '_flash_3_hub' # Attention backend for training. + +log: + run_name: null # Run name (auto: {model_type}_{finetune_type}_{timestamp}) + project: "Flow-Factory" # Project name for logging + logging_backend: "wandb" # Options: wandb, swanlab, none + save_dir: "saves/" # Directory to save model checkpoints and logs + save_freq: 20 # Save frequency in epochs (0 to disable) + save_model_only: true # Save only the model weights (not optimizer, scheduler, etc.) + +# Training Configuration +train: + # Trainer settings + trainer_type: 'crd' + advantage_aggregation: 'sum' + + # CRD-specific settings + crd_beta: 1.0 # Beta scaling for reward matching loss + crd_loss_type: 'mse' # Options: mse, bce + use_old_for_loss: true # Use old model snapshot for implicit reward (core CRD feature) + adaptive_logp: true # Adaptive weighting of implicit reward terms + weight_temp: -1.0 # Softmax temperature for centering (-1 = uniform / inf temp) + + # Old model and sampling model decay schedules + # Format: "start_step-start_value-slope-end_value" or int preset (see crd.py _DECAY_PRESETS) + old_model_decay: "0-0.25-0.005-0.999" # Old model: starts at 0.25, ramps to 0.999 (paper Table 3 OCR) + sampling_model_decay: "75-0.0-0.0075-0.999" # Sampling model: delayed start at step 75 (paper Table 3 OCR) + + # Training timestep distribution + # num_train_timesteps: 0 # 0 = auto from num_inference_steps * timestep_range + num_train_timesteps: 20 # Set null to all steps + time_sampling_strategy: discrete # Options: uniform, logit_normal, discrete, discrete_with_init, discrete_wo_init + time_shift: 3.0 + timestep_range: 0.99 # Original CRD default (top 99% of denoising steps) + + # KL regularization + kl_type: 'v-based' + kl_beta: 0.1 + kl_cfg: 4.5 + reward_adaptive_kl: true # Scale KL by reward signal + ref_param_device: 'cuda' + + # Clipping + adv_clip_range: 5.0 # Advantage clipping range + + # Sampling Settings + resolution: 512 # Can be int or [height, width] + num_inference_steps: 10 # Number of timesteps of rollout + guidance_scale: 1.0 # Guidance scale for sampling + + # Batch and sampling + per_device_batch_size: 8 # Batch size per device. For image-to-image task, this will always fallback to 1. + group_size: 24 # Group size K=24 (paper Table 3) + global_std: true # Use global std for advantage normalization + unique_sample_num_per_epoch: 48 # 48 groups per batch (paper Table 3) + gradient_step_per_epoch: 2 # Optimizer steps per batch = 2 (paper Table 3) + gradient_accumulation_steps: auto # Options: auto, or positive integer. When set, `gradient_step_per_epoch` is ignored. + + # Optimization + learning_rate: 3.0e-4 # Initial learning rate + adam_weight_decay: 1.0e-4 # AdamW weight decay + adam_betas: [0.9, 0.999] # AdamW betas + adam_epsilon: 1.0e-8 # AdamW epsilon + max_grad_norm: 1.0 # Max gradient norm for clipping + + # Gradient Checkpointing + enable_gradient_checkpointing: false # Enable gradient checkpointing to save memory with extra compute + + # Seed + seed: 42 # Random seed + +# Scheduler Configuration +scheduler: + dynamics_type: "ODE" # Options: Flow-SDE, Dance-SDE, CPS, ODE + +# Evaluation settings +eval: + resolution: 512 # Evaluation resolution + per_device_batch_size: 4 # Eval batch size + guidance_scale: 1.0 # Guidance scale for sampling + num_inference_steps: 40 # Number of eval timesteps + eval_freq: 20 # Eval frequency in epochs (0 to disable) + seed: 42 # Eval seed (defaults to training seed) + +# Reward Model Configuration +rewards: + - name: "ocr" + reward_model: "OCR" + weight: 1 # Weight of this reward model + batch_size: 16 + device: "cuda" + dtype: bfloat16 + +eval_rewards: + - name: "ocr" + reward_model: "OCR" + batch_size: 32 + device: "cuda" + dtype: bfloat16 \ No newline at end of file diff --git a/guidance/algorithms.md b/guidance/algorithms.md index 211a9eea..f664efee 100644 --- a/guidance/algorithms.md +++ b/guidance/algorithms.md @@ -17,6 +17,8 @@ - [AWM: Advantage Weighted Matching](#awm-advantage-weighted-matching) +- [CRD: Centered Reward Distillation](#crd-centered-reward-distillation) + - [References](#references) ## Overview @@ -252,6 +254,64 @@ Here $\varepsilon$ is a small constant for numerical stability and $p$ denotes ` > **Tip**: `ghuber` with a small power (e.g., `0.25`) provides a good balance between robustness and gradient signal strength. `Uniform` is the simplest baseline and works well when reward signals are clean and low-variance. +## CRD: Centered Reward Distillation + +This algorithm is introduced in [[10]](#ref10). **Centered Reward Distillation (CRD)** is a forward-process RL method that matches implicit model rewards (estimated from prediction error in velocity space) with centered external rewards. The key insight is that the unknown prompt-dependent normalizer cancels under *within-prompt centering*, yielding a well-posed reward-matching objective. + +CRD maintains two named parameter snapshots alongside the current model: +- **Old model** (`_crd_old`): used to estimate implicit rewards via prediction error difference. +- **Sampling model** (`_crd_sampling`): used for off-policy rollout generation, blended toward the current model over time. + +To use this algorithm, set: + +```yaml +train: + trainer_type: 'crd' +``` + +### Key Hyperparameters + +```yaml +train: + trainer_type: 'crd' + + # CRD loss + crd_beta: 1.0 # Scaling factor for reward-matching loss + crd_loss_type: 'mse' # Options: mse, bce + use_old_for_loss: true # Use old model snapshot for implicit reward (recommended) + adaptive_logp: true # Adaptive per-sample weighting of implicit reward terms + weight_temp: -1.0 # Softmax temperature τ for centering (-1 = uniform/τ→∞) + + # Model snapshot decay schedules + # Format: "start_step-start_value-slope-end_value" or int preset key + old_model_decay: "0-0.25-0.005-0.999" # Paper (OCR): min(0.25 + 0.005t, 0.999) + sampling_model_decay: "75-0.0-0.0075-0.999" # Paper (OCR): delayed start at step 75 + + # KL regularization anchored to CFG-guided pretrained reference + kl_beta: 0.1 # KL coefficient + kl_cfg: 4.5 # CFG scale for teacher reference model + reward_adaptive_kl: true # Scale KL by reward to accelerate early learning + ref_param_device: 'cuda' + + # Timestep sampling + timestep_range: 0.99 # Top 99% of denoising steps (original CRD default) + num_train_timesteps: 20 + time_sampling_strategy: discrete + time_shift: 3.0 + + # Advantage clipping + adv_clip_range: 5.0 +``` + +### Centering Modes (`weight_temp`) + +| `weight_temp` | Mode | Description | +|---|---|---| +| `< 0` | Uniform (τ→∞) | Simple mean centering; recommended default | +| `== 0` | Hard selection | Positive pool (adv > 0) vs negative pool (adv < 0) | +| `> 0` | Softmax temperature | Dual-direction: `softmax(adv/τ)` and `softmax(-adv/τ)` | + + ## References * [1] [**Flow-GRPO:** Training Flow Matching Models via Online RL](https://arxiv.org/abs/2505.05470) @@ -262,4 +322,5 @@ Here $\varepsilon$ is a small constant for numerical stability and $p$ denotes ` * [6] [**PaCo-RL**: Advancing Reinforcement Learning for Consistent Image Generation with Pairwise Reward Modeling](https://arxiv.org/abs/2512.04784) * [7] [**DiffusionNFT**: Online Diffusion Reinforcement with Forward Process](https://arxiv.org/abs/2509.16117) * [8] [**Coefficients-Preserving Sampling** for Reinforcement Learning with Flow Matching](https://arxiv.org/abs/2509.05952) -* [9] [**Advantage Weighted Matching**: Aligning RL with Pretraining in Diffusion Models](https://arxiv.org/abs/2509.25050) \ No newline at end of file +* [9] [**Advantage Weighted Matching**: Aligning RL with Pretraining in Diffusion Models](https://arxiv.org/abs/2509.25050) +* [10] [**CRD**: Diffusion Reinforcement Learning via Centered Reward Distillation](https://arxiv.org/abs/2603.14128) diff --git a/src/flow_factory/hparams/__init__.py b/src/flow_factory/hparams/__init__.py index 59227f55..8591844b 100644 --- a/src/flow_factory/hparams/__init__.py +++ b/src/flow_factory/hparams/__init__.py @@ -25,6 +25,7 @@ NFTTrainingArguments, AWMTrainingArguments, DPOTrainingArguments, + CRDTrainingArguments, get_training_args_class, ) from .reward_args import RewardArguments, MultiRewardArguments @@ -41,6 +42,7 @@ "NFTTrainingArguments", "AWMTrainingArguments", "DPOTrainingArguments", + "CRDTrainingArguments", "get_training_args_class", "RewardArguments", "MultiRewardArguments", diff --git a/src/flow_factory/hparams/training_args.py b/src/flow_factory/hparams/training_args.py index e8690242..23f45b35 100644 --- a/src/flow_factory/hparams/training_args.py +++ b/src/flow_factory/hparams/training_args.py @@ -681,6 +681,128 @@ def get_num_train_timesteps(self, args: Any) -> int: return self.num_train_timesteps +@dataclass +class CRDTrainingArguments(TrainingArguments): + r"""Training arguments for Centered Reward Distillation (CRD). + + Reference: + Diffusion Reinforcement Learning via Centered Reward Distillation + https://arxiv.org/abs/2603.14128 + """ + + # Group-wise advantage normalization + global_std: bool = field( + default=True, + metadata={"help": "Whether to use global std for advantage normalization."}, + ) + advantage_aggregation: Literal['sum', 'gdpo'] = field( + default='gdpo', + metadata={"help": "Method to aggregate advantages within each group. Options: ['sum', 'gdpo']."}, + ) + + # CRD core + crd_beta: float = field( + default=1.0, + metadata={"help": "Beta scaling for CRD reward matching loss. Controls implicit vs external reward balance."}, + ) + crd_loss_type: Literal['mse', 'bce'] = field( + default='mse', + metadata={"help": "Loss type for CRD reward distillation. 'mse': squared error, 'bce': binary cross-entropy."}, + ) + use_old_for_loss: bool = field( + default=True, + metadata={"help": "Use 'old' model snapshot (instead of ref) for implicit reward estimation."}, + ) + adaptive_logp: bool = field( + default=True, + metadata={"help": "Adaptively weight implicit reward terms by prediction error magnitude."}, + ) + weight_temp: float = field( + default=-1.0, + metadata={"help": "Temperature for softmax weighting of advantages in CRD. Negative means uniform (inf temp)."}, + ) + # Decay schedules for model snapshots + old_model_decay: str = field( + default="0-0.25-0.005-0.999", + metadata={"help": "Decay schedule for old model blending: 'start_step-start_value-slope-end_value' or preset name."}, + ) + sampling_model_decay: Union[str, int] = field( + default="75-0.0-0.0075-0.999", + metadata={"help": "Decay schedule for sampling model blending. Same format as old_model_decay, or int preset."}, + ) + + # Clipping / KL + adv_clip_range: tuple[float, float] = field( + default=(-5.0, 5.0), + metadata={"help": "Clipping range for advantages."}, + ) + kl_type: Literal['v-based'] = field( + default='v-based', + metadata={"help": "Type of KL divergence. CRD uses 'v-based' (velocity space)."}, + ) + kl_beta: float = field( + default=0.1, + metadata={"help": "KL penalty beta for regularization against the reference model."}, + ) + kl_cfg: float = field( + default=4.5, + metadata={ + "help": ( + "CFG scale for the teacher (reference) model during KL computation. " + "If > 1.0, the reference forward pass uses classifier-free guidance: " + "``noise_pred = uncond + kl_cfg * (cond - uncond)``. " + "Set to 1.0 (default) to disable CFG on the teacher." + ) + }, + ) + reward_adaptive_kl: bool = field( + default=True, + metadata={"help": "Dynamically adjust KL strength based on reward signal."}, + ) + ref_param_device: Literal["cpu", "cuda"] = field( + default="cuda", + metadata={"help": "Device to store reference model parameters."}, + ) + + # Timestep control + num_train_timesteps: int = field( + default=0, + metadata={"help": "Number of training timesteps. 0 = auto from num_inference_steps * timestep_range."}, + ) + time_sampling_strategy: Literal['uniform', 'logit_normal', 'discrete', 'discrete_with_init', 'discrete_wo_init'] = field( + default='discrete', + metadata={"help": "Time sampling strategy for training."}, + ) + time_shift: float = field( + default=3.0, + metadata={"help": "Time shift for logit normal time sampling."}, + ) + timestep_range: Union[float, Tuple[float, float]] = field( + default=0.99, + metadata={ + "help": "Fraction range along denoise axis 1000→0. Default 0.99 matches original CRD's timestep_fraction." + }, + ) + + def __post_init__(self): + super().__post_init__() + self.timestep_range = _standardize_timestep_range(self.timestep_range) + if not self.num_train_timesteps or self.num_train_timesteps <= 0: + self.num_train_timesteps = max(1, int( + self.num_inference_steps * (self.timestep_range[1] - self.timestep_range[0]) + )) + self.adv_clip_range = _standardize_clip_range(self.adv_clip_range, 'adv_clip_range') + + @property + def requires_ref_model(self) -> bool: + """CRD always needs a reference model for KL and implicit reward.""" + return True + + def get_num_train_timesteps(self, args: Any) -> int: + assert self.num_train_timesteps is not None + return self.num_train_timesteps + + # ============================================================================ # Training Arguments Registry # ============================================================================ @@ -691,6 +813,7 @@ def get_num_train_timesteps(self, args: Any) -> int: 'nft': NFTTrainingArguments, 'awm': AWMTrainingArguments, 'dpo': DPOTrainingArguments, + 'crd': CRDTrainingArguments, } diff --git a/src/flow_factory/trainers/crd.py b/src/flow_factory/trainers/crd.py new file mode 100644 index 00000000..f4007db2 --- /dev/null +++ b/src/flow_factory/trainers/crd.py @@ -0,0 +1,848 @@ +# Copyright 2026 Jayce-Ping +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# src/flow_factory/trainers/crd.py +""" +Centered Reward Distillation (CRD) Trainer. +Reference: +[1] Diffusion Reinforcement Learning via Centered Reward Distillation + - https://arxiv.org/abs/2603.14128 +""" +import os +from typing import List, Dict, Any, Union, Optional +from functools import partial +from collections import defaultdict +from contextlib import contextmanager +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.utils.torch_utils import randn_tensor +import tqdm as tqdm_ + +tqdm = partial(tqdm_.tqdm, dynamic_ncols=True) + +from .abc import BaseTrainer +from ..hparams import CRDTrainingArguments +from ..samples import BaseSample +from ..rewards import RewardBuffer +from ..utils.base import filter_kwargs, create_generator, create_generator_by_prompt, to_broadcast_tensor +from ..utils.logger_utils import setup_logger +from ..utils.noise_schedule import TimeSampler, flow_match_sigma +from ..utils.dist import reduce_loss_info + +logger = setup_logger(__name__) + + +# ========================= Decay Utilities ========================= + +# Predefined decay presets: (start_step, start_value, slope, end_value) +_DECAY_PRESETS = { + 0: (0, 0.0, 0.0, 0.0), + 1: (0, 0.0, 0.001, 0.5), + 2: (75, 0.0, 0.0075, 0.999), + 3: (0, 1.0, 0.0, 1.0), + 4: (0, 0.0, 0.02, 0.99), + 5: (0, 0.0, 0.01, 0.5), + 6: (0, 0.0, 0.0075, 0.999), + 'none': (0, 0.0, 0.0, 0.0), + 'slow': (0, 0.0, 0.001, 0.5), + 'medium': (75, 0.0, 0.0075, 0.999), + 'offline': (0, 1.0, 0.0, 1.0), + 'fast': (0, 0.0, 0.02, 0.99), + 'moderate': (0, 0.0, 0.01, 0.5), +} + + +def compute_decay(step: int, decay_type) -> float: + """ + Compute a decay value at the given step. + + Args: + step: Current training step. + decay_type: An int/str preset key, or a string ``"start_step-start_value-slope-end_value"``. + + Returns: + Decay value (float in [0, 1]). + """ + # Try int conversion for string digits like "0", "1", etc. + if isinstance(decay_type, str): + try: + decay_type = int(decay_type) + except ValueError: + pass + + if decay_type in _DECAY_PRESETS: + start_step, start_value, slope, end_value = _DECAY_PRESETS[decay_type] + elif isinstance(decay_type, str) and '-' in decay_type: + parts = decay_type.split('-') + assert len(parts) == 4, ( + f"Decay string format must be 'start_step-start_value-slope-end_value', got: {decay_type}" + ) + start_step, start_value, slope, end_value = float(parts[0]), float(parts[1]), float(parts[2]), float(parts[3]) + start_step = int(start_step) + else: + raise ValueError( + f"Invalid decay_type: {decay_type}. " + f"Valid options: {list(_DECAY_PRESETS.keys())} or 'start_step-start_value-slope-end_value'" + ) + + if step < start_step: + return start_value + return min(start_value + (step - start_step) * slope, end_value) + + +# ============================ CRD Trainer ============================ + +class CRDTrainer(BaseTrainer): + """ + Centered Reward Distillation (CRD) Trainer. + + Core algorithm: match centered external rewards with implicit model rewards + estimated from prediction error in velocity space. + + Key features (matching the original CRD implementation): + - Loss is based on centered reward distillation (not contrastive positive/negative). + - Maintains an "old" model snapshot for implicit reward estimation (decay_type). + - Maintains a "sampling" model snapshot for off-policy rollouts (decay_type2). + - Supports dual-direction centering with temperature-weighted softmax. + - Supports adaptive KL based on reward signals. + + Model snapshots: + - Current model: trainable parameters (LoRA "default" in original CRD). + - Old model: named parameter snapshot for implicit reward estimation. + - Sampling model: named parameter snapshot for rollout generation. + - Reference model: original pre-trained weights (LoRA disabled / base model). + + Reference: https://arxiv.org/abs/2603.14128 + """ + + _OLD_PARAMS_NAME = '_crd_old' + _SAMPLING_PARAMS_NAME = '_crd_sampling' + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.training_args: CRDTrainingArguments + + # CRD-specific config + self.crd_beta = self.training_args.crd_beta + self.crd_loss_type = self.training_args.crd_loss_type + self.use_old_for_loss = self.training_args.use_old_for_loss + self.adaptive_logp = self.training_args.adaptive_logp + self.weight_temp = self.training_args.weight_temp + + # Decay schedules + self.old_model_decay = self.training_args.old_model_decay + self.sampling_model_decay = self.training_args.sampling_model_decay + + # KL + self.kl_beta = self.training_args.kl_beta + self.kl_cfg = self.training_args.kl_cfg + self.reward_adaptive_kl = self.training_args.reward_adaptive_kl + + # Timestep sampling + self.time_sampling_strategy = self.training_args.time_sampling_strategy + self.time_shift = self.training_args.time_shift + self.num_train_timesteps = self.training_args.num_train_timesteps + self.timestep_range = self.training_args.timestep_range + + self.kl_type = self.training_args.kl_type + if self.kl_type != 'v-based': + logger.warning( + f"CRD-Trainer only supports 'v-based' KL loss, got {self.kl_type}, switching to 'v-based'." + ) + self.kl_type = 'v-based' + + # Initialize model snapshots: "old" (for implicit reward) and "sampling" (for rollout) + self._init_model_snapshots() + + # ========================= Initialization ========================= + + def _init_model_snapshots(self): + """ + Initialize both model snapshots by storing copies of current trainable parameters. + + In the original CRD, this corresponds to: + - ``transformer.add_adapter("old", ...)`` + copy from "default" + - ``transformer.add_adapter("sampling", ...)`` + copy from "default" + """ + ref_device = self.training_args.ref_param_device + + # Old model snapshot (for implicit reward estimation) + self.adapter.add_named_parameters( + name=self._OLD_PARAMS_NAME, + device=ref_device, + ) + logger.info("CRD: Initialized 'old' model snapshot for implicit reward estimation.") + + # Sampling model snapshot (for off-policy rollout generation) + self.adapter.add_named_parameters( + name=self._SAMPLING_PARAMS_NAME, + device=ref_device, + ) + logger.info("CRD: Initialized 'sampling' model snapshot for rollout generation.") + + @property + def enable_kl_loss(self) -> bool: + return self.kl_beta > 0.0 + + @contextmanager + def sampling_context(self): + """ + Use the sampling model snapshot for rollout generation. + + In the original CRD, this corresponds to ``transformer_ddp.module.set_adapter("sampling")``. + The sampling model is a separate snapshot blended towards current weights with + ``sampling_model_decay`` (decay_type2 in the original). + """ + with self.adapter.use_named_parameters(self._SAMPLING_PARAMS_NAME): + yield + + # ========================= Timestep Sampling ========================= + + def _sample_timesteps(self, batch_size: int) -> torch.Tensor: + """ + Sample continuous or discrete timesteps based on configured `time_sampling_strategy`. + + Returns: + Tensor of shape ``(num_train_timesteps, batch_size)`` with scheduler-scale ``t`` in ``[0, 1000]``. + """ + device = self.accelerator.device + time_sampling_strategy = self.time_sampling_strategy.lower() + available = ['logit_normal', 'uniform', 'discrete', 'discrete_with_init', 'discrete_wo_init'] + + if time_sampling_strategy == 'logit_normal': + return TimeSampler.logit_normal_shifted( + batch_size=batch_size, + num_timesteps=self.num_train_timesteps, + timestep_range=self.timestep_range, + time_shift=self.time_shift, + device=device, + stratified=True, + ) + elif time_sampling_strategy == 'uniform': + return TimeSampler.uniform( + batch_size=batch_size, + num_timesteps=self.num_train_timesteps, + timestep_range=self.timestep_range, + time_shift=self.time_shift, + device=device, + ) + elif time_sampling_strategy.startswith('discrete'): + # Map time_sampling_strategy to (include_init, force_init) + discrete_config = { + 'discrete': (True, False), + 'discrete_with_init': (True, True), + 'discrete_wo_init': (False, False), + } + if time_sampling_strategy not in discrete_config: + raise ValueError(f"Unknown time_sampling_strategy: {time_sampling_strategy}. Available: {available}") + include_init, force_init = discrete_config[time_sampling_strategy] + return TimeSampler.discrete( + batch_size=batch_size, + num_train_timesteps=self.num_train_timesteps, + scheduler_timesteps=self.adapter.scheduler.timesteps, + timestep_range=self.timestep_range, + include_init=include_init, + force_init=force_init, + ) + else: + raise ValueError(f"Unknown time_sampling_strategy: {time_sampling_strategy}. Available: {available}") + + # ========================= Evaluation Loop ========================= + + def evaluate(self) -> None: + """Evaluation loop.""" + if self.test_dataloader is None: + return + + self.adapter.eval() + self.eval_reward_buffer.clear() + + with torch.no_grad(), self.autocast(), self.adapter.use_ema_parameters(): + all_samples: List[BaseSample] = [] + + for batch in tqdm( + self.test_dataloader, + desc='Evaluating', + disable=not self.show_progress_bar, + ): + generator = create_generator_by_prompt(batch['prompt'], self.training_args.seed) + inference_kwargs = { + 'compute_log_prob': False, + 'generator': generator, + 'trajectory_indices': None, # No need to store trajectories during evaluation + **self.eval_args, + } + inference_kwargs.update(**batch) + inference_kwargs = filter_kwargs(self.adapter.inference, **inference_kwargs) + samples = self.adapter.inference(**inference_kwargs) + all_samples.extend(samples) + self.eval_reward_buffer.add_samples(samples) + + rewards = self.eval_reward_buffer.finalize(store_to_samples=True, split='pointwise') + + # Gather and log rewards + rewards = {key: torch.as_tensor(value).to(self.accelerator.device) for key, value in rewards.items()} + gathered_rewards = { + key: self.accelerator.gather(value).cpu().numpy() + for key, value in rewards.items() + } + + # Log statistics + if self.accelerator.is_main_process: + _log_data = {f'eval/reward_{key}_mean': np.mean(value) for key, value in gathered_rewards.items()} + _log_data.update({f'eval/reward_{key}_std': np.std(value) for key, value in gathered_rewards.items()}) + _log_data['eval_samples'] = all_samples + self.log_data(_log_data, step=self.step) + self.accelerator.wait_for_everyone() + + # ========================= Advantage Computation ========================= + + def compute_advantages( + self, + samples: List[BaseSample], + rewards: Dict[str, torch.Tensor], + store_to_samples: bool = True, + aggregation_func=None, + ) -> torch.Tensor: + """Compute advantages — delegates to AdvantageProcessor. + + Args: + samples: List of BaseSample instances + rewards: Dict of reward_name to reward tensors aligned with samples + store_to_samples: Whether to store computed advantages back to samples' extra_kwargs + aggregation_func: Method to aggregate advantages within each group. + Options: 'sum' (default GRPO), 'gdpo' (GDPO-style), or a custom callable. + Returns: + advantages: Tensor of shape (num_samples, ) with computed advantages + """ + aggregation_func = aggregation_func or self.training_args.advantage_aggregation + return self.advantage_processor.compute_advantages( + samples=samples, + rewards=rewards, + store_to_samples=store_to_samples, + aggregation_func=aggregation_func, + ) + + # ========================= Main Training Loop ========================= + + def start(self): + """Main training loop.""" + while self.should_continue_training(): + self.adapter.scheduler.set_seed(self.epoch + self.training_args.seed) + + # Save checkpoint + if ( + self.log_args.save_freq > 0 + and self.epoch % self.log_args.save_freq == 0 + and self.log_args.save_dir + ): + save_dir = os.path.join( + self.log_args.save_dir, + str(self.log_args.run_name), + 'checkpoints', + ) + self.save_checkpoint(save_dir, epoch=self.epoch) + + # Evaluation + if ( + self.eval_args.eval_freq > 0 + and self.epoch % self.eval_args.eval_freq == 0 + ): + self.evaluate() + + # Sampling: always use the "sampling" model snapshot + with self.sampling_context(): + samples = self.sample() + + self.prepare_feedback(samples) + self.optimize(samples) + + # Update EMA (if enabled), old model, and sampling model + self.adapter.ema_step(step=self.epoch) + self._update_old_model() + self._update_sampling_model() + + self.epoch += 1 + + def _blend_named_params(self, name: str, decay: float): + """ + Blend a named parameter snapshot towards the current trainable parameters. + + Formula: ``snapshot = decay * snapshot + (1 - decay) * current`` + + Args: + name: Name of the parameter snapshot. + decay: Blending coefficient. 0.0 = full copy, 1.0 = no change. + """ + if decay <= 0.0: + # Full copy from current params (no blending) + self.adapter.update_named_parameters(name) + elif decay >= 1.0: + # Keep snapshot unchanged (fully offline) + pass + else: + # Exponential blending: snapshot = decay * snapshot + (1 - decay) * current + info = self.adapter._named_parameters[name] + current_params = self.adapter._get_component_parameters(info.target_components) + with torch.no_grad(): + for ema_param, param in zip(info.ema_wrapper.ema_parameters, current_params, strict=True): + ema_param.data.mul_(decay).add_( + param.detach().to(ema_param.device), alpha=(1.0 - decay) + ) + + def _update_old_model(self): + """ + Blend the old model snapshot towards the current trainable parameters. + + In the original CRD, controlled by ``decay_type`` (default: ``"0-0.25-0.001-0.5"``). + """ + decay = compute_decay(self.step, self.old_model_decay) + self._blend_named_params(self._OLD_PARAMS_NAME, decay) + + # Log decay value + if self.accelerator.is_main_process: + self.log_data({'train/old_model_decay': decay}, step=self.step) + + def _update_sampling_model(self): + """ + Blend the sampling model snapshot towards the current trainable parameters. + + In the original CRD, controlled by ``decay_type2`` (default: preset 1 = ``(0, 0.0, 0.001, 0.5)``). + """ + decay = compute_decay(self.step, self.sampling_model_decay) + self._blend_named_params(self._SAMPLING_PARAMS_NAME, decay) + + # Log decay value + if self.accelerator.is_main_process: + self.log_data({'train/sampling_model_decay': decay}, step=self.step) + + # ========================= Sampling ========================= + + def sample(self) -> List[BaseSample]: + """Generate rollouts. Only keeps final latents (like NFT).""" + self.adapter.rollout() + self.reward_buffer.clear() + samples = [] + data_iter = iter(self.dataloader) + + with torch.no_grad(), self.autocast(): + for _ in tqdm( + range(self.training_args.num_batches_per_epoch), + desc=f'Epoch {self.epoch} Sampling', + disable=not self.show_progress_bar, + ): + batch = next(data_iter) + sample_kwargs = { + **self.training_args, + 'compute_log_prob': False, + 'trajectory_indices': [-1], # Only keep final latents + **batch, + } + sample_kwargs = filter_kwargs(self.adapter.inference, **sample_kwargs) + sample_batch = self.adapter.inference(**sample_kwargs) + samples.extend(sample_batch) + self.reward_buffer.add_samples(sample_batch) + + return samples + + # ========================= Forward Pass Helpers ========================= + + def _compute_crd_output( + self, + batch: Dict[str, Any], + timestep: torch.Tensor, + noised_latents: torch.Tensor, + guidance_scale: Optional[float] = None, + ) -> Dict[str, torch.Tensor]: + """ + Compute CRD forward pass for a single timestep. + + Args: + batch: Batch dict with prompt embeddings etc. + timestep: (B,) tensor in scheduler scale ``[0, 1000]``. + noised_latents: Interpolated latents ``x_t = (1-σ) x_1 + σ noise`` with ``σ = t/1000``. + guidance_scale: Override CFG scale. If None, uses the value from training_args + (typically 1.0 for student training). Pass ``self.kl_cfg`` for teacher + CFG inference — the model adapter will automatically do the double forward + pass using ``negative_prompt_embeds`` / ``negative_pooled_prompt_embeds`` + from the batch if ``guidance_scale > 1.0``. + + Returns: + Dict with ``noise_pred`` (velocity prediction), shape ``(B, C, H, W)``. + """ + t_b = timestep.view(-1) # Scheduler scale [0, 1000] + device = self.accelerator.device + + forward_kwargs = { + **self.training_args, + 't': t_b, + 't_next': torch.zeros_like(t_b), + 'latents': noised_latents, + 'compute_log_prob': False, + 'return_kwargs': ['noise_pred'], + 'noise_level': 0.0, + **{ + k: (v.to(device) if isinstance(v, torch.Tensor) else v) + for k, v in batch.items() + if k not in ['all_latents', 'timesteps', 'advantage'] + }, + } + if guidance_scale is not None: + forward_kwargs['guidance_scale'] = guidance_scale + forward_kwargs = filter_kwargs(self.adapter.forward, **forward_kwargs) + output = self.adapter.forward(**forward_kwargs) + return {'noise_pred': output.noise_pred} + + def prepare_feedback(self, samples: List[BaseSample]) -> None: + """Finalize rewards, compute advantages, and log advantage metrics.""" + rewards = self.reward_buffer.finalize(store_to_samples=True, split='all') + self.compute_advantages(samples, rewards, store_to_samples=True) + adv_metrics = self.advantage_processor.pop_advantage_metrics() + if adv_metrics: + self.log_data(adv_metrics, step=self.step) + + # ========================= CRD Centering Loss ========================= + + def _compute_crd_loss( + self, + adv_cur: torch.Tensor, + adv_cur_rank: torch.Tensor, + r_theta_gathered: torch.Tensor, + r_theta_local: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the centered reward distillation (CRD) loss. + + Supports three modes depending on ``weight_temp``: + - **Uniform** (``weight_temp < 0`` -> inf): Simple mean centering (single direction). + - **Hard selection** (``weight_temp == 0``): Separate positive/negative sample pools. + - **Softmax temperature** (``weight_temp > 0``): Dual-direction centering with + ``softmax(adv/T)`` for positive direction and ``softmax(-adv/T)`` for negative direction. + + In the non-uniform case (``weight_temp >= 0``), the loss is the average of two + directions: one centered on high-reward samples, one centered on low-reward samples. + + Args: + adv_cur: Gathered advantages across all GPUs, shape ``(N,)``. + adv_cur_rank: Local advantages for this rank, shape ``(B,)``. + r_theta_gathered: Gathered implicit rewards across all GPUs, shape ``(N,)``. + r_theta_local: Local implicit rewards for this rank, shape ``(B,)``. + + Returns: + Unscaled CRD policy loss (scalar). + """ + device = adv_cur.device + weight_temp = torch.inf if self.weight_temp < 0 else self.weight_temp + + if weight_temp == torch.inf: + # ---- Uniform weighting (single-direction centering) ---- + softmax_p = torch.softmax(adv_cur / weight_temp, dim=0) # uniform + adv_cur_avg = (adv_cur * softmax_p).sum(dim=0, keepdim=True) + r_theta_avg = (r_theta_gathered * softmax_p).sum(dim=0, keepdim=True) + + Rc = adv_cur_rank - adv_cur_avg + R_theta_c = r_theta_local - r_theta_avg.detach() + + if self.crd_loss_type == 'bce': + ori_policy_loss = F.binary_cross_entropy_with_logits( + self.crd_beta * R_theta_c, + torch.sigmoid(Rc.detach()), + reduction='mean', + ) + else: + diff = self.crd_beta * R_theta_c - Rc + ori_policy_loss = (diff ** 2).mean() + + else: + # ---- Non-uniform: Dual-direction centering ---- + # Positive direction: weight towards higher-reward samples + if weight_temp == 0: + # Hard selection: only positive-advantage samples + adv_plus_mask = (adv_cur > 0.0) + if adv_plus_mask.sum() == 0: + softmax_p = torch.ones_like(adv_cur) / adv_cur.shape[0] + else: + masked_adv = adv_cur.where( + adv_plus_mask, torch.tensor(float('-inf'), device=device) + ) + softmax_p = torch.softmax(masked_adv, dim=0) + else: + softmax_p = torch.softmax(adv_cur / weight_temp, dim=0) + + # Negative direction: weight towards lower-reward samples + if weight_temp == 0: + # Hard selection: only negative-advantage samples + adv_minus_mask = (adv_cur < 0.0) + if adv_minus_mask.sum() == 0: + softmax_p_minus = torch.ones_like(adv_cur) / adv_cur.shape[0] + else: + masked_adv = adv_cur.where( + adv_minus_mask, torch.tensor(float('-inf'), device=device) + ) + softmax_p_minus = torch.softmax(masked_adv, dim=0) + else: + softmax_p_minus = torch.softmax(-adv_cur / weight_temp, dim=0) + + # Positive direction centering + adv_cur_avg = (adv_cur * softmax_p).sum(dim=0, keepdim=True) + r_theta_avg = (r_theta_gathered * softmax_p).sum(dim=0, keepdim=True) + Rc = adv_cur_rank - adv_cur_avg + R_theta_c = r_theta_local - r_theta_avg.detach() + + # Negative direction centering + adv_cur_avg_minus = (adv_cur * softmax_p_minus).sum(dim=0, keepdim=True) + r_theta_avg_minus = (r_theta_gathered * softmax_p_minus).sum(dim=0, keepdim=True) + Rc_minus = adv_cur_rank - adv_cur_avg_minus + R_theta_c_minus = r_theta_local - r_theta_avg_minus.detach() + + if self.crd_loss_type == 'bce': + ori_policy_loss = 0.5 * F.binary_cross_entropy_with_logits( + self.crd_beta * R_theta_c, + torch.sigmoid(Rc.detach()), + reduction='mean', + ) + 0.5 * F.binary_cross_entropy_with_logits( + self.crd_beta * R_theta_c_minus, + torch.sigmoid(Rc_minus.detach()), + reduction='mean', + ) + else: + diff = self.crd_beta * R_theta_c - Rc + diff_minus = self.crd_beta * R_theta_c_minus - Rc_minus + ori_policy_loss = 0.5 * (diff ** 2).mean() + 0.5 * (diff_minus ** 2).mean() + + return ori_policy_loss + + # ========================= Optimization ========================= + + def optimize(self, samples: List[BaseSample]) -> None: + """ + CRD optimization loop. + + For each timestep: + 1. Compute velocity predictions from current model, old model, and reference model. + 2. Estimate implicit reward r_theta from prediction errors. + 3. Center both external and implicit rewards (with optional dual-direction centering). + 4. Compute CRD loss matching centered rewards. + 5. Add KL regularization (with optional reward-adaptive scaling). + """ + for inner_epoch in range(self.training_args.num_inner_epochs): + # CRD does not shuffle samples (needs same-prompt grouping for centering) + # Re-group samples into batches + sample_batches: List[Dict[str, Union[torch.Tensor, Any, List[Any]]]] = [ + BaseSample.stack(samples[i:i + self.training_args.per_device_batch_size]) + for i in range(0, len(samples), self.training_args.per_device_batch_size) + ] + + # ==================== Pre-compute: Old V Predictions ==================== + self.adapter.rollout() + with torch.no_grad(), self.autocast(): + for batch in tqdm( + sample_batches, + total=len(sample_batches), + desc=f'Epoch {self.epoch} Pre-computing Old V Predictions', + position=0, + disable=not self.show_progress_bar, + ): + batch_size = batch['all_latents'].shape[0] + clean_latents = batch['all_latents'][:, -1] + + # Sample timesteps: (T, B) in scheduler scale [0, 1000] + all_timesteps = self._sample_timesteps(batch_size) + batch['_all_timesteps'] = all_timesteps + batch['_all_random_noise'] = [] + + # Pre-compute old model predictions + old_v_pred_list = [] + for t_idx in range(self.num_train_timesteps): + t_flat = all_timesteps[t_idx] # (B,) scheduler scale [0, 1000] + sigma_broadcast = to_broadcast_tensor(flow_match_sigma(t_flat), clean_latents) + noise = randn_tensor( + clean_latents.shape, + device=clean_latents.device, + dtype=clean_latents.dtype, + ) + batch['_all_random_noise'].append(noise) + noised_latents = (1 - sigma_broadcast) * clean_latents + sigma_broadcast * noise + + if self.use_old_for_loss: + # Use old model snapshot + with self.adapter.use_named_parameters(self._OLD_PARAMS_NAME): + old_output = self._compute_crd_output(batch, t_flat, noised_latents) + else: + # Use reference model (original weights) + with self.adapter.use_ref_parameters(): + old_output = self._compute_crd_output(batch, t_flat, noised_latents) + old_v_pred_list.append(old_output['noise_pred'].detach()) + + batch['_old_v_pred_list'] = old_v_pred_list + + # ==================== Training Loop ==================== + self.adapter.train() + loss_info = defaultdict(list) + + with self.autocast(): + for batch in tqdm( + sample_batches, + total=len(sample_batches), + desc=f'Epoch {self.epoch} Training', + position=0, + disable=not self.show_progress_bar, + ): + # Retrieve pre-computed data + batch_size = batch['all_latents'].shape[0] + clean_latents = batch['all_latents'][:, -1] + all_timesteps = batch['_all_timesteps'] + all_random_noise = batch['_all_random_noise'] + old_v_pred_list = batch['_old_v_pred_list'] + # Iterate through timesteps + for t_idx in tqdm( + range(self.num_train_timesteps), + desc=f'Epoch {self.epoch} Timestep', + position=1, + leave=False, + disable=not self.show_progress_bar, + ): + with self.accelerator.accumulate(*self.adapter.trainable_components): + # 1. Prepare inputs + t_flat = all_timesteps[t_idx] # (B,) scheduler scale [0, 1000] + sigma_broadcast = to_broadcast_tensor(flow_match_sigma(t_flat), clean_latents) + noise = all_random_noise[t_idx] + noised_latents = (1 - sigma_broadcast) * clean_latents + sigma_broadcast * noise + old_v_pred = old_v_pred_list[t_idx] + v_target = noise - clean_latents + + # 2. Current model forward pass + output = self._compute_crd_output(batch, t_flat, noised_latents) + forward_pred = output['noise_pred'] + + # 3. Reference model forward pass (for KL) + # If kl_cfg > 1.0, the adapter's forward() will do CFG automatically: + # it concatenates [neg_embeds, pos_embeds] and computes: + # noise_pred = uncond + kl_cfg * (cond - uncond) + # The negative embeddings come from the batch (negative_prompt_embeds, + # negative_pooled_prompt_embeds stored by SD3_5Sample during rollout). + with torch.no_grad(), self.adapter.use_ref_parameters(): + cfg = self.kl_cfg if self.kl_cfg > 1.0 else None + ref_output = self._compute_crd_output(batch, t_flat, noised_latents, guidance_scale=cfg) + ref_pred = ref_output['noise_pred'] + + # 4. Compute implicit reward: r_theta = -(||pred_theta - v_target||^2 - ||pred_old - v_target||^2) + if self.adaptive_logp: + with torch.no_grad(): + weight_theta = ( + torch.abs(forward_pred.double() - v_target.double()) + .mean(dim=tuple(range(1, forward_pred.ndim)), keepdim=True) + .clip(min=1e-5) + ) + weight_old = ( + torch.abs(old_v_pred.double() - v_target.double()) + .mean(dim=tuple(range(1, old_v_pred.ndim)), keepdim=True) + .clip(min=1e-5) + ) + r_theta = -( + (forward_pred - v_target) ** 2 / weight_theta + - (old_v_pred - v_target) ** 2 / weight_old + ) + else: + r_theta = -( + (forward_pred - v_target) ** 2 + - (old_v_pred - v_target) ** 2 + ) + + # Reduce spatial dims to per-sample scalar + r_theta_local = r_theta.mean(dim=tuple(range(1, r_theta.ndim))) + + # Gather r_theta across all GPUs for centering + r_theta_gathered = self.accelerator.gather(r_theta_local.detach()).to( + self.accelerator.device + ) + + # 5. Compute advantages for CRD centering + adv = batch['advantage'] + adv_clip_range = self.training_args.adv_clip_range + adv_clipped = torch.clamp(adv, adv_clip_range[0], adv_clip_range[1]) + + # Normalize to [0, 1] + normalized_adv = (adv_clipped / max(adv_clip_range)) / 2.0 + 0.5 + adv_cur_rank = torch.clamp(normalized_adv, 0, 1) + + # Gather advantages across all GPUs + adv_cur = self.accelerator.gather(adv_cur_rank.detach()).to( + self.accelerator.device + ) + + # 6. Centered Reward Distillation loss (supports dual-direction centering) + ori_policy_loss = self._compute_crd_loss( + adv_cur=adv_cur, + adv_cur_rank=adv_cur_rank, + r_theta_gathered=r_theta_gathered, + r_theta_local=r_theta_local, + ) + + # Scale by adv_clip_max / beta for gradient magnitude normalization + policy_loss = (ori_policy_loss * adv_clip_range[1] / max(self.crd_beta, 1e-8)).mean() + loss = policy_loss + + # 7. KL regularization against reference model + kl_div = ((forward_pred - ref_pred) ** 2).mean( + dim=tuple(range(1, forward_pred.ndim)) + ) + + if self.reward_adaptive_kl: + # Linearly scale KL based on reward value + raw_reward = adv_cur_rank # Already in [0, 1] + base_beta = 1e-4 + min_coef = base_beta / max(self.kl_beta, 1e-8) + kl_loss = self.kl_beta * torch.mean((min_coef + raw_reward * (1 - min_coef)) * kl_div) + else: + kl_loss = self.kl_beta * kl_div.mean() + + loss = loss + kl_loss + + # 8. Logging + loss_info['policy_loss'].append(policy_loss.detach()) + loss_info['unweighted_policy_loss'].append(ori_policy_loss.mean().detach()) + loss_info['kl_div'].append(kl_div.mean().detach()) + loss_info['kl_loss'].append(kl_loss.detach()) + loss_info['r_theta_mean'].append(r_theta_local.mean().detach()) + loss_info['loss'].append(loss.detach()) + + if self.use_old_for_loss: + old_kl = ((old_v_pred - ref_pred) ** 2).mean( + dim=tuple(range(1, old_v_pred.ndim)) + ).mean() + loss_info['old_kl_div'].append(old_kl.detach()) + old_deviate = ((forward_pred - old_v_pred) ** 2).mean() + loss_info['old_deviate'].append(old_deviate.detach()) + + # 9. Backward and optimizer step + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + grad_norm = self.accelerator.clip_grad_norm_( + self.adapter.get_trainable_parameters(), + self.training_args.max_grad_norm, + ) + self.optimizer.step() + self.optimizer.zero_grad() + # Log accumulated loss info + loss_info = reduce_loss_info(self.accelerator, loss_info) + loss_info['grad_norm'] = grad_norm + self.log_data( + {f'train/{k}': v for k, v in loss_info.items()}, + step=self.step, + ) + self.step += 1 + loss_info = defaultdict(list) diff --git a/src/flow_factory/trainers/registry.py b/src/flow_factory/trainers/registry.py index 8e6b5085..19b1618a 100644 --- a/src/flow_factory/trainers/registry.py +++ b/src/flow_factory/trainers/registry.py @@ -32,6 +32,7 @@ 'nft': 'flow_factory.trainers.nft.DiffusionNFTTrainer', 'awm': 'flow_factory.trainers.awm.AWMTrainer', 'dpo': 'flow_factory.trainers.dpo.DPOTrainer', + 'crd': 'flow_factory.trainers.crd.CRDTrainer', }