diff --git a/tests/rl/agentic/agentic_rl_learner_test.py b/tests/rl/agentic/agentic_rl_learner_test.py index a0aceca07..1d6fe4503 100644 --- a/tests/rl/agentic/agentic_rl_learner_test.py +++ b/tests/rl/agentic/agentic_rl_learner_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest from absl.testing import parameterized +from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.agentic import agentic_rl_learner from tunix.rl.rollout import base_rollout @@ -135,6 +136,58 @@ def test_validate_rollout_config_vllm_missing_server_mode(self): algo_config=algo_config, ) + def test_train_raises_value_error_on_indivisible_batch_size(self): + rl_cluster = mock.Mock() + rl_cluster.cluster_config = mock.Mock() + rl_cluster.cluster_config.rollout_engine = "generic" + rl_cluster.cluster_config.rollout_config = base_rollout.RolloutConfig( + max_prompt_length=32, + max_tokens_to_generate=10, + return_logprobs=True, + ) + + # Configure roles + rl_cluster.cluster_config.role_to_mesh = { + rl_cluster_lib.Role.ACTOR: mock.Mock(), + rl_cluster_lib.Role.ROLLOUT: mock.Mock(), + } + + # Training config + training_config = mock.Mock() + training_config.mini_batch_size = 3 + training_config.train_micro_batch_size = 3 + training_config.rollout_micro_batch_size = 1 + training_config.compute_logps_micro_batch_size = None + rl_cluster.cluster_config.training_config = training_config + + rl_cluster.actor_trainer = mock.Mock() + rl_cluster.actor_trainer.restored_global_step.return_value = 0 + rl_cluster.actor_trainer.iter_steps = 0 + + algo_config = agentic_rl_learner.AgenticRLConfig( + max_response_length=10, + ) + + # Patch is_sharing_weights so we don't have to mock JAX structures + with mock.patch.object( + agentic_rl_learner.rl_utils, "is_sharing_weights", return_value=False + ): + learner = DummyLearner( + rl_cluster=rl_cluster, + reward_fns=mock.Mock(), + algo_config=algo_config, + ) + + # Create a dummy dataset where full_batch_size is 4 + # This should fail since mini_batch_size (3) does not divide + # full_batch_size (4) + dummy_dataset = [{"prompts": ["a", "b", "c", "d"]}] + + with self.assertRaisesRegex( + ValueError, "full_batch_size.*must be a multiple of.*mini_batch_size" + ): + learner.train(train_dataset=dummy_dataset) + if __name__ == "__main__": absltest.main() diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index 1e635d9b9..04cd8ceb8 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -59,7 +59,7 @@ TrainExample = agentic_rl_learner.TrainExample -@dataclasses.dataclass(kw_only=True) +@dataclasses.dataclass(slots=True, kw_only=True) class GRPOConfig(agentic_rl_learner.AgenticRLConfig): """Configuration for GRPO algorithm. @@ -216,10 +216,6 @@ def __init__( else: logging.warning("Metrics log dir is None, skipping trajectory logging.") - self.algo_config.temperature = self.rl_cluster.get_rollout_config( - mode=rl_cluster_lib.Mode.TRAIN - ).temperature - # Workaround to pass loss fn with algorithm flag policy_loss_fn = function_registry.get_policy_loss_fn( self.algo_config.policy_loss_fn @@ -379,10 +375,6 @@ def _process_results( dtype=old_logprobs.dtype, )[:max_response_length] ) - else: - padded_old_logprobs.append( - np.zeros(max_response_length, dtype=np.float32) - ) prompt_ids = jnp.asarray(padded_prompt_ids) prompt_mask = prompt_ids != pad_value @@ -431,7 +423,7 @@ def _process_results( completion_tokens=completion_ids, pad_id=pad_value, eos_id=eos_value, - micro_batch_size=None, + micro_batch_size=self._compute_logps_micro_batch_size, ) interval_v2.async_end([ref_per_token_logps]) else: @@ -512,11 +504,14 @@ def _process_results( } # Extract time metrics (env_time and reward_time) - for time_key in ["env_time", "reward_time"]: - prefix = f"trajectory/{time_key}" + for time_key, prefix in [ + ("env_time", "generation/trajectory/env_time"), + ("reward_time", "generation/trajectory/reward_time"), + ]: time_dicts = [item.traj.get(time_key, {}) for item in trajectories] - # Safely gather all unique sub-keys (e.g., 'reset_latency') across all trajectories + # Safely gather all unique sub-keys (e.g., 'reset_latency') across all + # trajectories for sub_key in {k for d in time_dicts for k in d.keys()}: vals = [d.get(sub_key, 0.0) for d in time_dicts] metrics_to_log.update({ @@ -619,7 +614,6 @@ def grpo_loss_fn( return_logits=True, segment_ids=getattr(train_example, "segment_ids", None), segment_positions=getattr(train_example, "segment_positions", None), - temperature=algo_config.temperature, ) per_token_logps = jnp.astype(per_token_logps, jnp.float32) # TODO(tsbao): We should handle token level advantages. diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index 54c8ebfa3..8256d1331 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -386,10 +386,7 @@ def _create_agent_env_pair( return agent, env def _model_call( - self, - chat_lists: List[Dict[str, str]], - env: Any = None, - max_generation_steps: int | None = None, + self, chat_lists: List[Dict[str, str]], env: Any = None, ) -> base_rollout.RolloutOutput: """Calls model generation.""" if env: @@ -419,7 +416,6 @@ def _model_call( apply_chat_template=False if self.chat_parser else True, mode=rl_cluster_lib.Mode.TRAIN, trace_tags=tags, - max_generation_steps=max_generation_steps, ) return result @@ -431,7 +427,6 @@ def _build_orchestrator(self) -> rollout_orchestrator.RolloutOrchestrator: tokenizer=self.tokenizer, chat_parser=self.chat_parser, timeout=self.algo_config.episode_timeout, - max_response_length=self.algo_config.max_response_length, overlong_filter=self.algo_config.overlong_filter, filter_statuses=self.algo_config.filter_statuses, perf_v2=self.rl_cluster.perf_v2, @@ -689,19 +684,38 @@ def train( train_micro_batch_size = ( self._training_config.train_micro_batch_size or mini_batch_size ) - # Rollout and compute_logps micro batch sizes have to be 1 since we only - # process inidividual prompts. + # Rollout micro batch size has to be 1 since we only process individual + # prompts. self._rollout_micro_batch_size = 1 - self._compute_logps_micro_batch_size = 1 for v, n in [ (self._rollout_micro_batch_size, f"{self._rollout_micro_batch_size=}"), - ( - self._compute_logps_micro_batch_size, - f"{self._compute_logps_micro_batch_size=}", - ), (mini_batch_size, f"{mini_batch_size=}"), ]: - rl_utils.check_divisibility(v, full_batch_size, n, f"{full_batch_size=}") + if v is not None: + rl_utils.check_divisibility( + v, full_batch_size, n, f"{full_batch_size=}" + ) + + if self._compute_logps_micro_batch_size is not None: + if ( + self._compute_logps_micro_batch_size + > self.algo_config.num_generations + ): + logging.warning( + "compute_logps_micro_batch_size (%d) is larger than num_generations" + " (%d). In agentic workflows, compute_logps currently operates" + " within a single prompt's generation group. The effective micro" + " batch size will be capped at num_generations.", + self._compute_logps_micro_batch_size, + self.algo_config.num_generations, + ) + else: + rl_utils.check_divisibility( + self._compute_logps_micro_batch_size, + self.algo_config.num_generations, + "compute_logps_micro_batch_size", + "num_generations", + ) grad_acc_steps = self._training_config.get_with_default( "gradient_accumulation_steps", 1 )