Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 57 additions & 27 deletions src/aviary/dataset_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import uuid
from contextlib import contextmanager
from itertools import starmap
from typing import Generic, TypeVar
from typing import Any, Generic

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from aviary.env import Environment, TaskDataset
from aviary.env import TaskDataset, TEnvironment
from aviary.message import Message
from aviary.tools import (
MessagesAdapter,
Expand All @@ -21,7 +21,7 @@

try:
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Security
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Security
from fastapi.security import APIKeyHeader

missing_dependencies = False
Expand All @@ -36,11 +36,33 @@ class StartRequest(BaseModel):
task_idx: int | None = Field(
default=None,
description=(
"Index of the dataset to start. "
"If provided, will call TaskDataset.get_new_env_by_idx(); "
"otherwise, TaskDataset.get_new_env()."
"Optional index of the dataset to start. If provided, will call"
" TaskDataset.get_new_env_by_idx(); otherwise, TaskDataset.get_new_env()."
" Mutually exclusive with task_kwargs."
),
)
task_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Optional keyword arguments passed to TaskDataset.get_new_env_by_args()."
" Mutually exclusive with task_idx."
),
)

@model_validator(mode="after")
def _check_mutually_exclusive(self) -> "StartRequest":
if self.task_idx is not None and self.task_kwargs is not None:
raise ValueError(
"task_idx and task_kwargs are mutually exclusive; specify at most one."
)
return self

def make_env(self, dataset: TaskDataset[TEnvironment]) -> TEnvironment:
if self.task_kwargs is not None:
return dataset.get_new_env_by_args(**self.task_kwargs)
if self.task_idx is not None:
return dataset.get_new_env_by_idx(self.task_idx)
return dataset.get_new_env()


class EnvRequest(BaseModel):
Expand All @@ -60,17 +82,14 @@ class FlushRequest(BaseModel):
BIND_ALL_HOST = "0.0.0.0" # noqa: S104


# Not sure why, but mypy complains if we use the TEnvironment in aviary.env, so redefine here
TEnvironment = TypeVar("TEnvironment", bound=Environment)


class TaskDatasetServer(Generic[TEnvironment]):
def __init__(
self,
dataset: TaskDataset[TEnvironment],
host: str = BIND_ALL_HOST,
port: int = DEFAULT_SERVER_PORT,
api_key: str | None = None,
router: "APIRouter | None" = None,
):
if missing_dependencies:
raise ImportError(
Expand All @@ -83,13 +102,19 @@ def __init__(
self.port = port
self.api_key = api_key

self.app = FastAPI()

# env ID -> (env, last used timestamp)
self.envs: dict[str, tuple[TEnvironment, float]] = {}
self.lock = asyncio.Lock()

self.router = router if router is not None else APIRouter()
self._setup_routes()

if router is None: # Standalone mode: build a default FastAPI app
self.app: FastAPI | None = FastAPI()
self.app.include_router(self.router)
else: # Mounted mode: caller mounts self.router onto their own app
self.app = None

def _get_env(self, env_id: str) -> TEnvironment:
try:
env, _ = self.envs[env_id]
Expand Down Expand Up @@ -123,22 +148,17 @@ def verify_api_key(api_key: str | None = Security(api_key_header)):
status_code=403, detail="Invalid or missing API key"
)

@self.app.post("/start", dependencies=[Depends(verify_api_key)])
@self.router.post("/start", dependencies=[Depends(verify_api_key)])
async def start(req: StartRequest):
with handle_exc_as_http_exc():
if req.task_idx is None:
env = await asyncio.to_thread(self.dataset.get_new_env)
else:
env = await asyncio.to_thread(
self.dataset.get_new_env_by_idx, req.task_idx
)
env = await asyncio.to_thread(req.make_env, self.dataset)

async with self.lock:
env_id = str(uuid.uuid4())
self.envs[env_id] = (env, time.time())
return {"env_id": env_id}

@self.app.post("/reset", dependencies=[Depends(verify_api_key)])
@self.router.post("/reset", dependencies=[Depends(verify_api_key)])
async def reset(req: EnvRequest):
async with self.lock:
env = self._get_env(req.env_id)
Expand All @@ -152,7 +172,7 @@ async def reset(req: EnvRequest):
ToolsAdapter.dump_python(tools, exclude_none=True, by_alias=True),
)

@self.app.post("/step", dependencies=[Depends(verify_api_key)])
@self.router.post("/step", dependencies=[Depends(verify_api_key)])
async def step(req: StepRequest):
async with self.lock:
env = self._get_env(req.env_id)
Expand All @@ -166,7 +186,7 @@ async def step(req: StepRequest):
)
return obs_serialized, *reward_done_trunc

@self.app.post("/close", dependencies=[Depends(verify_api_key)])
@self.router.post("/close", dependencies=[Depends(verify_api_key)])
async def close(req: EnvRequest):
async with self.lock:
env = self._get_env(req.env_id)
Expand All @@ -179,7 +199,7 @@ async def close(req: EnvRequest):

return {"env_id": req.env_id}

@self.app.post("/close_old_envs", dependencies=[Depends(verify_api_key)])
@self.router.post("/close_old_envs", dependencies=[Depends(verify_api_key)])
async def close_old_envs(req: FlushRequest):
"""Endpoint to close environments that have not been used in a while.

Comment thread
jamesbraza marked this conversation as resolved.
Expand Down Expand Up @@ -212,7 +232,7 @@ async def close(env_id: str, env: TEnvironment) -> str | None:
"closed_env_ids": [env_id for env_id in closed if env_id is not None]
}

@self.app.get("/info", dependencies=[Depends(verify_api_key)])
@self.router.get("/info", dependencies=[Depends(verify_api_key)])
def info():
try:
dataset_len: int | None = len(self.dataset)
Expand All @@ -223,11 +243,21 @@ def info():
"running_env_ids": list(self.envs.keys()),
}

def start(self):
def start(self) -> None:
if self.app is None:
raise RuntimeError(
f"{type(self).__name__} was constructed with an external router; "
"mount self.router on your own FastAPI app and run uvicorn there."
)
uvicorn.run(self.app, host=self.host, port=self.port, log_level="debug")

async def astart(self):
async def astart(self) -> None:
"""Async equivalent of start()."""
if self.app is None:
raise RuntimeError(
f"{type(self).__name__} was constructed with an external router; "
"mount self.router on your own FastAPI app and run uvicorn there."
)
config = uvicorn.Config(
self.app, host=self.host, port=self.port, log_level="debug"
)
Expand Down
10 changes: 10 additions & 0 deletions src/aviary/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def get_new_env(self) -> TEnvironment:
f'"{self.__class__.__name__}" does not implement get_new_env'
)

def get_new_env_by_args(self, **kwargs) -> TEnvironment:
"""Get an env from arbitrary task kwargs.

Useful when the caller drives env creation from request payloads
rather than a default configuration or fixed environment index.
"""
raise NotImplementedError(
f'"{self.__class__.__name__}" does not implement get_new_env_by_args'
)

def iter_batches(
self, batch_size: int, shuffle: bool = False
) -> Iterator[list[TEnvironment]]:
Expand Down
100 changes: 96 additions & 4 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import time
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from typing import Any, ClassVar
from typing import Any, ClassVar, cast
from unittest import mock

import litellm
import numpy as np
import pytest
import pytest_asyncio
from fastapi import FastAPI
from fastapi import APIRouter, FastAPI
from httpx import ASGITransport, AsyncClient
from pydantic import BaseModel, ValidationError
from pytest_subtests import SubTests
Expand Down Expand Up @@ -682,7 +682,32 @@ async def _make_test_client(app: FastAPI) -> AsyncIterator[AsyncClient]:
@pytest_asyncio.fixture
async def server_async_client() -> AsyncIterator[AsyncClient]:
server = TaskDatasetServer[DummyEnv](dataset=TaskDataset.from_name("dummy"))
async with _make_test_client(app=server.app) as client:
async with _make_test_client(app=cast(FastAPI, server.app)) as client:
yield client


class StubArgsTaskDataset(TaskDataset[DummyEnv]):
"""Stub task dataset to exercise get_new_env_by_args."""

def get_new_env_by_args(self, *, task: str) -> DummyEnv: # type: ignore[override]
return DummyEnv(task=task)


@pytest_asyncio.fixture
async def args_async_client() -> AsyncIterator[AsyncClient]:
server = TaskDatasetServer[DummyEnv](dataset=StubArgsTaskDataset())
async with _make_test_client(app=cast(FastAPI, server.app)) as client:
yield client


@pytest_asyncio.fixture
async def mounted_async_client() -> AsyncIterator[AsyncClient]:
server = TaskDatasetServer[DummyEnv](
dataset=TaskDataset.from_name("dummy"), router=APIRouter(tags=["env"])
)
app = FastAPI()
app.include_router(server.router, prefix="/env")
async with _make_test_client(app=app) as client:
yield client


Expand Down Expand Up @@ -764,7 +789,7 @@ async def slow_close(*_) -> None:
# last_used=0 guarantees it's stale for any req.last_used >= 0
server.envs["stale"] = (stale_env, 0.0)

async with _make_test_client(app=server.app) as client:
async with _make_test_client(app=cast(FastAPI, server.app)) as client:
with mock.patch.object(DummyEnv, "close", slow_close):
# Kick off /close_old_envs; env.close() will await release_event
close_task = asyncio.create_task(
Expand Down Expand Up @@ -844,6 +869,73 @@ async def test_step_with_tool_response_message(
assert not done
assert not truncated

@pytest.mark.asyncio
async def test_start_raises_when_get_new_env_by_args_not_implemented(
self, server_async_client: AsyncClient
) -> None:
# Dummy dataset doesn't implement get_new_env_by_args, so sending
# task_kwargs should surface as a 500 via handle_exc_as_http_exc
response = await server_async_client.post(
"/start", json={"task_kwargs": {"task": "anything"}}
)
assert response.status_code == 500
assert "get_new_env_by_args" in response.json()["detail"]

@pytest.mark.asyncio
async def test_start_with_task_kwargs(self, args_async_client: AsyncClient) -> None:
start_resp = await args_async_client.post(
"/start", json={"task_kwargs": {"task": "five-word-story topic"}}
)
assert start_resp.status_code == 200
env_id = start_resp.json()["env_id"]

# Reset and confirm the task made it into the initial observation.
reset_resp = await args_async_client.post("/reset", json={"env_id": env_id})
assert reset_resp.status_code == 200
(obs,), tools = reset_resp.json()
assert "five-word-story topic" in obs["content"]
assert tools

@pytest.mark.asyncio
async def test_start_rejects_both_task_idx_and_task_kwargs(
self, args_async_client: AsyncClient
) -> None:
# Specifying both is ambiguous; server should reject at request validation
start_resp = await args_async_client.post(
"/start", json={"task_idx": 42, "task_kwargs": {"task": "kwargs-won"}}
)
assert start_resp.status_code == 422
assert "mutually exclusive" in start_resp.text

@pytest.mark.asyncio
async def test_start_reset_step_through_prefix(
self, mounted_async_client: AsyncClient
) -> None:
"""End-to-end smoke test for the mounted-router code path."""
start_resp = await mounted_async_client.post("/env/start", json={})
assert start_resp.status_code == 200, (
"Mounted router did not expose /env/start — route registration"
" against external APIRouter is broken"
)
env_id = start_resp.json()["env_id"]

reset_resp = await mounted_async_client.post(
"/env/reset", json={"env_id": env_id}
)
assert reset_resp.status_code == 200, (
"Mounted router cannot retrieve the env it just created"
)

action = ToolRequestMessage(
tool_calls=[
ToolCall.from_name("print_story", story="one two three four five")
]
)
step_resp = await mounted_async_client.post(
"/env/step", json={"env_id": env_id, "action": action.model_dump()}
)
assert step_resp.status_code == 200


class TestDefaultNoToolCallsResponse:
@pytest.mark.parametrize(
Expand Down
Loading