From 5507e82ac766f3f7e4e893e84402c40e3d25f449 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 15:50:26 +0200 Subject: [PATCH 01/10] Removed generate_rollout_completions --- docs/source/openenv.md | 86 +- examples/notebooks/README.md | 2 +- ...rpo_functiongemma_browsergym_openenv.ipynb | 2363 ++++------------- examples/scripts/openenv/browsergym.py | 6 +- examples/scripts/openenv/browsergym_llm.py | 10 +- trl/__init__.py | 2 + trl/chat_template_utils.py | 40 +- trl/chat_templates/functiongemma.jinja | 279 ++ trl/experimental/openenv/__init__.py | 18 - trl/experimental/openenv/utils.py | 214 -- 10 files changed, 849 insertions(+), 2171 deletions(-) create mode 100644 trl/chat_templates/functiongemma.jinja delete mode 100644 trl/experimental/openenv/__init__.py delete mode 100644 trl/experimental/openenv/utils.py diff --git a/docs/source/openenv.md b/docs/source/openenv.md index 9df3e5c8a89..18f38465d24 100644 --- a/docs/source/openenv.md +++ b/docs/source/openenv.md @@ -9,7 +9,7 @@ This guide covers **how to integrate OpenEnv with TRL**. For more on OpenEnv its ## When to use environments -[`GRPOTrainer`] can be used to train agents. For agentic tasks, it supports two modes: **tools**, where the model can call external functions but each call is stateless and independent, and **environments**, which maintain state across turns, enabling genuine multi-turn interaction where the agent's actions shape future observations. Use environments when continuity matters — for example, navigating a game, browsing a web page, or any task where what the agent sees next depends on what it did before. +[`GRPOTrainer`] can be used to train agents. For agentic tasks, it supports two modes: **tools**, where the model can call external functions but each call is stateless and independent, and **environments**, which maintain state across turns, enabling genuine multi-turn interaction where the agent's actions shape future observations. Use environments when continuity matters: for example, navigating a game, browsing a web page, or any task where what the agent sees next depends on what it did before. ## Installation @@ -24,6 +24,9 @@ pip install "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordl # Catch (OpenSpiel) environment pip install "openenv-openspiel-env @ git+https://huggingface.co/spaces/openenv/openspiel_env" + +# BrowserGym environment +pip install "openenv-browsergym @ git+https://huggingface.co/spaces/sergiopaniego/browsergym_env" ``` This installs the **environment client** (e.g., `EchoEnv`) that communicates with the remote environment server via WebSocket, along with the action/observation models and all required dependencies (including `openenv-core`). @@ -561,6 +564,15 @@ The best way to explore the current catalog of maintained environments is by vis To create your own environment, check out the guide on [Building Your Own Environment with OpenEnv](https://meta-pytorch.org/OpenEnv/auto_getting_started/plot_03_building_environments.html). Environments are tightly integrated with the Hub, so you can push new environments for the community to reuse. +## `environment_factory` vs `rollout_func` + +`environment_factory` is the only supported approach for environment-based training in TRL. You define an environment class with tool methods, and the trainer handles generation, tool-call parsing, and the multi-turn loop automatically. + +`rollout_func` is an experimental API that predates `environment_factory`. It is no longer recommended and will be removed in a future version. If you have existing scripts that use `rollout_func`, migrate them to `environment_factory`. + +> [!WARNING] +> `rollout_func` emits a deprecation warning at runtime and may be removed without prior notice. Do not use it for new projects. + ## Server concurrency When using `environment_factory`, the trainer creates N environment instances (one per generation), each opening a WebSocket connection to the server. By default, OpenEnv servers allow only 1 concurrent session, which will cause failures during training. @@ -585,75 +597,3 @@ app = create_app( > [!TIP] > `max_concurrent_envs` should be ≥ `generation_batch_size` (which defaults to `per_device_train_batch_size × gradient_accumulation_steps`). For example, with `gradient_accumulation_steps=64` and batch size 1, you need at least 64 concurrent sessions. -## `environment_factory` vs `rollout_func` - -[`GRPOTrainer`] supports two approaches for environment-based training: - -- **`environment_factory`** (recommended): You define an environment class with tool methods, and the trainer handles generation, tool-call parsing, and the multi-turn loop automatically. This is the approach used throughout this guide. -- **`rollout_func`**: You write the entire generation and environment interaction loop yourself. This gives full control over how completions are produced, how tools are executed, and how rewards are computed. - -Use `rollout_func` when `environment_factory` doesn't fit your use case. For example, **external agent servers** where an external server owns the generation loop and manages its own agent-environment interaction protocol. - -### Migrating from `rollout_func` to `environment_factory` - -If you have existing `rollout_func` code and want to migrate, here's the mapping: - -| `rollout_func` pattern | `environment_factory` equivalent | -|------------------------|----------------------------------| -| Manual generation loop | Handled automatically by the trainer | -| `generate_rollout_completions()` | Not needed, trainer generates internally | -| `env.step(Action(...))` in rollout | Wrap in a tool method on the environment class | -| Reward via `kwargs["env_reward"]` | Reward via `environments` parameter | -| `env_mask` construction | Automatic, trainer builds `tool_mask` | -| Token concatenation | Automatic, trainer manages token sequences | - -**Before** (`rollout_func`): - -```python -def rollout_func(prompts, trainer): - outputs = generate_rollout_completions(trainer, prompts) - env_rewards = [] - for out in outputs: - text = tokenizer.decode(out["completion_ids"], skip_special_tokens=True) - result = client.step(EchoAction(message=text)) - env_rewards.append(result.reward) - return { - "prompt_ids": [out["prompt_ids"] for out in outputs], - "completion_ids": [out["completion_ids"] for out in outputs], - "logprobs": [out["logprobs"] for out in outputs], - "env_reward": env_rewards, - } - -trainer = GRPOTrainer(..., rollout_func=rollout_func) -``` - -**After** (`environment_factory`): - -```python -class EchoToolEnv: - def __init__(self): - self.env = EchoEnv(base_url=url) - self.reward = 0.0 - - def reset(self, **kwargs) -> str | None: - self.reward = 0.0 - return None - - def echo(self, message: str) -> str: - """Echo the message back. - - Args: - message: The message to echo - - Returns: - The echoed message. - """ - result = self.env.step(EchoAction(message=message)) - self.reward = result.observation.reward - return result.observation.echoed_message - -def reward_func(environments, **kwargs): - return [env.reward for env in environments] - -trainer = GRPOTrainer(..., environment_factory=EchoToolEnv, reward_funcs=reward_func) -``` diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md index f744d417ec4..e6c8812e099 100644 --- a/examples/notebooks/README.md +++ b/examples/notebooks/README.md @@ -17,7 +17,7 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t ## OpenEnv Notebooks -These notebooks demonstrate GRPO training with [OpenEnv](https://github.com/meta-pytorch/OpenEnv) environments using `environment_factory`. The BrowserGym notebook uses the lower-level `rollout_func` API instead. +These notebooks demonstrate GRPO training with [OpenEnv](https://github.com/meta-pytorch/OpenEnv) environments using `environment_factory`. | Notebook | Description | Open in Colab | | --- | --- | --- | diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index e19df898d54..b4bb396280c 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -1,1914 +1,565 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "lSR2nwdJg962" - }, - "source": [ - "# Fine-Tune FunctionGemma using Hugging Face TRL and OpenEnv\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb)\n", - "\n", - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n", - "\n", - "This guide describes the process of fine-tuning [FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) by Google DeepMind in the [BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) environment provided by OpenEnv, using Hugging Face TRL. The steps covered include:\n", - "\n", - "* What is GRPO and OpenEnv\n", - "* Setup dependencies for training\n", - "* Initialize the OpenEnv's BrowserGym environment\n", - "* Create rollout function with helpers\n", - "* Define the reward functions\n", - "* Load the custom dataset\n", - "* Fine tune using TRL and the GRPOTrainer\n", - "* Load the fine-tuned model and run inference\n", - "\n", - "> Note: The guide is designed to run on Google Colaboratory with access to an NVIDIA A100 GPU (40GB) using FunctionGemma. The workflow can be adapted to other GPU configurations, models, or environments." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "duXYuR6Cu_na" - }, - "source": [ - "## What is GRPO and OpenEnv\n", - "\n", - "Group Relative Policy Optimization ([GRPO](https://huggingface.co/papers/2402.03300)) is a post-training method widely used for efficiently fine-tuning large language models. GRPO leverages reward functions to guide learning, enabling models to optimize task-specific behaviors without retraining the entire network.\n", - "\n", - "[OpenEnv](https://meta-pytorch.org/OpenEnv) provides a standard interface for interacting with agentic execution environments using simple Gymnasium-style APIs, such as `step()`, `reset()`, and `state()`. These APIs facilitate reinforcement learning training loops by allowing models to interact with environments in a structured manner. OpenEnv also offers tools for environment creators to build isolated, secure, and deployable environments that can be shared via common protocols like HTTP or packaged in Docker.\n", - "\n", - "The combination of GRPO and OpenEnv enables efficient fine-tuning of models in controlled, interactive tasks while minimizing resource requirements." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cpSAQkzKmv50" - }, - "source": [ - "## Setup dependencies for training\n", - "\n", - "Install the required libraries, including Hugging Face TRL for fine-tuning and OpenEnv for reinforcement learning environments." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "c-2drnj5BP56" - }, - "outputs": [], - "source": [ - "!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/openenv/browsergym_env liger-kernel trackio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Inxeq6ZGpRno" - }, - "source": [ - "A valid Hugging Face token is required to save the fine-tuned model. In Google Colab, the token can be securely accessed through Colab secrets. Otherwise, it can be provided directly in the login method. Ensure the token has write permissions to allow uploading the model to the Hugging Face Hub during training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "C4q5UVu3BP57" - }, - "outputs": [], - "source": [ - "from google.colab import userdata\n", - "from huggingface_hub import login\n", - "\n", - "# Login into Hugging Face Hub\n", - "hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab\n", - "login(hf_token)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O3kr38TGm_hb" - }, - "source": [ - "## Initialize the OpenEnv's BrowserGym environment\n", - "\n", - "External environments can guide the fine-tuning of LLMs for function calling by providing interactive feedback that enhances performance on task-specific behaviors.\n", - "\n", - "[BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) is a unified framework for web-based agent tasks, offering multiple benchmarks through a Gymnasium-compatible API. It enables training on simple synthetic tasks with [MiniWoB++](https://github.com/Farama-Foundation/miniwob-plusplus) and evaluation on more complex, realistic tasks with [WebArena](https://github.com/web-arena-x/webarena), [VisualWebArena](https://github.com/web-arena-x/visualwebarena), or [WorkArena](https://github.com/ServiceNow/WorkArena). This setup supports iterative training and assessment of web agents without requiring extensive infrastructure.\n", - "\n", - "BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called *\"click-test\"*, which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment [openenv/browsergym_env](https://huggingface.co/spaces/openenv/browsergym_env) will be used.\n", - "\n", - "> Note: Hosted environments on the Hub currently have limited concurrency. For higher reliability or parallel runs, duplicating the Space to your own account is strongly recommended." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "clDs-WQlBP57" - }, - "outputs": [], - "source": [ - "from browsergym_env import BrowserGymEnv\n", - "space_url = \"https://openenv-browsergym-env.hf.space\"\n", - "\n", - "client = BrowserGymEnv(base_url=space_url)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EqfDavDQnD_5" - }, - "source": [ - "## Create rollout function with helpers\n", - "\n", - "The rollout function defines how the agent interacts with the environment during GRPO training. It generates model outputs, collects feedback in the form of rewards, and returns the information required for optimization.\n", - "\n", - "In this setup:\n", - "- The function is invoked automatically by the GRPOTrainer (introduced later), which orchestrates the training loop and handles policy updates.\n", - "- It uses the trainer's `generate_rollout_completions()` method for efficient output generation. This leverages vLLM, a high-performance inference engine for large language models, and is integrated within TRL to streamline rollout generation and reward collection during fine-tuning.\n", - "- Each rollout represents a complete interaction loop, where the model acts, receives feedback from the environment, and updates based on reward signals.\n", - "\n", - "Rewards capture various aspects of the agent's performance. Helper functions, such as `rollout_once`, manage individual episodes, keeping the main `rollout_func` clean, modular, and reusable.\n", - "\n", - "This modular structure allows GRPO to efficiently sample, evaluate, and refine the model's behavior through reinforcement learning.\n", - "\n", - "Before executing rollouts, a `system prompt` is defined to instruct the model on how to interact with the environment. This prompt specifies the available BrowserGym actions (such as `click`, `fill`, `send_keys`, and `scroll`), describes the page structure, and enforces that the model responds with exactly one action per step. It ensures consistent and structured interactions, guiding the model to complete tasks effectively without providing extra explanations or multiple actions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ItCXS6H0BP58" - }, - "outputs": [], - "source": [ - "# @title System prompt (click to expand)\n", - "SYSTEM_PROMPT = \"\"\"You control a web browser through BrowserGym actions.\n", - "You must complete the given web task by interacting with the page.\n", - "\n", - "Available actions:\n", - "- noop() - Do nothing\n", - "- click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "- fill(bid, text) - Fill input field with text\n", - "- send_keys(text) - Send keyboard input\n", - "- scroll(direction) - Scroll up/down\n", - "\n", - "The page structure shows elements as: [bid] element_type 'element_text'\n", - "For example: [13] button 'Click Me!' means bid='13'\n", - "\n", - "Reply with exactly ONE action on a single line, e.g.:\n", - "click('13')\n", - "fill('42', 'hello world')\n", - "noop()\n", - "\n", - "Do not include explanations or multiple actions.\"\"\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Vi1rFey39GUl" - }, - "source": [ - "The `rollout_func` orchestrates the interaction between the model and the remote BrowserGym environment. For each prompt in the batch, it executes a complete episode using the `rollout_once` function, collecting model outputs and rewards for GRPO optimization.\n", - "\n", - "The parameter `max_steps` defines the maximum number of steps the model can take within a single episode. This limits the length of the interaction loop, ensuring that episodes terminate even if the task is not completed, and helps maintain efficient training.\n", - "\n", - "During each episode, the function tracks prompt and completion IDs, log probabilities, and both step-wise and final rewards, returning them in a structured format for the trainer to perform policy updates." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CgHd5CFBBP58" - }, - "outputs": [], - "source": [ - "from trl import GRPOTrainer\n", - "\n", - "max_steps=10\n", - "\n", - "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n", - " episode_prompt_ids: list[list[int]] = []\n", - " episode_completion_ids: list[list[int]] = []\n", - " episode_logprobs: list[list[float]] = []\n", - " completion_rewards: list[float] = []\n", - "\n", - " print(f\"\\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)\")\n", - "\n", - " for i, prompt_text in enumerate(prompts):\n", - " print(f\"[DEBUG] Processing prompt {i + 1}/{len(prompts)}\")\n", - " episode = rollout_once(\n", - " trainer=trainer,\n", - " env=client,\n", - " tokenizer=trainer.processing_class,\n", - " dataset_prompt=prompt_text,\n", - " max_steps=max_steps,\n", - " )\n", - " episode_prompt_ids.append(episode[\"prompt_ids\"])\n", - " episode_completion_ids.append(episode[\"completion_ids\"])\n", - " episode_logprobs.append(episode[\"logprobs\"])\n", - " completion_rewards.append(episode[\"completion_reward\"])\n", - "\n", - " return {\n", - " \"prompt_ids\": episode_prompt_ids,\n", - " \"completion_ids\": episode_completion_ids,\n", - " \"logprobs\": episode_logprobs,\n", - " \"completion_reward\": completion_rewards,\n", - " }" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ioUHdIxr9ZQO" - }, - "source": [ - "### Define `rollout_once`\n", - "\n", - "The `rollout_once` function runs one complete interaction loop between the model and the BrowserGym environment using the trainer's generation method. \n", - "It executes a single episode, from generating an action to receiving feedback and computing rewards.\n", - "\n", - "Here's the step-by-step breakdown:\n", - "\n", - "1. Environment reset: Start a new BrowserGym session and initialize the observation.\n", - "2. Prompt construction: Combine the system prompt, environment observation (text-only via the accessibility tree), and any relevant errors or state information to form the model input.\n", - "3. Generation: Use `trl.experimental.openenv.generate_rollout_completions()` to produce the model's action efficiently with vLLM.\n", - "4. Action parsing and execution: Interpret the model's output and execute the corresponding BrowserGym action (e.g., `click`, `fill`, `scroll`).\n", - "5. Reward calculation: Track step-wise rewards provided by the environment and compute completion rewards based on task success or failure.\n", - "6. Return structured rollout data: Includes prompt/completion IDs, log probabilities, step rewards, and the final reward for the episode.\n", - "\n", - "This modular design allows each episode to be processed independently while providing rich feedback for the GRPO training loop, supporting both task completion and intermediate reward shaping." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "y8Ml47SYBP58" - }, - "outputs": [], - "source": [ - "from trl.experimental.openenv import generate_rollout_completions\n", - "from browsergym_env import BrowserGymAction\n", - "from transformers import AutoTokenizer\n", - "\n", - "def rollout_once(\n", - " trainer: GRPOTrainer,\n", - " env: BrowserGymEnv,\n", - " tokenizer: AutoTokenizer,\n", - " dataset_prompt: str,\n", - " max_steps: int,\n", - ") -> dict[str, list]:\n", - " \"\"\"Run one episode and collect training data (text-only, no screenshots).\"\"\"\n", - " result = env.reset()\n", - " observation = result.observation\n", - "\n", - " prompt_ids: list[int] = []\n", - " completion_ids: list[int] = []\n", - " logprobs: list[float] = []\n", - " step_rewards: list[float] = []\n", - " completion_rewards: list[float] = []\n", - "\n", - " for step_num in range(max_steps):\n", - " if result.done:\n", - " break\n", - "\n", - " # Create prompt from observation (text-only using accessibility tree)\n", - " goal = observation.goal or dataset_prompt\n", - " axtree = observation.axtree_txt or \"\"\n", - " error = observation.error if observation.last_action_error else \"\"\n", - "\n", - " user_prompt = make_user_prompt(goal, step_num, axtree, error)\n", - " messages = [\n", - " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", - " {\"role\": \"user\", \"content\": user_prompt},\n", - " ]\n", - " prompt_text = tokenizer.apply_chat_template(\n", - " messages,\n", - " add_generation_prompt=True,\n", - " tokenize=False,\n", - " )\n", - "\n", - " # Generate action with vLLM\n", - " rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n", - " prompt_ids.extend(rollout_outputs[\"prompt_ids\"])\n", - " completion_ids.extend(rollout_outputs[\"completion_ids\"])\n", - " logprobs.extend(rollout_outputs[\"logprobs\"])\n", - "\n", - " completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n", - " rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n", - " )\n", - "\n", - " # Parse and execute action\n", - " action_str = parse_action(completion_text)\n", - "\n", - " print(f\"Step {step_num + 1}: {action_str}\")\n", - "\n", - " # Take action in environment\n", - " result = env.step(BrowserGymAction(action_str=action_str))\n", - " observation = result.observation\n", - "\n", - " # Track rewards\n", - " step_reward = float(result.reward or 0.0)\n", - " step_rewards.append(step_reward)\n", - "\n", - " # Reward shaping: success is most important\n", - " if result.done and step_reward > 0:\n", - " completion_rewards.append(1.0) # Task completed successfully\n", - " elif result.done and step_reward == 0:\n", - " completion_rewards.append(0.0) # Task failed\n", - " else:\n", - " completion_rewards.append(step_reward) # Intermediate reward\n", - "\n", - " # Final reward is based on task completion\n", - " final_reward = completion_rewards[-1] if completion_rewards else 0.0\n", - "\n", - " return {\n", - " \"prompt_ids\": prompt_ids,\n", - " \"completion_ids\": completion_ids,\n", - " \"logprobs\": logprobs,\n", - " \"step_rewards\": step_rewards,\n", - " \"completion_reward\": final_reward,\n", - " }" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MDJKMQ__8qzj" - }, - "source": [ - "### Helper functions\n", - "\n", - "Supporting utilities used in `rollout_once`:\n", - "\n", - "- `make_user_prompt`: builds the user prompt combining the base text and previous game messages.\n", - "- `parse_action`: parses BrowserGym action from model response" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GG4ba41PBP58" - }, - "outputs": [], - "source": [ - "# @title Helpers (click to expand)\n", - "def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = \"\") -> str:\n", - " \"\"\"Create user prompt from observation.\"\"\"\n", - " prompt_parts = [f\"Step {step_num + 1}\"]\n", - "\n", - " if goal:\n", - " prompt_parts.append(f\"Goal: {goal}\")\n", - "\n", - " if error:\n", - " prompt_parts.append(f\"Previous action error: {error}\")\n", - "\n", - " # Include accessibility tree (truncated for context)\n", - " if axtree:\n", - " max_len = 2000\n", - " axtree_truncated = axtree[:max_len] + \"...\" if len(axtree) > max_len else axtree\n", - " prompt_parts.append(f\"Page structure:\\n{axtree_truncated}\")\n", - "\n", - " prompt_parts.append(\"What action do you take?\")\n", - "\n", - " return \"\\n\\n\".join(prompt_parts)\n", - "\n", - "\n", - "def parse_action(response_text: str) -> str:\n", - " \"\"\"Parse BrowserGym action from model response.\"\"\"\n", - " # Extract first line that looks like an action\n", - " for line in response_text.strip().split(\"\\n\"):\n", - " line = line.strip()\n", - " if \"(\" in line and \")\" in line:\n", - " return line\n", - "\n", - " # Fallback to noop if no valid action found\n", - " return \"noop()\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Oek3JhcWnKhw" - }, - "source": [ - "## Define the reward functions\n", - "\n", - "Reward functions quantify the model's performance in the environment and guide the GRPO optimization process.\n", - "\n", - "In this setup, the `reward_completion` function assigns rewards based on task completion. It extracts the final reward for each episode, which indicates whether the agent successfully completed the task. If no reward information is available, it defaults to zero.\n", - "\n", - "This modular approach allows additional reward functions to be added easily, enabling more granular feedback such as intermediate progress, efficiency, or correctness of actions, depending on the task requirements." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WxkXaz5aBP59" - }, - "outputs": [], - "source": [ - "def reward_completion(completions: list[str], **kwargs) -> list[float]:\n", - " \"\"\"Reward for task completion.\"\"\"\n", - " rewards = kwargs.get(\"completion_reward\") if kwargs else None\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "66ZsrLplm07U" - }, - "source": [ - "## Load the custom dataset\n", - "\n", - "The dataset is constructed with repeated prompts to control the total number of training episodes.\n", - "\n", - "Each entry in the dataset triggers a single rollout episode during training. The `dataset_prompt` provides the initial instruction to the model at the start of each episode, ensuring consistent guidance for task execution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UX6jUjxaBP59" - }, - "outputs": [], - "source": [ - "from datasets import Dataset\n", - "\n", - "dataset_prompt = \"Complete the web task successfully.\"\n", - "dataset_size = 1000\n", - "\n", - "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "lSR2nwdJg962" + }, + "source": "# Fine-Tune FunctionGemma using Hugging Face TRL and OpenEnv\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb)\n\n![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n\nThis guide describes the process of fine-tuning [FunctionGemma](https://huggingface.co/google/functiongemma-270m-it) by Google DeepMind in the [BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) environment provided by OpenEnv, using Hugging Face TRL. The steps covered include:\n\n* What is GRPO and OpenEnv\n* Setup dependencies for training\n* Initialize the OpenEnv's BrowserGym environment\n* Define the reward functions\n* Load the custom dataset\n* Fine tune using TRL and the GRPOTrainer\n* Load the fine-tuned model and run inference\n\n> Note: The guide is designed to run on Google Colaboratory with access to an NVIDIA A100 GPU (40GB) using FunctionGemma. The workflow can be adapted to other GPU configurations, models, or environments." + }, + { + "cell_type": "markdown", + "metadata": { + "id": "duXYuR6Cu_na" + }, + "source": [ + "## What is GRPO and OpenEnv\n", + "\n", + "Group Relative Policy Optimization ([GRPO](https://huggingface.co/papers/2402.03300)) is a post-training method widely used for efficiently fine-tuning large language models. GRPO leverages reward functions to guide learning, enabling models to optimize task-specific behaviors without retraining the entire network.\n", + "\n", + "[OpenEnv](https://meta-pytorch.org/OpenEnv) provides a standard interface for interacting with agentic execution environments using simple Gymnasium-style APIs, such as `step()`, `reset()`, and `state()`. These APIs facilitate reinforcement learning training loops by allowing models to interact with environments in a structured manner. OpenEnv also offers tools for environment creators to build isolated, secure, and deployable environments that can be shared via common protocols like HTTP or packaged in Docker.\n", + "\n", + "The combination of GRPO and OpenEnv enables efficient fine-tuning of models in controlled, interactive tasks while minimizing resource requirements." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cpSAQkzKmv50" + }, + "source": [ + "## Setup dependencies for training\n", + "\n", + "Install the required libraries, including Hugging Face TRL for fine-tuning and OpenEnv for reinforcement learning environments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c-2drnj5BP56" + }, + "outputs": [], + "source": [ + "!pip install -qU trl\n", + "!pip install -qU jmespath\n", + "!pip install -qU git+https://huggingface.co/spaces/sergiopaniego/browsergym_env\n", + "!pip install -qU trackio\n", + "!pip install -qU transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Inxeq6ZGpRno" + }, + "source": [ + "A valid Hugging Face token is required to save the fine-tuned model. In Google Colab, the token can be securely accessed through Colab secrets. Otherwise, it can be provided directly in the login method. Ensure the token has write permissions to allow uploading the model to the Hugging Face Hub during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C4q5UVu3BP57" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", + "from huggingface_hub import login\n", + "\n", + "# Login into Hugging Face Hub\n", + "hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab\n", + "login(hf_token)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O3kr38TGm_hb" + }, + "source": [ + "## Initialize the OpenEnv's BrowserGym environment\n", + "\n", + "External environments can guide the fine-tuning of LLMs for function calling by providing interactive feedback that enhances performance on task-specific behaviors.\n", + "\n", + "[BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) is a unified framework for web-based agent tasks, offering multiple benchmarks through a Gymnasium-compatible API. It enables training on simple synthetic tasks with [MiniWoB++](https://github.com/Farama-Foundation/miniwob-plusplus) and evaluation on more complex, realistic tasks with [WebArena](https://github.com/web-arena-x/webarena), [VisualWebArena](https://github.com/web-arena-x/visualwebarena), or [WorkArena](https://github.com/ServiceNow/WorkArena). This setup supports iterative training and assessment of web agents without requiring extensive infrastructure.\n", + "\n", + "BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called *\"click-test\"*, which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment [sergiopaniego/browsergym_env](https://huggingface.co/spaces/sergiopaniego/browsergym_env) will be used.\n", + "\n", + "> Note: Hosted environments on the Hub currently have limited concurrency. For higher reliability or parallel runs, duplicating the Space to your own account is strongly recommended." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "clDs-WQlBP57" + }, + "outputs": [], + "source": [ + "from browsergym_env import BrowserGymEnv\n", + "space_url = \"https://sergiopaniego-browsergym-env.hf.space\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EqfDavDQnD_5" + }, + "source": [ + "## Define BrowserGym environment factory\n", + "\n", + "The `environment_factory` defines how the model interacts with BrowserGym during GRPO training. Instead of a manual generation loop, the model makes structured tool calls (`click`, `fill`, `send_keys`, …) and the trainer handles generation, tool-call parsing, and the multi-turn loop automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ItCXS6H0BP58" + }, + "outputs": [], + "source": [ + "# @title System prompt (click to expand)\nSYSTEM_PROMPT = \"\"\"You control a web browser to complete tasks using tool calls.\n\nPage elements show their numeric bid: [bid:N] type 'text'.\nExample: [bid:7] link 'Home' means bid='7' (use the NUMBER, never the text label).\n\nCall exactly one tool per step. When the task is complete, call noop.\"\"\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Vi1rFey39GUl" + }, + "source": [ + "`BrowserGymFunctionGemmaEnv` wraps the BrowserGym client and exposes each browser action as a tool method. The trainer creates one environment instance per rollout, calls `reset()` to initialize the episode and provide the initial observation, and then lets the model interact via tool calls for up to `max_steps` steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CgHd5CFBBP58" + }, + "outputs": [], + "source": [ + "import re\nfrom browsergym_env import BrowserGymAction\n\nmax_steps = 10\n\n\nclass BrowserGymFunctionGemmaEnv:\n def __init__(self):\n self.client = BrowserGymEnv(base_url=space_url).sync()\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n\n def _ensure_large_max_size(self):\n \"\"\"Patch the WebSocket connection to allow messages up to 100 MB.\n\n Some accessibility trees exceed the default 1 MB limit, which causes\n the frame to be dropped silently and the observation to be empty.\n \"\"\"\n self.client.connect()\n ws = self.client._ws\n if ws is not None and hasattr(ws, \"protocol\"):\n proto = ws.protocol\n # websockets <16: max_size; websockets >=16: max_message_size\n attr = \"max_size\" if hasattr(proto, \"max_size\") else \"max_message_size\"\n if getattr(proto, attr) == 2**20:\n setattr(proto, attr, 100 * 1024 * 1024)\n\n def reset(self, **kwargs) -> str:\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n self._ensure_large_max_size()\n result = self.client.reset()\n self._done = result.done\n return self._format_observation(result.observation)\n\n def click(self, bid: str) -> str:\n \"\"\"Click an element on the page.\n\n Args:\n bid: The BrowserGym ID of the element to click.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"click({repr(bid)})\")\n\n def fill(self, bid: str, text: str) -> str:\n \"\"\"Fill an input field with text.\n\n Args:\n bid: The BrowserGym ID of the input field.\n text: The text to type into the field.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"fill({repr(bid)}, {repr(text)})\")\n\n def send_keys(self, text: str) -> str:\n \"\"\"Send keyboard input to the page.\n\n Args:\n text: The keyboard input to send.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"send_keys({repr(text)})\")\n\n def scroll(self, direction: str) -> str:\n \"\"\"Scroll the page.\n\n Args:\n direction: Direction to scroll, either 'up' or 'down'.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"scroll({repr(direction)})\")\n\n def noop(self) -> str:\n \"\"\"Do nothing and observe the current page state.\n\n Returns:\n The current page observation.\n \"\"\"\n return self._do_action(\"noop()\")\n\n def _do_action(self, action_str: str) -> str:\n if self._done:\n raise ValueError(\"Episode is done.\")\n self._step_count += 1\n result = self.client.step(BrowserGymAction(action_str=action_str))\n step_reward = float(result.reward or 0.0)\n self._done = result.done\n if self._done and step_reward > 0:\n self.reward = 1.0\n elif self._done:\n self.reward = 0.0\n else:\n self.reward = step_reward\n if self._step_count >= max_steps:\n self._done = True\n return self._format_observation(result.observation)\n\n def _format_observation(self, observation) -> str:\n parts = []\n if observation.goal:\n parts.append(f\"Goal: {observation.goal}\")\n if observation.last_action_error and observation.error:\n parts.append(f\"Error: {observation.error}\")\n if observation.axtree_txt:\n axtree = observation.axtree_txt\n axtree = re.sub(r'\\[(\\d+)\\]', r'[bid:\\1]', axtree) # [13] → [bid:13] so model uses the number as bid\n if len(axtree) > 2000:\n axtree = axtree[:2000] + \"...\"\n parts.append(f\"Page structure:\\n{axtree}\")\n parts.append(f\"Step {self._step_count + 1}/{max_steps}: call a tool to act.\")\n return \"\\n\\n\".join(parts) if parts else \"No observation available.\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oek3JhcWnKhw" + }, + "source": "## Define the reward functions\n\nReward functions quantify the model's performance in the environment and guide the GRPO optimization process.\n\n`reward_completion` assigns 1.0 when the agent successfully completes the task, 0.0 otherwise." + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WxkXaz5aBP59" + }, + "outputs": [], + "source": "def reward_completion(environments, **kwargs) -> list[float]:\n \"\"\"Reward for task completion.\"\"\"\n return [env.reward for env in environments]" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "66ZsrLplm07U" + }, + "source": [ + "## Load the custom dataset\n", + "\n", + "The dataset provides the initial prompt for each episode. Each entry triggers one rollout during training. The system prompt instructs the model to use the available browser tool calls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UX6jUjxaBP59" + }, + "outputs": [], + "source": [ + "from datasets import Dataset\n", + "\n", + "dataset_prompt = \"Complete the web task successfully.\"\n", + "dataset_size = 1000\n", + "\n", + "dataset = Dataset.from_dict(\n", + " {\"prompt\": [\n", + " [\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": dataset_prompt},\n", + " ]\n", + " ] * dataset_size}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TZ34a1h-BP59" + }, + "outputs": [], + "source": [ + "from trl import GRPOConfig, GRPOTrainer\n\noutput_dir = \"browsergym-grpo-functiongemma-270m-it\"\n\ngrpo_config = GRPOConfig(\n # num_train_epochs=1, # Number of times to iterate over the full dataset (use for full training runs)\n max_steps=100, # Number of dataset passes (for shorter runs/testing). For full trainings, use `num_train_epochs` instead\n learning_rate=5e-6, # Learning rate for the optimizer\n warmup_steps=10, # Number of steps to linearly increase learning rate at the start of training\n max_grad_norm=0.1, # Clip gradients to prevent explosion\n beta=0.01, # KL penalty against reference policy\n\n per_device_train_batch_size=1, # Number of samples per device per step\n num_generations=4, # Number of completions to generate per prompt\n generation_batch_size=4, # Must be divisible by num_generations\n max_completion_length=128, # Maximum tokens per model generation\n max_tool_calling_iterations=max_steps, # Cap tool-call rounds per episode to match environment's max_steps\n\n use_liger_kernel=True, # Liger kernel optimizations for faster training\n\n output_dir=str(output_dir), # Directory where checkpoints, logs, and outputs will be saved\n logging_steps=1, # Log metrics every N steps\n log_completions=True, # Print completions in the log to inspect what the model generates\n num_completions_to_print=1, # Show only 1 completion per step in the log\n report_to=\"trackio\", # Logging/reporting platform\n trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n push_to_hub=True, # Optionally push trained model to Hugging Face Hub\n\n model_init_kwargs={\"attn_implementation\": \"eager\"}, # Gemma3 requires eager attention (sdpa causes numerical instability)\n)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a1taGmD--0Y4" + }, + "source": [ + "The next step is to initialize the GRPOTrainer, which manages the reinforcement learning loop.\n", + "\n", + "It receives the model name, reward functions, environment factory, and dataset. The trainer initializes the model and tokenizer from the model name, coordinates multi-turn interactions between the model and BrowserGym via tool calls, applies reward signals, and updates the policy.\n", + "\n", + "Calling `trainer.train()` starts fine-tuning.\n", + "\n", + "> Note: The training pipeline uses approximately 10.6 GB of GPU VRAM and can be adapted to different hardware configurations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "En43o4NZBP59" + }, + "outputs": [], + "source": [ + "model_name = \"google/functiongemma-270m-it\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\nfrom trl import add_response_schema\n\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nadd_response_schema(tokenizer) # Sets tokenizer.response_schema for FunctionGemma's tool-call format\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e1PrBB7gBP59", + "outputId": "61740a89-228c-4b3c-8e59-b4a3eb972c03" + }, + "outputs": [], + "source": "trainer = GRPOTrainer(\n model=model_name,\n processing_class=tokenizer,\n reward_funcs=[reward_completion],\n train_dataset=dataset,\n args=grpo_config,\n environment_factory=BrowserGymFunctionGemmaEnv,\n)" + }, + { + "cell_type": "code", + "source": "trainer_stats = trainer.train()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZj4IG9ZBAix" + }, + "source": [ + "In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "244ced1920694dbaae9bf98065b4f01d", + "e3769ae107554c9ba38c1e491b15bf4e", + "6d5b8bff73474faeb1d1b438fb4e8cec", + "9f952f8eb63b42e4b38711737da5461e", + "bd12780895064467b5be14e2ec3df114", + "d1261c1083a74dca877e6eece6395d73", + "999744cacd6a4fb08a1d4977ce2f06fd", + "faa5e0fb4ee244689c0f9eef9902acf7", + "6403bed2cd984ba18f74f416748c64e4", + "38be017369524e2eb22050e7a0a18ec5", + "b0720a4a2df948308011d4d87a288426", + "889ca2520f4d446daf2e6ed16ce11d2e" + ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "-mvka-96m3I7" + "id": "9oOBgEWeBP59", + "outputId": "76bef375-fc6b-4fdd-a296-549a9b109b11" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "244ced1920694dbaae9bf98065b4f01d", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "## Fine-tune using TRL and the GRPOTrainer\n", - "\n", - "The next step is to define the GRPOConfig, which sets all key training parameters.\n", - "\n", - "This configuration determines how the model interacts with vLLM, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process." + "text/plain": [ + "Processing Files (0 / 0) : | | 0.00B / 0.00B " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TZ34a1h-BP59" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e3769ae107554c9ba38c1e491b15bf4e", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "from trl import GRPOConfig\n", - "output_dir = \"browsergym-grpo-functiongemma-270m-it\"\n", - "\n", - "grpo_config = GRPOConfig(\n", - " # num_train_epochs=1, # Number of times to iterate over the full dataset (use for full training runs)\n", - " max_steps=100, # Number of dataset passes (for shorter runs/testing). For full trainings, use `num_train_epochs` instead\n", - " learning_rate=5e-6, # Learning rate for the optimizer\n", - " warmup_steps=10, # Number of steps to linearly increase learning rate at the start of training\n", - "\n", - " per_device_train_batch_size=1, # Number of samples per device per step\n", - " num_generations=4, # Number of completions to generate per prompt\n", - " generation_batch_size=4, # Batch size used during generation (must be divisible by num_generations)\n", - " max_completion_length=32, # Maximum length of generated completions\n", - "\n", - " use_vllm=True, # Use vLLM engine for fast inference\n", - " vllm_mode=\"colocate\", # vLLM mode: \"colocate\" runs generation on the same GPU as training\n", - " vllm_gpu_memory_utilization=0.1, # Fraction of GPU memory allocated to vLLM\n", - "\n", - " output_dir=str(output_dir), # Directory where checkpoints, logs, and outputs will be saved\n", - " logging_steps=1, # Log metrics every N steps\n", - " report_to=\"trackio\", # Logging/reporting platform (e.g., \"trackio\")\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " push_to_hub=True, # Optionally push trained model to Hugging Face Hub\n", - "\n", - " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", - ")\n" + "text/plain": [ + "New Data Upload : | | 0.00B / 0.00B " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "a1taGmD--0Y4" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6d5b8bff73474faeb1d1b438fb4e8cec", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "The next step is to initialize the GRPOTrainer, which manages the complete reinforcement learning loop.\n", - "\n", - "It receives the model name, reward functions, rollout function, and dataset defined earlier. From the model name, the trainer automatically initializes the model and tokenizer. It then coordinates interactions between the model and the environment, applies the defined reward signals, and updates the policy during training.\n", - "\n", - "Finally, calling `trainer.train()` starts the fine-tuning process, enabling the model to progressively improve its performance through iterative interaction and reinforcement learning.\n", - "\n", - "> Note: The training pipeline uses approximately 10.6 GB of GPU VRAM and can be adapted to different hardware configurations." + "text/plain": [ + " ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "En43o4NZBP59" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9f952f8eb63b42e4b38711737da5461e", + "version_major": 2, + "version_minor": 0 }, - "outputs": [], - "source": [ - "model_name = \"google/functiongemma-270m-it\"" + "text/plain": [ + " ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "047d386e54704add95edd4beace781d7" - ] - }, - "id": "k8-SvqJcBP59", - "outputId": "6a4d9276-fc91-4217-d3a2-51a18d222338" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bd12780895064467b5be14e2ec3df114", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipython-input-3830121904.py:1: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n", - " trainer = GRPOTrainer(\n", - "The model is already on multiple devices. Skipping the move to device specified in `args`.\n", - "`torch_dtype` is deprecated! Use `dtype` instead!\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "047d386e54704add95edd4beace781d7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1765969078\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: noop()\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: Click 'click(bid) - Click element with BrowserGym ID (the number in brackets\n", - "Step 8: I will use the action `click()` to click the button.\n", - "Step 9: noop()\n", - "Step 10: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: noop()\n", - "Step 2: noop()\n", - "Step 3: Clicks ('13')\n", - "Step 4: I will click 'Click Me!' using action 'click(bid)' on page 'Click Test Task' using a bid of '13'.\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: I will use the 'click(bid)' action.\n", - "Step 2: mouse_click(bid)\n", - "Step 3: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 4: Add action 'click(bid)' to Step 4.\n", - "Step 5: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 9: noop()\n", - "Step 10: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: noop()\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: Click('13')\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:liger_kernel.transformers.model.gemma3:It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`.\n", - "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py:282: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n", - " warnings.warn(\n", - "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7095: UserWarning: \n", - "Online softmax is disabled on the fly since Inductor decides to\n", - "split the reduction. Cut an issue to PyTorch if this is an\n", - "important use case and you want to speed it up with online\n", - "softmax.\n", - "\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [100/100 35:02, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
10.000000
20.000000
30.000000
40.000000
50.000000
60.000000
70.000000
80.000000
9-0.877900
101965.894400
11-0.830900
1210.616100
130.000000
140.000000
150.000000
160.000000
172.320100
181.887500
19-0.691600
20-0.764400
210.000000
220.000000
230.000000
240.000000
250.000000
260.000000
270.000000
280.000000
290.000000
300.000000
310.000000
320.000000
330.000000
340.000000
350.000000
360.000000
370.000000
380.000000
390.000000
400.000000
410.000000
420.000000
430.000000
440.000000
450.000000
460.000000
470.000000
480.000000
490.000000
500.000000
510.000000
520.000000
530.000000
540.000000
550.000000
560.000000
570.000000
580.000000
590.000000
600.000000
610.000000
620.000000
630.000000
640.000000
650.000000
660.000000
670.000000
680.000000
690.000000
700.000000
710.000000
720.000000
730.000000
740.000000
750.000000
760.000000
770.000000
780.000000
790.000000
800.000000
810.000000
820.000000
830.000000
840.000000
850.000000
860.000000
870.000000
880.000000
890.000000
900.000000
910.000000
920.000000
930.000000
940.000000
950.000000
960.000000
970.000000
980.000000
990.000000
1000.000000

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: Clicks ('13')\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 10: noop()\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: noop()\n", - "Step 2: I will use action: click(bid) to click the button.\n", - "Step 3: Yes, I can handle this. I will use the `click()` action to click the button.\n", - "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 9: noop()\n", - "Step 10: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 8: noop()\n", - "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 10: Pass the button ID ('Click Me!') to the action \"click('bid')\".\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: noop()\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: I will click the button by emitting `click(bid)` and `fill(bid, text)` simultaneously.\n", - "Step 6: noop()\n", - "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: - Noop()\n", - "Step 2: noop()\n", - "Step 3: -noop()\n", - "Step 4: noop()\n", - "Step 5: Click('13')\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: noop()\n", - "Step 2: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: Complete action: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: I will use the action 'click('bid') to click the button.\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: I call action Click (bid) on the page.\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: Oops()\n", - "Step 2: noop()\n", - "Step 3: fill(bid, text)\n", - "Step 4: noop()\n", - "Step 5: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: def click_button_on_page():\n", - "Step 2: noop()\n", - "Step 3: click(bid)\n", - "Step 4: Click('13')\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: noop()\n", - "Step 2: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 3: noop()\n", - "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 5: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 6: I will click the button 'Click Me!' by using the action `click(bid)` and emitting a bid of 13.\n", - "Step 7: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: noop()\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: `click(bid)` - No action\n", - "Step 2: - Noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: noop()\n", - "Step 10: I will click the button 'Click Me!' using the action 'click(bid)'.\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: noop()\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: Complete action: click(bid)\n", - "Step 10: noop()\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: noop()\n", - "Step 2: I will perform action 1: click('13') to complete the action.\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: noop()\n", - "Step 2: noop()\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: Click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 8: noop()\n", - "Step 9: Click ('13')\n", - "Step 10: Add action 'fill(bid, text) - Send keyboard input' to perform the click.\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: noop()\n", - "Step 2: Click('click(bid) - Bid')\n", - "Step 3: noop()\n", - "Step 4: noop()\n", - "Step 5: noop()\n", - "Step 6: noop()\n", - "Step 7: noop()\n", - "Step 8: noop()\n", - "Step 9: click(bid) - Click element with BrowserGym ID (the number in brackets)\n", - "Step 10: noop()\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "\n", - "[DEBUG] rollout_func called with 4 prompts (LLM mode, text-only)\n", - "[DEBUG] Processing prompt 1/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 2/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 3/4\n", - "Step 1: click('13')\n", - "[DEBUG] Processing prompt 4/4\n", - "Step 1: click('13')\n", - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" + "text/plain": [ + " ...270m-it/model.safetensors: 4%|3 | 41.9MB / 1.07GB " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "BZj4IG9ZBAix" - }, - "source": [ - "In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials." - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "No files have been modified since last commit. Skipping to prevent empty commit.\n", + "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "244ced1920694dbaae9bf98065b4f01d", - "e3769ae107554c9ba38c1e491b15bf4e", - "6d5b8bff73474faeb1d1b438fb4e8cec", - "9f952f8eb63b42e4b38711737da5461e", - "bd12780895064467b5be14e2ec3df114", - "d1261c1083a74dca877e6eece6395d73", - "999744cacd6a4fb08a1d4977ce2f06fd", - "faa5e0fb4ee244689c0f9eef9902acf7", - "6403bed2cd984ba18f74f416748c64e4", - "38be017369524e2eb22050e7a0a18ec5", - "b0720a4a2df948308011d4d87a288426", - "889ca2520f4d446daf2e6ed16ce11d2e" - ] - }, - "id": "9oOBgEWeBP59", - "outputId": "76bef375-fc6b-4fdd-a296-549a9b109b11" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "999744cacd6a4fb08a1d4977ce2f06fd", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "244ced1920694dbaae9bf98065b4f01d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing Files (0 / 0) : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e3769ae107554c9ba38c1e491b15bf4e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "New Data Upload : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6d5b8bff73474faeb1d1b438fb4e8cec", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9f952f8eb63b42e4b38711737da5461e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bd12780895064467b5be14e2ec3df114", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...ma-270m-it/tokenizer.json: 100%|##########| 33.4MB / 33.4MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d1261c1083a74dca877e6eece6395d73", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...270m-it/model.safetensors: 4%|3 | 41.9MB / 1.07GB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No files have been modified since last commit. Skipping to prevent empty commit.\n", - "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "999744cacd6a4fb08a1d4977ce2f06fd", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing Files (0 / 0) : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "faa5e0fb4ee244689c0f9eef9902acf7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "New Data Upload : | | 0.00B / 0.00B " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6403bed2cd984ba18f74f416748c64e4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...270m-it/training_args.bin: 100%|##########| 7.57kB / 7.57kB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "38be017369524e2eb22050e7a0a18ec5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...a-270m-it/tokenizer.model: 100%|##########| 4.69MB / 4.69MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b0720a4a2df948308011d4d87a288426", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...270m-it/model.safetensors: 3%|3 | 33.5MB / 1.07GB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "889ca2520f4d446daf2e6ed16ce11d2e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " ...ma-270m-it/tokenizer.json: 100%|##########| 33.4MB / 33.4MB " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No files have been modified since last commit. Skipping to prevent empty commit.\n", - "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n" - ] - }, - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it/commit/a17de133c28ca7fddfcb2694c32f2791de5ddbe6', commit_message='End of training', commit_description='', oid='a17de133c28ca7fddfcb2694c32f2791de5ddbe6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/browsergym-grpo-functiongemma-270m-it'), pr_revision=None, pr_num=None)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub()" + "text/plain": [ + "Processing Files (0 / 0) : | | 0.00B / 0.00B " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "talmc8b7nPXJ" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "faa5e0fb4ee244689c0f9eef9902acf7", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "## Load the Fine-Tuned Model and Run Inference\n", - "\n", - "The fine-tuned model is loaded to perform inference and evaluate its behavior on the target task. \n", - "In this case, the model is tested within the BrowserGym environment using OpenEnv, focusing on the *click* task from the MiniWoB++ benchmark, which is included among the available BrowserGym tasks." + "text/plain": [ + "New Data Upload : | | 0.00B / 0.00B " ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "c3879b716f37442a87d51b8414fe8c48" - ] - }, - "id": "iIDiaGVlBP5-", - "outputId": "4dc0e365-e89f-40ba-b391-74c7efdc932d" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6403bed2cd984ba18f74f416748c64e4", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c3879b716f37442a87d51b8414fe8c48", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "model.safetensors: 0%| | 0.00/1.07G [00:00 None: class BrowserGymVLMEnv: def __init__(self): - self.client = BrowserGymEnv(base_url=space_url) + self.client = BrowserGymEnv(base_url=space_url).sync() self.reward = 0.0 self.done = False self._step_count = 0 diff --git a/examples/scripts/openenv/browsergym_llm.py b/examples/scripts/openenv/browsergym_llm.py index ae68f98e578..0d158873e7e 100644 --- a/examples/scripts/openenv/browsergym_llm.py +++ b/examples/scripts/openenv/browsergym_llm.py @@ -17,7 +17,7 @@ # "trl[vllm,peft]", # "trackio", # "kernels", -# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env", +# "openenv-browsergym @ git+https://huggingface.co/spaces/sergiopaniego/browsergym_env", # ] # /// @@ -29,13 +29,13 @@ The environment runs on a Hugging Face Space by default. -Setup (Option A - Install from HF Space, recommended): +Setup: ```sh -uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env +uv pip install git+https://huggingface.co/spaces/sergiopaniego/browsergym_env ``` -Setup (Option B - Clone OpenEnv repo, for development): +Setup (for development, clone the repo): ```sh git clone https://github.com/meta-pytorch/OpenEnv.git @@ -285,7 +285,7 @@ def main() -> None: class BrowserGymLLMEnv: def __init__(self): - self.client = BrowserGymEnv(base_url=space_url) + self.client = BrowserGymEnv(base_url=space_url).sync() self.reward = 0.0 self._done = False self._step_count = 0 diff --git a/trl/__init__.py b/trl/__init__.py index 4947fc16c56..7c1480efee1 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -29,6 +29,7 @@ "chat_template_utils": [ "add_response_schema", "clone_chat_template", + "functiongemma_schema", "get_training_chat_template", "supports_tool_calling", ], @@ -77,6 +78,7 @@ from .chat_template_utils import ( add_response_schema, clone_chat_template, + functiongemma_schema, get_training_chat_template, supports_tool_calling, ) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 435c080267b..a690cddbd6b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -307,6 +307,42 @@ def clone_chat_template( } +functiongemma_chat_template = (_CHAT_TEMPLATES_DIR / "functiongemma.jinja").read_text(encoding="utf-8") + +functiongemma_schema = { + # FunctionGemma tool-call format: call:name{key:value} + # Both "system" and "developer" as first message role are rendered identically by the template. + "x-regex": r"^(?P(?:(?!)[\s\S])*?)(?P(?:[\s\S]+?\s*)+)?$", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "tool_calls": { + "type": "array", + "x-regex-iterator": r"(call:\w+\{[\s\S]*?\})", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "x-regex": r"call:(?P\w+)\{(?P[\s\S]*?)\}$", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "x-regex-key-value": r"(?P\w+):(?P[^<]*)", + "default": {}, + "additionalProperties": {"type": "string"}, + }, + }, + }, + }, + }, + }, + }, +} + cohere_chat_template = (_CHAT_TEMPLATES_DIR / "cohere.jinja").read_text(encoding="utf-8") cohere2_chat_template = (_CHAT_TEMPLATES_DIR / "cohere2.jinja").read_text(encoding="utf-8") @@ -392,7 +428,9 @@ def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: tokenizer = processing_class.tokenizer else: tokenizer = processing_class - if chat_template == glm4moe_chat_template: + if chat_template == functiongemma_chat_template: + tokenizer.response_schema = functiongemma_schema + elif chat_template == glm4moe_chat_template: tokenizer.response_schema = glm4moe_schema elif chat_template == gptoss_chat_template: tokenizer.response_schema = gptoss_schema diff --git a/trl/chat_templates/functiongemma.jinja b/trl/chat_templates/functiongemma.jinja new file mode 100644 index 00000000000..16294794d96 --- /dev/null +++ b/trl/chat_templates/functiongemma.jinja @@ -0,0 +1,279 @@ +{%- macro format_parameters(properties, required) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- if key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{- key }}:{description:{{ value['description'] }} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + ,enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'OBJECT' -%} + ,properties:{ + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + {%- elif value is mapping -%} + {{- format_parameters(value, value['required'] | default([])) -}} + {%- endif -%} + } + {%- if value['required'] -%} + ,required:[ + {%- for item in value['required'] | default([]) -%} + {{- item -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + ,items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + {{- req_item -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + ,type:{{ value['type'] | upper }}} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{% macro format_function_declaration(tool_data) -%} +declaration:{{- tool_data['function']['name'] -}} +{description:{{- tool_data['function']['description'] -}} +{%- set params = tool_data['function']['parameters'] -%} +{%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + {{- item -}} + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:{{- params['type'] | upper -}}} + {%- endif -%} +{%- endif -%} +} +{%- endmacro -%} +{% macro format_argument(argument, escape_keys=True) -%} +{%- if argument is string -%} + {{- '' + argument + '' -}} +{%- elif argument is boolean -%} + {%- if argument -%} + {{- 'true' -}} + {%- else -%} + {{- 'false' -}} + {%- endif -%} +{%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '' + key + '' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} +{%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} +{%- else -%} + {{- argument -}} +{%- endif -%} +{%- endmacro -%} +{{ bos_token }} +{%- set ns = namespace(prev_message_type=None) -%} +{#- Tool Declarations -#} +{%- set loop_messages = messages -%} +{%- if tools or messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%} + {{- 'developer\n' -}} + {%- if messages[0]['role'] == 'system' or messages[0]['role'] == 'developer' -%} + {%- if messages[0]['content'] is string -%} + {{- messages[0]['content'] | trim -}} + {%- elif messages[0]['content'] is sequence -%} + {%- for item in messages[0]['content'] -%} + {%- if item['type'] == 'text' -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + {%- if tools -%} + {%- for tool in tools %} + {{- '' -}} + {{- format_function_declaration(tool) | trim }} + {{- '' -}} + {%- endfor %} + {%- endif -%} + {{- '\n' }} +{%- endif %} +{#- Loop through messages. -#} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'assistant') -%} + {#- Rename "assistant" to "model". -#} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {%- if role != 'tool' -%} + {%- if ns.prev_message_type != 'tool_response' -%} + {{- '' + role + '\n' }} + {%- endif -%} + {%- set ns.prev_message_type = None -%} + {%- if 'content' in message and message['content'] is not none -%} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type in user/assistant message") }} + {%- endif -%} + {%- set ns.prev_message_type = 'content' -%} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls'] is iterable -%} + {#- Tool Calls -#} + {%- for tool_call in message['tool_calls'] -%} + {% set function = tool_call['function'] %} + {{- 'call:' + function['name'] + '{' -}} + {%- if 'arguments' in function -%} + {%- if function['arguments'] is mapping -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {# This handles string-JSON, just in case #} + {{ function['arguments'] }} + {%- endif %} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- if loop.last -%} + {{ '' }} + {%- endif -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + {%- else -%} + {#- Tool Responses -#} + {%- if 'content' in message and message['content'] -%} + {%- if message['content'] is mapping -%} + {%- if 'name' in message['content'] and 'response' in message['content'] -%} + {{ 'response:' + message['content']['name'] | trim + '{' }} + {%- set response_ns = namespace(found_first=false) -%} + {%- for key, value in message['content']['response'] | dictsort -%} + {%- if response_ns.found_first %},{% endif -%} + {%- set response_ns.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif 'name' in message -%} + {{ 'response:' + message['name'] | trim + '{' }} + {%- set response_ns = namespace(found_first=false) -%} + {%- for key, value in message['content'] | dictsort -%} + {%- if response_ns.found_first %},{% endif -%} + {%- set response_ns.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }} + {%- endif -%} + {%- elif message['content'] is string -%} + {%- if 'name' in message -%} + {{ 'response:' + message['name'] | trim + '{value:' + format_argument(message['content'], escape_keys=False) + '}' }} + {%- else -%} + {{ raise_exception("Invalid tool response: 'name' must be provided.") }} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item is mapping -%} + {%- if 'name' in item and 'response' in item -%} + {{ 'response:' + item['name'] | trim + '{' }} + {%- set response_ns = namespace(found_first=false) -%} + {%- for key, value in item['response'] | dictsort -%} + {%- if response_ns.found_first %},{% endif -%} + {%- set response_ns.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif 'name' in message -%} + {{ 'response:' + message['name'] | trim + '{' }} + {%- set response_ns = namespace(found_first=false) -%} + {%- for key, value in item | dictsort -%} + {%- if response_ns.found_first %},{% endif -%} + {%- set response_ns.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{ raise_exception("Invalid tool response mapping: must contain 'name' and 'response' keys, or 'name' must be in the message.") }} + {%- endif -%} + {%- else -%} + {{ raise_exception("Invalid tool response message: multiple responses must all be mappings") }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type in tool message: must be mapping, sequence of mappings, or string.") }} + {%- endif -%} + {%- endif -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + {%- if ns.prev_message_type not in ['tool_call', 'tool_response'] -%} + {{ '\n' }} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' -%} + {{- 'model\n' -}} + {%- endif -%} +{%- endif -%} diff --git a/trl/experimental/openenv/__init__.py b/trl/experimental/openenv/__init__.py deleted file mode 100644 index 4325e17f284..00000000000 --- a/trl/experimental/openenv/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .utils import generate_rollout_completions - - -__all__ = ["generate_rollout_completions"] diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py deleted file mode 100644 index f8f4d573854..00000000000 --- a/trl/experimental/openenv/utils.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from ...data_utils import is_conversational -from ...extras.profiling import profiling_context -from ...import_utils import is_vllm_available - - -if is_vllm_available(): - from vllm import SamplingParams - from vllm.sampling_params import StructuredOutputsParams - - -def _build_base_generation_kwargs( - trainer, - overrides: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Build base generation kwargs common to both colocate and server modes.""" - generation_kwargs: dict[str, Any] = { - "n": 1, - "temperature": trainer.temperature, - "top_k": trainer.top_k, - "min_p": 0.0 if trainer.min_p is None else trainer.min_p, - "max_tokens": trainer.max_completion_length, - } - if trainer.repetition_penalty is not None: - generation_kwargs["repetition_penalty"] = trainer.repetition_penalty - if trainer.top_p is not None: - generation_kwargs["top_p"] = trainer.top_p - - if trainer.args.generation_kwargs is not None: - generation_kwargs.update(trainer.args.generation_kwargs) - - if overrides is not None: - generation_kwargs.update(overrides) - - generation_kwargs = {key: value for key, value in generation_kwargs.items() if value is not None} - - if generation_kwargs.get("n", 1) != 1: - raise ValueError("generate_rollout_completions expects n=1.") - - return generation_kwargs - - -def _build_colocate_sampling_params( - trainer, - overrides: dict[str, Any] | None = None, - *, - logprobs: bool = True, -) -> "SamplingParams": - """Build SamplingParams for colocate mode.""" - generation_kwargs = _build_base_generation_kwargs(trainer, overrides) - - # Add colocate-specific parameters - if trainer.vllm_generation.structured_outputs_regex: - generation_kwargs["structured_outputs"] = StructuredOutputsParams( - regex=trainer.vllm_generation.structured_outputs_regex - ) - if logprobs: - generation_kwargs["logprobs"] = 0 - - return SamplingParams(**generation_kwargs) - - -def _build_server_generation_kwargs( - trainer, - overrides: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Build generation kwargs for server mode.""" - return _build_base_generation_kwargs(trainer, overrides) - - -def generate_rollout_completions( - trainer, - prompts: list[str], - *, - generation_overrides: dict[str, Any] | None = None, - as_chat: bool | None = None, -) -> list[dict[str, Any]]: - """ - Generate completions for custom rollouts when vLLM is running in colocate or server mode. - - Returns one result per prompt, containing prompt and completion token ids along with per-token log probabilities - and the generated text. - """ - - if not prompts: - return [] - - if not trainer.use_vllm: - raise RuntimeError("Custom rollouts require vLLM to call generate_rollout_completions.") - - if trainer.vllm_mode == "server": - return _generate_rollout_completions_server(trainer, prompts, generation_overrides, as_chat) - elif trainer.vllm_mode == "colocate": - return _generate_rollout_completions_colocate(trainer, prompts, generation_overrides, as_chat) - else: - raise ValueError(f"vllm_mode must be 'server' or 'colocate', got '{trainer.vllm_mode}'") - - -def _generate_rollout_completions_server( - trainer, - prompts: list[str], - generation_overrides: dict[str, Any] | None = None, - as_chat: bool | None = None, -) -> list[dict[str, Any]]: - """Generate completions using vLLM server mode.""" - generation_kwargs = _build_server_generation_kwargs(trainer, generation_overrides) - - if as_chat is None: - as_chat = prompts and is_conversational({"prompt": prompts[0]}) - - with profiling_context(trainer, "vLLM.generate_rollout_server"): - if as_chat: - # Prompts are raw message dicts; use .chat() so the vLLM server applies the chat template - output = trainer.vllm_generation.vllm_client.chat( - messages=prompts, - **generation_kwargs, - chat_template_kwargs=trainer.chat_template_kwargs, - tools=trainer.tools or None, - chat_template=trainer.chat_template, - ) - else: - output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) - - # Format results to match colocate output format - results: list[dict[str, Any]] = [] - for i in range(len(prompts)): - results.append( - { - "prompt_ids": output["prompt_ids"][i], - "completion_ids": list(output["completion_ids"][i]), - "logprobs": list(output["logprobs"][i]), - "text": trainer.processing_class.decode(output["completion_ids"][i], skip_special_tokens=True), - } - ) - - return results - - -def _generate_rollout_completions_colocate( - trainer, - prompts: list[str], - generation_overrides: dict[str, Any] | None = None, - as_chat: bool | None = None, -) -> list[dict[str, Any]]: - """Generate completions using vLLM colocate mode.""" - sampling_params = _build_colocate_sampling_params(trainer, generation_overrides) - prompts_for_generation = prompts - original_size = len(prompts) - - if trainer.vllm_tensor_parallel_size > 1: - gathered_prompts = [None for _ in range(trainer.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=trainer.vllm_generation.tp_group) - prompts_for_generation = [prompt for group_prompts in gathered_prompts for prompt in group_prompts] - - if as_chat is None: - as_chat = prompts_for_generation and is_conversational({"prompt": prompts_for_generation[0]}) - - if trainer.args.vllm_enable_sleep_mode: - trainer.vllm_generation.llm.wake_up(tags=["kv_cache"]) - # Work around for https://github.com/vllm-project/vllm/issues/29341 - trainer.vllm_generation.llm.collective_rpc("reload_weights") - - with profiling_context(trainer, "vLLM.generate_rollout"): - if as_chat: - vllm_outputs = trainer.vllm_generation.llm.chat( - prompts_for_generation, sampling_params=sampling_params, use_tqdm=False - ) - else: - vllm_outputs = trainer.vllm_generation.llm.generate( - prompts_for_generation, sampling_params=sampling_params, use_tqdm=False - ) - - results: list[dict[str, Any]] = [] - for request in vllm_outputs: - if not request.outputs: - results.append({"prompt_ids": request.prompt_token_ids, "completion_ids": [], "logprobs": [], "text": ""}) - continue - sequence = request.outputs[0] - logprobs = [next(iter(token_logprob.values())).logprob for token_logprob in sequence.logprobs] - results.append( - { - "prompt_ids": request.prompt_token_ids, - "completion_ids": sequence.token_ids, - "logprobs": logprobs, - "text": sequence.text, - } - ) - - if trainer.vllm_tensor_parallel_size > 1: - local_rank_in_group = torch.distributed.get_rank(group=trainer.vllm_generation.tp_group) - tp_slice = slice(local_rank_in_group * original_size, (local_rank_in_group + 1) * original_size) - results = results[tp_slice] - - if trainer.args.vllm_enable_sleep_mode: - trainer.vllm_generation.llm.sleep(level=2) - - return results From 4d66446189c23d5b25c7faa785cf770b2bfde0d7 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 16:22:54 +0200 Subject: [PATCH 02/10] update --- docs/source/openenv.md | 4 +-- ...rpo_functiongemma_browsergym_openenv.ipynb | 29 ++++--------------- examples/scripts/openenv/browsergym.py | 4 +-- examples/scripts/openenv/browsergym_llm.py | 4 +-- trl/__init__.py | 2 -- trl/chat_template_utils.py | 2 -- 6 files changed, 12 insertions(+), 33 deletions(-) diff --git a/docs/source/openenv.md b/docs/source/openenv.md index 18f38465d24..f51331d8987 100644 --- a/docs/source/openenv.md +++ b/docs/source/openenv.md @@ -26,7 +26,7 @@ pip install "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordl pip install "openenv-openspiel-env @ git+https://huggingface.co/spaces/openenv/openspiel_env" # BrowserGym environment -pip install "openenv-browsergym @ git+https://huggingface.co/spaces/sergiopaniego/browsergym_env" +pip install "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env" ``` This installs the **environment client** (e.g., `EchoEnv`) that communicates with the remote environment server via WebSocket, along with the action/observation models and all required dependencies (including `openenv-core`). @@ -571,7 +571,7 @@ To create your own environment, check out the guide on [Building Your Own Enviro `rollout_func` is an experimental API that predates `environment_factory`. It is no longer recommended and will be removed in a future version. If you have existing scripts that use `rollout_func`, migrate them to `environment_factory`. > [!WARNING] -> `rollout_func` emits a deprecation warning at runtime and may be removed without prior notice. Do not use it for new projects. +> `rollout_func` emits an experimental-feature warning at runtime and may be removed without prior notice. Do not use it for new projects. ## Server concurrency diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index b4bb396280c..26ef2ef3ed4 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -43,7 +43,7 @@ "source": [ "!pip install -qU trl\n", "!pip install -qU jmespath\n", - "!pip install -qU git+https://huggingface.co/spaces/sergiopaniego/browsergym_env\n", + "!pip install -qU git+https://huggingface.co/spaces/openenv/browsergym_env\n", "!pip install -qU trackio\n", "!pip install -qU transformers" ] @@ -85,7 +85,7 @@ "\n", "[BrowserGym](https://meta-pytorch.org/OpenEnv/environments/browsergym/) is a unified framework for web-based agent tasks, offering multiple benchmarks through a Gymnasium-compatible API. It enables training on simple synthetic tasks with [MiniWoB++](https://github.com/Farama-Foundation/miniwob-plusplus) and evaluation on more complex, realistic tasks with [WebArena](https://github.com/web-arena-x/webarena), [VisualWebArena](https://github.com/web-arena-x/visualwebarena), or [WorkArena](https://github.com/ServiceNow/WorkArena). This setup supports iterative training and assessment of web agents without requiring extensive infrastructure.\n", "\n", - "BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called *\"click-test\"*, which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment [sergiopaniego/browsergym_env](https://huggingface.co/spaces/sergiopaniego/browsergym_env) will be used.\n", + "BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called *\"click-test\"*, which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment [openenv/browsergym_env](https://huggingface.co/spaces/openenv/browsergym_env) will be used.\n", "\n", "> Note: Hosted environments on the Hub currently have limited concurrency. For higher reliability or parallel runs, duplicating the Space to your own account is strongly recommended." ] @@ -99,7 +99,7 @@ "outputs": [], "source": [ "from browsergym_env import BrowserGymEnv\n", - "space_url = \"https://sergiopaniego-browsergym-env.hf.space\"" + "space_url = \"https://openenv-browsergym-env.hf.space\"" ] }, { @@ -140,9 +140,7 @@ "id": "CgHd5CFBBP58" }, "outputs": [], - "source": [ - "import re\nfrom browsergym_env import BrowserGymAction\n\nmax_steps = 10\n\n\nclass BrowserGymFunctionGemmaEnv:\n def __init__(self):\n self.client = BrowserGymEnv(base_url=space_url).sync()\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n\n def _ensure_large_max_size(self):\n \"\"\"Patch the WebSocket connection to allow messages up to 100 MB.\n\n Some accessibility trees exceed the default 1 MB limit, which causes\n the frame to be dropped silently and the observation to be empty.\n \"\"\"\n self.client.connect()\n ws = self.client._ws\n if ws is not None and hasattr(ws, \"protocol\"):\n proto = ws.protocol\n # websockets <16: max_size; websockets >=16: max_message_size\n attr = \"max_size\" if hasattr(proto, \"max_size\") else \"max_message_size\"\n if getattr(proto, attr) == 2**20:\n setattr(proto, attr, 100 * 1024 * 1024)\n\n def reset(self, **kwargs) -> str:\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n self._ensure_large_max_size()\n result = self.client.reset()\n self._done = result.done\n return self._format_observation(result.observation)\n\n def click(self, bid: str) -> str:\n \"\"\"Click an element on the page.\n\n Args:\n bid: The BrowserGym ID of the element to click.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"click({repr(bid)})\")\n\n def fill(self, bid: str, text: str) -> str:\n \"\"\"Fill an input field with text.\n\n Args:\n bid: The BrowserGym ID of the input field.\n text: The text to type into the field.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"fill({repr(bid)}, {repr(text)})\")\n\n def send_keys(self, text: str) -> str:\n \"\"\"Send keyboard input to the page.\n\n Args:\n text: The keyboard input to send.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"send_keys({repr(text)})\")\n\n def scroll(self, direction: str) -> str:\n \"\"\"Scroll the page.\n\n Args:\n direction: Direction to scroll, either 'up' or 'down'.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"scroll({repr(direction)})\")\n\n def noop(self) -> str:\n \"\"\"Do nothing and observe the current page state.\n\n Returns:\n The current page observation.\n \"\"\"\n return self._do_action(\"noop()\")\n\n def _do_action(self, action_str: str) -> str:\n if self._done:\n raise ValueError(\"Episode is done.\")\n self._step_count += 1\n result = self.client.step(BrowserGymAction(action_str=action_str))\n step_reward = float(result.reward or 0.0)\n self._done = result.done\n if self._done and step_reward > 0:\n self.reward = 1.0\n elif self._done:\n self.reward = 0.0\n else:\n self.reward = step_reward\n if self._step_count >= max_steps:\n self._done = True\n return self._format_observation(result.observation)\n\n def _format_observation(self, observation) -> str:\n parts = []\n if observation.goal:\n parts.append(f\"Goal: {observation.goal}\")\n if observation.last_action_error and observation.error:\n parts.append(f\"Error: {observation.error}\")\n if observation.axtree_txt:\n axtree = observation.axtree_txt\n axtree = re.sub(r'\\[(\\d+)\\]', r'[bid:\\1]', axtree) # [13] → [bid:13] so model uses the number as bid\n if len(axtree) > 2000:\n axtree = axtree[:2000] + \"...\"\n parts.append(f\"Page structure:\\n{axtree}\")\n parts.append(f\"Step {self._step_count + 1}/{max_steps}: call a tool to act.\")\n return \"\\n\\n\".join(parts) if parts else \"No observation available.\"\n" - ] + "source": "import re\nfrom browsergym_env import BrowserGymAction\n\nmax_steps = 10\n\n\nclass BrowserGymFunctionGemmaEnv:\n def __init__(self):\n self.client = BrowserGymEnv(base_url=space_url).sync()\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n\n def _ensure_large_max_size(self):\n \"\"\"Patch the WebSocket connection to allow messages up to 100 MB.\n\n Some accessibility trees exceed the default 1 MB limit, which causes\n the frame to be dropped silently and the observation to be empty.\n \"\"\"\n self.client.connect()\n ws = self.client._ws\n if ws is not None and hasattr(ws, \"protocol\"):\n proto = ws.protocol\n # websockets <16: max_size; websockets >=16: max_message_size\n attr = \"max_size\" if hasattr(proto, \"max_size\") else \"max_message_size\"\n if getattr(proto, attr) == 2**20:\n setattr(proto, attr, 100 * 1024 * 1024)\n\n def reset(self, **kwargs) -> str:\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n self._ensure_large_max_size()\n result = self.client.reset()\n self._done = result.done\n return self._format_observation(result.observation)\n\n def click(self, bid: str) -> str:\n \"\"\"Click an element on the page.\n\n Args:\n bid: The BrowserGym ID of the element to click.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"click({repr(bid)})\")\n\n def fill(self, bid: str, text: str) -> str:\n \"\"\"Fill an input field with text.\n\n Args:\n bid: The BrowserGym ID of the input field.\n text: The text to type into the field.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"fill({repr(bid)}, {repr(text)})\")\n\n def send_keys(self, text: str) -> str:\n \"\"\"Send keyboard input to the page.\n\n Args:\n text: The keyboard input to send.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"send_keys({repr(text)})\")\n\n def scroll(self, direction: str) -> str:\n \"\"\"Scroll the page.\n\n Args:\n direction: Direction to scroll, either 'up' or 'down'.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"scroll({repr(direction)})\")\n\n def noop(self) -> str:\n \"\"\"Do nothing and observe the current page state.\n\n Returns:\n The current page observation.\n \"\"\"\n return self._do_action(\"noop()\")\n\n def _do_action(self, action_str: str) -> str:\n if self._done:\n raise ValueError(\"Episode is done.\")\n self._step_count += 1\n result = self.client.step(BrowserGymAction(action_str=action_str))\n step_reward = float(result.reward or 0.0)\n self._done = result.done\n if self._done and step_reward > 0:\n self.reward = 1.0\n elif self._done:\n self.reward = 0.0\n else:\n self.reward = step_reward\n if self._step_count >= max_steps:\n self._done = True\n return self._format_observation(result.observation)\n\n def _format_observation(self, observation) -> str:\n parts = []\n if observation.goal:\n parts.append(f\"Goal: {observation.goal}\")\n if observation.last_action_error and observation.error:\n parts.append(f\"Error: {observation.error}\")\n if observation.axtree_txt:\n axtree = observation.axtree_txt\n axtree = re.sub(r'\\[(\\d+)\\]', r'[bid:\\1]', axtree)\n if len(axtree) > 2000:\n axtree = axtree[:2000] + \"...\"\n parts.append(f\"Page structure:\\n{axtree}\")\n parts.append(f\"Step {self._step_count + 1}/{max_steps}: call a tool to act.\")\n return \"\\n\\n\".join(parts) if parts else \"No observation available.\"\n" }, { "cell_type": "markdown", @@ -507,21 +505,7 @@ "metadata": { "id": "talmc8b7nPXJ" }, - "source": [ - "## Load the Fine-Tuned Model and Run Inference\n", - "\n", - "The fine-tuned model is loaded to perform inference and evaluate its behavior on the target task. \n", - "In this case, the model is tested within the BrowserGym environment using OpenEnv, focusing on the *click* task from the MiniWoB++ benchmark, which is included among the available BrowserGym tasks." - ] - }, - { - "cell_type": "markdown", - "execution_count": null, - "metadata": { - "id": "doAEIf5IBP5-" - }, - "outputs": [], - "source": "With the fine-tuned model loaded, testing can be conducted on the BrowserGym environment. The `test_browsergym` function runs the model for one episode, parsing each FunctionGemma tool call and executing it in the environment step by step." + "source": "## Load the Fine-Tuned Model and Run Inference\n\nThe fine-tuned model is loaded to perform inference and evaluate its behavior on the target task. \nIn this case, the model is tested within the BrowserGym environment using OpenEnv, focusing on the *click* task from the MiniWoB++ benchmark, which is included among the available BrowserGym tasks.\n\nWith the fine-tuned model loaded, testing can be conducted on the BrowserGym environment. The `test_browsergym` function runs the model for one episode, parsing each FunctionGemma tool call and executing it in the environment step by step." }, { "cell_type": "code", @@ -534,7 +518,6 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": { "id": "Z77wlVb6BP5-", "outputId": "ed4ad094-1529-4cc7-8274-2782784efe2d" @@ -562,4 +545,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index c95e028fd02..956d0f55944 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -17,7 +17,7 @@ # "trl[vllm,peft]", # "trackio", # "kernels", -# "openenv-browsergym @ git+https://huggingface.co/spaces/sergiopaniego/browsergym_env", +# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env", # ] # /// @@ -30,7 +30,7 @@ Setup: ```sh -pip install "openenv-browsergym @ git+https://huggingface.co/spaces/sergiopaniego/browsergym_env" +pip install "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env" ``` Usage: diff --git a/examples/scripts/openenv/browsergym_llm.py b/examples/scripts/openenv/browsergym_llm.py index 0d158873e7e..60f0e3a7b49 100644 --- a/examples/scripts/openenv/browsergym_llm.py +++ b/examples/scripts/openenv/browsergym_llm.py @@ -17,7 +17,7 @@ # "trl[vllm,peft]", # "trackio", # "kernels", -# "openenv-browsergym @ git+https://huggingface.co/spaces/sergiopaniego/browsergym_env", +# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env", # ] # /// @@ -32,7 +32,7 @@ Setup: ```sh -uv pip install git+https://huggingface.co/spaces/sergiopaniego/browsergym_env +uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env ``` Setup (for development, clone the repo): diff --git a/trl/__init__.py b/trl/__init__.py index 7c1480efee1..4947fc16c56 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -29,7 +29,6 @@ "chat_template_utils": [ "add_response_schema", "clone_chat_template", - "functiongemma_schema", "get_training_chat_template", "supports_tool_calling", ], @@ -78,7 +77,6 @@ from .chat_template_utils import ( add_response_schema, clone_chat_template, - functiongemma_schema, get_training_chat_template, supports_tool_calling, ) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index a690cddbd6b..1a9fa977bcb 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -310,8 +310,6 @@ def clone_chat_template( functiongemma_chat_template = (_CHAT_TEMPLATES_DIR / "functiongemma.jinja").read_text(encoding="utf-8") functiongemma_schema = { - # FunctionGemma tool-call format: call:name{key:value} - # Both "system" and "developer" as first message role are rendered identically by the template. "x-regex": r"^(?P(?:(?!)[\s\S])*?)(?P(?:[\s\S]+?\s*)+)?$", "type": "object", "properties": { From a6d5cdf1127ade58a96894dd568cc5250c796bc1 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 16:24:45 +0200 Subject: [PATCH 03/10] update notebook --- .../grpo_functiongemma_browsergym_openenv.ipynb | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index 26ef2ef3ed4..1a39cc18630 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -514,7 +514,9 @@ }, "source": [ "from transformers import parse_response\n\n\ndef test_browsergym(model, tokenizer):\n \"\"\"Run the fine-tuned model on BrowserGym for one episode.\"\"\"\n env = BrowserGymFunctionGemmaEnv()\n tools = [env.click, env.fill, env.send_keys, env.scroll, env.noop]\n\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": env.reset()},\n ]\n\n for step in range(max_steps):\n if env._done:\n break\n\n prompt = tokenizer.apply_chat_template(\n messages,\n tools=tools,\n add_generation_prompt=True,\n tokenize=False,\n )\n inputs = tokenizer([prompt], return_tensors=\"pt\").to(model.device)\n output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)\n new_ids = output_ids[0][inputs.input_ids.shape[1]:]\n\n response = parse_response(tokenizer, new_ids)\n tool_calls = response.get(\"tool_calls\") or []\n\n if not tool_calls:\n print(f\" Step {step + 1}: no tool call generated\")\n break\n\n tc = tool_calls[0]\n func_name = tc[\"function\"][\"name\"]\n arguments = tc[\"function\"][\"arguments\"] or {}\n\n print(f\" Step {step + 1}: {func_name}({\", \".join(f'{k}={v!r}' for k, v in arguments.items())})\")\n\n tool_fn = getattr(env, func_name, None)\n if tool_fn is None:\n print(f\" Unknown function: {func_name}\")\n break\n\n try:\n obs = tool_fn(**arguments)\n messages.append(response)\n messages.append({\"role\": \"tool\", \"name\": func_name, \"content\": obs})\n except Exception as e:\n print(f\" Error: {e}\")\n break\n\n print(f\"Reward: {env.reward}\")\n" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -522,7 +524,6 @@ "id": "Z77wlVb6BP5-", "outputId": "ed4ad094-1529-4cc7-8274-2782784efe2d" }, - "outputs": [], "source": "Finally, `test_browsergym` is called to evaluate the fine-tuned model on one episode of the BrowserGym click task." }, { @@ -530,7 +531,9 @@ "metadata": { "id": "wHydP-ZVCcYK" }, - "source": "test_browsergym(fine_tuned_model, tokenizer)" + "source": "test_browsergym(fine_tuned_model, tokenizer)", + "outputs": [], + "execution_count": null } ], "metadata": { From 704cd0171dbf4b7baffb51b112c9fbd0afa7f056 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 16:32:11 +0200 Subject: [PATCH 04/10] updated notebook --- .../grpo_functiongemma_browsergym_openenv.ipynb | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index 1a39cc18630..26157c5c0d6 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -229,15 +229,6 @@ "model_name = \"google/functiongemma-270m-it\"" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer\nfrom trl import add_response_schema\n\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nadd_response_schema(tokenizer) # Sets tokenizer.response_schema for FunctionGemma's tool-call format\n" - ] - }, { "cell_type": "code", "execution_count": null, @@ -246,7 +237,7 @@ "outputId": "61740a89-228c-4b3c-8e59-b4a3eb972c03" }, "outputs": [], - "source": "trainer = GRPOTrainer(\n model=model_name,\n processing_class=tokenizer,\n reward_funcs=[reward_completion],\n train_dataset=dataset,\n args=grpo_config,\n environment_factory=BrowserGymFunctionGemmaEnv,\n)" + "source": "trainer = GRPOTrainer(\n model=model_name,\n reward_funcs=[reward_completion],\n train_dataset=dataset,\n args=grpo_config,\n environment_factory=BrowserGymFunctionGemmaEnv,\n)" }, { "cell_type": "code", @@ -496,9 +487,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM\n\nfine_tuned_model = AutoModelForCausalLM.from_pretrained(output_dir).to(\"cuda\")\n" - ] + "source": "from transformers import AutoModelForCausalLM, AutoTokenizer\nfrom trl import add_response_schema\n\nfine_tuned_model = AutoModelForCausalLM.from_pretrained(output_dir).to(\"cuda\")\ntokenizer = AutoTokenizer.from_pretrained(output_dir)\nadd_response_schema(tokenizer) # response_schema is not persisted to disk; re-apply before inference\n" }, { "cell_type": "markdown", From fb20b4d57d547d781d1c2b208e8b304b09a83bfc Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 16:34:46 +0200 Subject: [PATCH 05/10] udpated notebook --- .../notebooks/grpo_functiongemma_browsergym_openenv.ipynb | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index 26157c5c0d6..63fd7995715 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -192,6 +192,11 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Fine-tune using TRL and the GRPOTrainer\n\nThe next step is to define the GRPOConfig, which sets all key training parameters.\n\nThis configuration determines how the model interacts with the environment, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process." + }, { "cell_type": "code", "execution_count": null, From 030cf2e046af14cc6cf7ed5fc188fafb1a386785 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 16:41:43 +0200 Subject: [PATCH 06/10] nit --- docs/source/example_overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index ee458ac6c21..42d046a4187 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -42,7 +42,7 @@ These notebooks are easier to run and are designed for quick experimentation wit ### OpenEnv Notebooks -These notebooks demonstrate how to train models with [OpenEnv](openenv) environments using [`GRPOTrainer`]'s `environment_factory`. The BrowserGym notebook uses the lower-level `rollout_func` API instead. See the [OpenEnv Integration](openenv) guide for more details. +These notebooks demonstrate how to train models with [OpenEnv](openenv) environments using [`GRPOTrainer`]'s `environment_factory`. See the [OpenEnv Integration](openenv) guide for more details. | Notebook | Description | Open in Colab | |----------|-------------|---------------| From f5e219e27a73d08462040fc531c61610c466d2f2 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 16:43:49 +0200 Subject: [PATCH 07/10] cursor feedbackg --- .../notebooks/grpo_functiongemma_browsergym_openenv.ipynb | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index 63fd7995715..67137f5b8a6 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -506,9 +506,7 @@ "metadata": { "id": "9QvGD8f8CQx1" }, - "source": [ - "from transformers import parse_response\n\n\ndef test_browsergym(model, tokenizer):\n \"\"\"Run the fine-tuned model on BrowserGym for one episode.\"\"\"\n env = BrowserGymFunctionGemmaEnv()\n tools = [env.click, env.fill, env.send_keys, env.scroll, env.noop]\n\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": env.reset()},\n ]\n\n for step in range(max_steps):\n if env._done:\n break\n\n prompt = tokenizer.apply_chat_template(\n messages,\n tools=tools,\n add_generation_prompt=True,\n tokenize=False,\n )\n inputs = tokenizer([prompt], return_tensors=\"pt\").to(model.device)\n output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)\n new_ids = output_ids[0][inputs.input_ids.shape[1]:]\n\n response = parse_response(tokenizer, new_ids)\n tool_calls = response.get(\"tool_calls\") or []\n\n if not tool_calls:\n print(f\" Step {step + 1}: no tool call generated\")\n break\n\n tc = tool_calls[0]\n func_name = tc[\"function\"][\"name\"]\n arguments = tc[\"function\"][\"arguments\"] or {}\n\n print(f\" Step {step + 1}: {func_name}({\", \".join(f'{k}={v!r}' for k, v in arguments.items())})\")\n\n tool_fn = getattr(env, func_name, None)\n if tool_fn is None:\n print(f\" Unknown function: {func_name}\")\n break\n\n try:\n obs = tool_fn(**arguments)\n messages.append(response)\n messages.append({\"role\": \"tool\", \"name\": func_name, \"content\": obs})\n except Exception as e:\n print(f\" Error: {e}\")\n break\n\n print(f\"Reward: {env.reward}\")\n" - ], + "source": "def test_browsergym(model, tokenizer):\n \"\"\"Run the fine-tuned model on BrowserGym for one episode.\"\"\"\n env = BrowserGymFunctionGemmaEnv()\n tools = [env.click, env.fill, env.send_keys, env.scroll, env.noop]\n\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": env.reset()},\n ]\n\n for step in range(max_steps):\n if env._done:\n break\n\n prompt = tokenizer.apply_chat_template(\n messages,\n tools=tools,\n add_generation_prompt=True,\n tokenize=False,\n )\n inputs = tokenizer([prompt], return_tensors=\"pt\").to(model.device)\n output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)\n new_ids = output_ids[0][inputs.input_ids.shape[1]:]\n\n response = tokenizer.parse_response(new_ids)\n tool_calls = response.get(\"tool_calls\") or []\n\n if not tool_calls:\n print(f\" Step {step + 1}: no tool call generated\")\n break\n\n tc = tool_calls[0]\n func_name = tc[\"function\"][\"name\"]\n arguments = tc[\"function\"][\"arguments\"] or {}\n\n print(f\" Step {step + 1}: {func_name}({\", \".join(f'{k}={v!r}' for k, v in arguments.items())})\")\n\n tool_fn = getattr(env, func_name, None)\n if tool_fn is None:\n print(f\" Unknown function: {func_name}\")\n break\n\n try:\n obs = tool_fn(**arguments)\n messages.append(response)\n messages.append({\"role\": \"tool\", \"name\": func_name, \"content\": obs})\n except Exception as e:\n print(f\" Error: {e}\")\n break\n\n print(f\"Reward: {env.reward}\")\n", "outputs": [], "execution_count": null }, From 3ace55632f43b986dfc37ac5899f88b4a9361ae4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 17:35:35 +0200 Subject: [PATCH 08/10] added changes in PR#5568 --- examples/scripts/openenv/browsergym.py | 50 ++++++++++++++++++---- examples/scripts/openenv/browsergym_llm.py | 19 +++++--- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index 956d0f55944..3590c82936f 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -67,11 +67,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--space-url", default="https://openenv-browsergym-env.hf.space") parser.add_argument("--dataset-prompt", default="Complete the web task successfully.") parser.add_argument("--dataset-size", type=int, default=1000) - parser.add_argument("--max-steps", type=int, default=10) + parser.add_argument("--max-steps", type=int, default=10, help="Max steps per episode.") parser.add_argument("--max-completion-length", type=int, default=1024) parser.add_argument("--image-size", type=int, default=512, help="Resize screenshots to this size. 0 to disable.") parser.add_argument("--num-generations", type=int, default=4) - parser.add_argument("--gradient-accumulation-steps", type=int, default=32) + parser.add_argument("--gradient-accumulation-steps", type=int, default=1) parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--num-epochs", type=int, default=1) parser.add_argument("--logging-steps", type=int, default=1) @@ -108,6 +108,11 @@ def reward_completion(completions, environments, **kwargs) -> list[float]: return [env.reward for env in environments] +def reward_efficiency(completions, environments, **kwargs) -> list[float]: + """Penalize extra tool calls beyond the first: -0.1 per extra call.""" + return [-0.1 * max(0, env._step_count - 1) for env in environments] + + def main() -> None: args = parse_args() @@ -134,10 +139,39 @@ def __init__(self): self.done = False self._step_count = 0 + def _ensure_large_max_size(self): + """Raise WebSocket max message size for large observations (screenshots + axtree). + + openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library + defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent. + """ + import websockets + + self.client.connect() + ws = self.client._ws + if ws is not None and ws.protocol is not None: + proto = ws.protocol + # websockets renamed max_size to max_message_size in version 16 + if int(websockets.__version__.split(".")[0]) >= 16: + if proto.max_message_size == 2**20: + proto.max_message_size = 100 * 1024 * 1024 + else: + if proto.max_size == 2**20: + proto.max_size = 100 * 1024 * 1024 + + @staticmethod + def _normalize_bid(bid) -> str: + """Normalize bid to a plain string (handles int, '[13]', or '13' formats).""" + bid = str(bid).strip() + if bid.startswith("[") and bid.endswith("]"): + bid = bid[1:-1].strip() + return bid + def reset(self, **kwargs) -> str | None: self.reward = 0.0 self.done = False self._step_count = 0 + self._ensure_large_max_size() result = self.client.reset() self.done = result.done return self._format_observation(result.observation) @@ -151,7 +185,7 @@ def click(self, bid: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"click('{bid}')") + return self._do_action(f"click({self._normalize_bid(bid)!r})") def fill(self, bid: str, text: str) -> list: """Fill an input field with text. @@ -163,7 +197,7 @@ def fill(self, bid: str, text: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"fill('{bid}', '{text}')") + return self._do_action(f"fill({self._normalize_bid(bid)!r}, {text!r})") def send_keys(self, text: str) -> list: """Send keyboard input to the page. @@ -174,7 +208,7 @@ def send_keys(self, text: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"send_keys('{text}')") + return self._do_action(f"send_keys({text!r})") def scroll(self, direction: str) -> list: """Scroll the page. @@ -185,7 +219,7 @@ def scroll(self, direction: str) -> list: Returns: The updated page observation with screenshot. """ - return self._do_action(f"scroll('{direction}')") + return self._do_action(f"scroll({direction!r})") def noop(self) -> list: """Do nothing and observe the current page state. @@ -197,7 +231,7 @@ def noop(self) -> list: def _do_action(self, action_str: str) -> list: if self.done: - raise ValueError("Episode is done.") + return "Episode is done. No further actions needed." self._step_count += 1 result = self.client.step(BrowserGymAction(action_str=action_str)) @@ -263,7 +297,7 @@ def _format_observation_multimodal(self, observation) -> list: trainer = GRPOTrainer( model=args.model_id, - reward_funcs=reward_completion, + reward_funcs=[reward_completion, reward_efficiency], train_dataset=dataset, args=GRPOConfig( use_vllm=args.use_vllm, diff --git a/examples/scripts/openenv/browsergym_llm.py b/examples/scripts/openenv/browsergym_llm.py index 60f0e3a7b49..6d4d9dafd0c 100644 --- a/examples/scripts/openenv/browsergym_llm.py +++ b/examples/scripts/openenv/browsergym_llm.py @@ -152,7 +152,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--gradient-accumulation-steps", type=int, - default=32, + default=1, help="Gradient accumulation steps for GRPO training.", ) parser.add_argument( @@ -296,14 +296,19 @@ def _ensure_large_max_size(self): openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent. """ + import websockets + self.client.connect() ws = self.client._ws - if ws is not None and hasattr(ws, "protocol"): + if ws is not None and ws.protocol is not None: proto = ws.protocol - # websockets <16: max_size; websockets >=16: max_message_size - attr = "max_size" if hasattr(proto, "max_size") else "max_message_size" - if getattr(proto, attr) == 2**20: - setattr(proto, attr, 100 * 1024 * 1024) + # websockets renamed max_size to max_message_size in version 16 + if int(websockets.__version__.split(".")[0]) >= 16: + if proto.max_message_size == 2**20: + proto.max_message_size = 100 * 1024 * 1024 + else: + if proto.max_size == 2**20: + proto.max_size = 100 * 1024 * 1024 def reset(self, **kwargs) -> str: self.reward = 0.0 @@ -369,7 +374,7 @@ def noop(self) -> str: def _do_action(self, action_str: str) -> str: if self._done: - raise ValueError("Episode is done.") + return "Episode is done. No further actions needed." self._step_count += 1 result = self.client.step(BrowserGymAction(action_str=action_str)) From 187407bc790f9536640f6357024a083520581b76 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 17:42:24 +0200 Subject: [PATCH 09/10] update based on cursor --- examples/scripts/openenv/browsergym.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index 3590c82936f..138330bdcfb 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -231,7 +231,7 @@ def noop(self) -> list: def _do_action(self, action_str: str) -> list: if self.done: - return "Episode is done. No further actions needed." + return [{"type": "text", "text": "Episode is done. No further actions needed."}] self._step_count += 1 result = self.client.step(BrowserGymAction(action_str=action_str)) From 789cc101fe7c12c5bdad1a62b2df276584c4fb80 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 29 May 2026 18:07:05 +0200 Subject: [PATCH 10/10] cursor review --- examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index 67137f5b8a6..a573022e1b0 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -140,7 +140,7 @@ "id": "CgHd5CFBBP58" }, "outputs": [], - "source": "import re\nfrom browsergym_env import BrowserGymAction\n\nmax_steps = 10\n\n\nclass BrowserGymFunctionGemmaEnv:\n def __init__(self):\n self.client = BrowserGymEnv(base_url=space_url).sync()\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n\n def _ensure_large_max_size(self):\n \"\"\"Patch the WebSocket connection to allow messages up to 100 MB.\n\n Some accessibility trees exceed the default 1 MB limit, which causes\n the frame to be dropped silently and the observation to be empty.\n \"\"\"\n self.client.connect()\n ws = self.client._ws\n if ws is not None and hasattr(ws, \"protocol\"):\n proto = ws.protocol\n # websockets <16: max_size; websockets >=16: max_message_size\n attr = \"max_size\" if hasattr(proto, \"max_size\") else \"max_message_size\"\n if getattr(proto, attr) == 2**20:\n setattr(proto, attr, 100 * 1024 * 1024)\n\n def reset(self, **kwargs) -> str:\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n self._ensure_large_max_size()\n result = self.client.reset()\n self._done = result.done\n return self._format_observation(result.observation)\n\n def click(self, bid: str) -> str:\n \"\"\"Click an element on the page.\n\n Args:\n bid: The BrowserGym ID of the element to click.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"click({repr(bid)})\")\n\n def fill(self, bid: str, text: str) -> str:\n \"\"\"Fill an input field with text.\n\n Args:\n bid: The BrowserGym ID of the input field.\n text: The text to type into the field.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"fill({repr(bid)}, {repr(text)})\")\n\n def send_keys(self, text: str) -> str:\n \"\"\"Send keyboard input to the page.\n\n Args:\n text: The keyboard input to send.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"send_keys({repr(text)})\")\n\n def scroll(self, direction: str) -> str:\n \"\"\"Scroll the page.\n\n Args:\n direction: Direction to scroll, either 'up' or 'down'.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"scroll({repr(direction)})\")\n\n def noop(self) -> str:\n \"\"\"Do nothing and observe the current page state.\n\n Returns:\n The current page observation.\n \"\"\"\n return self._do_action(\"noop()\")\n\n def _do_action(self, action_str: str) -> str:\n if self._done:\n raise ValueError(\"Episode is done.\")\n self._step_count += 1\n result = self.client.step(BrowserGymAction(action_str=action_str))\n step_reward = float(result.reward or 0.0)\n self._done = result.done\n if self._done and step_reward > 0:\n self.reward = 1.0\n elif self._done:\n self.reward = 0.0\n else:\n self.reward = step_reward\n if self._step_count >= max_steps:\n self._done = True\n return self._format_observation(result.observation)\n\n def _format_observation(self, observation) -> str:\n parts = []\n if observation.goal:\n parts.append(f\"Goal: {observation.goal}\")\n if observation.last_action_error and observation.error:\n parts.append(f\"Error: {observation.error}\")\n if observation.axtree_txt:\n axtree = observation.axtree_txt\n axtree = re.sub(r'\\[(\\d+)\\]', r'[bid:\\1]', axtree)\n if len(axtree) > 2000:\n axtree = axtree[:2000] + \"...\"\n parts.append(f\"Page structure:\\n{axtree}\")\n parts.append(f\"Step {self._step_count + 1}/{max_steps}: call a tool to act.\")\n return \"\\n\\n\".join(parts) if parts else \"No observation available.\"\n" + "source": "import re\nfrom browsergym_env import BrowserGymAction\n\nmax_steps = 10\n\n\nclass BrowserGymFunctionGemmaEnv:\n def __init__(self):\n self.client = BrowserGymEnv(base_url=space_url).sync()\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n\n def _ensure_large_max_size(self):\n \"\"\"Raise WebSocket max message size for large observations (accessibility trees).\n\n openenv-core<=0.2.1 does not pass max_size to ws_connect, so the websockets library\n defaults to 1MB. We force a connection and patch it to 100MB before any messages are sent.\n \"\"\"\n import websockets\n\n self.client.connect()\n ws = self.client._ws\n if ws is not None and ws.protocol is not None:\n proto = ws.protocol\n # websockets renamed max_size to max_message_size in version 16\n if int(websockets.__version__.split(\".\")[0]) >= 16:\n if proto.max_message_size == 2**20:\n proto.max_message_size = 100 * 1024 * 1024\n else:\n if proto.max_size == 2**20:\n proto.max_size = 100 * 1024 * 1024\n\n def reset(self, **kwargs) -> str:\n self.reward = 0.0\n self._done = False\n self._step_count = 0\n self._ensure_large_max_size()\n result = self.client.reset()\n self._done = result.done\n return self._format_observation(result.observation)\n\n def click(self, bid: str) -> str:\n \"\"\"Click an element on the page.\n\n Args:\n bid: The BrowserGym ID of the element to click.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"click({repr(bid)})\")\n\n def fill(self, bid: str, text: str) -> str:\n \"\"\"Fill an input field with text.\n\n Args:\n bid: The BrowserGym ID of the input field.\n text: The text to type into the field.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"fill({repr(bid)}, {repr(text)})\")\n\n def send_keys(self, text: str) -> str:\n \"\"\"Send keyboard input to the page.\n\n Args:\n text: The keyboard input to send.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"send_keys({repr(text)})\")\n\n def scroll(self, direction: str) -> str:\n \"\"\"Scroll the page.\n\n Args:\n direction: Direction to scroll, either 'up' or 'down'.\n\n Returns:\n The updated page observation.\n \"\"\"\n return self._do_action(f\"scroll({repr(direction)})\")\n\n def noop(self) -> str:\n \"\"\"Do nothing and observe the current page state.\n\n Returns:\n The current page observation.\n \"\"\"\n return self._do_action(\"noop()\")\n\n def _do_action(self, action_str: str) -> str:\n if self._done:\n return \"Episode is done. No further actions needed.\"\n self._step_count += 1\n result = self.client.step(BrowserGymAction(action_str=action_str))\n step_reward = float(result.reward or 0.0)\n self._done = result.done\n if self._done and step_reward > 0:\n self.reward = 1.0\n elif self._done:\n self.reward = 0.0\n else:\n self.reward = step_reward\n if self._step_count >= max_steps:\n self._done = True\n return self._format_observation(result.observation)\n\n def _format_observation(self, observation) -> str:\n parts = []\n if observation.goal:\n parts.append(f\"Goal: {observation.goal}\")\n if observation.last_action_error and observation.error:\n parts.append(f\"Error: {observation.error}\")\n if observation.axtree_txt:\n axtree = observation.axtree_txt\n axtree = re.sub(r'\\[(\\d+)\\]', r'[bid:\\1]', axtree)\n if len(axtree) > 2000:\n axtree = axtree[:2000] + \"...\"\n parts.append(f\"Page structure:\\n{axtree}\")\n parts.append(f\"Step {self._step_count + 1}/{max_steps}: call a tool to act.\")\n return \"\\n\\n\".join(parts) if parts else \"No observation available.\"\n" }, { "cell_type": "markdown",