From 1ba11048071d863909495b3e93d90b0bdaff6b5e Mon Sep 17 00:00:00 2001 From: "joseph.marinier" Date: Tue, 21 Apr 2026 18:33:06 -0400 Subject: [PATCH 1/3] List docker commands explicitly to clarify container ports --- README.md | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e6c1bb8..d0f118f 100644 --- a/README.md +++ b/README.md @@ -121,24 +121,16 @@ unzip gym_dbs.zip Each domain requires a running MCP server. Pull and start the Docker image for each domain: ```bash -docker pull shivakrishnareddyma225/enterpriseops-gym-mcp-:latest -docker run -d -p : shivakrishnareddyma225/enterpriseops-gym-mcp-:latest +docker run -d -p 8001:8005 shivakrishnareddyma225/enterpriseops-gym-mcp-csm:latest +docker run -d -p 8002:8005 shivakrishnareddyma225/enterpriseops-gym-mcp-teams:latest +docker run -d -p 8003:8003 shivakrishnareddyma225/enterpriseops-gym-mcp-calendar:latest +docker run -d -p 8004:8005 shivakrishnareddyma225/enterpriseops-gym-mcp-email:latest +docker run -d -p 8006:8005 shivakrishnareddyma225/enterpriseops-gym-mcp-itsm:latest +docker run -d -p 8008:8005 shivakrishnareddyma225/enterpriseops-gym-mcp-hr:latest +docker run -d -p 8009:8005 shivakrishnareddyma225/enterpriseops-gym-mcp-drive:latest ``` -Default ports: - -| Domain | MCP Server | Port | -|--------|-----------|------| -| `teams` | `gym-teams-mcp` | 8002 | -| `csm` | `sn-csm-server` | 8001 | -| `email` | `gym-email-mcp` | 8004 | -| `itsm` | `gym-itsm-mcp` | 8006 | -| `calendar` | `gym-calendar` | 8003 | -| `hr` | `sn-hr-internal` | 8008 | -| `drive` | `gym-google-drive-mcp` | 8009 | -| `` | N/A | 8005 | - -Update `conf/ray/domain_conf.json` if you use non-default ports. For `calendar` use 8003 as the container_port. +Update `conf/ray/domain_conf.json` if you use non-default host ports. For `calendar` use 8003 as the container port, and 8005 for the other domains. ### 2. LLM Config @@ -319,6 +311,7 @@ We release 60% of the benchmark samples in the public split. For completeness, w | Qwen3-30B (Think) | 21.3 | 5.0 | 53.7 | 8.7 | 18.0 | 8.8 | 26.6 | 11.4 | 17.0 | | Qwen3-235B (Inst.) | 29.5 | 4.0 | 41.8 | 10.7 | 23.0 | 14.7 | 31.2 | 19.3 | 19.6 | | Qwen3-4B (Think) | 23.0 | 3.0 | 37.3 | 5.8 | 4.9 | 7.8 | 23.4 | 15.9 | 13.6 | + --- ## 📚 Citation From 9944d7e7365513b2aa9c7957a1c42d01ad0afaeb Mon Sep 17 00:00:00 2001 From: "joseph.marinier" Date: Wed, 22 Apr 2026 11:28:42 -0400 Subject: [PATCH 2/3] Integrate a Prime Intellect Verifiers environment - Add enterpriseops_gym_env.py with load_environment() entry point that wraps the benchmark as a Verifiers ToolEnv - Delegates verification to the existing VerifierEngine (all 3 types: database_state, tool_execution, response_check) - Add prime-intellect optional dependency group for verifiers - Make package installable (replace package=false with build-system) - Add unit tests (36 tests, no Docker/API keys required) --- .prime/.env-metadata.json | 7 + README.md | 57 +++ enterpriseops_gym_env.py | 457 ++++++++++++++++++++++++ pyproject.toml | 19 +- tests/test_enterpriseops_gym_env.py | 535 ++++++++++++++++++++++++++++ 5 files changed, 1070 insertions(+), 5 deletions(-) create mode 100644 .prime/.env-metadata.json create mode 100644 enterpriseops_gym_env.py create mode 100644 tests/test_enterpriseops_gym_env.py diff --git a/.prime/.env-metadata.json b/.prime/.env-metadata.json new file mode 100644 index 0000000..e94da12 --- /dev/null +++ b/.prime/.env-metadata.json @@ -0,0 +1,7 @@ +{ + "environment_id": "qmn9n710aw681nbvu45m0p9f", + "owner": "joseph-marinier", + "name": "enterpriseops-gym-env", + "pushed_at": "2026-04-22T12:40:37.767323", + "wheel_sha256": "192f69f5254f2e181f34e29242df4ea228914225fdc4de8a5147819ce7390455" +} diff --git a/README.md b/README.md index d0f118f..7b782b2 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Unlike static datasets, tasks run against live MCP servers and are evaluated by - [🔧 Prerequisites](#-prerequisites) - [🚀 Running the Benchmark](#-running-the-benchmark) - [📊 Scoring](#-scoring) +- [🌐 Prime Intellect Environment](#-prime-intellect-environment) - [🏆 Leaderboard](#-leaderboard) - [📚 Citation](#-citation) @@ -266,6 +267,62 @@ Output: --- +## 🌐 Prime Intellect Environment + +EnterpriseOps-Gym is published on [Prime Intellect's Environment Hub](https://app.primeintellect.ai/dashboard/environments) as a [Verifiers](https://github.com/PrimeIntellect-ai/verifiers) environment. Install it from the hub and evaluate locally. + +### Install from the Environment Hub + +```bash +prime env install joseph-marinier/enterpriseops-gym-env +``` + +Or install locally from the repo: + +```bash +uv sync --extra prime-intellect +``` + +### Usage + +```python +import verifiers as vf + +# Via Verifiers discovery (after prime env install): +env = vf.load_environment("enterpriseops-gym-env", gym_dbs_path="./gym_dbs", domains=["teams"]) + +# Or import directly: +from enterpriseops_gym_env import load_environment +env = load_environment(gym_dbs_path="./gym_dbs", mode="oracle", domains=["teams"]) + +# Evaluate +client = vf.ClientConfig( + client_type="openai_chat_completions", + api_key_var="OPENAI_API_KEY", + api_base_url="https://api.openai.com/v1", +) +results = env.evaluate_sync(client=client, model="gpt-4.1") +``` + +### Configuration + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `server_urls` | localhost standard ports | MCP server name → URL mapping | +| `gym_dbs_path` | `"gym_dbs"` | Path to extracted SQL seed files | +| `hf_dataset` | `ServiceNow-AI/EnterpriseOps-Gym` | HuggingFace dataset | +| `mode` | `"oracle"` | Tool-set mode | +| `domains` | All 8 domains | Which domains to include | +| `max_turns` | `50` | Max agent turns per task | +| `llm_client` | `None` | `LLMClient` instance for `response_check` verifiers | + +### Limitations + +- **Local evaluation only** — MCP servers run as Docker containers that must be started before evaluation. Prime Intellect's hosted evaluation (`prime eval run`) is not supported since it cannot access local Docker containers. Use `env.evaluate_sync()` locally instead. +- **Single-worker** — concurrent rollouts are not yet supported. Each task uses a different `selected_tools` subset, applied by mutating shared state on the environment instance. The constructor enforces `max_workers=1`. To lift this, per-task tool definitions would need to flow through the rollout state rather than the shared instance. + +--- + ## 🏆 Leaderboard Task success rate (%) on Oracle mode on the full benchmark. A task passes only if **all** verification conditions are met. diff --git a/enterpriseops_gym_env.py b/enterpriseops_gym_env.py new file mode 100644 index 0000000..1162317 --- /dev/null +++ b/enterpriseops_gym_env.py @@ -0,0 +1,457 @@ +"""EnterpriseOps-Gym adapter for the Prime Intellect Verifiers framework. + +Wraps ServiceNow's EnterpriseOps-Gym benchmark (1,150 enterprise agent tasks across +8 domains, verified via SQL state checks) as a ``vf.ToolEnv`` so it can be hosted on +Prime Intellect's Environment Hub. +""" + +import asyncio +import concurrent.futures +import json +import logging +import os +from typing import Any, cast + +from datasets import Dataset, load_dataset + +import verifiers as vf +from verifiers.types import Messages, State, Tool, ToolMessage + +from benchmark.llm_client import LLMClient +from benchmark.mcp_client import MCPClient, create_database_from_file, delete_database +from benchmark.models import VerifierConfig +from benchmark.verifier import VerifierEngine + +logger = logging.getLogger(__name__) + +# MCP server names (as referenced in the HuggingFace dataset's gym_servers_config) +# mapped to the default localhost ports from EnterpriseOps-Gym's Docker setup. +# IMPORTANT: Most containers listen internally on port 8005 (calendar on 8003). +# Use e.g. `docker run -p 8002:8005 ...teams:latest` — see README for full commands. +DEFAULT_SERVER_URLS: dict[str, str] = { + "sn-csm-server": "http://localhost:8001", + "gym-teams-mcp": "http://localhost:8002", + "gym-calendar": "http://localhost:8003", + "gym-email-mcp": "http://localhost:8004", + "gym-itsm-mcp": "http://localhost:8006", + "sn-hr-internal": "http://localhost:8008", + "gym-google-drive-mcp": "http://localhost:8009", +} + +ALL_DOMAINS = ["teams", "csm", "calendar", "email", "itsm", "hr", "drive", "hybrid"] + + +# -- Helpers ------------------------------------------------------------------ + + +def _run_sync(coro: Any) -> Any: + """Run an async coroutine from a synchronous context. + + Uses ``asyncio.run()`` when no event loop is running, otherwise falls back + to executing in a thread (e.g. inside Jupyter). + """ + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + + +def _noop(**_kw: Any) -> None: # noqa: D401 + """Placeholder callable for ``tool_map``; never invoked because ``call_tool`` is overridden.""" + + +def _parse_info(raw: str | dict[str, Any]) -> dict[str, Any]: + """Deserialize the info field, which is stored as a JSON string in the dataset.""" + if isinstance(raw, str): + return json.loads(raw) + return raw + + +def _resolve_sql_path(seed_database_file: str, gym_dbs_path: str) -> str | None: + """Try several candidate locations for a SQL seed file.""" + candidates = [ + seed_database_file, + os.path.join(gym_dbs_path, seed_database_file), + os.path.join(gym_dbs_path, os.path.basename(seed_database_file)), + ] + for path in candidates: + if os.path.isfile(path): + return path + return None + + +def _mcp_content_to_str(result: Any) -> str: + """Flatten an MCP tool result into a plain string for the conversation.""" + if isinstance(result, str): + return result + if isinstance(result, dict): + content = result.get("content", []) + if isinstance(content, list): + parts = [ + item["text"] if isinstance(item, dict) and item.get("type") == "text" else str(item) + for item in content + ] + if parts: + return "\n".join(parts) + return str(content) if content else json.dumps(result) + return str(result) + + +# -- Environment -------------------------------------------------------------- + + +class EOpsGymEnv(vf.ToolEnv): + """EnterpriseOps-Gym as a Verifiers ``ToolEnv``. + + Connects to pre-running MCP Docker servers at init, discovers all tools, + then per rollout: seeds fresh databases, scopes the tool set to the task's + ``selected_tools``, runs the agent loop, scores via the benchmark's + ``VerifierEngine``, and cleans up databases. + + Note: + Concurrent rollouts are not supported. Each task has a different ``selected_tools`` + subset, and the current implementation applies this by mutating ``self.tool_defs`` + on the shared environment instance. With ``max_workers > 1``, concurrent rollouts + would see each other's tool sets. This is enforced by clamping ``max_workers=1`` + in the constructor. See the README for ideas on lifting this limitation. + """ + + def __init__( + self, + server_urls: dict[str, str], + gym_dbs_path: str, + max_turns: int = 50, + llm_client: LLMClient | None = None, + **kwargs: Any, + ): + if kwargs.get("max_workers", 1) != 1: + logger.warning( + "EOpsGymEnv does not support concurrent rollouts (per-task tool_defs mutation). " + "Forcing max_workers=1." + ) + kwargs["max_workers"] = 1 + + self.server_urls = server_urls + self.gym_dbs_path = gym_dbs_path + self.llm_client = llm_client + self.clients: dict[str, MCPClient] = {} + self.tool_to_server: dict[str, str] = {} + self._all_tool_defs: list[Tool] = [] + + self._connect_and_discover() + + super().__init__(tools=[], max_turns=max_turns, **kwargs) + self.tool_defs = list(self._all_tool_defs) + self.tool_map = {t.name: _noop for t in self._all_tool_defs} + + # -- Init helpers --------------------------------------------------------- + + def _connect_and_discover(self) -> None: + """Connect to every configured MCP server and merge their tool catalogues.""" + for name, url in self.server_urls.items(): + client = MCPClient(base_url=url) + if not _run_sync(client.connect()): + logger.warning("Could not connect to %s at %s — skipping", name, url) + continue + self.clients[name] = client + + for schema in _run_sync(client.list_tools()): + tool_name = schema["name"] + if tool_name in self.tool_to_server: + logger.debug("Duplicate tool '%s' — keeping first occurrence", tool_name) + continue + self.tool_to_server[tool_name] = name + self._all_tool_defs.append( + Tool( + name=tool_name, + description=schema.get("description", ""), + parameters=schema.get("inputSchema", {"type": "object", "properties": {}}), + ) + ) + + logger.info( + "EOpsGymEnv: discovered %d tools across %d servers", len(self._all_tool_defs), len(self.clients) + ) + + # -- Per-rollout lifecycle ------------------------------------------------ + + async def setup_state(self, state: State) -> State: + """Seed databases and scope tools for the current task.""" + state = await super().setup_state(state) + + # Retrieve per-task metadata from the dataset row + info = _parse_info(state["input"]["info"]) + + gym_configs: list[dict] = info.get("gym_servers_config", []) + selected: list[str] = info.get("selected_tools", []) + + # Seed a database for each MCP server this task uses + db_ids: dict[str, str] = {} + url_map: dict[str, str] = {} + + for cfg in gym_configs: + srv = cfg["mcp_server_name"] + url = self.server_urls.get(srv, cfg.get("mcp_server_url", "")) + url_map[srv] = url + + seed_file = cfg.get("seed_database_file", "") + if not seed_file: + continue + + sql_path = _resolve_sql_path(seed_file, self.gym_dbs_path) + if not sql_path: + logger.warning("SQL seed file not found for %s: %s", srv, seed_file) + continue + + db_id = create_database_from_file(url, sql_path) + if not db_id: + logger.warning("Failed to create database for %s", srv) + continue + db_ids[srv] = db_id + if srv in self.clients: + self.clients[srv].database_id = db_id + + state["database_ids"] = db_ids + state["server_url_map"] = url_map + + # Restrict visible tools to this task's selected set + if selected: + self.tool_defs = [t for t in self._all_tool_defs if t.name in selected] + else: + self.tool_defs = list(self._all_tool_defs) + self.tool_map = {t.name: _noop for t in self.tool_defs} + + return state + + async def call_tool( + self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs: Any + ) -> ToolMessage: + """Route a tool call to the owning MCP server via HTTP JSON-RPC.""" + server_name = self.tool_to_server.get(tool_name) + if not server_name: + return cast( + ToolMessage, + {"role": "tool", "content": f"Error: unknown tool '{tool_name}'", "tool_call_id": tool_call_id}, + ) + + client = self.clients.get(server_name) + if not client: + return cast( + ToolMessage, + {"role": "tool", "content": f"Error: server '{server_name}' not connected", "tool_call_id": tool_call_id}, + ) + + result = await client.call_tool(tool_name, tool_args) + if result.get("success"): + content = _mcp_content_to_str(result.get("result", "")) + else: + content = f"Error: {result.get('error', 'unknown')}" + + return cast(ToolMessage, {"role": "tool", "content": content, "tool_call_id": tool_call_id}) + + # NOTE: Database cleanup is intentionally NOT done via @vf.cleanup because that hook + # runs inside rollout(), BEFORE the rubric scores the state. The rubric needs the + # databases alive to run SQL verifiers. Instead, cleanup is done via rubric.cleanup() + # which runs AFTER scoring. See _build_rubric(). + + @vf.teardown + async def _teardown(self) -> None: + """Release MCP clients on environment shutdown.""" + self.clients.clear() + + +# -- Dataset builder ---------------------------------------------------------- + + +def _build_dataset(hf_repo: str, mode: str, domains: list[str]) -> Dataset: + """Load the HuggingFace dataset and reshape into Verifiers format. + + Each row becomes: + prompt: [{role: system, ...}, {role: user, ...}] + answer: "" (verification is SQL-based, not text-based) + info: {task_id, domain, selected_tools, verifiers, gym_servers_config, ...} + """ + json_fields = {"gym_servers_config", "verifiers"} + rows: list[dict[str, Any]] = [] + + for domain in domains: + logger.info("Loading HF dataset: %s config=%s split=%s", hf_repo, mode, domain) + ds = load_dataset(hf_repo, mode, split=domain) + for raw in ds: + info: dict[str, Any] = {} + for k, v in raw.items(): + if k in ("system_prompt", "user_prompt"): + continue + if k in json_fields and isinstance(v, str): + info[k] = json.loads(v) + else: + info[k] = v + + rows.append({ + "prompt": [ + {"role": "system", "content": raw["system_prompt"]}, + {"role": "user", "content": raw["user_prompt"]}, + ], + "answer": "", + "info": json.dumps(info), # serialized to avoid Arrow type-inference conflicts + }) + + logger.info("Dataset built: %d tasks across %s", len(rows), domains) + return Dataset.from_dict({ + "prompt": [r["prompt"] for r in rows], + "answer": [r["answer"] for r in rows], + "info": [r["info"] for r in rows], + }) + + +# -- Rubric builder ----------------------------------------------------------- + + +def _collect_tool_calls(state: State) -> list[str]: + """Extract the list of tool names called during the rollout from the trajectory.""" + tool_names: list[str] = [] + for step in state["trajectory"]: + for msg in step["completion"]: + if msg.role != "assistant": + continue + for tc in msg.tool_calls or []: + if tc.name: + tool_names.append(tc.name) + return tool_names + + +def _build_rubric(server_urls: dict[str, str], llm_client: LLMClient | None = None) -> vf.Rubric: + """Create a rubric that scores rollouts using the benchmark's ``VerifierEngine``. + + Delegates all verification logic (database_state, tool_execution, response_check) + to ``benchmark.verifier.VerifierEngine``, avoiding duplicated verification code. + + Args: + server_urls: MCP server name to base URL mapping (for constructing per-verifier clients). + llm_client: Optional ``benchmark.llm_client.LLMClient`` for ``response_check`` verifiers. + """ + + async def verification(completion: Messages, answer: str, state: State, info: str) -> float: + """Fraction of verifiers that pass, delegated to VerifierEngine.""" + info = _parse_info(info) + verifier_configs: list[dict] = info.get("verifiers", []) + if not verifier_configs: + return 0.0 + + db_ids: dict[str, str] = state.get("database_ids", {}) + url_map: dict[str, str] = state.get("server_url_map", {}) + + # Build MCP clients dict for VerifierEngine (keyed by gym_name) + mcp_clients: dict[str, MCPClient] = {} + for gym_name in {v.get("gym_name") for v in verifier_configs if v.get("gym_name")}: + base_url = url_map.get(gym_name, server_urls.get(gym_name, "")) + db_id = db_ids.get(gym_name, "") + if base_url: + mcp_clients[gym_name] = MCPClient(base_url=base_url, database_id=db_id) + + # Skip response_check verifiers if no llm_client is configured + runnable = [] + for v_cfg in verifier_configs: + if v_cfg.get("verifier_type") == "response_check" and llm_client is None: + logger.warning( + "Skipping response_check verifier (no llm_client configured). " + "Pass llm_client to load_environment() to enable." + ) + continue + runnable.append(v_cfg) + + if not runnable: + return 0.0 + + engine = VerifierEngine(mcp_clients, llm_client) + + # Build model_response for tool_execution and response_check verifiers + tools_called = _collect_tool_calls(state) + model_response = { + "content": completion and completion[-1].content or "", + "tool_calls": [{"name": t, "args": {}} for t in tools_called], + } + + passed = 0 + total = len(runnable) + + for v_cfg in runnable: + verifier = VerifierConfig(**v_cfg) + db_id = db_ids.get(verifier.gym_name, "") + try: + result = await engine.execute_verifier(verifier, model_response, db_id, gym_name=verifier.gym_name) + if result.get("passed"): + passed += 1 + except Exception: + logger.debug("Verifier failed", exc_info=True) + + return passed / total + + async def all_pass(completion: Messages, answer: str, state: State, info: str) -> float: + """Binary metric: 1.0 only if every verifier passes.""" + info = _parse_info(info) + if not info.get("verifiers"): + return 0.0 + score = await verification(completion, answer, state, info) + return 1.0 if score == 1.0 else 0.0 + + async def cleanup_databases(state: State) -> None: + """Delete databases created for this rollout. Runs after scoring.""" + url_map: dict[str, str] = state.get("server_url_map", {}) + for srv, db_id in state.get("database_ids", {}).items(): + url = url_map.get(srv, server_urls.get(srv, "")) + if url: + delete_database(url, db_id) + + rubric = vf.Rubric(funcs=[verification], weights=[1.0]) + rubric.add_metric(all_pass) + rubric._cleanup_handlers.append(cleanup_databases) + return rubric + + +# -- Entry point -------------------------------------------------------------- + + +def load_environment( + server_urls: dict[str, str] | None = None, + gym_dbs_path: str = "gym_dbs", + hf_dataset: str = "ServiceNow-AI/EnterpriseOps-Gym", + mode: str = "oracle", + domains: list[str] | None = None, + max_turns: int = 50, + llm_client: LLMClient | None = None, + **kwargs: Any, +) -> vf.Environment: + """Load EnterpriseOps-Gym as a Verifiers environment. + + Args: + server_urls: MCP server name to base URL mapping. + Defaults to localhost with the standard ports from the EnterpriseOps-Gym Docker setup. + gym_dbs_path: Path to directory with extracted SQL seed files (from ``gym_dbs.zip``). + hf_dataset: HuggingFace dataset repo ID. + mode: Tool-set mode (``oracle``, ``plus_5_tools``, ``plus_10_tools``, ``plus_15_tools``). + domains: Which domains to include. Defaults to all 8 (7 single-domain + hybrid). + max_turns: Maximum agent turns per task. + llm_client: Optional ``benchmark.llm_client.LLMClient`` instance for ``response_check`` + verifiers. If not provided, ``response_check`` verifiers will fail gracefully. + + Returns: + A configured ``EOpsGymEnv`` ready for evaluation. + """ + urls = server_urls or dict(DEFAULT_SERVER_URLS) + doms = domains or list(ALL_DOMAINS) + + dataset = _build_dataset(hf_dataset, mode, doms) + rubric = _build_rubric(urls, llm_client) + + return EOpsGymEnv( + server_urls=urls, + gym_dbs_path=gym_dbs_path, + max_turns=max_turns, + llm_client=llm_client, + dataset=dataset, + rubric=rubric, + **kwargs, + ) diff --git a/pyproject.toml b/pyproject.toml index 8dc1908..9704362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] -name = "EnterpriseOps-Gym" -version = "0.1.0" +name = "enterpriseops-gym-env" +version = "0.1.1" description = "Benchmarking framework for evaluating LLMs on enterprise agentic tasks." requires-python = ">=3.11" dependencies = [ @@ -42,7 +42,10 @@ qwq = [ "langchain-qwq>=0.1.0", ] all = [ - "EnterpriseOps-Gym[anthropic,openai,google,deepseek,qwq]", + "enterpriseops-gym-env[anthropic,openai,google,deepseek,qwq]", +] +prime-intellect = [ + "verifiers>=0.1.12", ] [dependency-groups] @@ -50,7 +53,13 @@ dev = [ "ipykernel>=6.0.0", "ipython>=8.0.0", "matplotlib>=3.7.0", + "pytest>=7.0.0", ] -[tool.uv] -package = false +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["benchmark", "orchestrators", "utils"] +force-include = {"enterpriseops_gym_env.py" = "enterpriseops_gym_env.py"} diff --git a/tests/test_enterpriseops_gym_env.py b/tests/test_enterpriseops_gym_env.py new file mode 100644 index 0000000..9dad782 --- /dev/null +++ b/tests/test_enterpriseops_gym_env.py @@ -0,0 +1,535 @@ +"""Tests for the Prime Intellect Verifiers environment adapter. + +These tests run without Docker or API keys by mocking the MCP and verification layers. +""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from verifiers.types import AssistantMessage, State, ToolCall, ToolMessage as VFToolMessage, TrajectoryStep + +from enterpriseops_gym_env import ( + ALL_DOMAINS, + DEFAULT_SERVER_URLS, + EOpsGymEnv, + _build_dataset, + _build_rubric, + _collect_tool_calls, + _mcp_content_to_str, + _noop, + _parse_info, + _resolve_sql_path, +) + + +# -- Helpers ------------------------------------------------------------------ + + +class TestParseInfo: + def test_json_string(self): + result = _parse_info('{"key": "value"}') + assert result == {"key": "value"} + + def test_dict_passthrough(self): + d = {"key": "value"} + assert _parse_info(d) is d + + def test_nested_json(self): + data = {"verifiers": [{"type": "database_state"}], "selected_tools": ["a", "b"]} + result = _parse_info(json.dumps(data)) + assert result == data + + +class TestResolveSqlPath: + def test_absolute_path(self, tmp_path): + sql = tmp_path / "test.sql" + sql.write_text("CREATE TABLE t;") + assert _resolve_sql_path(str(sql), "/nonexistent") == str(sql) + + def test_relative_to_gym_dbs(self, tmp_path): + sql = tmp_path / "domain" / "db.sql" + sql.parent.mkdir() + sql.write_text("CREATE TABLE t;") + assert _resolve_sql_path("domain/db.sql", str(tmp_path)) == str(sql) + + def test_basename_fallback(self, tmp_path): + sql = tmp_path / "db.sql" + sql.write_text("CREATE TABLE t;") + assert _resolve_sql_path("some/nested/path/db.sql", str(tmp_path)) == str(sql) + + def test_not_found(self): + assert _resolve_sql_path("nonexistent.sql", "/nonexistent") is None + + +class TestMcpContentToStr: + def test_string_passthrough(self): + assert _mcp_content_to_str("hello") == "hello" + + def test_text_content_list(self): + result = {"content": [{"type": "text", "text": "line1"}, {"type": "text", "text": "line2"}]} + assert _mcp_content_to_str(result) == "line1\nline2" + + def test_empty_dict(self): + assert _mcp_content_to_str({}) == "{}" + + def test_non_text_content(self): + result = {"content": [{"type": "image", "url": "http://example.com"}]} + assert "image" in _mcp_content_to_str(result) + + def test_other_type(self): + assert _mcp_content_to_str(42) == "42" + assert _mcp_content_to_str(None) == "None" + + +# -- _collect_tool_calls ------------------------------------------------------ + + +class TestCollectToolCalls: + def _make_state(self, steps: list) -> State: + return State(trajectory=steps, database_ids={}, server_url_map={}) + + def test_extracts_tool_names(self): + step = TrajectoryStep( + completion=[ + AssistantMessage(role="assistant", content=None, tool_calls=[ + ToolCall(id="tc1", name="list_users", arguments="{}"), + ToolCall(id="tc2", name="create_chat", arguments="{}"), + ]), + VFToolMessage(role="tool", content="result1", tool_call_id="tc1"), + VFToolMessage(role="tool", content="result2", tool_call_id="tc2"), + ], + prompt=[], response=None, tokens=None, + reward=None, advantage=None, is_truncated=False, trajectory_id="t1", extras={}, + ) + assert _collect_tool_calls(self._make_state([step])) == ["list_users", "create_chat"] + + def test_skips_non_assistant_messages(self): + step = TrajectoryStep( + completion=[VFToolMessage(role="tool", content="result", tool_call_id="tc1")], + prompt=[], response=None, tokens=None, + reward=None, advantage=None, is_truncated=False, trajectory_id="t1", extras={}, + ) + assert _collect_tool_calls(self._make_state([step])) == [] + + def test_handles_none_tool_calls(self): + step = TrajectoryStep( + completion=[AssistantMessage(role="assistant", content="done", tool_calls=None)], + prompt=[], response=None, tokens=None, + reward=None, advantage=None, is_truncated=False, trajectory_id="t1", extras={}, + ) + assert _collect_tool_calls(self._make_state([step])) == [] + + def test_multiple_steps(self): + make_step = lambda names: TrajectoryStep( + completion=[AssistantMessage( + role="assistant", content=None, + tool_calls=[ToolCall(id=f"tc_{n}", name=n, arguments="{}") for n in names], + )], + prompt=[], response=None, tokens=None, + reward=None, advantage=None, is_truncated=False, trajectory_id="t1", extras={}, + ) + state = self._make_state([make_step(["a", "b"]), make_step(["c"])]) + assert _collect_tool_calls(state) == ["a", "b", "c"] + + def test_empty_trajectory(self): + assert _collect_tool_calls(self._make_state([])) == [] + + +# -- EOpsGymEnv --------------------------------------------------------------- + + +class TestEOpsGymEnvInit: + @patch("enterpriseops_gym_env.MCPClient") + def test_max_workers_forced_to_one(self, mock_mcp_cls): + mock_mcp_cls.return_value = AsyncMock() + mock_mcp_cls.return_value.connect.return_value = False + import verifiers as vf + env = EOpsGymEnv( + server_urls={}, + gym_dbs_path="gym_dbs", + max_workers=4, + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + assert env.max_workers == 1 + + @patch("enterpriseops_gym_env.MCPClient") + def test_connects_and_discovers_tools(self, mock_mcp_cls): + mock_client = AsyncMock() + mock_client.connect.return_value = True + mock_client.list_tools.return_value = [ + {"name": "tool_a", "description": "A", "inputSchema": {"type": "object", "properties": {}}}, + {"name": "tool_b", "description": "B"}, + ] + mock_mcp_cls.return_value = mock_client + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"server1": "http://localhost:9999"}, + gym_dbs_path="gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + assert len(env._all_tool_defs) == 2 + assert env._all_tool_defs[0].name == "tool_a" + assert env._all_tool_defs[1].name == "tool_b" + assert env.tool_to_server == {"tool_a": "server1", "tool_b": "server1"} + assert "server1" in env.clients + + @patch("enterpriseops_gym_env.MCPClient") + def test_skips_unreachable_servers(self, mock_mcp_cls): + mock_client = AsyncMock() + mock_client.connect.return_value = False + mock_mcp_cls.return_value = mock_client + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"dead_server": "http://localhost:9999"}, + gym_dbs_path="gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + assert len(env.clients) == 0 + assert len(env._all_tool_defs) == 0 + + @patch("enterpriseops_gym_env.MCPClient") + def test_deduplicates_tools_across_servers(self, mock_mcp_cls): + tool_schema = [{"name": "shared_tool", "description": "shared"}] + + def make_client(): + c = AsyncMock() + c.connect.return_value = True + c.list_tools.return_value = tool_schema + return c + + mock_mcp_cls.side_effect = [make_client(), make_client()] + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"server1": "http://localhost:8001", "server2": "http://localhost:8002"}, + gym_dbs_path="gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + assert len(env._all_tool_defs) == 1 + assert env.tool_to_server["shared_tool"] == "server1" + + +class TestCallTool: + @patch("enterpriseops_gym_env.MCPClient") + def test_routes_to_correct_server(self, mock_mcp_cls): + mock_client = AsyncMock() + mock_client.connect.return_value = True + mock_client.list_tools.return_value = [{"name": "my_tool", "description": ""}] + mock_client.call_tool = AsyncMock(return_value={ + "success": True, + "result": {"content": [{"type": "text", "text": "tool result"}]}, + }) + mock_mcp_cls.return_value = mock_client + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"srv": "http://localhost:9999"}, + gym_dbs_path="gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + result = asyncio.run(env.call_tool("my_tool", {"arg": "val"}, "call-1")) + assert result["content"] == "tool result" + assert result["tool_call_id"] == "call-1" + mock_client.call_tool.assert_called_once_with("my_tool", {"arg": "val"}) + + @patch("enterpriseops_gym_env.MCPClient") + def test_unknown_tool(self, mock_mcp_cls): + mock_mcp_cls.return_value = AsyncMock() + mock_mcp_cls.return_value.connect.return_value = False + + import verifiers as vf + env = EOpsGymEnv( + server_urls={}, + gym_dbs_path="gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + result = asyncio.run(env.call_tool("nonexistent", {}, "call-1")) + assert "Error" in result["content"] + assert "nonexistent" in result["content"] + + @patch("enterpriseops_gym_env.MCPClient") + def test_tool_call_failure(self, mock_mcp_cls): + mock_client = AsyncMock() + mock_client.connect.return_value = True + mock_client.list_tools.return_value = [{"name": "failing_tool", "description": ""}] + mock_client.call_tool = AsyncMock(return_value={"success": False, "error": "server error"}) + mock_mcp_cls.return_value = mock_client + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"srv": "http://localhost:9999"}, + gym_dbs_path="gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + result = asyncio.run(env.call_tool("failing_tool", {}, "call-1")) + assert "server error" in result["content"] + + +# -- Rubric / Verification --------------------------------------------------- + + +class TestBuildRubric: + def test_response_check_skipped_without_llm_client(self): + rubric = _build_rubric(DEFAULT_SERVER_URLS, llm_client=None) + verify_fn = rubric.funcs[0] + + info = json.dumps({ + "verifiers": [{"verifier_type": "response_check", "validation_config": {}}], + }) + state = State(database_ids={}, server_url_map={}, trajectory=[]) + + score = asyncio.run(verify_fn(completion=[], answer="", state=state, info=info)) + assert score == 0.0 + + def test_tool_execution_verifier(self): + rubric = _build_rubric(DEFAULT_SERVER_URLS, llm_client=None) + verify_fn = rubric.funcs[0] + + info = json.dumps({ + "verifiers": [{ + "verifier_type": "tool_execution", + "validation_config": {"selected_tools": ["list_users"], "minimum_tool_calls": 1}, + }], + }) + + step = TrajectoryStep( + completion=[AssistantMessage( + role="assistant", content=None, + tool_calls=[ToolCall(id="tc1", name="list_users", arguments="{}")], + )], + prompt=[], response=None, tokens=None, + reward=None, advantage=None, is_truncated=False, trajectory_id="t1", extras={}, + ) + state = State(database_ids={}, server_url_map={}, trajectory=[step]) + + with patch("enterpriseops_gym_env.VerifierEngine") as mock_engine_cls: + mock_engine = MagicMock() + mock_engine.execute_verifier = AsyncMock(return_value={"passed": True}) + mock_engine_cls.return_value = mock_engine + + score = asyncio.run(verify_fn(completion=[], answer="", state=state, info=info)) + + assert score == 1.0 + + def test_tool_execution_verifier_fails(self): + rubric = _build_rubric(DEFAULT_SERVER_URLS, llm_client=None) + verify_fn = rubric.funcs[0] + + info = json.dumps({ + "verifiers": [{ + "verifier_type": "tool_execution", + "validation_config": {"selected_tools": ["list_users"], "minimum_tool_calls": 1}, + }], + }) + state = State(database_ids={}, server_url_map={}, trajectory=[]) + + with patch("enterpriseops_gym_env.VerifierEngine") as mock_engine_cls: + mock_engine = MagicMock() + mock_engine.execute_verifier = AsyncMock(return_value={"passed": False}) + mock_engine_cls.return_value = mock_engine + + score = asyncio.run(verify_fn(completion=[], answer="", state=state, info=info)) + + assert score == 0.0 + + def test_empty_verifiers(self): + rubric = _build_rubric(DEFAULT_SERVER_URLS) + verify_fn = rubric.funcs[0] + + info = json.dumps({"verifiers": []}) + state = State(database_ids={}, server_url_map={}, trajectory=[]) + + score = asyncio.run(verify_fn(completion=[], answer="", state=state, info=info)) + assert score == 0.0 + + def test_all_pass_metric(self): + rubric = _build_rubric(DEFAULT_SERVER_URLS) + all_pass_fn = rubric.funcs[1] # add_metric appends to funcs with weight=0 + + info = json.dumps({ + "verifiers": [{ + "verifier_type": "tool_execution", + "validation_config": {"selected_tools": ["a"], "minimum_tool_calls": 1}, + }], + }) + + step = TrajectoryStep( + completion=[AssistantMessage( + role="assistant", content=None, + tool_calls=[ToolCall(id="tc1", name="a", arguments="{}")], + )], + prompt=[], response=None, tokens=None, + reward=None, advantage=None, is_truncated=False, trajectory_id="t1", extras={}, + ) + state = State(database_ids={}, server_url_map={}, trajectory=[step]) + + with patch("enterpriseops_gym_env.VerifierEngine") as mock_engine_cls: + mock_engine = MagicMock() + mock_engine.execute_verifier = AsyncMock(return_value={"passed": True}) + mock_engine_cls.return_value = mock_engine + + score = asyncio.run(all_pass_fn(completion=[], answer="", state=state, info=info)) + + assert score == 1.0 + + def test_cleanup_registered_on_rubric_not_env(self): + """Database cleanup must run AFTER scoring (via rubric.cleanup), not during + rollout (via @vf.cleanup), because verifiers need the database alive.""" + rubric = _build_rubric({"srv": "http://localhost:8002"}) + assert len(rubric._cleanup_handlers) == 1, "Expected 1 cleanup handler on rubric" + + # Verify EOpsGymEnv has NO @vf.cleanup methods + import inspect + import verifiers as vf + cleanup_methods = [ + name for name, method in inspect.getmembers(EOpsGymEnv, predicate=inspect.isfunction) + if hasattr(method, "__vf_cleanup__") or (hasattr(method, "__wrapped__") and "cleanup" in name) + ] + assert cleanup_methods == [], f"EOpsGymEnv should not have @vf.cleanup methods, found: {cleanup_methods}" + + @patch("enterpriseops_gym_env.delete_database") + def test_cleanup_deletes_databases(self, mock_delete): + rubric = _build_rubric({"srv": "http://localhost:8002"}) + cleanup_fn = rubric._cleanup_handlers[0] + + state = State( + database_ids={"srv": "db_123"}, + server_url_map={"srv": "http://localhost:8002"}, + trajectory=[], + ) + asyncio.run(cleanup_fn(state)) + mock_delete.assert_called_once_with("http://localhost:8002", "db_123") + + def test_mixed_verifiers_partial_pass(self): + rubric = _build_rubric(DEFAULT_SERVER_URLS) + verify_fn = rubric.funcs[0] + + info = json.dumps({ + "verifiers": [ + {"verifier_type": "database_state", "validation_config": {"query": "SELECT 1"}, "gym_name": "srv"}, + {"verifier_type": "database_state", "validation_config": {"query": "SELECT 2"}, "gym_name": "srv"}, + ], + }) + state = State(database_ids={"srv": "db1"}, server_url_map={"srv": "http://localhost:8002"}, trajectory=[]) + + with patch("enterpriseops_gym_env.VerifierEngine") as mock_engine_cls: + mock_engine = MagicMock() + mock_engine.execute_verifier = AsyncMock(side_effect=[ + {"passed": True}, {"passed": False}, + ]) + mock_engine_cls.return_value = mock_engine + + score = asyncio.run(verify_fn(completion=[], answer="", state=state, info=info)) + + assert score == 0.5 + + +# -- Setup state -------------------------------------------------------------- + + +class TestSetupState: + @patch("enterpriseops_gym_env.create_database_from_file") + @patch("enterpriseops_gym_env.MCPClient") + def test_seeds_database_and_filters_tools(self, mock_mcp_cls, mock_create_db): + mock_client = AsyncMock() + mock_client.connect.return_value = True + mock_client.list_tools.return_value = [ + {"name": "tool_a", "description": "A"}, + {"name": "tool_b", "description": "B"}, + {"name": "tool_c", "description": "C"}, + ] + mock_mcp_cls.return_value = mock_client + mock_create_db.return_value = "db_123" + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"srv": "http://localhost:8002"}, + gym_dbs_path="/fake/gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + assert len(env._all_tool_defs) == 3 + + info = json.dumps({ + "gym_servers_config": [{"mcp_server_name": "srv", "mcp_server_url": "http://localhost:8002"}], + "selected_tools": ["tool_a", "tool_c"], + }) + state = State(input={"info": info, "prompt": [], "answer": ""}) + + async def run(): + return await env.setup_state(state) + + result = asyncio.run(run()) + + # No seed file, so no DB created + assert result["database_ids"] == {} + assert result["server_url_map"] == {"srv": "http://localhost:8002"} + # Tools filtered to selected + assert len(env.tool_defs) == 2 + assert {t.name for t in env.tool_defs} == {"tool_a", "tool_c"} + + @patch("enterpriseops_gym_env._resolve_sql_path", return_value="/fake/db.sql") + @patch("enterpriseops_gym_env.create_database_from_file") + @patch("enterpriseops_gym_env.MCPClient") + def test_handles_create_db_failure(self, mock_mcp_cls, mock_create_db, mock_resolve): + mock_client = AsyncMock() + mock_client.connect.return_value = True + mock_client.list_tools.return_value = [] + mock_mcp_cls.return_value = mock_client + mock_create_db.return_value = None # failure + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"srv": "http://localhost:8002"}, + gym_dbs_path="/fake/gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + info = json.dumps({ + "gym_servers_config": [{ + "mcp_server_name": "srv", + "mcp_server_url": "http://localhost:8002", + "seed_database_file": "db.sql", + }], + "selected_tools": [], + }) + state = State(input={"info": info, "prompt": [], "answer": ""}) + + result = asyncio.run(env.setup_state(state)) + + assert result["database_ids"] == {} # None was not stored + + +# -- Constants ---------------------------------------------------------------- + + +class TestConstants: + def test_all_domains_includes_hybrid(self): + assert "hybrid" in ALL_DOMAINS + assert len(ALL_DOMAINS) == 8 + + def test_default_server_urls_match_hf_data(self): + expected_servers = { + "sn-csm-server", "gym-teams-mcp", "gym-calendar", + "gym-email-mcp", "gym-itsm-mcp", "sn-hr-internal", "gym-google-drive-mcp", + } + assert set(DEFAULT_SERVER_URLS.keys()) == expected_servers From 6a6b2a3478968a75e4f3785c18c44959792a96fe Mon Sep 17 00:00:00 2001 From: "joseph.marinier" Date: Wed, 22 Apr 2026 15:45:19 -0400 Subject: [PATCH 3/3] Fix tool filtering and remove single-worker limitation Set `state["tool_defs"]` instead of `self.tool_defs` because `get_model_response()` reads from `state`. Since we don't mutate shared state anymore, concurrent rollouts are safe. --- README.md | 1 - enterpriseops_gym_env.py | 23 ++---------- tests/test_enterpriseops_gym_env.py | 55 ++++++++++++++++++----------- 3 files changed, 36 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 7b782b2..5a1b223 100644 --- a/README.md +++ b/README.md @@ -319,7 +319,6 @@ results = env.evaluate_sync(client=client, model="gpt-4.1") ### Limitations - **Local evaluation only** — MCP servers run as Docker containers that must be started before evaluation. Prime Intellect's hosted evaluation (`prime eval run`) is not supported since it cannot access local Docker containers. Use `env.evaluate_sync()` locally instead. -- **Single-worker** — concurrent rollouts are not yet supported. Each task uses a different `selected_tools` subset, applied by mutating shared state on the environment instance. The constructor enforces `max_workers=1`. To lift this, per-task tool definitions would need to flow through the rollout state rather than the shared instance. --- diff --git a/enterpriseops_gym_env.py b/enterpriseops_gym_env.py index 1162317..72dbc3a 100644 --- a/enterpriseops_gym_env.py +++ b/enterpriseops_gym_env.py @@ -58,9 +58,6 @@ def _run_sync(coro: Any) -> Any: return pool.submit(asyncio.run, coro).result() -def _noop(**_kw: Any) -> None: # noqa: D401 - """Placeholder callable for ``tool_map``; never invoked because ``call_tool`` is overridden.""" - def _parse_info(raw: str | dict[str, Any]) -> dict[str, Any]: """Deserialize the info field, which is stored as a JSON string in the dataset.""" @@ -109,13 +106,6 @@ class EOpsGymEnv(vf.ToolEnv): then per rollout: seeds fresh databases, scopes the tool set to the task's ``selected_tools``, runs the agent loop, scores via the benchmark's ``VerifierEngine``, and cleans up databases. - - Note: - Concurrent rollouts are not supported. Each task has a different ``selected_tools`` - subset, and the current implementation applies this by mutating ``self.tool_defs`` - on the shared environment instance. With ``max_workers > 1``, concurrent rollouts - would see each other's tool sets. This is enforced by clamping ``max_workers=1`` - in the constructor. See the README for ideas on lifting this limitation. """ def __init__( @@ -126,13 +116,6 @@ def __init__( llm_client: LLMClient | None = None, **kwargs: Any, ): - if kwargs.get("max_workers", 1) != 1: - logger.warning( - "EOpsGymEnv does not support concurrent rollouts (per-task tool_defs mutation). " - "Forcing max_workers=1." - ) - kwargs["max_workers"] = 1 - self.server_urls = server_urls self.gym_dbs_path = gym_dbs_path self.llm_client = llm_client @@ -144,7 +127,6 @@ def __init__( super().__init__(tools=[], max_turns=max_turns, **kwargs) self.tool_defs = list(self._all_tool_defs) - self.tool_map = {t.name: _noop for t in self._all_tool_defs} # -- Init helpers --------------------------------------------------------- @@ -218,10 +200,9 @@ async def setup_state(self, state: State) -> State: # Restrict visible tools to this task's selected set if selected: - self.tool_defs = [t for t in self._all_tool_defs if t.name in selected] + state["tool_defs"] = [t for t in self._all_tool_defs if t.name in selected] else: - self.tool_defs = list(self._all_tool_defs) - self.tool_map = {t.name: _noop for t in self.tool_defs} + state["tool_defs"] = list(self._all_tool_defs) return state diff --git a/tests/test_enterpriseops_gym_env.py b/tests/test_enterpriseops_gym_env.py index 9dad782..1caf92f 100644 --- a/tests/test_enterpriseops_gym_env.py +++ b/tests/test_enterpriseops_gym_env.py @@ -7,18 +7,15 @@ import json from unittest.mock import AsyncMock, MagicMock, patch -import pytest from verifiers.types import AssistantMessage, State, ToolCall, ToolMessage as VFToolMessage, TrajectoryStep from enterpriseops_gym_env import ( ALL_DOMAINS, DEFAULT_SERVER_URLS, EOpsGymEnv, - _build_dataset, _build_rubric, _collect_tool_calls, _mcp_content_to_str, - _noop, _parse_info, _resolve_sql_path, ) @@ -141,20 +138,6 @@ def test_empty_trajectory(self): class TestEOpsGymEnvInit: - @patch("enterpriseops_gym_env.MCPClient") - def test_max_workers_forced_to_one(self, mock_mcp_cls): - mock_mcp_cls.return_value = AsyncMock() - mock_mcp_cls.return_value.connect.return_value = False - import verifiers as vf - env = EOpsGymEnv( - server_urls={}, - gym_dbs_path="gym_dbs", - max_workers=4, - dataset=vf.load_example_dataset("gsm8k"), - rubric=vf.Rubric(funcs=[]), - ) - assert env.max_workers == 1 - @patch("enterpriseops_gym_env.MCPClient") def test_connects_and_discovers_tools(self, mock_mcp_cls): mock_client = AsyncMock() @@ -472,7 +455,7 @@ def test_seeds_database_and_filters_tools(self, mock_mcp_cls, mock_create_db): "gym_servers_config": [{"mcp_server_name": "srv", "mcp_server_url": "http://localhost:8002"}], "selected_tools": ["tool_a", "tool_c"], }) - state = State(input={"info": info, "prompt": [], "answer": ""}) + state = State(input={"info": info, "prompt": [], "answer": ""}, tool_defs=list(env._all_tool_defs)) async def run(): return await env.setup_state(state) @@ -482,9 +465,39 @@ async def run(): # No seed file, so no DB created assert result["database_ids"] == {} assert result["server_url_map"] == {"srv": "http://localhost:8002"} - # Tools filtered to selected - assert len(env.tool_defs) == 2 - assert {t.name for t in env.tool_defs} == {"tool_a", "tool_c"} + # Tools filtered to selected (set on state, not self) + assert len(result["tool_defs"]) == 2 + assert {t.name for t in result["tool_defs"]} == {"tool_a", "tool_c"} + # self._all_tool_defs must not be mutated (concurrent safety) + assert len(env._all_tool_defs) == 3 + + @patch("enterpriseops_gym_env.create_database_from_file") + @patch("enterpriseops_gym_env.MCPClient") + def test_empty_selected_tools_gives_all(self, mock_mcp_cls, mock_create_db): + mock_client = AsyncMock() + mock_client.connect.return_value = True + mock_client.list_tools.return_value = [ + {"name": "tool_a", "description": "A"}, + {"name": "tool_b", "description": "B"}, + ] + mock_mcp_cls.return_value = mock_client + + import verifiers as vf + env = EOpsGymEnv( + server_urls={"srv": "http://localhost:8002"}, + gym_dbs_path="/fake/gym_dbs", + dataset=vf.load_example_dataset("gsm8k"), + rubric=vf.Rubric(funcs=[]), + ) + + info = json.dumps({ + "gym_servers_config": [{"mcp_server_name": "srv", "mcp_server_url": "http://localhost:8002"}], + "selected_tools": [], + }) + state = State(input={"info": info, "prompt": [], "answer": ""}, tool_defs=list(env._all_tool_defs)) + + result = asyncio.run(env.setup_state(state)) + assert len(result["tool_defs"]) == 2 @patch("enterpriseops_gym_env._resolve_sql_path", return_value="/fake/db.sql") @patch("enterpriseops_gym_env.create_database_from_file")