Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/deepswe/run_deepswe_disagg_v5p_32.sh
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ python -m tunix.cli.grpo_main \
`# ── Agentic / multi-turn ─────────────────────────────────────────────` \
agentic_grpo_config.max_turns=20 \
agentic_grpo_config.per_turn_timeout_secs=300 \
agentic_grpo_config.context_ratio=2 \
agentic_grpo_config.max_concurrency=100 \
\
`# ── GRPO algorithm ───────────────────────────────────────────────────` \
Expand Down
3 changes: 1 addition & 2 deletions tests/cli/grpo_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,7 @@ def test_single_turn_kv_cache(self):
def test_multi_turn_kv_cache(self):
p = self._make_agentic_pipeline(max_turns=20, context_ratio=2)
cfg = p.create_rollout_config()
# max_prompt=256, max_response=512, 20 turns * ratio 2
self.assertEqual(cfg.kv_cache_size, 256 + 512 * 2 * 20)
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)

def test_standard_grpo_kv_cache(self):
extra = """
Expand Down
16 changes: 5 additions & 11 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class GrpoPipeline(config.HyperParameters):
GRPOLearner. Additional config sections are recognised:

* ``agentic_grpo_config``: GRPOConfig fields (num_generations, beta, …)
plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``.
plus ``max_turns``, ``per_turn_timeout_secs``.
* role-specific ``*_model_config.mesh``: any role with an explicit mesh gets
its own device slice; omitted meshes share the actor mesh by default.
* role-specific ``same_mesh_as``: optional mesh sharing like
Expand Down Expand Up @@ -259,8 +259,8 @@ def create_rollout_config(
Standard mode: pass rollout_config fields through with kv_cache_size =
max_prompt_length + total_generation_steps + 256.

Agentic mode: same base, but multi-turn KV cache =
max_prompt + total_generation_steps * context_ratio * max_turns.
Agentic mode: same base. Same kv_cache_size calculation.

Engine-specific extras (sglang_jax_config, vllm_config) are also applied.
"""
rollout_cfg = self.config["rollout_config"]
Expand All @@ -281,12 +281,7 @@ def create_rollout_config(

if mode == "agentic_grpo":
agentic_cfg = self.config.get("agentic_grpo_config", {})
max_turns = agentic_cfg.get("max_turns", 1)
context_ratio = agentic_cfg.get("context_ratio", 1)
if max_turns > 1:
kv_cache_size = max_prompt + max_response * context_ratio * max_turns
else:
kv_cache_size = max_prompt + max_response + 256
kv_cache_size = max_prompt + max_response + 256
filtered["kv_cache_size"] = kv_cache_size
logging.info("kv_cache_size: %d", kv_cache_size)

Expand Down Expand Up @@ -659,7 +654,6 @@ def _create_agentic_grpo_config(self):
# Strip helper keys that are not GRPOConfig fields
valid = {f.name for f in dataclasses.fields(GRPOConfig)}
cfg.pop("max_turns", None)
cfg.pop("context_ratio", None)
return GRPOConfig(**{k: v for k, v in cfg.items() if k in valid})

def _create_chat_parser(self, tokenizer: Any) -> Any:
Expand Down Expand Up @@ -744,7 +738,7 @@ def _run(self, mode: str = "grpo"):
rl_cluster = self.create_rl_cluster(tokenizer)

if mode == "grpo":
from tunix.rl.grpo import grpo_learner
from tunix.rl.grpo import grpo_learner # pylint: disable=g-import-not-at-top

grpo_trainer = grpo_learner.GrpoLearner(
rl_cluster=rl_cluster,
Expand Down
Loading