6 个文件 · 650 行核心代码 · 完整实现 On-Policy Distillation
对比 Swift (72K+ 行)、LlamaFactory (45K+ 行)、SLIME (30K+ 行) 的 OPD 实现, FlashOPD 用 不到 1/50 的代码量 提供了 同等完整度 的 On-Policy Distillation 能力。
| 特性 | FlashOPD | Swift | LlamaFactory | EasyDistill |
|---|---|---|---|---|
| OPD 核心代码量 | ~650 行 | 72K+ 行 | 45K+ 行 | 15K+ 行 |
| 安装时间 | < 30 秒 (源码) | 5-10 分钟 | 3-5 分钟 | 3-5 分钟 |
| 学习曲线 | 读完 6 个文件 | 需理解框架体系 | 需理解框架体系 | 需理解多模块 |
| 单文件可读 | CleanRL 风格 | 分散在数十个模块 | 分散在数十个模块 | 多层封装 |
| API Teacher | vLLM OpenAI API | 需自行集成 | 需自行集成 | 部分支持 |
| KL 类型 | forward/reverse/JSD | 部分 | 部分 | 部分 |
| 动态 Loss 平衡 | EMA 自适应 | 手动 | 手动 | 手动 |
| DeepSpeed/FSDP | 原生支持 | 原生支持 | 原生支持 | 原生支持 |
FlashOPD 融合了五个顶级代码库的设计精华:
- CleanRL 的极简主义:每个文件自成一体,copy-paste 即可 hack
- Tianshou 的抽象能力:干净的接口分离(Teacher / Loss / Rollout / Balancer)
- EasyDistill 的易用性:一份 YAML 搞定所有配置
- SDPO 的 LLM 适配:原生支持 KV-cache rollout、分块 KL、API teacher
- OPSD 的学术严谨:完整的 forward/reverse/JSD KL、温度缩放、top-k sparse
git clone https://github.com/china10s/flashopd.git && cd flashopd
pip install -e . # 开发模式,改代码立即生效
pip install -e ".[all]" # 全量安装(含 LoRA + DeepSpeed + API)from flashopd import OPDConfig, OPDTrainer, create_teacher
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset
cfg = OPDConfig(
student_model="Qwen/Qwen2.5-1.5B-Instruct",
teacher_model="Qwen/Qwen2.5-7B-Instruct",
kl_type="reverse", kl_coef=0.1, max_new_tokens=256,
)
tokenizer = AutoTokenizer.from_pretrained(cfg.student_model, trust_remote_code=True)
student = AutoModelForCausalLM.from_pretrained(cfg.student_model, torch_dtype="auto", trust_remote_code=True)
teacher = create_teacher(cfg, student_tokenizer=tokenizer)
trainer = OPDTrainer(
opd_config=cfg, teacher=teacher, model=student,
args=TrainingArguments(output_dir="./out", num_train_epochs=1, bf16=True),
train_dataset=load_dataset("json", data_files="data.jsonl", split="train"),
tokenizer=tokenizer,
)
trainer.train()flashopd --config configs/default.yaml# 单机多卡
NUM_GPUS=8 bash scripts/run_distill.sh configs/default.yaml
# DeepSpeed ZeRO-2
bash scripts/run_distill_deepspeed.sh configs/default.yaml configs/ds_zero2.jsonflashopd/
├── config.py (100 行) 所有超参集中管理,支持 YAML 读写
├── loss.py (137 行) KL 散度:forward / reverse / JSD / sparse / top-k
├── rollout.py ( 91 行) KV-cache 加速的 student 自回归生成
├── teacher.py (177 行) Teacher 抽象:本地模型 / vLLM API 统一接口
├── balancer.py ( 56 行) CE+KL 动态平衡:fixed / EMA 自适应
├── trainer.py (190 行) 核心 Trainer:粘合 rollout→teacher→loss→balance
├── train.py (110 行) 完整训练流程:模型加载→数据→训练→保存
└── cli.py ( 51 行) 命令行入口
核心蒸馏逻辑仅 651 行,每个文件可独立阅读和修改。
FlashOPD 将所有 OPD 超参收进一个 OPDConfig dataclass,
对应你的 shell 脚本配置 一一映射:
| Shell 变量 | OPDConfig 字段 | 说明 |
|---|---|---|
OPD_CE_COEF |
ce_coef |
CE loss 系数 |
OPD_KL_COEF |
kl_coef |
KL loss 系数 |
OPD_LOSS_BALANCE |
loss_balance |
fixed / dynamic |
OPD_KL_TYPE |
kl_type |
reverse / forward / jsd |
OPD_TEMPERATURE |
temperature |
softmax 温度 |
OPD_MAX_NEW_TOKENS |
max_new_tokens |
Rollout 最大长度 |
OPD_KL_MIN |
kl_clip_min |
KL 下界截断 |
OPD_KL_MAX |
kl_clip_max |
KL 上界截断 |
OPD_KL_TOP_K |
kl_top_k |
KL top-k token |
OPD_ROLLOUT_TOP_K |
rollout_top_k |
Rollout top-k 采样 |
OPD_ROLLOUT_TOP_P |
rollout_top_p |
Rollout top-p 核采样 |
OPD_ROLLOUT_TEMPERATURE |
rollout_temperature |
Rollout 采样温度 |
OPD_DISABLE_AFTER_RATIO |
disable_after_ratio |
OPD 关闭阈值 |
TEACHER_MODEL_PATH |
teacher_model |
Teacher 模型路径 |
OPD_TEACHER_API_URL |
teacher_api_url |
vLLM API 地址 |
OPD_TEACHER_API_MODEL |
teacher_api_model |
API 模型名 |
OPD_TEACHER_API_LOGPROBS |
teacher_api_logprobs |
API logprobs 数量 |
rollout.py: student_rollout()
├── 首次 forward: 处理完整 prompt,缓存 KV
├── 逐 token 生成: 每步只 forward 1 个 token embedding
└── 支持 greedy / top-k / top-p / temperature 采样
对 (B, L, V) 的 logits 沿 sequence 维度分块计算 KL, 避免在长序列 + 大词表场景下 OOM。默认 chunk_size=128。
TeacherBackend (抽象基类)
├── LocalTeacher → 同进程加载,返回完整 logits
└── APITeacher → vLLM OpenAI API,返回 sparse logprobs
两种后端对 Trainer 暴露统一接口,切换只需改一个配置。
CE 和 KL 的量级可能差异很大,dynamic 模式用 EMA 跟踪两者的
running average,自动缩放 KL 权重使两者贡献对齐。
在 loss.py 中添加函数,在 trainer.py 的 compute_loss 中路由即可。
继承 TeacherBackend,实现 get_logits 或 get_sparse_logprobs。
OPDTrainer 继承自 HuggingFace Trainer,
你可以 override compute_loss 来注入任意前处理逻辑。
Apache-2.0