[rl] continuous-batching generator + multi-turn rollouts#3593
[rl] continuous-batching generator + multi-turn rollouts#3593felipemello1 wants to merge 10 commits into
Conversation
8d33a12 to
dd0478c
Compare
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests + multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version tracking moves to the async PR. PR2 is the readable continuous-batching core only. - generator.py: per-request `generate` (enqueue + await a future) driven by one background `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep. Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable. - `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull symmetry); the engine never self-drains. In the sync loop it runs on an idle engine. - Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix). - `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match + branching is a TODO. Land the original goal — concurrent requests + multiturn — without duplicating the incoming router. 127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING. - generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep - `pull_model_state_dict` routed through the loop (collective ordering, no self-drain) - metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly) - `rollout_to_episode` single-episode behavior + the prefix-matching TODO - Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
…ric cleanups Pack a multi-turn rollout into ONE training Episode (loss-masking each assistant turn) by token-prefix subtraction; branch into separate Episodes only where a turn's prompt diverges from the prior prompt+completion (env edited history). Episode.advantage becomes a per-token list aligned with token_ids. Metrics: - drop generator/output_tokens (derivable from completion length) - dedup reward -> rollout_reward (validation already used validation_reward) - add rollout/num_turns and rollout/branches_per_rollout - inflight_at_completion: add Mean alongside Max, namespace by request prefix - rename batcher/packing_efficiency -> pct_pad_in_batch (true padding ratio) Also fix the collection-loop token estimate to count the packed-episode length per rollout (was summing per-turn, which N-counted the growing prefix and collected too few rollouts for multi-turn); alphabet_sort max_turns default 1->5.
89fd714 to
b144248
Compare
…g & metric cleanups
…trics + naming cleanups - generator: engine-loop state split into request inbox / future outbox (GenerationRequest/WeightPullRequest/CloseRequest + GenerationFuture/_pull_future), _resolve_finished_requests, _fail_outstanding_futures, n=1 raise, weight-sync prefix-cache knobs - rollouter/trainer: generate_fn + sample naming; GenerateFn moved to rollout/types - metrics: output_tokens (mean/std/max) in prepare_rollout_metrics; fold reward/ into rollout_reward/ - tests updated
… guard, drop debug sample logging - examples/alphabet_sort: copy verifiers' prompts verbatim (terse, self-contained per turn; <alphabetical_sorted> / <combined_alphabetical_sorted>; explicit "Mark any NEW names ... with `// new name!`"); fixed format example ending in `...` so the model doesn't copy a fixed row count; max_turns 5 -> 3 (verifiers default). - trainer.py GRPOLoss: nan_to_num + clamp(log_ratio, -20, 20) before exp() so a non-finite vLLM cudagraph logprob can't NaN the loss; add metric loss/generator_logprob_nan_frac; TODO to record env input in the rollout recorder. - rollout: remove _log_samples / last_completion_text / completion_text and the log_samples config; episode packing/branching docstring cleanup.
…g seam; simplify AlphabetSort example - rollout/types.py: GenerateFn is now a Protocol with an explicit signature (prompt_token_ids/request_id/session_id/sampling_config -> Completion|None), not a loose Callable. - session_id seam: run_single_rollout passes a stable per-rollout session_id (sticky-routing key) plus a per-turn request_id, threaded rollouter -> generate_fn -> generator.generate. A single generator ignores session_id; ready for the multi-generator router (pytorch#3583/pytorch#3625). - run_single_rollout takes rollout_id (built in run_group_rollouts) instead of sample_idx. - examples/alphabet_sort/env.py: fixed format example ending in "..." (dropped the randomized placeholder-row machinery); restored the original docstrings. - docstring/comment cleanups (plain wording).
…pr2-readable-cb # Conflicts: # torchtitan/experiments/rl/trainer.py
d90e298 to
f7ca1ec
Compare
- Remove the _MAX_TURNS module constant + loop cap; the turn cap is deferred (TODO: add max_num_turns to TokenEnv.Config). The loop runs until the env is terminal. - Docstrings: the Rollouter (not the controller) drives the rollouts; generate_fn "runs one generation" (routes through the generator, not "bound to one generator").
cf0012a to
7d1c046
Compare
There was a problem hiding this comment.
not much to see here. I just updated the prompt to be more faithful to the original.
| ) | ||
| packing_metrics = [ | ||
| m.Metric( | ||
| "batcher/packing_efficiency", |
There was a problem hiding this comment.
this metric was confusing. I replaced it with pct_pad_in_batch
| ), | ||
| ), | ||
| m.Metric( | ||
| "batcher/num_packed_rows", |
There was a problem hiding this comment.
this didnt seem to be relevant
| for ep in episodes: | ||
| prompt_len = len(ep.prompt_token_ids) | ||
| completion_len = len(ep.completion_token_ids) | ||
| raw_ids = ep.prompt_token_ids + ep.completion_token_ids | ||
| gen_lp = [0.0] * prompt_len + ep.completion_logprobs | ||
| loss_mask = [False] * prompt_len + [True] * completion_len | ||
| advantages = [0.0] * prompt_len + [ep.advantage] * completion_len | ||
| for episode in episodes: | ||
| sample = { | ||
| "input_ids": raw_ids[:-1], | ||
| "labels": raw_ids[1:], | ||
| "generator_logprobs": gen_lp[1:], | ||
| "loss_mask": loss_mask[1:], | ||
| "advantages": advantages[1:], | ||
| "input_ids": episode.token_ids[:-1], | ||
| "labels": episode.token_ids[1:], | ||
| "generator_logprobs": episode.logprobs[1:], | ||
| "loss_mask": episode.loss_mask[1:], | ||
| "advantages": episode.advantage[1:], | ||
| } |
There was a problem hiding this comment.
before we received a prompt + response and put them together here. Now, episodes already have the token logic for multiturn prebuilt. This happens in rollout_to_episodes
| "reward/_mean", | ||
| "reward/_max", | ||
| "reward/zero_std_frac", |
There was a problem hiding this comment.
before we have "reward" and "rollout_reward", so i just unified the prefixes
| _ERROR = frozenset({"error_parse", "error_timeout", "error_abort", "error"}) | ||
|
|
||
|
|
||
| class GenerateFn(Protocol): |
There was a problem hiding this comment.
check how its used in trainer.py, but, tldr: the rollout shouldnt need to know how to call a GeneratorRouter, i.e. we have to pass a bunch of args + call it in a fancy way.
Its better if this callable can be anything: monarch endpoint, http endpoint, mock, etc.
Rebase onto pytorch#3593 (continuous-batching + multi-generator router). The async loop now: - routes generation through self.generator_router (route/fanout/pull) instead of self.generator; setup_async spawns N generators and builds the router; close() fanouts. - imports GenerateFn from rollout.types; _generate accepts+forwards session_id. - calls Rollouter.run_group_rollouts with generate_fn=/sample= (renamed kwargs). - ports pytorch#3593's GRPOLoss nan-guard (nan_to_num + clamp + generator_logprob_nan_frac). - inlines the deleted last_completion_text helper in _log_samples. - adds the generator_router Config field (set hot_swap=True for non-draining weight sync). Tests: 163 passed / 6 skipped.
Rebase onto pytorch#3593 (continuous-batching + multi-generator router). The async loop now: - routes generation through self.generator_router (route/fanout/pull) instead of self.generator; setup_async spawns N generators and builds the router; close() fanouts. - imports GenerateFn from rollout.types; _generate accepts+forwards session_id. - calls Rollouter.run_group_rollouts with generate_fn=/sample= (renamed kwargs). - ports pytorch#3593's GRPOLoss nan-guard (nan_to_num + clamp + generator_logprob_nan_frac). - inlines the deleted last_completion_text helper in _log_samples. - adds the generator_router Config field (set hot_swap=True for non-draining weight sync). Tests: 163 passed / 6 skipped.
| init_prompt_messages=[{"role": "user", "content": prompt}] | ||
| ) | ||
|
|
||
| async def step(self, completion_message: Message) -> MessageEnvStepOutput: |
There was a problem hiding this comment.
add a del completion_message in the first line
| # TODO: rename `Episode` -> `TrainingSample` | ||
| # and `rollout_to_episode` -> `rollout_to_training_sample` | ||
| @dataclass(kw_only=True, slots=True) | ||
| class Episode: |
There was a problem hiding this comment.
why don't we rename this to TrainingSample?
| metrics=[Metric("generator/queue_time_ms", ...)]) | ||
| """ | ||
|
|
||
| policy_version: int |
There was a problem hiding this comment.
with hot-swap this needs to be more rich?
| log_ratio = torch.nan_to_num(log_ratio) | ||
| log_ratio = torch.clamp(log_ratio, -20.0, 20.0) |
There was a problem hiding this comment.
Should we clamp when the values are not infinity? Can we do torch.nan_to_num(log_ratio, nan=0.0, posinf=None, neginf=None)
| # Fraction of response tokens whose generator (vLLM) logprob is nan | ||
| "loss/generator_logprob_nan_frac": ( | ||
| (~torch.isfinite(generator_logprobs)).float() * loss_mask | ||
| ).sum() | ||
| / loss_denominator, |
There was a problem hiding this comment.
will this be sumed over DP ranks before logging?
| await pull_applied | ||
|
|
||
| @sl.log_trace_span("pull_weights_copy") | ||
| async def _pull_weights(self, version: int) -> None: |
There was a problem hiding this comment.
if possible I'd like a unification of naming, using model_state_dict instead of weights -- at the cost of everything longer.
| async def _pull_weights(self, version: int) -> None: | |
| async def _pull_model_state_dict(self, version: int) -> None: |
| generation_future = self._generation_futures.pop(request_output.request_id) | ||
| metrics_prefix = generation_future.metrics_prefix | ||
| # Sanity check to avoid unwanted behavior. | ||
| if len(request_output.outputs) != 1: |
There was a problem hiding this comment.
hmm, one engine.step could not finish two requests at the same time? That sounds surprising.
| ) | ||
| return LoopDecision(action=LoopAction.STEP, requests=requests) | ||
|
|
||
| def _resolve_finished_requests(self, request_outputs: list[RequestOutput]) -> None: |
There was a problem hiding this comment.
| def _resolve_finished_requests(self, request_outputs: list[RequestOutput]) -> None: | |
| def _process_finished_requests(self, request_outputs: list[RequestOutput]) -> None: |
| request_output, prefix=metrics_prefix | ||
| ) | ||
| # +1 for the just-popped request | ||
| inflight_at_completion = float(len(self._generation_futures) + 1) |
There was a problem hiding this comment.
why +1, technically it's not in-flight any more? I guess please define in-flight.
| f"{os.getpid()=} Generator pulled model state dict for policy v{version}" | ||
| ) | ||
| if self.config.reset_prefix_cache_on_weight_sync: | ||
| # TODO(async): under hot-swap, prefer per-token weight-version tracking over a full |
There was a problem hiding this comment.
hmm sounds tricky
- when hot-swap is off, we definitely want this to be true. O/w later request can reuse prefix request's prefix cache, doesn't make sense to turn hot-swap off.
- when hot-swap is off, it doesn't matter if reset_running_requests_on_weight_sync is True/False, because there won't be in-flight requests before/after the pull
- @liangel-02 added this when debugging batch invariance, so we should keep that alignment "if batch invariance -> reset_prefix_cache_on_weight_sync must be True"
Note: Its a big PR. I reviewed it a few times, but there might still be rough edges and slop. Sorry/thank you.
How to review
a.
actors/generator.py— the engine loop +LoopDecision(the core).b.
rollout/rollouter.py—run_group_rollouts/run_single_rollout.c.
rollout/utils.py:rollout_to_episodes+batcher.py.d.
trainer.pyI don't think that you should try to compare before/after. Just open the file and read the code.
Summary
A vLLM generator that drives many concurrent generations through ONE engine loop (continuous batching), plus the multi-turn rollout stack around it. A whole GRPO group's turns now coalesce into one batch instead of N blocking calls.
Mental model — request intake is decoupled from
engine.step(the TP collective):What's in it
Continuous-batching generator (
actors/generator.py): one_engine_loopper rank; rank0 fires independent generate/pull/close and resolves per-request futures.Rollout moved into a
Rollouter(rollout/rollouter.py): Now it also ownsrun_group_rolloutsandrun_single_rollout. The controller no longer drives rollouts. It just calls the rollouter.Episodes + batcher (
rollout/utils.py,batcher.py): Made changes to adapt it to multiturn.Weight sync / hotswap: default is to swap weight in between tokens. We expose two flags to users can decide what to do with the kv cache. TODO: support drain then swap. This is not really relevant for this PR, since training is still synchronous.
GRPO loss: got some NaN from vllm when cudagraph=True. I added clamping + nan_to_num + metric. We should monitor and figure out why this happens in cudagraph and not in eager.
Task: AlphabetSort multi-turn -- changed default to max 3 turns (instead of 1) and adapted prompt to the original that i copied from (verifiers)
Test plan
AlphabetSort, Qwen3-0.6B, 6 GPUs (4 gen + 2 train), cudagraph on: