Add StickySessionRoutingStrategy for GeneratorRouter#3625
Conversation
| if config.max_sessions <= 0: | ||
| raise ValueError( | ||
| f"max_sessions must be positive, got {config.max_sessions}" | ||
| ) |
| if len(self._sessions) > self._max_sessions: | ||
| self._sessions.popitem(last=False) |
There was a problem hiding this comment.
add to TODO to log this using structured logger @felipemello1
maybe logger.warning for now?
| estimated_cost: int = 1 | ||
| """Estimated request cost used by load-aware routing strategies.""" | ||
|
|
||
| session_id: str | None = None |
There was a problem hiding this comment.
where are you assigning this to a context?
also this is leaking unnecessary info to non-stick routing strategies
| if sticky_generator is not None and any( | ||
| h is sticky_generator for h in candidates | ||
| ): | ||
| self._sessions.move_to_end(routing_ctx.session_id) | ||
| return sticky_generator |
There was a problem hiding this comment.
It seems if a generator is in weight-sync, you'd choose a new session, instead of wait until this generator's weight-sync finishes.
Do we know if this trade-off is worth it, especially for extreme long horizon rollout?
|
|
||
| chosen = self._fallback_strategy.choose(routing_ctx, candidates) | ||
| self._sessions[routing_ctx.session_id] = chosen | ||
| self._sessions.move_to_end(routing_ctx.session_id) |
There was a problem hiding this comment.
please add more comments on the strategy here
| return min(candidates, key=lambda h: h.reserved_load) | ||
|
|
||
|
|
||
| class StickySessionRoutingStrategy(RoutingStrategy): |
There was a problem hiding this comment.
The logic around assigning context.session_id is not clear, as it's not used anywhere https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/rl/trainer.py#L592
Most importantly, we should make sure all the same prompts from a single GRPO group can route to the same generator.
There was a problem hiding this comment.
please enhance the multi-generator CI test with this feature
…g seam; simplify AlphabetSort example - rollout/types.py: GenerateFn is now a Protocol with an explicit signature (prompt_token_ids/request_id/session_id/sampling_config -> Completion|None), not a loose Callable. - session_id seam: run_single_rollout passes a stable per-rollout session_id (sticky-routing key) plus a per-turn request_id, threaded rollouter -> generate_fn -> generator.generate. A single generator ignores session_id; ready for the multi-generator router (pytorch#3583/pytorch#3625). - run_single_rollout takes rollout_id (built in run_group_rollouts) instead of sample_idx. - examples/alphabet_sort/env.py: fixed format example ending in "..." (dropped the randomized placeholder-row machinery); restored the original docstrings. - docstring/comment cleanups (plain wording).
As title.