From 95c10338410d8bbf7209b9e8357ac59ecc5bd748 Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 15 Apr 2026 12:11:44 -0700 Subject: [PATCH 1/3] Dropping extra TEnvironment in dataset_server.py --- src/aviary/dataset_server.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/aviary/dataset_server.py b/src/aviary/dataset_server.py index 67f4f4f4..9b26d585 100644 --- a/src/aviary/dataset_server.py +++ b/src/aviary/dataset_server.py @@ -6,11 +6,11 @@ import uuid from contextlib import contextmanager from itertools import starmap -from typing import Generic, TypeVar +from typing import Generic from pydantic import BaseModel, Field -from aviary.env import Environment, TaskDataset +from aviary.env import TaskDataset, TEnvironment from aviary.message import Message from aviary.tools import ( MessagesAdapter, @@ -60,10 +60,6 @@ class FlushRequest(BaseModel): BIND_ALL_HOST = "0.0.0.0" # noqa: S104 -# Not sure why, but mypy complains if we use the TEnvironment in aviary.env, so redefine here -TEnvironment = TypeVar("TEnvironment", bound=Environment) - - class TaskDatasetServer(Generic[TEnvironment]): def __init__( self, From bbb0875eabcf73daca4359e4ab6e7483d5643e8b Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 15 Apr 2026 12:35:29 -0700 Subject: [PATCH 2/3] Support arbitrary task kwargs and external APIRouter in TaskDatasetServer Add TaskDataset.get_new_env_by_args(**kwargs) so subclasses can build envs from request-driven payloads rather than a fixed index. StartRequest picks it up via a new task_kwargs field (takes precedence over task_idx). TaskDatasetServer now accepts an optional APIRouter: routes are always attached to self.router, and in mounted mode the caller includes the router on their own FastAPI app (start()/astart() raise in that mode). Co-Authored-By: Claude Opus 4.6 --- src/aviary/dataset_server.py | 68 +++++++++++++++------- src/aviary/env.py | 10 ++++ tests/test_envs.py | 110 +++++++++++++++++++++++++++++++++-- 3 files changed, 163 insertions(+), 25 deletions(-) diff --git a/src/aviary/dataset_server.py b/src/aviary/dataset_server.py index 9b26d585..f5a92da0 100644 --- a/src/aviary/dataset_server.py +++ b/src/aviary/dataset_server.py @@ -6,7 +6,7 @@ import uuid from contextlib import contextmanager from itertools import starmap -from typing import Generic +from typing import Any, Generic from pydantic import BaseModel, Field @@ -21,7 +21,7 @@ try: import uvicorn - from fastapi import Depends, FastAPI, HTTPException, Security + from fastapi import APIRouter, Depends, FastAPI, HTTPException, Security from fastapi.security import APIKeyHeader missing_dependencies = False @@ -36,11 +36,25 @@ class StartRequest(BaseModel): task_idx: int | None = Field( default=None, description=( - "Index of the dataset to start. " - "If provided, will call TaskDataset.get_new_env_by_idx(); " - "otherwise, TaskDataset.get_new_env()." + "Optional index of the dataset to start. If provided (and task_kwargs is" + " left as default of None), will call TaskDataset.get_new_env_by_idx();" + " otherwise, TaskDataset.get_new_env()." ), ) + task_kwargs: dict[str, Any] | None = Field( + default=None, + description=( + "Optional keyword arguments passed to TaskDataset.get_new_env_by_args()." + " Takes precedence over task_idx when set." + ), + ) + + def make_env(self, dataset: TaskDataset[TEnvironment]) -> TEnvironment: + if self.task_kwargs is not None: + return dataset.get_new_env_by_args(**self.task_kwargs) + if self.task_idx is not None: + return dataset.get_new_env_by_idx(self.task_idx) + return dataset.get_new_env() class EnvRequest(BaseModel): @@ -67,6 +81,7 @@ def __init__( host: str = BIND_ALL_HOST, port: int = DEFAULT_SERVER_PORT, api_key: str | None = None, + router: "APIRouter | None" = None, ): if missing_dependencies: raise ImportError( @@ -79,13 +94,19 @@ def __init__( self.port = port self.api_key = api_key - self.app = FastAPI() - # env ID -> (env, last used timestamp) self.envs: dict[str, tuple[TEnvironment, float]] = {} self.lock = asyncio.Lock() + + self.router = router if router is not None else APIRouter() self._setup_routes() + if router is None: # Standalone mode: build a default FastAPI app + self.app: FastAPI | None = FastAPI() + self.app.include_router(self.router) + else: # Mounted mode: caller mounts self.router onto their own app + self.app = None + def _get_env(self, env_id: str) -> TEnvironment: try: env, _ = self.envs[env_id] @@ -119,22 +140,17 @@ def verify_api_key(api_key: str | None = Security(api_key_header)): status_code=403, detail="Invalid or missing API key" ) - @self.app.post("/start", dependencies=[Depends(verify_api_key)]) + @self.router.post("/start", dependencies=[Depends(verify_api_key)]) async def start(req: StartRequest): with handle_exc_as_http_exc(): - if req.task_idx is None: - env = await asyncio.to_thread(self.dataset.get_new_env) - else: - env = await asyncio.to_thread( - self.dataset.get_new_env_by_idx, req.task_idx - ) + env = await asyncio.to_thread(req.make_env, self.dataset) async with self.lock: env_id = str(uuid.uuid4()) self.envs[env_id] = (env, time.time()) return {"env_id": env_id} - @self.app.post("/reset", dependencies=[Depends(verify_api_key)]) + @self.router.post("/reset", dependencies=[Depends(verify_api_key)]) async def reset(req: EnvRequest): async with self.lock: env = self._get_env(req.env_id) @@ -148,7 +164,7 @@ async def reset(req: EnvRequest): ToolsAdapter.dump_python(tools, exclude_none=True, by_alias=True), ) - @self.app.post("/step", dependencies=[Depends(verify_api_key)]) + @self.router.post("/step", dependencies=[Depends(verify_api_key)]) async def step(req: StepRequest): async with self.lock: env = self._get_env(req.env_id) @@ -162,7 +178,7 @@ async def step(req: StepRequest): ) return obs_serialized, *reward_done_trunc - @self.app.post("/close", dependencies=[Depends(verify_api_key)]) + @self.router.post("/close", dependencies=[Depends(verify_api_key)]) async def close(req: EnvRequest): async with self.lock: env = self._get_env(req.env_id) @@ -175,7 +191,7 @@ async def close(req: EnvRequest): return {"env_id": req.env_id} - @self.app.post("/close_old_envs", dependencies=[Depends(verify_api_key)]) + @self.router.post("/close_old_envs", dependencies=[Depends(verify_api_key)]) async def close_old_envs(req: FlushRequest): """Endpoint to close environments that have not been used in a while. @@ -208,7 +224,7 @@ async def close(env_id: str, env: TEnvironment) -> str | None: "closed_env_ids": [env_id for env_id in closed if env_id is not None] } - @self.app.get("/info", dependencies=[Depends(verify_api_key)]) + @self.router.get("/info", dependencies=[Depends(verify_api_key)]) def info(): try: dataset_len: int | None = len(self.dataset) @@ -219,11 +235,21 @@ def info(): "running_env_ids": list(self.envs.keys()), } - def start(self): + def start(self) -> None: + if self.app is None: + raise RuntimeError( + f"{type(self).__name__} was constructed with an external router; " + "mount self.router on your own FastAPI app and run uvicorn there." + ) uvicorn.run(self.app, host=self.host, port=self.port, log_level="debug") - async def astart(self): + async def astart(self) -> None: """Async equivalent of start().""" + if self.app is None: + raise RuntimeError( + f"{type(self).__name__} was constructed with an external router; " + "mount self.router on your own FastAPI app and run uvicorn there." + ) config = uvicorn.Config( self.app, host=self.host, port=self.port, log_level="debug" ) diff --git a/src/aviary/env.py b/src/aviary/env.py index 4f2471fa..517c27c3 100644 --- a/src/aviary/env.py +++ b/src/aviary/env.py @@ -437,6 +437,16 @@ def get_new_env(self) -> TEnvironment: f'"{self.__class__.__name__}" does not implement get_new_env' ) + def get_new_env_by_args(self, **kwargs) -> TEnvironment: + """Get an env from arbitrary task kwargs. + + Useful when the caller drives env creation from request payloads + rather than a default configuration or fixed environment index. + """ + raise NotImplementedError( + f'"{self.__class__.__name__}" does not implement get_new_env_by_args' + ) + def iter_batches( self, batch_size: int, shuffle: bool = False ) -> Iterator[list[TEnvironment]]: diff --git a/tests/test_envs.py b/tests/test_envs.py index 5f1bb611..f6833708 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -6,14 +6,14 @@ import time from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from typing import Any, ClassVar +from typing import Any, ClassVar, cast from unittest import mock import litellm import numpy as np import pytest import pytest_asyncio -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from httpx import ASGITransport, AsyncClient from pydantic import BaseModel, ValidationError from pytest_subtests import SubTests @@ -682,7 +682,32 @@ async def _make_test_client(app: FastAPI) -> AsyncIterator[AsyncClient]: @pytest_asyncio.fixture async def server_async_client() -> AsyncIterator[AsyncClient]: server = TaskDatasetServer[DummyEnv](dataset=TaskDataset.from_name("dummy")) - async with _make_test_client(app=server.app) as client: + async with _make_test_client(app=cast(FastAPI, server.app)) as client: + yield client + + +class StubArgsTaskDataset(TaskDataset[DummyEnv]): + """Stub task dataset to exercise get_new_env_by_args.""" + + def get_new_env_by_args(self, *, task: str) -> DummyEnv: # type: ignore[override] + return DummyEnv(task=task) + + +@pytest_asyncio.fixture +async def args_async_client() -> AsyncIterator[AsyncClient]: + server = TaskDatasetServer[DummyEnv](dataset=StubArgsTaskDataset()) + async with _make_test_client(app=cast(FastAPI, server.app)) as client: + yield client + + +@pytest_asyncio.fixture +async def mounted_async_client() -> AsyncIterator[AsyncClient]: + server = TaskDatasetServer[DummyEnv]( + dataset=TaskDataset.from_name("dummy"), router=APIRouter(tags=["env"]) + ) + app = FastAPI() + app.include_router(server.router, prefix="/env") + async with _make_test_client(app=app) as client: yield client @@ -764,7 +789,7 @@ async def slow_close(*_) -> None: # last_used=0 guarantees it's stale for any req.last_used >= 0 server.envs["stale"] = (stale_env, 0.0) - async with _make_test_client(app=server.app) as client: + async with _make_test_client(app=cast(FastAPI, server.app)) as client: with mock.patch.object(DummyEnv, "close", slow_close): # Kick off /close_old_envs; env.close() will await release_event close_task = asyncio.create_task( @@ -844,6 +869,83 @@ async def test_step_with_tool_response_message( assert not done assert not truncated + @pytest.mark.asyncio + async def test_start_raises_when_get_new_env_by_args_not_implemented( + self, server_async_client: AsyncClient + ) -> None: + # Dummy dataset doesn't implement get_new_env_by_args, so sending + # task_kwargs should surface as a 500 via handle_exc_as_http_exc + response = await server_async_client.post( + "/start", json={"task_kwargs": {"task": "anything"}} + ) + assert response.status_code == 500 + assert "get_new_env_by_args" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_start_with_task_kwargs(self, args_async_client: AsyncClient) -> None: + start_resp = await args_async_client.post( + "/start", json={"task_kwargs": {"task": "five-word-story topic"}} + ) + assert start_resp.status_code == 200 + env_id = start_resp.json()["env_id"] + + # Reset and confirm the task made it into the initial observation. + reset_resp = await args_async_client.post("/reset", json={"env_id": env_id}) + assert reset_resp.status_code == 200 + (obs,), tools = reset_resp.json() + assert "five-word-story topic" in obs["content"] + assert tools + + @pytest.mark.asyncio + async def test_task_kwargs_takes_precedence_over_task_idx( + self, args_async_client: AsyncClient + ) -> None: + assert "get_new_env_by_idx" not in StubArgsTaskDataset.__dict__, ( + "Test expects no get_new_env_by_idx for assertions to make sense" + ) + # StubArgsTaskDataset has no get_new_env_by_idx, so if task_idx were used + # this would 500. The fact that it succeeds proves task_kwargs took precedence + start_resp = await args_async_client.post( + "/start", json={"task_idx": 42, "task_kwargs": {"task": "kwargs-won"}} + ) + assert start_resp.status_code == 200 + env_id = start_resp.json()["env_id"] + + # Reset and confirm the task made it into the initial observation. + reset_resp = await args_async_client.post("/reset", json={"env_id": env_id}) + (obs,), tools = reset_resp.json() + assert "kwargs-won" in obs["content"] + assert tools + + @pytest.mark.asyncio + async def test_start_reset_step_through_prefix( + self, mounted_async_client: AsyncClient + ) -> None: + """End-to-end smoke test for the mounted-router code path.""" + start_resp = await mounted_async_client.post("/env/start", json={}) + assert start_resp.status_code == 200, ( + "Mounted router did not expose /env/start — route registration" + " against external APIRouter is broken" + ) + env_id = start_resp.json()["env_id"] + + reset_resp = await mounted_async_client.post( + "/env/reset", json={"env_id": env_id} + ) + assert reset_resp.status_code == 200, ( + "Mounted router cannot retrieve the env it just created" + ) + + action = ToolRequestMessage( + tool_calls=[ + ToolCall.from_name("print_story", story="one two three four five") + ] + ) + step_resp = await mounted_async_client.post( + "/env/step", json={"env_id": env_id, "action": action.model_dump()} + ) + assert step_resp.status_code == 200 + class TestDefaultNoToolCallsResponse: @pytest.mark.parametrize( From 6b69eddecb65af64642c2a93b4ce71b40326e61f Mon Sep 17 00:00:00 2001 From: James Braza Date: Thu, 16 Apr 2026 11:56:52 -0700 Subject: [PATCH 3/3] Reject StartRequest with both task_idx and task_kwargs as 422 Co-Authored-By: Claude Opus 4 --- src/aviary/dataset_server.py | 18 +++++++++++++----- tests/test_envs.py | 18 ++++-------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/aviary/dataset_server.py b/src/aviary/dataset_server.py index f5a92da0..8a961099 100644 --- a/src/aviary/dataset_server.py +++ b/src/aviary/dataset_server.py @@ -8,7 +8,7 @@ from itertools import starmap from typing import Any, Generic -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from aviary.env import TaskDataset, TEnvironment from aviary.message import Message @@ -36,19 +36,27 @@ class StartRequest(BaseModel): task_idx: int | None = Field( default=None, description=( - "Optional index of the dataset to start. If provided (and task_kwargs is" - " left as default of None), will call TaskDataset.get_new_env_by_idx();" - " otherwise, TaskDataset.get_new_env()." + "Optional index of the dataset to start. If provided, will call" + " TaskDataset.get_new_env_by_idx(); otherwise, TaskDataset.get_new_env()." + " Mutually exclusive with task_kwargs." ), ) task_kwargs: dict[str, Any] | None = Field( default=None, description=( "Optional keyword arguments passed to TaskDataset.get_new_env_by_args()." - " Takes precedence over task_idx when set." + " Mutually exclusive with task_idx." ), ) + @model_validator(mode="after") + def _check_mutually_exclusive(self) -> "StartRequest": + if self.task_idx is not None and self.task_kwargs is not None: + raise ValueError( + "task_idx and task_kwargs are mutually exclusive; specify at most one." + ) + return self + def make_env(self, dataset: TaskDataset[TEnvironment]) -> TEnvironment: if self.task_kwargs is not None: return dataset.get_new_env_by_args(**self.task_kwargs) diff --git a/tests/test_envs.py b/tests/test_envs.py index f6833708..4ccc5b42 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -897,25 +897,15 @@ async def test_start_with_task_kwargs(self, args_async_client: AsyncClient) -> N assert tools @pytest.mark.asyncio - async def test_task_kwargs_takes_precedence_over_task_idx( + async def test_start_rejects_both_task_idx_and_task_kwargs( self, args_async_client: AsyncClient ) -> None: - assert "get_new_env_by_idx" not in StubArgsTaskDataset.__dict__, ( - "Test expects no get_new_env_by_idx for assertions to make sense" - ) - # StubArgsTaskDataset has no get_new_env_by_idx, so if task_idx were used - # this would 500. The fact that it succeeds proves task_kwargs took precedence + # Specifying both is ambiguous; server should reject at request validation start_resp = await args_async_client.post( "/start", json={"task_idx": 42, "task_kwargs": {"task": "kwargs-won"}} ) - assert start_resp.status_code == 200 - env_id = start_resp.json()["env_id"] - - # Reset and confirm the task made it into the initial observation. - reset_resp = await args_async_client.post("/reset", json={"env_id": env_id}) - (obs,), tools = reset_resp.json() - assert "kwargs-won" in obs["content"] - assert tools + assert start_resp.status_code == 422 + assert "mutually exclusive" in start_resp.text @pytest.mark.asyncio async def test_start_reset_step_through_prefix(