Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/rl/agentic/agentic_rl_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from absl.testing import absltest
from absl.testing import parameterized
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.agentic import agentic_rl_learner
from tunix.rl.rollout import base_rollout

Expand Down Expand Up @@ -135,6 +136,58 @@ def test_validate_rollout_config_vllm_missing_server_mode(self):
algo_config=algo_config,
)

def test_train_raises_value_error_on_indivisible_batch_size(self):
rl_cluster = mock.Mock()
rl_cluster.cluster_config = mock.Mock()
rl_cluster.cluster_config.rollout_engine = "generic"
rl_cluster.cluster_config.rollout_config = base_rollout.RolloutConfig(
max_prompt_length=32,
max_tokens_to_generate=10,
return_logprobs=True,
)

# Configure roles
rl_cluster.cluster_config.role_to_mesh = {
rl_cluster_lib.Role.ACTOR: mock.Mock(),
rl_cluster_lib.Role.ROLLOUT: mock.Mock(),
}

# Training config
training_config = mock.Mock()
training_config.mini_batch_size = 3
training_config.train_micro_batch_size = 3
training_config.rollout_micro_batch_size = 1
training_config.compute_logps_micro_batch_size = None
rl_cluster.cluster_config.training_config = training_config

rl_cluster.actor_trainer = mock.Mock()
rl_cluster.actor_trainer.restored_global_step.return_value = 0
rl_cluster.actor_trainer.iter_steps = 0

algo_config = agentic_rl_learner.AgenticRLConfig(
max_response_length=10,
)

# Patch is_sharing_weights so we don't have to mock JAX structures
with mock.patch.object(
agentic_rl_learner.rl_utils, "is_sharing_weights", return_value=False
):
learner = DummyLearner(
rl_cluster=rl_cluster,
reward_fns=mock.Mock(),
algo_config=algo_config,
)

# Create a dummy dataset where full_batch_size is 4
# This should fail since mini_batch_size (3) does not divide
# full_batch_size (4)
dummy_dataset = [{"prompts": ["a", "b", "c", "d"]}]

with self.assertRaisesRegex(
ValueError, "full_batch_size.*must be a multiple of.*mini_batch_size"
):
learner.train(train_dataset=dummy_dataset)


if __name__ == "__main__":
absltest.main()
22 changes: 8 additions & 14 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
TrainExample = agentic_rl_learner.TrainExample


@dataclasses.dataclass(kw_only=True)
@dataclasses.dataclass(slots=True, kw_only=True)
class GRPOConfig(agentic_rl_learner.AgenticRLConfig):
"""Configuration for GRPO algorithm.

Expand Down Expand Up @@ -216,10 +216,6 @@ def __init__(
else:
logging.warning("Metrics log dir is None, skipping trajectory logging.")

self.algo_config.temperature = self.rl_cluster.get_rollout_config(
mode=rl_cluster_lib.Mode.TRAIN
).temperature

# Workaround to pass loss fn with algorithm flag
policy_loss_fn = function_registry.get_policy_loss_fn(
self.algo_config.policy_loss_fn
Expand Down Expand Up @@ -379,10 +375,6 @@ def _process_results(
dtype=old_logprobs.dtype,
)[:max_response_length]
)
else:
padded_old_logprobs.append(
np.zeros(max_response_length, dtype=np.float32)
)

prompt_ids = jnp.asarray(padded_prompt_ids)
prompt_mask = prompt_ids != pad_value
Expand Down Expand Up @@ -431,7 +423,7 @@ def _process_results(
completion_tokens=completion_ids,
pad_id=pad_value,
eos_id=eos_value,
micro_batch_size=None,
micro_batch_size=self._compute_logps_micro_batch_size,
)
interval_v2.async_end([ref_per_token_logps])
else:
Expand Down Expand Up @@ -512,11 +504,14 @@ def _process_results(
}

# Extract time metrics (env_time and reward_time)
for time_key in ["env_time", "reward_time"]:
prefix = f"trajectory/{time_key}"
for time_key, prefix in [
("env_time", "generation/trajectory/env_time"),
("reward_time", "generation/trajectory/reward_time"),
]:
time_dicts = [item.traj.get(time_key, {}) for item in trajectories]

# Safely gather all unique sub-keys (e.g., 'reset_latency') across all trajectories
# Safely gather all unique sub-keys (e.g., 'reset_latency') across all
# trajectories
for sub_key in {k for d in time_dicts for k in d.keys()}:
vals = [d.get(sub_key, 0.0) for d in time_dicts]
metrics_to_log.update({
Expand Down Expand Up @@ -619,7 +614,6 @@ def grpo_loss_fn(
return_logits=True,
segment_ids=getattr(train_example, "segment_ids", None),
segment_positions=getattr(train_example, "segment_positions", None),
temperature=algo_config.temperature,
)
per_token_logps = jnp.astype(per_token_logps, jnp.float32)
# TODO(tsbao): We should handle token level advantages.
Expand Down
42 changes: 28 additions & 14 deletions tunix/rl/agentic/agentic_rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,7 @@ def _create_agent_env_pair(
return agent, env

def _model_call(
self,
chat_lists: List[Dict[str, str]],
env: Any = None,
max_generation_steps: int | None = None,
self, chat_lists: List[Dict[str, str]], env: Any = None,
) -> base_rollout.RolloutOutput:
"""Calls model generation."""
if env:
Expand Down Expand Up @@ -419,7 +416,6 @@ def _model_call(
apply_chat_template=False if self.chat_parser else True,
mode=rl_cluster_lib.Mode.TRAIN,
trace_tags=tags,
max_generation_steps=max_generation_steps,
)

return result
Expand All @@ -431,7 +427,6 @@ def _build_orchestrator(self) -> rollout_orchestrator.RolloutOrchestrator:
tokenizer=self.tokenizer,
chat_parser=self.chat_parser,
timeout=self.algo_config.episode_timeout,
max_response_length=self.algo_config.max_response_length,
overlong_filter=self.algo_config.overlong_filter,
filter_statuses=self.algo_config.filter_statuses,
perf_v2=self.rl_cluster.perf_v2,
Expand Down Expand Up @@ -689,19 +684,38 @@ def train(
train_micro_batch_size = (
self._training_config.train_micro_batch_size or mini_batch_size
)
# Rollout and compute_logps micro batch sizes have to be 1 since we only
# process inidividual prompts.
# Rollout micro batch size has to be 1 since we only process individual
# prompts.
self._rollout_micro_batch_size = 1
self._compute_logps_micro_batch_size = 1
for v, n in [
(self._rollout_micro_batch_size, f"{self._rollout_micro_batch_size=}"),
(
self._compute_logps_micro_batch_size,
f"{self._compute_logps_micro_batch_size=}",
),
(mini_batch_size, f"{mini_batch_size=}"),
]:
rl_utils.check_divisibility(v, full_batch_size, n, f"{full_batch_size=}")
if v is not None:
rl_utils.check_divisibility(
v, full_batch_size, n, f"{full_batch_size=}"
)

if self._compute_logps_micro_batch_size is not None:
if (
self._compute_logps_micro_batch_size
> self.algo_config.num_generations
):
logging.warning(
"compute_logps_micro_batch_size (%d) is larger than num_generations"
" (%d). In agentic workflows, compute_logps currently operates"
" within a single prompt's generation group. The effective micro"
" batch size will be capped at num_generations.",
self._compute_logps_micro_batch_size,
self.algo_config.num_generations,
)
else:
rl_utils.check_divisibility(
self._compute_logps_micro_batch_size,
self.algo_config.num_generations,
"compute_logps_micro_batch_size",
"num_generations",
)
grad_acc_steps = self._training_config.get_with_default(
"gradient_accumulation_steps", 1
)
Expand Down
Loading