Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions training/examples/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Async RL Examples

This directory includes two minimal async RL rollout examples:

- `single_turn_token_in`
- `multi_turn_message_in`

The examples keep sampler ownership in rollout code. The async training recipe
passes metadata and a fixed-width request gate through `RolloutContext`; it does
not inject a sampler, rollout provider, or rollout engine.
Users can subclass `RolloutContext` or pass `ctx_extras` when their rollout
engine needs more state.

`single_turn_token_in` expects pre-tokenized row fields such as
`prompt_token_ids`. `multi_turn_message_in` accepts OpenAI-style `messages` and
uses the generic TITO adapter so generated assistant token IDs are preserved
across turns.

The `train.py` files use placeholder model and tokenizer names. Replace them
with your own public model and tokenizer identifiers before running.
3 changes: 3 additions & 0 deletions training/examples/rl/multi_turn_message_in/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Materialized GSM8K splits — regenerate with `python prepare_data.py`.
train.jsonl
test.jsonl
58 changes: 58 additions & 0 deletions training/examples/rl/multi_turn_message_in/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Multi-Turn GSM8K (message-in) RL example

A multi-turn math agent that ports
[AReaL's `examples/multi_turn_math/`](https://github.com/inclusionAI/AReaL/tree/main/examples/multi_turn_math)
to the cookbook's async RL recipe.

The model is asked a GSM8K problem and must put its final answer in
`\boxed{...}`. If the boxed answer is wrong (verified by `math_verify`),
the rollout appends a fixed user-feedback message and lets the model retry
once more (configurable via `--max-turns`, default 2).

The full trajectory — prompt + first attempt + feedback + second attempt —
is packed into a single `RolloutSample`. `MessageTrajectoryAssembler`
keeps the per-token loss mask aligned: assistant tokens are trained on
across both turns; the original prompt and the user-feedback bridge tokens
are masked out.

## Files

- `prepare_data.py` — downloads `openai/gsm8k` (config `main`) from HuggingFace
via ``load_dataset("openai/gsm8k", "main")`` and writes ``train.jsonl``
(7473 rows) and ``test.jsonl`` (1319 rows) with ``{messages, answer}`` rows.
- `reward.py` — extracts `\boxed{...}` from the completion and verifies
against the GSM8K ground-truth (`#### N`) via numeric match plus a
`math_verify` fallback.
- `rollout.py` — per-sample `make_rollout_fn(setup) -> RolloutFn`; runs the
retry loop and returns one `RolloutSample` per trajectory.
- `train.py` — wires the dataset and the rollout factory into
`recipes/async_rl_loop.main`.
- `run.sh` — one-shot end-to-end (auto-runs `prepare_data.py` if needed).

## Usage

```bash
# 1. Download GSM8K (writes train.jsonl + test.jsonl in this directory).
python prepare_data.py

# 2. Train.
python train.py \
--base-model accounts/fireworks/models/qwen3-1p5b-instruct \
--tokenizer-model Qwen/Qwen2.5-1.5B-Instruct \
--max-rows 512 \
--completions-per-prompt 4 \
--max-turns 2 \
--output-model-id accounts/<acct>/models/gsm8k-mt
```

Or `bash run.sh` for the canned configuration.

## Reference

- AReaL `examples/multi_turn_math/gsm8k_rl_mt.py` — the original
`MultiTurnMathAgent` whose retry loop, feedback prompt, and reward
function this example mirrors.
- AReaL `areal/experimental/openai/types.py::to_tensor_dict` — the
bridge-masking semantics (prompt and feedback positions get
`loss_mask=0`, only assistant tokens are trained on) that
`MessageTrajectoryAssembler.to_flat()` reproduces in this recipe.
1 change: 1 addition & 0 deletions training/examples/rl/multi_turn_message_in/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Multi-turn message-in RL example."""
79 changes: 79 additions & 0 deletions training/examples/rl/multi_turn_message_in/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
"""Download GSM8K from HuggingFace and convert to JSONL for the recipe.

Pulls the real ``openai/gsm8k`` dataset (config ``main``) — both ``train`` and
``test`` splits in one call — and writes ``train.jsonl`` and ``test.jsonl``
next to this script (or wherever ``--output-dir`` points).

Output format per row::

{"id": "gsm8k-train-0",
"messages": [{"role": "user",
"content": "<question>\\nPlease put your final answer within \\boxed{}."}],
"answer": "<full GSM8K answer string ending in '#### N'>"}

Mirrors AReaL's ``examples/multi_turn_math/`` data shape so the reward function
can verify ``\\boxed{...}`` against the GSM8K ground-truth ``#### N`` token.

Usage::

python prepare_data.py # writes train.jsonl + test.jsonl
python prepare_data.py --split train # only train
python prepare_data.py --max-rows 100 # cap each split
"""

from __future__ import annotations

import argparse
import json
import os

from datasets import load_dataset

PROMPT_SUFFIX = "\nPlease put your final answer within \\boxed{}."

DEFAULT_DIR = os.path.dirname(os.path.abspath(__file__))


def _write_split(split_name: str, split_data, out_path: str, max_rows: int | None) -> int:
n = 0
with open(out_path, "w") as f:
for idx, row in enumerate(split_data):
if max_rows is not None and n >= max_rows:
break
entry = {
"id": f"gsm8k-{split_name}-{idx}",
"messages": [
{"role": "user", "content": row["question"] + PROMPT_SUFFIX},
],
"answer": row["answer"],
}
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
n += 1
return n


def main():
parser = argparse.ArgumentParser(description="Prepare GSM8K JSONL")
parser.add_argument("--split", default="all", choices=["all", "train", "test"],
help="Which split to write (default: both)")
parser.add_argument("--output-dir", default=DEFAULT_DIR,
help="Directory for {split}.jsonl files")
parser.add_argument("--max-rows", type=int, default=None,
help="Optional cap on rows per split")
args = parser.parse_args()

ds = load_dataset("openai/gsm8k", "main")
print(f"Loaded openai/gsm8k(main): "
f"train={len(ds['train'])} rows, test={len(ds['test'])} rows")

splits = ["train", "test"] if args.split == "all" else [args.split]
os.makedirs(args.output_dir, exist_ok=True)
for s in splits:
out_path = os.path.join(args.output_dir, f"{s}.jsonl")
n = _write_split(s, ds[s], out_path, args.max_rows)
print(f"Wrote {n} rows to {out_path}")


if __name__ == "__main__":
main()
87 changes: 87 additions & 0 deletions training/examples/rl/multi_turn_message_in/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""GSM8K reward function for the multi-turn message-in recipe.

Mirrors AReaL's ``examples/multi_turn_math/gsm8k_rl_mt.py::gsm8k_reward_fn``:
parse the model's completion (``\\boxed{...}``) and the GSM8K ground-truth
answer string (whose final number follows ``#### ``) with ``math_verify`` and
return ``1.0`` on a verified match, ``0.0`` otherwise.

A pure-regex numeric fallback handles the common case where ``math_verify``'s
LaTeX parsing rejects an otherwise-correct integer answer.
"""

from __future__ import annotations

import logging
import re

logger = logging.getLogger(__name__)

_BOXED_RE = re.compile(r"\\boxed\s*\{", re.DOTALL)
_GSM8K_GT_RE = re.compile(r"####\s*([\-\+]?[\d,]*\.?\d+)")
_NUMERIC_TOL = 1e-6


def _extract_boxed(text: str) -> str | None:
"""Return the content of the LAST ``\\boxed{...}`` in ``text``.

Walks brace depth so nested braces (``\\frac{1}{2}``) are preserved.
"""
matches = list(_BOXED_RE.finditer(text))
if not matches:
return None
last = matches[-1]
start = last.end()
depth = 1
i = start
while i < len(text) and depth > 0:
c = text[i]
if c == "{":
depth += 1
elif c == "}":
depth -= 1
i += 1
if depth != 0:
return None
return text[start : i - 1].strip()


def _gsm8k_ground_truth_number(answer: str) -> str | None:
"""Strip the GSM8K chain-of-thought; return the final number after ``#### ``."""
m = _GSM8K_GT_RE.search(answer)
if not m:
return None
return m.group(1).replace(",", "")


def _try_numeric_match(pred: str, gt: str) -> bool:
try:
return abs(float(pred.replace(",", "")) - float(gt.replace(",", ""))) < _NUMERIC_TOL
except (ValueError, OverflowError):
return False


def gsm8k_reward(completion: str, answer: str) -> float:
"""Return ``1.0`` if the model's boxed answer matches the GSM8K ground truth.

The cheap numeric path catches the GSM8K-typical case (integer answers).
Falls through to ``math_verify`` (a project dep) for fractions / surds /
LaTeX. Returns ``0.0`` on any failure -- never raises.
"""
pred = _extract_boxed(completion)
if pred is None:
return 0.0
gt_num = _gsm8k_ground_truth_number(answer)
if gt_num is not None and _try_numeric_match(pred, gt_num):
return 1.0

try:
from math_verify import parse as math_parse, verify as math_verify_fn

pred_parsed = math_parse(f"\\boxed{{{pred}}}")
gt_parsed = math_parse(answer)
if pred_parsed and gt_parsed and math_verify_fn(gt_parsed, pred_parsed):
return 1.0
except Exception:
logger.debug("math_verify failed; pred=%r gt=%r", pred, answer, exc_info=True)

return 0.0
111 changes: 111 additions & 0 deletions training/examples/rl/multi_turn_message_in/rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Multi-turn GSM8K rollout (per-sample) with retry-on-wrong feedback.

Mirrors AReaL's ``examples/multi_turn_math/gsm8k_rl_mt.py``: ask GSM8K, score
the boxed answer, and -- if the model is wrong on the first try -- append a
fixed user-feedback message and let it try once more (configurable via
``setup.extras["max_turns"]``, default ``2``).

The whole trajectory (prompt + assistant turn 1 + feedback + assistant turn 2)
is packed into a single ``RolloutSample``. ``MessageTrajectoryAssembler``
keeps the per-token loss mask aligned: ``1`` on every assistant-generated
token (across all turns), ``0`` everywhere else (original prompt, the
user-feedback bridge between turns). The scalar ``reward`` is the
last-turn verification result (``0.0`` or ``1.0``); rolling-up across turns
isn't useful in concat mode -- the GRPO advantage compares trajectories, so
"wrong then right" naturally beats "wrong then wrong" without an explicit
per-turn discount.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from training.examples.rl.multi_turn_message_in.reward import gsm8k_reward
from training.examples.rl.vanilla_sampler import build_deployment_sampler
from training.utils.rl.rollout import (
MessageTrajectoryAssembler,
RolloutSample,
TITOTokenizer,
)

if TYPE_CHECKING:
from training.recipes.async_rl_loop import RolloutFn, RolloutSetup

logger = logging.getLogger(__name__)

RETRY_PROMPT = (
"Your answer is either wrong or not parsable to the reward function. "
"You may misunderstand the original question. Please carefully read the "
"original question, check the previous errors, and try to answer it again."
)


def make_rollout_fn(setup: "RolloutSetup") -> "RolloutFn":
sampler = build_deployment_sampler(setup)
sample_kwargs = dict(setup.sample_kwargs)
tokenizer = setup.tokenizer
max_turns = int(setup.extras.get("max_turns", 2))
if max_turns < 1:
raise ValueError(f"max_turns must be >= 1, got {max_turns}")

async def rollout_fn(row: dict) -> RolloutSample | None:
messages = list(row.get("messages") or [])
answer = row.get("answer")
if not messages or answer is None:
return None

assembler = MessageTrajectoryAssembler(TITOTokenizer(tokenizer))
current_messages = messages
last_reward = 0.0

for turn in range(max_turns):
prompt_tokens = assembler.prepare_next_input(current_messages)
completions = await sampler.sample_with_prompt_tokens(
prompt_tokens, n=1, **sample_kwargs,
)
if not completions:
return None
completion = completions[0]

prompt_len = int(completion.prompt_len)
output_tokens = list(completion.full_tokens[prompt_len:])
output_logprobs = list(completion.inference_logprobs or [])
if (
getattr(completion, "logprobs_echoed", False)
and len(output_logprobs) == len(completion.full_tokens)
):
output_logprobs = output_logprobs[prompt_len:]
if not output_tokens or len(output_logprobs) != len(output_tokens):
return None

assistant_text = getattr(completion, "text", "") or tokenizer.decode(output_tokens)
assistant_message = {"role": "assistant", "content": assistant_text}
assembler.add_assistant_response(
request_messages=current_messages,
assistant_message=assistant_message,
prompt_token_ids=prompt_tokens,
completion_token_ids=output_tokens,
completion_logprobs=output_logprobs,
finish_reason=getattr(completion, "finish_reason", "stop"),
)

last_reward = gsm8k_reward(assistant_text, str(answer))
if last_reward >= 1.0:
break

if turn + 1 < max_turns:
current_messages = current_messages + [
assistant_message,
{"role": "user", "content": RETRY_PROMPT},
]

tokens, logprobs, loss_mask = assembler.trajectory.to_flat()
return RolloutSample(
tokens=tokens,
logprobs=logprobs,
loss_mask=loss_mask,
reward=last_reward,
)

return rollout_fn
25 changes: 25 additions & 0 deletions training/examples/rl/multi_turn_message_in/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
set -euo pipefail

HERE="$(cd "$(dirname "$0")" && pwd)"
REPO_ROOT="$(cd "$HERE/../../../.." && pwd)"
export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}"

if [[ ! -f "$HERE/train.jsonl" ]]; then
echo "train.jsonl not found; downloading openai/gsm8k from HuggingFace..."
python "$HERE/prepare_data.py"
fi

python "$HERE/train.py" \
--base-model accounts/fireworks/models/qwen3-1p5b-instruct \
--tokenizer-model Qwen/Qwen2.5-1.5B-Instruct \
--dataset-path "$HERE/train.jsonl" \
--max-rows 512 \
--epochs 1 \
--completions-per-prompt 4 \
--prompt-groups-per-step 8 \
--max-completion-tokens 1024 \
--max-turns 2 \
--learning-rate 1.7e-5 \
--kl-beta 0.0 \
--output-model-id "${OUTPUT_MODEL_ID:-accounts/fireworks/models/gsm8k-mt-$(date +%s)}"
Loading
Loading