[RL] - MessageEnv, Rollout types, Rubric, Renderer#3453
Conversation
- Add shared pack() function in torchtitan/components/dataloading/utils.py - Add Batcher class with Configurable pattern: packs episodes into fixed [B, seq_length] TrainBatches with gradient accumulation support - Refactor TrainBatch to use [B, L] tensors instead of [1, total_tokens] - Update PolicyTrainer.forward_backward to work with [B, L] microbatches - Simplify compute_logprobs/verify_logprob_identity in actors/utils.py
- envs/: MessageEnv ABC + MessageReset/MessageStep; RendererEnv wrapper + RendererEnvConfig + TokenizedTurn; Rollout / RolloutTurn / RolloutStatus / DatasetOutput types. - rubrics/: Rubric base class with score_one / score_group, typed Reward dataclass; Rubric owns truncation_reward / error_reward (doc 37 Option B). - recipes/: Task base + SumDigitsTask (dataset + env + grader + recipe). - grpo.py: _run_rollouts refactored (Option G — inline orchestration + _do_group_step helper + per-group failure isolation); validate() reuses the same path. - Drops: sum_digits.py (orphan), test_grpo_metrics.py (broken), Trajectory + Step from types.py.
- rollouts/: new folder holding Rollout / RolloutTurn / RolloutStatus / DatasetOutput (types.py) and last_assistant_text / rollout_to_episode / prepare_rollout_metrics (utils.py). envs/types.py deleted. - Rubric: Configurable with register_funcs() hook + lazy @cached_property weight normalization (sum-to-1). New RewardFn dataclass (fn + weight). truncation_reward / error_reward become Optional[float] short-circuit knobs (None = run reward fns on partial response). - SumDigitsRubric subclass added in grader.py; SumDigitsDataset is now Configurable; SumDigitsTask.Config composes nested sub-configs (dataset / rubric / env_limits) matching PolicyTrainer.Config style. - MessageReset/Step -> MsgResponseReset/Step. - RendererEnvConfig -> EnvLimits; kwarg config -> limits. - recipes/sum_digits/recipe.py -> task.py. - _run_rollouts: overflow-aware prompt routing -- initial-prompt overflow builds TRUNCATED_OVERFLOW rollouts directly instead of sending an empty prompt to the generator. - prepare_rollout_metrics consolidates the rollout/validation metric blocks; _prepare_reward_metrics removed. - Cleanups: RolloutStatus is_*() via frozenset; validate_rollout_output dropped; Rollout.group_id / sample_idx required; _log_samples Episode-only; reward_correct stops asserting env_input type.
Task surface
- DatasetOutput.env_name -> DatasetOutput.task; SumDigitsDataset
ENV_NAME -> TASK_NAME.
- Dataset moves off Task onto RLTrainer.Config:
train: Task.Config -> train_dataset + tasks dict keyed by task name.
Same for validation.
Rows route to the matching Task via example.task.
- SumDigitsTask: drops dataset + sample_example; only rubric,
env_limits, and make_envs remain.
- Task base gains score_group(rollouts, env_input) -> list[Rollout]:
default delegates to self.rubric.score_group + fills reward and
reward_components. Step logic stays in grpo._do_group_step.
- TODOs: continuous-batching Task.do_single_rollout; revisit Camp A
vs B (dataset on Task vs framework).
Style + cleanups (same files)
- Dataclasses converted from per-field docstrings to Args-at-top with
inline shape comments. Drops double-backticks and stale PR
references.
- Bare # title block comments in _run_rollouts, _build_episodes,
validate, train, _do_group_step. Dashed banners removed in train.
- Dead code: _shard_episodes removed (Batcher does sharding).
- Bug fixes:
- prepare_rollout_metrics total_lens computed per-rollout
(multi-turn safe).
- _do_group_step initial.next_token_ids raises on None instead of
silent default to [].
- Rubric: register_funcs result checked for unique fn names.
- renderer_env: env_step.status defaulting uses replace() not
in-place mutation; _terminal accepts last_response_messages
instead of post-construction mutation.
Smoke: imports, config build, rubric e2e (3 statuses + short-circuit
modes), task.score_group fills rewards.
- Dataclasses: per-field docstrings; single backticks; drop internal-doc refs. - RendererEnv.step_completion parses before classifying length/abort, so truncated/aborted/timed-out turns keep their response message + tokens (partial-reward grading + debugging). - Completion.text removed (text comes from rollout messages); RolloutTurn gains reward/reward_components; TokenizedTurn carries terminal status. - Read pad/eos from the renderer's tokenizer; drop the standalone tokenizer. - Task.score_group returns list[Reward]; controller applies them. - Condense _run_rollouts; README installs renderers.
Conflicts resolved: - batcher.py / TrainingBatch labels / compute_logprobs -> take main (landed batcher PR). - grpo.py / config_registry / types Completion+Episode -> keep v8 (our rollout/env/rubric/recipe layer), adapted to main's batcher API. - config_registry: keep recipes structure (tasks/dataset), import BatchConfig from rl.batcher, enable_wandb=True (matches main). - Dropped stale test_grpo_metrics.py (tested the removed Step/Trajectory API).
renderers 0.1.8.dev37 replaced the name-based factory with a typed-config one. RendererConfig.build picks the config variant for name from the public discriminated union and passes only supported knobs; adds an enable_thinking field. Removes the private-attr workaround.
The merge with the landed batcher PR pulled in changes that don't belong
to this PR:
- Drop half-done-batcher-base leftovers: BatchConfig in config/configs.py
(the landed batcher has its own in rl.batcher), config/__init__ export,
components/dataloading/{__init__,utils}.py, and the text_datasets.py
rewrite. RL never imports any of these.
- grpo.py: revert train() docstring, reworded comments, and the
per-microbatch metrics rewrite back to upstream's aggregation; revert the
setup_async docstring; re-add _shard_episodes (dead in upstream too).
Keep only the genuine v8 substitutions (Rollout types, async
_collect_rollouts, _build_episodes).
- _run_rollouts/_collect_rollouts return list[RolloutGroup]; _PendingGroup build struct; whole-group drop on prompt overflow; flat n=1 generate with completions zipped 1:1 (no rebucket). - _do_group_step -> _do_single_rollout (one env+completion -> Rollout) with try/except inside so partial turns survive as an ERROR rollout. - RolloutStatus gains ONGOING + ERROR + is_terminal(); TRUNCATED_OVERFLOW -> TRUNCATED_PROMPT_OVERFLOW; TokenizedResponseStep.status is required. - _is_overflow -> _is_prompt_overflow (prompt_len >= max_rollout_tokens); drop unused EnvLimits.max_generation_tokens. - TokenizedTurn -> TokenizedResponseStep; Rubric.register_funcs is @abc.abstractmethod; logging TODOs on Rollout/RolloutTurn.
- config_registry: enable_thinking=True, max_tokens=700 for all rl_grpo_qwen3_* configs. Qwen3-0.6B one-shots sum_digits with thinking + enough budget (reward ~0.95-1.0 from step 1); 100/200 tokens truncate mid-<think>. - rollouts/types.py: shorten RolloutStatus docstring; group RolloutTurn fields.
…build - recipes/ -> tasks/; DatasetOutput.task -> task_name. - env carriers: MsgResponseReset/Step -> ResetOutput/StepOutput; TokenizedResponseStep -> TokenizedStepOutput. StepOutput drops status (keeps done); RendererEnv owns RolloutStatus. - EnvLimits -> RendererEnvConfig (field renderer_env_config; RendererEnv(config=)). - next_token_ids/next_messages -> next_prompt_token_ids/next_prompt_messages. - last_response_messages/response_messages split -> assistant_message + env_messages. - RolloutTurn gains prompt_messages; restore env-set reward path: StepOutput.reward_components -> TokenizedStepOutput.env_reward_components -> RolloutTurn.reward_components. - renderer.py build() via pydantic TypeAdapter; drop defensive list copy; log parse/timeout exceptions; MessageEnv docstring + Example; revert 9 unrelated RLTrainer.Config docstrings; 'policy' -> 'limits'.
- _build_episodes: drop a group when any sibling has no turns (turn-less ERROR rollout) instead of checking reward-is-None. The old check never fired (score_group always sets rewards) and let a turn-less rollout reach rollout_to_episode, which raises (requires exactly one turn). - rollout_to_episode: derive text from the rollout (last_assistant_text) instead of taking it as a param; drop the now-unused import in grpo.
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks, did a first pass. Overall I feel the logic among Env, Task, Rollout, Dataset, Rubric could be more clear.
| return sum(s.reward for _, s in self.transitions) | ||
|
|
||
|
|
||
| # TODO: rename `Episode` -> `TrainSample` and `rollout_to_episode` -> |
There was a problem hiding this comment.
nit: TrainingSample, to be consistent with TrainingBatch
| uv pip install torchmonarch==0.4.1 | ||
| uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main" | ||
| uv pip install pygtrie portpicker | ||
| uv pip install "git+https://github.com/PrimeIntellect-ai/renderers.git@main" |
There was a problem hiding this comment.
would this be used for sft as well?
There was a problem hiding this comment.
Do we need to depend on renders lastest main?
| preserve_thinking_between_tool_calls: bool = False | ||
|
|
||
| def build(self, *, model_path: str) -> Renderer: | ||
| from transformers import AutoTokenizer |
There was a problem hiding this comment.
Can we avoid this dependency?
There was a problem hiding this comment.
| `model_path` and constructs the `renderers` config matching `name`. | ||
|
|
||
| Args: | ||
| name: Renderer name (e.g. `"qwen3"`, `"auto"`). |
There was a problem hiding this comment.
Not clear what this means. We should make it explicitly referring to torchtitan model (and their tokenizer) and avoid transformers dependency.
| top_p=_sampling_config.top_p, | ||
| max_tokens=_sampling_config.max_tokens, | ||
| n=_sampling_config.n, | ||
| stop_token_ids=list(_sampling_config.stop_token_ids) or None, |
There was a problem hiding this comment.
| stop_token_ids=list(_sampling_config.stop_token_ids) or None, | |
| stop_token_ids=_sampling_config.stop_token_ids or None, |
| """The env's reply messages this turn (tool / user).""" | ||
|
|
||
| # For rubrics | ||
| reward_components: dict[str, float] = field(default_factory=dict) |
There was a problem hiding this comment.
can we call it reward which should always decompose into "components"
There was a problem hiding this comment.
I think that we want to have a simple field "reward: float", because thats what the final loss will use. No ambiguity here. And components is the breakdown for logging or weight averaging or some specific advantage computation.
There was a problem hiding this comment.
If single field reward is always coming from the components, can we have cached_property that derive the single field from components. If not (instead single field is the only thing RL algorithm eventually care, and everything else is only for logging), then I'm fine with it.
| assistant_message: Message | None = None | ||
| """The model's parsed turn (generator output as a message).""" | ||
|
|
||
| env_messages: list[Message] = field(default_factory=list) # [M_env] |
There was a problem hiding this comment.
how is env_messages related to prompt_messages?
There was a problem hiding this comment.
Prompt_messages: input to the generator (all history up to that point)
assistant_message: output of genereator
env_message: output of the environment, e.g. tool calls, new user message, etc
There was a problem hiding this comment.
butsomewhere else you used next_prompt_...
Is it the same as env_message or a subset?
There was a problem hiding this comment.
so "prompt" is previous [env + assistant] history, but next_prompt only means partial data in a turn
assistant vs. generator used interchangeably?
|
|
||
|
|
||
| @dataclass(frozen=True, kw_only=True, slots=True) | ||
| class DatasetOutput: |
There was a problem hiding this comment.
"Output" sounds confusing, it can be input to rollout / grader
There was a problem hiding this comment.
maybe DataSample?
| from torchtitan.experiments.rl.tasks.sum_digits.grader import SumDigitsRubric | ||
|
|
||
|
|
||
| class SumDigitsTask(Task): |
There was a problem hiding this comment.
The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.
I feel we should do this. The RLTrainer should just own trainer, generator, define task / env and maybe pass generator to them to obtain rollouts.
| Returns: | ||
| One `Reward` per rollout, in input order. | ||
| """ | ||
| return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts)) |
There was a problem hiding this comment.
After this PR we are still in Sync RL, but defining these functions here as they are basic for async RL?
There was a problem hiding this comment.
we do asyncio.gather because the reward_fns can be multiple LLMs, for example. So we can run them in parallel. This is orthogonal to async/sync RL. Does this make sense?
| return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts)) | ||
|
|
||
|
|
||
| def _fn_name(fn: Callable) -> str: |
There was a problem hiding this comment.
What's this helper function for? Can we just inline it?
There was a problem hiding this comment.
We should set up CPU CI test for RL to guard these tests
| """Ground-truth total digit sum.""" | ||
|
|
||
|
|
||
| class SumDigitsDataset(Configurable): |
There was a problem hiding this comment.
What's the consideration of putting the dataset a field of SumDigitsTask.Config() vs. connecting them using TASK_NAME? A Task can be the container of Env, Rubric and Dataset?
There was a problem hiding this comment.
Can we rename to rubrics.py which more aligned with our naming now
| self._validation_dataset = config.validation_dataset.build() | ||
| self._str2task_map: dict[str, Task] = { | ||
| name: cfg.build() for name, cfg in config.tasks.items() | ||
| } |
There was a problem hiding this comment.
We only have one task in our current RL loop now, are you going to support data mix soon? I guess we can simplify in this PR and only support one task for now, and handle multi-tasks together with data-mix PR
| completions, generation_metrics = self._get_rank_0_value( | ||
| self.generator.generate.call(tokenized_prompts).get() | ||
| group_size = self.config.generator.sampling.n | ||
| sampling_cfg = replace( |
There was a problem hiding this comment.
Is stop_token_ids same as eos_ids?
There was a problem hiding this comment.
thats my understanding, but we should get it directly from the renderer.tokenizer. Also, maybe the user can have some specific logic to stop when some token T appears
|
|
||
| Steps: | ||
| 1. Get examples from dataset | ||
| 2. For each example, find associated task, e.g. CodingTask, SearchTask, etc |
There was a problem hiding this comment.
So the reason we don't put dataset a subfield of Task is because of data-mixing? We will do datamixing at DatasetOutput level, not Task level? What does other repo model these concepts?
There was a problem hiding this comment.
What does other repo model these concepts?
I dont recall. I can take a look. Another option is to play it by year, so what we are comfortable with and refactor later when we try datamix. For now, lets put it inside of the Task, since thats a pattern i have seen as well and both and tianyu shared that it made more sense to you. I will make the changes.
| group_id=f"{example.task_name}/step={step}/group={group_offset + group_idx}", | ||
| example=example, | ||
| task=task, | ||
| envs=task.make_envs( |
There was a problem hiding this comment.
Do we need to create enviornments repeatedly for each sample? Eg, spin up a docker for each single CodingTask?
Or can we create a Enviornment for each Task?
There was a problem hiding this comment.
its up to the user to decide what happens inside of "make_envs", i.e. create a fresh new one or pull from a pool
| # 4. For each env, get initial prompt (n_groups * n_rollouts_per_group) | ||
| initial_steps: list[list[TokenizedStepOutput]] = await asyncio.gather( | ||
| *( | ||
| asyncio.gather(*(env.initial_prompt() for env in group.envs)) |
There was a problem hiding this comment.
What is initial_prompt here? Can you give an example? I'm confused why the prompt doesn't come from dataset
There was a problem hiding this comment.
the sample (history of messages) comes from the dataset. The env adds the system message and adds tool calls. The prompt is the final constructed input for the generator. But notice that we are not opinionated about it: if the user wants the env_input to include the system prompt as well, they can do it.
- env_types: MessageEnv (messages) + RendererWrapperEnv (tokens); Message{Reset,Step}Output
- rubrics: Configurable RewardFn + concrete Rubric + Reward; SumDigits grader -> rubric
- rollouts/types: RolloutTurn/Rollout/RolloutGroup; RolloutStatus incl TRUNCATED_PROMPT_TOO_LONG
- tasks: Task ABC owns rubric; SumDigits dataset/env/rubric/task; datasets on Task
- grpo/generator/renderer/config_registry: group_size, request_idx, renderer ModelSpec map
tianyu-l
left a comment
There was a problem hiding this comment.
Overall structure looks reasonable. I still ranted a lot on variable naming, as I think that's the key to hackability.
There was a problem hiding this comment.
In general I would prefer full spelling over shorthand.
Could we do
- rl/environments/message.py
- rl/environments/renderer.py
There was a problem hiding this comment.
I wanted to signal somehow that those are "base" envs, not something like a SumDigitsEnvs, thats why i put env_types. I am cool with environments, but i still think that we should signal somehow that those are types. Any ideas?
There was a problem hiding this comment.
I see. If we follow strictly the sft counterpart, for classes / interfaces you would expect users to implement, it probably should be
protocols/environmentsprotocols/rubrics- etc.
Meanwhile, components for things that torchtitan implement for you (but also overridable).
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class MessageResetOutput: |
There was a problem hiding this comment.
Can we call this MessageInitialOutput or MessageInitOutput.
reset sounds cleaning up from an old state which this message shouldn't care.
If you want strong consistency with env.reset, I proposal change that as well to env.init
There was a problem hiding this comment.
It also sounds unnatural when MessageOutput owns messages and tools.
Maybe should call RawOutput / EnvOutput / UnrenderedOutput vs. RenderedOutput
| class MessageResetOutput: | ||
| """Initial prompt messages + tool specs from `MessageEnv.reset`.""" | ||
|
|
||
| prompt_messages: list[Message] # [M_prompt] |
There was a problem hiding this comment.
Reserving prompt for the first message sounds ambiguous.
Can we call it init_messages or just messages?
| # env replies are tool/user turns; the assistant turn comes from the generator | ||
| if any(m.get("role") == "assistant" for m in self.env_messages): | ||
| raise ValueError( | ||
| "MessageStepOutput.env_messages may not contain assistant messages" |
There was a problem hiding this comment.
this check should be added for init output as well?
| """ | ||
|
|
||
| @abc.abstractmethod | ||
| async def reset(self) -> MessageResetOutput: |
There was a problem hiding this comment.
is this too controversal
| async def reset(self) -> MessageResetOutput: | |
| async def init(self) -> MessageResetOutput: |
|
|
||
| gen_metrics: list[m.Metric] = [] | ||
| # 3. Reset every env to get its first prompt. | ||
| prompt_steps_per_group_state = await asyncio.gather( # [G][group_size] |
| request_ids=request_ids, | ||
| sampling_config=sampling, | ||
| metrics_prefix=generation_metrics_prefix, | ||
| ).get() |
| 5. Run one batched `generate` over valid groups (n=1; pre-expanded) | ||
| 6. For each group: step generated rollouts if needed, then score with `task.score_group` |
There was a problem hiding this comment.
What's the consideration for these two phases. I understand that each rollout may finish in different number of turns -- how would we deal with the inefficiency of "finish one rollout at a time" later?
There was a problem hiding this comment.
when we land the generator refactor, there wont be a "batch" concept. We will change this function to continuously feed new examples to a queue that produces RolloutGroup in its own stream and put in the buffer. does this answer the question?
| @@ -238,19 +252,21 @@ class Config(Configurable.Config): | |||
|
|
|||
| num_prompts_per_step: int = 5 | |||
There was a problem hiding this comment.
| num_prompts_per_step: int = 5 | |
| num_groups_per_iteration: int = 5 |
| self.config.num_prompts_per_step * self.config.group_size, | ||
| self.config.num_validation_samples, |
There was a problem hiding this comment.
Add comment on how this number is derived. It looks you are using asyncio at multiple places, does this number capture all the concurrency?
|
@tianyu-l , i saw that you added more comments. I will read and address them. Regarding my latest changes: Thanks for the thorough review @tianyu-l, @wwwjn. I pushed some changes. Here is the before/after: Goal for the code (we cannot do it until we fix the generator) async def do_single_rollout(sampler, env) -> Rollout:
turns = []
prev_step = await env.reset() # TokenizedStepOutput (first prompt)
while not prev_step.status.is_terminal():
completion = await sampler.generate(prev_step.next_prompt_token_ids)
step = await env.step(completion)
turns.append(
RolloutTurn(
prompt_token_ids=prev_step.next_prompt_token_ids, # input -> prompt_*
prompt_messages=prev_step.next_prompt_messages,
assistant_token_ids=completion.token_ids, # model -> assistant_*
assistant_logprobs=completion.token_logprobs,
assistant_message=step.assistant_message,
env_messages=step.env_messages, # env -> env_*
env_rewards=step.env_rewards,
)
)
prev_step = step
return Rollout(turns=turns, status=prev_step.status)
async def do_group_rollout(task, task, group_size) -> RolloutGroup:
example = task.sample_train_example()
envs = task.make_envs(group_size=group_size)
for env in envs:
rollout = do_single_rollout(env, example)
rollout_group = RolloutGroup(rollouts=rollouts)
rewards = task.score_group(rollout_group)
...The overall picture — what a 1. The env is two explicit layers: messages vs tokens
Change: The old Before class MessageEnv(abc.ABC):
async def reset(self) -> ResetOutput: ...
async def step_message(self, msg: Message) -> StepOutput: ...
class RendererEnv: # "an Env that renders"?
async def initial_prompt(self) -> TokenizedStepOutput: ...
async def step_completion(self, completion) -> TokenizedStepOutput: ...
class RolloutTurn:
prompt_token_ids: list[int]
response_token_ids: list[int]
response_logprobs: list[float]
policy_version: int
prompt_messages: list[Message]
assistant_message: Message | None
env_messages: list[Message]
reward_components: dict[str, float]After class MessageEnv(Configurable, abc.ABC): # user code: messages only, never tokens
async def reset(self) -> MessageResetOutput: ...
async def step(self, assistant_message: Message) -> MessageStepOutput: ...
class RendererWrapperEnv: # wraps a MessageEnv: messages <-> tokens via a Renderer
async def reset(self) -> TokenizedStepOutput: ...
async def step(self, completion: Completion) -> TokenizedStepOutput: ...
# full snapshot of tokens + messages
class RolloutTurn: # the full per-turn snapshot
prompt_token_ids: list[int] # input
prompt_messages: list[Message]
assistant_token_ids: list[int] # model output (was response_token_ids)
assistant_logprobs: list[float] # (was response_logprobs)
assistant_message: Message | None
env_messages: list[Message] # env reply
env_rewards: dict[str, float] # (was reward_components)
policy_version: int | NoneThe wrapper does 2. The Task owns its dataset, env, and rubric (config-driven)
Change: A Before — datasets on the trainer, # RLTrainer.Config
train_dataset = SumDigitsDataset.Config(seed=42)
validation_dataset = SumDigitsDataset.Config(seed=99)
tasks = {"sum_digits": SumDigitsTask.Config()}
class Task(Configurable, abc.ABC):
class Config:
rubric: Rubric.Config # base owns ONLY the rubric
@abc.abstractmethod
def make_envs(self, *, example, group_size, renderer): ...
class SumDigitsTask(Task):
class Config(Task.Config):
rubric = SumDigitsRubric.Config()
renderer_env_config = RendererEnvConfig() # loose, on the subclass
def make_envs(self, *, example, group_size, renderer):
return [RendererEnv(message_env=SumDigitsEnv(env_input=example.env_input),
renderer=renderer, config=self.renderer_env_config)
for _ in range(group_size)]After — one # RLTrainer.Config
task = SumDigitsTask.Config()
group_size = 8 # explicit (was overloaded onto sampling.n, see §4)
# TODO: support multiple tasks for data mixing.
class Task(Configurable): # concrete, make_envs is generic
class Config(Configurable.Config):
train_dataset: Configurable.Config
validation_dataset: Configurable.Config
rubric: Rubric.Config
message_env: MessageEnv.Config # the env is part of the Task config
env_wrapper_cfg: RendererWrapperEnv.Config = field(default_factory=RendererWrapperEnv.Config)
def sample_train_example(self):
return next(self._train_dataset) # datasets are iterators
def sample_validation_example(self):
return next(self._validation_dataset)
def make_envs(self, *, example, group_size, renderer):
return [RendererWrapperEnv(message_env=self._message_env_config.build(env_input=example),
renderer=renderer, config=self._env_wrapper_cfg)
for _ in range(group_size)]
# override methods to customize the Task beyond the config
# e.g. manage a pool or envs instead of making them fresh
class SumDigitsTask(Task):
class Config(Task.Config):
train_dataset = SumDigitsDataset.Config(seed=42)
validation_dataset = SumDigitsDataset.Config(seed=99)
rubric = Rubric.Config(reward_fns=[RewardCorrect.Config(weight=1.0), RewardFormat.Config(weight=0.3)])
message_env = SumDigitsEnv.Config()3. Reward functions are
|
- env_types/ -> environments/ (MessageEnv, TokenEnv); reset -> init - generator output -> completion_* / parsed_completion_message; MessageInitOutput.init_prompt_messages - make_envs -> make_env_group; env_wrapper_cfg -> token_env; num_prompts_per_step -> num_groups_per_rollout_batch - SamplingConfig drops n + stop_token_ids (n=1 hardwired; stop tokens injected into VLLMGenerator) - _terminal inlined; Rollout.status required; await the generate call; self._messages rebind; trimmed list() copies - gen_metrics -> generation_metrics; max_concurrent_rollouts derivation comment
- C14: spell out shape hints ([L_prompt] -> [num_prompt_tokens], [M_*] -> [num_*_messages], etc.) - C37: rename Task -> Rollouter (+ SumDigitsRollouter; RLTrainer.Config.task -> rollouter) - move the tasks/ package into rollouts/ (rollouts/rollouter.py + rollouts/sum_digits/); Rollouter is imported from the submodule (rollouts.rollouter) to avoid a circular import via rollouts/__init__ (environments imports rollouts.types) - TODO(naming): VLLMGenerator -> InferenceEngine left at the class
- environments/ -> environment/ (singular) - rollouts/ -> rollout/ (now also holds rollouter.py) - tasks/ dissolved -> examples/sum_digits/ (concrete problem); base Rollouter in rollout/rollouter.py - Rollouter imported from the submodule (rollout.rollouter), not re-exported from rollout/__init__, to avoid an environment<->rollout import cycle - update all import paths
…renderer - all RendererConfig fields default None; build() overrides only the knobs the user set (non-None and supported), constructing the typed config via config_type(**args) - name None/"auto"/unmapped -> auto-resolve or pass the renderer name straight through - log the chosen renderer + args
the new layoutGoal for the code (multi-turn, once the generator supports it): async def do_single_rollout(generator, env) -> Rollout:
turns = []
env_output = await env.init()
while not env_output.status.is_terminal():
prompt_token_ids = env_output.next_prompt_token_ids
prompt_messages = env_output.next_prompt_messages
completion = await generator.generate(prompt_token_ids)
env_output = await env.step(completion)
turns.append(
RolloutTurn(
prompt_token_ids=prompt_token_ids,
prompt_messages=prompt_messages,
completion_token_ids=completion.token_ids, # generator -> completion_*
completion_logprobs=completion.token_logprobs,
parsed_completion_message=env_output.parsed_completion_message,
env_messages=env_output.env_messages,
env_rewards=env_output.env_rewards,
)
)
return Rollout(turns=turns, status=env_output.status)1.
|
|
Thanks for the comments, it's much more clear
I might missed something, what's the gap for generator to support this "code goal"? |
…otocol - grpo.py: kept our refactored async _collect_rollouts (rollouter); upstream's fut.get()->await auto-merged - test_grpo_metrics.py: kept deleted (our refactor removed it; upstream's version imports removed Step/Trajectory) - test_generator.py: drop removed SamplingConfig(n=); set _stop_token_ids on the hand-built generator - test_shutdown.py: stub the renderer/generator/rollouter that __init__ now builds - RL tests: 91 passed, 3 skipped
allow multiple threads to call Currently we have to batch all requests because our generator doesn't have any I will have a PR to fix that probably tomorrow |
| if TYPE_CHECKING: | ||
| from torchtitan.experiments.rl.actors.generator import Completion |
There was a problem hiding this comment.
Why we need this type checking?
| """Render the initial conversation into the first generator prompt.""" | ||
| env_init_output = await self._message_env.init() | ||
|
|
||
| # Copy our running conversation so we avoid mutating previous states |
There was a problem hiding this comment.
What does this "running conversation" mean?
| if parsed.reasoning_content: | ||
| parsed_completion_message["reasoning_content"] = parsed.reasoning_content | ||
| if parsed.tool_calls: | ||
| parsed_completion_message["tool_calls"] = parsed.tool_calls |
There was a problem hiding this comment.
This means the model generate a <tool_use> token, right?
and I want to do a knowledge check of my mental model. say if we are training on a Coding task:
- Do we need to plug both SandBoxEnv and MessageEnv into the TokenEnv
- When there's a <tool_use>` token, TokenEnv will call the SandBoxEnv to execute the tool?
I guess my questions is: where will the tool calling happen in the future?
There was a problem hiding this comment.
Chatted with Claude, it says:
In this design, MessageEnv is the problem-specific environment that receives the
assistant’s parsed message and decides what happens next. For a coding task, that env
would own or call the sandbox.
VLLMGenerator
|
| completion tokens
v
TokenEnv
|
| Renderer.parse_response(...)
| parsed assistant message
v
CodingMessageEnv.step(parsed_completion_message)
|
| extract code / tool call / patch / command
v
Coding Sandbox
|
| run tests, execute code, inspect files, enforce timeout
v
CodingMessageEnv
|
| returns tool/user messages + optional env_rewards + done
v
TokenEnv
|
| render next prompt tokens if not done
v
VLLMGenerator
I updated my mental model: TokenEnv is plumbing between Generator and message space
tianyu-l
left a comment
There was a problem hiding this comment.
I promise this is last push.
| """Maximum number of tokens to generate per completion.""" | ||
|
|
||
|
|
||
| # TODO(naming): rename VLLMGenerator -> InferenceEngine. |
There was a problem hiding this comment.
I think after we name Rollouter, this is less necessary.
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class MessageInitOutput: |
There was a problem hiding this comment.
Since you have TokenEnvOutput, I would mirror this to MessageEnvInitOutput, especially considering it contains not just messages but also tools.
| """Initial messages + tool specs from `MessageEnv.init`.""" | ||
|
|
||
| init_prompt_messages: list[Message] # [num_prompt_messages] | ||
| """The opening messages (e.g. [system, user]).""" |
There was a problem hiding this comment.
mention that we don't post_init check because assistant is legit
| parsed_completion_message: Message | None = None | ||
| """This turn's parsed completion message. `None` on init and on parse failure.""" | ||
|
|
||
| env_messages: list[Message] = field(default_factory=list) # [num_env_messages] |
There was a problem hiding this comment.
Is this part of next_prompt_messages? like
next_prompt_messages contains
[previous prompt_messages, parsed_completion_message, env_messages]
There was a problem hiding this comment.
can we use completion_tokens vs completion_message?
There was a problem hiding this comment.
the reason i added "parsed" is to make it evident that it did NOT come from the vllm completion, but rather it was parsed by the tokenEnv. Does this make sense? I am ok with renaming, but we lose this detail
There was a problem hiding this comment.
That's fair, but for perfect consistency it sounds we should do
- prompt_message
- tokenized_prompt_token_ids
- completion_token_ids
- detokenized_completion_message
There are a lot of redundancy so I'm not sure.
I would like to at least avoid inventing another word "parsed"
| self._messages = ( | ||
| self._messages + [parsed_completion_message] + step_output.env_messages | ||
| ) |
There was a problem hiding this comment.
now that i am thinking about it, this is at the wrong place. The MessageEnv should say what the "next_messages" are, e.g. for compacting history. Let me see how i can move this logic to MessageEnv. I guess this would mean i have to delete "env_messages" and just have it return "next_prompt_messages"
| completion_token_ids=completion.token_ids, | ||
| completion_logprobs=completion.token_logprobs, | ||
| policy_version=completion.policy_version, |
There was a problem hiding this comment.
It seems you could've put these in TokenEnvOutput, so that we can consolidate RolloutTurn and TokenEnvOutput.
If so, I would strongly advocate that, in this version.
There was a problem hiding this comment.
class TokenEnvOutput:
# off by one
next_prompt_token_ids: list[int] | None # [num_prompt_tokens] or None
next_prompt_messages: list[Message] | None = None # [num_prompt_messages] or None
# different
status: RolloutStatus
# Same
parsed_completion_message: Message | None = None
env_messages: list[Message] = field(default_factory=list) # [num_env_messages]
env_rewards: dict[str, float] = field(default_factory=dict)
class RolloutTurn:
# off by one
prompt_token_ids: list[int] # [num_prompt_tokens]
prompt_messages: list[Message] = field(default_factory=list) # [num_prompt_messages]
# different
policy_version: int | None = None
completion_token_ids: list[int] # [num_completion_tokens]
completion_logprobs: list[float] # [num_completion_tokens]
# Same
parsed_completion_message: Message | None = None
env_messages: list[Message] = field(default_factory=list) # [num_env_messages]
env_rewards: dict[str, float] = field(default_factory=dict)RolloutTurn is aware of the whole rollout loop. There may be other generators (e.g. a teacher), more metrics, some crazy logic. While the TokenEnvOutput is concerned only about the Env.
I think it would be too opinionated to do consolidate them and would put unnecessary responsability on the TokenEnv.
I do agree that there is some overlap. We could create some dataclasses to hold some of the fields that are common, e.g. Completion or tokens_with_logprobs can be an output of the generator and we do RolloutTurn.completion.
NextPrompt could hold next_prompt_token_ids and next_prompt_messages.
cc: @wwwjn in case you have an opinion as well
There was a problem hiding this comment.
sounds good to leave them separate right now
| def make_env_group( | ||
| self, | ||
| *, | ||
| example: object, |
There was a problem hiding this comment.
not strongly opinionated
| example: object, | |
| sample: object, |
| self._token_env_config = config.token_env | ||
|
|
||
| # TODO: revisit this abstraction: should it return a sample or a dataset or an iterator? | ||
| def sample_train_example(self) -> object: |
There was a problem hiding this comment.
nit: use "training" instead of "train" to be consistent
| def sample_train_example(self) -> object: | |
| def get_training_sample(self) -> object: |
| completions, generation_metrics = self._get_rank_0_value( | ||
| self.generator.generate.call(tokenized_prompts).get() | ||
|
|
||
| rollouter: Rollouter = self._rollouter |
There was a problem hiding this comment.
prefer to not have this indirection.
| self, | ||
| *, | ||
| group_state: _RolloutGroupState, | ||
| rollouter: Rollouter, |
There was a problem hiding this comment.
don't need to pass this in, you have self._rollouter
There was a problem hiding this comment.
for datamix we would need it, but all of this code in grpo will go under the rollout anyway, and in this case, it would just use self. Either way, for now, we have 1 rollout, so i agree
| for group_state in valid_group_states: | ||
| for sample_idx, env_init_output in enumerate(group_state.env_init_outputs): | ||
| prompt_token_ids.append(env_init_output.next_prompt_token_ids or []) | ||
| request_ids.append(_sample_id(group_state.group_id, sample_idx)) |
There was a problem hiding this comment.
A side note: I read from somewhere else the request_id needs to contain the DP rank id, otherwise each DP rank will have the same set of request_ids, which will causing collision request id error in vllm engine. I think it's a good practice to make the request id globally unique.
We can put a TODO here, wdty?
There was a problem hiding this comment.
I am not sure if i follow: The id is unique. I guess i could add some random letters to make sure.
But would we send the same sample_id to different dp ranks? Shouldnt each rank get a different set of ramples?
There was a problem hiding this comment.
same prompt go to same DP rank for prefix cache reuse
| next_completion_offset = completion_offset + len( | ||
| group_state.env_init_outputs | ||
| ) | ||
| group_state.completions = completions[ |
There was a problem hiding this comment.
Is completions sorted?
Sort and then slicing to group samples into samples are error-prone, can we get group_id from request_id and then group using group id?
There was a problem hiding this comment.
They are sorted. I should fix it. But since this code will be delete/refactored in the next feel days, is it ok if i leave a todo for now?
| ) | ||
|
|
||
| @sl.log_trace_span("_run_single_rollout") | ||
| async def _run_single_rollout( |
There was a problem hiding this comment.
What's the consideration of putting _run_group_rollout, _run_single_rollout in controller vs. in Rollouter?
if we put it in rollouter, I would image controller could simply call
self._rollouter.run_group_rollout(generator, sample_id, xxx)
There was a problem hiding this comment.
yes, thats the direction. I thought about already removing the "collect_rollouts" from grpo and put it in the Rollouter. I was going to leave it for the next PR refactor. Maybe we should land as is to avoid another round of review that will be wasted anyway due to the refactor?
Now i feel like i should have fixed the generator first before doing this PR haha
…, sample), docstrings + TODOs
…datatypes-env-protocol # Conflicts: # torchtitan/experiments/rl/batcher.py # torchtitan/experiments/rl/config_registry.py
tianyu-l
left a comment
There was a problem hiding this comment.
Great work!!
Not sure if you missed #3453 (comment)
The benefit is that
- user knows what they should change (env, rubric, etc.), what they should not (generator / trainer)
- consistent with pretraining https://github.com/pytorch/torchtitan/tree/main/torchtitan/protocols
Shouldn't block this PR. Happy to discuss more.

How to review?
a. Read the contents in tasks/sum_digits
b. Read grpo.py:collect_rollouts
c. Read the rest
Summary
Our current script does not use messages or chat template. Now it will be the default. Users write
reset/step_message;a
RendererEnvwraps it and owns all message <-> token plumbing done by theRenderer.Typed rollout records:
RolloutGroup(Rollout(RolloutTurn)) replace the oldTrajectory/(Completion, Step)pairs. They now carry messages and tokens that support multi-turn.Rubric: Class to hold functions for scoring after rollout is finished
It also handles partial scoring in case of truncation and error.
a) create/store Envs for an specific task
b) Holds the
rubricassociated to that taskc) in the future will hold the rollout loop -- which can be customized by users if they want to.
This makes it trivial to do i) dataset mix; ii) share/import tasks
Why?
For single turn we can naively concatenate prompt+response. To enable multiturn we had to enable (1) and (2). Given i was refactoring it, i added (3) and (4)
Blockers/next steps: