A framework-agnostic training and evaluation harness for continual learning benchmarks. Train language models on text corpora and evaluate memorization via QA — using Kauldron (JAX) or HuggingFace/TRL (PyTorch) as the backend, with the same config, same data, and same metrics.
Pick exactly one backend before you start. The repo supports Kauldron (JAX) and HuggingFace/TRL (PyTorch); the rest of this README is organized so the same yaml works for either, but the conda env you create installs only one of them. Mixing isn't supported in a single env. If you want to try both, create two envs (e.g.
cbe-kdandcbe-hf) — see Environment Setup.
- TL;DR — End-to-end first run
- Environment Setup
- Authentication (HuggingFace + W&B)
- Downloading Data
- Formatting Data
- Training recipes
- Training
- Evaluation
- Logging and Monitoring
- Output Layout
- Adding a New Track
- Adding a New Model
- Known Limitations
If you just want to train Gemma3-1B-LoRA on the news track end-to-end with sane defaults, this is the full required path. Anything not in this list (formatting raw data, custom configs, sweeps, etc.) is optional and explained later.
# 1) Clone + install ONE backend (pick torch-gpu OR jax-gpu, not both)
git clone <repo-url>
cd ContinuousBenchEval
bash setup_env.sh torch-gpu wandb # for HF/TRL — env name: "cbe"
# (or) bash setup_env.sh jax-gpu wandb # for Kauldron — env name: "cbe"
# 2) Activate the env (every new shell needs this)
conda activate cbe
# 3) Get access to gated Gemma weights (one-time, on the HuggingFace website)
# Visit https://huggingface.co/google/gemma-3-1b-pt and click "Agree and access"
# (do this for every Gemma checkpoint you want to use: 1b, 4b, 12b, etc.)
# 4) Authenticate to HuggingFace + (optionally) W&B
hf auth login # paste a read token
wandb login # paste your W&B API key (skip if not using W&B)
# 5) Pull benchmark data (one-time per track)
python data/helper/load_data.py --track news # → data/news/{train,val,valqa,testqa}.jsonl
# 6) Format the news corpus (REQUIRED for the news task — see "Formatting Data")
python data/helper/format_news.py --input data/news/train.jsonl --output data/news/train.jsonl --overwrite
python data/helper/format_news.py --input data/news/val.jsonl --output data/news/val.jsonl --overwrite
# 7) Train
python train.py --config configs/tracks/news_gemma3_1b_lora128.yaml --framework hf
# (or --framework kd if you installed jax-gpu)That's all of the required steps. Default configs already specify model, batch sizes, learning rate, eval cadence, etc., so you don't have to touch any yaml unless you want to. The remaining sections describe what each piece does and how to customize it.
You must pick exactly one of torch-gpu, jax-gpu, or jax-tpu per env. They install conflicting frameworks. If you want to try both backends, create two separate envs (different env_name).
git clone <repo-url>
cd ContinuousBenchEval
# Pick ONE of the following — each creates a fresh conda env named "cbe":
bash setup_env.sh torch-gpu # HuggingFace / TRL on GPU
bash setup_env.sh jax-gpu # Kauldron on GPU
bash setup_env.sh jax-tpu # Kauldron on TPUsetup_env.sh <backend> [extras] [env_name]
# 2nd arg = "wandb" → also installs Weights & Biases support (any backend)
bash setup_env.sh torch-gpu wandb
bash setup_env.sh jax-gpu wandb
# 3rd arg = custom env name (any backend; must pass empty 2nd arg if no extras)
bash setup_env.sh torch-gpu "" cbe-hf # HF env named "cbe-hf"
bash setup_env.sh jax-gpu "" cbe-kd # KD env named "cbe-kd"
bash setup_env.sh jax-gpu wandb cbe-kd # both wandb + custom nameEach invocation creates a fresh conda env with Python 3.11 and all backend-specific dependencies.
Don't forget to activate it. Every new terminal shell needs
conda activate <env_name>(default:cbe) before running any of the train/eval/data-loader commands in this README.
| Backend | torch | jax | kauldron | gemma | trl/peft | Key pins |
|---|---|---|---|---|---|---|
torch-gpu |
2.4-2.5 (cu124) | - | - | - | trl, peft<0.15 | setuptools<81 |
jax-gpu |
- | 0.8.2 (cuda12) | 1.3.0 | latest | - | typeguard==4.4.1, setuptools<81 |
jax-tpu |
- | latest (tpu) | 1.3.0 | latest | - | typeguard==4.4.1, setuptools<81 |
- KD on GPU: JAX auto-discovers all visible GPUs. FSDP shards params across them. No special launcher needed.
- KD on TPU: Native JAX, handles sharding automatically.
- HF on single GPU:
python train.py --config ... --framework hf - HF on multi-GPU:
torchrun --nproc_per_node=N train.py --config ... --framework hf(DDP) - GPU selection:
CUDA_VISIBLE_DEVICES=0,1 python train.py ...
setup_env.sh registers a conda activation hook that puts pip-installed NVIDIA libs on LD_LIBRARY_PATH. If you create the env manually, you may need:
export LD_LIBRARY_PATH=$(find $CONDA_PREFIX/lib/python3.11/site-packages/nvidia -name lib -type d | tr '\n' ':')$LD_LIBRARY_PATHTwo services need credentials before training works. Both are one-time per machine; tokens persist to disk.
Gemma model weights are gated on HuggingFace. You must:
- Click "Agree and access" once per Gemma model on the HF website. The repo defaults to Gemma3, so visit at minimum:
- Get a read token at https://huggingface.co/settings/tokens
- Persist it locally so subprocesses can read it:
hf auth login # paste token interactively (writes ~/.cache/huggingface/token) # or: export HF_TOKEN=hf_... # add to ~/.bashrc to make it permanent
The same token is used by data/helper/load_data.py to pull benchmark data and by the trainer to download Gemma weights at runtime. Without it, you'll see 401 Unauthorized or GatedRepoError when training starts.
If you installed the wandb extra (bash setup_env.sh <backend> wandb):
wandb login # paste your API key from https://wandb.ai/authorizeThe credential persists to ~/.netrc. Runs land at wandb.ai/<your-username>/<project_name> where project_name comes from the YAML config. Skip this entirely if you only want TensorBoard.
Benchmark data is hosted on HuggingFace:
ContinuousBench/News(tagv5) — news articles + QAContinuousBench/Geminon(tagv9) — Geminon articles + QA
The downloader script lives at data/helper/load_data.py (and so does the recipe data/helper/download.yaml). Make sure you've run hf auth login first (see Authentication).
# Always always run from the repo root, NOT from data/helper/.
# (output paths are repo-root-relative)
# Just one track (recommended — pass --track explicitly)
python data/helper/load_data.py --track news
python data/helper/load_data.py --track geminon
# Download all tracks listed in the recipe (no --track flag)
python data/helper/load_data.py
# Override corpus / QA size (small/medium/large where supported)
python data/helper/load_data.py --track geminon --corpus large --qa medium
# Debug: list every file in the HF repo for a track
python data/helper/load_data.py --list news
python data/helper/load_data.py --list geminonThe download recipe (data/helper/download.yaml) maps HF repo paths to local filenames. After running the loader, files always land at:
data/<track>/train.jsonl
data/<track>/val.jsonl
data/<track>/valqa.jsonl
data/<track>/testqa.jsonl
Files are written to data/<track>/, not data/helper/<track>/. If you see them in helper/, you're running an out-of-date version of the script — re-pull main. The track YAML configs hard-code these data/<track>/... paths, so they only work after the loader has run.
The news data on HuggingFace ships as multi-column JSONL (url, hostname, title, date, crawl_date, language, text). The train.py data pipeline expects a single-column {"text": "Title: ...\nDate: ...\nArticle: ..."} shape. So after load_data.py you must run the formatter before the news track will train correctly.
# In-place rewrite (recommended — keeps the original filenames)
python data/helper/format_news.py --input data/news/train.jsonl --output data/news/train.jsonl --overwrite
python data/helper/format_news.py --input data/news/val.jsonl --output data/news/val.jsonl --overwrite
# OR write to a new file and update train_path / val_path in the config
python data/helper/format_news.py --input data/news/train.jsonl --output data/news/train_formatted.jsonlThe QA files (valqa.jsonl, testqa.jsonl) do not need formatting — they're already in the right shape.
For raw / dirty input (not from ContinuousBench), pass --normalize. ContinuousBench/News is pre-cleaned during curation, so --normalize is a no-op on it.
Geminon data ships pre-formatted; the loader output is ready to train on directly.
The shipped track configs in configs/tracks/ are recipes — one YAML per (task × model × adapter) combo, with sensible defaults already baked in. You almost never need to touch these. Just pick one and run it.
# Available recipes
ls configs/tracks/
# geminon_gemma3_1b_full.yaml news_gemma3_1b_full.yaml
# geminon_gemma3_1b_lora128.yaml news_gemma3_1b_lora128.yaml
# geminon_gemma3_4b_full.yaml news_gemma3_4b_full.yaml
# geminon_gemma3_4b_lora128.yaml news_gemma3_4b_lora128.yaml
# Run one
python train.py --config configs/tracks/news_gemma3_1b_lora128.yaml --framework hf
# (or --framework kd if you installed jax-gpu)The recipe's filename tells you everything: <task>_<model>_<adapter>.yaml. Each one inherits shared defaults from configs/base/{tasks,models}/, so the track file itself stays tiny (data + a few overrides like run name).
If you do need to tweak something: any field can be overridden from the CLI (
--override optimizer.lr=1e-4, etc.), or you can edit the recipe directly. The two batch-size knobs to know about areeffective_batch_size(real samples per optimizer step — same meaning everywhere) andper_device_batch_size(memory knob; lower it if you OOM, raise it for fewer grad-accum steps). Defaults fit 1× or 2× 40 GB A100 for most (model, adapter) pairs. To lower it on the fly:python train.py --config configs/tracks/news_gemma3_4b_full.yaml --framework hf \ --override training.per_device_batch_size=2
Always activate the env first. Every command in this section assumes
conda activate <env_name>has already been run in your current shell (default env iscbe). Runningtrain.pyfrom a non-activated shell will hitModuleNotFoundError: cbeor, worse, find a different Python and silently use it.
conda activate cbe # or whatever you named your env via setup_env.shpython train.py --config configs/tracks/news_gemma3_1b_lora128.yaml --framework hf# 4 GPUs
torchrun --nproc_per_node=4 train.py \
--config configs/tracks/news_gemma3_1b_lora128.yaml --framework hf
# Specific GPUs
CUDA_VISIBLE_DEVICES=2,3 torchrun --nproc_per_node=2 train.py \
--config configs/tracks/news_gemma3_1b_lora128.yaml --framework hfDDP is HF Trainer's default when launched via torchrun. Only rank 0 logs to wandb/TB and writes metrics — no duplicate entries.
python train.py --config configs/tracks/news_gemma3_1b_lora128.yaml --framework kd
# Use specific GPUs
CUDA_VISIBLE_DEVICES=0,1 python train.py \
--config configs/tracks/news_gemma3_1b_lora128.yaml --framework kdJAX auto-discovers all visible devices and shards via FSDP. No special launcher needed.
# HuggingFace checkpoint (auto-detects LoRA from adapter_config.json)
python evaluate.py --framework hf \
--checkpoint outputs/cbe/geminon/.../checkpoints/checkpoint-2000 \
--model gemma3-1b-pt \
--qa_data data/geminon/testqa.jsonl \
--parser geminon \
--num_examples 10
# Kauldron checkpoint (with LoRA — does split/merge of base + adapter)
python evaluate.py --framework kd \
--checkpoint outputs/cbe/geminon/.../checkpoints/ckpt_2000 \
--model gemma3-1b-pt --lora_rank 128 \
--qa_data data/geminon/testqa.jsonl \
--parser geminon
# Save detailed per-example results
python evaluate.py --framework hf \
--checkpoint outputs/cbe/geminon/.../checkpoints/checkpoint-2000 \
--model gemma3-1b-pt \
--qa_data data/geminon/testqa.jsonl \
--parser geminon \
--save_details results.jsonlSubstring/exact-match metrics undercount paraphrased correct answers. llm_evaluate.py re-scores the eval_details/*.jsonl files produced during training with Gemini as a judge, adding an llm_match: bool field per record and writing a stratified summary.
pip install google-genai
cp secrets/gemini_keys.txt.example secrets/gemini_keys.txt # then add your keys
# Judge one per-example results file
python llm_evaluate.py \
--input outputs/<project>/<run>/eval_details/testqa_step_001000.jsonl
# → writes testqa_step_001000_llm_judged.jsonl + testqa_step_001000_summary.jsonlThe script reads API keys from secrets/gemini_keys.txt (one per line, multi-key round-robin recommended for higher quota), or the GEMINI_API_KEY / GOOGLE_API_KEY env vars. Uses gemini-2.5-flash-lite with temperature=0 (deterministic) by default. See python llm_evaluate.py --help for resume, concurrency, and stratification options.
# View one run
tensorboard --logdir outputs/cbe/geminon/debug-kd --port 6006
# Compare all runs in a project
tensorboard --logdir outputs/cbe
# Remote machine — SSH tunnel
ssh -L 6006:localhost:6006 user@host
# then open http://localhost:6006For KD runs, Kauldron writes to separate subdirs per evaluator (train/, eval_loss/, qa_valqa/, qa_testqa/). Point TB at the run root to see all of them.
W&B is opt-in. To enable it:
- Install the wandb extra at env-creation time (any backend):
bash setup_env.sh torch-gpu wandb # or jax-gpu wandb / jax-tpu wandb - Authenticate one-time per machine (see Authentication):
wandb login # paste API key from https://wandb.ai/authorize - Add
wandbto thelogging.backendslist (already on by default in the news + geminon task configs):Or pass via CLI:logging: backends: [tensorboard, wandb] project_name: cbe run_name: news/my-experiment
--override "logging.backends=[tensorboard,wandb]"
Runs upload to wandb.ai/<your-username>/<project_name>.
Terminal output is also tee'd to outputs/<project>/<run>/logs/train.log — tail it live with tail -f.
python scripts/plot_runs.py outputs/<project>/<task>
# writes <task>/runs_plot.png with 3 panels: valqa fuzzy match, eval loss, train lossThe script auto-discovers every subdir under the given task dir that has metrics/eval_results.jsonl, infers framework (HF vs KD) from file shape, and reads train/eval loss from HF's trainer_state.json log_history or KD's TB event files. Multiple TB event files per run (from resume) are merged automatically. X-axis is normalized to optimizer steps.
Every run writes to a standardized directory:
outputs/<project_name>/<run_name>/
├── config.yaml # Frozen copy of the resolved config
├── logs/
│ ├── train.log # Full stdout+stderr
│ └── tensorboard/ # TB event files (from MultiLogger)
├── train/ # KD-only: train loss events
├── eval_loss/ # KD-only: eval loss events
├── checkpoints/
│ ├── checkpoint-2000/ # HF naming
│ ├── ckpt_2000/ # KD naming
│ └── latest -> checkpoint-2000 # Symlink to most recent
├── metrics/
│ └── eval_results.jsonl # Append-only: {step, eval_loss, valqa_exact_match, ...}
└── eval_details/ # Per-example QA results (opt-in)
├── valqa_step_002000.jsonl
└── testqa_step_002000.jsonl
Record schemas (one JSON object per line):
// metrics/eval_results.jsonl
{"step": 2000, "timestamp": "...", "eval_loss": 3.298, "valqa_exact_match": 0.42, "valqa_fuzzy_match": 0.51, "testqa_exact_match": 0.38, "testqa_fuzzy_match": 0.47}
// eval_details/<set>_step_<N>.jsonl
{"prompt": "Q: ...\nA:", "question": "...", "raw_prediction": " ...", "parsed_prediction": "...", "ground_truth": "...", "exact_match": true, "fuzzy_match": true}- Place data in
data/<track>/{train,val,valqa,testqa}.jsonl(or add entries todata/helper/download.yamland runpython data/helper/load_data.py --track <track>) - Create a task base file
configs/base/tasks/<track>.yaml(copyconfigs/base/tasks/news.yamlas a template; updatedata.*_path,eval.parser, etc.) - Create a track config
configs/tracks/<track>_gemma3_<size>_<adapter>.yaml(copy any of the existingnews_gemma3_*files); set its_base:list to point at the new task + the model base you want - Train:
python train.py --config configs/tracks/<track>_gemma3_<size>_<adapter>.yaml --framework hf
- HF: Set
model.nameto any HuggingFace hub ID (meta-llama/Llama-3.1-8B,mistralai/Mistral-7B-v0.3, etc.). Short Gemma names (gemma3-1b-pt) are auto-mapped to hub IDs. - KD: Implement the
JaxModelFactoryprotocol insrc/cbe/models/kd_models.py. The Gemma factory there is a reference implementation.