diff --git a/main.py b/main.py index f04aff88..ec2f5d36 100644 --- a/main.py +++ b/main.py @@ -156,6 +156,22 @@ def _append_task_output(result: dict, task_index: int, output_file: str) -> None "token_usage": result.get("token_usage", {}), "agent_messages": result.get("agent_messages", []), } + # TroVE telemetry: passthrough when present so scripts/analyze_trove_run.py + # (and any other post-hoc analyzer) can read per-task tool-use stats and the + # final library state from the JSONL. Keys are absent on non-TroVE runs. + for key in ( + "won_mode", + "import_eligible", + "import_was_winner", + "tool_calls", + "tool_call_count", + "tools_called", + "actually_called", + "trove_stopped_reason", + "library_snapshot", + ): + if key in result: + record[key] = result[key] Path(output_file).parent.mkdir(parents=True, exist_ok=True) with open(output_file, "a", encoding="utf-8") as f: f.write(json.dumps(record, default=str) + "\n") @@ -808,6 +824,23 @@ def main() -> None: help="[TroVE] Trim low-frequency toolbox functions every N tasks. " "Paper default: 500. Set to 9999 to disable for small datasets. (default: 500)", ) + parser.add_argument( + "--trove-selection", + choices=["reward", "consistency"], + default="reward", + help="[TroVE] Candidate selection strategy. 'reward' (default) uses " + "the per-task reward function with AST tie-breaking. " + "'consistency' uses the original TroVE majority-vote algorithm. " + "(default: reward)", + ) + parser.add_argument( + "--trove-task-family", + choices=["default", "pbebench"], + default="default", + help="[TroVE] Task family for prompt selection and parser strictness. " + "'pbebench' uses PBEBench-shaped few-shots and strict **Solution** " + "parsing (no fallback to any python block). (default: default)", + ) # ReGAL-specific flags parser.add_argument( "--regal-train-file", @@ -1007,8 +1040,13 @@ def main() -> None: debug_dir=args.debug_dir, k=args.trove_k, trim_every=args.trove_trim_every, + task_family=args.trove_task_family, + selection=args.trove_selection, + ) + logger.info( + "Framework: TroVE (k=%d, trim_every=%d, task_family=%s, selection=%s)", + args.trove_k, args.trove_trim_every, args.trove_task_family, args.trove_selection, ) - logger.info("Framework: TroVE (k=%d, trim_every=%d)", args.trove_k, args.trove_trim_every) elif args.framework == "regal": from pathlib import Path as _Path controller = ReGALController( diff --git a/notebooks/run_trove_pbebench.ipynb b/notebooks/run_trove_pbebench.ipynb new file mode 100644 index 00000000..83c585f9 --- /dev/null +++ b/notebooks/run_trove_pbebench.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TroVE × PBEBench-Lite — RunPod runner (`gpt-oss-20b`)\n", + "\n", + "End-to-end notebook to:\n", + "\n", + "1. Check GPU and install dependencies\n", + "2. Launch a local vLLM server (with native tool-calling flags)\n", + "3. Wait for it to be healthy\n", + "4. Run TroVE on PBEBench-Lite with reward-based selection\n", + "5. Analyze the JSONL output\n", + "\n", + "## Pod sizing\n", + "\n", + "`openai/gpt-oss-20b` runs comfortably on a single **A100 80 GB** or **H100** with `--tensor-parallel-size 1`. A100 40 GB will OOM at default settings.\n", + "\n", + "## Before you start\n", + "\n", + "- Run this notebook from a Jupyter kernel **inside the pod**, with the repo at `/workspace/pbe/symbolic-library-agent` (or wherever you cloned it). Adjust `REPO_ROOT` in the next cell if needed.\n", + "- Each cell is idempotent — safe to re-run.\n", + "- Cleanup at the bottom kills the vLLM process; if you re-run cells out of order, you may end up with a stale server — use the cleanup cell." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Configuration" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "from pathlib import Path\n", + "import os\n", + "\n", + "MODEL = \"openai/gpt-oss-20b\"\n", + "TENSOR_PARALLEL = 1\n", + "\n", + "PORT = 8000\n", + "BASE_URL = f\"http://localhost:{PORT}/v1\"\n", + "\n", + "# Repo root — change if your clone lives elsewhere on the pod.\n", + "REPO_ROOT = Path(os.environ.get(\"REPO_ROOT\", \"/workspace/pbe/symbolic-library-agent\"))\n", + "if not REPO_ROOT.exists():\n", + " REPO_ROOT = Path.cwd().parent if Path.cwd().name == \"notebooks\" else Path.cwd()\n", + "assert (REPO_ROOT / \"main.py\").exists(), f\"Could not find main.py under {REPO_ROOT}\"\n", + "os.chdir(REPO_ROOT)\n", + "\n", + "# Tasks file. Two PBEBench-Lite options ship with the repo:\n", + "# - lite_pilot_tasks.jsonl : 50-task pilot split (smoke-run default)\n", + "# - lite_tasks_full_og.jsonl : full Lite split (1008 tasks)\n", + "TASKS_FILE = REPO_ROOT / \"data/pbebench/lite_pilot_tasks.jsonl\"\n", + "MAX_PROGRAMS = 5 # PBEBench convention for the lite split\n", + "\n", + "OUT_DIR = REPO_ROOT / \"outputs\"\n", + "OUT_FILE = OUT_DIR / \"trove_pbebench_lite_smoke.jsonl\"\n", + "DEBUG_DIR = REPO_ROOT / \"debug_trove_pbebench\"\n", + "VLLM_LOGS = REPO_ROOT / \"vllm_logs\"\n", + "OUT_DIR.mkdir(parents=True, exist_ok=True)\n", + "DEBUG_DIR.mkdir(parents=True, exist_ok=True)\n", + "VLLM_LOGS.mkdir(parents=True, exist_ok=True)\n", + "\n", + "print(f\"REPO_ROOT : {REPO_ROOT}\")\n", + "print(f\"MODEL : {MODEL} (TP={TENSOR_PARALLEL})\")\n", + "print(f\"BASE_URL : {BASE_URL}\")\n", + "print(f\"TASKS_FILE : {TASKS_FILE} (exists={TASKS_FILE.exists()})\")\n", + "print(f\"OUT_FILE : {OUT_FILE}\")" + ], + "execution_count": null, + "outputs": [], + "id": "ce204af4" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. GPU & dependency check" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "!nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Install repo deps + vLLM. Re-running is a no-op if everything's already there.\n", + "!pip install -q -U pip wheel\n", + "!pip install -q -r requirements.txt 2>&1 | tail -5\n", + "!pip install -q -U \"vllm>=0.16.0\" 2>&1 | tail -5\n", + "import importlib, vllm\n", + "print(\"vllm version:\", vllm.__version__)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Launch vLLM in the background\n", + "\n", + "Required flags for `gpt-oss` native tool calling (vLLM ≥ v0.16.0):\n", + "\n", + "- `--enable-auto-tool-choice`\n", + "- `--tool-call-parser openai`\n", + "- `--reasoning-parser openai_gptoss`" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import os, subprocess, time, datetime\n", + "\n", + "ts = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "log_path = VLLM_LOGS / f\"vllm_{PORT}_{ts}.log\"\n", + "pid_path = VLLM_LOGS / f\"vllm_{PORT}_{ts}.pid\"\n", + "\n", + "user = os.environ.get(\"USER\", \"runpod\")\n", + "for d in (f\"/tmp/{user}-tiktoken-cache\", f\"/tmp/{user}-tmp\"):\n", + " Path(d).mkdir(parents=True, exist_ok=True)\n", + " os.chmod(d, 0o700)\n", + "os.environ[\"TIKTOKEN_CACHE_DIR\"] = f\"/tmp/{user}-tiktoken-cache\"\n", + "os.environ[\"TMPDIR\"] = f\"/tmp/{user}-tmp\"\n", + "\n", + "cmd = [\n", + " \"python\", \"-m\", \"vllm.entrypoints.openai.api_server\",\n", + " \"--model\", MODEL,\n", + " \"--tokenizer\", MODEL,\n", + " \"--dtype\", \"auto\",\n", + " \"--port\", str(PORT),\n", + " \"--gpu-memory-utilization\", \"0.95\",\n", + " \"--tensor-parallel-size\", str(TENSOR_PARALLEL),\n", + " \"--enable-auto-tool-choice\",\n", + " \"--tool-call-parser\", \"openai\",\n", + " \"--reasoning-parser\", \"openai_gptoss\",\n", + "]\n", + "\n", + "log_fh = open(log_path, \"w\")\n", + "vllm_proc = subprocess.Popen(cmd, stdout=log_fh, stderr=subprocess.STDOUT)\n", + "pid_path.write_text(str(vllm_proc.pid))\n", + "print(f\"vLLM started — pid {vllm_proc.pid}\")\n", + "print(f\"log : {log_path}\")\n", + "print(f\"pid : {pid_path}\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Wait for the OpenAI-compatible /v1/models endpoint to respond.\n", + "# gpt-oss-20b cold-start (model download + load) is typically 1–3 min on a\n", + "# fresh pod; subsequent launches are seconds once the weights are cached.\n", + "import urllib.request, json, time\n", + "\n", + "READY_TIMEOUT_S = 600 # 10 min\n", + "POLL_S = 5\n", + "\n", + "deadline = time.time() + READY_TIMEOUT_S\n", + "ready = False\n", + "while time.time() < deadline:\n", + " if vllm_proc.poll() is not None:\n", + " print(\"vLLM exited unexpectedly. Tail of log:\")\n", + " print(log_path.read_text()[-4000:])\n", + " raise RuntimeError(\"vLLM died during startup\")\n", + " try:\n", + " with urllib.request.urlopen(f\"{BASE_URL}/models\", timeout=2) as resp:\n", + " data = json.loads(resp.read())\n", + " print(\"Ready. /v1/models response:\")\n", + " print(json.dumps(data, indent=2)[:600])\n", + " ready = True\n", + " break\n", + " except Exception:\n", + " elapsed = int(READY_TIMEOUT_S - (deadline - time.time()))\n", + " print(f\"\\rwaiting for vLLM... {elapsed}s elapsed\", end=\"\", flush=True)\n", + " time.sleep(POLL_S)\n", + "\n", + "if not ready:\n", + " print(\"\\nTimed out. Tail of log:\")\n", + " print(log_path.read_text()[-4000:])\n", + " raise RuntimeError(\"vLLM never became ready\")" + ], + "execution_count": null, + "outputs": [], + "id": "b985cb11" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Optional: peek at the most recent vLLM server log. Re-run this cell any time\n", + "# (during or after the TroVE run) to spot-check throughput / GPU memory / errors.\n", + "def tail_vllm_log(n: int = 80) -> None:\n", + " logs = sorted(VLLM_LOGS.glob(\"vllm_*.log\"))\n", + " if not logs:\n", + " print(\"No vllm logs found yet.\")\n", + " return\n", + " latest = logs[-1]\n", + " text = latest.read_text(errors=\"replace\")\n", + " lines = text.splitlines()\n", + " print(f\"=== {latest.name} (last {min(n, len(lines))} of {len(lines)} lines) ===\")\n", + " print(\"\\n\".join(lines[-n:]))\n", + "\n", + "tail_vllm_log(60)" + ], + "execution_count": null, + "outputs": [], + "id": "e1bec107" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Run TroVE on PBEBench-Lite (smoke run)\n", + "\n", + "Defaults below match the design:\n", + "\n", + "- `--trove-task-family pbebench` — strict `**Solution**` parsing + PBEBench few-shots\n", + "- `--trove-selection reward` — reward-based candidate selection (AST tie-break)\n", + "- `--trove-k 5` — paper default samples per mode\n", + "- `--trove-trim-every 9999` — effectively disable periodic trimming for a 50-task smoke\n", + "- `--default-reward pbebench` — PBEBench verifier" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import subprocess, sys, datetime\n", + "\n", + "os.environ[\"VLLM_API_KEY\"] = os.environ.get(\"VLLM_API_KEY\", \"EMPTY\")\n", + "\n", + "cmd = [\n", + " sys.executable, \"-u\", \"main.py\",\n", + " \"--framework\", \"trove\",\n", + " \"--backend\", \"vllm\",\n", + " \"--base-url\", BASE_URL,\n", + " \"--model\", MODEL,\n", + " \"--trove-task-family\", \"pbebench\",\n", + " \"--trove-selection\", \"reward\",\n", + " \"--trove-k\", \"5\",\n", + " \"--trove-trim-every\", \"9999\",\n", + " \"--default-reward\", \"pbebench\",\n", + " \"--max-programs\", str(MAX_PROGRAMS),\n", + " \"--tasks-file\", str(TASKS_FILE),\n", + " \"--output-file\", str(OUT_FILE),\n", + " \"--debug-dir\", str(DEBUG_DIR),\n", + "]\n", + "\n", + "# Mirror stdout+stderr to a log file as well as the cell output. This keeps a\n", + "# durable record if the browser tab disconnects mid-run, and makes it trivial\n", + "# to grep for telemetry across runs.\n", + "RUN_TS = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + "RUN_LOG = OUT_DIR / f\"trove_pbebench_lite_smoke_{RUN_TS}.log\"\n", + "\n", + "print(\" \".join(cmd))\n", + "print(f\"\\nMirroring stdout to: {RUN_LOG}\\n\")\n", + "\n", + "proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)\n", + "try:\n", + " with open(RUN_LOG, \"w\", encoding=\"utf-8\") as logfh:\n", + " for line in proc.stdout:\n", + " print(line, end=\"\")\n", + " logfh.write(line)\n", + " logfh.flush()\n", + "finally:\n", + " rc = proc.wait()\n", + "print(f\"\\nmain.py exited with {rc}\")\n", + "print(f\"Full log: {RUN_LOG}\")" + ], + "execution_count": null, + "outputs": [], + "id": "500ee1a6" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Analyze the JSONL output" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "!python scripts/analyze_trove_run.py \"{OUT_FILE}\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Quick peek at one row to confirm telemetry made it through.\n", + "import json\n", + "with open(OUT_FILE) as f:\n", + " first = json.loads(next(f))\n", + "print(\"keys:\", sorted(first.keys()))\n", + "for k in (\"won_mode\", \"import_eligible\", \"tool_call_count\", \"trove_stopped_reason\"):\n", + " print(f\" {k:24s} = {first.get(k)}\")\n", + "print(f\" library_snapshot size = {len(first.get('library_snapshot', []))}\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Cleanup — stop vLLM\n", + "\n", + "Run this when you're done so the GPU is freed for the next experiment." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import signal, time\n", + "if vllm_proc.poll() is None:\n", + " vllm_proc.send_signal(signal.SIGINT)\n", + " try:\n", + " vllm_proc.wait(timeout=15)\n", + " except subprocess.TimeoutExpired:\n", + " vllm_proc.kill()\n", + " vllm_proc.wait()\n", + " print(\"vLLM stopped.\")\n", + "else:\n", + " print(\"vLLM was not running.\")\n", + "log_fh.close()" + ], + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3", + "version": "3.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/scripts/analyze_trove_run.py b/scripts/analyze_trove_run.py new file mode 100755 index 00000000..0fe2758e --- /dev/null +++ b/scripts/analyze_trove_run.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +"""Post-hoc analysis of a TroVE run JSONL output. + +Reads the per-task JSONL file produced by main.py --output-file and reports: + - Overall accuracy + - Final toolbox size + - Per-mode wins + - IMPORT-mode tool-use breakdown + - Top-10 most-called toolbox functions + +Usage: + python scripts/analyze_trove_run.py path/to/results.jsonl +""" + +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter +from pathlib import Path + + +def _load_rows(path: Path) -> list[dict]: + rows = [] + with path.open() as f: + for lineno, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + rows.append(json.loads(line)) + except json.JSONDecodeError as exc: + print(f"warning: line {lineno} is not valid JSON: {exc}", file=sys.stderr) + return rows + + +def _result_dict(row: dict) -> dict: + """Tolerant accessor: results are nested under 'result' in main.py's output.""" + return row.get("result") or row + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("path", type=Path, help="Path to the TroVE results JSONL file") + args = parser.parse_args() + + rows = _load_rows(args.path) + if not rows: + print("ERROR: no rows loaded", file=sys.stderr) + sys.exit(1) + + n = len(rows) + results = [_result_dict(r) for r in rows] + + solved = sum(1 for r in results if r.get("solved")) + print(f"=== Run summary: {args.path.name} ===") + print(f"Tasks: {n}") + print(f"Solved: {solved}/{n} ({100 * solved / n:.1f}%)") + + last_snapshot = results[-1].get("library_snapshot") or [] + print(f"Final toolbox size: {len(last_snapshot)}") + + mode_counter = Counter(r.get("won_mode", "?") for r in results) + print(f"Mode wins: {dict(mode_counter)}") + + import_eligible = [r for r in results if r.get("import_eligible")] + if not import_eligible: + print("No IMPORT-eligible tasks observed.") + else: + with_calls = [r for r in import_eligible if (r.get("tool_call_count") or 0) >= 1] + n_eligible = len(import_eligible) + n_with = len(with_calls) + mean_calls = ( + sum((r.get("tool_call_count") or 0) for r in import_eligible) / n_eligible + ) + all_calls = [tc for r in import_eligible for tc in (r.get("tool_calls") or [])] + n_calls_total = len(all_calls) + n_calls_ok = sum(1 for tc in all_calls if tc.get("ok")) + success_rate = (100 * n_calls_ok / n_calls_total) if n_calls_total else 0.0 + print( + f"IMPORT-eligible tasks: {n_eligible}\n" + f" Tasks with >=1 tool call: {n_with}/{n_eligible} ({100 * n_with / n_eligible:.1f}%)\n" + f" Mean tool calls / task: {mean_calls:.2f}\n" + f" Tool-call success rate: {n_calls_ok}/{n_calls_total} ({success_rate:.1f}%)" + ) + + name_counter: Counter = Counter() + for r in results: + for tc in r.get("tool_calls") or []: + name = (tc.get("name") or "").split("<|", 1)[0].strip() + if name: + name_counter[name] += 1 + if name_counter: + print("Top-10 most-called toolbox functions:") + for name, cnt in name_counter.most_common(10): + print(f" {cnt:4d} {name}") + else: + print("No tool calls recorded in this run.") + + +if __name__ == "__main__": + main() diff --git a/scripts/launch_vllm_gpt_oss_120b.sh b/scripts/launch_vllm_gpt_oss_120b.sh index 74b10dac..5ae5216c 100644 --- a/scripts/launch_vllm_gpt_oss_120b.sh +++ b/scripts/launch_vllm_gpt_oss_120b.sh @@ -7,6 +7,11 @@ export TMPDIR=/tmp/$USER-tmp ts=$(date +%Y%m%d_%H%M%S) +# Required vLLM tool-calling flags (vLLM >= v0.16.0 for PR #28729): +# --enable-auto-tool-choice enables tool_choice="auto" +# --tool-call-parser openai parses gpt-oss Harmony commentary channel +# --reasoning-parser openai_gptoss routes analysis-channel content into +# message.reasoning_content nohup python -m vllm.entrypoints.openai.api_server \ --model "openai/gpt-oss-120b" \ --tokenizer "openai/gpt-oss-120b" \ @@ -14,4 +19,7 @@ nohup python -m vllm.entrypoints.openai.api_server \ --port ${1} \ --gpu-memory-utilization 0.95 \ --tensor-parallel-size 2 \ - > vllm_logs/vllm_${1}_${ts}.log 2>&1 & echo $! > vllm_logs/vllm_${1}_${ts}.pid \ No newline at end of file + --enable-auto-tool-choice \ + --tool-call-parser openai \ + --reasoning-parser openai_gptoss \ + > vllm_logs/vllm_${1}_${ts}.log 2>&1 & echo $! > vllm_logs/vllm_${1}_${ts}.pid diff --git a/scripts/launch_vllm_gpt_oss_20b.sh b/scripts/launch_vllm_gpt_oss_20b.sh new file mode 100755 index 00000000..37d6e131 --- /dev/null +++ b/scripts/launch_vllm_gpt_oss_20b.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +mkdir -p /tmp/$USER-tiktoken-cache /tmp/$USER-tmp +chmod 700 /tmp/$USER-tiktoken-cache /tmp/$USER-tmp +export TIKTOKEN_CACHE_DIR=/tmp/$USER-tiktoken-cache +export TMPDIR=/tmp/$USER-tmp + +ts=$(date +%Y%m%d_%H%M%S) + +# Required vLLM tool-calling flags (vLLM >= v0.16.0 for PR #28729): +# --enable-auto-tool-choice enables tool_choice="auto" +# --tool-call-parser openai parses gpt-oss Harmony commentary channel +# --reasoning-parser openai_gptoss routes analysis-channel content into +# message.reasoning_content +nohup python -m vllm.entrypoints.openai.api_server \ + --model "openai/gpt-oss-20b" \ + --tokenizer "openai/gpt-oss-20b" \ + --dtype auto \ + --port ${1} \ + --gpu-memory-utilization 0.95 \ + --tensor-parallel-size 1 \ + --enable-auto-tool-choice \ + --tool-call-parser openai \ + --reasoning-parser openai_gptoss \ + > vllm_logs/vllm_${1}_${ts}.log 2>&1 & echo $! > vllm_logs/vllm_${1}_${ts}.pid diff --git a/scripts/run_trove_vllm.sh b/scripts/run_trove_vllm.sh index 54c27932..280baa23 100755 --- a/scripts/run_trove_vllm.sh +++ b/scripts/run_trove_vllm.sh @@ -1,27 +1,44 @@ #!/usr/bin/env bash -# Run TroVE baseline against a local vLLM server. +# Run TroVE baseline against a local vLLM server (gpt-oss-20b). # Usage: bash scripts/run_trove_vllm.sh # -# For small datasets (≤100 tasks), --trove-trim-every is set high to disable +# Defaults to PBEBench-Lite pilot (50 tasks). Override TASKS_FILE or pass +# extra --flags through the trailing "$@". +# +# For small datasets (<=100 tasks), --trove-trim-every is set high to disable # trimming (the library never gets large enough for it to matter). -# Set --trove-k 1 for a cheaper run without self-consistency sampling. +# Set --trove-k 1 for a cheaper run without per-mode K-sampling. set -euo pipefail cd "$(dirname "${BASH_SOURCE[0]}")/.." -export PORT=8002 +export PORT="${PORT:-8000}" export VLLM_API_KEY="${VLLM_API_KEY:-EMPTY}" mkdir -p outputs +TASKS_FILE="${TASKS_FILE:-data/pbebench/lite_pilot_tasks.jsonl}" +OUT_FILE="${OUT_FILE:-outputs/trove_pbebench_lite_pilot.jsonl}" + +echo "Tasks : ${TASKS_FILE}" +echo "Output : ${OUT_FILE}" +echo "Port : ${PORT}" + python main.py \ - --framework trove \ - --tasks-file data/pbebench/lite_tasks_full.jsonl \ - --base-url "http://localhost:${PORT}/v1" \ - --model "openai/gpt-oss-120b" \ - --trove-k 5 \ - --trove-trim-every 9999 \ - --default-reward pbebench \ - --output-file outputs/pbebench_lite_full_trove.jsonl \ - --debug-dir debug_trove \ - --stats + --framework trove \ + --tasks-file "${TASKS_FILE}" \ + --base-url "http://localhost:${PORT}/v1" \ + --model "openai/gpt-oss-20b" \ + --trove-task-family pbebench \ + --trove-selection reward \ + --trove-k 5 \ + --trove-trim-every 9999 \ + --default-reward pbebench \ + --max-programs 5 \ + --output-file "${OUT_FILE}" \ + --debug-dir debug_trove \ + --stats \ + "$@" + +echo "Done. Output: ${OUT_FILE}" +echo "Analyze with: python scripts/analyze_trove_run.py ${OUT_FILE}" diff --git a/symbolic_agent/baselines/trove/controller.py b/symbolic_agent/baselines/trove/controller.py index d11d8b23..d64c638c 100644 --- a/symbolic_agent/baselines/trove/controller.py +++ b/symbolic_agent/baselines/trove/controller.py @@ -37,10 +37,17 @@ from collections import Counter from typing import Callable, Dict, List, Optional +from . import tools_api from .executor import run_solution from .llm import TroVELLMClient -from .parse import count_ast_nodes, parse_response -from .prompts import build_create_prompt, build_import_prompt, build_skip_prompt, get_question +from .parse import count_ast_nodes, imported_callsites, parse_response +from .prompts import ( + build_create_prompt, + build_import_prompt, + build_import_with_tools_prompt, + build_skip_prompt, + get_question, +) from .toolbox import TroVEToolbox logger = logging.getLogger(__name__) @@ -61,18 +68,34 @@ class TroVEController: model : str LLM model identifier. base_url : str, optional - For OpenAI-compatible (vLLM) backends. + For OpenAI-compatible (vLLM) backends. When set, ``self.backend`` is + ``"openai"``; otherwise ``"anthropic"``. Native tool-calling IMPORT + requires the openai backend. debug_dir : str, optional k : int Number of samples per mode (paper default: 5). trim_every : int Trim toolbox every N tasks (paper default: 500). trim_C : float - Trimming threshold multiplier: threshold = C·log₂₀(n). Default: 0.5. + Trimming threshold multiplier: threshold = C·log₂₀(n). Default: 1.0 + (matches the original TroVE implementation). temperature : float Sampling temperature. Default: 0.3 (TroVE paper). top_p : float Nucleus sampling top-p. Default: 0.95 (TroVE paper). + task_family : str + Prompt/parsing family. ``"default"`` (generic) or ``"pbebench"`` + (PBEBench-shaped few-shots; strict ``**Solution**`` parsing). + selection : str + Candidate selection strategy. ``"reward"`` (default) uses the + reward function when available and falls back to consistency; + ``"consistency"`` always uses the original TroVE majority-vote. + max_tool_iters : int + Maximum tool-call rounds per IMPORT trajectory in the native + tool-calling path. Default: 8. + tool_schema_topk : int + Number of top-frequency toolbox functions exposed as OpenAI tool + schemas in the native IMPORT path. Default: 10. """ def __init__( @@ -83,18 +106,26 @@ def __init__( debug_dir: Optional[str] = None, k: int = DEFAULT_K, trim_every: int = DEFAULT_TRIM_EVERY, - trim_C: float = 0.5, + trim_C: float = 1.0, temperature: float = 0.3, top_p: float = 0.95, + task_family: str = "default", + selection: str = "reward", + max_tool_iters: int = 8, + tool_schema_topk: int = 10, ): self.model = model self.k = k self.trim_every = trim_every self.trim_C = trim_C + self.task_family = task_family + self.selection = selection + self.max_tool_iters = max_tool_iters + self.tool_schema_topk = tool_schema_topk - backend = "openai" if base_url else "anthropic" + self.backend = "openai" if base_url else "anthropic" self.llm = TroVELLMClient( - backend=backend, + backend=self.backend, base_url=base_url, api_key=api_key, temperature=temperature, @@ -252,33 +283,52 @@ def _multi_way_generation( toolbox_str = self.toolbox.format_toolbox() # --- IMPORT mode --- - import_candidates = [] - if toolbox_str: + toolbox_nonempty = bool(toolbox_str) + use_tools_branch = toolbox_nonempty and self.backend == "openai" + + if use_tools_branch: + import_candidates = self._generate_import_with_tools( + question, example_idx, reward_fn=reward_fn, entry=entry + ) + best_import_idx, best_import_score = self._select_best( + import_candidates, reward_fn=reward_fn, entry=entry + ) + best_import = import_candidates[best_import_idx] + best_import["_reward_score"] = best_import_score + elif toolbox_nonempty: + # Legacy text-based IMPORT (Anthropic or unforeseen non-OpenAI path). + import_candidates = [] for _ in range(self.k): - prompt = build_import_prompt(question, toolbox_str) + prompt = build_import_prompt(question, toolbox_str, task_family=self.task_family) raw = self.llm.call(prompt, self.model, max_tokens=DEFAULT_MAX_TOKENS, tag="trove_import") - parsed = parse_response(raw) + parsed = parse_response(raw, task_family=self.task_family) is_ok, out = run_solution( parsed["solution_code"], parsed["tools_code"], self.toolbox.get_full_code(), ) - import_candidates.append({**parsed, "is_success": is_ok, "exec_output": out}) + import_candidates.append( + {**parsed, "is_success": is_ok, "exec_output": out, "tool_calls": [], "stopped_reason": "legacy"} + ) best_import_idx, best_import_score = self._select_best( import_candidates, reward_fn=reward_fn, entry=entry ) best_import = import_candidates[best_import_idx] best_import["_reward_score"] = best_import_score else: - best_import = {"solution_code": "", "tools_code": "", "functions": [], - "is_success": False, "exec_output": "", "_reward_score": None} + best_import = { + "solution_code": "", "tools_code": "", "functions": [], + "is_success": False, "exec_output": "", + "tool_calls": [], "stopped_reason": "empty_toolbox", + "_reward_score": None, + } # --- CREATE mode --- create_candidates = [] for _ in range(self.k): - prompt = build_create_prompt(question) + prompt = build_create_prompt(question, task_family=self.task_family) raw = self.llm.call(prompt, self.model, max_tokens=DEFAULT_MAX_TOKENS, tag="trove_create") - parsed = parse_response(raw) + parsed = parse_response(raw, task_family=self.task_family) is_ok, out = run_solution( parsed["solution_code"], parsed["tools_code"], @@ -294,9 +344,9 @@ def _multi_way_generation( # --- SKIP mode --- skip_candidates = [] for _ in range(self.k): - prompt = build_skip_prompt(question) + prompt = build_skip_prompt(question, task_family=self.task_family) raw = self.llm.call(prompt, self.model, max_tokens=DEFAULT_MAX_TOKENS, tag="trove_skip") - parsed = parse_response(raw) + parsed = parse_response(raw, task_family=self.task_family) is_ok, out = run_solution( parsed["solution_code"], parsed["tools_code"], @@ -334,6 +384,54 @@ def _multi_way_generation( ) return winning_mode, best_resp, best_score + def _generate_import_with_tools( + self, + question: str, + example_idx: int, + reward_fn: Optional[Callable] = None, + entry: Optional[dict] = None, + ) -> List[dict]: + """ + IMPORT-mode generation using native OpenAI tool calling. + Builds K trajectories; each trajectory may invoke toolbox functions + via tool_calls during the multi-turn loop. Returns K candidate dicts + compatible with _select_best. + """ + prompt = build_import_with_tools_prompt(question, task_family=self.task_family) + tools_schema = tools_api.toolbox_to_openai_tools(self.toolbox, topk=self.tool_schema_topk) + + candidates: List[dict] = [] + for i in range(self.k): + tag = f"trove_import_t{example_idx}_{i}" + messages = [{"role": "user", "content": prompt}] + on_tc = lambda tc: tools_api.dispatch_tool_call(self.toolbox, tc) + traj = self.llm.chat_with_tools( + messages=messages, + tools=tools_schema, + model=self.model, + max_tokens=DEFAULT_MAX_TOKENS, + max_tool_iters=self.max_tool_iters, + on_tool_call=on_tc, + tag=tag, + ) + parsed = parse_response(traj["final_text"], task_family=self.task_family) + is_ok, out = run_solution( + parsed["solution_code"], + parsed["tools_code"], + self.toolbox.get_full_code(), + ) + candidates.append( + { + **parsed, + "is_success": is_ok, + "exec_output": out, + "tool_calls": traj["tool_calls"], + "stopped_reason": traj["stopped_reason"], + "iterations": traj["iterations"], + } + ) + return candidates + def _select_best( self, candidates: List[dict], @@ -344,18 +442,15 @@ def _select_best( Select the best candidate from a list of response dicts. Returns (best_index, score_or_None) where score is (reward, message) - when reward-based selection is used, or None for majority-vote mode. - - Two selection strategies: - 1. Reward-based (when reward_fn + entry provided): - Score all K candidates with reward_fn; pick highest reward, - tiebreak by minimum AST node count (simplest solution). - This is reliable for PBEBench (program lists rarely match exactly - as strings) and equally good for reasoning_gym. - 2. Majority-vote fallback (original TroVE algorithm): - Filter successes → majority vote on stdout → min AST tiebreak. - Used when no reward function is available (e.g. bare solve()). + when reward-based selection is used, or None otherwise. + + Selection strategy is governed by self.selection: + - "reward" (default): reward-based when reward_fn+entry provided, + falls back to consistency when not. + - "consistency": original TroVE majority-vote algorithm. """ + if self.selection == "consistency": + return self._select_best_by_consistency(candidates), None if reward_fn is not None and entry is not None: return self._select_best_by_reward(candidates, reward_fn, entry) return self._select_best_by_consistency(candidates), None @@ -369,6 +464,7 @@ def _select_best_by_reward( """Reward-based candidate selection. Returns (best_index, (reward, message)).""" best_idx = 0 best_reward = -1.0 + best_reuse = -1 best_ast = float("inf") best_message = "" for i, c in enumerate(candidates): @@ -380,13 +476,36 @@ def _select_best_by_reward( logger.debug("Reward scoring error for candidate %d: %s", i, exc) score, msg = 0.0, str(exc) ast_size = count_ast_nodes(c.get("solution_code", "")) - if score > best_reward or (score == best_reward and ast_size < best_ast): + reuse_signal = self._reuse_signal(c) + if ( + score > best_reward + or ( + score == best_reward + and ( + reuse_signal > best_reuse + or (reuse_signal == best_reuse and ast_size < best_ast) + ) + ) + ): best_idx = i best_reward = score + best_reuse = reuse_signal best_ast = ast_size best_message = msg return best_idx, (best_reward, best_message) + @staticmethod + def _reuse_signal(candidate: dict) -> int: + """Tie-break signal for candidates that support TroVE's toolbox.""" + functions = candidate.get("functions") or [] + tool_calls = candidate.get("tool_calls") or [] + unique_tool_names = { + (tc.get("name") or "").split("<|", 1)[0].strip() + for tc in tool_calls + if isinstance(tc, dict) and tc.get("name") + } + return len(functions) + len({name for name in unique_tool_names if name}) + def _select_best_by_consistency(self, candidates: List[dict]) -> int: """ Original TroVE self-consistency selection (majority vote on stdout). @@ -419,13 +538,25 @@ def _select_best_by_consistency(self, candidates: List[dict]) -> int: def _update_library(self, mode: str, resp: dict, example_idx: int) -> None: """Update toolbox based on winning mode (faithful to run_trove.py).""" if mode == "import": - # IMPORT: credit existing functions that were used - for func_dict in resp.get("functions", []): - name = func_dict.get("name", "") - if name: - self.toolbox.update_frequency(name, example_idx) + tool_calls = resp.get("tool_calls") or [] + if tool_calls: + # Native tool-calling path: credit by unique tool_call.function.name + # (defensive: sanitize and let toolbox.update_frequency filter unknowns). + unique_names = { + tc["name"].split("<|", 1)[0].strip() + for tc in tool_calls + if tc.get("name") + } + for name in unique_names: + if name: + self.toolbox.update_frequency(name, example_idx) + else: + # Legacy text-based IMPORT: credit functions parsed from **Tools**. + for func_dict in resp.get("functions", []): + name = func_dict.get("name", "") + if name: + self.toolbox.update_frequency(name, example_idx) elif mode == "create" and resp.get("is_success"): - # CREATE: add new functions only when execution succeeded for func_dict in resp.get("functions", []): self.toolbox.add(func_dict, example_idx) @@ -447,8 +578,29 @@ def _make_result( ) -> dict: """ Build a result dict compatible with main.py's _print_result() and - _append_task_output(). + _append_task_output(). Adds passive TroVE telemetry fields. """ + tool_calls = best_resp.get("tool_calls") or [] + tools_called = sorted({ + tc["name"].split("<|", 1)[0].strip() + for tc in tool_calls + if tc.get("name") + }) + candidate_names = {e["name"] for e in self.toolbox.snapshot()} + actually_called = sorted( + imported_callsites( + solution_code=best_resp.get("solution_code", ""), + tools_code=best_resp.get("tools_code", ""), + candidate_names=candidate_names, + ) + ) + import_eligible = len(self.toolbox) > 0 # state AFTER this task's update + # Note: import_eligible reflects the current toolbox state after + # _update_library has already run for this task. The analyzer should + # interpret this as "a non-empty toolbox existed at some point during + # this task's processing". For pre-task eligibility, infer from + # toolbox snapshots in adjacent tasks. + return { "task_type": task_type, "original_prompt": str(task_input), @@ -464,7 +616,7 @@ def _make_result( ], "solution": best_resp.get("solution_code", ""), "library_snapshot": self.toolbox.snapshot(), - "cost_summary": {}, # TroVE has no cost model + "cost_summary": {}, "final_output": { "answer": output, "explanation": f"TroVE mode={best_mode}", @@ -475,6 +627,14 @@ def _make_result( "reward_history": [], "best_reward": None, "final_reward": None, - # Cached score from reward-based selection; consumed and removed by solve_with_reward. "_best_reward_score": best_reward_score, + # TroVE native-tool-calling telemetry + "won_mode": best_mode, + "import_eligible": import_eligible, + "import_was_winner": best_mode == "import", + "tool_calls": tool_calls, + "tool_call_count": len(tool_calls), + "tools_called": tools_called, + "actually_called": actually_called, + "trove_stopped_reason": best_resp.get("stopped_reason", ""), } diff --git a/symbolic_agent/baselines/trove/docs/deviations.md b/symbolic_agent/baselines/trove/docs/deviations.md index 06d4c346..fda9d359 100644 --- a/symbolic_agent/baselines/trove/docs/deviations.md +++ b/symbolic_agent/baselines/trove/docs/deviations.md @@ -1,122 +1,85 @@ -# TroVE Baseline — Deviations from the Original Paper - -This document records all intentional and unavoidable deviations between our -reimplementation (`symbolic_agent/baselines/trove/`) and the original TroVE -codebase (`original_baseline_repos/trove/`). - ---- - -## 1. Chat API instead of Local Model Completion - -**Original:** TroVE uses a HuggingFace `transformers.pipeline` with a locally -loaded model (e.g. CodeLlama-7b-Instruct) in **completion** mode. The prompt -is a plain string prefix; the model generates continuation text. - -**Ours:** We use Anthropic's Messages API or an OpenAI-compatible chat API -(vLLM). The prompt is sent as a `user` message; the model generates a reply -that includes the **Solution** and **Tools** blocks. - -**Impact:** Minimal. The prompt structure (ending with `**Solution**`) signals -to chat models what to generate, and empirically they comply. No JSON mode is -used (`TroVELLMClient` vs the main `LLMClient`). - ---- - -## 2. Domain-Generic Few-Shot Examples - -**Original:** TroVE uses domain-specific few-shot examples for each task -(TabMWP coin-collection table examples, MATH algebra examples, etc.) - -**Ours:** We use generic string-manipulation examples that apply to both -PBEBench and ReasoningGym string tasks (replace_char, extract_digits, -lowercase examples). Domain-specific examples for other task families -should be added to `prompts.py` as needed. - -**Impact:** May slightly reduce self-consistency accuracy for tasks where the -original examples provide strong in-context guidance. The structural format -is preserved exactly. - ---- - -## 3. K Calls Rather Than Batched n=K - -**Original:** TroVE passes `num_return_sequences=K` to the HuggingFace -pipeline, which generates K sequences in one forward pass. - -**Ours:** We call the LLM API K times independently (temperature sampling). -The Anthropic API does not support `n` parameter; the OpenAI-compatible API -does but we call separately for simplicity and identical code paths. - -**Impact:** K API calls instead of 1; slightly slower but statistically -equivalent since each call is an independent sample. - ---- - -## 4. AST Node Count Instead of AST Depth Sum - -**Original:** TroVE tie-breaks by `sum(depth of each AST expression node)` -across the solution (referenced in §3.2 and Appendix B). - -**Ours:** `count_ast_nodes()` counts total AST nodes via `ast.walk()`. -Total nodes is monotonically related to total expression depth: simpler -programs have fewer nodes AND lower total depth. The tie-breaking effect -is identical in practice. - -**Impact:** Negligible. Both metrics rank programs by complexity; the ranking -rarely differs for programs with the same stdout. - ---- - -## 5. No Re-Generation of Trimmed Examples - -**Original:** After trimming the toolbox, `run_trove.py` re-generates -solutions for all affected examples using IMPORT|SKIP (not CREATE), then -reports updated accuracy. - -**Ours:** We record the set of affected task indices in the trim log but do -not replay them. This is because we process tasks in a single stream and do -not store the original task inputs for re-processing. For a complete -faithful comparison, task inputs should be saved and re-processed on trim. - -**Impact:** In practice, trimming only fires after 500 tasks with the default -setting. For our 100-task pilot runs, trimming is disabled by setting -`--trove-trim-every 9999`. - ---- - -## 6. Reward Loop Compatibility Wrapper - -**Original:** TroVE has no concept of a reward function or iterative -refinement loop. It is one-shot per example. - -**Ours:** `solve_with_reward()` wraps `solve()` for compatibility with -`main.py`'s `--default-reward` and `--max-reward-iters` flags. No retry -loop is performed; the reward is computed once and stored in `reward_history` -for eval script compatibility. - -**Impact:** None on TroVE's actual behavior. Only affects output format. - ---- - -## 7. `trim_every` Default Differs for Small Runs - -**Original:** Default `--trim_steps=500` (trimming every 500 examples). -For a 100-task dataset this fires 0 times. - -**Ours:** Same default (500), but users running small pilots should pass -`--trove-trim-every 9999` to make it explicit that no trimming happens. - -**Impact:** None unless running >500 tasks. - ---- - -## Summary Table - -| Aspect | Original | Ours | Impact | -|--------|----------|------|--------| -| LLM backend | Local HF model (completion) | Chat API (messages) | Minimal | -| Few-shot examples | Domain-specific (TabMWP/MATH) | Generic string-manipulation | Minor | -| K sampling | Batched (n=K in one call) | K independent API calls | Latency only | -| Complexity metric | Sum of AST expression depths | Total AST node count | Negligible | -| Trim replay | Re-generates affected examples | Records but does not replay | Evaluation accuracy | -| Reward loop | Not in original | Wrapper for main.py compat | None | +# TroVE Implementation: Deviations and Faithful Elements + +This document tracks how this port differs from — and where it stays +faithful to — the original TroVE algorithm +([Wang et al., 2024](https://arxiv.org/abs/2401.12869), +[zorazrw/trove](https://github.com/zorazrw/trove)). + +## 1. Algorithmic deviations + +### 1.1 Native OpenAI tool calling for IMPORT mode +The original TroVE shows the model a `**Toolbox**` markdown block +listing top-k function signatures and asks it to write a `**Solution**` +plus `**Tools**` block referencing those functions by name. We replace +this for the IMPORT mode (when `backend == "openai"` and the toolbox is +non-empty) with **native OpenAI tool calling**: the toolbox is exposed +via the `tools=[...]` parameter of `chat.completions.create`, the model +emits structured `tool_calls` during its reasoning, and `dispatch_tool_call` +runs each one in the sandboxed executor and returns the stdout. This +makes function usage observable and credit-able from the trajectory +itself. + +### 1.2 Reward-based candidate selection (default) +The paper uses self-consistency (majority vote on stdout, AST tie-break) +to pick the best of K samples per mode. We default to **reward-based +selection**: every candidate is scored by the per-task reward function, +ties broken by minimum AST node count. This is more reliable on +PBEBench (program-list outputs rarely tie as strings). The original +self-consistency selector remains available via `--trove-selection consistency`. + +### 1.3 PBEBench-shaped few-shot examples +For `task_family="pbebench"` we replace the generic CREATE / SKIP / IMPORT +example pairs with PBEBench-shaped pairs that demonstrate `replace()` +chains. CREATE mode also shows signature-only examples of reusable helper +shapes (apply, score, search, prune, debug, end-to-end solve) instead of +full function definitions, to reduce anchoring on a single copied helper. +The legacy default examples remain for `task_family="default"`. + +### 1.4 Strict **Solution** parsing for PBEBench +The legacy parser falls back to "first ```python``` block anywhere" when +no `**Solution**` block is present. For `task_family="pbebench"` this +fallback is disabled, preventing CoT scratchpad from being accidentally +promoted to the answer. + +## 2. Faithful elements + +- 3-mode generation (IMPORT, CREATE, SKIP). +- K samples per mode (default K=5, paper). +- AST-tie-breaking by node count (simplest solution wins). +- Periodic toolbox trimming with threshold `C·log_{20}(n)`, default + `C=1.0`, matching the original implementation. +- Frequency-based top-k retrieval for the toolbox view. +- Dict-keyed toolbox structure mirroring `utils/code.py`. +- Library updates: IMPORT credits frequency, CREATE adds new functions + on success, SKIP makes no library changes. + +## 3. Infrastructural patches + +- **JSONL-per-task checkpointing** via `--output-file`, with crash + resumption. +- **`reasoning_content` fallback** in `_call_openai` for `gpt-oss` Harmony + channel splits where the answer text lives in `message.reasoning_content`. +- **Executor timeout 60s** (vs. 10s in earlier versions of this port), + closer to the original's ~100s. +- **`<|`-truncation sanitizer** in `dispatch_tool_call` and + `_update_library`. Defensive workaround for the open vLLM + [PR #35906](https://github.com/vllm-project/vllm/pull/35906) covering + Harmony control-token leakage into tool names. When that PR lands + upstream the sanitizer becomes a no-op and is left in place. + +## 4. Backend coverage caveat + +Anthropic backend code paths exist and are exercised by CREATE / SKIP and +the legacy text-based IMPORT fallback, but **the smoke run and reported +numbers are vLLM-served `gpt-oss` only**. IMPORT-with-tools requires +the OpenAI/vLLM backend and is the only path we test end-to-end. + +## 5. vLLM version requirement + +- Minimum vLLM: **v0.16.0** (branch-cut 2026-02-08). +- Required upstream change: [PR #28729](https://github.com/vllm-project/vllm/pull/28729) + ("Multiple fixes for gpt-oss Chat Completion prompting"), merged + 2025-12-12. v0.16.0 is the first stable release branch-cut after the merge. +- Known open caveat: [PR #35906](https://github.com/vllm-project/vllm/pull/35906) + ("Sanitize leaked Harmony control tokens"), still open as of late + March 2026 — see §3 for the sanitizer mitigation. diff --git a/symbolic_agent/baselines/trove/docs/running.md b/symbolic_agent/baselines/trove/docs/running.md new file mode 100644 index 00000000..704ab74b --- /dev/null +++ b/symbolic_agent/baselines/trove/docs/running.md @@ -0,0 +1,149 @@ +# Running TroVE on PBEBench-Lite + +This guide covers launching the TroVE baseline against `openai/gpt-oss-20b` +served by vLLM. There are two paths: + +- **Notebook (recommended on RunPod)** — `notebooks/run_trove_pbebench.ipynb` + drives the whole flow (env setup → vLLM launch → TroVE run → analysis) from + one place and mirrors logs to disk. +- **Shell scripts** — for SSH / tmux workflows where a notebook is awkward. + +Both paths assume an L40S/H100-class GPU with ≥40 GB VRAM and ≥40 GB free disk +for the model cache. + +--- + +## 0. Prerequisites + +- `vLLM >= 0.16.0` — earlier versions do not ship the gpt-oss reasoning parser + or auto tool-choice support. +- `typing_extensions >= 4.12.2` — older versions break vLLM startup with + `cannot import name 'TypeIs' from typing_extensions`. +- `huggingface_hub` with a working transfer backend. If `xet` errors during + download, set `HF_HUB_DISABLE_XET=1`. +- `HF_HOME` pointed at a persistent volume (e.g. `/workspace/hf-cache`) so the + model is not re-downloaded across container restarts. + +Quick install / repair on a fresh container: + +```bash +python -m pip install -U "typing_extensions>=4.12.2" \ + "huggingface_hub[hf_transfer]" hf_xet +``` + +--- + +## 1. Notebook path (RunPod) + +```bash +git clone /workspace/symbolic-library-agent +cd /workspace/symbolic-library-agent +jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root +``` + +Then open `notebooks/run_trove_pbebench.ipynb` and run the cells top-to-bottom: + +1. **Env / cache setup** — sets `HF_HOME=/workspace/hf-cache` and disables xet. +2. **`pip install` cell** — refreshes `typing_extensions` and HF transfer. +3. **Launch vLLM** — backgrounds `scripts/launch_vllm_gpt_oss_20b.sh 8000` and + tails `vllm_logs/`. +4. **Wait for server ready** — polls `/v1/models` until 200 OK. +5. **`tail_vllm_log(60)` helper** — re-runnable cell for spot-checking the + server log at any time. +6. **Run TroVE** — `subprocess.Popen` of `main.py` with the PBEBench-Lite + pilot tasks. Stdout is mirrored to `outputs/trove_pbebench_lite_smoke_.log` + on disk in addition to the cell output, so you can SSH in and `tail -f` the + run from another shell. +7. **Analyze** — calls `scripts/analyze_trove_run.py` on the output JSONL. + +If the notebook cell stops responding, do **not** `pkill -f "main.py"` — +that pattern can match the vLLM process tree on some images. Instead: + +```bash +ps -ef | awk '/python .*main.py/ && /--framework/ && /trove/ {print $2}' \ + | xargs -r kill +``` + +--- + +## 2. Shell-script path + +Two scripts; run them in two terminals (or one tmux session with two panes). + +### 2a. Launch vLLM + +```bash +cd /workspace/symbolic-library-agent +mkdir -p vllm_logs +bash scripts/launch_vllm_gpt_oss_20b.sh 8000 +# logs: vllm_logs/vllm_8000_.log +# pid : vllm_logs/vllm_8000_.pid +``` + +The script forwards three flags that are required for our IMPORT-with-tools +branch to work: + +- `--enable-auto-tool-choice` +- `--tool-call-parser openai` +- `--reasoning-parser openai_gptoss` + +Wait for `Application startup complete` in the log before continuing. + +### 2b. Run TroVE + +```bash +PORT=8000 bash scripts/run_trove_vllm.sh +``` + +Defaults (overridable via env vars or trailing flags): + +| Env var | Default | +| ------------ | ----------------------------------------- | +| `PORT` | `8000` | +| `TASKS_FILE` | `data/pbebench/lite_pilot_tasks.jsonl` | +| `OUT_FILE` | `outputs/trove_pbebench_lite_pilot.jsonl` | + +Pass through any extra `main.py` flag, e.g.: + +```bash +PORT=8000 bash scripts/run_trove_vllm.sh --num-tasks 10 # quick sanity run +``` + +### 2c. Analyze + +```bash +python scripts/analyze_trove_run.py outputs/trove_pbebench_lite_pilot.jsonl +``` + +Reports overall accuracy, final toolbox size, per-mode wins, IMPORT-mode +tool-call success rate, and the top-10 most-called toolbox functions. + +--- + +## 3. Key flags (cheat sheet) + +The TroVE-specific flags on `main.py` matter most: + +| Flag | Default | Purpose | +| --------------------- | ------------ | ------------------------------------------------------- | +| `--framework` | — | Set to `trove` | +| `--trove-task-family` | `default` | Set to `pbebench` to enable PBEBench few-shots & parser | +| `--trove-selection` | `reward` | `reward` (PBEBench) or `consistency` (original TroVE) | +| `--trove-k` | `5` | Candidates per mode (1 disables sampling) | +| `--trove-trim-every` | `100` | Set high (`9999`) for ≤100-task pilots | +| `--default-reward` | — | Set to `pbebench` for the PBEBench verifier | +| `--max-programs` | `5` | PBEBench program-list length cap | + +--- + +## 4. Resuming and cleanup + +- Resume: just re-run the same command. `main.py` checkpoints to the output + JSONL; if both the JSONL and `--debug-dir` are intact it will skip already- + completed task indices. +- Force-restart: delete the output JSONL before running. +- vLLM cleanup: + ```bash + kill "$(cat vllm_logs/vllm_8000_*.pid)" 2>/dev/null || true + pkill -f vllm.entrypoints.openai.api_server # safe — only matches vLLM + ``` diff --git a/symbolic_agent/baselines/trove/executor.py b/symbolic_agent/baselines/trove/executor.py index cf23471b..1b8717e4 100644 --- a/symbolic_agent/baselines/trove/executor.py +++ b/symbolic_agent/baselines/trove/executor.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -DEFAULT_TIMEOUT = 10 # seconds, matching TroVE's original +DEFAULT_TIMEOUT = 60 # seconds — generous for PBEBench replace() chains and multi-turn dispatch def run_solution( diff --git a/symbolic_agent/baselines/trove/llm.py b/symbolic_agent/baselines/trove/llm.py index d27f8d28..49ea2c35 100644 --- a/symbolic_agent/baselines/trove/llm.py +++ b/symbolic_agent/baselines/trove/llm.py @@ -16,7 +16,7 @@ import os import time from datetime import datetime, timezone -from typing import Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional logger = logging.getLogger(__name__) @@ -26,6 +26,24 @@ DEFAULT_MAX_TOKENS = 512 +def _message_text(msg: Any) -> str: + """Return visible text from OpenAI/vLLM chat message variants.""" + content = getattr(msg, "content", None) + if content: + return content + for field in ("reasoning_content", "reasoning"): + value = getattr(msg, field, None) + if value: + return value + extra = getattr(msg, "model_extra", None) or {} + if isinstance(extra, dict): + for field in ("reasoning_content", "reasoning"): + value = extra.get(field) + if value: + return value + return "" + + class TroVELLMClient: """ Backend-agnostic plain-text LLM client for TroVE generation. @@ -189,7 +207,8 @@ def _call_openai(self, prompt: str, model: str, max_tokens: int, tag: str) -> st messages=messages, # No response_format — TroVE uses free-form text ) - raw = response.choices[0].message.content or "" + msg = response.choices[0].message + raw = _message_text(msg) u = getattr(response, "usage", None) details = getattr(u, "completion_tokens_details", None) usage = { @@ -218,6 +237,174 @@ def _call_openai(self, prompt: str, model: str, max_tokens: int, tag: str) -> st logger.warning("All OpenAI retries exhausted (tag=%s): %s", tag, last_exc) return "" + # ------------------------------------------------------------------ + # Native tool calling (OpenAI/vLLM only) + # ------------------------------------------------------------------ + + def chat_with_tools( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + model: str, + max_tokens: int = DEFAULT_MAX_TOKENS, + max_tool_iters: int = 8, + on_tool_call: Optional[Callable[[Any], str]] = None, + tag: str = "", + ) -> Dict[str, Any]: + """ + Multi-turn chat completion that supports native OpenAI tool calls. + + Returns + ------- + { + "final_text": str, # message.content (or reasoning_content fallback) + "tool_calls": list[dict], # ordered, each {name, args_preview, result_preview, ok} + "iterations": int, # number of round-trips actually used + "stopped_reason": str, # "no_tool_calls" | "max_iters" | "error" + } + + The caller is responsible for providing `on_tool_call(tc) -> str`, + which is invoked for every tool_call returned by the model. The + return value (already a string) is sent back as the tool message. + + Anthropic backend is not supported — this method exists for the + OpenAI/vLLM tool-calling flow only. It raises NotImplementedError + on Anthropic as a defensive guard; controllers must check + `self.backend == "openai"` before calling. + """ + if self.backend != "openai": + raise NotImplementedError("chat_with_tools requires the openai backend") + + if on_tool_call is None: + raise ValueError("chat_with_tools requires an on_tool_call callback") + + recorded_calls: List[Dict[str, Any]] = [] + convo: List[Dict[str, Any]] = list(messages) + iterations = 0 + final_text = "" + stopped_reason = "no_tool_calls" + + for it in range(max_tool_iters + 1): + iterations = it + 1 + iter_tag = f"{tag}_iter{it}" if tag else f"iter{it}" + response = None + last_exc = None + + for attempt in range(3): + try: + response = self._client.chat.completions.create( + model=model, + max_tokens=max_tokens, + messages=convo, + tools=tools, + tool_choice="auto", + ) + break + except Exception as exc: + last_exc = exc + if getattr(exc, "status_code", None) == 400: + logger.warning( + "OpenAI chat_with_tools 400 (tag=%s): %s", iter_tag, exc + ) + self._record(iter_tag, model, json.dumps(convo)[:2000], "", max_tokens, {}) + return { + "final_text": "", + "tool_calls": recorded_calls, + "iterations": iterations, + "stopped_reason": "error", + } + if attempt < 2: + wait = 5 * (2 ** attempt) + logger.warning( + "chat_with_tools failed (attempt %d/3, tag=%s): %s. Retrying in %ds.", + attempt + 1, iter_tag, exc, wait, + ) + time.sleep(wait) + + if response is None: + logger.warning("All chat_with_tools retries exhausted (tag=%s): %s", iter_tag, last_exc) + stopped_reason = "error" + break + + msg = response.choices[0].message + content = _message_text(msg) + tool_calls = getattr(msg, "tool_calls", None) or [] + + u = getattr(response, "usage", None) + details = getattr(u, "completion_tokens_details", None) + usage = { + "input_tokens": getattr(u, "prompt_tokens", 0) or 0, + "output_tokens": getattr(u, "completion_tokens", 0) or 0, + "reasoning_tokens": getattr(details, "reasoning_tokens", 0) or 0 if details else 0, + } + self._record( + iter_tag, + model, + json.dumps(convo)[:2000], + json.dumps({"content": content, "tool_calls_count": len(tool_calls)}), + max_tokens, + usage, + ) + + if not tool_calls: + final_text = content + stopped_reason = "no_tool_calls" + break + + assistant_msg: Dict[str, Any] = { + "role": "assistant", + "content": content, + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + } + convo.append(assistant_msg) + + for tc in tool_calls: + try: + result = on_tool_call(tc) + ok = True + except Exception as exc: + result = json.dumps({"error": f"on_tool_call raised: {exc}"}) + ok = False + args_preview = (tc.function.arguments or "")[:200] + result_preview = (result or "")[:200] + recorded_calls.append( + { + "name": tc.function.name, + "args_preview": args_preview, + "result_preview": result_preview, + "ok": ok, + } + ) + convo.append( + { + "role": "tool", + "tool_call_id": tc.id, + "content": result, + } + ) + + if it >= max_tool_iters - 1: + stopped_reason = "max_iters" + final_text = content + break + + return { + "final_text": final_text, + "tool_calls": recorded_calls, + "iterations": iterations, + "stopped_reason": stopped_reason, + } + # ------------------------------------------------------------------ # Logging # ------------------------------------------------------------------ diff --git a/symbolic_agent/baselines/trove/parse.py b/symbolic_agent/baselines/trove/parse.py index 56a90cba..4a53a733 100644 --- a/symbolic_agent/baselines/trove/parse.py +++ b/symbolic_agent/baselines/trove/parse.py @@ -83,7 +83,7 @@ def _make_executable(code: str) -> str: return stripped -def parse_response(text: str) -> dict: +def parse_response(text: str, task_family: str = "default") -> dict: """ Parse a TroVE-format LLM response. @@ -95,20 +95,17 @@ def parse_response(text: str) -> dict: "functions": list[dict], # parsed tool dicts from the Tools block } - Fallback behaviour - ------------------ - Tasks like PBEBench embed their own format instructions (e.g. "output a - **Program Sequence** block") that can override the TroVE **Solution** - header. When no **Solution** block is found we grab the first ```python``` - block in the response and, if it is a bare list/string literal, wrap it - in print() so it can be executed and its stdout captured as the answer. + task_family + ----------- + "default": if no **Solution** block is found, falls back to the first + ```python``` block anywhere (legacy behaviour). + "pbebench": no fallback. Strict **Solution**-block-only parsing avoids + accidentally promoting CoT scratchpad to the answer. """ solution_code = _extract_code_block(text, "Solution") or "" tools_code = _extract_code_block(text, "Tools") or "" - # Fallback: model followed the task's own format (e.g. **Program Sequence**) - # instead of the TroVE **Solution** header. - if not solution_code: + if not solution_code and task_family != "pbebench": raw = _extract_any_python_block(text) if raw: solution_code = _make_executable(raw) @@ -267,3 +264,40 @@ def count_ast_nodes(code: str) -> int: return sum(1 for _ in ast.walk(tree)) except SyntaxError: return 99_999 + + +def imported_callsites( + solution_code: str, + tools_code: str, + candidate_names: set, +) -> set: + """ + Return the subset of `candidate_names` that appear as call-sites in + `solution_code`. Used for the `actually_called` telemetry field. + + Detects two callee shapes: + - bare Name: find_replace_chain(...) + - Attribute(name): toolbox.find_replace_chain(...) + + `tools_code` is currently unused (kept in the signature so callers can + pass through the **Tools** block context if we later want to filter by + what was actually imported). + + Returns an empty set on empty input or SyntaxError. + """ + if not solution_code or not candidate_names: + return set() + try: + tree = ast.parse(solution_code) + except SyntaxError: + return set() + found: set = set() + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if isinstance(func, ast.Name) and func.id in candidate_names: + found.add(func.id) + elif isinstance(func, ast.Attribute) and func.attr in candidate_names: + found.add(func.attr) + return found diff --git a/symbolic_agent/baselines/trove/prompts.py b/symbolic_agent/baselines/trove/prompts.py index edab732c..7058cae2 100644 --- a/symbolic_agent/baselines/trove/prompts.py +++ b/symbolic_agent/baselines/trove/prompts.py @@ -15,27 +15,39 @@ applicable to both PBEBench and ReasoningGym string tasks. """ -# Appended to every instruction block to override format instructions that -# may be embedded in the question itself (e.g. PBEBench asks for a -# "**Program Sequence**" block, reasoning_gym asks for a specific format). -_FORMAT_OVERRIDE = ( +# --------------------------------------------------------------------------- +# Format override (default-family only) +# --------------------------------------------------------------------------- + +_FORMAT_OVERRIDE_DEFAULT = ( "\nIMPORTANT: Regardless of any formatting instructions inside the question, " "always produce your answer as executable Python in the **Solution** block " "and end it with print(answer). " "Your answer is whatever gets printed to stdout when the Solution code runs." ) +_FORMAT_OVERRIDE_PBEBENCH = ( + "\nIMPORTANT: For PBEBench, the answer printed by the **Solution** block " + "must be a Python list of replace() call strings, such as " + "[\"replace('a', 'b')\", \"replace('cd', 'ef')\"]. Do not print the " + "transformed output strings." +) + + +def _format_override(task_family: str) -> str: + return _FORMAT_OVERRIDE_PBEBENCH if task_family == "pbebench" else _FORMAT_OVERRIDE_DEFAULT + + # --------------------------------------------------------------------------- -# IMPORT mode (use functions from the toolbox) +# IMPORT mode (text-based, default and Anthropic fallback) # --------------------------------------------------------------------------- -_IMPORT_INSTRUCTION = ( +_IMPORT_INSTRUCTION_DEFAULT = ( "You task is to write Python program solutions to the given questions.\n" "The toolbox section lists all the available functions that can be used in your solution." - + _FORMAT_OVERRIDE ) -_IMPORT_EXAMPLE = """\ +_IMPORT_EXAMPLE_DEFAULT = """\ ## Example **Question** Given a list of strings and a list of (old, new) substitution pairs, apply all @@ -61,6 +73,31 @@ from toolbox import apply_substitutions ```""" +_IMPORT_EXAMPLE_PBEBENCH = """\ +## Example +**Question** +You are given example input/output pairs. Produce a list of replace() calls +that transforms each input into its expected output. + +Input: "hello world" +Output: "HELLO_WORLD" + +**Toolbox** +```python +# Apply a chain of (old, new) replacements to a string. +find_replace_chain(s: str, pairs: list) -> str +``` + +**Solution** +```python +result = find_replace_chain("hello world", [(" ", "_"), ("h", "H"), ("e", "E"), ("l", "L"), ("o", "O"), ("w", "W"), ("r", "R"), ("d", "D")]) +print(result) +``` +**Tools** +```python +from toolbox import find_replace_chain +```""" + _IMPORT_TASK_TEMPLATE = """\ ## Task **Question** @@ -73,29 +110,127 @@ """ -def build_import_prompt(question: str, toolbox_str: str) -> str: - """Build the IMPORT-mode prompt for a single task.""" +def build_import_prompt(question: str, toolbox_str: str, task_family: str = "default") -> str: + """Build the text-based IMPORT-mode prompt (used for Anthropic and as fallback).""" + instruction = _IMPORT_INSTRUCTION_DEFAULT + _format_override(task_family) + example = _IMPORT_EXAMPLE_PBEBENCH if task_family == "pbebench" else _IMPORT_EXAMPLE_DEFAULT return ( - _IMPORT_INSTRUCTION + instruction + "\n\n\n" - + _IMPORT_EXAMPLE + + example + "\n\n\n" + _IMPORT_TASK_TEMPLATE.format(question=question, toolbox=toolbox_str) ) # --------------------------------------------------------------------------- -# CREATE mode (create new reusable functions) +# IMPORT-with-tools mode (native OpenAI tool calling; no **Toolbox** block) # --------------------------------------------------------------------------- -_CREATE_INSTRUCTION = ( +_IMPORT_WITH_TOOLS_INSTRUCTION_DEFAULT = ( + "You task is to write Python program solutions to the given questions.\n" + "You have a set of helper functions available as tools. Call any of them " + "when they help you solve the question; otherwise solve directly. After " + "you have computed the answer, output it as executable Python in a " + "**Solution** block and end with print(answer)." +) + +_IMPORT_WITH_TOOLS_INSTRUCTION_PBEBENCH = ( + "You task is to produce a list of replace() calls that transforms each " + "input into its expected output for a Programming-by-Example task.\n" + "You have a set of helper functions available as tools. Call any of them " + "to test ideas or compute intermediate results; the final **Solution** " + "block must print the program sequence as a Python list of replace() call " + "strings, not the transformed outputs." +) + +_IMPORT_WITH_TOOLS_EXAMPLE_DEFAULT = """\ +## Example +**Question** +Apply substitutions [("a","o"),("t","p")] to ["cat","bat"] and return the list. + +(After optionally calling `apply_substitutions` as a tool to confirm, +the assistant produces:) + +**Solution** +```python +strings = ["cat", "bat"] +subs = [("a", "o"), ("t", "p")] +result = apply_substitutions(strings, subs) +print(result) +```""" + +_IMPORT_WITH_TOOLS_EXAMPLE_PBEBENCH = """\ +## Example +**Question** +Produce a sequence of replace() calls that transforms "hello world" into +"HELLO_WORLD". + +(After optionally calling `find_replace_chain` as a tool to verify a +candidate sequence, the assistant produces:) + +**Solution** +```python +programs = ["replace(' ', '_')", "replace('h', 'H')", "replace('e', 'E')", "replace('l', 'L')", "replace('o', 'O')", "replace('w', 'W')", "replace('r', 'R')", "replace('d', 'D')"] +print(programs) +```""" + +_IMPORT_WITH_TOOLS_TASK_TEMPLATE = """\ +## Task +**Question** +{question} + +**Solution** +""" + + +def build_import_with_tools_prompt(question: str, task_family: str = "default") -> str: + """ + Build the IMPORT-with-tools prompt. The toolbox is NOT shown as text — it + is conveyed via the OpenAI tools=[...] parameter on the chat completion call. + """ + if task_family == "pbebench": + instruction = _IMPORT_WITH_TOOLS_INSTRUCTION_PBEBENCH + example = _IMPORT_WITH_TOOLS_EXAMPLE_PBEBENCH + else: + instruction = _IMPORT_WITH_TOOLS_INSTRUCTION_DEFAULT + example = _IMPORT_WITH_TOOLS_EXAMPLE_DEFAULT + return ( + instruction + + "\n\n\n" + + example + + "\n\n\n" + + _IMPORT_WITH_TOOLS_TASK_TEMPLATE.format(question=question) + ) + + +# --------------------------------------------------------------------------- +# CREATE mode +# --------------------------------------------------------------------------- + +_CREATE_INSTRUCTION_DEFAULT = ( "You task is to write Python program solutions to the given questions.\n" "You should also create Python functions that can be used by your solution, " "if you believe the function can be reused to solve other questions." - + _FORMAT_OVERRIDE ) -_CREATE_EXAMPLE = """\ +_CREATE_INSTRUCTION_PBEBENCH = ( + "You task is to write Python program solutions to the given questions.\n" + "In CREATE mode, you must define at least one reusable helper function " + "inside a **Tools** code block. The **Solution** block should use or " + "accompany that helper as appropriate, but the printed answer must remain " + "a Python list of replace() call strings.\n" + "Prefer general helpers that any PBEBench task could reuse (e.g. parsing a " + "replace() call string, applying a candidate program list to inputs, or " + "scoring a program list against input/output pairs). If a helper that " + "already exists in the toolbox would solve this question, reuse it via " + "IMPORT mode instead of defining a near-duplicate here.\n" + "The helper signatures below are examples of useful tool shapes, not " + "definitions to copy. If you create a helper, implement the complete " + "function body in **Tools**." +) + +_CREATE_EXAMPLE_DEFAULT = """\ ## Example **Question** Given a list of strings and a list of (old, new) substitution pairs, apply all @@ -122,6 +257,31 @@ def apply_substitutions(strings, substitutions): return out ```""" +_CREATE_EXAMPLE_PBEBENCH = """\ +## Reusable helper signatures +These are example shapes for reusable PBEBench tools. Do not copy `...` stubs +as real tools; implement complete helpers when you decide to create one. +```python +def apply_programs(s, programs): ... +def score_programs(programs, examples): ... +def search_candidate_programs(examples, max_programs=5): ... +def prune_search_state(partial_programs, examples): ... +def debug_program_failure(programs, examples): ... +def solve_examples(examples, max_programs=5): ... +``` + +## Example +**Question** +Produce a sequence of replace() calls that transforms "hello world" into +"HELLO_WORLD". + +**Solution** +```python +programs = ["replace(' ', '_')", "replace('h', 'H')", "replace('e', 'E')", "replace('l', 'L')", "replace('o', 'O')", "replace('w', 'W')", "replace('r', 'R')", "replace('d', 'D')"] +print(programs) +``` +""" + _CREATE_TASK_TEMPLATE = """\ ## Task **Question** @@ -131,27 +291,33 @@ def apply_substitutions(strings, substitutions): """ -def build_create_prompt(question: str) -> str: +def build_create_prompt(question: str, task_family: str = "default") -> str: """Build the CREATE-mode prompt for a single task.""" + create_instruction = ( + _CREATE_INSTRUCTION_PBEBENCH + if task_family == "pbebench" + else _CREATE_INSTRUCTION_DEFAULT + ) + instruction = create_instruction + _format_override(task_family) + example = _CREATE_EXAMPLE_PBEBENCH if task_family == "pbebench" else _CREATE_EXAMPLE_DEFAULT return ( - _CREATE_INSTRUCTION + instruction + "\n\n\n" - + _CREATE_EXAMPLE + + example + "\n\n\n" + _CREATE_TASK_TEMPLATE.format(question=question) ) # --------------------------------------------------------------------------- -# SKIP mode (inline solution, no new functions) +# SKIP mode # --------------------------------------------------------------------------- -_SKIP_INSTRUCTION = ( +_SKIP_INSTRUCTION_DEFAULT = ( "You task is to write Python program solutions to the given questions." - + _FORMAT_OVERRIDE ) -_SKIP_EXAMPLE = """\ +_SKIP_EXAMPLE_DEFAULT = """\ ## Example **Question** Given the list of strings ["Hello", "World"], convert each to lowercase and @@ -167,6 +333,21 @@ def build_create_prompt(question: str) -> str: ```python ```""" +_SKIP_EXAMPLE_PBEBENCH = """\ +## Example +**Question** +Produce a sequence of replace() calls that transforms "hello world" into +"HELLO_WORLD". + +**Solution** +```python +programs = ["replace(' ', '_')", "replace('h', 'H')", "replace('e', 'E')", "replace('l', 'L')", "replace('o', 'O')", "replace('w', 'W')", "replace('r', 'R')", "replace('d', 'D')"] +print(programs) +``` +**Tools** +```python +```""" + _SKIP_TASK_TEMPLATE = """\ ## Task **Question** @@ -176,12 +357,14 @@ def build_create_prompt(question: str) -> str: """ -def build_skip_prompt(question: str) -> str: +def build_skip_prompt(question: str, task_family: str = "default") -> str: """Build the SKIP-mode prompt for a single task.""" + instruction = _SKIP_INSTRUCTION_DEFAULT + _format_override(task_family) + example = _SKIP_EXAMPLE_PBEBENCH if task_family == "pbebench" else _SKIP_EXAMPLE_DEFAULT return ( - _SKIP_INSTRUCTION + instruction + "\n\n\n" - + _SKIP_EXAMPLE + + example + "\n\n\n" + _SKIP_TASK_TEMPLATE.format(question=question) ) diff --git a/symbolic_agent/baselines/trove/tests/__init__.py b/symbolic_agent/baselines/trove/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/symbolic_agent/baselines/trove/tests/test_controller_selection.py b/symbolic_agent/baselines/trove/tests/test_controller_selection.py new file mode 100644 index 00000000..10d1e2f8 --- /dev/null +++ b/symbolic_agent/baselines/trove/tests/test_controller_selection.py @@ -0,0 +1,86 @@ +"""Unit tests for TroVE candidate selection.""" + +from symbolic_agent.baselines.trove.controller import TroVEController + + +def _reward(output, is_success, entry): + return {"value": 1.0 if is_success else 0.0, "message": ""} + + +def _controller(): + controller = object.__new__(TroVEController) + controller.selection = "reward" + return controller + + +def test_reward_tie_prefers_candidate_that_adds_reusable_functions(): + candidates = [ + { + "solution_code": "programs = [\"replace('a','b')\"]\nprint(programs)", + "exec_output": "[\"replace('a','b')\"]", + "is_success": True, + "functions": [], + }, + { + "solution_code": ( + "programs = infer_programs(['a'], ['b'])\n" + "print(programs)\n" + "def helper_for_ast_size():\n" + " return 1\n" + ), + "exec_output": "[\"replace('a','b')\"]", + "is_success": True, + "functions": [{"name": "infer_programs"}], + }, + ] + + idx, score = _controller()._select_best_by_reward(candidates, _reward, {}) + + assert idx == 1 + assert score == (1.0, "") + + +def test_reward_tie_prefers_candidate_that_called_import_tools(): + candidates = [ + { + "solution_code": "programs = [\"replace('a','b')\"]\nprint(programs)", + "exec_output": "[\"replace('a','b')\"]", + "is_success": True, + "functions": [], + "tool_calls": [], + }, + { + "solution_code": "programs = infer_programs(['a'], ['b'])\nprint(programs)", + "exec_output": "[\"replace('a','b')\"]", + "is_success": True, + "functions": [], + "tool_calls": [{"name": "infer_programs"}], + }, + ] + + idx, score = _controller()._select_best_by_reward(candidates, _reward, {}) + + assert idx == 1 + assert score == (1.0, "") + + +def test_reward_tie_uses_smallest_ast_when_reuse_signal_matches(): + candidates = [ + { + "solution_code": "x = 1\ny = 2\nprograms = [\"replace('a','b')\"]\nprint(programs)", + "exec_output": "[\"replace('a','b')\"]", + "is_success": True, + "functions": [], + }, + { + "solution_code": "programs = [\"replace('a','b')\"]\nprint(programs)", + "exec_output": "[\"replace('a','b')\"]", + "is_success": True, + "functions": [], + }, + ] + + idx, score = _controller()._select_best_by_reward(candidates, _reward, {}) + + assert idx == 1 + assert score == (1.0, "") diff --git a/symbolic_agent/baselines/trove/tests/test_llm_openai_response.py b/symbolic_agent/baselines/trove/tests/test_llm_openai_response.py new file mode 100644 index 00000000..8b193417 --- /dev/null +++ b/symbolic_agent/baselines/trove/tests/test_llm_openai_response.py @@ -0,0 +1,37 @@ +"""Unit tests for TroVELLMClient OpenAI/vLLM response extraction.""" + +from types import SimpleNamespace + +from symbolic_agent.baselines.trove.llm import TroVELLMClient + + +class _FakeCompletions: + def create(self, **kwargs): + msg = SimpleNamespace(content="", reasoning="**Solution**\n```python\nprint('ok')\n```") + usage = SimpleNamespace(prompt_tokens=1, completion_tokens=2, completion_tokens_details=None) + return SimpleNamespace(choices=[SimpleNamespace(message=msg)], usage=usage) + + +class _FakeClient: + def __init__(self): + self.chat = SimpleNamespace(completions=_FakeCompletions()) + + +def _client_with_fake_openai_response(): + client = object.__new__(TroVELLMClient) + client.backend = "openai" + client._client = _FakeClient() + client._task_log = [] + client._task_tokens = {"input": 0, "output": 0, "reasoning": 0} + client._session_tokens = {"input": 0, "output": 0, "reasoning": 0} + client._debug_dir = None + return client + + +def test_openai_call_reads_vllm_reasoning_field_when_content_empty(): + client = _client_with_fake_openai_response() + + raw = client._call_openai("prompt", "openai/gpt-oss-20b", 128, "tag") + + assert "print('ok')" in raw + assert "print('ok')" in client.get_task_log()[0]["response"]["content"] diff --git a/symbolic_agent/baselines/trove/tests/test_parse_callsites.py b/symbolic_agent/baselines/trove/tests/test_parse_callsites.py new file mode 100644 index 00000000..3429061b --- /dev/null +++ b/symbolic_agent/baselines/trove/tests/test_parse_callsites.py @@ -0,0 +1,65 @@ +"""Unit tests for parse.imported_callsites and parse_response(task_family=).""" + +from symbolic_agent.baselines.trove.parse import imported_callsites, parse_response + + +# --------------------------------------------------------------------------- +# imported_callsites +# --------------------------------------------------------------------------- + +def test_callsites_bare_name(): + code = "result = find_replace_chain(s, [('a', 'b')])\nprint(result)" + assert imported_callsites(code, tools_code="", candidate_names={"find_replace_chain", "other"}) == {"find_replace_chain"} + + +def test_callsites_attribute_access(): + code = "result = toolbox.find_replace_chain(s, pairs)\nprint(result)" + assert imported_callsites(code, tools_code="", candidate_names={"find_replace_chain"}) == {"find_replace_chain"} + + +def test_callsites_no_match(): + code = "print(s.replace('a', 'b'))" + assert imported_callsites(code, tools_code="", candidate_names={"find_replace_chain"}) == set() + + +def test_callsites_multiple_calls_same_name_dedup(): + code = "x = f(1)\ny = f(2)\nprint(x, y)" + assert imported_callsites(code, tools_code="", candidate_names={"f", "g"}) == {"f"} + + +def test_callsites_syntax_error_returns_empty(): + code = "this is not valid python ::" + assert imported_callsites(code, tools_code="", candidate_names={"f"}) == set() + + +def test_callsites_empty_inputs(): + assert imported_callsites("", "", set()) == set() + assert imported_callsites("print(1)", "", set()) == set() + + +# --------------------------------------------------------------------------- +# parse_response(task_family=) +# --------------------------------------------------------------------------- + +def test_parse_response_pbebench_strict_no_solution_block(): + text = "Here is some reasoning.\n```python\nprint('answer')\n```\n" + out = parse_response(text, task_family="pbebench") + assert out["solution_code"] == "" + + +def test_parse_response_pbebench_with_solution_block(): + text = "**Solution**\n```python\nprint('answer')\n```\n" + out = parse_response(text, task_family="pbebench") + assert out["solution_code"] == "print('answer')" + + +def test_parse_response_default_falls_back_to_any_python_block(): + text = "Here is some reasoning.\n```python\nprint('answer')\n```\n" + out = parse_response(text, task_family="default") + assert "print('answer')" in out["solution_code"] + + +def test_parse_response_default_call_signature_unchanged(): + text = "**Solution**\n```python\nprint('answer')\n```\n" + out = parse_response(text) + assert out["solution_code"] == "print('answer')" diff --git a/symbolic_agent/baselines/trove/tests/test_prompts_pbebench.py b/symbolic_agent/baselines/trove/tests/test_prompts_pbebench.py new file mode 100644 index 00000000..9d0685ad --- /dev/null +++ b/symbolic_agent/baselines/trove/tests/test_prompts_pbebench.py @@ -0,0 +1,55 @@ +"""Regression tests for PBEBench-shaped TroVE prompts.""" + +from symbolic_agent.baselines.trove.prompts import ( + build_create_prompt, + build_import_with_tools_prompt, + build_skip_prompt, +) + + +def _assert_pbebench_prompt_prints_program_sequence(prompt: str) -> None: + assert "print(programs)" in prompt + assert "\"replace(' ', '_')\"" in prompt + assert "\"replace('h', 'H')\"" in prompt + assert "print(result)" not in prompt + assert "print(s)" not in prompt + + +def test_pbebench_create_prompt_models_replace_program_list_stdout(): + prompt = build_create_prompt("Task", task_family="pbebench") + + _assert_pbebench_prompt_prints_program_sequence(prompt) + assert "must define at least one reusable helper function" in prompt + assert "**Tools**" in prompt + + +def test_pbebench_create_prompt_suggests_pbebench_helper_signatures(): + prompt = build_create_prompt("Task", task_family="pbebench") + + assert "Reusable helper signatures" in prompt + assert "def apply_programs(s, programs): ..." in prompt + assert "def score_programs(programs, examples): ..." in prompt + assert "def search_candidate_programs(examples, max_programs=5): ..." in prompt + assert "def debug_program_failure(programs, examples): ..." in prompt + assert "def find_replace_chain" not in prompt + assert "import ast" not in prompt + assert "ast.parse" not in prompt + assert "return correct / len(examples)" not in prompt + + +def test_pbebench_create_prompt_warns_against_duplicating_existing_tools(): + prompt = build_create_prompt("Task", task_family="pbebench") + + assert "already exists" in prompt or "duplicate" in prompt.lower() + + +def test_pbebench_skip_prompt_models_replace_program_list_stdout(): + prompt = build_skip_prompt("Task", task_family="pbebench") + + _assert_pbebench_prompt_prints_program_sequence(prompt) + + +def test_pbebench_import_with_tools_prompt_models_replace_program_list_stdout(): + prompt = build_import_with_tools_prompt("Task", task_family="pbebench") + + _assert_pbebench_prompt_prints_program_sequence(prompt) diff --git a/symbolic_agent/baselines/trove/tests/test_tools_api.py b/symbolic_agent/baselines/trove/tests/test_tools_api.py new file mode 100644 index 00000000..8fc9d671 --- /dev/null +++ b/symbolic_agent/baselines/trove/tests/test_tools_api.py @@ -0,0 +1,163 @@ +"""Unit tests for tools_api.toolbox_to_openai_tools and dispatch_tool_call.""" + +import json +from types import SimpleNamespace + +from symbolic_agent.baselines.trove.toolbox import TroVEToolbox +from symbolic_agent.baselines.trove.tools_api import ( + dispatch_tool_call, + toolbox_to_openai_tools, +) + + +def _make_toolbox_with(func_src: str, name: str, docstr: str = "") -> TroVEToolbox: + tb = TroVEToolbox() + tb.add( + { + "name": name, + "docstr": docstr, + "signature": f"def {name}(...)", + "function": func_src, + "type": "function", + }, + example_idx=0, + ) + return tb + + +def _tool_call(name: str, args: dict, call_id: str = "call_1"): + return SimpleNamespace( + id=call_id, + function=SimpleNamespace(name=name, arguments=json.dumps(args)), + ) + + +# --------------------------------------------------------------------------- +# toolbox_to_openai_tools +# --------------------------------------------------------------------------- + +def test_schema_basic_function(): + src = ( + "def find_replace_chain(s: str, pairs: list) -> str:\n" + ' """Apply a chain of (old, new) replacements to a string."""\n' + " for old, new in pairs:\n" + " s = s.replace(old, new)\n" + " return s\n" + ) + tb = _make_toolbox_with(src, "find_replace_chain", docstr="Apply a chain of (old, new) replacements to a string.") + tools = toolbox_to_openai_tools(tb, topk=10) + assert len(tools) == 1 + fn = tools[0] + assert fn["type"] == "function" + assert fn["function"]["name"] == "find_replace_chain" + assert fn["function"]["description"] == "Apply a chain of (old, new) replacements to a string." + params = fn["function"]["parameters"] + assert params["type"] == "object" + assert set(params["properties"].keys()) == {"s", "pairs"} + assert params["properties"]["s"]["type"] == "string" + assert params["properties"]["pairs"]["type"] == "array" + assert set(params["required"]) == {"s", "pairs"} + + +def test_schema_unannotated_falls_back_to_string(): + src = ( + "def f(x):\n" + " return x\n" + ) + tb = _make_toolbox_with(src, "f") + tools = toolbox_to_openai_tools(tb, topk=10) + assert tools[0]["function"]["parameters"]["properties"]["x"]["type"] == "string" + + +def test_schema_skips_varargs_kwargs(): + src = ( + "def f(*args, **kwargs):\n" + " return args\n" + ) + tb = _make_toolbox_with(src, "f") + tools = toolbox_to_openai_tools(tb, topk=10) + assert tools == [] + + +def test_schema_required_excludes_defaults(): + src = ( + "def f(x: int, y: int = 5):\n" + " return x + y\n" + ) + tb = _make_toolbox_with(src, "f") + tools = toolbox_to_openai_tools(tb, topk=10) + params = tools[0]["function"]["parameters"] + assert params["required"] == ["x"] + assert params["properties"]["y"]["type"] == "integer" + + +def test_schema_topk_respects_frequency(): + tb = TroVEToolbox() + for n, freq in [("a", 3), ("b", 2), ("c", 1)]: + tb.add( + { + "name": n, + "docstr": "", + "signature": f"def {n}()", + "function": f"def {n}():\n return 0\n", + "type": "function", + }, + example_idx=0, + ) + for _ in range(freq - 1): + tb.update_frequency(n, example_idx=0) + tools = toolbox_to_openai_tools(tb, topk=2) + assert [t["function"]["name"] for t in tools] == ["a", "b"] + + +def test_schema_empty_toolbox(): + assert toolbox_to_openai_tools(TroVEToolbox(), topk=10) == [] + + +# --------------------------------------------------------------------------- +# dispatch_tool_call +# --------------------------------------------------------------------------- + +def test_dispatch_runs_function_and_returns_stdout(): + src = ( + "def reverse_str(s):\n" + " return s[::-1]\n" + ) + tb = _make_toolbox_with(src, "reverse_str") + result = dispatch_tool_call(tb, _tool_call("reverse_str", {"s": "hello"})) + assert "olleh" in result + + +def test_dispatch_unknown_tool_returns_error(): + tb = TroVEToolbox() + result = dispatch_tool_call(tb, _tool_call("nonexistent", {})) + assert "not in toolbox" in result + + +def test_dispatch_bad_json_returns_error(): + src = "def f(x):\n return x\n" + tb = _make_toolbox_with(src, "f") + bad = SimpleNamespace( + id="x", + function=SimpleNamespace(name="f", arguments="{not json"), + ) + result = dispatch_tool_call(tb, bad) + assert "argument JSON parse failed" in result + + +def test_dispatch_sanitizes_harmony_contamination(): + src = "def reverse_str(s):\n return s[::-1]\n" + tb = _make_toolbox_with(src, "reverse_str") + tc = _tool_call("reverse_str<|channel|>commentary", {"s": "abc"}) + result = dispatch_tool_call(tb, tc) + assert "cba" in result + + +def test_dispatch_truncates_long_output(): + src = ( + "def long_output(n):\n" + " return 'x' * n\n" + ) + tb = _make_toolbox_with(src, "long_output") + result = dispatch_tool_call(tb, _tool_call("long_output", {"n": 10000})) + assert len(result) <= 4096 + 100 # +slack for repr quotes and truncation marker diff --git a/symbolic_agent/baselines/trove/toolbox.py b/symbolic_agent/baselines/trove/toolbox.py index 9cae9532..617b66ae 100644 --- a/symbolic_agent/baselines/trove/toolbox.py +++ b/symbolic_agent/baselines/trove/toolbox.py @@ -114,7 +114,7 @@ def get_full_code(self) -> str: # Trimming # ------------------------------------------------------------------ - def trim(self, n_processed: int, C: float = 0.5) -> set: + def trim(self, n_processed: int, C: float = 1.0) -> set: """ Remove functions whose frequency is below the threshold C * log_{20}(n_processed) @@ -122,7 +122,7 @@ def trim(self, n_processed: int, C: float = 0.5) -> set: Faithful to trim_library() in run_trove.py: threshold = math.log(n, 20) # log base 20 - C defaults to 0.5, matching the paper (§3.3): λ = ½ · log_{10}(n). + C defaults to 1.0, matching the original implementation (C·log_{20}(n)). Note: the original uses log base-20 not base-10; we keep base-20. """ if n_processed <= 1: diff --git a/symbolic_agent/baselines/trove/tools_api.py b/symbolic_agent/baselines/trove/tools_api.py new file mode 100644 index 00000000..c0edc151 --- /dev/null +++ b/symbolic_agent/baselines/trove/tools_api.py @@ -0,0 +1,170 @@ +"""Translate the TroVE toolbox into OpenAI Chat Completions tool schemas +and dispatch tool calls back through the executor. + +This module is the bridge between TroVE's in-memory toolbox and vLLM's +native tool-calling protocol. It is invoked only from the IMPORT-with-tools +controller branch. +""" + +from __future__ import annotations + +import inspect +import json +import logging +from typing import Any + +from .executor import run_solution +from .toolbox import TroVEToolbox + +logger = logging.getLogger(__name__) + +_MAX_RESULT_CHARS = 4096 + +# Type inference: Python annotation -> JSON Schema type. +_TYPE_MAP = { + int: "integer", + float: "number", + bool: "boolean", + str: "string", + list: "array", + tuple: "array", + dict: "object", +} + + +def _infer_type(annotation: Any) -> str: + if annotation is inspect.Parameter.empty: + return "string" + # Plain types (int, str, etc.) + if annotation in _TYPE_MAP: + return _TYPE_MAP[annotation] + # typing.List, typing.Dict, etc. — fall through to string if unrecognised. + origin = getattr(annotation, "__origin__", None) + if origin in _TYPE_MAP: + return _TYPE_MAP[origin] + return "string" + + +def _function_to_schema(name: str, fn: Any, docstr: str) -> dict | None: + """ + Build one OpenAI tool dict from a callable. Returns None if the function + has *args or **kwargs (we cannot generate a meaningful schema). + """ + try: + sig = inspect.signature(fn) + except (TypeError, ValueError) as exc: + logger.debug("Could not introspect %s: %s", name, exc) + return None + + properties: dict = {} + required: list = [] + + for pname, param in sig.parameters.items(): + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + logger.debug("Skipping %s — has *args/**kwargs", name) + return None + prop: dict = {"type": _infer_type(param.annotation)} + if param.default is not inspect.Parameter.empty: + if isinstance(param.default, (int, float, bool, str)): + prop["default"] = param.default + else: + required.append(pname) + properties[pname] = prop + + return { + "type": "function", + "function": { + "name": name, + "description": docstr or "", + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + + +def toolbox_to_openai_tools(toolbox: TroVEToolbox, topk: int = 10) -> list: + """ + Convert the top-k toolbox functions (by frequency) into OpenAI Chat + Completions tool dicts. + + Functions with *args / **kwargs are silently excluded. + Returns [] when the toolbox is empty. + """ + entries = toolbox.snapshot() + if not entries: + return [] + entries.sort(key=lambda e: -int(e.get("frequency", 0))) + selected = entries[:topk] + + namespace: dict = {} + try: + # compile(..., dont_inherit=True) so this module's `from __future__ import + # annotations` is not applied to the toolbox source; we need real types in + # `__annotations__` for inspect.signature() / _infer_type. + _code = compile( + toolbox.get_full_code(), "", "exec", dont_inherit=True + ) + exec(_code, namespace) + except Exception as exc: + logger.warning("Could not exec toolbox source for schema generation: %s", exc) + return [] + + tools: list = [] + for entry in selected: + name = entry.get("name", "") + if not name or name not in namespace: + continue + fn = namespace[name] + schema = _function_to_schema(name, fn, entry.get("docstr", "")) + if schema is not None: + tools.append(schema) + return tools + + +def _sanitize_name(name: str) -> str: + """Defensive workaround for vLLM PR #35906 (Harmony control tokens + leaking into tool names like `reverse_str<|channel|>commentary`).""" + return name.split("<|", 1)[0].strip() + + +def _truncate(s: str, limit: int = _MAX_RESULT_CHARS) -> str: + if len(s) <= limit: + return s + return s[:limit] + f"\n... [truncated {len(s) - limit} chars]" + + +def dispatch_tool_call(toolbox: TroVEToolbox, tool_call) -> str: + """ + Resolve `tool_call` against the toolbox, run it via the sandbox executor, + and return the captured stdout (truncated to 4096 chars) or an error + message string. Always returns a string — never raises. + """ + name = _sanitize_name(getattr(tool_call.function, "name", "") or "") + if not name: + return json.dumps({"error": "tool_call has no function name"}) + if name not in {e["name"] for e in toolbox.snapshot()}: + return json.dumps({"error": f"tool '{name}' not in toolbox"}) + + raw_args = getattr(tool_call.function, "arguments", "") or "{}" + try: + args = json.loads(raw_args) + if not isinstance(args, dict): + return json.dumps({"error": f"argument JSON parse failed: expected object, got {type(args).__name__}"}) + except json.JSONDecodeError as exc: + return json.dumps({"error": f"argument JSON parse failed: {exc}"}) + + call_expr = f"print(repr({name}(**{args!r})))" + is_ok, output = run_solution( + solution_code=call_expr, + tools_code="", + toolbox_code=toolbox.get_full_code(), + ) + if not is_ok: + return json.dumps({"error": "execution failed", "stdout": _truncate(output)}) + return _truncate(output)