Skip to content

feat(recipes): add per-step learning rate support to all training recipes#359

Open
renfeichen-fw wants to merge 1 commit intomainfrom
feat/per-step-lr-schedule
Open

feat(recipes): add per-step learning rate support to all training recipes#359
renfeichen-fw wants to merge 1 commit intomainfrom
feat/per-step-lr-schedule

Conversation

@renfeichen-fw
Copy link
Copy Markdown
Contributor

Description

Add per-step learning rate control to all cookbook training recipes (SFT, DPO, RL, ORPO).

The primary abstraction is lr_per_step: list[float] | None — an explicit list of LR values, one per optimizer step. Customers pass this list to recipes (they compute it however they want). The managed service can auto-generate it from schedule params (lr_schedule, warmup_ratio, min_lr_ratio) after data loading.

Architecture / Code Overview Diagram

flowchart TD
    subgraph customerPath["Customer Path"]
        CustCode["Customer Code"] -->|"lr_per_step=[...]"| Config["Recipe Config"]
    end
    subgraph managedPath["Managed / Schedule Path"]
        Params["lr_schedule + warmup_ratio + min_lr_ratio"] --> Builder["build_lr_per_step()"]
        Builder -->|"[0.0, 1e-6, ..., 1e-4, ..., 1e-5]"| Config
    end
    Config --> Loop["Recipe Training Loop"]
    Loop -->|"step i: AdamParams(lr=lr_per_step[i])"| Trainer["Tinker optim_step"]
Loading

Type of Change

  • Bug fix
  • New feature
  • Breaking change
  • Refactoring
  • Documentation
  • Infrastructure/DevOps

Testing

  • Added/updated tests
  • Tested manually
  • No testing needed

Unit tests: 18 tests in test_lr_schedule.py covering constant/cosine/linear schedules, warmup, min_lr_ratio, edge cases, and resolve_step_lr fallback.

Smoke test: Verified E2E on 8xH200 with Qwen3.5-35B-A3B — 5 optimizer steps with different LRs (1e-5, 5e-5, 1e-4, 7e-5, 5e-6) and varying grad accumulation counts (1, 2, 3). All steps completed, reported LR metrics matched exactly.

Import smoke tests: All 80 tests pass including Config defaults and __all__ resolution.

Surface Consistency

  • No customer-facing surface impact
  • Related surfaces checked — all consistent or follow-up filed
  • Inline "keep in sync" comments followed

Deployment Notes

  • Requires database migration
  • Requires config/env changes
  • Requires Terraform/K8s changes
  • No special deployment considerations

Change Size

  • Small (< 200 LOC)
  • Medium (200–999 LOC)
  • Large (≥ 1,000 LOC)

Checklist

  • Agent-reviewed the diff before committing
  • Self-reviewed my code
  • Change is the minimum necessary diff
  • Added tests for my changes
  • Updated relevant documentation
  • No new linter warnings/errors
  • No secrets or credentials in the diff
  • Checked surface consistency for customer-facing changes
  • Visual diagram included

Files Changed

File What changed
training/utils/lr_schedule.py Newbuild_lr_per_step() and resolve_step_lr()
training/utils/__init__.py Export new symbols
training/recipes/sft_loop.py lr_per_step + schedule fields, per-step LR in loop
training/recipes/dpo_loop.py Same; removed static adam_params from _train_loop
training/recipes/rl_loop.py Same; tail-fallback for dynamic step counts
training/recipes/orpo_loop.py lr_per_step override; replaced inline _compute_lr with shared utility
training/tests/unit/test_lr_schedule.py New — 18 unit tests
training/tests/unit/test_dpo_loop.py Updated test mocks for new _train_loop signature

Made with Cursor

…ipes

Add `lr_per_step: list[float] | None` field to SFT, DPO, RL, and ORPO
recipe Configs. When set, each optimizer step uses `lr_per_step[step]`
instead of a flat `learning_rate`. Falls back to the last value if the
list is shorter than the actual step count.

Also adds `lr_schedule`, `warmup_ratio`, and `min_lr_ratio` convenience
fields that auto-generate `lr_per_step` from schedule params after data
loading (when the caller doesn't supply an explicit list).

Shared utility `build_lr_per_step()` extracted from ORPO's `_compute_lr()`
into `training/utils/lr_schedule.py` so all recipes and the managed
orchestrator share the same schedule logic.

All new fields default to backward-compatible values (None / "constant"
/ 0.0). No proto, Go, or trainer backend changes required.

Made-with: Cursor

beta: float = 0.1
learning_rate: float = 1e-5
lr_per_step: list[float] | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a list? I feel that nobody gives a list of LR precomputed per step as input...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants