diff --git a/docs/metrics.md b/docs/metrics.md
index d9e74a04c..b8917bb8b 100644
--- a/docs/metrics.md
+++ b/docs/metrics.md
@@ -99,7 +99,7 @@ options = metrics_logger.MetricsLoggerOptions(
logger = metrics_logger.MetricsLogger(metrics_logger_options=options)
```
-With the above, agentic_grpo_learner will by default start an async trajectory
+With the above, agentic_learner will by default start an async trajectory
logger which logs the trajectories including prompts, responses, etc. to the
specified `log_dir`.
diff --git a/docs/quickstart.md b/docs/quickstart.md
index 136a3457d..bc98cd29c 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -323,7 +323,7 @@ out_data = sampler(
During reinforcement learning (RL) training, it is often useful to analyze the
generated trajectories (prompts, responses, rewards, etc.). Tunix provides an
`AsyncTrajectoryLogger` to log this data asynchronously to CSV files without
-blocking the training loop. It's enabled in agentic_grpo_learner by default, if
+blocking the training loop. It's enabled in agentic_learner by default, if
you provide a log directory in your cluster configuration training config.
```python
diff --git a/examples/agentic/gemma_grpo_demo_nb.py b/examples/agentic/gemma_grpo_demo_nb.py
index c33930958..2472c657c 100644
--- a/examples/agentic/gemma_grpo_demo_nb.py
+++ b/examples/agentic/gemma_grpo_demo_nb.py
@@ -63,7 +63,7 @@
from tunix.models.gemma import model as gemma_lib
from tunix.sft import utils
from tunix.utils import script_utils
- from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner
+ from tunix.rl.agentic.agentic_learner import GRPOConfig, GRPOLearner
from flax import nnx
from tunix.cli.utils import model as model_utils
diff --git a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh
index 7feb5b3d9..59fe8ff27 100755
--- a/examples/deepscaler/run_deepscaler_disagg_v5p16.sh
+++ b/examples/deepscaler/run_deepscaler_disagg_v5p16.sh
@@ -128,19 +128,19 @@ python -m tunix.cli.grpo_main \
tokenizer_config.add_eos=false \
\
`# ── GRPO algorithm ───────────────────────────────────────────────────` \
- agentic_grpo_config.num_generations=8 \
- agentic_grpo_config.num_iterations=1 \
- agentic_grpo_config.beta=0.0 \
- agentic_grpo_config.epsilon=0.2 \
- agentic_grpo_config.epsilon_high=0.28 \
- agentic_grpo_config.system_prompt="" \
- agentic_grpo_config.max_concurrency=1024 \
- agentic_grpo_config.max_response_length="$max_response_length" \
- agentic_grpo_config.off_policy_steps=0 \
- agentic_grpo_config.loss_agg_mode="token-mean" \
- agentic_grpo_config.kl_loss_mode="low_var_kl" \
- agentic_grpo_config.max_turns=1 \
- agentic_grpo_config.context_ratio=1 \
+ agentic_config.num_generations=8 \
+ agentic_config.num_iterations=1 \
+ agentic_config.beta=0.0 \
+ agentic_config.epsilon=0.2 \
+ agentic_config.epsilon_high=0.28 \
+ agentic_config.system_prompt="" \
+ agentic_config.max_concurrency=1024 \
+ agentic_config.max_response_length="$max_response_length" \
+ agentic_config.off_policy_steps=0 \
+ agentic_config.loss_agg_mode="token-mean" \
+ agentic_config.kl_loss_mode="low_var_kl" \
+ agentic_config.max_turns=1 \
+ agentic_config.context_ratio=1 \
\
`# ── Optimizer ────────────────────────────────────────────────────────` \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
diff --git a/examples/deepscaler/train_deepscaler_nb.py b/examples/deepscaler/train_deepscaler_nb.py
index 8cc903077..6d0bdb1e9 100644
--- a/examples/deepscaler/train_deepscaler_nb.py
+++ b/examples/deepscaler/train_deepscaler_nb.py
@@ -60,7 +60,7 @@
from tunix.models.qwen2 import params as params_lib
from tunix.models.qwen2 import model as model_lib
from tunix.sft import metrics_logger
- from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner
+ from tunix.rl.agentic.agentic_learner import GRPOConfig, GRPOLearner
from tunix.rl.agentic.agents import model_agent
from tunix.rl.agentic.environments import task_environment
from tunix.rl.agentic.trajectory import trajectory_collect_engine
diff --git a/examples/deepswe/run_deepswe_disagg_v5p_32.sh b/examples/deepswe/run_deepswe_disagg_v5p_32.sh
index 4eceb7ae2..bff865ef4 100755
--- a/examples/deepswe/run_deepswe_disagg_v5p_32.sh
+++ b/examples/deepswe/run_deepswe_disagg_v5p_32.sh
@@ -149,21 +149,21 @@ python -m tunix.cli.grpo_main \
kubernetes_config.node_selector_val="deepswe-cpu-pool" \
\
`# ── Agentic / multi-turn ─────────────────────────────────────────────` \
- agentic_grpo_config.max_turns=20 \
- agentic_grpo_config.per_turn_timeout_secs=300 \
- agentic_grpo_config.context_ratio=2 \
- agentic_grpo_config.max_concurrency=100 \
+ agentic_config.max_turns=20 \
+ agentic_config.per_turn_timeout_secs=300 \
+ agentic_config.context_ratio=2 \
+ agentic_config.max_concurrency=100 \
\
`# ── GRPO algorithm ───────────────────────────────────────────────────` \
- agentic_grpo_config.num_generations="$num_generations" \
- agentic_grpo_config.max_response_length="$max_response_length" \
- agentic_grpo_config.num_iterations=1 \
- agentic_grpo_config.beta=0.001 \
- agentic_grpo_config.epsilon=0.2 \
- agentic_grpo_config.epsilon_high=0.28 \
- agentic_grpo_config.off_policy_steps=0 \
- agentic_grpo_config.loss_agg_mode="seq-mean-token-mean" \
- agentic_grpo_config.kl_loss_mode="low_var_kl" \
+ agentic_config.num_generations="$num_generations" \
+ agentic_config.max_response_length="$max_response_length" \
+ agentic_config.num_iterations=1 \
+ agentic_config.beta=0.001 \
+ agentic_config.epsilon=0.2 \
+ agentic_config.epsilon_high=0.28 \
+ agentic_config.off_policy_steps=0 \
+ agentic_config.loss_agg_mode="seq-mean-token-mean" \
+ agentic_config.kl_loss_mode="low_var_kl" \
\
`# ── Optimizer ────────────────────────────────────────────────────────` \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
diff --git a/examples/deepswe/train_deepswe_nb.py b/examples/deepswe/train_deepswe_nb.py
index c2623578d..6c07ff450 100644
--- a/examples/deepswe/train_deepswe_nb.py
+++ b/examples/deepswe/train_deepswe_nb.py
@@ -171,7 +171,7 @@
from tunix.sft import metrics_logger
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
-from tunix.rl.agentic import agentic_grpo_learner
+from tunix.rl.agentic import agentic_learner
from tunix.rl.agentic.parser.chat_template_parser import parser as template_parser
from tunix.rl.agentic.rewards.reward_types import RewardOutput
from examples.deepswe.swe_agent import (
@@ -564,7 +564,7 @@ def transform(entry):
# ==========================================
# 11. Learner & Agent Setup
# ==========================================
-grpo_config = agentic_grpo_learner.GRPOConfig(
+grpo_config = agentic_learner.GRPOConfig(
num_generations=NUM_GENERATIONS,
num_iterations=NUM_ITERATIONS,
max_response_length=MAX_RESPONSE_LENGTH,
@@ -579,7 +579,7 @@ def transform(entry):
)
-agentic_grpo_learner = agentic_grpo_learner.GRPOLearner(
+agentic_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=None,
agent_class=SWEAgent,
@@ -652,7 +652,7 @@ def mixed_type_batch_fn(elements):
print("Starting training...")
-agentic_grpo_learner.train(train_dataset=train_dataset)
+agentic_learner.train(train_dataset=train_dataset)
# %%
diff --git a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh
index dfc5e7c11..8c75a8e5e 100755
--- a/examples/rl/grpo/gsm8k/run_qwen3_8b.sh
+++ b/examples/rl/grpo/gsm8k/run_qwen3_8b.sh
@@ -130,15 +130,15 @@ python -m tunix.cli.grpo_main \
tokenizer_config.add_eos=false \
\
`# -- GRPO algorithm ---------------------------------------------------` \
- agentic_grpo_config.num_generations="$num_generations" \
- agentic_grpo_config.num_iterations=1 \
- agentic_grpo_config.beta=0.08 \
- agentic_grpo_config.epsilon=0.2 \
- agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \
- agentic_grpo_config.max_concurrency=128 \
- agentic_grpo_config.max_response_length=768 \
- agentic_grpo_config.max_turns=1 \
- agentic_grpo_config.context_ratio=1 \
+ agentic_config.num_generations="$num_generations" \
+ agentic_config.num_iterations=1 \
+ agentic_config.beta=0.08 \
+ agentic_config.epsilon=0.2 \
+ agentic_config.system_prompt="You are given a grade school math problem. Think step by step and respond using ... followed by ... with only the final numeric answer inside ." \
+ agentic_config.max_concurrency=128 \
+ agentic_config.max_response_length=768 \
+ agentic_config.max_turns=1 \
+ agentic_config.context_ratio=1 \
\
`# -- Optimizer --------------------------------------------------------` \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py
index b68d98117..5cd44cccc 100644
--- a/tests/cli/grpo_main_test.py
+++ b/tests/cli/grpo_main_test.py
@@ -224,7 +224,7 @@ def test_agentic_data_module_receives_data_config_for_raw_dataset(self):
env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
@@ -277,7 +277,7 @@ def test_agentic_nullable_string_can_be_overridden_from_cli(self):
env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
@@ -321,7 +321,7 @@ def test_agentic_nullable_dict_can_be_overridden_from_cli(self):
env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
@@ -371,7 +371,7 @@ def test_agentic_nullable_string_can_be_overridden_from_env(self):
env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
@@ -417,7 +417,7 @@ def test_standard_grpo_dispatches_to_standard(self):
pipeline.run_grpo_trainer()
mock_run.assert_called_once_with(mode="grpo")
- def test_agentic_grpo_dispatches_to_agentic(self):
+ def test_agentic_dispatches_to_agentic(self):
extra = """
training_mode: "agentic_grpo"
data_module: "tunix.cli.recipes.deepscaler_data"
@@ -436,7 +436,7 @@ def test_agentic_grpo_dispatches_to_agentic(self):
env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
@@ -530,7 +530,7 @@ def _make_agentic_pipeline(self, max_turns, context_ratio):
env_class_path: null
env_kwargs: {{}}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
@@ -603,7 +603,7 @@ def _base_extra(self, agentic_overrides="", system_prompt='""'):
env_class_path: null
env_kwargs: {{}}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.001
@@ -623,26 +623,26 @@ def test_episode_timeout_computed(self):
p = _make_pipeline(
self._base_extra("max_turns: 20\n per_turn_timeout_secs: 300")
)
- algo = p._create_agentic_grpo_config()
+ algo = p._create_agentic_config()
self.assertEqual(algo.episode_timeout, 300 * 20)
def test_max_response_length_from_rollout(self):
p = _make_pipeline(self._base_extra("max_turns: 1"))
- algo = p._create_agentic_grpo_config()
+ algo = p._create_agentic_config()
# rollout_config.total_generation_steps = 512
self.assertEqual(algo.max_response_length, 512)
def test_num_generations_passed_through(self):
p = _make_pipeline(self._base_extra("max_turns: 1"))
- algo = p._create_agentic_grpo_config()
+ algo = p._create_agentic_config()
self.assertEqual(algo.num_generations, 2)
def test_cli_empty_system_prompt_stays_empty_string(self):
p = _make_pipeline_with_cli_args(
self._base_extra("max_turns: 1", system_prompt='"base"'),
- ['agentic_grpo_config.system_prompt=""'],
+ ['agentic_config.system_prompt=""'],
)
- self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "")
+ self.assertEqual(p.config["agentic_config"]["system_prompt"], "")
class SplitMeshConfigTest(absltest.TestCase):
@@ -665,7 +665,7 @@ def test_split_mesh_uses_explicit_role_meshes(self):
env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config:
+agentic_config:
num_generations: 2
num_iterations: 1
beta: 0.0
diff --git a/tests/rl/agentic/agentic_grpo_learner_test.py b/tests/rl/agentic/agentic_learner_test.py
similarity index 96%
rename from tests/rl/agentic/agentic_grpo_learner_test.py
rename to tests/rl/agentic/agentic_learner_test.py
index 571c37073..520fb790e 100644
--- a/tests/rl/agentic/agentic_grpo_learner_test.py
+++ b/tests/rl/agentic/agentic_learner_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Tests for agentic_grpo_learner."""
+"""Tests for agentic_learner."""
import asyncio
import functools
@@ -42,7 +42,7 @@
from tunix.rl import common as rl_common
from tunix.rl import function_registry
from tunix.rl import rl_cluster as rl_cluster_lib
-from tunix.rl.agentic import agentic_grpo_learner
+from tunix.rl.agentic import agentic_learner
from tunix.rl.agentic.agents.agent_types import Action, Step
from tunix.rl.agentic.agents.base_agent import ConversationAgentBase
from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult
@@ -55,7 +55,7 @@
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"
Mesh = sharding.Mesh
-TrainingInputT = agentic_grpo_learner.TrainingInputT
+TrainingInputT = agentic_learner.TrainingInputT
def reward_fn_1(prompts, completions, **kwargs):
@@ -194,7 +194,7 @@ def assistant_token(self):
return "Assistant: "
-class _LearnerWithException(agentic_grpo_learner.GRPOLearner):
+class _LearnerWithException(agentic_learner.GRPOLearner):
def _batch_to_train_example(self, batch_results, mode):
raise ValueError("test exception in producer")
@@ -218,7 +218,7 @@ def setUp(self):
)
def test_iterator(self):
- class _MockTrainer(agentic_grpo_learner.GRPOLearner):
+ class _MockTrainer(agentic_learner.GRPOLearner):
def __init__(self, algo_config):
self.algo_config = algo_config
@@ -276,7 +276,7 @@ async def _orchestrator_producer(
yield group, [example]
i += 1
- algo_config = agentic_grpo_learner.GRPOConfig(
+ algo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=2,
)
@@ -305,11 +305,11 @@ def test_grpo_config_validation(self):
with self.assertRaisesRegex(
ValueError, "num_generations must be greater than 1"
):
- agentic_grpo_learner.GRPOConfig(num_generations=1)
+ agentic_learner.GRPOConfig(num_generations=1)
with self.assertRaisesRegex(
ValueError, "loss_algo should be either grpo or gspo-token"
):
- agentic_grpo_learner.GRPOConfig(loss_algo="invalid")
+ agentic_learner.GRPOConfig(loss_algo="invalid")
def test_num_iterations_greater_than_1(self):
vocab = test_common.MockVocab()
@@ -354,13 +354,13 @@ def test_num_iterations_greater_than_1(self):
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=2, # > 1
loss_algo="grpo",
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -411,7 +411,7 @@ def test_grpo_loss_fn(self, loss_algo):
(batch_size, seq_len), -0.1, dtype=jnp.float32
)
- train_example = agentic_grpo_learner.TrainExample(
+ train_example = agentic_learner.TrainExample(
prompt_ids=prompt_ids,
prompt_mask=prompt_ids > -1,
completion_ids=completion_ids,
@@ -436,7 +436,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
None,
)
- algo_config = agentic_grpo_learner.GRPOConfig(
+ algo_config = agentic_learner.GRPOConfig(
beta=0.1,
epsilon=0.2,
loss_algo=loss_algo,
@@ -508,7 +508,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
# Unmasked example
final_completion_mask = completion_mask
- train_example = agentic_grpo_learner.TrainExample(
+ train_example = agentic_learner.TrainExample(
prompt_ids=prompt_ids,
prompt_mask=prompt_ids > -1,
completion_ids=completion_ids,
@@ -518,7 +518,7 @@ def __call__(self, inputs, positions, cache, attention_mask):
old_per_token_logps=None,
)
- config = agentic_grpo_learner.GRPOConfig(
+ config = agentic_learner.GRPOConfig(
beta=0.1,
epsilon=0.2,
num_generations=2,
@@ -624,7 +624,7 @@ def mock_compute_rewards(prompts, completions, **kwargs):
tokenizer=tokenizer,
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
beta=0.1,
epsilon=0.2,
num_generations=2,
@@ -632,7 +632,7 @@ def mock_compute_rewards(prompts, completions, **kwargs):
max_response_length=10,
)
- learner = agentic_grpo_learner.GRPOLearner(
+ learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=None,
algo_config=grpo_config,
@@ -724,7 +724,7 @@ def __init__(
tokenizer=tokenizer,
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
beta=0.1,
epsilon=0.2,
num_generations=2,
@@ -733,7 +733,7 @@ def __init__(
max_response_length=10,
)
- learner = agentic_grpo_learner.GRPOLearner(
+ learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=None,
algo_config=grpo_config,
@@ -836,12 +836,12 @@ def create_learner(
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -938,12 +938,12 @@ def create_learner(
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn,
algo_config=grpo_config,
@@ -1045,12 +1045,12 @@ def create_learner(
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn,
algo_config=grpo_config,
@@ -1144,14 +1144,14 @@ def test_force_compute_kl(self, beta, force_compute_kl, expect_ref_logps):
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
beta=beta,
force_compute_kl=force_compute_kl,
max_response_length=10,
num_generations=2,
num_iterations=1,
)
- learner = agentic_grpo_learner.GRPOLearner(
+ learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -1233,7 +1233,7 @@ def test_exception_handling(self):
tokenizer=tokenizer,
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(max_response_length=10)
+ grpo_config = agentic_learner.GRPOConfig(max_response_length=10)
learner = _LearnerWithException(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
@@ -1310,13 +1310,13 @@ def test_grpo_learner(self, reward_fns, loss_algo, use_old_logprobs=False):
)
rl_cluster.with_external_metrics_logger(print)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
loss_algo=loss_algo,
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fns,
algo_config=grpo_config,
@@ -1466,14 +1466,14 @@ def test_on_off_policy_training(self, offpolicy_steps):
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
loss_algo="grpo",
off_policy_steps=offpolicy_steps,
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -1534,8 +1534,8 @@ def test_put_prompts_to_queue(self):
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(max_response_length=512)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_config = agentic_learner.GRPOConfig(max_response_length=512)
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -1600,13 +1600,13 @@ def test_trajectory_logging(self):
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
loss_algo="grpo",
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -1695,13 +1695,13 @@ def test_grpo_with_lora_model(self):
tokenizer=tokenizer,
cluster_config=cluster_config,
)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
max_response_length=10,
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -1819,14 +1819,14 @@ def update_from_model(self, response, **kwargs):
)
rl_cluster.with_external_metrics_logger(print)
- grpo_config = agentic_grpo_learner.GRPOConfig(
+ grpo_config = agentic_learner.GRPOConfig(
num_generations=2,
num_iterations=1,
loss_algo="grpo",
max_response_length=128,
max_concurrency=1, # so the output is deterministic.
)
- grpo_learner = agentic_grpo_learner.GRPOLearner(
+ grpo_learner = agentic_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=reward_fn_1,
algo_config=grpo_config,
@@ -1901,7 +1901,7 @@ def _patch_process_results(
def test_compute_rloo_advantages(self):
rewards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- advantages = agentic_grpo_learner.compute_rloo_advantages(
+ advantages = agentic_learner.compute_rloo_advantages(
rewards, num_generations=3
)
expected_value = jnp.array([-1.5, 0.0, 1.5, -1.5, 0.0, 1.5])
@@ -1909,7 +1909,7 @@ def test_compute_rloo_advantages(self):
def test_compute_rloo_advantages_low_generations(self):
rewards = jnp.array([1.0, 2.0])
- advantages = agentic_grpo_learner.compute_rloo_advantages(
+ advantages = agentic_learner.compute_rloo_advantages(
rewards, num_generations=1
)
np.testing.assert_allclose(advantages, jnp.zeros_like(rewards))
diff --git a/tunix/__init__.py b/tunix/__init__.py
index 7ddbffa7f..a91cd3149 100644
--- a/tunix/__init__.py
+++ b/tunix/__init__.py
@@ -33,8 +33,8 @@
from tunix.perf.export import PerfMetricsExport
from tunix.perf.metrics import PerfMetricsConfig
from tunix.perf.metrics import PerfSpanQuery
-from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig as AgenticGRPOConfig
-from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner as AgenticGRPOLearner
+from tunix.rl.agentic.agentic_learner import GRPOConfig as AgenticGRPOConfig
+from tunix.rl.agentic.agentic_learner import GRPOLearner as AgenticGRPOLearner
from tunix.rl.grpo.grpo_learner import GRPOConfig
from tunix.rl.grpo.grpo_learner import GrpoConfig
from tunix.rl.grpo.grpo_learner import GRPOLearner
diff --git a/tunix/cli/base_agentic_config.yaml b/tunix/cli/base_agentic_config.yaml
index dca76e0cd..a9637611e 100644
--- a/tunix/cli/base_agentic_config.yaml
+++ b/tunix/cli/base_agentic_config.yaml
@@ -231,7 +231,7 @@ env_class_path: null
env_kwargs: {}
kubernetes_config: null
-agentic_grpo_config: {}
+agentic_config: {}
############################# Reward Fns #############################
diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py
index 3491b885b..97a762b2e 100644
--- a/tunix/cli/grpo_main.py
+++ b/tunix/cli/grpo_main.py
@@ -69,7 +69,7 @@ class GrpoPipeline(config.HyperParameters):
``training_mode: "agentic_grpo"`` — multi-turn agentic GRPO using
GRPOLearner. Additional config sections are recognised:
- * ``agentic_grpo_config``: GRPOConfig fields (num_generations, beta, …)
+ * ``agentic_config``: GRPOConfig fields (num_generations, beta, …)
plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``.
* role-specific ``*_model_config.mesh``: any role with an explicit mesh gets
its own device slice; omitted meshes share the actor mesh by default.
@@ -280,7 +280,7 @@ def create_rollout_config(
max_response = rollout_cfg.get("total_generation_steps", 0)
if mode == "agentic_grpo":
- agentic_cfg = self.config.get("agentic_grpo_config", {})
+ agentic_cfg = self.config.get("agentic_config", {})
max_turns = agentic_cfg.get("max_turns", 1)
context_ratio = agentic_cfg.get("context_ratio", 1)
if max_turns > 1:
@@ -637,11 +637,11 @@ def _get_dataset(self, tokenizer):
# Agentic GRPO helpers
# ------------------------------------------------------------------
- def _create_agentic_grpo_config(self):
- """Build GRPOConfig (agentic) from the agentic_grpo_config YAML section."""
- from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig # pylint: disable=g-import-not-at-top
+ def _create_agentic_config(self):
+ """Build GRPOConfig (agentic) from the agentic_config YAML section."""
+ from tunix.rl.agentic.agentic_learner import GRPOConfig # pylint: disable=g-import-not-at-top
- cfg = dict(self.config.get("agentic_grpo_config", {}))
+ cfg = dict(self.config.get("agentic_config", {}))
# episode_timeout = per_turn_timeout_secs * max_turns when not explicit
if "episode_timeout" not in cfg:
@@ -758,8 +758,8 @@ def _run(self, mode: str = "grpo"):
if mode != "agentic_grpo":
raise ValueError(f"Unsupported training_mode {mode!r}")
- from tunix.rl.agentic.agentic_grpo_learner import GRPOLearner # pylint: disable=g-import-not-at-top
- algo_config = self._create_agentic_grpo_config()
+ from tunix.rl.agentic.agentic_learner import GRPOLearner # pylint: disable=g-import-not-at-top
+ algo_config = self._create_agentic_config()
reward_fns = (
self.obtain_reward_fn() if self.config.get("reward_functions") else None
diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_learner.py
similarity index 100%
rename from tunix/rl/agentic/agentic_grpo_learner.py
rename to tunix/rl/agentic/agentic_learner.py