diff --git a/examples/frozenlake/agent.py b/examples/frozenlake/agent.py new file mode 100644 index 000000000..cda52acdd --- /dev/null +++ b/examples/frozenlake/agent.py @@ -0,0 +1,210 @@ +import logging +import re +from typing import Any, Optional + +from examples.frozenlake.env import FrozenLakeEnv +from tunix.rl.agentic.agents import agent_types +from tunix.rl.agentic.agents import base_agent + +# Prompting format inspired by the RAGEN project: https://github.com/RAGEN-AI/RAGEN +SYSTEM_PROMPT: str = """You are walking on a frozen lake. + +FrozenLake Quick Guide +Goal: Reach the goal (G). Player (P) and Goal (G) must overlap. + +Symbols: +_ Frozen | O Hole | G Goal | P Player + +Rules: +1. Avoid falling into holes (O). +2. Frozen tiles are slippery, you may move perpendicular to your intended direction. + +Valid Action (separated by | ): +Up | Down | Left | Right + +Rewards: +Fall into hole: 0 +Reach goal: +1.0 + +You will be provided the current observation, please decide on the next Action. +You should show your thought process and then input the final action in ``` ```. +You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```. +You should plan ahead and need to achieve it in minimum number of steps. +You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it. + +Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right. +""" + +MULTI_SHOT_SYSTEM_PROMPT: str = """You are a helpful assistant. You are walking on a frozen lake. + +FrozenLake Quick Guide +Goal: Reach the goal (G). Player (P) and Goal (G) must overlap. + +Symbols: +_ Frozen | O Hole | G Goal | P Player + +Rules: +1. Avoid falling into holes (O). +2. Frozen tiles are slippery, you may move perpendicular to your intended direction. + +Valid Action (separated by | ): +Up | Down | Left | Right + +Rewards: +Fall into hole: 0 +Reach goal: +1.0 + +You will be provided the current observation, please decide on the next Action. +You should show your thought process and then input the final action in ``` ```. +You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```. +You should plan ahead and need to achieve it in minimum number of steps. + +Below are examples for an interaction: +Example1: +User: Current Observation: +P _ _ _ _ +O _ _ O _ +O _ O _ _ +O _ _ G _ +_ _ _ _ _ +You have not achieved the goal, P has not reached G yet. Please give the next action. + +Assistant: P is now at the top right corner. It should reach G at the bottom right corner. I should move it closer to it. I can move right or down but there is a hole in down position and I can not move diagonally. There is no hole in my next movement right so I can move to right. Action: ```Right``` + +Example2: +User: Current Observation: +_ _ _ _ +_ _ _ O +_ O _ P +O _ _ G +You have not achieved the goal, P has not reached G yet. Please give the next action. + +Assistant: P is now at the near G. It should reach G to its bottom. I should move to be on it. There is no hole in my next movement so I can move to down. Action: ```Down``` + +Example3: +User: Current Observation: +_ _ _ O _ +O _ P O _ +O _ O _ _ +O _ _ G _ +_ _ _ _ _ +You have not achieved the goal, P has not reached G yet. Please give the next action. + +Assistant: G is at the bottom right relative to P. I want to move closer so I should move right or down. But there is a hole at each position and I do not want to fall into holes. Up and left are both valid but left brings me closer. Action: ```Left``` + +Example4: +User: Current Observation: +_ _ _ _ +_ _ _ O +_ O _ O +O G P _ +You have not achieved the goal, P has not reached G yet. Please give the next action. + +Assistant: P is now near G. But game has not finished. P is not at G and I should never output invalid action. I need to recheck my understanding. P is not actually on G yet because they are not overlapping, it needs reach G to its left. Action: ```Left``` + +Example5: +User: Current Observation: +_ _ _ O _ +O _ P _ _ +O _ O O O +O _ O G _ +O _ _ _ _ +You have not achieved the goal, P has not reached G yet. Please give the next action. + +Assistant: G is at the bottom right corner of P. I can move left, right, or up. Move right will initially bring me closer but I can't reach G that way. Move up and left means I can still reach G. Move up will result in 9 steps in total while left is 7 steps. I need to move left. Action: ```Left``` + +Now it is your turn, please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right. +""" + + +class FrozenLakeAgent(base_agent.ConversationAgentBase): + + def __init__( + self, + system_prompt: Optional[str] = None, + use_multistep_prompt: bool | None = True, + ): + self.multistep_prompt = use_multistep_prompt + system_prompt = ( + SYSTEM_PROMPT + if not self.multistep_prompt + else MULTI_SHOT_SYSTEM_PROMPT + ) + super().__init__(system_prompt=system_prompt) + self.last_observation = None + + def _init_messages(self, system_prompt: str) -> None: + """Initialize conversation history with a system prompt. + + Subclasses may override this to inject additional content (e.g., tool + documentation) into the initial system message. + + Args: + system_prompt: The system prompt to use. + """ + self._messages = [{"role": "system", "content": system_prompt or ""}] + + def update_from_env( + self, + observation: Any, + reward: float, + done: bool, + info: dict[str, Any] | None = None, + **kwargs, + ) -> None: + new_obs_str = str(observation) + # Base message for the user + new_obs_str = "Current Observation: \n" + new_obs_str + if not done: + new_obs_str += "\n" + "You have not achieved the goal, P has not reached G yet. Please give the next action." + + # Check if the observation is the same as the previous step's observation + if self.last_observation and self.last_observation == new_obs_str: + new_obs_str += "\nYour last response is invalid. Your position didn't change at all. You may need to recheck your thinking process, action outputted, and the format of response. Remember, you should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```." + self.last_observation = new_obs_str + + super().update_from_env(new_obs_str, reward, done, info) + self.cur_step = agent_types.Step(observation=new_obs_str) + + def _observation_to_messages( + self, observation: Any, reward: float, done: bool, info: dict[str, Any] + ) -> None: + self._messages.append({"role": "user", "content": str(observation)}) + + def update_from_model(self, response: str, **kwargs) -> agent_types.Action: + DIRECTION_MAP = {"left": 1, "down": 2, "right": 3, "up": 4} + + thought = response + action_str = str(FrozenLakeEnv.INVALID_ACTION) + + matches = re.findall(r"```(.*?)```", response, re.DOTALL) + + if matches: + last_match_content = matches[-1].strip() + last_match_index = response.rfind(f"```{last_match_content}```") + if last_match_index != -1: + thought = response[:last_match_index].strip() + + extracted_text = last_match_content.lower() + + if extracted_text in DIRECTION_MAP: + action_str = str(DIRECTION_MAP[extracted_text]) + elif extracted_text.isdigit() and int(extracted_text) in DIRECTION_MAP.values(): + action_str = str(int(extracted_text)) + + # Add assistant's response to conversation history. + self._messages.append({"role": "assistant", "content": response}) + + self._trajectory.steps.append(self.cur_step) + # Record complete step with conversation context and parsed action. + cur_step = self._trajectory.steps[-1] + cur_step.thought = thought + cur_step.action = action_str + cur_step.model_response = response + + self.step += 1 + return agent_types.Action(action=cur_step.action) + + def reset(self) -> None: + super().reset() + self.last_observation = None diff --git a/examples/frozenlake/data.py b/examples/frozenlake/data.py new file mode 100644 index 000000000..3997c2343 --- /dev/null +++ b/examples/frozenlake/data.py @@ -0,0 +1,159 @@ +""" +FrozenLake Dataset Generator + +This script generates training and test datasets for the FrozenLake environment. +Each dataset entry contains environment configuration parameters (seed, size, p) +that can be used to create FrozenLake environment instances. + +The generated datasets are saved as Parquet files and can be used for training +reinforcement learning agents on various FrozenLake configurations. + +Usage: + python recipes/frozenlake/data.py --train_size 10000 --test_size 100 + +The script generates: +- Training dataset: Random FrozenLake configurations for agent training +- Test dataset: Separate set of configurations for evaluation +""" + +import argparse +import os + +import numpy as np +import pandas as pd + + +DEFAULT_DIR = os.getcwd() + + +def get_frozenlake_dict(seed: int, size: int, p: float) -> dict: + """ + Create a dictionary with FrozenLake environment configuration parameters. + + Args: + seed: Random seed for environment generation + size: Grid size (size x size grid) + p: Probability of moving in the intended direction (1-p = slip probability) + + Returns: + Dictionary containing environment configuration with keys: seed, size, p + """ + return {"env_name": "frozenlake", "seed": int(seed), "size": int(size), "p": float(p)} + + +def generate_dataset_parameters( + size: int, random_seed: int = 42 +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Generate random parameters for FrozenLake environments. + + This function creates diverse environment configurations by sampling: + - Random seeds for environment generation + - Grid sizes ranging from 2x2 to 9x9 + - Slip probabilities between 0.15-0.4 (p values 0.6-0.85) + + Args: + size: Number of environment configurations to generate + random_seed: Random seed for reproducible parameter generation + + Returns: + Tuple of (seeds, sizes, p_values) numpy arrays + """ + np.random.seed(random_seed) + seeds = np.random.randint(0, 100000, size=size) + sizes = np.random.randint(2, 10, size=size) # Grid sizes from 2x2 to 9x9 + p_values = np.random.uniform( + 0.6, 0.85, size=size + ) # Slip probability between 0.15-0.4 + + return seeds, sizes, p_values + + +def save_dataset(data: list[dict], filepath: str) -> None: + """ + Save dataset to Parquet file format. + + Converts the list of environment configuration dictionaries to a pandas + DataFrame and saves it as a Parquet file for efficient storage and loading. + + Args: + data: List of environment configuration dictionaries + filepath: Full path where to save the Parquet file + """ + df = pd.DataFrame(data) + df.to_parquet(filepath) + print(f"Saved {len(data)} entries to {filepath}") + + +def main(): + """Main function to generate and save FrozenLake datasets. + + Parses command line arguments, generates training and test datasets with + different random seeds for diversity, and saves them as Parquet files. + """ + parser = argparse.ArgumentParser( + description=( + "Generate FrozenLake environment configuration datasets for training" + " and testing." + ) + ) + parser.add_argument( + "--local_dir", + default=os.path.join(DEFAULT_DIR, "data/frozenlake"), + help="Local directory to save the datasets", + ) + parser.add_argument( + "--hdfs_dir", + default=None, + help="HDFS directory to copy datasets to (optional)", + ) + parser.add_argument( + "--train_size", + type=int, + default=10000, + help=( + "Number of training environment configurations to generate (default:" + " 10000)" + ), + ) + parser.add_argument( + "--test_size", + type=int, + default=100, + help=( + "Number of test environment configurations to generate (default: 100)" + ), + ) + + args = parser.parse_args() + + # Create local directory + local_dir = os.path.expanduser(args.local_dir) + print(f"Using local directory: {local_dir}") + os.makedirs(local_dir, exist_ok=True) + print(f"Using local directory: {local_dir}") + + # Generate training dataset parameters + train_seeds, train_sizes, train_ps = generate_dataset_parameters( + args.train_size, random_seed=42 + ) + train_data = [ + get_frozenlake_dict(seed, train_sizes[idx], train_ps[idx]) + for idx, seed in enumerate(train_seeds) + ] + + # Generate test dataset parameters (different random seed for diversity) + test_seeds, test_sizes, test_ps = generate_dataset_parameters( + args.test_size, random_seed=123 + ) + test_data = [ + get_frozenlake_dict(seed, test_sizes[idx], test_ps[idx]) + for idx, seed in enumerate(test_seeds) + ] + + # Save datasets as Parquet files + save_dataset(train_data, os.path.join(local_dir, "train.parquet")) + save_dataset(test_data, os.path.join(local_dir, "test.parquet")) + + +if __name__ == "__main__": + main() diff --git a/examples/frozenlake/env.py b/examples/frozenlake/env.py new file mode 100644 index 000000000..c51136b04 --- /dev/null +++ b/examples/frozenlake/env.py @@ -0,0 +1,339 @@ +"""DISCLAIMER: + +This implementation is based on the Gymnasium FrozenLake environment and the +RAGEN project: +- Gymnasium: https://gymnasium.farama.org/environments/toy_text/frozen_lake/ +- RAGEN: +https://github.com/RAGEN-AI/RAGEN/blob/main/ragen/env/frozen_lake/env.py + + Some components have been modified or extended for custom use in this project. +""" + +import copy +from typing import Any, Dict +from absl import logging +import gymnasium as gym +from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv as GymFrozenLakeEnv +from gymnasium.utils import seeding +import numpy as np +from tunix.rl.agentic.environments.base_environment import BaseTaskEnv, EnvStepResult + +MAX_STEPS: int = 5 + + +# DFS to check that it's a valid path. +def is_valid(board: list[list[str]], max_size: int) -> bool: + frontier, discovered = [], set() + # find the start point + start_r, start_c = np.where(np.array(board) == "S") + frontier.append((start_r[0], start_c[0], 0)) # row, col steps + # dfs to check if there is a path from start to goal + while frontier: + r, c, steps = frontier.pop() + if steps > MAX_STEPS: + continue + + if (r, c) not in discovered: + discovered.add((r, c)) + directions = [(1, 0), (0, 1), (-1, 0), (0, -1)] + for x, y in directions: + r_new = r + x + c_new = c + y + if r_new < 0 or r_new >= max_size or c_new < 0 or c_new >= max_size: + continue + if board[r_new][c_new] == "G": + return True + if board[r_new][c_new] != "H": + frontier.append((r_new, c_new, steps + 1)) + return False + + +def generate_random_map( + size: int = 8, p: float = 0.8, seed: int = 0 +) -> tuple[list[str], tuple[int, int]]: + """Generates a random valid map (one that has a path from start to goal) + + Args: + size: size of each side of the grid + p: probability that a tile is frozen + seed: seed to ensure the generation of reproducible maps + + Returns: + A random valid map + """ + valid = False + board: list[list[str]] = [] # initialize to make pyright happy + + np_random, _ = seeding.np_random(seed) + + # generate random start and end points + goal_r, goal_c = -1, -1 + while not valid: + p = min(1, p) + board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p]).tolist() + while True: + start_r = int(np_random.integers(0, size)) + start_c = int(np_random.integers(0, size)) + goal_r = int(np_random.integers(0, size)) + goal_c = int(np_random.integers(0, size)) + # Ensure start and goal are different positions + if (start_r, start_c) != (goal_r, goal_c): + break + board[start_r][start_c] = "S" + board[goal_r][goal_c] = "G" + valid = is_valid(board, size) + return ["".join(x) for x in board], (goal_r, goal_c) + + +def get_goal_position(random_map): + positions = np.argwhere(random_map == b"G") + if positions.size == 0: + return None # G not found + return tuple(positions[0]) # returns (row, col) + + +class FrozenLakeEnv(BaseTaskEnv, GymFrozenLakeEnv): + """Inherits from gymnasium.envs.toy_text.frozen_lake.FrozenLakeEnv + + ## Description + The game starts with the player at random location of the frozen lake grid + world with the + goal located at another random location for the 4x4 environment. + + ## Action Space + The action shape is `(1,)` in the range `{0, 3}` indicating + which direction to move the player. + NOTE the action space is different from + gymnasium.envs.toy_text.frozen_lake.FrozenLakeEnv, start from 1 + - 0: Still + - 1: Left + - 2: Down + - 3: Right + - 4: Up + + ## Starting State + The episode starts with the player at random location + + ## Rewards + NOTE added -0.1 as penalty for invalid action + Reward schedule: + - Reach goal: +1 + - Reach hole: 0 + - Reach frozen: 0 + + ## Arguments + `is_slippery`: if action is left and is_slippery is True, then: + - P(move left)=1/3 + - P(move up)=1/3 + - P(move down)=1/3 + + ## Example + P _ _ _ + _ _ _ O + O _ O _ + O _ _ G + """ + + # Map gym state in integer + MAP_LOOKUP = { + b"P": 0, + b"F": 1, + b"H": 2, + b"G": 3, + } + + # Define rules to transform to rendered text observation of the environment + GRID_LOOKUP = { + 0: " P \t", # player + 1: " _ \t", # frozen + 2: " O \t", # hole + 3: " G \t", # goal + 4: " X \t", # player fall into hole + 5: " √ \t", # player on goal + } + + ACTION_LOOKUP = { + 0: "None", + 1: "Left", + 2: "Down", + 3: "Right", + 4: "Up", + } + + INVALID_ACTION = 0 + PENALTY_FOR_INVALID = -1 + + def __init__( + self, + entry: dict[str, Any], + group_id: int | None = None, + pair_index: int | None = None, + max_steps: int = 5, + **kwargs, + ): + global MAX_STEPS + MAX_STEPS = max_steps + + desc = kwargs.pop("desc", None) + is_slippery = kwargs.pop("is_slippery", False) + self.seed = entry["seed"].item() if "seed" in entry else 42 + self.size = entry["size"].item() if "size" in entry else 8 + self.p = entry["p"].item() if "p" in entry else 0.8 + + if desc is None: + random_map, goal_position = generate_random_map( + size=self.size, p=self.p, seed=self.seed + ) + else: + random_map = np.asarray(copy.deepcopy(desc), dtype="c") + goal_position = get_goal_position(random_map) + + self.goal_postion = goal_position + + BaseTaskEnv.__init__(self, max_steps=MAX_STEPS) + GymFrozenLakeEnv.__init__(self, desc=random_map[:], is_slippery=is_slippery) + + self.ACTION_SPACE = gym.spaces.Discrete(4, start=1) + + self.map_kwargs = { + "size": self.size, + "p": self.p, + } + self.env_kwargs = { + "is_slippery": is_slippery, + "desc": copy.deepcopy(desc), + "seed": self.seed, + } + self.action_map = { + 1: 0, + 2: 1, + 3: 2, + 4: 3, + } + + self.reward = 0 + self._valid_actions = [] + + if not hasattr(self, "extra_kwargs"): + self.extra_kwargs = {} + self.extra_kwargs["group_id"] = group_id + self.extra_kwargs["pair_index"] = pair_index + + def _get_player_position(self): + return (self.s // self.ncol, self.s % self.ncol) + + def _initial_observation(self) -> Any: + GymFrozenLakeEnv.reset(self, seed=self.seed) + self.reward = 0 + self._valid_actions = [] + init_observation = self.render(mode="tiny_rgb_array") + return init_observation + + def finished(self): + player_pos = self._get_player_position() + return self.desc[player_pos] in b"GH" + + def success(self): + """Check if the agent has reacched the goal (G) or hole (H)""" + player_pos = self._get_player_position() + return self.desc[player_pos] in b"G" + + def step(self, action: Any) -> tuple[Any, float, bool, Dict[str, Any]]: # type: ignore[signature-mismatch] + return BaseTaskEnv.step(self, action) + + def _step_impl(self, action: Any) -> EnvStepResult: + """- Map custom action to gymnasium FrozenLakeEnv action and take the step + + - Check if the action is effective (whether player moves in the env). + """ + if self.success(): + return EnvStepResult( + observation=self.render(), + reward=1.0, + done=True, + info={"action_is_effective": False}, + ) + + if not action: + action = self.INVALID_ACTION + action = int(action) + + assert isinstance(action, int), "Action must be an integer" + assert not self.success(), "Agent has already reached the goal or hole" + + if action == self.INVALID_ACTION: + return EnvStepResult( + observation=self.render(), + reward=0.0, + done=False, + info={"action_is_effective": False}, + ) + + prev_player_position = int(self.s) + + player_pos, reward, done, _, prob = GymFrozenLakeEnv.step( + self, int(self.action_map[action]) + ) + + obs = self.render() + return EnvStepResult( + observation=obs, + reward=float(reward), + done=done, + info={"action_is_effective": prev_player_position != int(player_pos)}, + ) + + def render(self, mode="tiny_rgb_array"): + assert mode in ["tiny_rgb_array", "list", "state", "rgb_array", "ansi"] + if mode in ["rgb_array", "ansi"]: + prev_render_mode = self.render_mode + self.render_mode = mode + obs = GymFrozenLakeEnv.render(self) + self.render_mode = prev_render_mode + return obs + room_state = copy.deepcopy(self.desc) + + # replace the position of start 'S' with 'F' + position_S = np.where(room_state == b"S") + room_state[position_S] = b"F" + + # replace the position of the player with 'P' + position_P = self._get_player_position() + room_state[position_P] = b"P" + + if mode == "state": + # transform 'S', 'F', 'H', 'G' to numpy integer array + room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state) + # add player in hole or player on goal + if self.desc[position_P] == b"H": + room_state[position_P] = 4 + elif self.desc[position_P] == b"G": + room_state[position_P] = 5 + return room_state + + room_state = self.render(mode="state").tolist() + + if mode == "list": + lookup = lambda cell: self.GRID_LOOKUP.get(cell, "?").strip("\t").strip() + return [" ".join(lookup(cell) for cell in row) for row in room_state] + + if mode == "tiny_rgb_array": + lookup = lambda cell: self.GRID_LOOKUP.get(cell, "?") + result = "\n".join( + "".join(lookup(cell) for cell in row) for row in room_state + ) + # result += f"Player Position is at ({position_P[0]}, {position_P[1]}), Goal Position is at ({self.goal_postion[0]}, {self.goal_postion[1]})" + return result + + def reset(self): # type: ignore[signature-mismatch] + BaseTaskEnv.reset(self) + GymFrozenLakeEnv.reset(self, seed=self.seed) + return self.render(mode="tiny_rgb_array"), {} + + @classmethod + def from_dict(cls, env_info: dict) -> "FrozenLakeEnv": + return cls( + entry=env_info, + max_turns=env_info.get("max_turns", MAX_STEPS), + is_slippery=env_info.get("is_slippery", False), + ) diff --git a/examples/frozenlake/train_frozenlake.py b/examples/frozenlake/train_frozenlake.py new file mode 100644 index 000000000..66c6c6c8f --- /dev/null +++ b/examples/frozenlake/train_frozenlake.py @@ -0,0 +1,606 @@ +"""Script to train FrozenLake with GRPO on Gemma4.""" + +import contextlib +import datetime +import logging +import math +import os +import sys +from typing import List + +from absl import logging as absl_logging +from flax import nnx +import grain +import jax +from jax import numpy as jnp +import numpy as np +import optax +from orbax import checkpoint as ocp +import qwix + +# ====== Logging Configuration ====== +# 1. Force absl to use python logging +absl_logging.use_python_logging() + +# 2. Configure the root logger +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, +) + +# 3. Explicitly set levels for relevant loggers +logging.getLogger().setLevel(logging.INFO) +logging.getLogger("absl").setLevel(logging.INFO) + +# 4. Set absl verbosity +absl_logging.set_verbosity(absl_logging.INFO) +absl_logging.set_stderrthreshold("info") + +print("Logging configured at INFO level.") + +try: + from etils import ecolab + + cm = ecolab.adhoc( + source=ecolab.FROM_NOTEBOOK_OR_HEAD, + reload="tunix", + behavior="preferred", + cell_autoreload=True, + ) +except: + import contextlib + + cm = contextlib.nullcontext() + +with cm: + from tunix.models.gemma4 import params_safetensors as params_lib + from tunix.models.gemma4 import model as model_lib + from tunix.sft import metrics_logger + from tunix.rl.agentic.agentic_grpo_learner import GRPOConfig, GRPOLearner + from tunix.rl.agentic.parser.chat_template_parser import parser + from tunix.rl import rl_cluster as rl_cluster_lib + from tunix.rl.rollout import base_rollout + from tunix.sft import utils as sft_utils + from tunix.utils import compat + from tunix.rl import reshard + from tunix.cli.utils import data as data_lib + from tunix import PerfMetricsConfig + from tunix.perf.experimental.export import PerfMetricsExport + from examples.frozenlake.agent import FrozenLakeAgent + from examples.frozenlake.env import FrozenLakeEnv + +try: + import pathwaysutils + + pathwaysutils.initialize() +except: + pass + +print("jax devices: ", jax.devices()) + +# %% +import argparse + +arg_parser = argparse.ArgumentParser(description="Train FrozenLake parameters") +arg_parser.add_argument("--batch_size", type=int, default=64) +arg_parser.add_argument("--mini_batch_size", type=int, default=64) +arg_parser.add_argument("--learning_rate", type=float, default=1e-6) +arg_parser.add_argument("--b1", type=float, default=0.9) +arg_parser.add_argument("--b2", type=float, default=0.99) +arg_parser.add_argument("--weight_decay", type=float, default=0.01) +arg_parser.add_argument("--num_batches", type=int, default=150) +arg_parser.add_argument("--num_generations", type=int, default=8) +arg_parser.add_argument("--beta", type=float, default=0.0) +arg_parser.add_argument("--epsilon", type=float, default=0.2) +arg_parser.add_argument("--epsilon_high", type=float, default=0.28) +arg_parser.add_argument("--max_prompt_length", type=int, default=2048) +arg_parser.add_argument("--max_response_length", type=int, default=4096) +arg_parser.add_argument("--temperature", type=float, default=0.7) +arg_parser.add_argument("--top_p", type=float, default=0.95) +arg_parser.add_argument("--top_k", type=int, default=None) +arg_parser.add_argument("--max_concurrency", type=int, default=64) +arg_parser.add_argument("--shuffle_data", type=bool, default=False) +arg_parser.add_argument("--seed", type=int, default=42) +arg_parser.add_argument( + "--loss_agg_mode", type=str, default="sequence-mean-token-mean" +) +arg_parser.add_argument( + "--kl_loss_mode", type=str, default="low_var_kl" +) +args, _ = arg_parser.parse_known_args() + +# ====== Data ====== +TRAIN_FRACTION = 1.0 + +# ====== Reproducibility ====== +SEED = args.seed + +# ====== LoRA ====== +RANK = 64 +ALPHA = 64.0 +TRAIN_WITH_LORA = False + +# ====== Sharding ====== +ROLLOUT_MESH = [(1, 4), ("fsdp", "tp")] +TRAINER_MESH = [(4, 4), ("fsdp", "tp")] +REFERENCE_MESH = [(1, 4), ("fsdp", "tp")] + +# ====== GRPO ====== +# === Generation during GRPO training === +MAX_PROMPT_LENGTH = args.max_prompt_length +MAX_RESPONSE_LENGTH = args.max_response_length +# Important to keep a high-ish temperature for varied, diverse responses during +# training. +TEMPERATURE = args.temperature +TOP_P = args.top_p +TOP_K = args.top_k +# The number of times the policy generates multiple responses for a given prompt +# within a single training step. This corresponds to `G` in Algorithm 1 in the +# paper. The "group" in GRPO comes from here. +NUM_GENERATIONS = args.num_generations + +# Max number of sequences to be processed in parallel by vllm. +VLLM_MAX_NUM_SEQS = 64 + +# Max number of tokens to be processed in parallel by vllm. +# Divide by 8 for on policy, 1 step off divide by 4 +VLLM_MAX_BATCHED_TOKENS = VLLM_MAX_NUM_SEQS * 10 * 1024 // 8 + +# === other GRPO configs === +# The number of iterations per batch (𝜇 in GRPO algo 1). +NUM_ITERATIONS = 1 +# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. +# Important to keep a high enough value for this, otherwise, the KL divergence +# can increase unchecked. +BETA = args.beta +# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for +# stable updates. +EPSILON = args.epsilon +EPSILON_HIGH = args.epsilon_high + +# ====== Training ====== +ENABLE_REMAT = True +ENABLE_FLASH_ATTENTION = True +ENABLE_MIX_PRECISION = True +BATCH_SIZE = args.batch_size +MINI_BATCH_SIZE = args.mini_batch_size +NUM_BATCHES = args.num_batches +# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be +# increased to a max. of 330 (if batch size is 4). +NUM_TEST_BATCHES = 50 + +EVAL_EVERY_N_STEPS = 1000 # this doesn't matter if `TRAIN_FRACTION = 1.0`. +NUM_EPOCHS = 3 # can potentially train for more epochs + +# Number of training steps. +MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS) + +# Max concurrency for parallel processing of trajectories. +MAX_CONCURRENCY = args.max_concurrency + +# Max number of off-policy steps. Default to 0 for synchronous training. +OFF_POLICY_STEPS = 0 + +MODEL_DTYPE = jnp.bfloat16 + +# === AdamW, warmup, cosine scheduler === +LEARNING_RATE = args.learning_rate +B1 = args.b1 # Adam beta1 +B2 = args.b2 # Adam beta2 +WEIGHT_DECAY = args.weight_decay +# == Cosine decay with warmup scheduler == +# Linearly increase learning rate from 0. to 5e-6 in the first 10% training +# steps, and then gradually decrease the learning rate to 0 using cosine +# scheduler. +WARMUP_STEPS = int(0.1 * MAX_STEPS) +# == Grad clipping == +# Grad clipping to prevent large gradients. Found this +# important to keep KL divergence in check. +MAX_GRAD_NORM = 0.3 + +# ====== Checkpoint saving ====== +SAVE_INTERVAL_STEPS = 5 +MAX_TO_KEEP = 500 +DO_MEM_PROFILING = False + +# ====== Inference ====== +GENERATION_CONFIGS = { + # greedy search + "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0}, + # some randomness + "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95}, + # liberal + "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0}, +} +# ====== Rollout ====== +ROLLOUT_ENGINE = os.getenv( + "ROLLOUT_ENGINE", "vllm" +) # one of "vanilla", "vllm" + + +trainer_devices = math.prod(TRAINER_MESH[0]) +rollout_devices = math.prod(ROLLOUT_MESH[0]) +reference_devices = math.prod(REFERENCE_MESH[0]) + +if trainer_devices + rollout_devices + reference_devices > jax.device_count(): + raise ValueError( + "Trainer devices must be less than or equal to the number of devices" + " available." + ) + + +rollout_device_list = jax._src.mesh_utils.create_device_mesh( + ROLLOUT_MESH[0], jax.devices()[:rollout_devices] +) + +rollout_mesh = jax.sharding.Mesh( + rollout_device_list, + axis_names=ROLLOUT_MESH[1], + axis_types=(jax.sharding.AxisType.Auto,) * len(ROLLOUT_MESH[0]), +) +print(f"{rollout_device_list=} {rollout_mesh.devices=}") +reference_device_list = jax._src.mesh_utils.create_device_mesh( + REFERENCE_MESH[0], + jax.devices()[rollout_devices : rollout_devices + reference_devices], +) +reference_mesh = jax.sharding.Mesh( + reference_device_list, + axis_names=REFERENCE_MESH[1], + axis_types=(jax.sharding.AxisType.Auto,) * len(REFERENCE_MESH[0]), +) +print(f"{reference_device_list=} {reference_mesh.devices=}") +trainer_device_list = jax._src.mesh_utils.create_device_mesh( + TRAINER_MESH[0], jax.devices()[-trainer_devices:] +) +trainer_mesh = jax.sharding.Mesh( + trainer_device_list, + axis_names=TRAINER_MESH[1], + axis_types=(jax.sharding.AxisType.Auto,) * len(TRAINER_MESH[0]), +) +print(f"{trainer_device_list=} {trainer_mesh.devices=}") + +# %% +try: + from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib import gfile + file_open = gfile.Open + NOTEBOOK_ENV = "g3" +except Exception: + NOTEBOOK_ENV = "git" + from google.cloud import storage + import fsspec + file_open = fsspec.open + +if NOTEBOOK_ENV == "g3": + DATA_PATH_PREFIX = "/GOOGLE_INTERNAL_STOAGE_PATH/gg-d/home/qwix-dev/rl/data/" + MODEL_PATH_PREFIX = "/GOOGLE_INTERNAL_STOAGE_PATH/gg-d/home/qwix-dev/" + CKPT_DIR_PREFIX = "/GOOGLE_INTERNAL_STOAGE_PATH/gg-d/home/qwix-dev/" +else: + DATA_PATH_PREFIX = "gs://tunix/data/Frozenlake" + MODEL_PATH_PREFIX = "gs://tunix/models" + CKPT_DIR_PREFIX = "gs://tunix/rl/checkpoints" + +print("NOTEBOOK_ENV: ", NOTEBOOK_ENV) +now_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") +CKPT_DIR = os.path.join(CKPT_DIR_PREFIX, f"frozenlake/{now_str}") + +MODEL_VERSION = "google/gemma-4-31B-it" +MODEL_PATH = os.path.join(MODEL_PATH_PREFIX, "gemma-4/gemma-4-31B-it") +# %% +show_hbm_usage = sft_utils.show_hbm_usage + +# %% +import pandas as pd +import datasets as datasets_lib +import transformers + +Dataset = datasets_lib.Dataset +AutoTokenizer = transformers.AutoTokenizer + + +TRAIN_DATA_PATH = os.path.join( + DATA_PATH_PREFIX, "train.parquet" +) +TEST_DATA_PATH = os.path.join( + DATA_PATH_PREFIX, "test.parquet" +) + + +def create_datasets( + train_ds_path: str = TRAIN_DATA_PATH, + test_ds_path: str = TEST_DATA_PATH, +): + with file_open(train_ds_path) as train_f, file_open( + test_ds_path, "rb" + ) as test_f: + train_df = pd.read_parquet(train_f) + test_df = pd.read_parquet(test_f) + + train_ds = Dataset.from_pandas(train_df) + test_ds = Dataset.from_pandas(test_df) + if args.shuffle_data: + train_ds = train_ds.shuffle(SEED) + test_ds = test_ds.shuffle(SEED) + + def process_item(item): + item["prompts"] = "" + return item + + train_ds = grain.MapDataset.source(train_ds).map(process_item) + test_ds = grain.MapDataset.source(test_ds).map(process_item) + return train_ds, test_ds + + +# %% + +tokenizer = AutoTokenizer.from_pretrained(MODEL_VERSION) + +chat_parser = parser.DefaultChatTemplateParser(tokenizer) + +# %% +train_dataset, test_dataset = create_datasets() +train_dataset, val_dataset = data_lib.post_init_dataset( + train_dataset, + tokenizer, + batch_size=BATCH_SIZE, + num_batches=NUM_BATCHES, + max_prompt_length=MAX_PROMPT_LENGTH, + fraction=TRAIN_FRACTION, + num_epochs=NUM_EPOCHS, +) + +test_dataset, _ = data_lib.post_init_dataset( + test_dataset, + tokenizer, + batch_size=BATCH_SIZE, + num_batches=NUM_TEST_BATCHES, + max_prompt_length=MAX_PROMPT_LENGTH, +) + +# %% +show_hbm_usage("Done with loading datasets") + +# %% +config = model_lib.ModelConfig.gemma4_31b() +if ENABLE_REMAT: + config.remat_config = model_lib.RematConfig.BLOCK +if ENABLE_FLASH_ATTENTION: + config.use_flash_attention = True + config.flash_attention_block_size = 256 +if ENABLE_MIX_PRECISION: + config.dtype = jnp.bfloat16 + +gemma4_ref = params_lib.create_model_from_safe_tensors( + MODEL_PATH, config, reference_mesh, dtype=MODEL_DTYPE +) + +# %% +show_hbm_usage("after loading gemma4_ref") + + +# %% +def get_lora_model(base_model, model_mesh): + lora_provider = qwix.LoraProvider( + module_path=( + ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|" + ".*attn_vec_einsum" + ), + rank=RANK, + alpha=ALPHA, + ) + + model_input = base_model.get_model_input() + lora_model = qwix.apply_lora_to_model( + base_model, lora_provider, **model_input + ) + + with compat.set_mesh(model_mesh): + state = nnx.state(lora_model) + pspecs = nnx.get_partition_spec(state) + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) + nnx.update(lora_model, sharded_state) + + return lora_model + + +# %% +if TRAIN_WITH_LORA: + gemma4_actor = get_lora_model(gemma4_ref, trainer_mesh) +else: + # gemma4_actor = params_lib.create_model_from_safe_tensors( + # MODEL_PATH, config, trainer_mesh, dtype=MODEL_DTYPE + # ) + graph, state = nnx.split(gemma4_ref) + trainer_shardings = jax.tree_util.tree_map( + lambda x: jax.sharding.NamedSharding( + trainer_mesh, + x, + ), + nnx.get_partition_spec(state), + ) + gemma4_actor = nnx.merge(graph, reshard.reshard_pytree(state, trainer_shardings)) + +# %% +show_hbm_usage("after loading gemma4_actor") + + +# %% +# Ckpt saving +checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP +) + +# Metrics logger +wandb_config = vars(args) +wandb_config.update({ + "WARMUP_STEPS": WARMUP_STEPS, + "num_steps": MAX_STEPS, + "rollout_engine": ROLLOUT_ENGINE, +}) +metrics_logging_options = metrics_logger.MetricsLoggerOptions( + log_dir="gs://linchai-bucket-dev/tensorboard/grpo", + flush_every_n_steps=20, + backend_kwargs={"wandb": {"config": wandb_config}}, +) + +# %% +# Optimizer, learning rate scheduler, gradient clipping +optimizer = optax.adamw( + learning_rate=LEARNING_RATE, + b1=B1, + b2=B2, + weight_decay=WEIGHT_DECAY, +) +if MAX_GRAD_NORM is not None: + optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM), + optimizer, + ) + +# %% +# Training config +print("# Rollout mesh: ", rollout_mesh) +print("Trainer mesh: ", trainer_mesh) +print("Reference mesh: ", reference_mesh) + +base_rollout_dict = { + "max_prompt_length": MAX_PROMPT_LENGTH, + "kv_cache_size": MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256, + "temperature": TEMPERATURE, + "top_p": TOP_P, + "top_k": TOP_K, + "return_logprobs": True, + "max_tokens_to_generate": MAX_RESPONSE_LENGTH, +} + +vllm_rollout_dict = { + # vllm-tpu specific configs + "rollout_vllm_model_version": MODEL_VERSION, + "rollout_vllm_hbm_utilization": 0.7, + "rollout_vllm_tpu_backend_type": "jax", + "rollout_vllm_server_mode": True, + "rollout_vllm_enable_dp_attention": True, + "rollout_vllm_async_scheduling": True, + "rollout_vllm_init_with_random_weights": True, + "tensor_parallel_size": ROLLOUT_MESH[0][1], + "data_parallel_size": ROLLOUT_MESH[0][0], + "rollout_vllm_max_num_seqs": VLLM_MAX_NUM_SEQS, + "rollout_vllm_max_num_batched_tokens": VLLM_MAX_BATCHED_TOKENS, + "rollout_vllm_kwargs": { + "kv_cache_metrics": True, + "disable_log_stats": False, + "enable_prefix_caching": False, + "dtype": "bfloat16", + }, +} + +if ROLLOUT_ENGINE == "vllm": + rollout_engine_config = base_rollout.RolloutConfig( + **base_rollout_dict, **vllm_rollout_dict + ) +elif ROLLOUT_ENGINE == "vanilla": + rollout_engine_config = base_rollout.RolloutConfig(**base_rollout_dict) +else: + raise ValueError(f"Unsupported rollout engine: {ROLLOUT_ENGINE}") + +cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: trainer_mesh, + rl_cluster_lib.Role.REFERENCE: reference_mesh, + rl_cluster_lib.Role.ROLLOUT: rollout_mesh, + }, + rollout_engine=ROLLOUT_ENGINE, + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optimizer, + eval_every_n_steps=EVAL_EVERY_N_STEPS, + max_steps=MAX_STEPS, + mini_batch_size=MINI_BATCH_SIZE, + train_micro_batch_size=1, + # metrics logging + metrics_logging_options=metrics_logging_options, + # checkpoint saving + checkpoint_root_directory=CKPT_DIR, + checkpointing_options=checkpointing_options, + ), + rollout_config=rollout_engine_config, +) + +grpo_config = GRPOConfig( + num_generations=NUM_GENERATIONS, + num_iterations=NUM_ITERATIONS, + max_response_length=MAX_RESPONSE_LENGTH, + beta=BETA, + epsilon=EPSILON, + epsilon_high=EPSILON_HIGH, + system_prompt="", + max_concurrency=MAX_CONCURRENCY, + off_policy_steps=OFF_POLICY_STEPS, + loss_agg_mode=args.loss_agg_mode, + kl_loss_mode=args.kl_loss_mode, +) + +# Perf Metrics logging +perf_metrics_config = PerfMetricsConfig( + custom_export_fn_v2=PerfMetricsExport.from_cluster_config( + cluster_config=cluster_config, + trace_dir="/tmp/agentic_perf", + ).export_metrics +) + +# %% +# RL cluster +rl_cluster = rl_cluster_lib.RLCluster( + actor=gemma4_actor, + reference=gemma4_ref, + tokenizer=tokenizer, + cluster_config=cluster_config, + perf_config=perf_metrics_config, +) + +show_hbm_usage("after RLCluster creation") + + +# %% +def metric_fn(prompts, completions, rewards, advantages, **kwargs): + del prompts, completions, advantages, kwargs + solve_all = (rewards > 0.1).all() + solve_none = (rewards == 0).all() + solve_partial = (~solve_all) and (~solve_none) + solve_ratio = (rewards > 0.1).mean() + return { + "rewards/solve_all": ( + 1 if solve_all else 0, + np.mean, + ), + "rewards/solve_none": ( + 1 if solve_none else 0, + np.mean, + ), + "rewards/solve_partial": ( + 1 if solve_partial else 0, + np.mean, + ), + "rewards/solve_ratio": ( + solve_ratio, + np.mean, + ), + } + + +# GRPO Trainer +grpo_trainer = GRPOLearner( + rl_cluster=rl_cluster, + agent_class=FrozenLakeAgent, + agent_kwargs={}, + env_class=FrozenLakeEnv, + env_kwargs={"max_steps": "5"}, + algo_config=grpo_config, + chat_parser=chat_parser, + metric_fns=[metric_fn], +) +show_hbm_usage("after GRPOLearner creation") + +grpo_trainer.train(train_dataset) diff --git a/tunix/rl/agentic/agentic_grpo_learner.py b/tunix/rl/agentic/agentic_grpo_learner.py index 1e635d9b9..4b5f3877f 100644 --- a/tunix/rl/agentic/agentic_grpo_learner.py +++ b/tunix/rl/agentic/agentic_grpo_learner.py @@ -248,6 +248,7 @@ def __init__( "pg_loss": np.mean, "pg_clipfrac": np.mean, "ppo_kl": np.mean, + "kl_loss": np.mean, }) self.rl_cluster.actor_trainer.with_tqdm_metrics_to_display([ lambda: "kl"