Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/deepswe/train_deepswe_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@
parser.add_argument("--epsilon", type=float, default=0.2)
parser.add_argument("--epsilon_high", type=float, default=0.28)
parser.add_argument("--off_policy_steps", type=int, default=0)
parser.add_argument(
"--advantage_estimator",
type=str,
default="rloo",
choices=["grpo", "rloo"],
)

# Rollout Config
parser.add_argument("--max_prompt_length", type=int, default=4096)
Expand Down
9 changes: 3 additions & 6 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import optax
import orbax.checkpoint as ocp
from tunix.generate import tokenizer_adapter
from tunix.rl import algo_core
from tunix.rl import common as rl_common
from tunix.rl import function_registry
from tunix.rl import rl_cluster as rl_cluster_lib
Expand Down Expand Up @@ -1906,17 +1907,13 @@ def _patch_process_results(

def test_compute_rloo_advantages(self):
rewards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
advantages = agentic_grpo_learner.compute_rloo_advantages(
rewards, num_generations=3
)
advantages = algo_core.compute_rloo_advantages(rewards, num_generations=3)
expected_value = jnp.array([-1.5, 0.0, 1.5, -1.5, 0.0, 1.5])
np.testing.assert_allclose(advantages, expected_value)

def test_compute_rloo_advantages_low_generations(self):
rewards = jnp.array([1.0, 2.0])
advantages = agentic_grpo_learner.compute_rloo_advantages(
rewards, num_generations=1
)
advantages = algo_core.compute_rloo_advantages(rewards, num_generations=1)
np.testing.assert_allclose(advantages, jnp.zeros_like(rewards))


Expand Down
2 changes: 1 addition & 1 deletion tests/rl/function_registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_custom_categories_instance(self):
def test_empty_categories_instance(self):
# Test-specific instance for empty categories
registry = function_registry.FunctionRegistry(allowed_categories=[])
self.assertLen(registry.list_categories(), 3)
self.assertLen(registry.list_categories(), 4)

@parameterized.named_parameters(
dict(
Expand Down
6 changes: 4 additions & 2 deletions tests/rl/grpo/dapo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def test_diff_loss(self):
rngs=nnx.Rngs(0),
)

# Call DAPO loss function
# Call DAPO loss function (DAPO sets ref_per_token_logps to None as it doesn't fetch it)
dapo_train_example = self.create_train_example()
dapo_train_example.ref_per_token_logps = None
dapo_loss, dapo_aux = dapo_loss_fn_impl(
model, train_example, dapo_config, pad_id, eos_id
model, dapo_train_example, dapo_config, pad_id, eos_id
)

# Call GRPO loss function
Expand Down
3 changes: 2 additions & 1 deletion tests/rl/grpo/drgrpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl import algo_core # pylint: disable=unused-import
from tunix.rl import function_registry as fr
from tunix.rl.grpo import drgrpo_learner as drgrpo_lib
from tunix.rl.grpo import grpo_learner as grpo_lib
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_compute_advantages(self):
rewards = jnp.array(
[[0.57450044, 0.09968603, 0.7419659, 0.8941783, 0.59656656, 0.45325184]]
)
advantages = drgrpo_lib.compute_advantages(rewards, num_generations=3)
advantages = algo_core.compute_drgrpo_advantages(rewards, num_generations=3)
expected_array = jnp.array([
[0.10245, -0.372365, 0.269915, 0.246179, -0.051432, -0.194747],
])
Expand Down
5 changes: 3 additions & 2 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import orbax.checkpoint as ocp
from tunix.perf import trace as trace_lib
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl import algo_core as grpo_core
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo import grpo_learner as grpo_lib
from tunix.rl.queue import data_queue as queue_lib
Expand Down Expand Up @@ -1234,9 +1235,9 @@ def test_compute_advantages(self):

rng = jax.random.PRNGKey(0)
rewards = jax.random.uniform(rng, shape=(1, 6))
advantages = grpo_lib.compute_advantages(rewards, num_generations=3)
advantages = grpo_core.compute_advantages(rewards, num_generations=3)
expected_value = jnp.array(
[[0.307407, -1.117304, 0.809897, 1.094044, -0.22857, -0.865474]]
[[0.307498, -1.117636, 0.810138, 1.094526, -0.228671, -0.865855]]
)
np.testing.assert_allclose(advantages, expected_value, rtol=1e-5, atol=1e-5)

Expand Down
2 changes: 1 addition & 1 deletion tests/rl/ppo/ppo_helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl.ppo import ppo_helpers
from tunix.rl import algo_core as ppo_helpers


def _ref_compute_gae_advantages(
Expand Down
218 changes: 3 additions & 215 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl import algo_core # pylint: disable=unused-import
from tunix.perf.experimental import constants as perf_constants
from tunix.rl import common
from tunix.rl import function_registry
Expand Down Expand Up @@ -88,8 +89,8 @@ class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
"""

algo_variant: str = "agentic_grpo"
advantage_estimator: str = "agentic_grpo"
policy_loss_fn: str = "agentic_grpo"
advantage_estimator: str = "grpo"
policy_loss_fn: str = "grpo"
loss_agg_mode: str = "sequence-mean-token-mean"
loss_algo: (
str
Expand Down Expand Up @@ -559,218 +560,5 @@ def _process_results(
return [combined_batch]


@function_registry.register_policy_loss_fn("agentic_grpo")
def grpo_loss_fn(
model,
train_example,
algo_config,
pad_id,
eos_id,
):
"""GRPO loss function.

The loss aims to maximize the expected advantage of the chosen actions while
constraining the policy updates to stay within a certain range of the
reference policy.

Args:
model: The policy model to be trained.
train_example: A `TrainExample` instance containing the processed input
data, including prompt IDs, completion IDs, masks, advantages, and
per-token log probabilities from the reference and policy models.
algo_config: The algorithm config.
pad_id: The pad ID from tokenizer.
eos_id: The eos ID from.

Returns:
A tuple containing the loss and an aux dictionary.
"""
beta = algo_config.beta
epsilon = algo_config.epsilon
loss_algo = algo_config.loss_algo
epsilon_high = (
algo_config.epsilon_high
if hasattr(algo_config, "epsilon_high")
else epsilon
)
epsilon_c = (
algo_config.epsilon_c
if hasattr(algo_config, "epsilon_c")
else 3.0
)
loss_aggregation_mode = algo_config.loss_agg_mode

completion_ids, completion_mask = (
train_example.completion_ids,
train_example.completion_mask,
)

# TODO(tsbao): split can be avoided with updated peft_trainer model handling.
graphdef, state = nnx.split(model)
per_token_logps, logits = common.compute_per_token_logps(
graphdef,
state,
prompt_tokens=train_example.prompt_ids,
completion_tokens=completion_ids,
pad_id=pad_id,
eos_id=eos_id,
completion_mask=completion_mask,
stop_gradient=False,
return_logits=True,
segment_ids=getattr(train_example, "segment_ids", None),
segment_positions=getattr(train_example, "segment_positions", None),
temperature=algo_config.temperature,
)
per_token_logps = jnp.astype(per_token_logps, jnp.float32)
# TODO(tsbao): We should handle token level advantages.
advantages = jnp.astype(train_example.advantages, jnp.float32)

if train_example.old_per_token_logps is None:
old_per_token_logps = jax.lax.stop_gradient(per_token_logps)
else:
old_per_token_logps = jnp.astype(
train_example.old_per_token_logps, jnp.float32
)

seq_importance_ratio = per_token_logps - old_per_token_logps
# Record KL divergence before clipping.
ppo_kl = ppo_helpers.masked_mean(-seq_importance_ratio, completion_mask)

seq_importance_ratio = jnp.clip(seq_importance_ratio, max=20.0, min=-20.0)

# TODO(sizhi): Refactor this to a separate function.
if loss_algo == "gspo-token":
seq_importance_ratio = (seq_importance_ratio * completion_mask).sum(
axis=-1
) / jnp.clip(completion_mask.sum(-1), min=1)
seq_importance_ratio = (
per_token_logps
- jax.lax.stop_gradient(per_token_logps)
+ jnp.expand_dims(jax.lax.stop_gradient(seq_importance_ratio), axis=-1)
)
seq_importance_ratio = jnp.clip(seq_importance_ratio, max=10.0)

is_ratio = jnp.exp(seq_importance_ratio)

# Advantages must be broadcast against seq_length.
# When sequence packing is used, advantages are already 2D [B, seq_length].
# When unpacked, they are 1D [B].
adv = advantages if advantages.ndim == 2 else jnp.expand_dims(advantages, 1)

pg_loss_1 = -adv * is_ratio
pg_loss_2 = -adv * jnp.clip(is_ratio, 1 - epsilon, 1 + epsilon_high)

per_token_loss = jnp.maximum(pg_loss_1, pg_loss_2).astype(jnp.float32)

clipped_fraction = ppo_helpers.masked_mean(
jnp.greater(pg_loss_2, pg_loss_1), completion_mask
)

# dual-clip ppo loss
pg_loss_3 = -epsilon_c * adv

# pg_clipfrac_lower measures how often dual-clip ppo kicks in.
# It kicks in when the standard clipped loss is larger than pg_loss_3
# for instances with negative advantages.
unreduced_pg_clipfrac_lower = (
(per_token_loss > pg_loss_3) & (adv < 0.0)
).astype(jnp.float32)
pg_clipfrac_lower = common.aggregate_loss(
unreduced_pg_clipfrac_lower, completion_mask, loss_aggregation_mode
)

pg_loss_clipped_dual = jnp.minimum(pg_loss_3, per_token_loss)
per_token_loss = jnp.where(adv < 0.0, pg_loss_clipped_dual, per_token_loss)
loss = common.aggregate_loss(
per_token_loss, completion_mask, loss_aggregation_mode
)
aux = {
"kl": 0.0,
"kl_loss": 0.0,
"pg_loss": loss,
"pg_clipfrac": clipped_fraction,
"ppo_kl": ppo_kl,
"pg_clipfrac_lower": pg_clipfrac_lower,
}
# We do not alwayscompute KL divergence (e.g. when beta is 0.0 unless
# force_compute_kl is True).
if train_example.ref_per_token_logps is not None:
kl = common.compute_kl_divergence(
per_token_logps,
train_example.ref_per_token_logps,
algo_config.kl_loss_mode,
)
# Log mean KL.
aux["kl"] = jnp.astype(
(kl * completion_mask).sum() / jnp.clip(completion_mask.sum(), min=1),
jnp.float32,
)
kl_loss = common.aggregate_loss(
kl, completion_mask, loss_aggregation_mode
)
aux["kl_loss"] = kl_loss
if beta is not None and beta != 0.0:
loss = loss + beta * kl_loss

token_entropy = ppo_helpers.compute_entropy_from_logits(logits)
entropy_loss = common.aggregate_loss(
token_entropy, completion_mask, loss_aggregation_mode
)
aux["entropy"] = entropy_loss

return loss, aux


@function_registry.register_advantage_estimator("agentic_grpo")
def compute_advantages(rewards: jax.Array, num_generations: int) -> jax.Array:
"""Compute group relative advantages.

Args:
rewards: reward functions output.
num_generations: Number of generations.

Returns:
Group relative advantages.
"""
rewards = jnp.astype(rewards, jnp.float32)
mean_grouped_rewards = rewards.reshape(-1, num_generations).mean(axis=-1)
std_grouped_rewards = rewards.reshape(-1, num_generations).std(
axis=-1, ddof=1
)

mean_grouped_rewards = mean_grouped_rewards.repeat(num_generations)
std_grouped_rewards = std_grouped_rewards.repeat(num_generations)
return (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-6)


@function_registry.register_advantage_estimator("agentic_rloo")
def compute_rloo_advantages(
rewards: jax.Array, num_generations: int
) -> jax.Array:
"""Compute RLOO (REINFORCE Leave-One-Out) advantages.

RLOO computes a baseline for each completion by averaging the rewards of all
other completions to the same prompt.

Args:
rewards: reward functions output.
num_generations: Number of generations.

Returns:
RLOO advantages.
"""
if num_generations < 2:
# RLOO requires at least 2 samples to calculate a baseline.
return jnp.zeros_like(rewards)

reshaped_rewards = rewards.reshape(-1, num_generations)
loo_mean = (
reshaped_rewards.sum(axis=-1, keepdims=True) - reshaped_rewards
) / (num_generations - 1)
rloo_advantages = reshaped_rewards - loo_mean

return rloo_advantages.flatten()


GrpoConfig = GRPOConfig
GrpoLearner = GRPOLearner
2 changes: 1 addition & 1 deletion tunix/rl/agentic/agentic_rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from tunix.rl.agentic.environments import base_environment
from tunix.rl.agentic.environments import task_environment
from tunix.rl.agentic.pipeline import rollout_orchestrator
from tunix.rl.agentic.rewards import reward
from tunix.rl.agentic.rewards import reward # pylint: disable=unused-import
from tunix.rl.agentic.trajectory import trajectory_collect_engine
from tunix.rl.queue import data_queue as queue_lib
from tunix.sft import utils as sft_utils
Expand Down
Loading
Loading