diff --git a/src/agentevals/api/app.py b/src/agentevals/api/app.py index 695ee56..7a8cf59 100644 --- a/src/agentevals/api/app.py +++ b/src/agentevals/api/app.py @@ -1,11 +1,14 @@ """FastAPI application for agentevals REST API.""" +from __future__ import annotations + import asyncio import json import logging import os from contextlib import asynccontextmanager from pathlib import Path +from typing import TYPE_CHECKING from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -17,6 +20,9 @@ from .debug_routes import debug_router from .routes import router +if TYPE_CHECKING: + from ..streaming.ws_server import StreamingTraceManager + try: from dotenv import load_dotenv @@ -27,107 +33,122 @@ pass -@asynccontextmanager -async def lifespan(app: FastAPI): - log_level_str = os.getenv("AGENTEVALS_LOG_LEVEL", "INFO").upper() - log_level = getattr(logging, log_level_str, logging.INFO) - logging.basicConfig( - level=log_level, - format="%(levelname)s:%(name)s:%(message)s", - force=True, +def _build_lifespan(): + @asynccontextmanager + async def lifespan(app: FastAPI): + log_level_str = os.getenv("AGENTEVALS_LOG_LEVEL", "INFO").upper() + log_level = getattr(logging, log_level_str, logging.INFO) + logging.basicConfig( + level=log_level, + format="%(levelname)s:%(name)s:%(message)s", + force=True, + ) + ae_logger = logging.getLogger("agentevals") + ae_logger.setLevel(log_level) + if log_buffer not in ae_logger.handlers: + log_buffer.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) + ae_logger.addHandler(log_buffer) + mgr = getattr(app.state, "trace_manager", None) + if mgr: + mgr.start_cleanup_task() + yield + if mgr: + await mgr.shutdown() + ae_logger.removeHandler(log_buffer) + + return lifespan + + +def create_app( + *, + trace_manager: StreamingTraceManager | None = None, + enable_streaming: bool = False, +) -> FastAPI: + """Create the main agentevals API app.""" + app = FastAPI( + title="agentevals API", + version=__version__, + description="REST API for evaluating agent traces using ADK's scoring framework", + lifespan=_build_lifespan(), ) - ae_logger = logging.getLogger("agentevals") - ae_logger.setLevel(log_level) - if log_buffer not in ae_logger.handlers: - log_buffer.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) - ae_logger.addHandler(log_buffer) - mgr = getattr(app.state, "trace_manager", None) - if mgr: - mgr.start_cleanup_task() - yield - if mgr: - await mgr.shutdown() - ae_logger.removeHandler(log_buffer) - - -app = FastAPI( - title="agentevals API", - version=__version__, - description="REST API for evaluating agent traces using ADK's scoring framework", - lifespan=lifespan, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:5173", "http://localhost:5174"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - expose_headers=["*"], -) - -app.include_router(router, prefix="/api") -app.include_router(debug_router, prefix="/api/debug") - -_live_mode = os.getenv("AGENTEVALS_LIVE") == "1" - -if _live_mode: - from fastapi import Request as _Request - from fastapi import WebSocket - from ..streaming.ws_server import StreamingTraceManager - from .streaming_routes import streaming_router - - app.include_router(streaming_router, prefix="/api/streaming") - app.state.trace_manager = StreamingTraceManager() - - @app.websocket("/ws/traces") - async def websocket_endpoint(websocket: WebSocket): - await websocket.app.state.trace_manager.handle_connection(websocket) - - @app.get("/stream/ui-updates") - async def ui_updates_stream(request: _Request): - mgr = request.app.state.trace_manager - - async def event_generator(): - queue = mgr.register_sse_client() - try: - while True: - event = await queue.get() - if event is None: - break - yield f"data: {json.dumps(event)}\n\n" - except asyncio.CancelledError: - pass - finally: - mgr.unregister_sse_client(queue) - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) + app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:5173", "http://localhost:5174"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + ) + + app.include_router(router, prefix="/api") + app.include_router(debug_router, prefix="/api/debug") + + if trace_manager is not None: + app.state.trace_manager = trace_manager + + if enable_streaming: + if trace_manager is None: + raise ValueError("enable_streaming requires a trace_manager") + + from fastapi import Request as _Request + from fastapi import WebSocket + + from .streaming_routes import streaming_router + + app.include_router(streaming_router, prefix="/api/streaming") + + @app.websocket("/ws/traces") + async def websocket_endpoint(websocket: WebSocket): + await websocket.app.state.trace_manager.handle_connection(websocket) + + @app.get("/stream/ui-updates") + async def ui_updates_stream(request: _Request): + mgr = request.app.state.trace_manager + + async def event_generator(): + queue = mgr.register_sse_client() + try: + while True: + event = await queue.get() + if event is None: + break + yield f"data: {json.dumps(event)}\n\n" + except asyncio.CancelledError: + pass + finally: + mgr.unregister_sse_client(queue) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + static_dir = Path(__file__).parent.parent / "_static" + has_ui = static_dir.is_dir() and (static_dir / "index.html").exists() + + if has_ui and not os.getenv("AGENTEVALS_HEADLESS"): + from fastapi.responses import FileResponse + from fastapi.staticfiles import StaticFiles + app.mount("/assets", StaticFiles(directory=static_dir / "assets"), name="ui-assets") -_static_dir = Path(__file__).parent.parent / "_static" -_has_ui = _static_dir.is_dir() and (_static_dir / "index.html").exists() + @app.get("/") + async def root(): + return FileResponse(static_dir / "index.html") -if _has_ui and not os.getenv("AGENTEVALS_HEADLESS"): - from fastapi.responses import FileResponse - from fastapi.staticfiles import StaticFiles + @app.get("/{path:path}") + async def spa_fallback(path: str): + file_path = static_dir / path + if file_path.is_file(): + return FileResponse(file_path) + return FileResponse(static_dir / "index.html") - app.mount("/assets", StaticFiles(directory=_static_dir / "assets"), name="ui-assets") + return app - @app.get("/") - async def root(): - return FileResponse(_static_dir / "index.html") - @app.get("/{path:path}") - async def spa_fallback(path: str): - file_path = _static_dir / path - if file_path.is_file(): - return FileResponse(file_path) - return FileResponse(_static_dir / "index.html") +app = create_app() diff --git a/src/agentevals/api/dependencies.py b/src/agentevals/api/dependencies.py index c6b08a1..5f2566b 100644 --- a/src/agentevals/api/dependencies.py +++ b/src/agentevals/api/dependencies.py @@ -26,11 +26,3 @@ def require_trace_manager(request: Request) -> StreamingTraceManager: if mgr is None: raise HTTPException(status_code=503, detail="Live mode not enabled") return mgr - - -def require_trace_manager_from_app(app: Any) -> StreamingTraceManager: - """Return the StreamingTraceManager from app, raising RuntimeError if missing.""" - mgr = get_trace_manager_from_app(app) - if mgr is None: - raise RuntimeError("Live mode not enabled") - return mgr diff --git a/src/agentevals/api/otlp_app.py b/src/agentevals/api/otlp_app.py index d04fd1e..ab78ddf 100644 --- a/src/agentevals/api/otlp_app.py +++ b/src/agentevals/api/otlp_app.py @@ -1,25 +1,27 @@ """Minimal FastAPI app for the OTLP HTTP receiver on port 4318. -Shares the StreamingTraceManager with the main app (port 8001). Mounts only the /v1/traces and /v1/logs endpoints. """ -from contextlib import asynccontextmanager +from __future__ import annotations + +from typing import TYPE_CHECKING from fastapi import FastAPI from .otlp_routes import otlp_router +if TYPE_CHECKING: + from ..streaming.ws_server import StreamingTraceManager -@asynccontextmanager -async def lifespan(app: FastAPI): - from .app import app as main_app - mgr = getattr(main_app.state, "trace_manager", None) - if mgr: - app.state.trace_manager = mgr - yield +def create_otlp_app(*, trace_manager: StreamingTraceManager | None = None) -> FastAPI: + """Create the OTLP HTTP receiver app.""" + app = FastAPI(title="agentevals OTLP receiver") + if trace_manager is not None: + app.state.trace_manager = trace_manager + app.include_router(otlp_router) + return app -otlp_app = FastAPI(title="agentevals OTLP receiver", lifespan=lifespan) -otlp_app.include_router(otlp_router) +otlp_app = create_otlp_app() diff --git a/src/agentevals/cli.py b/src/agentevals/cli.py index 3fd7dbd..3964f03 100644 --- a/src/agentevals/cli.py +++ b/src/agentevals/cli.py @@ -534,27 +534,26 @@ async def _run_servers( otlp_grpc_port: int, *, mcp_port: int | None = None, - reload: bool = False, - reload_dirs: list[str] | None = None, log_level: str = "warning", ) -> None: """Start API, OTLP HTTP+gRPC receivers, and optional MCP (Streamable HTTP).""" import uvicorn + from .api.app import create_app + from .api.otlp_app import create_otlp_app + from .streaming.ws_server import StreamingTraceManager + shared_kwargs: dict = { "host": host, - "reload": reload, "log_level": log_level, } - if reload_dirs: - shared_kwargs["reload_dirs"] = reload_dirs - # TODO #99 Create the manager and pass it into the Server constructors instead of injecting it into the app state. + mgr = StreamingTraceManager() + main_app = create_app(trace_manager=mgr, enable_streaming=True) + otlp_app = create_otlp_app(trace_manager=mgr) - main_server = uvicorn.Server(uvicorn.Config("agentevals.api.app:app", port=port, **shared_kwargs)) - otlp_http_server = uvicorn.Server( - uvicorn.Config("agentevals.api.otlp_app:otlp_app", port=otlp_http_port, **shared_kwargs) - ) + main_server = uvicorn.Server(uvicorn.Config(main_app, port=port, **shared_kwargs)) + otlp_http_server = uvicorn.Server(uvicorn.Config(otlp_app, port=otlp_http_port, **shared_kwargs)) uvicorn_servers: list = [main_server, otlp_http_server] if mcp_port is not None: @@ -571,10 +570,6 @@ async def _run_servers( mcp_uvicorn = uvicorn.Server(uvicorn.Config(mcp_app, **mcp_kwargs)) uvicorn_servers.append(mcp_uvicorn) - from .api.app import app as main_app - from .api.dependencies import require_trace_manager_from_app - - mgr = require_trace_manager_from_app(main_app) otlp_grpc_server = create_otlp_grpc_server(host=host, port=otlp_grpc_port, manager=mgr) await otlp_grpc_server.start() @@ -703,8 +698,6 @@ def serve( click.echo("Waiting for agent connections...") click.echo() - src_path = Path(__file__).parent.parent - reload_dirs = [str(src_path)] asyncio.run( _run_servers( host, @@ -712,8 +705,6 @@ def serve( otlp_http_port, otlp_grpc_port, mcp_port=mcp_port, - reload=True, - reload_dirs=reload_dirs, log_level="info", ) ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 922841d..36e2667 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -101,18 +101,12 @@ def live_servers(): os.environ["AGENTEVALS_LIVE"] = "1" os.environ["AGENTEVALS_HEADLESS"] = "1" - import importlib + from agentevals.api.app import create_app + from agentevals.api.otlp_app import create_otlp_app - from agentevals.api import app as app_module - - importlib.reload(app_module) - - from agentevals.api.app import app - from agentevals.api.otlp_app import otlp_app - - mgr = getattr(app.state, "trace_manager", None) - if mgr: - otlp_app.state.trace_manager = mgr + mgr = StreamingTraceManager() + app = create_app(trace_manager=mgr, enable_streaming=True) + otlp_app = create_otlp_app(trace_manager=mgr) main_config = uvicorn.Config(app, host="127.0.0.1", port=main_port, log_level="warning") otlp_config = uvicorn.Config(otlp_app, host="127.0.0.1", port=otlp_http_port, log_level="warning") diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..abeb5e3 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,87 @@ +"""CLI startup wiring tests.""" + +from __future__ import annotations + +import sys +from types import ModuleType +from unittest.mock import AsyncMock + +import pytest + +from agentevals import cli + + +class _FakeGrpcServer: + def __init__(self): + self.started = False + + async def start(self) -> None: + self.started = True + + +class _FakeUvicornConfig: + def __init__(self, app, **kwargs): + self.app = app + self.kwargs = kwargs + + +class _FakeUvicornServer: + def __init__(self, config): + self.config = config + self.should_exit = False + self.force_exit = False + self.handle_exit = None + + async def serve(self) -> None: + return None + + +@pytest.mark.asyncio +async def test_run_servers_shares_one_trace_manager_across_live_servers(monkeypatch): + fake_grpc_server = _FakeGrpcServer() + fake_stop_grpc = AsyncMock() + created_servers: list[_FakeUvicornServer] = [] + captured: dict[str, object] = {} + + def fake_create_otlp_grpc_server(*, host, port, manager): + captured["host"] = host + captured["port"] = port + captured["manager"] = manager + return fake_grpc_server + + def fake_server_factory(config): + server = _FakeUvicornServer(config) + created_servers.append(server) + return server + + fake_uvicorn = ModuleType("uvicorn") + fake_uvicorn.Config = _FakeUvicornConfig + fake_uvicorn.Server = fake_server_factory + + monkeypatch.setitem(sys.modules, "uvicorn", fake_uvicorn) + monkeypatch.setattr(cli, "create_otlp_grpc_server", fake_create_otlp_grpc_server) + monkeypatch.setattr(cli, "stop_otlp_grpc_server", fake_stop_grpc) + + await cli._run_servers("127.0.0.1", 8001, 4318, 4317) + + assert len(created_servers) == 2 + main_app = created_servers[0].config.app + otlp_app = created_servers[1].config.app + manager = captured["manager"] + + assert captured["host"] == "127.0.0.1" + assert captured["port"] == 4317 + assert main_app.state.trace_manager is manager + assert otlp_app.state.trace_manager is manager + assert "reload" not in created_servers[0].config.kwargs + assert "reload_dirs" not in created_servers[0].config.kwargs + assert "reload" not in created_servers[1].config.kwargs + assert "reload_dirs" not in created_servers[1].config.kwargs + assert fake_grpc_server.started is True + assert created_servers[0].handle_exit is not None + assert created_servers[1].handle_exit is not None + assert any(route.path == "/ws/traces" for route in main_app.routes) + assert any(route.path == "/stream/ui-updates" for route in main_app.routes) + assert any(route.path == "/v1/traces" for route in otlp_app.routes) + assert any(route.path == "/v1/logs" for route in otlp_app.routes) + fake_stop_grpc.assert_awaited_once_with(fake_grpc_server) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 614be24..34fdec0 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -19,7 +19,6 @@ ) from agentevals.runner import MetricResult, RunResult, TraceResult - # --------------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------------