Skip to content

[RL] - MessageEnv, Rollout types, Rubric, Renderer#3453

Merged
felipemello1 merged 40 commits into
pytorch:mainfrom
felipemello1:v8-datatypes-env-protocol
Jun 5, 2026
Merged

[RL] - MessageEnv, Rollout types, Rubric, Renderer#3453
felipemello1 merged 40 commits into
pytorch:mainfrom
felipemello1:v8-datatypes-env-protocol

Conversation

@felipemello1

@felipemello1 felipemello1 commented May 29, 2026

Copy link
Copy Markdown
Contributor

How to review?
a. Read the contents in tasks/sum_digits
b. Read grpo.py:collect_rollouts
c. Read the rest

Summary

  1. A message-level env protocol:

Our current script does not use messages or chat template. Now it will be the default. Users write reset / step_message;

class SumDigitsEnv(MessageEnv):
    async def reset(self) -> MsgResponseReset:
		...
    async def step_message(self, msg: Message) -> MsgResponseStep:
        ...

a RendererEnv wraps it and owns all message <-> token plumbing done by the Renderer.

example = dataset.get_sample()
env = RendererEnv( 
	message_env=SumDigitsEnv(env_input=example.env_input),
    renderer=renderer
)
initial_turn = env.reset()
  1. Typed rollout records: RolloutGroup(Rollout(RolloutTurn)) replace the old Trajectory / (Completion, Step) pairs. They now carry messages and tokens that support multi-turn.

  2. Rubric: Class to hold functions for scoring after rollout is finished

class MyRubric(Rubric):
    def register_funcs(self) -> list[RewardFn]:
        return [
            RewardFn(fn=my_reward_fn1, weight=0.5),
            RewardFn(fn=my_reward_fn2, weight=0.5),
        ]

rubric = MyRubric(config=MyRubric.Config(truncation_reward=0.0))
rewards = await rubric.score_group(my_rollouts, env_input)
for reward, rollout in zip(rewards, my_rollouts):
    my_rollout.reward = reward.reward
    my_rollout.reward_components = reward.components

It also handles partial scoring in case of truncation and error.

  1. Task: A class that knows how to

a) create/store Envs for an specific task
b) Holds the rubric associated to that task
c) in the future will hold the rollout loop -- which can be customized by users if they want to.

This makes it trivial to do i) dataset mix; ii) share/import tasks

experiments/rl/
├── grpo.py                  # controller + rollout loop  (changed)
├── renderer.py              # NEW: RendererConfig -> a renderers.Renderer
├── env_types/                    # NEW: the env protocol
│   ├── message_env.py       #   MessageEnv (ABC), ResetOutput, StepOutput
│   └── renderer_env.py      #   RendererEnv, RendererEnvConfig, TokenizedStepOutput
├── rollouts/                # NEW: the datatypes
│   ├── types.py             #   Rollout, RolloutTurn, RolloutGroup, RolloutStatus, DatasetOutput
│   └── utils.py             #   rollout_to_episode, prepare_rollout_metrics
├── rubrics/rubric.py        # NEW: Rubric, RewardFn, Reward
└── tasks/                   # NEW (was sum_digits.py)
    ├── task.py              #   Task
    └── sum_digits/          #   the worked example
        └── data.py · env.py · grader.py · task.py

Why?

For single turn we can naively concatenate prompt+response. To enable multiturn we had to enable (1) and (2). Given i was refactoring it, i added (3) and (4)

Blockers/next steps:

  1. Continuous Batching (CB): Necessary so multiturn Rollouts can progress independently
  2. Async rollout + AlphabetSort multiturn task: Small PR once this and continuous batching lands

wwwjn and others added 18 commits May 20, 2026 13:19
- Add shared pack() function in torchtitan/components/dataloading/utils.py
- Add Batcher class with Configurable pattern: packs episodes into
  fixed [B, seq_length] TrainBatches with gradient accumulation support
- Refactor TrainBatch to use [B, L] tensors instead of [1, total_tokens]
- Update PolicyTrainer.forward_backward to work with [B, L] microbatches
- Simplify compute_logprobs/verify_logprob_identity in actors/utils.py
- envs/: MessageEnv ABC + MessageReset/MessageStep; RendererEnv wrapper
  + RendererEnvConfig + TokenizedTurn; Rollout / RolloutTurn / RolloutStatus
  / DatasetOutput types.
- rubrics/: Rubric base class with score_one / score_group, typed Reward
  dataclass; Rubric owns truncation_reward / error_reward (doc 37 Option B).
- recipes/: Task base + SumDigitsTask (dataset + env + grader + recipe).
- grpo.py: _run_rollouts refactored (Option G — inline orchestration +
  _do_group_step helper + per-group failure isolation); validate() reuses
  the same path.
- Drops: sum_digits.py (orphan), test_grpo_metrics.py (broken), Trajectory
  + Step from types.py.
- rollouts/: new folder holding Rollout / RolloutTurn / RolloutStatus /
  DatasetOutput (types.py) and last_assistant_text / rollout_to_episode /
  prepare_rollout_metrics (utils.py). envs/types.py deleted.
- Rubric: Configurable with register_funcs() hook + lazy @cached_property
  weight normalization (sum-to-1). New RewardFn dataclass (fn + weight).
  truncation_reward / error_reward become Optional[float] short-circuit
  knobs (None = run reward fns on partial response).
- SumDigitsRubric subclass added in grader.py; SumDigitsDataset is now
  Configurable; SumDigitsTask.Config composes nested sub-configs
  (dataset / rubric / env_limits) matching PolicyTrainer.Config style.
- MessageReset/Step -> MsgResponseReset/Step.
- RendererEnvConfig -> EnvLimits; kwarg config -> limits.
- recipes/sum_digits/recipe.py -> task.py.
- _run_rollouts: overflow-aware prompt routing -- initial-prompt overflow
  builds TRUNCATED_OVERFLOW rollouts directly instead of sending an empty
  prompt to the generator.
- prepare_rollout_metrics consolidates the rollout/validation metric
  blocks; _prepare_reward_metrics removed.
- Cleanups: RolloutStatus is_*() via frozenset; validate_rollout_output
  dropped; Rollout.group_id / sample_idx required; _log_samples
  Episode-only; reward_correct stops asserting env_input type.
Task surface
- DatasetOutput.env_name -> DatasetOutput.task; SumDigitsDataset
  ENV_NAME -> TASK_NAME.
- Dataset moves off Task onto RLTrainer.Config:
    train: Task.Config -> train_dataset + tasks dict keyed by task name.
    Same for validation.
  Rows route to the matching Task via example.task.
- SumDigitsTask: drops dataset + sample_example; only rubric,
  env_limits, and make_envs remain.
- Task base gains score_group(rollouts, env_input) -> list[Rollout]:
  default delegates to self.rubric.score_group + fills reward and
  reward_components. Step logic stays in grpo._do_group_step.
- TODOs: continuous-batching Task.do_single_rollout; revisit Camp A
  vs B (dataset on Task vs framework).

Style + cleanups (same files)
- Dataclasses converted from per-field docstrings to Args-at-top with
  inline shape comments. Drops double-backticks and stale PR
  references.
- Bare # title block comments in _run_rollouts, _build_episodes,
  validate, train, _do_group_step. Dashed banners removed in train.
- Dead code: _shard_episodes removed (Batcher does sharding).
- Bug fixes:
  - prepare_rollout_metrics total_lens computed per-rollout
    (multi-turn safe).
  - _do_group_step initial.next_token_ids raises on None instead of
    silent default to [].
  - Rubric: register_funcs result checked for unique fn names.
  - renderer_env: env_step.status defaulting uses replace() not
    in-place mutation; _terminal accepts last_response_messages
    instead of post-construction mutation.

Smoke: imports, config build, rubric e2e (3 statuses + short-circuit
modes), task.score_group fills rewards.
- Dataclasses: per-field docstrings; single backticks; drop internal-doc refs.
- RendererEnv.step_completion parses before classifying length/abort, so
  truncated/aborted/timed-out turns keep their response message + tokens
  (partial-reward grading + debugging).
- Completion.text removed (text comes from rollout messages); RolloutTurn gains
  reward/reward_components; TokenizedTurn carries terminal status.
- Read pad/eos from the renderer's tokenizer; drop the standalone tokenizer.
- Task.score_group returns list[Reward]; controller applies them.
- Condense _run_rollouts; README installs renderers.
Conflicts resolved:
- batcher.py / TrainingBatch labels / compute_logprobs -> take main (landed batcher PR).
- grpo.py / config_registry / types Completion+Episode -> keep v8 (our rollout/env/rubric/recipe layer), adapted to main's batcher API.
- config_registry: keep recipes structure (tasks/dataset), import BatchConfig from rl.batcher, enable_wandb=True (matches main).
- Dropped stale test_grpo_metrics.py (tested the removed Step/Trajectory API).
renderers 0.1.8.dev37 replaced the name-based factory with a typed-config one.
RendererConfig.build picks the config variant for name from the public
discriminated union and passes only supported knobs; adds an enable_thinking
field. Removes the private-attr workaround.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 29, 2026
Felipe Mello added 5 commits May 28, 2026 23:30
The merge with the landed batcher PR pulled in changes that don't belong
to this PR:

- Drop half-done-batcher-base leftovers: BatchConfig in config/configs.py
  (the landed batcher has its own in rl.batcher), config/__init__ export,
  components/dataloading/{__init__,utils}.py, and the text_datasets.py
  rewrite. RL never imports any of these.
- grpo.py: revert train() docstring, reworded comments, and the
  per-microbatch metrics rewrite back to upstream's aggregation; revert the
  setup_async docstring; re-add _shard_episodes (dead in upstream too).
  Keep only the genuine v8 substitutions (Rollout types, async
  _collect_rollouts, _build_episodes).
- _run_rollouts/_collect_rollouts return list[RolloutGroup]; _PendingGroup
  build struct; whole-group drop on prompt overflow; flat n=1 generate with
  completions zipped 1:1 (no rebucket).
- _do_group_step -> _do_single_rollout (one env+completion -> Rollout) with
  try/except inside so partial turns survive as an ERROR rollout.
- RolloutStatus gains ONGOING + ERROR + is_terminal(); TRUNCATED_OVERFLOW ->
  TRUNCATED_PROMPT_OVERFLOW; TokenizedResponseStep.status is required.
- _is_overflow -> _is_prompt_overflow (prompt_len >= max_rollout_tokens);
  drop unused EnvLimits.max_generation_tokens.
- TokenizedTurn -> TokenizedResponseStep; Rubric.register_funcs is
  @abc.abstractmethod; logging TODOs on Rollout/RolloutTurn.
- config_registry: enable_thinking=True, max_tokens=700 for all rl_grpo_qwen3_*
  configs. Qwen3-0.6B one-shots sum_digits with thinking + enough budget
  (reward ~0.95-1.0 from step 1); 100/200 tokens truncate mid-<think>.
- rollouts/types.py: shorten RolloutStatus docstring; group RolloutTurn fields.
…build

- recipes/ -> tasks/; DatasetOutput.task -> task_name.
- env carriers: MsgResponseReset/Step -> ResetOutput/StepOutput;
  TokenizedResponseStep -> TokenizedStepOutput. StepOutput drops status
  (keeps done); RendererEnv owns RolloutStatus.
- EnvLimits -> RendererEnvConfig (field renderer_env_config; RendererEnv(config=)).
- next_token_ids/next_messages -> next_prompt_token_ids/next_prompt_messages.
- last_response_messages/response_messages split -> assistant_message + env_messages.
- RolloutTurn gains prompt_messages; restore env-set reward path:
  StepOutput.reward_components -> TokenizedStepOutput.env_reward_components
  -> RolloutTurn.reward_components.
- renderer.py build() via pydantic TypeAdapter; drop defensive list copy;
  log parse/timeout exceptions; MessageEnv docstring + Example;
  revert 9 unrelated RLTrainer.Config docstrings; 'policy' -> 'limits'.
- _build_episodes: drop a group when any sibling has no turns (turn-less
  ERROR rollout) instead of checking reward-is-None. The old check never
  fired (score_group always sets rewards) and let a turn-less rollout reach
  rollout_to_episode, which raises (requires exactly one turn).
- rollout_to_episode: derive text from the rollout (last_assistant_text)
  instead of taking it as a param; drop the now-unused import in grpo.
@felipemello1 felipemello1 marked this pull request as ready for review May 29, 2026 21:52
@felipemello1 felipemello1 requested review from tianyu-l and wwwjn May 29, 2026 21:52
@felipemello1 felipemello1 changed the title [NOT READY][RL] - MessageEnv, Rollout types, Rubric, Renderer [RL] - MessageEnv, Rollout types, Rubric, Renderer May 29, 2026

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, did a first pass. Overall I feel the logic among Env, Task, Rollout, Dataset, Rubric could be more clear.

Comment thread torchtitan/experiments/rl/types.py Outdated
return sum(s.reward for _, s in self.transitions)


# TODO: rename `Episode` -> `TrainSample` and `rollout_to_episode` ->

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: TrainingSample, to be consistent with TrainingBatch

uv pip install torchmonarch==0.4.1
uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main"
uv pip install pygtrie portpicker
uv pip install "git+https://github.com/PrimeIntellect-ai/renderers.git@main"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be used for sft as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to depend on renders lastest main?

preserve_thinking_between_tool_calls: bool = False

def build(self, *, model_path: str) -> Renderer:
from transformers import AutoTokenizer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this dependency?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread torchtitan/experiments/rl/renderer.py Outdated
`model_path` and constructs the `renderers` config matching `name`.

Args:
name: Renderer name (e.g. `"qwen3"`, `"auto"`).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear what this means. We should make it explicitly referring to torchtitan model (and their tokenizer) and avoid transformers dependency.

top_p=_sampling_config.top_p,
max_tokens=_sampling_config.max_tokens,
n=_sampling_config.n,
stop_token_ids=list(_sampling_config.stop_token_ids) or None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stop_token_ids=list(_sampling_config.stop_token_ids) or None,
stop_token_ids=_sampling_config.stop_token_ids or None,

"""The env's reply messages this turn (tool / user)."""

# For rubrics
reward_components: dict[str, float] = field(default_factory=dict)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call it reward which should always decompose into "components"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we want to have a simple field "reward: float", because thats what the final loss will use. No ambiguity here. And components is the breakdown for logging or weight averaging or some specific advantage computation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If single field reward is always coming from the components, can we have cached_property that derive the single field from components. If not (instead single field is the only thing RL algorithm eventually care, and everything else is only for logging), then I'm fine with it.

assistant_message: Message | None = None
"""The model's parsed turn (generator output as a message)."""

env_messages: list[Message] = field(default_factory=list) # [M_env]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is env_messages related to prompt_messages?

@felipemello1 felipemello1 Jun 1, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt_messages: input to the generator (all history up to that point)
assistant_message: output of genereator
env_message: output of the environment, e.g. tool calls, new user message, etc

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

butsomewhere else you used next_prompt_...

Is it the same as env_message or a subset?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so "prompt" is previous [env + assistant] history, but next_prompt only means partial data in a turn

assistant vs. generator used interchangeably?

Comment thread torchtitan/experiments/rl/rollouts/types.py Outdated


@dataclass(frozen=True, kw_only=True, slots=True)
class DatasetOutput:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Output" sounds confusing, it can be input to rollout / grader

@felipemello1 felipemello1 Jun 1, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe DataSample?

from torchtitan.experiments.rl.tasks.sum_digits.grader import SumDigitsRubric


class SumDigitsTask(Task):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.

I feel we should do this. The RLTrainer should just own trainer, generator, define task / env and maybe pass generator to them to obtain rollouts.

Returns:
One `Reward` per rollout, in input order.
"""
return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR we are still in Sync RL, but defining these functions here as they are basic for async RL?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do asyncio.gather because the reward_fns can be multiple LLMs, for example. So we can run them in parallel. This is orthogonal to async/sync RL. Does this make sense?

return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts))


def _fn_name(fn: Callable) -> str:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this helper function for? Can we just inline it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set up CPU CI test for RL to guard these tests

"""Ground-truth total digit sum."""


class SumDigitsDataset(Configurable):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consideration of putting the dataset a field of SumDigitsTask.Config() vs. connecting them using TASK_NAME? A Task can be the container of Env, Rubric and Dataset?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename to rubrics.py which more aligned with our naming now

Comment thread torchtitan/experiments/rl/grpo.py Outdated
self._validation_dataset = config.validation_dataset.build()
self._str2task_map: dict[str, Task] = {
name: cfg.build() for name, cfg in config.tasks.items()
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only have one task in our current RL loop now, are you going to support data mix soon? I guess we can simplify in this PR and only support one task for now, and handle multi-tasks together with data-mix PR

Comment thread torchtitan/experiments/rl/grpo.py Outdated
completions, generation_metrics = self._get_rank_0_value(
self.generator.generate.call(tokenized_prompts).get()
group_size = self.config.generator.sampling.n
sampling_cfg = replace(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is stop_token_ids same as eos_ids?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats my understanding, but we should get it directly from the renderer.tokenizer. Also, maybe the user can have some specific logic to stop when some token T appears

Comment thread torchtitan/experiments/rl/grpo.py Outdated

Steps:
1. Get examples from dataset
2. For each example, find associated task, e.g. CodingTask, SearchTask, etc

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason we don't put dataset a subfield of Task is because of data-mixing? We will do datamixing at DatasetOutput level, not Task level? What does other repo model these concepts?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does other repo model these concepts?

I dont recall. I can take a look. Another option is to play it by year, so what we are comfortable with and refactor later when we try datamix. For now, lets put it inside of the Task, since thats a pattern i have seen as well and both and tianyu shared that it made more sense to you. I will make the changes.

Comment thread torchtitan/experiments/rl/grpo.py Outdated
group_id=f"{example.task_name}/step={step}/group={group_offset + group_idx}",
example=example,
task=task,
envs=task.make_envs(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to create enviornments repeatedly for each sample? Eg, spin up a docker for each single CodingTask?

Or can we create a Enviornment for each Task?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its up to the user to decide what happens inside of "make_envs", i.e. create a fresh new one or pull from a pool

Comment thread torchtitan/experiments/rl/grpo.py Outdated
# 4. For each env, get initial prompt (n_groups * n_rollouts_per_group)
initial_steps: list[list[TokenizedStepOutput]] = await asyncio.gather(
*(
asyncio.gather(*(env.initial_prompt() for env in group.envs))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is initial_prompt here? Can you give an example? I'm confused why the prompt doesn't come from dataset

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sample (history of messages) comes from the dataset. The env adds the system message and adds tool calls. The prompt is the final constructed input for the generator. But notice that we are not opinionated about it: if the user wants the env_input to include the system prompt as well, they can do it.

Felipe Mello added 2 commits June 1, 2026 21:54
- env_types: MessageEnv (messages) + RendererWrapperEnv (tokens); Message{Reset,Step}Output
- rubrics: Configurable RewardFn + concrete Rubric + Reward; SumDigits grader -> rubric
- rollouts/types: RolloutTurn/Rollout/RolloutGroup; RolloutStatus incl TRUNCATED_PROMPT_TOO_LONG
- tasks: Task ABC owns rubric; SumDigits dataset/env/rubric/task; datasets on Task
- grpo/generator/renderer/config_registry: group_size, request_idx, renderer ModelSpec map

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall structure looks reasonable. I still ranted a lot on variable naming, as I think that's the key to hackability.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would prefer full spelling over shorthand.

Could we do

  • rl/environments/message.py
  • rl/environments/renderer.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to signal somehow that those are "base" envs, not something like a SumDigitsEnvs, thats why i put env_types. I am cool with environments, but i still think that we should signal somehow that those are types. Any ideas?

@tianyu-l tianyu-l Jun 3, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If we follow strictly the sft counterpart, for classes / interfaces you would expect users to implement, it probably should be

  • protocols/environments
  • protocols/rubrics
  • etc.

Meanwhile, components for things that torchtitan implement for you (but also overridable).



@dataclass(kw_only=True, slots=True)
class MessageResetOutput:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this MessageInitialOutput or MessageInitOutput.
reset sounds cleaning up from an old state which this message shouldn't care.

If you want strong consistency with env.reset, I proposal change that as well to env.init

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also sounds unnatural when MessageOutput owns messages and tools.

Maybe should call RawOutput / EnvOutput / UnrenderedOutput vs. RenderedOutput

class MessageResetOutput:
"""Initial prompt messages + tool specs from `MessageEnv.reset`."""

prompt_messages: list[Message] # [M_prompt]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reserving prompt for the first message sounds ambiguous.

Can we call it init_messages or just messages?

# env replies are tool/user turns; the assistant turn comes from the generator
if any(m.get("role") == "assistant" for m in self.env_messages):
raise ValueError(
"MessageStepOutput.env_messages may not contain assistant messages"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check should be added for init output as well?

"""

@abc.abstractmethod
async def reset(self) -> MessageResetOutput:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this too controversal

Suggested change
async def reset(self) -> MessageResetOutput:
async def init(self) -> MessageResetOutput:

Comment thread torchtitan/experiments/rl/grpo.py Outdated

gen_metrics: list[m.Metric] = []
# 3. Reset every env to get its first prompt.
prompt_steps_per_group_state = await asyncio.gather( # [G][group_size]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"steps" not clear

Comment thread torchtitan/experiments/rl/grpo.py Outdated
request_ids=request_ids,
sampling_config=sampling,
metrics_prefix=generation_metrics_prefix,
).get()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't use get()?

Comment thread torchtitan/experiments/rl/grpo.py Outdated
Comment on lines +606 to +607
5. Run one batched `generate` over valid groups (n=1; pre-expanded)
6. For each group: step generated rollouts if needed, then score with `task.score_group`

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consideration for these two phases. I understand that each rollout may finish in different number of turns -- how would we deal with the inefficiency of "finish one rollout at a time" later?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we land the generator refactor, there wont be a "batch" concept. We will change this function to continuously feed new examples to a queue that produces RolloutGroup in its own stream and put in the buffer. does this answer the question?

Comment thread torchtitan/experiments/rl/grpo.py Outdated
@@ -238,19 +252,21 @@ class Config(Configurable.Config):

num_prompts_per_step: int = 5

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_prompts_per_step: int = 5
num_groups_per_iteration: int = 5

Comment thread torchtitan/experiments/rl/grpo.py Outdated
Comment on lines +451 to +452
self.config.num_prompts_per_step * self.config.group_size,
self.config.num_validation_samples,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on how this number is derived. It looks you are using asyncio at multiple places, does this number capture all the concurrency?

@felipemello1

Copy link
Copy Markdown
Contributor Author

@tianyu-l , i saw that you added more comments. I will read and address them.


Regarding my latest changes:

Thanks for the thorough review @tianyu-l, @wwwjn. I pushed some changes. Here is the before/after:

Goal for the code (we cannot do it until we fix the generator)

async def do_single_rollout(sampler, env) -> Rollout:
    turns = []
    prev_step = await env.reset()                           # TokenizedStepOutput (first prompt)
    while not prev_step.status.is_terminal():
        completion = await sampler.generate(prev_step.next_prompt_token_ids)
        step = await env.step(completion)
        turns.append(
            RolloutTurn(
                prompt_token_ids=prev_step.next_prompt_token_ids,    # input  -> prompt_*
                prompt_messages=prev_step.next_prompt_messages,
                assistant_token_ids=completion.token_ids,           # model  -> assistant_*
                assistant_logprobs=completion.token_logprobs,
                assistant_message=step.assistant_message,
                env_messages=step.env_messages,                     # env    -> env_*
                env_rewards=step.env_rewards,
            )
        )
        prev_step = step
    return Rollout(turns=turns, status=prev_step.status)

async def do_group_rollout(task, task, group_size) -> RolloutGroup:
    example = task.sample_train_example()
    envs = task.make_envs(group_size=group_size)
    for env in envs:
        rollout = do_single_rollout(env, example)
    rollout_group = RolloutGroup(rollouts=rollouts)
    rewards = task.score_group(rollout_group)
    ...

The overall picture — what a Task bundles, and who talks to whom:

Task — one config bundle
   ├─ train / validation datasets
   ├─ make_envs(env_input)
   ├─ rubric (RewardFn, …)
   └─ #TODO hold the rollout loop after generation fix

A rollout — the user side speaks messages, the model side speaks tokens:

   MessageEnv   ◄── messages ──►   RendererWrapperEnv   ◄── tokens ──►   generator (vLLM, n=1)

      • MessageEnv          user code, speaks messages, never sees tokens
      • RendererWrapperEnv  converts messages ⇄ tokens, and owns the
                            prompt-too-long / truncation / timeout checks

1. The env is two explicit layers: messages vs tokens

"step_completion sounds confusing" · "MessageEnv.step_message(msg) sounds confusing — is msg always from assistant?" · "what does 'a Renderer for tokens' mean" · "components sounds confusing, does rewards work?"

Change: The old RendererEnv read like an RL env, but it is a token-space wrapper around the user's message-space env. We named the two layers and gave both the same reset() / step() shape. Every model-output field is unified on assistant_*, every env field on env_*.

Before

class MessageEnv(abc.ABC):
    async def reset(self) -> ResetOutput: ...
    async def step_message(self, msg: Message) -> StepOutput: ...

class RendererEnv:                                    # "an Env that renders"?
    async def initial_prompt(self) -> TokenizedStepOutput: ...
    async def step_completion(self, completion) -> TokenizedStepOutput: ...

class RolloutTurn:
    prompt_token_ids: list[int]
    response_token_ids: list[int]
    response_logprobs: list[float]
    policy_version: int
    prompt_messages: list[Message]
    assistant_message: Message | None
    env_messages: list[Message]
    reward_components: dict[str, float]

After

class MessageEnv(Configurable, abc.ABC):             # user code: messages only, never tokens
    async def reset(self) -> MessageResetOutput: ...
    async def step(self, assistant_message: Message) -> MessageStepOutput: ...

class RendererWrapperEnv:                             # wraps a MessageEnv: messages <-> tokens via a Renderer
    async def reset(self) -> TokenizedStepOutput: ...
    async def step(self, completion: Completion) -> TokenizedStepOutput: ...

# full snapshot of tokens + messages
class RolloutTurn:                               # the full per-turn snapshot
    prompt_token_ids: list[int]                  # input
    prompt_messages: list[Message]
    assistant_token_ids: list[int]               # model output  (was response_token_ids)
    assistant_logprobs: list[float]              # (was response_logprobs)
    assistant_message: Message | None
    env_messages: list[Message]                  # env reply
    env_rewards: dict[str, float]                # (was reward_components)
    policy_version: int | None

The wrapper does tokens → messages (parse the completion) and messages → tokens (render the next prompt), and owns the "prompt too long / generation truncated / env timed out" checks, so the user's MessageEnv only ever sees a clean assistant message:

2. The Task owns its dataset, env, and rubric (config-driven)

"why this env doesn't own the sourcing of prompts / data?" · "why this is not RendererEnv.Config and owned by that class?" · "The relationship between Task and Env seems not clear" · "RLTrainer should just own trainer, generator, define task / env"

Change: A Task is now one config bundle — datasets, env, rubric, env-wrapper limits. The base builds everything and make_envs is generic. A concrete task is just typed defaults.

Before — datasets on the trainer, Task base owned only the rubric, a task_name → task map routed rows:

# RLTrainer.Config
train_dataset = SumDigitsDataset.Config(seed=42)
validation_dataset = SumDigitsDataset.Config(seed=99)
tasks = {"sum_digits": SumDigitsTask.Config()}

class Task(Configurable, abc.ABC):
    class Config:
        rubric: Rubric.Config                        # base owns ONLY the rubric
    @abc.abstractmethod
    def make_envs(self, *, example, group_size, renderer): ...

class SumDigitsTask(Task):
    class Config(Task.Config):
        rubric = SumDigitsRubric.Config()
        renderer_env_config = RendererEnvConfig()    # loose, on the subclass
    def make_envs(self, *, example, group_size, renderer):
        return [RendererEnv(message_env=SumDigitsEnv(env_input=example.env_input),
                            renderer=renderer, config=self.renderer_env_config)
                for _ in range(group_size)]

After — one task, the base is concrete, the env is part of the config:

# RLTrainer.Config
task = SumDigitsTask.Config()
group_size = 8                                        # explicit (was overloaded onto sampling.n, see §4)
# TODO: support multiple tasks for data mixing.

class Task(Configurable):                             # concrete, make_envs is generic
    class Config(Configurable.Config):
        train_dataset: Configurable.Config
        validation_dataset: Configurable.Config
        rubric: Rubric.Config
        message_env: MessageEnv.Config               # the env is part of the Task config
        env_wrapper_cfg: RendererWrapperEnv.Config = field(default_factory=RendererWrapperEnv.Config)
    def sample_train_example(self):
        return next(self._train_dataset)             # datasets are iterators
    def sample_validation_example(self):
        return next(self._validation_dataset)
    def make_envs(self, *, example, group_size, renderer):
        return [RendererWrapperEnv(message_env=self._message_env_config.build(env_input=example),
                                   renderer=renderer, config=self._env_wrapper_cfg)
                for _ in range(group_size)]

# override methods to customize the Task beyond the config
# e.g. manage a pool or envs instead of making them fresh
class SumDigitsTask(Task):
    class Config(Task.Config):
        train_dataset = SumDigitsDataset.Config(seed=42)
        validation_dataset = SumDigitsDataset.Config(seed=99)
        rubric = Rubric.Config(reward_fns=[RewardCorrect.Config(weight=1.0), RewardFormat.Config(weight=0.3)])
        message_env = SumDigitsEnv.Config()

3. Reward functions are Configurable and live in config

"can they just define a Rubric and use it in a config, instead of register" · "can we just put them in config" (weights) · "any relationship between reward and components… can reward be a cached_property of components?"

Change: A RewardFn is one configurable scoring criterion. A Rubric is the concrete scheme that builds and weights them. The common case needs no subclass and is CLI-tunable.

Before — subclass + register_funcs(), weights hardcoded:

class Rubric(Configurable, abc.ABC):
    @abc.abstractmethod
    def register_funcs(self) -> list[RewardFn]: ...           # RewardFn = dataclass(fn=…, weight=…)

class SumDigitsRubric(Rubric):
    def register_funcs(self):
        return [RewardFn(fn=reward_correct, weight=1.0), RewardFn(fn=reward_format, weight=0.3)]

AfterRewardFn is a Configurable callable, Rubric is concrete, the rubric is just config:

class RewardFn(Configurable, abc.ABC):
    class Config(Configurable.Config):
        weight: float = 1.0                                   # + any args a stateful reward needs
    async def __call__(self, rollout, env_input) -> float: ...

class Rubric(Configurable):
    class Config(Configurable.Config):
        reward_fns: list[RewardFn.Config] = field(default_factory=list)
        truncation_reward: float | None = None
        error_reward: float | None = None

A concrete reward function — the whole thing is the __call__ and a weight from config:

class RewardCorrect(RewardFn):
    """1.0 if the last [ANSWER] <n> equals the target, else 0.0."""
    class Config(RewardFn.Config):
        pass                                                  # only needs `weight`
    async def __call__(self, rollout, env_input: SumDigitsExample) -> float:
        matches = _ANSWER_RE.findall(last_assistant_text(rollout))
        if not matches:
            return 0.0
        return 1.0 if int(matches[-1]) == env_input.target else 0.0

# config:
Rubric.Config(reward_fns=[RewardCorrect.Config(weight=1.0), RewardFormat.Config(weight=0.3)])

On reward vs reward_breakdown — can reward be a cached_property of the breakdown? No: the rubric decides reward, and a custom score_group can set it however it likes (pairwise, rank-normalized, an LLM judge), so reward is not a function of the breakdown. reward_breakdown is the raw per-fn values, kept as its own field because callers consume it directly (e.g. per-component advantage), not just for logging. By default the Rubric happens to set reward to a normalized weighted sum of those values, but that is only the default.

4. Three small config knobs: group size, renderer, thread pool

"can we not let user config n and always fix it to be 1?"

done

· "Not clear what this means… avoid transformers dependency"

i rewrote the renderer config. Hopefully is clear now

· "This really sounds details that users shouldn't need to care" (the executor pool)

I removed from config and used a default with an educated guess based on number of rollouts

5. Completions map back by a stable id, and prompt-too-long is explicit

"comment is rollout_index, but code uses prompt_idx. What's the definition of prompt_idx?" · "not clear what 'overflow' means here"

Change: Each prompt is tagged with its rollout's sample_id.

Before

Completion.prompt_idx: int
Rollout.sample_idx: int
ordered = sorted(completions, key=lambda c: c.prompt_idx)        # brittle int sort

After

sample_id = f"{group_id}/sample={sample_idx}"                    # readable id at the boundary
Completion.request_id: str
Rollout.sample_id: str
returned_ids = [c.request_id for c in completions]
if returned_ids != request_ids:                                 # reorder / missing / n>1
    raise RuntimeError(...)

RolloutStatus.TRUNCATED_PROMPT_OVERFLOWTRUNCATED_PROMPT_TOO_LONG (and _is_prompt_overflow_is_prompt_too_long).

Felipe Mello added 5 commits June 3, 2026 08:45
- env_types/ -> environments/ (MessageEnv, TokenEnv); reset -> init
- generator output -> completion_* / parsed_completion_message; MessageInitOutput.init_prompt_messages
- make_envs -> make_env_group; env_wrapper_cfg -> token_env; num_prompts_per_step -> num_groups_per_rollout_batch
- SamplingConfig drops n + stop_token_ids (n=1 hardwired; stop tokens injected into VLLMGenerator)
- _terminal inlined; Rollout.status required; await the generate call; self._messages rebind; trimmed list() copies
- gen_metrics -> generation_metrics; max_concurrent_rollouts derivation comment
- C14: spell out shape hints ([L_prompt] -> [num_prompt_tokens], [M_*] -> [num_*_messages], etc.)
- C37: rename Task -> Rollouter (+ SumDigitsRollouter; RLTrainer.Config.task -> rollouter)
- move the tasks/ package into rollouts/ (rollouts/rollouter.py + rollouts/sum_digits/);
  Rollouter is imported from the submodule (rollouts.rollouter) to avoid a circular import
  via rollouts/__init__ (environments imports rollouts.types)
- TODO(naming): VLLMGenerator -> InferenceEngine left at the class
- environments/ -> environment/ (singular)
- rollouts/ -> rollout/ (now also holds rollouter.py)
- tasks/ dissolved -> examples/sum_digits/ (concrete problem); base Rollouter in rollout/rollouter.py
- Rollouter imported from the submodule (rollout.rollouter), not re-exported from rollout/__init__,
  to avoid an environment<->rollout import cycle
- update all import paths
…renderer

- all RendererConfig fields default None; build() overrides only the knobs the user set
  (non-None and supported), constructing the typed config via config_type(**args)
- name None/"auto"/unmapped -> auto-resolve or pass the renderer name straight through
- log the chosen renderer + args
@felipemello1

felipemello1 commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

@tianyu-l

the new layout

torchtitan/experiments/rl/
├── environment/        # message space ⇄ token space   (was env_types/)
│   ├── message.py      #   MessageEnv  (+ MessageInitOutput / MessageStepOutput)
│   └── token.py        #   TokenEnv    (+ TokenEnvOutput)
├── rollout/            # rollout data + the rollout producer   (was rollouts/)
│   ├── types.py        #   Rollout / RolloutGroup / RolloutTurn / RolloutStatus
│   ├── utils.py
│   └── rollouter.py    #   Rollouter           (was tasks/task.py → Task)
├── examples/           # concrete problems      (was tasks/)
│   └── sum_digits/     #   data.py / env.py / rubric.py / rollouter.py (SumDigitsRollouter)
├── actors/
│   ├── generator.py    #   VLLMGenerator       (TODO: → InferenceEngine)
│   └── trainer.py      #   PolicyTrainer
├── rubrics/rubric.py   #   Rubric / RubricOutput
├── grpo.py             #   RLTrainer + GRPOLoss (controller / train loop)
└── renderer.py  batcher.py  generate.py  config_registry.py  types.py

Goal for the code (multi-turn, once the generator supports it):

async def do_single_rollout(generator, env) -> Rollout:
    turns = []
    env_output = await env.init()
    while not env_output.status.is_terminal():
        prompt_token_ids = env_output.next_prompt_token_ids
        prompt_messages = env_output.next_prompt_messages
        completion = await generator.generate(prompt_token_ids)
        env_output = await env.step(completion)
        turns.append(
            RolloutTurn(
                prompt_token_ids=prompt_token_ids,
                prompt_messages=prompt_messages,
                completion_token_ids=completion.token_ids,         # generator  -> completion_*
                completion_logprobs=completion.token_logprobs,
                parsed_completion_message=env_output.parsed_completion_message,
                env_messages=env_output.env_messages,
                env_rewards=env_output.env_rewards,
            )
        )
    return Rollout(turns=turns, status=env_output.status)

1. environment/ package; the token wrapper is TokenEnv

"prefer full spelling… rl/environments/message.py" · "if it's doing more than rendering, the additional part would justify a broader name" · "if you follow MessageEnv, this should be called TokenEnv"

  • env_types/environment/ (message.py + token.py) (note: called it "environment" singular, to not confuse with a place to put environments like sumdigits)
  • RendererWrapperEnvTokenEnv (renders + parses + owns the rollout checks; now a Configurable)
  • TokenizedStepOutputTokenEnvOutput (not RawOutput/RenderedOutputMessage*/Token* names the space, which is the actual distinction)

2. init not reset; message vs prompt

"reset sounds like cleaning up an old state… maybe init" · "reserving prompt for the first message is ambiguous"

  • resetinit (these envs are single-use)
  • MessageResetOutput.prompt_messagesMessageInitOutput.init_prompt_messages (qualified with init so it's unambiguously the first-prompt messages)

3. the generator's output is completion

"assistant vs generator used interchangeably?" · "does completion.token_ids contain the full history?"

Everything the generator makes is completion_*:

  • assistant_token_ids / assistant_logprobscompletion_token_ids / completion_logprobs
  • assistant_messageparsed_completion_message

4. SamplingConfig is user knobs only

"can we not make this configurable" (n) · "I'd rather not put stop_token_ids in SamplingParams, send it from the engine constructor"

done — dropped both:

  • SamplingConfig.n → removed (n=1 is hardwired; group_size pre-expands the prompts)
  • SamplingConfig.stop_token_ids → removed (injected into the generator's constructor instead)
  • also killed the extra replace() in RLTrainer.__init__

5. TaskRollouter

"is it too bad if we call this RolloutGenerator, mimicking Dataloader in SFT… I'm even willing to rename Generator to InferenceEngine for this."

  • renamed TaskRollouter
  • also SumDigitsTaskSumDigitsRollouter
    -RLTrainer.Config.taskrollouter.
  • package moves (see the layout up top): tasks/ dissolved — the base Rollouter now lives in rollout/rollouter.py, concrete problems moved to examples/sum_digits/;
  • VLLMGeneratorInferenceEngine left as a # TODO(naming) at the class — the "rename Generator" you offered (broad cross-tree rename; separate PR).

6. rollouter construction

"make_env_group, as you have score_group" · "env is ambiguous"

  • make_envsmake_env_group
  • env_wrapper_cfgtoken_env

7. the nits

  • _terminal deleted (inlined)
  • Rollout.status → required (no COMPLETED default)
  • num_prompts_per_stepnum_groups_per_rollout_batch
  • gen_metricsgeneration_metrics
  • added a comment on how max_concurrent_rollouts is derived (and what concurrency it bounds)
  • TokenEnvOutput locals → env_output / env_init_output
  • .get()await on the rollout generate call
  • trimmed the unnecessary list() copies
  • truncation_reward/error_reward docstrings clarified
  • shape hints spelled out: [L_prompt]/[M_prompt]/[L_response][num_prompt_tokens]/[num_prompt_messages]/[num_completion_tokens] (also [num_env_messages]/[num_tools]/[num_turns])

8. kept on purpose (pushed back, with reasoning)

"prompt seems a term in the resource space; message is more of a user-facing thing -- can we not mix up the two?"

I decided to keep the prompt, since the MessageEnv is aware of the concept. For "init" it produces a prompt. For "step", it has to produce the next prompt. Let me know if you strongly disagree.

"this check should be added for init output as well?"

No — MessageStepOutput rejects assistant turns because a step reply is tool/user only. The init conversation can legitimately carry few-shot assistant turns.

"storing turns: list[RolloutTurn] in Rollout sounds quite redundant, O(n) -> O(n^2)."

Full-history-per-turn is a common layout (verifiers, rllm, tinker). We collapse these at Episode assembly. Keeping the full history is helpful to find divergences and branch (e.g. for compactation). If memory becomes an issue, then we can revisit. Let me know if you disagree.

"should they be put into a Reward"

we could do something like:
Rollout.reward = RubricOutput(reward, reward_breakdown)

I opted to keep it flat for now. Let me know if you strongly disagree.

"I'd rather we create a thin Renderer class inheriting prime-rl's renderer, and have this as Renderer.Config."

There is nothing really to subclass. The 'config_from_name' returns a model specific class, e.g. Qwen3Rendererconfig.
The correct API here would be to do:

renderer = create_renderer(tokenizer, Qwen3RendererConfig(enable_thinking=False))

We do:

class RendererConfig(Configurable.Config):
    def build(self, *, tokenizer_path: str) -> Renderer:
        renderer_config: Qwen3RendererConfig = config_from_name(renderer_name)
        ... # override params
        renderer = create_renderer(tokenizer, renderer_config_with_overrides)

"Can we do this (Thread pool) in Renderer's constructor, after we create one."

I don't think that loop-wide concurrency should be hidden in the renderer. Its better to be exposed in the controller, no? Let me know if you strongly disagree

follow-ups (TODOs in code, separate PRs)

  • generator_logprobsold_logprobs at the loss (vs policy_logprobs)
  • VLLMGeneratorInferenceEngine (TODO left at the class)

@wwwjn

wwwjn commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Thanks for the comments, it's much more clear

Goal for the code (multi-turn, once the generator supports it)

I might missed something, what's the gap for generator to support this "code goal"?

…otocol

- grpo.py: kept our refactored async _collect_rollouts (rollouter); upstream's fut.get()->await auto-merged
- test_grpo_metrics.py: kept deleted (our refactor removed it; upstream's version imports removed Step/Trajectory)
- test_generator.py: drop removed SamplingConfig(n=); set _stop_token_ids on the hand-built generator
- test_shutdown.py: stub the renderer/generator/rollouter that __init__ now builds
- RL tests: 91 passed, 3 skipped
@felipemello1

felipemello1 commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the comments, it's much more clear

Goal for the code (multi-turn, once the generator supports it)

I might missed something, what's the gap for generator to support this "code goal"?

allow multiple threads to call await generator.generate{prompt_request) independently.

Currently we have to batch all requests because our generator doesn't have any await that releases the lock for other threads to add more requests to the queue.

I will have a PR to fix that probably tomorrow

Comment on lines +20 to +21
if TYPE_CHECKING:
from torchtitan.experiments.rl.actors.generator import Completion

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this type checking?

"""Render the initial conversation into the first generator prompt."""
env_init_output = await self._message_env.init()

# Copy our running conversation so we avoid mutating previous states

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this "running conversation" mean?

if parsed.reasoning_content:
parsed_completion_message["reasoning_content"] = parsed.reasoning_content
if parsed.tool_calls:
parsed_completion_message["tool_calls"] = parsed.tool_calls

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means the model generate a <tool_use> token, right?

and I want to do a knowledge check of my mental model. say if we are training on a Coding task:

  • Do we need to plug both SandBoxEnv and MessageEnv into the TokenEnv
  • When there's a <tool_use>` token, TokenEnv will call the SandBoxEnv to execute the tool?

I guess my questions is: where will the tool calling happen in the future?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted with Claude, it says:

 In this design, MessageEnv is the problem-specific environment that receives the
  assistant’s parsed message and decides what happens next. For a coding task, that env
  would own or call the sandbox.

  VLLMGenerator
    |
    | completion tokens
    v
  TokenEnv
    |
    | Renderer.parse_response(...)
    | parsed assistant message
    v
  CodingMessageEnv.step(parsed_completion_message)
    |
    | extract code / tool call / patch / command
    v
  Coding Sandbox
    |
    | run tests, execute code, inspect files, enforce timeout
    v
  CodingMessageEnv
    |
    | returns tool/user messages + optional env_rewards + done
    v
  TokenEnv
    |
    | render next prompt tokens if not done
    v
  VLLMGenerator

I updated my mental model: TokenEnv is plumbing between Generator and message space

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I promise this is last push.

"""Maximum number of tokens to generate per completion."""


# TODO(naming): rename VLLMGenerator -> InferenceEngine.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think after we name Rollouter, this is less necessary.



@dataclass(kw_only=True, slots=True)
class MessageInitOutput:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you have TokenEnvOutput, I would mirror this to MessageEnvInitOutput, especially considering it contains not just messages but also tools.

"""Initial messages + tool specs from `MessageEnv.init`."""

init_prompt_messages: list[Message] # [num_prompt_messages]
"""The opening messages (e.g. [system, user])."""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention that we don't post_init check because assistant is legit

parsed_completion_message: Message | None = None
"""This turn's parsed completion message. `None` on init and on parse failure."""

env_messages: list[Message] = field(default_factory=list) # [num_env_messages]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part of next_prompt_messages? like

next_prompt_messages contains
[previous prompt_messages, parsed_completion_message, env_messages]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use completion_tokens vs completion_message?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason i added "parsed" is to make it evident that it did NOT come from the vllm completion, but rather it was parsed by the tokenEnv. Does this make sense? I am ok with renaming, but we lose this detail

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, but for perfect consistency it sounds we should do

  • prompt_message
  • tokenized_prompt_token_ids
  • completion_token_ids
  • detokenized_completion_message

There are a lot of redundancy so I'm not sure.

I would like to at least avoid inventing another word "parsed"

Comment on lines +218 to +220
self._messages = (
self._messages + [parsed_completion_message] + step_output.env_messages
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks I'm mostly correct.

@felipemello1 felipemello1 Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that i am thinking about it, this is at the wrong place. The MessageEnv should say what the "next_messages" are, e.g. for compacting history. Let me see how i can move this logic to MessageEnv. I guess this would mean i have to delete "env_messages" and just have it return "next_prompt_messages"

Comment on lines +844 to +846
completion_token_ids=completion.token_ids,
completion_logprobs=completion.token_logprobs,
policy_version=completion.policy_version,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you could've put these in TokenEnvOutput, so that we can consolidate RolloutTurn and TokenEnvOutput.

If so, I would strongly advocate that, in this version.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class TokenEnvOutput:
    # off by one
    next_prompt_token_ids: list[int] | None  # [num_prompt_tokens] or None
    next_prompt_messages: list[Message] | None = None  # [num_prompt_messages] or None

	# different 
    status: RolloutStatus

	# Same
    parsed_completion_message: Message | None = None
    env_messages: list[Message] = field(default_factory=list)  # [num_env_messages]
    env_rewards: dict[str, float] = field(default_factory=dict)

class RolloutTurn:
	# off by one
    prompt_token_ids: list[int]  # [num_prompt_tokens]
    prompt_messages: list[Message] = field(default_factory=list)  # [num_prompt_messages]

	# different 
    policy_version: int | None = None
    completion_token_ids: list[int]  # [num_completion_tokens]
   	completion_logprobs: list[float]  # [num_completion_tokens]

	# Same
    parsed_completion_message: Message | None = None
    env_messages: list[Message] = field(default_factory=list)  # [num_env_messages]
    env_rewards: dict[str, float] = field(default_factory=dict)

RolloutTurn is aware of the whole rollout loop. There may be other generators (e.g. a teacher), more metrics, some crazy logic. While the TokenEnvOutput is concerned only about the Env.

I think it would be too opinionated to do consolidate them and would put unnecessary responsability on the TokenEnv.

I do agree that there is some overlap. We could create some dataclasses to hold some of the fields that are common, e.g. Completion or tokens_with_logprobs can be an output of the generator and we do RolloutTurn.completion.

NextPrompt could hold next_prompt_token_ids and next_prompt_messages.

cc: @wwwjn in case you have an opinion as well

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to leave them separate right now

def make_env_group(
self,
*,
example: object,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not strongly opinionated

Suggested change
example: object,
sample: object,

self._token_env_config = config.token_env

# TODO: revisit this abstraction: should it return a sample or a dataset or an iterator?
def sample_train_example(self) -> object:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use "training" instead of "train" to be consistent

Suggested change
def sample_train_example(self) -> object:
def get_training_sample(self) -> object:

Comment thread torchtitan/experiments/rl/grpo.py Outdated
completions, generation_metrics = self._get_rank_0_value(
self.generator.generate.call(tokenized_prompts).get()

rollouter: Rollouter = self._rollouter

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer to not have this indirection.

Comment thread torchtitan/experiments/rl/grpo.py Outdated
self,
*,
group_state: _RolloutGroupState,
rollouter: Rollouter,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need to pass this in, you have self._rollouter

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for datamix we would need it, but all of this code in grpo will go under the rollout anyway, and in this case, it would just use self. Either way, for now, we have 1 rollout, so i agree

for group_state in valid_group_states:
for sample_idx, env_init_output in enumerate(group_state.env_init_outputs):
prompt_token_ids.append(env_init_output.next_prompt_token_ids or [])
request_ids.append(_sample_id(group_state.group_id, sample_idx))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A side note: I read from somewhere else the request_id needs to contain the DP rank id, otherwise each DP rank will have the same set of request_ids, which will causing collision request id error in vllm engine. I think it's a good practice to make the request id globally unique.

We can put a TODO here, wdty?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if i follow: The id is unique. I guess i could add some random letters to make sure.

But would we send the same sample_id to different dp ranks? Shouldnt each rank get a different set of ramples?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same prompt go to same DP rank for prefix cache reuse

next_completion_offset = completion_offset + len(
group_state.env_init_outputs
)
group_state.completions = completions[

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is completions sorted?

Sort and then slicing to group samples into samples are error-prone, can we get group_id from request_id and then group using group id?

@felipemello1 felipemello1 Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are sorted. I should fix it. But since this code will be delete/refactored in the next feel days, is it ok if i leave a todo for now?

)

@sl.log_trace_span("_run_single_rollout")
async def _run_single_rollout(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consideration of putting _run_group_rollout, _run_single_rollout in controller vs. in Rollouter?

if we put it in rollouter, I would image controller could simply call

self._rollouter.run_group_rollout(generator, sample_id, xxx)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thats the direction. I thought about already removing the "collect_rollouts" from grpo and put it in the Rollouter. I was going to leave it for the next PR refactor. Maybe we should land as is to avoid another round of review that will be wasted anyway due to the refactor?

Now i feel like i should have fixed the generator first before doing this PR haha

@felipemello1

felipemello1 commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

All requests addressed

Only one that I didnt was MessageEnv returning "next_prompt_messages" for history-editing cases, like compacting. This is not trivial. We would need a postprocessing function in rollout_to_episode to branch messages if history diverges, which I will add in the multiturn PR

Also, it might be easier in the single_rollout_loop to intercept TokenEnvOutput, run the compaction service and create a new TokenEnvOutput.

For now, i added a TODO

image

cc: @tianyu-l @wwwjn

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!!


Not sure if you missed #3453 (comment)

The benefit is that

Shouldn't block this PR. Happy to discuss more.

@felipemello1 felipemello1 merged commit f0c73ca into pytorch:main Jun 5, 2026
17 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants