Skip to content

[rl] continuous-batching generator + multi-turn rollouts#3593

Open
felipemello1 wants to merge 10 commits into
pytorch:mainfrom
felipemello1:57-pr2-readable-cb
Open

[rl] continuous-batching generator + multi-turn rollouts#3593
felipemello1 wants to merge 10 commits into
pytorch:mainfrom
felipemello1:57-pr2-readable-cb

Conversation

@felipemello1

@felipemello1 felipemello1 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Note: Its a big PR. I reviewed it a few times, but there might still be rough edges and slop. Sorry/thank you.

How to review

a. actors/generator.py — the engine loop + LoopDecision (the core).
b. rollout/rollouter.pyrun_group_rollouts / run_single_rollout.
c. rollout/utils.py:rollout_to_episodes + batcher.py.
d. trainer.py

I don't think that you should try to compare before/after. Just open the file and read the code.

Summary

A vLLM generator that drives many concurrent generations through ONE engine loop (continuous batching), plus the multi-turn rollout stack around it. A whole GRPO group's turns now coalesce into one batch instead of N blocking calls.

Mental model — request intake is decoupled from engine.step (the TP collective):

(generate / pull / close)--> rank0 queue --> _engine_loop --> engine.step
each rollout awaits its own generate; the engine loop batches whatever is in flight.

What's in it

  • Continuous-batching generator (actors/generator.py): one _engine_loop per rank; rank0 fires independent generate/pull/close and resolves per-request futures.

  • Rollout moved into a Rollouter (rollout/rollouter.py): Now it also owns run_group_rollouts and run_single_rollout. The controller no longer drives rollouts. It just calls the rollouter.

  • Episodes + batcher (rollout/utils.py, batcher.py): Made changes to adapt it to multiturn.

  • Weight sync / hotswap: default is to swap weight in between tokens. We expose two flags to users can decide what to do with the kv cache. TODO: support drain then swap. This is not really relevant for this PR, since training is still synchronous.

  • GRPO loss: got some NaN from vllm when cudagraph=True. I added clamping + nan_to_num + metric. We should monitor and figure out why this happens in cudagraph and not in eager.

  • Task: AlphabetSort multi-turn -- changed default to max 3 turns (instead of 1) and adapted prompt to the original that i copied from (verifiers)

Test plan

AlphabetSort, Qwen3-0.6B, 6 GPUs (4 gen + 2 train), cudagraph on:

python torchtitan/experiments/rl/train.py --module rl --config rl_grpo_qwen3_0_6b_varlen --num_steps 30
image

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 9, 2026
@felipemello1 felipemello1 force-pushed the 57-pr2-readable-cb branch 4 times, most recently from 8d33a12 to dd0478c Compare June 10, 2026 00:18
Felipe Mello added 2 commits June 9, 2026 23:11
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests +
multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream
GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version
tracking moves to the async PR. PR2 is the readable continuous-batching core only.

- generator.py: per-request `generate` (enqueue + await a future) driven by one background
  `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep.
  Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable.
- `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull
  symmetry); the engine never self-drains. In the sync loop it runs on an idle engine.
- Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The
  rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix).
- `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match
  + branching is a TODO.

Land the original goal — concurrent requests + multiturn — without duplicating the incoming router.

127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

- generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly)
- `rollout_to_episode` single-episode behavior + the prefix-matching TODO

- Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the
  router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
…ric cleanups

Pack a multi-turn rollout into ONE training Episode (loss-masking each assistant
turn) by token-prefix subtraction; branch into separate Episodes only where a
turn's prompt diverges from the prior prompt+completion (env edited history).
Episode.advantage becomes a per-token list aligned with token_ids.

Metrics:
- drop generator/output_tokens (derivable from completion length)
- dedup reward -> rollout_reward (validation already used validation_reward)
- add rollout/num_turns and rollout/branches_per_rollout
- inflight_at_completion: add Mean alongside Max, namespace by request prefix
- rename batcher/packing_efficiency -> pct_pad_in_batch (true padding ratio)

Also fix the collection-loop token estimate to count the packed-episode length
per rollout (was summing per-turn, which N-counted the growing prefix and
collected too few rollouts for multi-turn); alphabet_sort max_turns default 1->5.
Felipe Mello added 5 commits June 10, 2026 10:35
…trics + naming cleanups

- generator: engine-loop state split into request inbox / future outbox
  (GenerationRequest/WeightPullRequest/CloseRequest + GenerationFuture/_pull_future),
  _resolve_finished_requests, _fail_outstanding_futures, n=1 raise, weight-sync prefix-cache knobs
- rollouter/trainer: generate_fn + sample naming; GenerateFn moved to rollout/types
- metrics: output_tokens (mean/std/max) in prepare_rollout_metrics; fold reward/ into rollout_reward/
- tests updated
… guard, drop debug sample logging

- examples/alphabet_sort: copy verifiers' prompts verbatim (terse, self-contained
  per turn; <alphabetical_sorted> / <combined_alphabetical_sorted>; explicit
  "Mark any NEW names ... with `// new name!`"); fixed format example ending in `...`
  so the model doesn't copy a fixed row count; max_turns 5 -> 3 (verifiers default).
- trainer.py GRPOLoss: nan_to_num + clamp(log_ratio, -20, 20) before exp() so a
  non-finite vLLM cudagraph logprob can't NaN the loss; add metric
  loss/generator_logprob_nan_frac; TODO to record env input in the rollout recorder.
- rollout: remove _log_samples / last_completion_text / completion_text and the
  log_samples config; episode packing/branching docstring cleanup.
…g seam; simplify AlphabetSort example

- rollout/types.py: GenerateFn is now a Protocol with an explicit signature
  (prompt_token_ids/request_id/session_id/sampling_config -> Completion|None), not a loose Callable.
- session_id seam: run_single_rollout passes a stable per-rollout session_id (sticky-routing key)
  plus a per-turn request_id, threaded rollouter -> generate_fn -> generator.generate. A single
  generator ignores session_id; ready for the multi-generator router (pytorch#3583/pytorch#3625).
- run_single_rollout takes rollout_id (built in run_group_rollouts) instead of sample_idx.
- examples/alphabet_sort/env.py: fixed format example ending in "..." (dropped the randomized
  placeholder-row machinery); restored the original docstrings.
- docstring/comment cleanups (plain wording).
…pr2-readable-cb

# Conflicts:
#	torchtitan/experiments/rl/trainer.py
@felipemello1 felipemello1 changed the title [NOT READY][RL][Continuous batching] [rl] continuous-batching generator + multi-turn rollouts Jun 11, 2026
@felipemello1 felipemello1 marked this pull request as ready for review June 11, 2026 02:11
@felipemello1 felipemello1 requested review from tianyu-l and wwwjn June 11, 2026 02:12
Felipe Mello added 3 commits June 10, 2026 19:41
- Remove the _MAX_TURNS module constant + loop cap; the turn cap is deferred
  (TODO: add max_num_turns to TokenEnv.Config). The loop runs until the env is terminal.
- Docstrings: the Rollouter (not the controller) drives the rollouts; generate_fn "runs one
  generation" (routes through the generator, not "bound to one generator").

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not much to see here. I just updated the prompt to be more faithful to the original.

)
packing_metrics = [
m.Metric(
"batcher/packing_efficiency",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this metric was confusing. I replaced it with pct_pad_in_batch

),
),
m.Metric(
"batcher/num_packed_rows",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this didnt seem to be relevant

Comment on lines -231 to 244
for ep in episodes:
prompt_len = len(ep.prompt_token_ids)
completion_len = len(ep.completion_token_ids)
raw_ids = ep.prompt_token_ids + ep.completion_token_ids
gen_lp = [0.0] * prompt_len + ep.completion_logprobs
loss_mask = [False] * prompt_len + [True] * completion_len
advantages = [0.0] * prompt_len + [ep.advantage] * completion_len
for episode in episodes:
sample = {
"input_ids": raw_ids[:-1],
"labels": raw_ids[1:],
"generator_logprobs": gen_lp[1:],
"loss_mask": loss_mask[1:],
"advantages": advantages[1:],
"input_ids": episode.token_ids[:-1],
"labels": episode.token_ids[1:],
"generator_logprobs": episode.logprobs[1:],
"loss_mask": episode.loss_mask[1:],
"advantages": episode.advantage[1:],
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before we received a prompt + response and put them together here. Now, episodes already have the token logic for multiturn prebuilt. This happens in rollout_to_episodes

Comment on lines -59 to -61
"reward/_mean",
"reward/_max",
"reward/zero_std_frac",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before we have "reward" and "rollout_reward", so i just unified the prefixes

_ERROR = frozenset({"error_parse", "error_timeout", "error_abort", "error"})


class GenerateFn(Protocol):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check how its used in trainer.py, but, tldr: the rollout shouldnt need to know how to call a GeneratorRouter, i.e. we have to pass a bunch of args + call it in a fancy way.

Its better if this callable can be anything: monarch endpoint, http endpoint, mock, etc.

felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 11, 2026
Rebase onto pytorch#3593 (continuous-batching + multi-generator router). The async loop now:
- routes generation through self.generator_router (route/fanout/pull) instead of self.generator;
  setup_async spawns N generators and builds the router; close() fanouts.
- imports GenerateFn from rollout.types; _generate accepts+forwards session_id.
- calls Rollouter.run_group_rollouts with generate_fn=/sample= (renamed kwargs).
- ports pytorch#3593's GRPOLoss nan-guard (nan_to_num + clamp + generator_logprob_nan_frac).
- inlines the deleted last_completion_text helper in _log_samples.
- adds the generator_router Config field (set hot_swap=True for non-draining weight sync).

Tests: 163 passed / 6 skipped.
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 11, 2026
Rebase onto pytorch#3593 (continuous-batching + multi-generator router). The async loop now:
- routes generation through self.generator_router (route/fanout/pull) instead of self.generator;
  setup_async spawns N generators and builds the router; close() fanouts.
- imports GenerateFn from rollout.types; _generate accepts+forwards session_id.
- calls Rollouter.run_group_rollouts with generate_fn=/sample= (renamed kwargs).
- ports pytorch#3593's GRPOLoss nan-guard (nan_to_num + clamp + generator_logprob_nan_frac).
- inlines the deleted last_completion_text helper in _log_samples.
- adds the generator_router Config field (set hot_swap=True for non-draining weight sync).

Tests: 163 passed / 6 skipped.
init_prompt_messages=[{"role": "user", "content": prompt}]
)

async def step(self, completion_message: Message) -> MessageEnvStepOutput:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a del completion_message in the first line

# TODO: rename `Episode` -> `TrainingSample`
# and `rollout_to_episode` -> `rollout_to_training_sample`
@dataclass(kw_only=True, slots=True)
class Episode:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we rename this to TrainingSample?

metrics=[Metric("generator/queue_time_ms", ...)])
"""

policy_version: int

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with hot-swap this needs to be more rich?

Comment on lines +102 to +103
log_ratio = torch.nan_to_num(log_ratio)
log_ratio = torch.clamp(log_ratio, -20.0, 20.0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we clamp when the values are not infinity? Can we do torch.nan_to_num(log_ratio, nan=0.0, posinf=None, neginf=None)

Comment on lines +122 to +126
# Fraction of response tokens whose generator (vLLM) logprob is nan
"loss/generator_logprob_nan_frac": (
(~torch.isfinite(generator_logprobs)).float() * loss_mask
).sum()
/ loss_denominator,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be sumed over DP ranks before logging?

await pull_applied

@sl.log_trace_span("pull_weights_copy")
async def _pull_weights(self, version: int) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if possible I'd like a unification of naming, using model_state_dict instead of weights -- at the cost of everything longer.

Suggested change
async def _pull_weights(self, version: int) -> None:
async def _pull_model_state_dict(self, version: int) -> None:

generation_future = self._generation_futures.pop(request_output.request_id)
metrics_prefix = generation_future.metrics_prefix
# Sanity check to avoid unwanted behavior.
if len(request_output.outputs) != 1:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, one engine.step could not finish two requests at the same time? That sounds surprising.

)
return LoopDecision(action=LoopAction.STEP, requests=requests)

def _resolve_finished_requests(self, request_outputs: list[RequestOutput]) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _resolve_finished_requests(self, request_outputs: list[RequestOutput]) -> None:
def _process_finished_requests(self, request_outputs: list[RequestOutput]) -> None:

request_output, prefix=metrics_prefix
)
# +1 for the just-popped request
inflight_at_completion = float(len(self._generation_futures) + 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why +1, technically it's not in-flight any more? I guess please define in-flight.

f"{os.getpid()=} Generator pulled model state dict for policy v{version}"
)
if self.config.reset_prefix_cache_on_weight_sync:
# TODO(async): under hot-swap, prefer per-token weight-version tracking over a full

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm sounds tricky

  • when hot-swap is off, we definitely want this to be true. O/w later request can reuse prefix request's prefix cache, doesn't make sense to turn hot-swap off.
  • when hot-swap is off, it doesn't matter if reset_running_requests_on_weight_sync is True/False, because there won't be in-flight requests before/after the pull
  • @liangel-02 added this when debugging batch invariance, so we should keep that alignment "if batch invariance -> reset_prefix_cache_on_weight_sync must be True"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants