feat(rl): gate-native async RL loop + flat Rollout contract#382
Draft
Hecate0821 wants to merge 4 commits intomainfrom
Draft
feat(rl): gate-native async RL loop + flat Rollout contract#382Hecate0821 wants to merge 4 commits intomainfrom
Hecate0821 wants to merge 4 commits intomainfrom
Conversation
6b05eb6 to
b66f45c
Compare
3 tasks
fbb69a1 to
ef8998e
Compare
caab904 to
9aff7ae
Compare
Adds an async-pipelined RL training recipe (recipes/async_rl_loop.py)
that decouples rollout sampling from optimizer steps, plus the
supporting rollout package and a real GSM8K multi-turn example modeled
after AReaL's examples/multi_turn_math/.
User-facing rollout contract is per-sample (one trajectory per call),
matching AReaL's ``arun_episode`` and slime's ``generate``:
async def rollout_fn(row) -> RolloutSample | None: ...
The framework fans each row out to ``completions_per_prompt`` parallel
calls and joins them by row id via ``GroupAssembler`` before handing
the assembled ``PromptGroup`` to the trainer.
Major components
- recipes/async_rl_loop.py: async loop driving rollout fan-out, group
assembly, off-policy gating (row-granular staleness budget), and the
train step. Uses the new ``LossArgs`` Protocol from main's #412
(build_loss_fn(args), validate_loss_path(args)) and is fixed to
``loss_path="client"`` -- this recipe runs the loss closure in Python
via forward_backward_custom, no server-side builtin path.
- utils/rl/async_train.py: low-level async loop primitives (RowRequest,
run_async_rl_loop, _StalenessController).
- utils/rl/rollout/: rollout package -- types (RolloutSample, Rollout,
rollout_to_prompt_group), GroupAssembler, MessageTrajectoryAssembler
(TITO-style multi-turn), TrajectoryAssembler (token-native), service
Protocol + remote adapter, native trajectory analysis.
- utils/dataloader.py: cursor-based dataloader with epoch shuffle.
- utils/data.py::replicate_rows_for_epochs: deepcopy-per-epoch helper
so rollout fns mutating their input row don't poison later epochs.
- utils/timer.py::elapsed_timer: span-based timer for nested timing.
Examples
- examples/rl/multi_turn_message_in/: GSM8K multi-turn agent ported from
AReaL. Up to N turns; on a wrong boxed answer, append AReaL's
verbatim retry-prompt and let the model try again. Reward via
math_verify with numeric fallback. prepare_data.py downloads
openai/gsm8k from HuggingFace.
- examples/rl/single_turn_token_in/: token-in single-turn baseline.
- examples/rl/vanilla_sampler.py: shared deployment sampler helper
(renamed from sampler.py for clarity).
PromptGroup gains a per-sample ``prompt_lens`` field for heterogeneous
rollouts (multi-turn, tool branches) where each sample has a different
prefix length. ``combine_prompt_groups`` prefers it when set.
``_get_loss_mask`` reads ``loss_fn_inputs["weights"]`` first (new SDK
contract) with legacy fallback to ``"loss_mask"``.
SDK pin: fireworks-ai[training]>=1.2.0a65,<2 (includes the chunk
transient backoff/retry from a64+).
Tests: 596 unit tests pass. GSM8K reward, group assembler, async loop,
and rollout adapter all covered.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
9aff7ae to
b6996ae
Compare
…lag, ConcurrencyConfig Cleans up the async RL recipe surface following review. - Delete utils/rl/rollout/concurrency.py (FixedRequestGate, RequestGate Protocol, DEFAULT_REQUEST_GATE_CONCURRENCY) and its unit tests. The per-HTTP gate was redundant: row-level scheduling in _StalenessController already enforces capacity (cfg.sample_max_concurrency) and staleness ((max_offpolicy + version + 1) * batch_size - inflight) via min(), so a third lower-level gate added complexity without buying anything. Examples now use DeploymentSampler directly with no concurrency_controller wired in. - Drop Config.provision_inference: deployment provisioning is implicit in RL (DeployConfig.deployment_id None vs set already encodes create-vs-reuse). No external caller flipped this to False. Removes 8 sites of dead branching including ``or ""`` fallbacks on inference_base_url and inference_model. - Drop Config.concurrency: ConcurrencyConfig was imported, defaulted, and never read in the async recipe. sample_max_concurrency on Config is the actually-wired knob. - Tighten tokenizer load: require deployment.tokenizer_model upfront, drop the conditional ``tokenizer = None; if cfg.deployment.tokenizer_model:`` guard since the tokenizer is mandatory. - pyproject: bump fireworks-ai[training] floor to 1.2.0a65 to match the SDK pin documented in the prior commit message. - smoke imports: add new single_turn_token_in/multi_turn_message_in rollout modules. Tests: 636 unit tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…nc interval to 1 Brings the async RL recipe to AReaL parity for the inner-loop knob and surfaces two new wandb metrics that were missing or dead. - Add ``Config.ppo_n_minibatches`` (default 1) mirroring ``rl_loop.py``. The per-step path now snapshots ``old_policy_logprobs`` once and runs K ``forward_backward + optim_step`` minibatches against that snapshot, so the PPO ratio measures genuine inner-loop drift. DCP cadence stays in rollout-batch units so ``dcp_save_interval`` is unchanged. - Pin ``weight_sync_interval`` to a module-level constant ``_WEIGHT_SYNC_INTERVAL = 1``. Raising it trades rollout staleness for sync wall-time, which is almost never worth it in fully-async RL. ``WeightSyncConfig.weight_sync_interval`` is still honored by the sync recipes; the async recipe ignores it. - Emit ``train/ppo_kl`` from ``run_loss_loop`` (mean of ``exp(diff) - diff - 1`` between current policy logprobs and ``old_policy_logprobs``). Averaged across minibatches by the existing ``compute_step_metrics`` reducer. - Populate ``perf/sample_wait_time`` and fix ``perf/wait_time_ratio`` for the async path: ``async_train.py`` tracks ``last_step_end`` so the trainer's per-step wait gap is recorded, and the recipe sets ``step_wall_time = wait + train`` before ``compute_step_metrics`` fires -- so the ratio is wait/(wait+train), the overall step ratio. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…versions, dynamic filter on multi_turn_message_in Adds three CLI flags to the GSM8K multi-turn example so the test plan in PR #382 is fully reproducible without source edits. - ``--ppo-n-minibatches`` (default 1) -- inner PPO minibatch count. - ``--max-head-offpolicy-versions`` (default 0) -- staleness budget. - ``--filter-constant-reward`` -- drops prompt groups whose rewards are identical across all samples (GRPO advantage is 0 there, optimizer step is a no-op). Standard async-RL hygiene; matches AReaL's ``adv filter`` in spirit. Also drops the ``WeightSyncConfig(weight_sync_interval=1)`` override since the async recipe pins the interval to 1 internally now. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this PR adds
An async RL training recipe (
recipes/async_rl_loop.py) that decouples rollout sampling from optimizer steps, plus the supporting rollout package and a real GSM8K multi-turn example modeled after AReaL'sexamples/multi_turn_math/.User-facing rollout contract (per-sample, matches AReaL/slime)
The framework fans each row out to
completions_per_promptparallel calls and joins them by row id viaGroupAssemblerbefore handing the assembledPromptGroupto the trainer. AReaL'sarun_episodeand slime'sgenerateuse the same shape.Major components
recipes/async_rl_loop.py— async loop driving rollout fan-out, group assembly, off-policy gating (row-granular staleness budget), the train step, and the inner PPO loop. Uses main'sLossArgsProtocol (build_loss_fn(args),validate_loss_path(args)from refactor(losses): collapse build_loss_fn 9-kwarg call site to LossArgs Protocol #412) and is fixed toloss_path=\"client\"— this recipe runs the loss closure in Python viaforward_backward_custom, no server-side builtin path.utils/rl/async_train.py— low-level async primitives (RowRequest,run_async_rl_loop,_StalenessController).utils/rl/rollout/— rollout package:RolloutSample,Rollout,rollout_to_prompt_group,GroupAssembler,MessageTrajectoryAssembler(TITO-style multi-turn),TrajectoryAssembler(token-native), serviceProtocol+ remote adapter, native trajectory analysis.utils/dataloader.py— cursor-based dataloader with epoch shuffle.utils/data.py::replicate_rows_for_epochs—deepcopy-per-epoch helper so rollout fns mutating their input row don't poison later epochs.utils/timer.py::elapsed_timer— span-based timer for nested timing.AReaL-parity knobs
Config.ppo_n_minibatches(default 1, mirrorsrl_loop.py) — each rollout batch snapshotsold_policy_logprobsonce and runs Kforward_backward + optim_stepminibatches against that snapshot. K=1 reproduces legacy 1:1 behavior; K>1 makes the PPO ratio measure genuine inner-loop drift and the clip do real work. Collapses AReaL'sgradient_accumulation_stepsandppo_n_minibatchesinto one knob.Config.max_head_offpolicy_versions≡ AReaL'smax_head_offpolicyness(staleness budget, row-granular).weight_sync_intervalis pinned to1inside the recipe (constant_WEIGHT_SYNC_INTERVAL). Configurable inWeightSyncConfigfor sync recipes; the async recipe ignores it because raising it just trades rollout staleness for sync wall-time, which is almost never worth it.New WandB metrics
train/ppo_kl— intra-step KL between the current policy and the step-startold_policy_logprobssnapshot, averaged across the K inner minibatches. Slime/AReaL/OpenRLHF naming. ~0 whenppo_n_minibatches=1; >0 when the inner loop drifts.perf/sample_wait_time— per-step seconds the trainer waited for the next batch to fill (was hardcoded0.0).perf/wait_time_ratio—wait / (wait + train), the overall step ratio. Withppo_n_minibatches > 1the denominator covers all K minibatches because they run sequentially beforetrain_stepreturns.perf/overlap_ratio—1 - wait_time_ratio.Examples
examples/rl/multi_turn_message_in/— GSM8K multi-turn agent ported from AReaL. Up to N turns; on a wrong boxed answer, append AReaL's verbatim retry-prompt and let the model try again. Reward viamath_verifywith numeric fallback.prepare_data.pydownloadsopenai/gsm8kfrom HuggingFace.examples/rl/single_turn_token_in/— token-in single-turn baseline.examples/rl/vanilla_sampler.py— shared deployment sampler helper.Loss subsystem changes
PromptGroupgains a per-sampleprompt_lens: List[int] | Nonefield for heterogeneous rollouts (multi-turn, tool branches) where each sample has a different prefix length.combine_prompt_groupsprefers it when set._get_loss_maskreadsloss_fn_inputs[\"weights\"]first (new SDK contract) with legacy fallback to\"loss_mask\".SDK pin
fireworks-ai[training]>=1.2.0a65,<2— includes the chunk-transient backoff/retry from a64+.Tests
Unit tests
~600 unit tests pass on this branch. New coverage for this PR:
tests/unit/test_async_rl_train.py— full async loop behavior (staleness controller, batch assembly, dynamic filter, weight sync cadence, ppo_n_minibatches path).tests/unit/test_group_assembler.py— row fan-in semantics.tests/unit/test_gsm8k_reward.py— boxed-answer extraction + math_verify.tests/unit/test_rollout_{types,assembler,message,helpers,trace}.py— rollout package.tests/unit/test_cursor_dataloader.py,tests/unit/test_data_utils.py.Integration test plan (GSM8K multi-turn)
End-to-end verification on the GSM8K multi-turn example in two phases, run sequentially. Phase 2 only proceeds after Phase 1 finishes cleanly. Both phases share the same cohort and group shape:
--max-rows 200)--completions-per-prompt 8)--prompt-groups-per-step 8)--filter-constant-rewarddrops zero-advantage groups so the optimizer doesn't no-op on flat-reward batches)Setup (once):
Phase 1 —
ppo_n_minibatches=1, 1-version off-policy:WandB run should show:
train/ppo_kl≈ 0 (no inner-loop drift when K=1)async/version_offset_max≤ 1perf/sample_wait_timeandperf/wait_time_ratiopopulated and finiterollout/rewardtrending upward across stepstrain/mean_klfinite, no NaN/infPhase 1 success criteria: run finishes without error, reward improves, wait-ratio metric is non-zero (proves the new instrumentation works).
Phase 2 —
ppo_n_minibatches=2, 2-version off-policy:Only after Phase 1 succeeds:
WandB run should show:
train/ppo_kl> 0 (the second minibatch sees real drift from the snapshot)train/ppo_clip_frac> 0 (PPO clipping fires)async/version_offset_max≤ 2 (staleness budget honored)perf/wait_time_ratio≤ Phase 1 (more concurrency headroom from the larger budget)rollout/rewardtrending upwardPhase 2 success criteria: run finishes without error,
ppo_kl > 0(proves the inner loop is actually running), version offset bounded by the budget.References