[rl] Search-R1: multi-turn retrieval-augmented GRPO#3602
Open
yichuan-w wants to merge 3 commits into
Open
Conversation
Add a Search-R1 example (multi-turn <search> / <information> / <answer> QA RL) to torchtitan/experiments/rl, plus the minimal core infra it required. Modeled on slime's examples/search-r1, following PR pytorch#3582's example-folder format. Core RL infra (single-turn behavior unchanged; sum_digits loss verified bit-identical before/after with --debug.seed=42 --debug.deterministic): - Multi-turn rollout controller: batched generate-per-turn loop until each rollout is terminal (trainer.py), replacing the single generate+step path. - Multi-turn episode flattening with a per-token loss_mask that masks env-injected (retrieved) tokens out of the GRPO loss (types.py, rollout/utils.py, batcher.py). - Generator stop-strings + include_stop_str_in_output (actors/generator.py). - TokenEnv.max_num_turns cap (environment/token.py). - Periodic held-out validation (validation_freq); surface real exceptions before the hang-prone Monarch shutdown (train.py). - GRPOLoss robustness: drop tokens whose generator logprob is non-finite (vLLM occasionally emits NaN logprobs) + clamp log-ratio -> no NaN loss. - Batcher: clamp grad-accum to available packed rows (avoids empty-microbatch torch.cat on under-filled multi-turn steps). Example examples/search_r1/: data / env / retrieval / rubric / rollouter + README. Configs: rl_grpo_qwen3_0_6b_search_r1 (smoke), rl_grpo_qwen3_1_7b_search_r1. Tests: rubric + env parsing; multi-turn rollout_to_episode flattening. Result (Qwen3-1.7B, NQ test EM over 500 prompts): 0.157 baseline -> 0.257 peak (+61%) by step 25, then overfits (train reward up, eval down) -> needs KL/clip-higher.
validate() pulled from an advancing validation-dataset iterator, so each validation pass scored a *different* subset of the benchmark — making the eval-over-steps trend non-comparable (baseline vs later steps measured on different questions). Add Rollouter.reset_validation() and call it at the start of validate() so every pass scores the same fixed first-N prompts (deterministic via the dataset seed). Matters for finite benchmarks like NQ; harmless for the infinite generated datasets (sum_digits/alphabet).
tianyu-l
reviewed
Jun 10, 2026
| completion_offset:next_completion_offset | ||
| ] | ||
| completion_offset = next_completion_offset | ||
| else: |
Contributor
There was a problem hiding this comment.
Where is the if part of this else?
| prompt_token_ids = [lr.next_prompt_token_ids for lr in active] | ||
| # TODO: pass the remaining budget (max_rollout_tokens - len(prompt)) to the | ||
| # sampling_config, to limit generation length in one turn. | ||
| completions, turn_metrics = self._get_rank_0_value( |
Contributor
There was a problem hiding this comment.
There seems a barrier between two consecutive turns? So there's no continuous batching across turns.
Contributor
|
thanks. can you briefly describe why search-r1 is a good example to stay in titan rl? |
felipemello1
pushed a commit
to felipemello1/torchtitan
that referenced
this pull request
Jun 10, 2026
Ports the Search-R1 example (upstream pytorch#3602) onto our CB + async + losses stack as `examples/search_r1/`. The model <think>s, emits <search>query</search> (the env injects <information> from a retriever), and finally <answer>s; reward is exact-match + format. Two infra adds the example needs: - `TokenEnv.max_num_turns`: cap assistant turns; end TRUNCATED_LENGTH past it. - `SamplingConfig.stop` / `include_stop_str_in_output`: pass vLLM stop strings so generation halts at </search> / </answer> and keeps the tag for parsing. Example layer (data / env / retrieval / rubric / rollouter) matches our alphabet_sort conventions unchanged; uses DAPO. Config `rl_grpo_qwen3_0_6b_search_r1` (flex, 4 GPUs). A real run needs the Search-R1 dense-retrieval server + NQ/HotpotQA parquet; smoke-tested here with a synthetic parquet + mock retriever (8 steps: multi-turn search->inject->answer flow, stop strings, max-turn truncation, EM reward, async hotswap — all exercised). Tests: test_search_r1.py + suite; 189 passed.
…lapse Without a KL penalty the bare GRPO clip lets the policy collapse to a degenerate "emit a terse <answer> with no <think>/<search>" mode: NQ response_length drops 140->~30, the model stops searching, and EM "rises" only by answer-formatting (reward hacking), not by retrieval reasoning. Port slime's stabilizers (minimal): - GRPOLoss: KL-to-reference penalty (low_var_kl/k3 estimator, _compute_kl ported from slime/OpenRLHF) + clip-higher (asymmetric clamp 1-clip_eps / 1+clip_eps_high). - PolicyTrainer: build a frozen reference model (the initial/base policy) when kl_coef>0 and compute ref_logprobs (no-grad) in forward_backward. - config: rl_grpo_qwen3_1_7b_search_r1 uses kl_coef=0.001, clip_eps_high=0.28, lr 1e-6 constant; trainer TP=1 + generator TP=4. Result (Qwen3-1.7B, fixed NQ-500 eval): EM 0.159 -> ~0.22 plateau (peak 0.229, +44%), held over 20 eval passes, and response_length stays ~60-130 (no collapse; completions remain "After searching, I found that ..."). vs the no-KL run which hit similar EM only via collapse to bare answers.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Adds a Search-R1 example (multi-turn
<search>/<information>/<answer>open-domainQA RL) to
torchtitan/experiments/rl, plus the minimal multi-turn RL infra and theRL stabilizers (KL-to-reference + clip-higher) needed to train it without collapse.
Modeled on slime's
examples/search-r1, following PR #3582's example-folder format.Why
The RL controller previously only supported single-turn rollouts and
rollout_to_episoderejected >1 turn. Search-R1 is inherently multi-turn.Core multi-turn infra (single-turn path verified bit-identical: sum_digits loss with
--debug.seed=42 --debug.deterministic)trainer.py).loss_maskmasking env-injected (retrieved) tokens (types.py,rollout/utils.py,batcher.py).include_stop_str_in_output(actors/generator.py).TokenEnv.max_num_turnscap (environment/token.py).Rollouter.reset_validation()so every eval scores the same set (else the iterator advanced and each eval used a different subset).torch.cat).RL stabilizers (ported from slime — needed to avoid mode collapse)
low_var_kl/k3,_compute_kl) + a frozen reference model built inPolicyTrainerwhenkl_coef>0(ref_logprobscomputed no-grad inforward_backward).1-clip_eps/1+clip_eps_high).Result (Qwen3-1.7B, fixed NQ-500 test EM)
<search>, bare<answer>(reward hacking)KL holds the policy near the base model, so EM rises while preserving the search
behavior (genuine Search-R1) instead of collapsing to terse-answer hacking. EM is
noisy (±0.02–0.03); read the plateau over 20 eval passes, not single points.
Tests
tests/test_search_r1.py(rubric + env parsing),tests/test_rollout_to_episode.py(multi-turn flattening + loss-mask).
pre-commitclean.