Skip to content

china10s/flash-opd

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashOPD

6 个文件 · 650 行核心代码 · 完整实现 On-Policy Distillation

对比 Swift (72K+ 行)、LlamaFactory (45K+ 行)、SLIME (30K+ 行) 的 OPD 实现, FlashOPD 用 不到 1/50 的代码量 提供了 同等完整度 的 On-Policy Distillation 能力。


Why FlashOPD?

特性 FlashOPD Swift LlamaFactory EasyDistill
OPD 核心代码量 ~650 行 72K+ 行 45K+ 行 15K+ 行
安装时间 < 30 秒 (源码) 5-10 分钟 3-5 分钟 3-5 分钟
学习曲线 读完 6 个文件 需理解框架体系 需理解框架体系 需理解多模块
单文件可读 CleanRL 风格 分散在数十个模块 分散在数十个模块 多层封装
API Teacher vLLM OpenAI API 需自行集成 需自行集成 部分支持
KL 类型 forward/reverse/JSD 部分 部分 部分
动态 Loss 平衡 EMA 自适应 手动 手动 手动
DeepSpeed/FSDP 原生支持 原生支持 原生支持 原生支持

设计哲学

FlashOPD 融合了五个顶级代码库的设计精华:

  • CleanRL 的极简主义:每个文件自成一体,copy-paste 即可 hack
  • Tianshou 的抽象能力:干净的接口分离(Teacher / Loss / Rollout / Balancer)
  • EasyDistill 的易用性:一份 YAML 搞定所有配置
  • SDPO 的 LLM 适配:原生支持 KV-cache rollout、分块 KL、API teacher
  • OPSD 的学术严谨:完整的 forward/reverse/JSD KL、温度缩放、top-k sparse

Quick Start

安装(< 30 秒)

git clone https://github.com/china10s/flashopd.git && cd flashopd
pip install -e .               # 开发模式,改代码立即生效
pip install -e ".[all]"        # 全量安装(含 LoRA + DeepSpeed + API)

20 行代码启动蒸馏

from flashopd import OPDConfig, OPDTrainer, create_teacher
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset

cfg = OPDConfig(
    student_model="Qwen/Qwen2.5-1.5B-Instruct",
    teacher_model="Qwen/Qwen2.5-7B-Instruct",
    kl_type="reverse", kl_coef=0.1, max_new_tokens=256,
)
tokenizer = AutoTokenizer.from_pretrained(cfg.student_model, trust_remote_code=True)
student = AutoModelForCausalLM.from_pretrained(cfg.student_model, torch_dtype="auto", trust_remote_code=True)
teacher = create_teacher(cfg, student_tokenizer=tokenizer)

trainer = OPDTrainer(
    opd_config=cfg, teacher=teacher, model=student,
    args=TrainingArguments(output_dir="./out", num_train_epochs=1, bf16=True),
    train_dataset=load_dataset("json", data_files="data.jsonl", split="train"),
    tokenizer=tokenizer,
)
trainer.train()

或者,一行命令

flashopd --config configs/default.yaml

分布式训练

# 单机多卡
NUM_GPUS=8 bash scripts/run_distill.sh configs/default.yaml

# DeepSpeed ZeRO-2
bash scripts/run_distill_deepspeed.sh configs/default.yaml configs/ds_zero2.json

Architecture

flashopd/
├── config.py      (100 行)  所有超参集中管理,支持 YAML 读写
├── loss.py        (137 行)  KL 散度:forward / reverse / JSD / sparse / top-k
├── rollout.py     ( 91 行)  KV-cache 加速的 student 自回归生成
├── teacher.py     (177 行)  Teacher 抽象:本地模型 / vLLM API 统一接口
├── balancer.py    ( 56 行)  CE+KL 动态平衡:fixed / EMA 自适应
├── trainer.py     (190 行)  核心 Trainer:粘合 rollout→teacher→loss→balance
├── train.py       (110 行)  完整训练流程:模型加载→数据→训练→保存
└── cli.py         ( 51 行)  命令行入口

核心蒸馏逻辑仅 651 行,每个文件可独立阅读和修改。


全部配置项

FlashOPD 将所有 OPD 超参收进一个 OPDConfig dataclass, 对应你的 shell 脚本配置 一一映射

Shell 变量 OPDConfig 字段 说明
OPD_CE_COEF ce_coef CE loss 系数
OPD_KL_COEF kl_coef KL loss 系数
OPD_LOSS_BALANCE loss_balance fixed / dynamic
OPD_KL_TYPE kl_type reverse / forward / jsd
OPD_TEMPERATURE temperature softmax 温度
OPD_MAX_NEW_TOKENS max_new_tokens Rollout 最大长度
OPD_KL_MIN kl_clip_min KL 下界截断
OPD_KL_MAX kl_clip_max KL 上界截断
OPD_KL_TOP_K kl_top_k KL top-k token
OPD_ROLLOUT_TOP_K rollout_top_k Rollout top-k 采样
OPD_ROLLOUT_TOP_P rollout_top_p Rollout top-p 核采样
OPD_ROLLOUT_TEMPERATURE rollout_temperature Rollout 采样温度
OPD_DISABLE_AFTER_RATIO disable_after_ratio OPD 关闭阈值
TEACHER_MODEL_PATH teacher_model Teacher 模型路径
OPD_TEACHER_API_URL teacher_api_url vLLM API 地址
OPD_TEACHER_API_MODEL teacher_api_model API 模型名
OPD_TEACHER_API_LOGPROBS teacher_api_logprobs API logprobs 数量

Key Design Decisions

1. KV-Cache Rollout(比 naive forward 快 5-10x)

rollout.py: student_rollout()
├── 首次 forward: 处理完整 prompt,缓存 KV
├── 逐 token 生成: 每步只 forward 1 个 token embedding
└── 支持 greedy / top-k / top-p / temperature 采样

2. 分块 KL 计算(避免 OOM)

对 (B, L, V) 的 logits 沿 sequence 维度分块计算 KL, 避免在长序列 + 大词表场景下 OOM。默认 chunk_size=128。

3. Teacher 后端解耦

TeacherBackend (抽象基类)
├── LocalTeacher   → 同进程加载,返回完整 logits
└── APITeacher     → vLLM OpenAI API,返回 sparse logprobs

两种后端对 Trainer 暴露统一接口,切换只需改一个配置。

4. 动态 Loss 平衡

CE 和 KL 的量级可能差异很大,dynamic 模式用 EMA 跟踪两者的 running average,自动缩放 KL 权重使两者贡献对齐。


Extending FlashOPD

添加新的 KL 类型

loss.py 中添加函数,在 trainer.pycompute_loss 中路由即可。

自定义 Teacher

继承 TeacherBackend,实现 get_logitsget_sparse_logprobs

自定义数据流

OPDTrainer 继承自 HuggingFace Trainer, 你可以 override compute_loss 来注入任意前处理逻辑。


License

Apache-2.0

About

6 files · 650 lines of core code · Full implementation of On-Policy Distillation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors