From c1f912d2122affd4efc8f00ad50176a3b658412e Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 7 May 2026 10:21:26 -0700 Subject: [PATCH] Remove `context_ratio` from agentic GRPO config. PiperOrigin-RevId: 912029627 --- examples/deepswe/run_deepswe_disagg_v5p_32.sh | 1 - tests/cli/grpo_main_test.py | 3 +-- tunix/cli/grpo_main.py | 16 +++++----------- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/examples/deepswe/run_deepswe_disagg_v5p_32.sh b/examples/deepswe/run_deepswe_disagg_v5p_32.sh index 4eceb7ae2..5b73f063e 100755 --- a/examples/deepswe/run_deepswe_disagg_v5p_32.sh +++ b/examples/deepswe/run_deepswe_disagg_v5p_32.sh @@ -151,7 +151,6 @@ python -m tunix.cli.grpo_main \ `# ── Agentic / multi-turn ─────────────────────────────────────────────` \ agentic_grpo_config.max_turns=20 \ agentic_grpo_config.per_turn_timeout_secs=300 \ - agentic_grpo_config.context_ratio=2 \ agentic_grpo_config.max_concurrency=100 \ \ `# ── GRPO algorithm ───────────────────────────────────────────────────` \ diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index b68d98117..89879adbb 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -557,8 +557,7 @@ def test_single_turn_kv_cache(self): def test_multi_turn_kv_cache(self): p = self._make_agentic_pipeline(max_turns=20, context_ratio=2) cfg = p.create_rollout_config() - # max_prompt=256, max_response=512, 20 turns * ratio 2 - self.assertEqual(cfg.kv_cache_size, 256 + 512 * 2 * 20) + self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256) def test_standard_grpo_kv_cache(self): extra = """ diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 3491b885b..e02053a8a 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -70,7 +70,7 @@ class GrpoPipeline(config.HyperParameters): GRPOLearner. Additional config sections are recognised: * ``agentic_grpo_config``: GRPOConfig fields (num_generations, beta, …) - plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``. + plus ``max_turns``, ``per_turn_timeout_secs``. * role-specific ``*_model_config.mesh``: any role with an explicit mesh gets its own device slice; omitted meshes share the actor mesh by default. * role-specific ``same_mesh_as``: optional mesh sharing like @@ -259,8 +259,8 @@ def create_rollout_config( Standard mode: pass rollout_config fields through with kv_cache_size = max_prompt_length + total_generation_steps + 256. - Agentic mode: same base, but multi-turn KV cache = - max_prompt + total_generation_steps * context_ratio * max_turns. + Agentic mode: same base. Same kv_cache_size calculation. + Engine-specific extras (sglang_jax_config, vllm_config) are also applied. """ rollout_cfg = self.config["rollout_config"] @@ -281,12 +281,7 @@ def create_rollout_config( if mode == "agentic_grpo": agentic_cfg = self.config.get("agentic_grpo_config", {}) - max_turns = agentic_cfg.get("max_turns", 1) - context_ratio = agentic_cfg.get("context_ratio", 1) - if max_turns > 1: - kv_cache_size = max_prompt + max_response * context_ratio * max_turns - else: - kv_cache_size = max_prompt + max_response + 256 + kv_cache_size = max_prompt + max_response + 256 filtered["kv_cache_size"] = kv_cache_size logging.info("kv_cache_size: %d", kv_cache_size) @@ -659,7 +654,6 @@ def _create_agentic_grpo_config(self): # Strip helper keys that are not GRPOConfig fields valid = {f.name for f in dataclasses.fields(GRPOConfig)} cfg.pop("max_turns", None) - cfg.pop("context_ratio", None) return GRPOConfig(**{k: v for k, v in cfg.items() if k in valid}) def _create_chat_parser(self, tokenizer: Any) -> Any: @@ -744,7 +738,7 @@ def _run(self, mode: str = "grpo"): rl_cluster = self.create_rl_cluster(tokenizer) if mode == "grpo": - from tunix.rl.grpo import grpo_learner + from tunix.rl.grpo import grpo_learner # pylint: disable=g-import-not-at-top grpo_trainer = grpo_learner.GrpoLearner( rl_cluster=rl_cluster,