From 67fcde52554c3fc64d6ffbaada5693fbf0bafd40 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 30 Apr 2026 10:34:59 -0700 Subject: [PATCH] Refactor algo in one place Move all loss function, advantage estimator into algo_core. So both agentic rl and non-agentic rl share same algorithms. 1. Combine agentic grpo and grpo loss fn and advantage estimator. 2. Use np.array for group advantage for faster computation, details in cl/845477002 PiperOrigin-RevId: 908254319 --- examples/deepswe/train_deepswe_nb.py | 6 + tests/rl/agentic/agentic_grpo_learner_test.py | 9 +- tests/rl/function_registry_test.py | 2 +- tests/rl/grpo/dapo_learner_test.py | 6 +- tests/rl/grpo/drgrpo_learner_test.py | 3 +- tests/rl/grpo/grpo_learner_test.py | 5 +- tests/rl/ppo/ppo_helpers_test.py | 2 +- tunix/rl/agentic/agentic_grpo_learner.py | 218 +------ tunix/rl/agentic/agentic_rl_learner.py | 2 +- tunix/rl/algo_core.py | 564 ++++++++++++++++++ tunix/rl/function_registry.py | 14 + tunix/rl/grpo/dapo_learner.py | 3 +- tunix/rl/grpo/drgrpo_learner.py | 19 +- tunix/rl/grpo/grpo_learner.py | 146 +---- tunix/rl/ppo/ppo_learner.py | 164 +---- 15 files changed, 619 insertions(+), 544 deletions(-) create mode 100644 tunix/rl/algo_core.py diff --git a/examples/deepswe/train_deepswe_nb.py b/examples/deepswe/train_deepswe_nb.py index e847910fd..1b02bf23a 100644 --- a/examples/deepswe/train_deepswe_nb.py +++ b/examples/deepswe/train_deepswe_nb.py @@ -78,6 +78,12 @@ parser.add_argument("--epsilon", type=float, default=0.2) parser.add_argument("--epsilon_high", type=float, default=0.28) parser.add_argument("--off_policy_steps", type=int, default=0) +parser.add_argument( + "--advantage_estimator", + type=str, + default="rloo", + choices=["grpo", "rloo"], +) # Rollout Config parser.add_argument("--max_prompt_length", type=int, default=4096) diff --git a/tests/rl/agentic/agentic_grpo_learner_test.py b/tests/rl/agentic/agentic_grpo_learner_test.py index 6dcced541..cb577c21c 100644 --- a/tests/rl/agentic/agentic_grpo_learner_test.py +++ b/tests/rl/agentic/agentic_grpo_learner_test.py @@ -39,6 +39,7 @@ import optax import orbax.checkpoint as ocp from tunix.generate import tokenizer_adapter +from tunix.rl import algo_core from tunix.rl import common as rl_common from tunix.rl import function_registry from tunix.rl import rl_cluster as rl_cluster_lib @@ -1906,17 +1907,13 @@ def _patch_process_results( def test_compute_rloo_advantages(self): rewards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - advantages = agentic_grpo_learner.compute_rloo_advantages( - rewards, num_generations=3 - ) + advantages = algo_core.compute_rloo_advantages(rewards, num_generations=3) expected_value = jnp.array([-1.5, 0.0, 1.5, -1.5, 0.0, 1.5]) np.testing.assert_allclose(advantages, expected_value) def test_compute_rloo_advantages_low_generations(self): rewards = jnp.array([1.0, 2.0]) - advantages = agentic_grpo_learner.compute_rloo_advantages( - rewards, num_generations=1 - ) + advantages = algo_core.compute_rloo_advantages(rewards, num_generations=1) np.testing.assert_allclose(advantages, jnp.zeros_like(rewards)) diff --git a/tests/rl/function_registry_test.py b/tests/rl/function_registry_test.py index fac8af764..51bc3dc8a 100644 --- a/tests/rl/function_registry_test.py +++ b/tests/rl/function_registry_test.py @@ -52,7 +52,7 @@ def test_custom_categories_instance(self): def test_empty_categories_instance(self): # Test-specific instance for empty categories registry = function_registry.FunctionRegistry(allowed_categories=[]) - self.assertLen(registry.list_categories(), 3) + self.assertLen(registry.list_categories(), 4) @parameterized.named_parameters( dict( diff --git a/tests/rl/grpo/dapo_learner_test.py b/tests/rl/grpo/dapo_learner_test.py index 71667f690..cd6847e4d 100644 --- a/tests/rl/grpo/dapo_learner_test.py +++ b/tests/rl/grpo/dapo_learner_test.py @@ -89,9 +89,11 @@ def test_diff_loss(self): rngs=nnx.Rngs(0), ) - # Call DAPO loss function + # Call DAPO loss function (DAPO sets ref_per_token_logps to None as it doesn't fetch it) + dapo_train_example = self.create_train_example() + dapo_train_example.ref_per_token_logps = None dapo_loss, dapo_aux = dapo_loss_fn_impl( - model, train_example, dapo_config, pad_id, eos_id + model, dapo_train_example, dapo_config, pad_id, eos_id ) # Call GRPO loss function diff --git a/tests/rl/grpo/drgrpo_learner_test.py b/tests/rl/grpo/drgrpo_learner_test.py index 8146f216a..38565b101 100644 --- a/tests/rl/grpo/drgrpo_learner_test.py +++ b/tests/rl/grpo/drgrpo_learner_test.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp import numpy as np +from tunix.rl import algo_core # pylint: disable=unused-import from tunix.rl import function_registry as fr from tunix.rl.grpo import drgrpo_learner as drgrpo_lib from tunix.rl.grpo import grpo_learner as grpo_lib @@ -135,7 +136,7 @@ def test_compute_advantages(self): rewards = jnp.array( [[0.57450044, 0.09968603, 0.7419659, 0.8941783, 0.59656656, 0.45325184]] ) - advantages = drgrpo_lib.compute_advantages(rewards, num_generations=3) + advantages = algo_core.compute_drgrpo_advantages(rewards, num_generations=3) expected_array = jnp.array([ [0.10245, -0.372365, 0.269915, 0.246179, -0.051432, -0.194747], ]) diff --git a/tests/rl/grpo/grpo_learner_test.py b/tests/rl/grpo/grpo_learner_test.py index 2680a160e..811856d2e 100644 --- a/tests/rl/grpo/grpo_learner_test.py +++ b/tests/rl/grpo/grpo_learner_test.py @@ -33,6 +33,7 @@ import orbax.checkpoint as ocp from tunix.perf import trace as trace_lib from tunix.perf.experimental import tracer as perf_tracer_v2 +from tunix.rl import algo_core as grpo_core from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.grpo import grpo_learner as grpo_lib from tunix.rl.queue import data_queue as queue_lib @@ -1234,9 +1235,9 @@ def test_compute_advantages(self): rng = jax.random.PRNGKey(0) rewards = jax.random.uniform(rng, shape=(1, 6)) - advantages = grpo_lib.compute_advantages(rewards, num_generations=3) + advantages = grpo_core.compute_advantages(rewards, num_generations=3) expected_value = jnp.array( - [[0.307407, -1.117304, 0.809897, 1.094044, -0.22857, -0.865474]] + [[0.307498, -1.117636, 0.810138, 1.094526, -0.228671, -0.865855]] ) np.testing.assert_allclose(advantages, expected_value, rtol=1e-5, atol=1e-5) diff --git a/tests/rl/ppo/ppo_helpers_test.py b/tests/rl/ppo/ppo_helpers_test.py index 87e51b4f0..29ebb6ea1 100644 --- a/tests/rl/ppo/ppo_helpers_test.py +++ b/tests/rl/ppo/ppo_helpers_test.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp import numpy as np -from tunix.rl.ppo import ppo_helpers +from tunix.rl import algo_core as ppo_helpers def _ref_compute_gae_advantages( diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index 1e635d9b9..9d49aa5c6 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -37,6 +37,7 @@ import jax import jax.numpy as jnp import numpy as np +from tunix.rl import algo_core # pylint: disable=unused-import from tunix.perf.experimental import constants as perf_constants from tunix.rl import common from tunix.rl import function_registry @@ -88,8 +89,8 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig): """ algo_variant: str = "agentic_grpo" - advantage_estimator: str = "agentic_grpo" - policy_loss_fn: str = "agentic_grpo" + advantage_estimator: str = "grpo" + policy_loss_fn: str = "grpo" loss_agg_mode: str = "sequence-mean-token-mean" loss_algo: ( str @@ -559,218 +560,5 @@ def _process_results( return [combined_batch] -@function_registry.register_policy_loss_fn("agentic_grpo") -def grpo_loss_fn( - model, - train_example, - algo_config, - pad_id, - eos_id, -): - """GRPO loss function. - - The loss aims to maximize the expected advantage of the chosen actions while - constraining the policy updates to stay within a certain range of the - reference policy. - - Args: - model: The policy model to be trained. - train_example: A `TrainExample` instance containing the processed input - data, including prompt IDs, completion IDs, masks, advantages, and - per-token log probabilities from the reference and policy models. - algo_config: The algorithm config. - pad_id: The pad ID from tokenizer. - eos_id: The eos ID from. - - Returns: - A tuple containing the loss and an aux dictionary. - """ - beta = algo_config.beta - epsilon = algo_config.epsilon - loss_algo = algo_config.loss_algo - epsilon_high = ( - algo_config.epsilon_high - if hasattr(algo_config, "epsilon_high") - else epsilon - ) - epsilon_c = ( - algo_config.epsilon_c - if hasattr(algo_config, "epsilon_c") - else 3.0 - ) - loss_aggregation_mode = algo_config.loss_agg_mode - - completion_ids, completion_mask = ( - train_example.completion_ids, - train_example.completion_mask, - ) - - # TODO(tsbao): split can be avoided with updated peft_trainer model handling. - graphdef, state = nnx.split(model) - per_token_logps, logits = common.compute_per_token_logps( - graphdef, - state, - prompt_tokens=train_example.prompt_ids, - completion_tokens=completion_ids, - pad_id=pad_id, - eos_id=eos_id, - completion_mask=completion_mask, - stop_gradient=False, - return_logits=True, - segment_ids=getattr(train_example, "segment_ids", None), - segment_positions=getattr(train_example, "segment_positions", None), - temperature=algo_config.temperature, - ) - per_token_logps = jnp.astype(per_token_logps, jnp.float32) - # TODO(tsbao): We should handle token level advantages. - advantages = jnp.astype(train_example.advantages, jnp.float32) - - if train_example.old_per_token_logps is None: - old_per_token_logps = jax.lax.stop_gradient(per_token_logps) - else: - old_per_token_logps = jnp.astype( - train_example.old_per_token_logps, jnp.float32 - ) - - seq_importance_ratio = per_token_logps - old_per_token_logps - # Record KL divergence before clipping. - ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask) - - seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0) - - # TODO(sizhi): Refactor this to a separate function. - if loss_algo == "gspo-token": - seq_importance_ratio = (seq_importance_ratio * completion_mask).sum( - axis=-1 - ) / jnp.clip(completion_mask.sum(-1), min=1) - seq_importance_ratio = ( - per_token_logps - - jax.lax.stop_gradient(per_token_logps) - + jnp.expand_dims(jax.lax.stop_gradient(seq_importance_ratio), axis=-1) - ) - seq_importance_ratio = jnp.clip(seq_importance_ratio, max=10.0) - - is_ratio = jnp.exp(seq_importance_ratio) - - # Advantages must be broadcast against seq_length. - # When sequence packing is used, advantages are already 2D [B, seq_length]. - # When unpacked, they are 1D [B]. - adv = advantages if advantages.ndim == 2 else jnp.expand_dims(advantages, 1) - - pg_loss_1 = -adv * is_ratio - pg_loss_2 = -adv * jnp.clip(is_ratio, 1 - epsilon, 1 + epsilon_high) - - per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32) - - clipped_fraction = ppo_helpers.masked_mean( - jnp.greater(pg_loss_2, pg_loss_1), completion_mask - ) - - # dual-clip ppo loss - pg_loss_3 = -epsilon_c * adv - - # pg_clipfrac_lower measures how often dual-clip ppo kicks in. - # It kicks in when the standard clipped loss is larger than pg_loss_3 - # for instances with negative advantages. - unreduced_pg_clipfrac_lower = ( - (per_token_loss > pg_loss_3) & (adv < 0.0) - ).astype(jnp.float32) - pg_clipfrac_lower = common.aggregate_loss( - unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode - ) - - pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss) - per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss) - loss = common.aggregate_loss( - per_token_loss, completion_mask, loss_aggregation_mode - ) - aux = { - "kl": 0.0, - "kl_loss": 0.0, - "pg_loss": loss, - "pg_clipfrac": clipped_fraction, - "ppo_kl": ppo_kl, - "pg_clipfrac_lower": pg_clipfrac_lower, - } - # We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless - # force_compute_kl is True). - if train_example.ref_per_token_logps is not None: - kl = common.compute_kl_divergence( - per_token_logps, - train_example.ref_per_token_logps, - algo_config.kl_loss_mode, - ) - # Log mean KL. - aux["kl"] = jnp.astype( - (kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1), - jnp.float32, - ) - kl_loss = common.aggregate_loss( - kl, completion_mask, loss_aggregation_mode - ) - aux["kl_loss"] = kl_loss - if beta is not None and beta != 0.0: - loss = loss + beta * kl_loss - - token_entropy = ppo_helpers.compute_entropy_from_logits(logits) - entropy_loss = common.aggregate_loss( - token_entropy, completion_mask, loss_aggregation_mode - ) - aux["entropy"] = entropy_loss - - return loss, aux - - -@function_registry.register_advantage_estimator("agentic_grpo") -def compute_advantages(rewards: jax.Array, num_generations: int) -> jax.Array: - """Compute group relative advantages. - - Args: - rewards: reward functions output. - num_generations: Number of generations. - - Returns: - Group relative advantages. - """ - rewards = jnp.astype(rewards, jnp.float32) - mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=-1) - std_grouped_rewards = rewards.reshape(-1, num_generations).std( - axis=-1, ddof=1 - ) - - mean_grouped_rewards = mean_grouped_rewards.repeat(num_generations) - std_grouped_rewards = std_grouped_rewards.repeat(num_generations) - return (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-6) - - -@function_registry.register_advantage_estimator("agentic_rloo") -def compute_rloo_advantages( - rewards: jax.Array, num_generations: int -) -> jax.Array: - """Compute RLOO (REINFORCE Leave-One-Out) advantages. - - RLOO computes a baseline for each completion by averaging the rewards of all - other completions to the same prompt. - - Args: - rewards: reward functions output. - num_generations: Number of generations. - - Returns: - RLOO advantages. - """ - if num_generations < 2: - # RLOO requires at least 2 samples to calculate a baseline. - return jnp.zeros_like(rewards) - - reshaped_rewards = rewards.reshape(-1, num_generations) - loo_mean = ( - reshaped_rewards.sum(axis=-1, keepdims=True) - reshaped_rewards - ) / (num_generations - 1) - rloo_advantages = reshaped_rewards - loo_mean - - return rloo_advantages.flatten() - - GrpoConfig = GRPOConfig GrpoLearner = GRPOLearner diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index 54c8ebfa3..7ebd5074e 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -47,7 +47,7 @@ from tunix.rl.agentic.environments import base_environment from tunix.rl.agentic.environments import task_environment from tunix.rl.agentic.pipeline import rollout_orchestrator -from tunix.rl.agentic.rewards import reward +from tunix.rl.agentic.rewards import reward # pylint: disable=unused-import from tunix.rl.agentic.trajectory import trajectory_collect_engine from tunix.rl.queue import data_queue as queue_lib from tunix.sft import utils as sft_utils diff --git a/tunix/rl/algo_core.py b/tunix/rl/algo_core.py new file mode 100644 index 000000000..0b2a6f419 --- /dev/null +++ b/tunix/rl/algo_core.py @@ -0,0 +1,564 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm core implementations for RL and Agentic RL learners.""" + +import functools +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +from tunix.rl import common +from tunix.rl import function_registry + + +registry = function_registry.default_registry + +# ============================================================================== +# Utils +# ============================================================================== + +@registry.register("advantage_estimator", "gae") +@jax.jit +def compute_gae_advantages( + rewards: jax.Array, + values: jax.Array, + completion_mask: jax.Array, + gamma: float, + gae_lambda: float, +) -> tuple[jax.Array, jax.Array]: + """Compute advantages using Generalized Advantage Estimation (GAE). + + Computing GAE is a two-step process: + + First, compute the temporal difference (TF), `δ_t`, for each timestep `t`: + + ``` + δ_t = r_t + γ * V(s_{t+1}) - V(s_t) + ``` + + Then, compute the GAE advantage, `A_t`, by summing the discounted TD + residuals. It is calculated recursively, starting from the last timestep: + + ``` + A_t = δ_t + (γ * λ) * A_{t+1} + ``` + + where: + + - `A_t` is the GAE advantage at timestep `t`. + - `δ_t` is the temporal difference at timestep `t`. + - `γ` is the discount factor. + - `λ` is the GAE lambda parameter. + - `V(s_t)` is the value function at timestep `t`. + - `r_t` is the reward at timestep `t`. + + Args: + rewards: A 2D array of rewards for each step in the rollout. + values: A 2D array of value estimates from the critic for each step. + completion_mask: A 2D mask, which is 0 for padding tokens. + gamma: The discount factor, `γ`. + gae_lambda: The GAE lambda parameter, `λ`. + + Returns: + A tuple of two 2D arrays - advantages and returns for each step. + """ + batch_size = values.shape[0] + + def gae_step(state_t_plus_1, xs): + # Unpack state and inputs. + gae_t_plus_1, next_values = state_t_plus_1 + rewards_t, values_t, mask_t = xs + + # Compute Temporal Difference (TD). + delta = rewards_t + gamma * next_values - values_t + # Compute GAE for this time step. + gae_t = delta + gamma * gae_lambda * gae_t_plus_1 + + # Skip values on non-completion tokens. + next_values = values_t * mask_t + (1 - mask_t) * next_values + gae_t = gae_t * mask_t + (1 - mask_t) * gae_t_plus_1 + + # New state to carry over comprises `gae_t` and `next_values`. Output for + # this step is `gae_t`. + return (gae_t, next_values), gae_t + + _, advantages_transposed = jax.lax.scan( + gae_step, + init=(jnp.zeros((batch_size,)), jnp.zeros((batch_size,))), + xs=( + jnp.transpose(jnp.array(rewards)), + jnp.transpose(jnp.array(values)), + jnp.transpose(jnp.array(completion_mask)), + ), + reverse=True, + ) + advantages = jnp.transpose(advantages_transposed) + returns = advantages + values + + # Normalise advantages. + advantages = masked_whiten(advantages, completion_mask) + return advantages, returns + + +@jax.jit +def masked_whiten( + x: jax.Array, + completion_mask: jax.Array, +) -> jax.Array: + """Normalize the input array.""" + x_mean = masked_mean(x, completion_mask) + x_var = masked_var( + x, + completion_mask, + x_mean, + ) + x = (x - x_mean) * jax.lax.rsqrt(x_var + 1e-8) + return x + + +@functools.partial(jax.jit, static_argnames=('axis',)) +def masked_mean( + x: jax.Array, mask: jax.Array, axis: int | None = None +) -> jax.Array: + """Compute the mean of a masked array.""" + cast_mask = mask.astype(x.dtype) + return jnp.sum(x * cast_mask, axis=axis) / ( + jnp.sum(cast_mask, axis=axis) + 1e-8 + ) + + +@jax.jit +def masked_var( + x: jax.Array, + mask: jax.Array, + mean: jax.Array | None = None, +) -> jax.Array: + """Compute the variance of a masked array.""" + cast_mask = mask.astype(x.dtype) + if mean is None: + mean = masked_mean(x, cast_mask) + + variance = masked_mean(jnp.square(x - mean), cast_mask) + + mask_sum = cast_mask.sum() + bessel_corr = mask_sum / (mask_sum - 1) + return variance * bessel_corr + + +def compute_entropy_from_logits(logits: jax.Array) -> jax.Array: + """Computes the entropy of a distribution given its logits. + + Args: + logits: Logits as returned by the model. Of shape `[batch_size, seq_len, + emb_dim]`. + + Returns: + A JAX array of shape `[batch_size, seq_len]`, containing the entropy values. + """ + log_probs = jax.nn.log_softmax(logits, axis=-1) + probs = jax.nn.softmax(log_probs) + return -jnp.sum(probs * log_probs, axis=-1) + + +# ============================================================================== +# PPO Core +# ============================================================================== + + +@function_registry.register_policy_loss_fn("ppo") +def ppo_policy_loss_fn( + model, + train_example, + algo_config, + pad_id, + eos_id, + **kwargs, +): + """PPO policy loss function.""" + epsilon_low = algo_config.epsilon_low + epsilon_high = algo_config.epsilon_high + entropy_coef = algo_config.entropy_coef + + completion_ids = train_example.completion_ids + completion_mask = train_example.completion_mask + + graphdef, state = nnx.split(model) + per_token_logps, logits = common.compute_per_token_logps( + graphdef, + state, + prompt_tokens=train_example.prompt_ids, + completion_tokens=completion_ids, + pad_id=pad_id, + eos_id=eos_id, + completion_mask=completion_mask, + stop_gradient=False, + return_logits=True, + ) + + advantages = train_example.advantages + old_per_token_logps = train_example.old_per_token_logps + + seq_importance_ratio = jnp.exp(per_token_logps - old_per_token_logps) + + # Compute pg_clipfrac + pg_losses_1 = -seq_importance_ratio * advantages + pg_losses_2 = ( + -jnp.clip(seq_importance_ratio, 1 - epsilon_low, 1 + epsilon_high) + * advantages + ) + + # add dual clip logic + epsilon_c = getattr(algo_config, "epsilon_c", 3.0) + if epsilon_c is None: + epsilon_c = 3.0 + pg_loss_3 = -epsilon_c * advantages + + per_token_loss = jnp.maximum(pg_losses_1, pg_losses_2) + unreduced_pg_clipfrac_lower = ( + (per_token_loss > pg_loss_3) & (advantages < 0.0) + ).astype(jnp.float32) + pg_clipfrac_lower = masked_mean(unreduced_pg_clipfrac_lower, completion_mask) + + pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss) + pg_losses = jnp.where(advantages < 0.0, pg_loss_clipped_dual, per_token_loss) + + aux = { + "pg_clipfrac": masked_mean( + jnp.greater(pg_losses_2, pg_losses_1), completion_mask + ), + "pg_clipfrac_lower": pg_clipfrac_lower, + } + + policy_loss = masked_mean(pg_losses, completion_mask) + loss = policy_loss + + if entropy_coef is not None and entropy_coef != 0.0: + token_entropy = compute_entropy_from_logits(logits) + entropy_loss = masked_mean(token_entropy, completion_mask) + loss = loss - entropy_coef * entropy_loss + aux["loss/entropy"] = entropy_loss + + # kl penalty term logic as before + kl_coef = getattr(algo_config, "kl_coef", 0.0) + if kl_coef > 0.0 and train_example.ref_per_token_logps is not None: + kl = common.compute_kl_divergence( + per_token_logps, train_example.ref_per_token_logps, "kl" + ) + kl_loss = masked_mean(kl, completion_mask) + loss = loss + kl_coef * kl_loss + aux["kl"] = kl_loss + + return loss, aux + + +@function_registry.register_value_loss_fn("ppo") +def ppo_value_loss_fn( + model: nnx.Module, + train_example, + clip_range_value: float | None, + pad_id: int, + eos_id: int, +): + """Computes the value loss for PPO.""" + + prompt_ids, completion_ids, completion_mask = ( + train_example.prompt_ids, + train_example.completion_ids, + train_example.completion_mask, + ) + # ====== Loss ====== + values = train_example.old_values + returns = train_example.returns + + segment_ids = getattr(train_example, "segment_ids", None) + if segment_ids is not None: + # For packed sequences, prompt_ids is empty and completion_ids holds the full sequence. + # We predict values for token t using the model's output at t-1. + logits_to_keep = completion_ids.shape[1] - 1 + else: + logits_to_keep = completion_ids.shape[1] + + # Get new values. + vpreds = common.compute_score( + model, + prompt_ids, + completion_ids, + pad_id, + eos_id, + stop_gradient=False, + segment_ids=segment_ids, + segment_positions=getattr(train_example, "segment_positions", None), + ) + vpreds = vpreds[:, -logits_to_keep - 1 : -1] + + if segment_ids is not None: + # Pad the first token's value with 0.0, since it has no preceding token to predict it. + vpreds = jnp.pad(vpreds, ((0, 0), (1, 0)), constant_values=0.0) + vpred_clipped = jnp.clip( + vpreds, values - clip_range_value, values + clip_range_value + ) + vf_losses1 = jnp.square(vpreds - returns) + vf_losses2 = jnp.square(vpred_clipped - returns) + + clipped_vf_losses = jnp.maximum(vf_losses1, vf_losses2) + # "token mean" style of normalisation. + vf_loss = 0.5 * masked_mean(clipped_vf_losses, completion_mask) + + aux = { + "vf_loss": vf_loss, + "vpred_mean": masked_mean(vpreds, completion_mask), + "vf_clipfrac": masked_mean( + jnp.greater(vf_losses2, vf_losses1), completion_mask + ), + "return_mean": masked_mean(returns, completion_mask), + } + + return vf_loss, aux + + +# ============================================================================== +# GRPO Core +# ============================================================================== + + +@function_registry.register_policy_loss_fn("grpo") +def grpo_loss_fn( + model, + train_example, + algo_config, + pad_id, + eos_id, + **kwargs, +): + """GRPO loss function. + + The loss aims to maximize the expected advantage of the chosen actions while + constraining the policy updates to stay within a certain range of the + reference policy. + + Args: + model: The policy model to be trained. + train_example: A `TrainExample` instance containing the processed input + data, including prompt IDs, completion IDs, masks, advantages, and + per-token log probabilities from the reference and policy models. + algo_config: The algorithm config. + pad_id: The pad ID from tokenizer. + eos_id: The eos ID from. + + Returns: + A tuple containing the loss and an aux dictionary. + """ + beta = algo_config.beta + epsilon = algo_config.epsilon + loss_algo = algo_config.loss_algo + epsilon_high = ( + algo_config.epsilon_high + if hasattr(algo_config, "epsilon_high") + else epsilon + ) + epsilon_c = ( + algo_config.epsilon_c if hasattr(algo_config, "epsilon_c") else 3.0 + ) + loss_aggregation_mode = algo_config.loss_agg_mode + + completion_ids, completion_mask = ( + train_example.completion_ids, + train_example.completion_mask, + ) + + # TODO(tsbao): split can be avoided with updated peft_trainer model handling. + graphdef, state = nnx.split(model) + per_token_logps, logits = common.compute_per_token_logps( + graphdef, + state, + prompt_tokens=train_example.prompt_ids, + completion_tokens=completion_ids, + pad_id=pad_id, + eos_id=eos_id, + completion_mask=completion_mask, + stop_gradient=False, + return_logits=True, + segment_ids=getattr(train_example, "segment_ids", None), + segment_positions=getattr(train_example, "segment_positions", None), + temperature=algo_config.temperature, + ) + per_token_logps = jnp.astype(per_token_logps, jnp.float32) + # TODO(tsbao): We should handle token level advantages. + advantages = jnp.astype(train_example.advantages, jnp.float32) + + if train_example.old_per_token_logps is None: + old_per_token_logps = jax.lax.stop_gradient(per_token_logps) + else: + old_per_token_logps = jnp.astype( + train_example.old_per_token_logps, jnp.float32 + ) + + seq_importance_ratio = per_token_logps - old_per_token_logps + # Record KL divergence before clipping. + ppo_kl = masked_mean(-seq_importance_ratio, completion_mask) + + seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0) + + # TODO(sizhi): Refactor this to a separate function. + if loss_algo == "gspo-token": + seq_importance_ratio = (seq_importance_ratio * completion_mask).sum( + axis=-1 + ) / jnp.clip(completion_mask.sum(-1), min=1) + seq_importance_ratio = ( + per_token_logps + - jax.lax.stop_gradient(per_token_logps) + + jnp.expand_dims(jax.lax.stop_gradient(seq_importance_ratio), axis=-1) + ) + seq_importance_ratio = jnp.clip(seq_importance_ratio, max=10.0) + + is_ratio = jnp.exp(seq_importance_ratio) + + # Advantages must be broadcast against seq_length. + # When sequence packing is used, advantages are already 2D [B, seq_length]. + # When unpacked, they are 1D [B]. + adv = advantages if advantages.ndim == 2 else jnp.expand_dims(advantages, 1) + + pg_loss_1 = -adv * is_ratio + pg_loss_2 = -adv * jnp.clip(is_ratio, 1 - epsilon, 1 + epsilon_high) + + per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32) + + clipped_fraction = masked_mean( + jnp.greater(pg_loss_2, pg_loss_1), completion_mask + ) + + # dual-clip ppo loss + pg_loss_3 = -epsilon_c * adv + + # pg_clipfrac_lower measures how often dual-clip ppo kicks in. + # It kicks in when the standard clipped loss is larger than pg_loss_3 + # for instances with negative advantages. + unreduced_pg_clipfrac_lower = ( + (per_token_loss > pg_loss_3) & (adv < 0.0) + ).astype(jnp.float32) + pg_clipfrac_lower = common.aggregate_loss( + unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode + ) + + pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss) + per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss) + loss = common.aggregate_loss( + per_token_loss, completion_mask, loss_aggregation_mode + ) + aux = { + "kl": 0.0, + "kl_loss": 0.0, + "pg_loss": loss, + "pg_clipfrac": clipped_fraction, + "ppo_kl": ppo_kl, + "pg_clipfrac_lower": pg_clipfrac_lower, + } + # We do not always compute KL divergence (e.g. when beta is 0.0 unless + # force_compute_kl is True). + if train_example.ref_per_token_logps is not None: + kl = common.compute_kl_divergence( + per_token_logps, + train_example.ref_per_token_logps, + algo_config.kl_loss_mode, + ) + # Log mean KL. + aux["kl"] = jnp.astype( + (kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1), + jnp.float32, + ) + kl_loss = common.aggregate_loss(kl, completion_mask, loss_aggregation_mode) + aux["kl_loss"] = kl_loss + if beta is not None and beta != 0.0: + loss = loss + beta * kl_loss + + token_entropy = compute_entropy_from_logits(logits) + entropy_loss = common.aggregate_loss( + token_entropy, completion_mask, loss_aggregation_mode + ) + aux["entropy"] = entropy_loss + + return loss, aux + + +@function_registry.register_advantage_estimator("grpo") +def compute_advantages(rewards: np.ndarray, num_generations: int) -> np.ndarray: + """Compute group relative advantages. + + Args: + rewards: reward functions output. + num_generations: Number of generations. + + Returns: + Group relative advantages. + """ + mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=-1) + std_grouped_rewards = rewards.reshape(-1, num_generations).std( + axis=-1, ddof=1 + ) + + mean_grouped_rewards = mean_grouped_rewards.repeat(num_generations) + std_grouped_rewards = std_grouped_rewards.repeat(num_generations) + return (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + +@function_registry.register_advantage_estimator("rloo") +def compute_rloo_advantages( + rewards: jax.Array, num_generations: int +) -> jax.Array: + """Compute RLOO (REINFORCE Leave-One-Out) advantages. + + RLOO computes a baseline for each completion by averaging the rewards of all + other completions to the same prompt. + + Args: + rewards: reward functions output. + num_generations: Number of generations. + + Returns: + RLOO advantages. + """ + if num_generations < 2: + # RLOO requires at least 2 samples to calculate a baseline. + return jnp.zeros_like(rewards) + + reshaped_rewards = rewards.reshape(-1, num_generations) + loo_mean = ( + reshaped_rewards.sum(axis=-1, keepdims=True) - reshaped_rewards + ) / (num_generations - 1) + rloo_advantages = reshaped_rewards - loo_mean + + return rloo_advantages.flatten() + + +# ============================================================================== +# DrGRPO Core +# ============================================================================== + + +@function_registry.register_advantage_estimator("drgrpo") +def compute_drgrpo_advantages( + rewards: jax.Array, num_generations: int +) -> jax.Array: + """Group relative advantages -- done right. + + Args: + rewards: reward functions output. + num_generations: Number of generations. + + Returns: + Group relative advantages. + """ + mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=1) + return rewards - mean_grouped_rewards.repeat(num_generations) diff --git a/tunix/rl/function_registry.py b/tunix/rl/function_registry.py index 6d4e7b866..1707ff803 100644 --- a/tunix/rl/function_registry.py +++ b/tunix/rl/function_registry.py @@ -18,6 +18,7 @@ from absl import logging _POLICY_LOSS_FN_CATEGORY = "policy_loss_fn" +_VALUE_LOSS_FN_CATEGORY = "value_loss_fn" _ADVANTAGE_ESTIMATOR_CATEGORY = "advantage_estimator" _REWARD_MANAGER_CATEGORY = "reward_manager" @@ -27,6 +28,7 @@ class FunctionRegistry: DEFAULT_ALLOWED_CATEGORIES: FrozenSet[str] = frozenset({ _POLICY_LOSS_FN_CATEGORY, + _VALUE_LOSS_FN_CATEGORY, _ADVANTAGE_ESTIMATOR_CATEGORY, _REWARD_MANAGER_CATEGORY, }) @@ -152,3 +154,15 @@ def register_reward_manager( ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Returns a decorator to register a reward manager function by name.""" return default_registry.register(_REWARD_MANAGER_CATEGORY, name) + + +def get_value_loss_fn(name: str) -> Callable[..., Any]: + """Returns the value loss function by name.""" + return default_registry.get(_VALUE_LOSS_FN_CATEGORY, name) + + +def register_value_loss_fn( + name: str, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Returns a decorator to register a value loss function by name.""" + return default_registry.register(_VALUE_LOSS_FN_CATEGORY, name) diff --git a/tunix/rl/grpo/dapo_learner.py b/tunix/rl/grpo/dapo_learner.py index 8b318d1ee..f42a6f510 100644 --- a/tunix/rl/grpo/dapo_learner.py +++ b/tunix/rl/grpo/dapo_learner.py @@ -24,7 +24,7 @@ MetricFn = rl_learner.MetricFn -@dataclasses.dataclass(slots=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class DAPOConfig(grpo_learner_lib.GRPOConfig): """Configuration for DAPO. @@ -73,6 +73,7 @@ class DAPOConfig(grpo_learner_lib.GRPOConfig): ) def __post_init__(self): + self.beta = None if self.epsilon_high < self.epsilon: raise ValueError("epsilon_high must be greater than or equal to epsilon.") diff --git a/tunix/rl/grpo/drgrpo_learner.py b/tunix/rl/grpo/drgrpo_learner.py index 3c025e2b7..dffc3bb43 100644 --- a/tunix/rl/grpo/drgrpo_learner.py +++ b/tunix/rl/grpo/drgrpo_learner.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions for GRPO Trainer.""" + import dataclasses import jax +from tunix.rl import algo_core from tunix.rl import function_registry from tunix.rl import rl_learner from tunix.rl.grpo import grpo_learner as grpo_learner_lib @@ -23,22 +25,7 @@ MetricFn = rl_learner.MetricFn -@function_registry.register_advantage_estimator("drgrpo") -def compute_advantages(rewards: jax.Array, num_generations: int) -> jax.Array: - """Group relative advantages -- done right. - - Args: - rewards: reward functions output. - num_generations: Number of generations. - - Returns: - Group relative advantages. - """ - mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=1) - return rewards - mean_grouped_rewards.repeat(num_generations) - - -@dataclasses.dataclass(slots=True, kw_only=True) +@dataclasses.dataclass(kw_only=True) class DrGRPOConfig(grpo_learner_lib.GRPOConfig): """Configuration for DrGRPO.""" diff --git a/tunix/rl/grpo/grpo_learner.py b/tunix/rl/grpo/grpo_learner.py index 5d6b64d0a..8da9887c2 100644 --- a/tunix/rl/grpo/grpo_learner.py +++ b/tunix/rl/grpo/grpo_learner.py @@ -24,6 +24,7 @@ import jax import jax.numpy as jnp import numpy as np +from tunix.rl import algo_core # pylint: disable=unused-import from tunix.generate import utils from tunix.perf.experimental import constants as perf_constants from tunix.rl import algorithm_config as algo_config_lib @@ -62,6 +63,8 @@ class GRPOConfig(algo_config_lib.AlgorithmConfig): beta: The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. This term prevents policy updates from deviating too far from the reference model. A value of 0.0 means no KL penalty is applied. + kl_loss_mode: The divergence mode used for KL penalty estimation. Default: + `kl`. epsilon: Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, it ensures stable updates. epsilon_high: Epsilon value for upper bound clipping. @@ -88,6 +91,7 @@ class GRPOConfig(algo_config_lib.AlgorithmConfig): num_generations: int = 2 num_iterations: int = 1 beta: float = 0.04 + kl_loss_mode: str = "kl" epsilon: float = 0.2 def __post_init__(self): @@ -452,147 +456,5 @@ def train( # pylint: disable=useless-parent-delegation super().train(train_ds, eval_ds, skip_jit) -@function_registry.register_policy_loss_fn("grpo") -def grpo_loss_fn( - model, - train_example, - algo_config, - pad_id, - eos_id, -): - """GRPO loss function. - - The loss aims to maximize the expected advantage of the chosen actions while - constraining the policy updates to stay within a certain range of the - reference policy. - - Args: - model: The policy model to be trained. - train_example: A `TrainExample` instance containing the processed input - data, including prompt IDs, completion IDs, masks, advantages, and - per-token log probabilities from the reference and policy models. - algo_config: The algorithm config. - pad_id: The pad ID from tokenizer. - eos_id: The eos ID from. - - Returns: - A tuple containing the loss and an aux dictionary. - """ - beta = algo_config.beta - epsilon = algo_config.epsilon - loss_algo = algo_config.loss_algo - epsilon_high = ( - algo_config.epsilon_high - if hasattr(algo_config, "epsilon_high") - else epsilon - ) - loss_aggregation_mode = algo_config.loss_agg_mode - - completion_ids, completion_mask = ( - train_example.completion_ids, - train_example.completion_mask, - ) - - # TODO(yangmu): trace this part as "actor_inference_and_training". - # with perf_tracer.span("...", list(completion_ids.devices())): - graphdef, state = nnx.split(model) - per_token_logps = common.compute_per_token_logps( - graphdef, - state, - prompt_tokens=train_example.prompt_ids, - completion_tokens=completion_ids, - pad_id=pad_id, - eos_id=eos_id, - stop_gradient=False, - return_logits=False, - segment_ids=getattr(train_example, "segment_ids", None), - segment_positions=getattr(train_example, "segment_positions", None), - temperature=algo_config.temperature, - ) - advantages = train_example.advantages - - if train_example.old_per_token_logps is None: - old_per_token_logps = jax.lax.stop_gradient(per_token_logps) - else: - old_per_token_logps = train_example.old_per_token_logps - - seq_importance_ratio = per_token_logps - old_per_token_logps - # TODO(sizhi): Refactor this to a separate function. - if loss_algo == "gspo-token": - seq_importance_ratio = (seq_importance_ratio * completion_mask).sum( - axis=-1 - ) / jnp.clip(completion_mask.sum(-1), min=1) - seq_importance_ratio = ( - per_token_logps - - jax.lax.stop_gradient(per_token_logps) - + jnp.expand_dims(jax.lax.stop_gradient(seq_importance_ratio), axis=-1) - ) - seq_importance_ratio = jnp.clip(seq_importance_ratio, max=10.0) - - coef_1 = jnp.exp(seq_importance_ratio) - coef_2 = jnp.clip(coef_1, 1 - epsilon, 1 + epsilon_high) - - # Advantages must be broadcast against seq_length. - # When sequence packing is used, advantages are already 2D [B, seq_length]. - # When unpacked, they are 1D [B]. - adv = advantages if advantages.ndim == 2 else jnp.expand_dims(advantages, 1) - - # Compute pg_clipfrac - pg_losses_1 = -coef_1 * adv - pg_losses_2 = -coef_2 * adv - pg_clipfrac = jnp.sum( - (pg_losses_2 > pg_losses_1) * completion_mask - ) / jnp.clip(jnp.sum(completion_mask), min=1) - - # TODO(tsbao): We should handle token level advantages. - per_token_loss = -jnp.minimum( - coef_1 * adv, - coef_2 * adv, - ) - - # add KL penalty - mean_kl = 0.0 - if beta is not None and beta != 0.0: - kl = common.compute_kl_divergence( - per_token_logps, train_example.ref_per_token_logps - ) - per_token_loss = per_token_loss + beta * kl - mean_kl = (kl * completion_mask).sum() / jnp.clip( - completion_mask.sum(), min=1 - ) - - aux = { - "kl": mean_kl, - "pg_clipfrac": pg_clipfrac, - } - - loss = common.aggregate_loss( - per_token_loss, completion_mask, loss_aggregation_mode - ) - - return loss, aux - - -@function_registry.register_advantage_estimator("grpo") -def compute_advantages(rewards: np.ndarray, num_generations: int) -> np.ndarray: - """Compute group relative advantages. - - Args: - rewards: reward functions output. - num_generations: Number of generations. - - Returns: - Group relative advantages. - """ - mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=-1) - std_grouped_rewards = rewards.reshape(-1, num_generations).std( - axis=-1, ddof=1 - ) - - mean_grouped_rewards = mean_grouped_rewards.repeat(num_generations) - std_grouped_rewards = std_grouped_rewards.repeat(num_generations) - return (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) - - GrpoConfig = GRPOConfig GrpoLearner = GRPOLearner diff --git a/tunix/rl/ppo/ppo_learner.py b/tunix/rl/ppo/ppo_learner.py index 512d40675..d6b3663db 100644 --- a/tunix/rl/ppo/ppo_learner.py +++ b/tunix/rl/ppo/ppo_learner.py @@ -25,12 +25,13 @@ import jax.numpy as jnp import numpy as np from tunix.generate import utils +from tunix.rl import algo_core as ppo_helpers from tunix.rl import algorithm_config as algo_config_lib from tunix.rl import common from tunix.rl import function_registry from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl import rl_learner -from tunix.rl.ppo import ppo_helpers + TrainingInputT = rl_learner.TrainingInputT RewardFn = rl_learner.RewardFn @@ -80,6 +81,7 @@ class PPOConfig(algo_config_lib.AlgorithmConfig): algo_variant: str = "ppo" advantage_estimator: str = "gae" policy_loss_fn: str = "ppo" + value_loss_fn: str = "ppo" reward_manager: str = "sequence-level" num_iterations: int = 1 @@ -189,17 +191,17 @@ def __init__( self.rl_cluster.actor_trainer.with_gen_model_input_fn( lambda x: { "train_example": x, - "epsilon_low": self.algo_config.epsilon_low, - "epsilon_high": self.algo_config.epsilon_high, - "epsilon_c": self.algo_config.epsilon_c, - "entropy_coef": self.algo_config.entropy_coef, + "algo_config": self.algo_config, "pad_id": self.rl_cluster.rollout.pad_id(), "eos_id": self.rl_cluster.rollout.eos_id(), } ) # ===== Configure the critic (value) trainer ===== - self.rl_cluster.critic_trainer.with_loss_fn(ppo_value_loss_fn, has_aux=True) + value_loss_fn = registry.get( + "value_loss_fn", self.algo_config.value_loss_fn + ) + self.rl_cluster.critic_trainer.with_loss_fn(value_loss_fn, has_aux=True) self.rl_cluster.critic_trainer.with_gen_model_input_fn( lambda x: { "train_example": x, @@ -527,155 +529,5 @@ def train( # pylint: disable=useless-parent-delegation super().train(train_ds, eval_ds, skip_jit) -def ppo_value_loss_fn( - model: nnx.Module, - train_example: TrainExample, - clip_range_value: float | None, - pad_id: int, - eos_id: int, -): - """Computes the value loss for PPO.""" - - prompt_ids, completion_ids, completion_mask = ( - train_example.prompt_ids, - train_example.completion_ids, - train_example.completion_mask, - ) - # ====== Loss ====== - values = train_example.old_values - returns = train_example.returns - - segment_ids = getattr(train_example, "segment_ids", None) - if segment_ids is not None: - # For packed sequences, prompt_ids is empty and completion_ids holds the full sequence. - # We predict values for token t using the model's output at t-1. - logits_to_keep = completion_ids.shape[1] - 1 - else: - logits_to_keep = completion_ids.shape[1] - - # Get new values. - vpreds = common.compute_score( - model, - prompt_ids, - completion_ids, - pad_id, - eos_id, - stop_gradient=False, - segment_ids=segment_ids, - segment_positions=getattr(train_example, "segment_positions", None), - ) - vpreds = vpreds[:, -logits_to_keep - 1 : -1] - - if segment_ids is not None: - # Pad the first token's value with 0.0, since it has no preceding token to predict it. - vpreds = jnp.pad(vpreds, ((0, 0), (1, 0)), constant_values=0.0) - vpred_clipped = jnp.clip( - vpreds, values - clip_range_value, values + clip_range_value - ) - vf_losses1 = jnp.square(vpreds - returns) - vf_losses2 = jnp.square(vpred_clipped - returns) - - clipped_vf_losses = jnp.maximum(vf_losses1, vf_losses2) - # "token mean" style of normalisation. - vf_loss = ppo_helpers.masked_mean(clipped_vf_losses, completion_mask) - vf_loss = 0.5 * vf_loss - - aux = { - "vpred_mean": ppo_helpers.masked_mean(vpreds, completion_mask), - "vf_clipfrac": ppo_helpers.masked_mean( - (vf_losses2 > vf_losses1).astype(jnp.float32), completion_mask - ), - } - return vf_loss, aux - - -@registry.register("policy_loss_fn", "ppo") -def ppo_policy_loss_fn( - model: nnx.Module, - train_example: TrainExample, - epsilon_low: float, - epsilon_high: float, - epsilon_c: float | None, - entropy_coef: float | None, - pad_id: int, - eos_id: int, -): - """Computes the policy loss for PPO.""" - - prompt_ids, completion_ids, completion_mask = ( - train_example.prompt_ids, - train_example.completion_ids, - train_example.completion_mask, - ) - use_dual_clip_ppo = epsilon_c is not None - - # Get log probs. - graphdef, state = nnx.split(model) - per_token_logps, logits = common.compute_per_token_logps( - graphdef, - state, - prompt_tokens=prompt_ids, - completion_tokens=completion_ids, - pad_id=pad_id, - eos_id=eos_id, - stop_gradient=False, - return_logits=True, - segment_ids=getattr(train_example, "segment_ids", None), - segment_positions=getattr(train_example, "segment_positions", None), - ) - - advantages = train_example.advantages - - # Compute ratio. - old_per_token_logps = train_example.old_per_token_logps - ratio = jnp.exp(per_token_logps - old_per_token_logps) - ratio_clipped = jnp.clip(ratio, 1 - epsilon_low, 1 + epsilon_high) - - # Vanilla PPO loss - pg_losses_1 = -ratio * advantages - pg_losses_2 = -ratio_clipped * advantages - clip_pg_losses_1 = jnp.maximum(pg_losses_1, pg_losses_2) - - # Dual-clip PPO to avoid negative-advantage policy updates - pg_losses = clip_pg_losses_1 - if use_dual_clip_ppo: - pg_losses_3 = -epsilon_c * advantages - clip_pg_losses_2 = jnp.minimum(pg_losses_3, clip_pg_losses_1) - - pg_losses = jnp.where(advantages < 0.0, clip_pg_losses_2, clip_pg_losses_1) - - # For logging. - unreduced_pg_clipfrac_lower = ( - (clip_pg_losses_1 > pg_losses_3) & (advantages < 0.0) - ).astype(jnp.float32) - pg_clipfrac_lower = ppo_helpers.masked_mean( - unreduced_pg_clipfrac_lower, completion_mask - ) - - # Logging - aux = { - "pg_clipfrac": ppo_helpers.masked_mean( - (pg_losses_2 > pg_losses_1).astype(jnp.float32), completion_mask - ), - } - if use_dual_clip_ppo: - aux["pg_clipfrac_lower"] = pg_clipfrac_lower # pylint: disable=undefined-variable - - # "token mean" style of normalisation - policy_loss = ppo_helpers.masked_mean(pg_losses, completion_mask) - - # Compute entropy loss. - if entropy_coef is not None and entropy_coef > 0.0: - token_entropy = ppo_helpers.compute_entropy_from_logits(logits) - # "token mean" style of normalisation. - entropy_loss = ppo_helpers.masked_mean(token_entropy, completion_mask) - policy_loss -= entropy_coef * entropy_loss - - # Logging - aux["loss/entropy"] = entropy_loss - - return policy_loss, aux - - PpoConfig = PPOConfig PpoLearner = PPOLearner