diff --git a/docs/metrics.md b/docs/metrics.md index d9e74a04c..b8917bb8b 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -99,7 +99,7 @@ options = metrics_logger.MetricsLoggerOptions( logger = metrics_logger.MetricsLogger(metrics_logger_options=options) ``` -With the above, agentic_grpo_learner will by default start an async trajectory +With the above, agentic_learner will by default start an async trajectory logger which logs the trajectories including prompts, responses, etc. to the specified `log_dir`. diff --git a/docs/quickstart.md b/docs/quickstart.md index 136a3457d..bc98cd29c 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -323,7 +323,7 @@ out_data = sampler( During reinforcement learning (RL) training, it is often useful to analyze the generated trajectories (prompts, responses, rewards, etc.). Tunix provides an `AsyncTrajectoryLogger` to log this data asynchronously to CSV files without -blocking the training loop. It's enabled in agentic_grpo_learner by default, if +blocking the training loop. It's enabled in agentic_learner by default, if you provide a log directory in your cluster configuration training config. ```python diff --git a/examples/agentic/gemma_grpo_demo_nb.py b/examples/agentic/gemma_grpo_demo_nb.py index c33930958..2472c657c 100644 --- a/examples/agentic/gemma_grpo_demo_nb.py +++ b/examples/agentic/gemma_grpo_demo_nb.py @@ -63,7 +63,7 @@ from tunix.models.gemma import model as gemma_lib from tunix.sft import utils from tunix.utils import script_utils - from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner + from tunix.rl.agentic.agentic_learner import GRPOConfig, GRPOLearner from flax import nnx from tunix.cli.utils import model as model_utils diff --git a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh index 7feb5b3d9..59fe8ff27 100755 --- a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh +++ b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh @@ -128,19 +128,19 @@ python -m tunix.cli.grpo_main \ tokenizer_config.add_eos=false \ \ `# ── GRPO algorithm ───────────────────────────────────────────────────` \ - agentic_grpo_config.num_generations=8 \ - agentic_grpo_config.num_iterations=1 \ - agentic_grpo_config.beta=0.0 \ - agentic_grpo_config.epsilon=0.2 \ - agentic_grpo_config.epsilon_high=0.28 \ - agentic_grpo_config.system_prompt="" \ - agentic_grpo_config.max_concurrency=1024 \ - agentic_grpo_config.max_response_length="$max_response_length" \ - agentic_grpo_config.off_policy_steps=0 \ - agentic_grpo_config.loss_agg_mode="token-mean" \ - agentic_grpo_config.kl_loss_mode="low_var_kl" \ - agentic_grpo_config.max_turns=1 \ - agentic_grpo_config.context_ratio=1 \ + agentic_config.num_generations=8 \ + agentic_config.num_iterations=1 \ + agentic_config.beta=0.0 \ + agentic_config.epsilon=0.2 \ + agentic_config.epsilon_high=0.28 \ + agentic_config.system_prompt="" \ + agentic_config.max_concurrency=1024 \ + agentic_config.max_response_length="$max_response_length" \ + agentic_config.off_policy_steps=0 \ + agentic_config.loss_agg_mode="token-mean" \ + agentic_config.kl_loss_mode="low_var_kl" \ + agentic_config.max_turns=1 \ + agentic_config.context_ratio=1 \ \ `# ── Optimizer ────────────────────────────────────────────────────────` \ rl_training_config.actor_optimizer_config.opt_type="adamw" \ diff --git a/examples/deepscaler/train_deepscaler_nb.py b/examples/deepscaler/train_deepscaler_nb.py index 8cc903077..6d0bdb1e9 100644 --- a/examples/deepscaler/train_deepscaler_nb.py +++ b/examples/deepscaler/train_deepscaler_nb.py @@ -60,7 +60,7 @@ from tunix.models.qwen2 import params as params_lib from tunix.models.qwen2 import model as model_lib from tunix.sft import metrics_logger - from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner + from tunix.rl.agentic.agentic_learner import GRPOConfig, GRPOLearner from tunix.rl.agentic.agents import model_agent from tunix.rl.agentic.environments import task_environment from tunix.rl.agentic.trajectory import trajectory_collect_engine diff --git a/examples/deepswe/run_deepswe_disagg_v5p_32.sh b/examples/deepswe/run_deepswe_disagg_v5p_32.sh index 4eceb7ae2..bff865ef4 100755 --- a/examples/deepswe/run_deepswe_disagg_v5p_32.sh +++ b/examples/deepswe/run_deepswe_disagg_v5p_32.sh @@ -149,21 +149,21 @@ python -m tunix.cli.grpo_main \ kubernetes_config.node_selector_val="deepswe-cpu-pool" \ \ `# ── Agentic / multi-turn ─────────────────────────────────────────────` \ - agentic_grpo_config.max_turns=20 \ - agentic_grpo_config.per_turn_timeout_secs=300 \ - agentic_grpo_config.context_ratio=2 \ - agentic_grpo_config.max_concurrency=100 \ + agentic_config.max_turns=20 \ + agentic_config.per_turn_timeout_secs=300 \ + agentic_config.context_ratio=2 \ + agentic_config.max_concurrency=100 \ \ `# ── GRPO algorithm ───────────────────────────────────────────────────` \ - agentic_grpo_config.num_generations="$num_generations" \ - agentic_grpo_config.max_response_length="$max_response_length" \ - agentic_grpo_config.num_iterations=1 \ - agentic_grpo_config.beta=0.001 \ - agentic_grpo_config.epsilon=0.2 \ - agentic_grpo_config.epsilon_high=0.28 \ - agentic_grpo_config.off_policy_steps=0 \ - agentic_grpo_config.loss_agg_mode="seq-mean-token-mean" \ - agentic_grpo_config.kl_loss_mode="low_var_kl" \ + agentic_config.num_generations="$num_generations" \ + agentic_config.max_response_length="$max_response_length" \ + agentic_config.num_iterations=1 \ + agentic_config.beta=0.001 \ + agentic_config.epsilon=0.2 \ + agentic_config.epsilon_high=0.28 \ + agentic_config.off_policy_steps=0 \ + agentic_config.loss_agg_mode="seq-mean-token-mean" \ + agentic_config.kl_loss_mode="low_var_kl" \ \ `# ── Optimizer ────────────────────────────────────────────────────────` \ rl_training_config.actor_optimizer_config.opt_type="adamw" \ diff --git a/examples/deepswe/train_deepswe_nb.py b/examples/deepswe/train_deepswe_nb.py index c2623578d..6c07ff450 100644 --- a/examples/deepswe/train_deepswe_nb.py +++ b/examples/deepswe/train_deepswe_nb.py @@ -171,7 +171,7 @@ from tunix.sft import metrics_logger from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout -from tunix.rl.agentic import agentic_grpo_learner +from tunix.rl.agentic import agentic_learner from tunix.rl.agentic.parser.chat_template_parser import parser as template_parser from tunix.rl.agentic.rewards.reward_types import RewardOutput from examples.deepswe.swe_agent import ( @@ -564,7 +564,7 @@ def transform(entry): # ========================================== # 11. Learner & Agent Setup # ========================================== -grpo_config = agentic_grpo_learner.GRPOConfig( +grpo_config = agentic_learner.GRPOConfig( num_generations=NUM_GENERATIONS, num_iterations=NUM_ITERATIONS, max_response_length=MAX_RESPONSE_LENGTH, @@ -579,7 +579,7 @@ def transform(entry): ) -agentic_grpo_learner = agentic_grpo_learner.GRPOLearner( +agentic_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=None, agent_class=SWEAgent, @@ -652,7 +652,7 @@ def mixed_type_batch_fn(elements): print("Starting training...") -agentic_grpo_learner.train(train_dataset=train_dataset) +agentic_learner.train(train_dataset=train_dataset) # %% diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh index dfc5e7c11..8c75a8e5e 100755 --- a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh +++ b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh @@ -130,15 +130,15 @@ python -m tunix.cli.grpo_main \ tokenizer_config.add_eos=false \ \ `# -- GRPO algorithm ---------------------------------------------------` \ - agentic_grpo_config.num_generations="$num_generations" \ - agentic_grpo_config.num_iterations=1 \ - agentic_grpo_config.beta=0.08 \ - agentic_grpo_config.epsilon=0.2 \ - agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \ - agentic_grpo_config.max_concurrency=128 \ - agentic_grpo_config.max_response_length=768 \ - agentic_grpo_config.max_turns=1 \ - agentic_grpo_config.context_ratio=1 \ + agentic_config.num_generations="$num_generations" \ + agentic_config.num_iterations=1 \ + agentic_config.beta=0.08 \ + agentic_config.epsilon=0.2 \ + agentic_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \ + agentic_config.max_concurrency=128 \ + agentic_config.max_response_length=768 \ + agentic_config.max_turns=1 \ + agentic_config.context_ratio=1 \ \ `# -- Optimizer --------------------------------------------------------` \ rl_training_config.actor_optimizer_config.opt_type="adamw" \ diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index b68d98117..5cd44cccc 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -224,7 +224,7 @@ def test_agentic_data_module_receives_data_config_for_raw_dataset(self): env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 @@ -277,7 +277,7 @@ def test_agentic_nullable_string_can_be_overridden_from_cli(self): env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 @@ -321,7 +321,7 @@ def test_agentic_nullable_dict_can_be_overridden_from_cli(self): env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 @@ -371,7 +371,7 @@ def test_agentic_nullable_string_can_be_overridden_from_env(self): env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 @@ -417,7 +417,7 @@ def test_standard_grpo_dispatches_to_standard(self): pipeline.run_grpo_trainer() mock_run.assert_called_once_with(mode="grpo") - def test_agentic_grpo_dispatches_to_agentic(self): + def test_agentic_dispatches_to_agentic(self): extra = """ training_mode: "agentic_grpo" data_module: "tunix.cli.recipes.deepscaler_data" @@ -436,7 +436,7 @@ def test_agentic_grpo_dispatches_to_agentic(self): env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 @@ -530,7 +530,7 @@ def _make_agentic_pipeline(self, max_turns, context_ratio): env_class_path: null env_kwargs: {{}} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 @@ -603,7 +603,7 @@ def _base_extra(self, agentic_overrides="", system_prompt='""'): env_class_path: null env_kwargs: {{}} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.001 @@ -623,26 +623,26 @@ def test_episode_timeout_computed(self): p = _make_pipeline( self._base_extra("max_turns: 20\n per_turn_timeout_secs: 300") ) - algo = p._create_agentic_grpo_config() + algo = p._create_agentic_config() self.assertEqual(algo.episode_timeout, 300 * 20) def test_max_response_length_from_rollout(self): p = _make_pipeline(self._base_extra("max_turns: 1")) - algo = p._create_agentic_grpo_config() + algo = p._create_agentic_config() # rollout_config.total_generation_steps = 512 self.assertEqual(algo.max_response_length, 512) def test_num_generations_passed_through(self): p = _make_pipeline(self._base_extra("max_turns: 1")) - algo = p._create_agentic_grpo_config() + algo = p._create_agentic_config() self.assertEqual(algo.num_generations, 2) def test_cli_empty_system_prompt_stays_empty_string(self): p = _make_pipeline_with_cli_args( self._base_extra("max_turns: 1", system_prompt='"base"'), - ['agentic_grpo_config.system_prompt=""'], + ['agentic_config.system_prompt=""'], ) - self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "") + self.assertEqual(p.config["agentic_config"]["system_prompt"], "") class SplitMeshConfigTest(absltest.TestCase): @@ -665,7 +665,7 @@ def test_split_mesh_uses_explicit_role_meshes(self): env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: +agentic_config: num_generations: 2 num_iterations: 1 beta: 0.0 diff --git a/tests/rl/agentic/agentic_grpo_learner_test.py b/tests/rl/agentic/agentic_learner_test.py similarity index 96% rename from tests/rl/agentic/agentic_grpo_learner_test.py rename to tests/rl/agentic/agentic_learner_test.py index 571c37073..520fb790e 100644 --- a/tests/rl/agentic/agentic_grpo_learner_test.py +++ b/tests/rl/agentic/agentic_learner_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for agentic_grpo_learner.""" +"""Tests for agentic_learner.""" import asyncio import functools @@ -42,7 +42,7 @@ from tunix.rl import common as rl_common from tunix.rl import function_registry from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.agentic import agentic_grpo_learner +from tunix.rl.agentic import agentic_learner from tunix.rl.agentic.agents.agent_types import Action, Step from tunix.rl.agentic.agents.base_agent import ConversationAgentBase from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult @@ -55,7 +55,7 @@ os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" Mesh = sharding.Mesh -TrainingInputT = agentic_grpo_learner.TrainingInputT +TrainingInputT = agentic_learner.TrainingInputT def reward_fn_1(prompts, completions, **kwargs): @@ -194,7 +194,7 @@ def assistant_token(self): return "Assistant: " -class _LearnerWithException(agentic_grpo_learner.GRPOLearner): +class _LearnerWithException(agentic_learner.GRPOLearner): def _batch_to_train_example(self, batch_results, mode): raise ValueError("test exception in producer") @@ -218,7 +218,7 @@ def setUp(self): ) def test_iterator(self): - class _MockTrainer(agentic_grpo_learner.GRPOLearner): + class _MockTrainer(agentic_learner.GRPOLearner): def __init__(self, algo_config): self.algo_config = algo_config @@ -276,7 +276,7 @@ async def _orchestrator_producer( yield group, [example] i += 1 - algo_config = agentic_grpo_learner.GRPOConfig( + algo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=2, ) @@ -305,11 +305,11 @@ def test_grpo_config_validation(self): with self.assertRaisesRegex( ValueError, "num_generations must be greater than 1" ): - agentic_grpo_learner.GRPOConfig(num_generations=1) + agentic_learner.GRPOConfig(num_generations=1) with self.assertRaisesRegex( ValueError, "loss_algo should be either grpo or gspo-token" ): - agentic_grpo_learner.GRPOConfig(loss_algo="invalid") + agentic_learner.GRPOConfig(loss_algo="invalid") def test_num_iterations_greater_than_1(self): vocab = test_common.MockVocab() @@ -354,13 +354,13 @@ def test_num_iterations_greater_than_1(self): cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=2, # > 1 loss_algo="grpo", max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -411,7 +411,7 @@ def test_grpo_loss_fn(self, loss_algo): (batch_size, seq_len), -0.1, dtype=jnp.float32 ) - train_example = agentic_grpo_learner.TrainExample( + train_example = agentic_learner.TrainExample( prompt_ids=prompt_ids, prompt_mask=prompt_ids > -1, completion_ids=completion_ids, @@ -436,7 +436,7 @@ def __call__(self, inputs, positions, cache, attention_mask): None, ) - algo_config = agentic_grpo_learner.GRPOConfig( + algo_config = agentic_learner.GRPOConfig( beta=0.1, epsilon=0.2, loss_algo=loss_algo, @@ -508,7 +508,7 @@ def __call__(self, inputs, positions, cache, attention_mask): # Unmasked example final_completion_mask = completion_mask - train_example = agentic_grpo_learner.TrainExample( + train_example = agentic_learner.TrainExample( prompt_ids=prompt_ids, prompt_mask=prompt_ids > -1, completion_ids=completion_ids, @@ -518,7 +518,7 @@ def __call__(self, inputs, positions, cache, attention_mask): old_per_token_logps=None, ) - config = agentic_grpo_learner.GRPOConfig( + config = agentic_learner.GRPOConfig( beta=0.1, epsilon=0.2, num_generations=2, @@ -624,7 +624,7 @@ def mock_compute_rewards(prompts, completions, **kwargs): tokenizer=tokenizer, cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( beta=0.1, epsilon=0.2, num_generations=2, @@ -632,7 +632,7 @@ def mock_compute_rewards(prompts, completions, **kwargs): max_response_length=10, ) - learner = agentic_grpo_learner.GRPOLearner( + learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=None, algo_config=grpo_config, @@ -724,7 +724,7 @@ def __init__( tokenizer=tokenizer, cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( beta=0.1, epsilon=0.2, num_generations=2, @@ -733,7 +733,7 @@ def __init__( max_response_length=10, ) - learner = agentic_grpo_learner.GRPOLearner( + learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=None, algo_config=grpo_config, @@ -836,12 +836,12 @@ def create_learner( cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -938,12 +938,12 @@ def create_learner( cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn, algo_config=grpo_config, @@ -1045,12 +1045,12 @@ def create_learner( cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn, algo_config=grpo_config, @@ -1144,14 +1144,14 @@ def test_force_compute_kl(self, beta, force_compute_kl, expect_ref_logps): cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( beta=beta, force_compute_kl=force_compute_kl, max_response_length=10, num_generations=2, num_iterations=1, ) - learner = agentic_grpo_learner.GRPOLearner( + learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -1233,7 +1233,7 @@ def test_exception_handling(self): tokenizer=tokenizer, cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig(max_response_length=10) + grpo_config = agentic_learner.GRPOConfig(max_response_length=10) learner = _LearnerWithException( rl_cluster=rl_cluster, reward_fns=reward_fn_1, @@ -1310,13 +1310,13 @@ def test_grpo_learner(self, reward_fns, loss_algo, use_old_logprobs=False): ) rl_cluster.with_external_metrics_logger(print) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, loss_algo=loss_algo, max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fns, algo_config=grpo_config, @@ -1466,14 +1466,14 @@ def test_on_off_policy_training(self, offpolicy_steps): cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, loss_algo="grpo", off_policy_steps=offpolicy_steps, max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -1534,8 +1534,8 @@ def test_put_prompts_to_queue(self): cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig(max_response_length=512) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_config = agentic_learner.GRPOConfig(max_response_length=512) + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -1600,13 +1600,13 @@ def test_trajectory_logging(self): cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, loss_algo="grpo", max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -1695,13 +1695,13 @@ def test_grpo_with_lora_model(self): tokenizer=tokenizer, cluster_config=cluster_config, ) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, max_response_length=10, ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -1819,14 +1819,14 @@ def update_from_model(self, response, **kwargs): ) rl_cluster.with_external_metrics_logger(print) - grpo_config = agentic_grpo_learner.GRPOConfig( + grpo_config = agentic_learner.GRPOConfig( num_generations=2, num_iterations=1, loss_algo="grpo", max_response_length=128, max_concurrency=1, # so the output is deterministic. ) - grpo_learner = agentic_grpo_learner.GRPOLearner( + grpo_learner = agentic_learner.GRPOLearner( rl_cluster=rl_cluster, reward_fns=reward_fn_1, algo_config=grpo_config, @@ -1901,7 +1901,7 @@ def _patch_process_results( def test_compute_rloo_advantages(self): rewards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - advantages = agentic_grpo_learner.compute_rloo_advantages( + advantages = agentic_learner.compute_rloo_advantages( rewards, num_generations=3 ) expected_value = jnp.array([-1.5, 0.0, 1.5, -1.5, 0.0, 1.5]) @@ -1909,7 +1909,7 @@ def test_compute_rloo_advantages(self): def test_compute_rloo_advantages_low_generations(self): rewards = jnp.array([1.0, 2.0]) - advantages = agentic_grpo_learner.compute_rloo_advantages( + advantages = agentic_learner.compute_rloo_advantages( rewards, num_generations=1 ) np.testing.assert_allclose(advantages, jnp.zeros_like(rewards)) diff --git a/tunix/__init__.py b/tunix/__init__.py index 7ddbffa7f..a91cd3149 100644 --- a/tunix/__init__.py +++ b/tunix/__init__.py @@ -33,8 +33,8 @@ from tunix.perf.export import PerfMetricsExport from tunix.perf.metrics import PerfMetricsConfig from tunix.perf.metrics import PerfSpanQuery -from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig as AgenticGRPOConfig -from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner as AgenticGRPOLearner +from tunix.rl.agentic.agentic_learner import GRPOConfig as AgenticGRPOConfig +from tunix.rl.agentic.agentic_learner import GRPOLearner as AgenticGRPOLearner from tunix.rl.grpo.grpo_learner import GRPOConfig from tunix.rl.grpo.grpo_learner import GrpoConfig from tunix.rl.grpo.grpo_learner import GRPOLearner diff --git a/tunix/cli/base_agentic_config.yaml b/tunix/cli/base_agentic_config.yaml index dca76e0cd..a9637611e 100644 --- a/tunix/cli/base_agentic_config.yaml +++ b/tunix/cli/base_agentic_config.yaml @@ -231,7 +231,7 @@ env_class_path: null env_kwargs: {} kubernetes_config: null -agentic_grpo_config: {} +agentic_config: {} ############################# Reward Fns ############################# diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3491b885b..97a762b2e 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -69,7 +69,7 @@ class GrpoPipeline(config.HyperParameters): ``training_mode: "agentic_grpo"`` — multi-turn agentic GRPO using GRPOLearner. Additional config sections are recognised: - * ``agentic_grpo_config``: GRPOConfig fields (num_generations, beta, …) + * ``agentic_config``: GRPOConfig fields (num_generations, beta, …) plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``. * role-specific ``*_model_config.mesh``: any role with an explicit mesh gets its own device slice; omitted meshes share the actor mesh by default. @@ -280,7 +280,7 @@ def create_rollout_config( max_response = rollout_cfg.get("total_generation_steps", 0) if mode == "agentic_grpo": - agentic_cfg = self.config.get("agentic_grpo_config", {}) + agentic_cfg = self.config.get("agentic_config", {}) max_turns = agentic_cfg.get("max_turns", 1) context_ratio = agentic_cfg.get("context_ratio", 1) if max_turns > 1: @@ -637,11 +637,11 @@ def _get_dataset(self, tokenizer): # Agentic GRPO helpers # ------------------------------------------------------------------ - def _create_agentic_grpo_config(self): - """Build GRPOConfig (agentic) from the agentic_grpo_config YAML section.""" - from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig # pylint: disable=g-import-not-at-top + def _create_agentic_config(self): + """Build GRPOConfig (agentic) from the agentic_config YAML section.""" + from tunix.rl.agentic.agentic_learner import GRPOConfig # pylint: disable=g-import-not-at-top - cfg = dict(self.config.get("agentic_grpo_config", {})) + cfg = dict(self.config.get("agentic_config", {})) # episode_timeout = per_turn_timeout_secs * max_turns when not explicit if "episode_timeout" not in cfg: @@ -758,8 +758,8 @@ def _run(self, mode: str = "grpo"): if mode != "agentic_grpo": raise ValueError(f"Unsupported training_mode {mode!r}") - from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top - algo_config = self._create_agentic_grpo_config() + from tunix.rl.agentic.agentic_learner import GRPOLearner # pylint: disable=g-import-not-at-top + algo_config = self._create_agentic_config() reward_fns = ( self.obtain_reward_fn() if self.config.get("reward_functions") else None diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_learner.py similarity index 100% rename from tunix/rl/agentic/agentic_grpo_learner.py rename to tunix/rl/agentic/agentic_learner.py