Skip to content

Add a router for multiple generators#3583

Merged
tianyu-l merged 7 commits into
pytorch:mainfrom
pzhan9:router
Jun 10, 2026
Merged

Add a router for multiple generators#3583
tianyu-l merged 7 commits into
pytorch:mainfrom
pzhan9:router

Conversation

@pzhan9

@pzhan9 pzhan9 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

For large scale RL, we normally use more than 1 generators. In order to spread the load more evenly, we need a router to route requests among them, with a pre-selected routing strategy.

This PR adds a GeneratorRouter class for this purpose. In addition, it adds two routing strategies:

  • round robin
  • least_loaded

GeneratorRouter also has a sync_weights method which can be used to sync all the generators' weight. It supports 2 modes:

  • hot swap:
    • in this mode, the weight sync does not interrupt generation
  • drain (i.e. not hot swap):
    • in this mode, the weight sync will wait until the in-flight generation completes.

@pzhan9 pzhan9 requested a review from felipemello1 June 9, 2026 02:43
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 9, 2026
@pzhan9 pzhan9 requested a review from tianyu-l June 9, 2026 02:43
Comment on lines 191 to 195
trainer_mesh, generator_mesh = spawn_proc_mesh(
trainer_world_size,
generator_world_size,
host_meshes=None,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it is good to make it work with single node? e.g. trainer 4 GPUs, generator 2 x 2 GPUs.

This way we get to guard on it even with 8 GPU CI.

@pzhan9 pzhan9 Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I I do plan to work on that as part of the "PlacementConfig" work. Right now, without that config option, there is no good place to express 2 generators in the Trainer.Config.

Note that I do not think we should add the generator_num option to GeneratorRouter.Config, since imo, router should be agnostic to the placement.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believed I understood the entire logic, it looks very good.

It'd be really important to enable this for single-node and thus guarded by CI, as @felipemello1 will develop using the router setup. Not having it runnable on single-node or in CI is a big downside.

Could you do some temporary change to unblock?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I can do that. Do you prefer to do it in this PR, or in a different one? I prefer either a prep PR, or a follow-up PR, since that leaves the change history cleaner.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do in another PR!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the temporary solution in: #3624

Comment thread torchtitan/experiments/rl/trainer.py Outdated
Comment thread torchtitan/experiments/rl/trainer.py Outdated
Comment thread torchtitan/experiments/rl/trainer.py Outdated
Comment thread torchtitan/experiments/rl/trainer.py Outdated
Comment thread torchtitan/experiments/rl/router.py Outdated
Comment thread torchtitan/experiments/rl/config_registry.py Outdated
Comment thread torchtitan/experiments/rl/router.py Outdated
Comment thread torchtitan/experiments/rl/router.py Outdated
Comment thread torchtitan/experiments/rl/generator_router.py Outdated
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 9, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent
requests + multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing
move to the upstream GeneratorRouter (pytorch#3583, which calls the same
pull_model_state_dict endpoint); per-token version tracking moves to the async PR.
PR2 is the readable continuous-batching core only.

## Summary
- generator.py: per-request `generate` (enqueue + await a future) driven by one
  background `_engine_loop`; rank 0 decides + broadcasts a LoopCommand, all ranks step
  in TP lockstep. New code at the file bottom (style.md §39).
- `pull_model_state_dict`: weight pull routed through the loop as LoopAction.PULL
  (push/pull symmetry with the trainer); the engine never self-drains. In the
  synchronous loop it runs on an idle engine, so the controller calls it directly.
- Metrics ride with the rollout (`RolloutTurn.metrics`); `pop_metrics` removed.
- Multi-turn: `run_single_rollout` loops generate -> env.step; `rollout_to_episodes`
  flattens one episode per turn (prefix-match branching deferred, TODO).

## Why
Land the original goal — concurrent requests + multiturn — without duplicating the
incoming router. Drain/hotswap is the router's job (pytorch#3583); off-policy/version tracking
is the async PR's. See discussions/57_async_rl_pipeline/pr2_readable_cb_hotswap/.

## Validation
127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

## Review focus
- generator.py readability: one tick-trace + literal names, no walkthrough needed
- per-request `generate` + the engine loop (two-barrier TP lockstep)
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics ride-along (RolloutTurn.metrics) replacing pop_metrics

## Risks / open questions
- Ships with NO weight-sync hold. Safe for PR2 because the synchronous loop drains
  naturally (no generate is in flight at sync time, so the pull hits an idle engine).
  The hold returns via the router; the async PR (PR3) must integrate the router or
  re-add a hold before generation overlaps weight sync.
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 9, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent
requests + multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing
move to the upstream GeneratorRouter (pytorch#3583, which calls the same
pull_model_state_dict endpoint); per-token version tracking moves to the async PR.
PR2 is the readable continuous-batching core only.

## Summary
- generator.py: per-request `generate` (enqueue + await a future) driven by one
  background `_engine_loop`; rank 0 decides + broadcasts a LoopCommand, all ranks step
  in TP lockstep. New code at the file bottom (style.md §39).
- `pull_model_state_dict`: weight pull routed through the loop as LoopAction.PULL
  (push/pull symmetry with the trainer); the engine never self-drains. In the
  synchronous loop it runs on an idle engine, so the controller calls it directly.
- Metrics ride with the rollout (`RolloutTurn.metrics`); `pop_metrics` removed.
- Multi-turn: `run_single_rollout` loops generate -> env.step; `rollout_to_episodes`
  flattens one episode per turn (prefix-match branching deferred, TODO).

## Why
Land the original goal — concurrent requests + multiturn — without duplicating the
incoming router. Drain/hotswap is the router's job (pytorch#3583); off-policy/version tracking
is the async PR's. See discussions/57_async_rl_pipeline/pr2_readable_cb_hotswap/.

## Validation
127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

## Review focus
- generator.py readability: one tick-trace + literal names, no walkthrough needed
- per-request `generate` + the engine loop (two-barrier TP lockstep)
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics ride-along (RolloutTurn.metrics) replacing pop_metrics

## Risks / open questions
- Ships with NO weight-sync hold. Safe for PR2 because the synchronous loop drains
  naturally (no generate is in flight at sync time, so the pull hits an idle engine).
  The hold returns via the router; the async PR (PR3) must integrate the router or
  re-add a hold before generation overlaps weight sync.
@pzhan9

pzhan9 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

@tianyu-l I addressed most of the comments, while leaving some open for discussion. Please take another look.

felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 9, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests +
multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream
GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version
tracking moves to the async PR. PR2 is the readable continuous-batching core only.

## Summary
- generator.py: per-request `generate` (enqueue + await a future) driven by one background
  `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep.
  Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable.
- `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull
  symmetry); the engine never self-drains. In the sync loop it runs on an idle engine.
- Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The
  rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix).
- `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match
  + branching is a TODO.

## Why
Land the original goal — concurrent requests + multiturn — without duplicating the incoming router.

## Validation
127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

## Review focus
- generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly)
- `rollout_to_episode` single-episode behavior + the prefix-matching TODO

## Risks / open questions
- Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the
  router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 9, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests +
multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream
GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version
tracking moves to the async PR. PR2 is the readable continuous-batching core only.

## Summary
- generator.py: per-request `generate` (enqueue + await a future) driven by one background
  `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep.
  Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable.
- `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull
  symmetry); the engine never self-drains. In the sync loop it runs on an idle engine.
- Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The
  rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix).
- `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match
  + branching is a TODO.

## Why
Land the original goal — concurrent requests + multiturn — without duplicating the incoming router.

## Validation
127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

## Review focus
- generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly)
- `rollout_to_episode` single-episode behavior + the prefix-matching TODO

## Risks / open questions
- Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the
  router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
Comment thread torchtitan/experiments/rl/generator_router.py Outdated
Comment thread torchtitan/experiments/rl/generator_router.py Outdated
Comment thread torchtitan/experiments/rl/generator_router.py Outdated

await self._serving.wait()
candidates = self._candidates()
assert candidates, "serving event was set with no serving generators"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can cause problem?

When all generators are in weight sync, sending a request will crash.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine. It is protected by await self._serving.wait() above.

Basically, there are 2 layers of asyncio.Event in the router:

  • GeneratorHandle.state, which is at the single generator level
  • GeneratorRouter._serving, which is at the global level, and basically does book-keeping on wether there is any generator still serving.

The usage of the 2 layers is somewhat convoluted. But the benefit it brings is, now in the drain mode of weight sync, we do not need to stop/resume the generators all together. Instead, we can update them individually. e.g.:

  • generator A takes 2 minutes to drain and update weights;
  • generator B takes 1 minute.

Then B will start serving after it is updated, while A is still updating.

Of course at this stage, I am not quite sure if we need this level of optimization. If you find it is too convoluted, I could change it to a single global lock, which will stop/resume generators together. In that case, generator B will idle for 1 minute. But the overall code should be simpler. We could add the per-generator layer in the future when we find it is necessary.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, sorry I made the comment when I didn't understand the code, later when I undertstood I forgot to delete. Thanks for the explanation! I think the code quality is quite high and the logic is good.

Comment thread torchtitan/experiments/rl/generator_router.py Outdated
Comment thread torchtitan/experiments/rl/generator_router.py
Comment thread torchtitan/experiments/rl/generator_router.py Outdated
Comment on lines 191 to 195
trainer_mesh, generator_mesh = spawn_proc_mesh(
trainer_world_size,
generator_world_size,
host_meshes=None,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believed I understood the entire logic, it looks very good.

It'd be really important to enable this for single-node and thus guarded by CI, as @felipemello1 will develop using the router setup. Not having it runnable on single-node or in CI is a big downside.

Could you do some temporary change to unblock?

felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 9, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests +
multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream
GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version
tracking moves to the async PR. PR2 is the readable continuous-batching core only.

- generator.py: per-request `generate` (enqueue + await a future) driven by one background
  `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep.
  Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable.
- `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull
  symmetry); the engine never self-drains. In the sync loop it runs on an idle engine.
- Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The
  rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix).
- `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match
  + branching is a TODO.

Land the original goal — concurrent requests + multiturn — without duplicating the incoming router.

127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

- generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly)
- `rollout_to_episode` single-episode behavior + the prefix-matching TODO

- Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the
  router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 10, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests +
multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream
GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version
tracking moves to the async PR. PR2 is the readable continuous-batching core only.

- generator.py: per-request `generate` (enqueue + await a future) driven by one background
  `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep.
  Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable.
- `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull
  symmetry); the engine never self-drains. In the sync loop it runs on an idle engine.
- Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The
  rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix).
- `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match
  + branching is a TODO.

Land the original goal — concurrent requests + multiturn — without duplicating the incoming router.

127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

- generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly)
- `rollout_to_episode` single-episode behavior + the prefix-matching TODO

- Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the
  router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 10, 2026
Rewrite generator.py fresh from main for continuous batching: accept concurrent requests +
multiturn. The weight-sync hold (drain/hotswap) + multi-generator routing move to the upstream
GeneratorRouter (pytorch#3583, same pull_model_state_dict endpoint); per-token version
tracking moves to the async PR. PR2 is the readable continuous-batching core only.

- generator.py: per-request `generate` (enqueue + await a future) driven by one background
  `_engine_loop`; rank 0 decides a `LoopDecision` + broadcasts it, all ranks step in TP lockstep.
  Broadcast / admit / step are inlined into the loop; `max_steps_per_iteration` is configurable.
- `pull_model_state_dict`: weight pull routed through the loop as `LoopAction.PULL` (push/pull
  symmetry); the engine never self-drains. In the sync loop it runs on an idle engine.
- Metrics ride ON the completion (`Completion.metrics`); `generate -> Completion | None`. The
  rollout logger excludes `metrics` for now (Metric isn't JSON-serializable; TODO to fix).
- `rollout_to_episode`: one Episode per rollout (its last completed turn); multi-turn prefix-match
  + branching is a TODO.

Land the original goal — concurrent requests + multiturn — without duplicating the incoming router.

127 rl unit tests pass. GPU 20-step curve vs the PR1 baseline: PENDING.

- generator.py readability (names + the top-of-file flow), the two-barrier TP lockstep
- `pull_model_state_dict` routed through the loop (collective ordering, no self-drain)
- metrics on `Completion` + the logger `metrics` exclusion (TODO: make Metric JSON-friendly)
- `rollout_to_episode` single-episode behavior + the prefix-matching TODO

- Ships with no weight-sync hold: safe for PR2 (sync loop drains naturally); hold returns via the
  router. The async PR must integrate the router or re-add a hold before generation overlaps sync.
pzhan9 added 6 commits June 10, 2026 08:32
Pure rename of the generator-router module and its test
(router.py -> generator_router.py, test_router.py ->
test_generator_router.py) plus the import-path updates in trainer.py,
config_registry.py, and the tests. No logic changes.

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I learned a lot from this PR

Comment on lines 191 to 195
trainer_mesh, generator_mesh = spawn_proc_mesh(
trainer_world_size,
generator_world_size,
host_meshes=None,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do in another PR!

@tianyu-l tianyu-l merged commit 7f0749e into pytorch:main Jun 10, 2026
11 of 12 checks passed
@pzhan9 pzhan9 deleted the router branch June 10, 2026 20:04
felipemello1 pushed a commit to felipemello1/torchtitan that referenced this pull request Jun 11, 2026
…g seam; simplify AlphabetSort example

- rollout/types.py: GenerateFn is now a Protocol with an explicit signature
  (prompt_token_ids/request_id/session_id/sampling_config -> Completion|None), not a loose Callable.
- session_id seam: run_single_rollout passes a stable per-rollout session_id (sticky-routing key)
  plus a per-turn request_id, threaded rollouter -> generate_fn -> generator.generate. A single
  generator ignores session_id; ready for the multi-generator router (pytorch#3583/pytorch#3625).
- run_single_rollout takes rollout_id (built in run_group_rollouts) instead of sample_idx.
- examples/alphabet_sort/env.py: fixed format example ending in "..." (dropped the randomized
  placeholder-row machinery); restored the original docstrings.
- docstring/comment cleanups (plain wording).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants