Skip to content

[rl] Search-R1: multi-turn retrieval-augmented GRPO#3602

Open
yichuan-w wants to merge 3 commits into
pytorch:mainfrom
yichuan-w:yichuan/search-r1-multiturn-rl
Open

[rl] Search-R1: multi-turn retrieval-augmented GRPO#3602
yichuan-w wants to merge 3 commits into
pytorch:mainfrom
yichuan-w:yichuan/search-r1-multiturn-rl

Conversation

@yichuan-w

@yichuan-w yichuan-w commented Jun 10, 2026

Copy link
Copy Markdown
Member

What

Adds a Search-R1 example (multi-turn <search>/<information>/<answer> open-domain
QA RL) to torchtitan/experiments/rl, plus the minimal multi-turn RL infra and the
RL stabilizers (KL-to-reference + clip-higher) needed to train it without collapse.
Modeled on slime's examples/search-r1, following PR #3582's example-folder format.

Why

The RL controller previously only supported single-turn rollouts and
rollout_to_episode rejected >1 turn. Search-R1 is inherently multi-turn.

Core multi-turn infra (single-turn path verified bit-identical: sum_digits loss with --debug.seed=42 --debug.deterministic)

  • Multi-turn controller: batched generate-per-turn loop until each rollout is terminal (trainer.py).
  • Episode flattening + per-token loss_mask masking env-injected (retrieved) tokens (types.py, rollout/utils.py, batcher.py).
  • Generator stop-strings + include_stop_str_in_output (actors/generator.py).
  • TokenEnv.max_num_turns cap (environment/token.py).
  • Fixed held-out eval: Rollouter.reset_validation() so every eval scores the same set (else the iterator advanced and each eval used a different subset).
  • Robustness: GRPOLoss drops tokens with non-finite generator logprobs (vLLM emits occasional NaN logprobs) + log-ratio clamp; batcher clamps grad-accum to available packed rows (no empty-microbatch torch.cat).

RL stabilizers (ported from slime — needed to avoid mode collapse)

  • KL-to-reference penalty (low_var_kl/k3, _compute_kl) + a frozen reference model built in PolicyTrainer when kl_coef>0 (ref_logprobs computed no-grad in forward_backward).
  • clip-higher (asymmetric clamp 1-clip_eps / 1+clip_eps_high).

Result (Qwen3-1.7B, fixed NQ-500 test EM)

EM (baseline → trained) response_length behavior
bare GRPO (no KL) 0.16 → ~0.21 140 → ~30 (collapse) drops <search>, bare <answer> (reward hacking)
+ KL + clip-higher 0.159 → ~0.22 plateau (peak 0.229, +44%) 140 → ~60–130 (no collapse) keeps "After searching, I found …"

KL holds the policy near the base model, so EM rises while preserving the search
behavior
(genuine Search-R1) instead of collapsing to terse-answer hacking. EM is
noisy (±0.02–0.03); read the plateau over 20 eval passes, not single points.

Tests

tests/test_search_r1.py (rubric + env parsing), tests/test_rollout_to_episode.py
(multi-turn flattening + loss-mask). pre-commit clean.

Requires a running local dense retrieval server + Search-R1 NQ/HotpotQA data; see examples/search_r1/README.md.

Add a Search-R1 example (multi-turn <search> / <information> / <answer> QA RL)
to torchtitan/experiments/rl, plus the minimal core infra it required. Modeled
on slime's examples/search-r1, following PR pytorch#3582's example-folder format.

Core RL infra (single-turn behavior unchanged; sum_digits loss verified
bit-identical before/after with --debug.seed=42 --debug.deterministic):
- Multi-turn rollout controller: batched generate-per-turn loop until each
  rollout is terminal (trainer.py), replacing the single generate+step path.
- Multi-turn episode flattening with a per-token loss_mask that masks
  env-injected (retrieved) tokens out of the GRPO loss (types.py,
  rollout/utils.py, batcher.py).
- Generator stop-strings + include_stop_str_in_output (actors/generator.py).
- TokenEnv.max_num_turns cap (environment/token.py).
- Periodic held-out validation (validation_freq); surface real exceptions
  before the hang-prone Monarch shutdown (train.py).
- GRPOLoss robustness: drop tokens whose generator logprob is non-finite
  (vLLM occasionally emits NaN logprobs) + clamp log-ratio -> no NaN loss.
- Batcher: clamp grad-accum to available packed rows (avoids empty-microbatch
  torch.cat on under-filled multi-turn steps).

Example examples/search_r1/: data / env / retrieval / rubric / rollouter + README.
Configs: rl_grpo_qwen3_0_6b_search_r1 (smoke), rl_grpo_qwen3_1_7b_search_r1.
Tests: rubric + env parsing; multi-turn rollout_to_episode flattening.

Result (Qwen3-1.7B, NQ test EM over 500 prompts): 0.157 baseline -> 0.257 peak
(+61%) by step 25, then overfits (train reward up, eval down) -> needs KL/clip-higher.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@yichuan-w yichuan-w requested review from shuhuayu and tianyu-l and removed request for tianyu-l June 10, 2026 03:45
validate() pulled from an advancing validation-dataset iterator, so each
validation pass scored a *different* subset of the benchmark — making the
eval-over-steps trend non-comparable (baseline vs later steps measured on
different questions). Add Rollouter.reset_validation() and call it at the
start of validate() so every pass scores the same fixed first-N prompts
(deterministic via the dataset seed). Matters for finite benchmarks like
NQ; harmless for the infinite generated datasets (sum_digits/alphabet).
completion_offset:next_completion_offset
]
completion_offset = next_completion_offset
else:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the if part of this else?

prompt_token_ids = [lr.next_prompt_token_ids for lr in active]
# TODO: pass the remaining budget (max_rollout_tokens - len(prompt)) to the
# sampling_config, to limit generation length in one turn.
completions, turn_metrics = self._get_rank_0_value(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems a barrier between two consecutive turns? So there's no continuous batching across turns.

@shuhuayu

Copy link
Copy Markdown
Contributor

thanks. can you briefly describe why search-r1 is a good example to stay in titan rl?

felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 10, 2026
Ports the Search-R1 example (upstream pytorch#3602) onto our CB + async + losses
stack as `examples/search_r1/`. The model <think>s, emits <search>query</search>
(the env injects <information> from a retriever), and finally <answer>s; reward
is exact-match + format.

Two infra adds the example needs:
- `TokenEnv.max_num_turns`: cap assistant turns; end TRUNCATED_LENGTH past it.
- `SamplingConfig.stop` / `include_stop_str_in_output`: pass vLLM stop strings so
  generation halts at </search> / </answer> and keeps the tag for parsing.

Example layer (data / env / retrieval / rubric / rollouter) matches our
alphabet_sort conventions unchanged; uses DAPO. Config
`rl_grpo_qwen3_0_6b_search_r1` (flex, 4 GPUs).

A real run needs the Search-R1 dense-retrieval server + NQ/HotpotQA parquet;
smoke-tested here with a synthetic parquet + mock retriever (8 steps: multi-turn
search->inject->answer flow, stop strings, max-turn truncation, EM reward, async
hotswap — all exercised). Tests: test_search_r1.py + suite; 189 passed.
…lapse

Without a KL penalty the bare GRPO clip lets the policy collapse to a degenerate
"emit a terse <answer> with no <think>/<search>" mode: NQ response_length drops
140->~30, the model stops searching, and EM "rises" only by answer-formatting
(reward hacking), not by retrieval reasoning.

Port slime's stabilizers (minimal):
- GRPOLoss: KL-to-reference penalty (low_var_kl/k3 estimator, _compute_kl ported
  from slime/OpenRLHF) + clip-higher (asymmetric clamp 1-clip_eps / 1+clip_eps_high).
- PolicyTrainer: build a frozen reference model (the initial/base policy) when
  kl_coef>0 and compute ref_logprobs (no-grad) in forward_backward.
- config: rl_grpo_qwen3_1_7b_search_r1 uses kl_coef=0.001, clip_eps_high=0.28,
  lr 1e-6 constant; trainer TP=1 + generator TP=4.

Result (Qwen3-1.7B, fixed NQ-500 eval): EM 0.159 -> ~0.22 plateau (peak 0.229,
+44%), held over 20 eval passes, and response_length stays ~60-130 (no collapse;
completions remain "After searching, I found that ..."). vs the no-KL run which
hit similar EM only via collapse to bare answers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants