Skip to content

feat(rl): gate-native async RL loop + flat Rollout contract#382

Draft
Hecate0821 wants to merge 4 commits intomainfrom
chengxi/rl-async-gate-v2
Draft

feat(rl): gate-native async RL loop + flat Rollout contract#382
Hecate0821 wants to merge 4 commits intomainfrom
chengxi/rl-async-gate-v2

Conversation

@Hecate0821
Copy link
Copy Markdown
Collaborator

@Hecate0821 Hecate0821 commented Apr 23, 2026

What this PR adds

An async RL training recipe (recipes/async_rl_loop.py) that decouples rollout sampling from optimizer steps, plus the supporting rollout package and a real GSM8K multi-turn example modeled after AReaL's examples/multi_turn_math/.

User-facing rollout contract (per-sample, matches AReaL/slime)

async def rollout_fn(row) -> RolloutSample | None: ...

The framework fans each row out to completions_per_prompt parallel calls and joins them by row id via GroupAssembler before handing the assembled PromptGroup to the trainer. AReaL's arun_episode and slime's generate use the same shape.

Major components

  • recipes/async_rl_loop.py — async loop driving rollout fan-out, group assembly, off-policy gating (row-granular staleness budget), the train step, and the inner PPO loop. Uses main's LossArgs Protocol (build_loss_fn(args), validate_loss_path(args) from refactor(losses): collapse build_loss_fn 9-kwarg call site to LossArgs Protocol #412) and is fixed to loss_path=\"client\" — this recipe runs the loss closure in Python via forward_backward_custom, no server-side builtin path.
  • utils/rl/async_train.py — low-level async primitives (RowRequest, run_async_rl_loop, _StalenessController).
  • utils/rl/rollout/ — rollout package: RolloutSample, Rollout, rollout_to_prompt_group, GroupAssembler, MessageTrajectoryAssembler (TITO-style multi-turn), TrajectoryAssembler (token-native), service Protocol + remote adapter, native trajectory analysis.
  • utils/dataloader.py — cursor-based dataloader with epoch shuffle.
  • utils/data.py::replicate_rows_for_epochsdeepcopy-per-epoch helper so rollout fns mutating their input row don't poison later epochs.
  • utils/timer.py::elapsed_timer — span-based timer for nested timing.

AReaL-parity knobs

  • Config.ppo_n_minibatches (default 1, mirrors rl_loop.py) — each rollout batch snapshots old_policy_logprobs once and runs K forward_backward + optim_step minibatches against that snapshot. K=1 reproduces legacy 1:1 behavior; K>1 makes the PPO ratio measure genuine inner-loop drift and the clip do real work. Collapses AReaL's gradient_accumulation_steps and ppo_n_minibatches into one knob.
  • Config.max_head_offpolicy_versions ≡ AReaL's max_head_offpolicyness (staleness budget, row-granular).
  • weight_sync_interval is pinned to 1 inside the recipe (constant _WEIGHT_SYNC_INTERVAL). Configurable in WeightSyncConfig for sync recipes; the async recipe ignores it because raising it just trades rollout staleness for sync wall-time, which is almost never worth it.

New WandB metrics

  • train/ppo_kl — intra-step KL between the current policy and the step-start old_policy_logprobs snapshot, averaged across the K inner minibatches. Slime/AReaL/OpenRLHF naming. ~0 when ppo_n_minibatches=1; >0 when the inner loop drifts.
  • perf/sample_wait_time — per-step seconds the trainer waited for the next batch to fill (was hardcoded 0.0).
  • perf/wait_time_ratiowait / (wait + train), the overall step ratio. With ppo_n_minibatches > 1 the denominator covers all K minibatches because they run sequentially before train_step returns.
  • perf/overlap_ratio1 - wait_time_ratio.

Examples

  • examples/rl/multi_turn_message_in/ — GSM8K multi-turn agent ported from AReaL. Up to N turns; on a wrong boxed answer, append AReaL's verbatim retry-prompt and let the model try again. Reward via math_verify with numeric fallback. prepare_data.py downloads openai/gsm8k from HuggingFace.
  • examples/rl/single_turn_token_in/ — token-in single-turn baseline.
  • examples/rl/vanilla_sampler.py — shared deployment sampler helper.

Loss subsystem changes

  • PromptGroup gains a per-sample prompt_lens: List[int] | None field for heterogeneous rollouts (multi-turn, tool branches) where each sample has a different prefix length. combine_prompt_groups prefers it when set.
  • _get_loss_mask reads loss_fn_inputs[\"weights\"] first (new SDK contract) with legacy fallback to \"loss_mask\".

SDK pin

fireworks-ai[training]>=1.2.0a65,<2 — includes the chunk-transient backoff/retry from a64+.

Tests

Unit tests

~600 unit tests pass on this branch. New coverage for this PR:

  • tests/unit/test_async_rl_train.py — full async loop behavior (staleness controller, batch assembly, dynamic filter, weight sync cadence, ppo_n_minibatches path).
  • tests/unit/test_group_assembler.py — row fan-in semantics.
  • tests/unit/test_gsm8k_reward.py — boxed-answer extraction + math_verify.
  • tests/unit/test_rollout_{types,assembler,message,helpers,trace}.py — rollout package.
  • tests/unit/test_cursor_dataloader.py, tests/unit/test_data_utils.py.

Integration test plan (GSM8K multi-turn)

End-to-end verification on the GSM8K multi-turn example in two phases, run sequentially. Phase 2 only proceeds after Phase 1 finishes cleanly. Both phases share the same cohort and group shape:

  • 200 rows of GSM8K (--max-rows 200)
  • 8 samples per prompt (--completions-per-prompt 8)
  • 8 prompts per optimizer step (--prompt-groups-per-step 8)
  • Dynamic filter on (--filter-constant-reward drops zero-advantage groups so the optimizer doesn't no-op on flat-reward batches)
  • WandB enabled

Setup (once):

cd training/examples/rl/multi_turn_message_in
python prepare_data.py            # downloads openai/gsm8k -> train.jsonl
export FIREWORKS_API_KEY=<your-key>
export WANDB_ENTITY=<your-entity>
export WANDB_PROJECT=gsm8k-mt-async

Phase 1 — ppo_n_minibatches=1, 1-version off-policy:

python train.py \
  --base-model accounts/fireworks/models/qwen3-1p5b-instruct \
  --tokenizer-model Qwen/Qwen2.5-1.5B-Instruct \
  --max-rows 200 \
  --completions-per-prompt 8 \
  --prompt-groups-per-step 8 \
  --max-head-offpolicy-versions 1 \
  --ppo-n-minibatches 1 \
  --filter-constant-reward

WandB run should show:

  • train/ppo_kl ≈ 0 (no inner-loop drift when K=1)
  • async/version_offset_max ≤ 1
  • perf/sample_wait_time and perf/wait_time_ratio populated and finite
  • rollout/reward trending upward across steps
  • train/mean_kl finite, no NaN/inf
  • Run terminates cleanly after the dataset is exhausted; final checkpoint promotable

Phase 1 success criteria: run finishes without error, reward improves, wait-ratio metric is non-zero (proves the new instrumentation works).

Phase 2 — ppo_n_minibatches=2, 2-version off-policy:

Only after Phase 1 succeeds:

python train.py \
  --base-model accounts/fireworks/models/qwen3-1p5b-instruct \
  --tokenizer-model Qwen/Qwen2.5-1.5B-Instruct \
  --max-rows 200 \
  --completions-per-prompt 8 \
  --prompt-groups-per-step 8 \
  --max-head-offpolicy-versions 2 \
  --ppo-n-minibatches 2 \
  --filter-constant-reward

WandB run should show:

  • train/ppo_kl > 0 (the second minibatch sees real drift from the snapshot)
  • train/ppo_clip_frac > 0 (PPO clipping fires)
  • async/version_offset_max ≤ 2 (staleness budget honored)
  • perf/wait_time_ratio ≤ Phase 1 (more concurrency headroom from the larger budget)
  • Optimizer-step count is ~2× the accepted-batch count (2 inner steps per batch)
  • rollout/reward trending upward
  • DCP cadence unchanged vs. Phase 1 (saves still on rollout-batch granularity, not optim-step granularity)

Phase 2 success criteria: run finishes without error, ppo_kl > 0 (proves the inner loop is actually running), version offset bounded by the budget.

References

@Hecate0821 Hecate0821 force-pushed the chengxi/rl-async-gate-v2 branch 2 times, most recently from 6b05eb6 to b66f45c Compare April 30, 2026 04:12
@Hecate0821 Hecate0821 force-pushed the chengxi/rl-async-gate-v2 branch 2 times, most recently from fbb69a1 to ef8998e Compare April 30, 2026 23:10
@Hecate0821 Hecate0821 force-pushed the chengxi/rl-async-gate-v2 branch 4 times, most recently from caab904 to 9aff7ae Compare May 1, 2026 20:31
Adds an async-pipelined RL training recipe (recipes/async_rl_loop.py)
that decouples rollout sampling from optimizer steps, plus the
supporting rollout package and a real GSM8K multi-turn example modeled
after AReaL's examples/multi_turn_math/.

User-facing rollout contract is per-sample (one trajectory per call),
matching AReaL's ``arun_episode`` and slime's ``generate``:

    async def rollout_fn(row) -> RolloutSample | None: ...

The framework fans each row out to ``completions_per_prompt`` parallel
calls and joins them by row id via ``GroupAssembler`` before handing
the assembled ``PromptGroup`` to the trainer.

Major components
- recipes/async_rl_loop.py: async loop driving rollout fan-out, group
  assembly, off-policy gating (row-granular staleness budget), and the
  train step.  Uses the new ``LossArgs`` Protocol from main's #412
  (build_loss_fn(args), validate_loss_path(args)) and is fixed to
  ``loss_path="client"`` -- this recipe runs the loss closure in Python
  via forward_backward_custom, no server-side builtin path.
- utils/rl/async_train.py: low-level async loop primitives (RowRequest,
  run_async_rl_loop, _StalenessController).
- utils/rl/rollout/: rollout package -- types (RolloutSample, Rollout,
  rollout_to_prompt_group), GroupAssembler, MessageTrajectoryAssembler
  (TITO-style multi-turn), TrajectoryAssembler (token-native), service
  Protocol + remote adapter, native trajectory analysis.
- utils/dataloader.py: cursor-based dataloader with epoch shuffle.
- utils/data.py::replicate_rows_for_epochs: deepcopy-per-epoch helper
  so rollout fns mutating their input row don't poison later epochs.
- utils/timer.py::elapsed_timer: span-based timer for nested timing.

Examples
- examples/rl/multi_turn_message_in/: GSM8K multi-turn agent ported from
  AReaL.  Up to N turns; on a wrong boxed answer, append AReaL's
  verbatim retry-prompt and let the model try again.  Reward via
  math_verify with numeric fallback.  prepare_data.py downloads
  openai/gsm8k from HuggingFace.
- examples/rl/single_turn_token_in/: token-in single-turn baseline.
- examples/rl/vanilla_sampler.py: shared deployment sampler helper
  (renamed from sampler.py for clarity).

PromptGroup gains a per-sample ``prompt_lens`` field for heterogeneous
rollouts (multi-turn, tool branches) where each sample has a different
prefix length.  ``combine_prompt_groups`` prefers it when set.
``_get_loss_mask`` reads ``loss_fn_inputs["weights"]`` first (new SDK
contract) with legacy fallback to ``"loss_mask"``.

SDK pin: fireworks-ai[training]>=1.2.0a65,<2 (includes the chunk
transient backoff/retry from a64+).

Tests: 596 unit tests pass.  GSM8K reward, group assembler, async loop,
and rollout adapter all covered.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Hecate0821 Hecate0821 force-pushed the chengxi/rl-async-gate-v2 branch from 9aff7ae to b6996ae Compare May 1, 2026 20:34
Hecate0821 and others added 3 commits May 1, 2026 14:46
…lag, ConcurrencyConfig

Cleans up the async RL recipe surface following review.

- Delete utils/rl/rollout/concurrency.py (FixedRequestGate, RequestGate
  Protocol, DEFAULT_REQUEST_GATE_CONCURRENCY) and its unit tests.  The
  per-HTTP gate was redundant: row-level scheduling in
  _StalenessController already enforces capacity (cfg.sample_max_concurrency)
  and staleness ((max_offpolicy + version + 1) * batch_size - inflight)
  via min(), so a third lower-level gate added complexity without
  buying anything.  Examples now use DeploymentSampler directly with
  no concurrency_controller wired in.
- Drop Config.provision_inference: deployment provisioning is implicit
  in RL (DeployConfig.deployment_id None vs set already encodes
  create-vs-reuse).  No external caller flipped this to False.  Removes
  8 sites of dead branching including ``or ""`` fallbacks on
  inference_base_url and inference_model.
- Drop Config.concurrency: ConcurrencyConfig was imported, defaulted,
  and never read in the async recipe.  sample_max_concurrency on
  Config is the actually-wired knob.
- Tighten tokenizer load: require deployment.tokenizer_model upfront,
  drop the conditional ``tokenizer = None; if cfg.deployment.tokenizer_model:``
  guard since the tokenizer is mandatory.
- pyproject: bump fireworks-ai[training] floor to 1.2.0a65 to match
  the SDK pin documented in the prior commit message.
- smoke imports: add new single_turn_token_in/multi_turn_message_in
  rollout modules.

Tests: 636 unit tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nc interval to 1

Brings the async RL recipe to AReaL parity for the inner-loop knob and surfaces
two new wandb metrics that were missing or dead.

- Add ``Config.ppo_n_minibatches`` (default 1) mirroring ``rl_loop.py``.  The
  per-step path now snapshots ``old_policy_logprobs`` once and runs K
  ``forward_backward + optim_step`` minibatches against that snapshot, so the
  PPO ratio measures genuine inner-loop drift.  DCP cadence stays in
  rollout-batch units so ``dcp_save_interval`` is unchanged.
- Pin ``weight_sync_interval`` to a module-level constant ``_WEIGHT_SYNC_INTERVAL = 1``.
  Raising it trades rollout staleness for sync wall-time, which is almost never
  worth it in fully-async RL.  ``WeightSyncConfig.weight_sync_interval`` is
  still honored by the sync recipes; the async recipe ignores it.
- Emit ``train/ppo_kl`` from ``run_loss_loop`` (mean of
  ``exp(diff) - diff - 1`` between current policy logprobs and
  ``old_policy_logprobs``).  Averaged across minibatches by the existing
  ``compute_step_metrics`` reducer.
- Populate ``perf/sample_wait_time`` and fix ``perf/wait_time_ratio`` for the
  async path: ``async_train.py`` tracks ``last_step_end`` so the trainer's
  per-step wait gap is recorded, and the recipe sets
  ``step_wall_time = wait + train`` before ``compute_step_metrics`` fires --
  so the ratio is wait/(wait+train), the overall step ratio.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…versions, dynamic filter on multi_turn_message_in

Adds three CLI flags to the GSM8K multi-turn example so the test plan in PR #382
is fully reproducible without source edits.

- ``--ppo-n-minibatches`` (default 1) -- inner PPO minibatch count.
- ``--max-head-offpolicy-versions`` (default 0) -- staleness budget.
- ``--filter-constant-reward`` -- drops prompt groups whose rewards are
  identical across all samples (GRPO advantage is 0 there, optimizer step
  is a no-op).  Standard async-RL hygiene; matches AReaL's ``adv filter``
  in spirit.

Also drops the ``WeightSyncConfig(weight_sync_interval=1)`` override since
the async recipe pins the interval to 1 internally now.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant