From 58121de81b17ad79a1fc2149b6a747b7e9e644ba Mon Sep 17 00:00:00 2001 From: Trey Date: Thu, 29 Jan 2026 03:37:49 -0800 Subject: [PATCH 1/2] Implement 99 --- bot2/pyproject.toml | 2 + bot2/src/bot/service/__init__.py | 8 + bot2/src/bot/service/__main__.py | 66 +++ bot2/src/bot/service/bot_service.py | 119 ++++++ bot2/tests/unit/service/test_bot_service.py | 447 ++++++++++++++++++++ bot2/uv.lock | 90 +++- 6 files changed, 714 insertions(+), 18 deletions(-) create mode 100644 bot2/src/bot/service/__main__.py create mode 100644 bot2/src/bot/service/bot_service.py create mode 100644 bot2/tests/unit/service/test_bot_service.py diff --git a/bot2/pyproject.toml b/bot2/pyproject.toml index c882d2d..da84457 100644 --- a/bot2/pyproject.toml +++ b/bot2/pyproject.toml @@ -17,6 +17,8 @@ dependencies = [ "tensorboard>=2.14.0", "scipy>=1.17.0", "matplotlib>=3.8.0", + "fastapi>=0.115.0", + "uvicorn>=0.34.0", ] [build-system] diff --git a/bot2/src/bot/service/__init__.py b/bot2/src/bot/service/__init__.py index 4a42b18..7780721 100644 --- a/bot2/src/bot/service/__init__.py +++ b/bot2/src/bot/service/__init__.py @@ -7,6 +7,11 @@ SpawnBotRequest, SpawnBotResponse, ) +from bot.service.bot_service import ( + app, + get_bot_manager, + set_bot_manager, +) from bot.service.websocket_bot import ( BotRunnerProtocol, WebSocketBotClient, @@ -22,4 +27,7 @@ "BotManager", "SpawnBotRequest", "SpawnBotResponse", + "app", + "get_bot_manager", + "set_bot_manager", ] diff --git a/bot2/src/bot/service/__main__.py b/bot2/src/bot/service/__main__.py new file mode 100644 index 0000000..0f68dbb --- /dev/null +++ b/bot2/src/bot/service/__main__.py @@ -0,0 +1,66 @@ +"""Entry point for running the Bot Service. + +Usage: + uv run python -m bot.service + +Environment variables: + BOT_SERVICE_PORT: Port to listen on (default: 8080) + BOT_SERVICE_HOST: Host to bind (default: 0.0.0.0) + GAME_SERVER_HTTP_URL: Game server HTTP URL (default: http://localhost:4000) + GAME_SERVER_WS_URL: Game server WebSocket URL (default: ws://localhost:4000/ws) + MODEL_REGISTRY_PATH: Path to model registry (optional) + DEFAULT_DEVICE: Device for inference (default: cpu) +""" + +import logging +import os + +import uvicorn + +from bot.service.bot_manager import BotManager +from bot.service.bot_service import app, set_bot_manager +from bot.training.registry import ModelRegistry + + +def main() -> None: + """Initialize and run the Bot Service.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + + # Load configuration from environment + port = int(os.environ.get("BOT_SERVICE_PORT", "8080")) + host = os.environ.get("BOT_SERVICE_HOST", "0.0.0.0") + http_url = os.environ.get("GAME_SERVER_HTTP_URL", "http://localhost:4000") + ws_url = os.environ.get("GAME_SERVER_WS_URL", "ws://localhost:4000/ws") + registry_path = os.environ.get("MODEL_REGISTRY_PATH") + device = os.environ.get("DEFAULT_DEVICE", "cpu") + + # Initialize ModelRegistry if path is provided + registry: ModelRegistry | None = None + if registry_path: + logger.info(f"Loading model registry from {registry_path}") + registry = ModelRegistry(registry_path) + else: + logger.warning("MODEL_REGISTRY_PATH not set - neural network bots unavailable") + + # Initialize BotManager + manager = BotManager( + registry=registry, + http_url=http_url, + ws_url=ws_url, + default_device=device, + ) + set_bot_manager(manager) + + logger.info(f"Starting Bot Service on {host}:{port}") + logger.info(f"Game server: HTTP={http_url}, WS={ws_url}") + + # Run uvicorn server + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + main() diff --git a/bot2/src/bot/service/bot_service.py b/bot2/src/bot/service/bot_service.py new file mode 100644 index 0000000..822e208 --- /dev/null +++ b/bot2/src/bot/service/bot_service.py @@ -0,0 +1,119 @@ +"""FastAPI Bot Service for managing game bots. + +This module provides the REST API layer for the Bot Service, exposing endpoints +for spawning bots, managing active bots, listing available models, and health checks. +""" + +import logging +from contextlib import asynccontextmanager +from typing import Annotated + +from fastapi import Depends, FastAPI, HTTPException, status + +from bot.service.bot_manager import ( + BotInfo, + BotManager, + SpawnBotRequest, + SpawnBotResponse, +) +from bot.training.registry import ModelMetadata + +logger = logging.getLogger(__name__) + +# Global bot manager instance (initialized in __main__.py) +_bot_manager: BotManager | None = None + + +def get_bot_manager() -> BotManager: + """Dependency that provides the BotManager instance.""" + if _bot_manager is None: + raise RuntimeError("BotManager not initialized") + return _bot_manager + + +def set_bot_manager(manager: BotManager) -> None: + """Set the global BotManager instance (called from __main__.py).""" + global _bot_manager + _bot_manager = manager + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler for startup/shutdown events.""" + yield + # Shutdown: cleanup bot manager + if _bot_manager is not None: + logger.info("Shutting down bot manager...") + await _bot_manager.shutdown() + + +app = FastAPI( + title="Bot Service", + description="Service for spawning and managing game bots", + version="0.1.0", + lifespan=lifespan, +) + + +@app.get("/health") +def health_check() -> dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy"} + + +@app.get("/bots/models", response_model=list[ModelMetadata]) +def list_models( + manager: Annotated[BotManager, Depends(get_bot_manager)], +) -> list[ModelMetadata]: + """List available trained models.""" + return manager.list_models() + + +@app.post("/bots/spawn", response_model=SpawnBotResponse) +async def spawn_bot( + request: SpawnBotRequest, + manager: Annotated[BotManager, Depends(get_bot_manager)], +) -> SpawnBotResponse: + """Spawn a bot to join a game room. + + Returns immediately with bot_id. Bot connects in background. + """ + return await manager.spawn_bot(request) + + +@app.delete("/bots/{bot_id}") +async def delete_bot( + bot_id: str, + manager: Annotated[BotManager, Depends(get_bot_manager)], +) -> dict[str, bool]: + """Remove a bot from its game.""" + success = await manager.destroy_bot(bot_id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Bot {bot_id} not found", + ) + return {"success": True} + + +@app.get("/bots", response_model=list[BotInfo]) +def list_bots( + manager: Annotated[BotManager, Depends(get_bot_manager)], +) -> list[BotInfo]: + """List all active bots.""" + return manager.list_bots() + + +@app.get("/bots/{bot_id}", response_model=BotInfo) +def get_bot( + bot_id: str, + manager: Annotated[BotManager, Depends(get_bot_manager)], +) -> BotInfo: + """Get information about a specific bot.""" + bot = manager.get_bot(bot_id) + if bot is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Bot {bot_id} not found", + ) + return bot diff --git a/bot2/tests/unit/service/test_bot_service.py b/bot2/tests/unit/service/test_bot_service.py new file mode 100644 index 0000000..db0db60 --- /dev/null +++ b/bot2/tests/unit/service/test_bot_service.py @@ -0,0 +1,447 @@ +"""Unit tests for FastAPI Bot Service. + +Tests cover: +- POST /bots/spawn endpoint +- DELETE /bots/{bot_id} endpoint +- GET /bots endpoint +- GET /bots/{bot_id} endpoint +- GET /bots/models endpoint +- GET /health endpoint +- Route order (models vs bot_id) +- BotManager not initialized error +""" + +from collections.abc import Generator +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +from bot.service.bot_manager import BotInfo, SpawnBotResponse +from bot.service.bot_service import app, get_bot_manager, set_bot_manager +from bot.training.registry import ( + ModelMetadata, + NetworkArchitecture, + TrainingMetrics, +) + + +@pytest.fixture +def mock_manager() -> MagicMock: + """Create a mock BotManager for testing.""" + manager = MagicMock() + manager.spawn_bot = AsyncMock() + manager.destroy_bot = AsyncMock() + manager.get_bot = MagicMock() + manager.list_bots = MagicMock() + manager.list_models = MagicMock() + manager.shutdown = AsyncMock() + return manager + + +@pytest.fixture +def client(mock_manager: MagicMock) -> Generator[TestClient, None, None]: + """Create a TestClient with mocked BotManager.""" + app.dependency_overrides[get_bot_manager] = lambda: mock_manager + with TestClient(app) as test_client: + yield test_client + app.dependency_overrides.clear() + + +@pytest.fixture +def sample_bot_info() -> BotInfo: + """Create a sample BotInfo for testing.""" + return BotInfo( + bot_id="bot_abc123", + bot_type="rule_based", + model_id=None, + player_name="TestBot", + room_code="ABC123", + is_connected=True, + ) + + +@pytest.fixture +def sample_model_metadata() -> ModelMetadata: + """Create a sample ModelMetadata for testing.""" + return ModelMetadata( + model_id="ppo_gen_005", + generation=5, + created_at=datetime(2024, 1, 1), + training_duration_seconds=3600.0, + training_metrics=TrainingMetrics( + total_episodes=1000, + total_timesteps=100000, + average_reward=50.0, + average_episode_length=100.0, + win_rate=0.5, + average_kills=2.0, + average_deaths=2.0, + kills_deaths_ratio=1.0, + ), + architecture=NetworkArchitecture( + observation_size=128, + action_size=32, + ), + checkpoint_path="test.pth", + ) + + +class TestHealthEndpoint: + """Tests for GET /health endpoint.""" + + def test_health_check_returns_healthy(self, client: TestClient) -> None: + """GET /health returns healthy status.""" + response = client.get("/health") + + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +class TestSpawnBotEndpoint: + """Tests for POST /bots/spawn endpoint.""" + + def test_spawn_bot_success( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """POST /bots/spawn returns bot_id on success.""" + mock_manager.spawn_bot.return_value = SpawnBotResponse( + success=True, bot_id="bot_123" + ) + + response = client.post( + "/bots/spawn", + json={ + "room_code": "ABC123", + "room_password": "", + "bot_config": {"bot_type": "rule_based", "player_name": "TestBot"}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["bot_id"] == "bot_123" + assert data["error"] is None + + def test_spawn_bot_with_password( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """POST /bots/spawn accepts room password.""" + mock_manager.spawn_bot.return_value = SpawnBotResponse( + success=True, bot_id="bot_456" + ) + + response = client.post( + "/bots/spawn", + json={ + "room_code": "XYZ789", + "room_password": "secret123", + "bot_config": {"bot_type": "rule_based"}, + }, + ) + + assert response.status_code == 200 + assert response.json()["success"] is True + + # Verify password was passed to manager + call_args = mock_manager.spawn_bot.call_args + request = call_args[0][0] + assert request.room_password == "secret123" + + def test_spawn_neural_network_bot( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """POST /bots/spawn works with neural network bot.""" + mock_manager.spawn_bot.return_value = SpawnBotResponse( + success=True, bot_id="bot_nn_001" + ) + + response = client.post( + "/bots/spawn", + json={ + "room_code": "ABC123", + "bot_config": { + "bot_type": "neural_network", + "model_id": "ppo_gen_005", + "player_name": "NeuralBot", + }, + }, + ) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["bot_id"] == "bot_nn_001" + + def test_spawn_bot_failure_returns_error( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """POST /bots/spawn returns error message on failure.""" + mock_manager.spawn_bot.return_value = SpawnBotResponse( + success=False, error="Model not found" + ) + + response = client.post( + "/bots/spawn", + json={ + "room_code": "ABC123", + "bot_config": { + "bot_type": "neural_network", + "model_id": "invalid_model", + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["bot_id"] is None + assert data["error"] == "Model not found" + + def test_spawn_bot_invalid_request_returns_422( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """POST /bots/spawn with invalid request returns 422.""" + response = client.post( + "/bots/spawn", + json={ + "room_code": "ABC123", + # Missing required bot_config + }, + ) + + assert response.status_code == 422 + + def test_spawn_bot_invalid_bot_type_returns_422( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """POST /bots/spawn with invalid bot_type returns 422.""" + response = client.post( + "/bots/spawn", + json={ + "room_code": "ABC123", + "bot_config": {"bot_type": "invalid_type"}, + }, + ) + + assert response.status_code == 422 + + +class TestDeleteBotEndpoint: + """Tests for DELETE /bots/{bot_id} endpoint.""" + + def test_delete_bot_success( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """DELETE /bots/{bot_id} returns success for existing bot.""" + mock_manager.destroy_bot.return_value = True + + response = client.delete("/bots/bot_abc123") + + assert response.status_code == 200 + assert response.json() == {"success": True} + mock_manager.destroy_bot.assert_called_once_with("bot_abc123") + + def test_delete_bot_not_found( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """DELETE /bots/{bot_id} returns 404 for unknown bot.""" + mock_manager.destroy_bot.return_value = False + + response = client.delete("/bots/unknown_bot") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +class TestListBotsEndpoint: + """Tests for GET /bots endpoint.""" + + def test_list_bots_returns_all( + self, + client: TestClient, + mock_manager: MagicMock, + sample_bot_info: BotInfo, + ) -> None: + """GET /bots returns list of active bots.""" + bot2 = BotInfo( + bot_id="bot_def456", + bot_type="neural_network", + model_id="ppo_gen_003", + player_name="NeuralBot", + room_code="XYZ789", + is_connected=False, + ) + mock_manager.list_bots.return_value = [sample_bot_info, bot2] + + response = client.get("/bots") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["bot_id"] == "bot_abc123" + assert data[1]["bot_id"] == "bot_def456" + + def test_list_bots_empty(self, client: TestClient, mock_manager: MagicMock) -> None: + """GET /bots returns empty list when no bots.""" + mock_manager.list_bots.return_value = [] + + response = client.get("/bots") + + assert response.status_code == 200 + assert response.json() == [] + + +class TestGetBotEndpoint: + """Tests for GET /bots/{bot_id} endpoint.""" + + def test_get_bot_success( + self, + client: TestClient, + mock_manager: MagicMock, + sample_bot_info: BotInfo, + ) -> None: + """GET /bots/{bot_id} returns BotInfo for existing bot.""" + mock_manager.get_bot.return_value = sample_bot_info + + response = client.get("/bots/bot_abc123") + + assert response.status_code == 200 + data = response.json() + assert data["bot_id"] == "bot_abc123" + assert data["bot_type"] == "rule_based" + assert data["player_name"] == "TestBot" + assert data["room_code"] == "ABC123" + assert data["is_connected"] is True + mock_manager.get_bot.assert_called_once_with("bot_abc123") + + def test_get_bot_not_found( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """GET /bots/{bot_id} returns 404 for unknown bot.""" + mock_manager.get_bot.return_value = None + + response = client.get("/bots/unknown_bot") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +class TestListModelsEndpoint: + """Tests for GET /bots/models endpoint.""" + + def test_list_models_returns_models( + self, + client: TestClient, + mock_manager: MagicMock, + sample_model_metadata: ModelMetadata, + ) -> None: + """GET /bots/models returns list of available models.""" + mock_manager.list_models.return_value = [sample_model_metadata] + + response = client.get("/bots/models") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["model_id"] == "ppo_gen_005" + assert data[0]["generation"] == 5 + + def test_list_models_empty( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """GET /bots/models returns empty list when no models.""" + mock_manager.list_models.return_value = [] + + response = client.get("/bots/models") + + assert response.status_code == 200 + assert response.json() == [] + + +class TestRouteOrder: + """Tests to verify correct route order (models vs bot_id).""" + + def test_models_route_not_interpreted_as_bot_id( + self, + client: TestClient, + mock_manager: MagicMock, + ) -> None: + """GET /bots/models is not interpreted as GET /bots/{bot_id='models'}.""" + mock_manager.list_models.return_value = [] + mock_manager.get_bot.return_value = None + + response = client.get("/bots/models") + + # Should call list_models, not get_bot + mock_manager.list_models.assert_called_once() + mock_manager.get_bot.assert_not_called() + assert response.status_code == 200 + + +class TestBotManagerNotInitialized: + """Tests for BotManager not initialized error.""" + + def test_spawn_bot_manager_not_initialized(self) -> None: + """POST /bots/spawn raises RuntimeError when BotManager not initialized.""" + # Clear the global manager + set_bot_manager(None) # type: ignore[arg-type] + app.dependency_overrides.clear() + + with TestClient(app, raise_server_exceptions=False) as client: + response = client.post( + "/bots/spawn", + json={ + "room_code": "ABC123", + "bot_config": {"bot_type": "rule_based"}, + }, + ) + + assert response.status_code == 500 + + def test_list_bots_manager_not_initialized(self) -> None: + """GET /bots raises RuntimeError when BotManager not initialized.""" + set_bot_manager(None) # type: ignore[arg-type] + app.dependency_overrides.clear() + + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/bots") + + assert response.status_code == 500 + + def test_health_works_without_manager(self) -> None: + """GET /health works even when BotManager not initialized.""" + set_bot_manager(None) # type: ignore[arg-type] + app.dependency_overrides.clear() + + with TestClient(app) as client: + response = client.get("/health") + + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +class TestEdgeCases: + """Tests for edge cases and special scenarios.""" + + def test_long_bot_id(self, client: TestClient, mock_manager: MagicMock) -> None: + """Handles very long bot_id values.""" + long_id = "bot_" + "x" * 1000 + mock_manager.get_bot.return_value = None + + response = client.get(f"/bots/{long_id}") + + assert response.status_code == 404 + mock_manager.get_bot.assert_called_once_with(long_id) + + def test_special_characters_in_bot_id( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """Handles special characters in bot_id.""" + mock_manager.get_bot.return_value = None + + response = client.get("/bots/bot_abc-123_456") + + assert response.status_code == 404 + mock_manager.get_bot.assert_called_once_with("bot_abc-123_456") diff --git a/bot2/uv.lock b/bot2/uv.lock index fe723b3..1d60a9c 100644 --- a/bot2/uv.lock +++ b/bot2/uv.lock @@ -15,6 +15,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -42,6 +51,7 @@ name = "bot" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "fastapi" }, { name = "gymnasium" }, { name = "httpx" }, { name = "matplotlib" }, @@ -53,6 +63,7 @@ dependencies = [ { name = "tensorboard" }, { name = "torch" }, { name = "typer" }, + { name = "uvicorn" }, { name = "websockets" }, ] @@ -67,6 +78,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "fastapi", specifier = ">=0.115.0" }, { name = "gymnasium", specifier = ">=0.29.0" }, { name = "httpx", specifier = ">=0.25.0" }, { name = "matplotlib", specifier = ">=3.8.0" }, @@ -78,6 +90,7 @@ requires-dist = [ { name = "tensorboard", specifier = ">=2.14.0" }, { name = "torch", specifier = ">=2.7.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "typer", specifier = ">=0.9.0" }, + { name = "uvicorn", specifier = ">=0.34.0" }, { name = "websockets", specifier = ">=13.0" }, ] @@ -229,6 +242,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/2c/ffc08c54c05cdce6fbed2aeebc46348dbe180c6d2c541c7af7ba0aa5f5f8/Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae", size = 2511, upload-time = "2023-02-27T18:28:39.447Z" }, ] +[[package]] +name = "fastapi" +version = "0.128.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" }, +] + [[package]] name = "filelock" version = "3.20.1" @@ -1440,6 +1468,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "starlette" +version = "0.50.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/b8/73a0e6a6e079a9d9cfa64113d771e421640b6f679a52eeb9b32f72d871a1/starlette-0.50.0.tar.gz", hash = "sha256:a2a17b22203254bcbc2e1f926d2d55f3f9497f769416b3190768befe598fa3ca", size = 2646985, upload-time = "2025-11-01T15:25:27.516Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033, upload-time = "2025-11-01T15:25:25.461Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -1512,24 +1553,24 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:cf4ad82430824a80a9f398e29369524ed26c152cf00c2c12002e5400b35e260d" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2a1da940f0757621d098c9755f7504d791a72a40920ec85a4fd98b20253fca4e" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:633005a3700e81b5be0df2a7d3c1d48aced23ed927653797a3bd2b144a3aeeb6" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1176f250311fa95cc3bca8077af323e0d73ea385ba266e096af82e7e2b91f256" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7cb4018f4ce68b61fd3ef87dc1c4ca520731c7b5b200e360ad47b612d7844063" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:3a01f0b64c10a82d444d9fd06b3e8c567b1158b76b2764b8f51bfd8f535064b0" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:0b80b7555dcd0a75b7b06016991f01281a0bb078cf28fa2d1dfb949fad2fbd07" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:63381a109a569b280ed3319da89d3afe5cf9ab5c879936382a212affb5c90552" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:ad9183864acdd99fc5143d7ca9d3d2e7ddfc9a9600ff43217825d4e5e9855ccc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2314521c74d76e513c53bb72c0ce3511ef0295ff657a432790df6c207e5d7962" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4454a4faca31af81566e3a4208f10f20b8a6d9cfe42791b0ca7ff134326468fc" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:24420e430e77136f7079354134b34e7ba9d87e539f5ac84c33b08e5c13412ebe" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:32c036296c557f19a1537ce981c40533650097114e1720a321a39a3b08d9df56" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:7788d3d03d939cf00f93ac0da5ab520846f66411e339cfbf519a806e8facf519" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:7bcd40cbffac475b478d6ce812f03da84e9a4894956efb89c3b7bcca5dbd4f91" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:e88c78e5b08ae9303aa15da43b68b44287ecbec16d898d9fad6998832fe626a5" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7d8769bdf3200ca16a92f14df404c3370171ac3732996528a8973d753eac562f" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.9.1%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:0c784b600959ec70ee01cb23e8bc870a0e0475af30378ff5e39f4abed8b7c1cc" }, ] [[package]] @@ -1621,6 +1662,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, ] +[[package]] +name = "uvicorn" +version = "0.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/d1/8f3c683c9561a4e6689dd3b1d345c815f10f86acd044ee1fb9a4dcd0b8c5/uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea", size = 81761, upload-time = "2025-12-21T14:16:22.45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/d8/2083a1daa7439a66f3a48589a57d576aa117726762618f6bb09fe3798796/uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee", size = 68502, upload-time = "2025-12-21T14:16:21.041Z" }, +] + [[package]] name = "websockets" version = "15.0.1" From c06e9d782e99839f16a38baea2c3eaac7beb145f Mon Sep 17 00:00:00 2001 From: Trey Date: Sun, 1 Feb 2026 14:31:51 -0800 Subject: [PATCH 2/2] Address feedback --- bot2/src/bot/service/__main__.py | 9 +- bot2/tests/unit/service/test_bot_service.py | 230 +++++++++++++++++++- 2 files changed, 229 insertions(+), 10 deletions(-) diff --git a/bot2/src/bot/service/__main__.py b/bot2/src/bot/service/__main__.py index 0f68dbb..8d7d031 100644 --- a/bot2/src/bot/service/__main__.py +++ b/bot2/src/bot/service/__main__.py @@ -31,7 +31,14 @@ def main() -> None: logger = logging.getLogger(__name__) # Load configuration from environment - port = int(os.environ.get("BOT_SERVICE_PORT", "8080")) + port_str = os.environ.get("BOT_SERVICE_PORT", "8080") + try: + port = int(port_str) + except ValueError: + logger.error( + f"Invalid BOT_SERVICE_PORT value: '{port_str}'. Must be an integer." + ) + raise host = os.environ.get("BOT_SERVICE_HOST", "0.0.0.0") http_url = os.environ.get("GAME_SERVER_HTTP_URL", "http://localhost:4000") ws_url = os.environ.get("GAME_SERVER_WS_URL", "ws://localhost:4000/ws") diff --git a/bot2/tests/unit/service/test_bot_service.py b/bot2/tests/unit/service/test_bot_service.py index db0db60..16e3a72 100644 --- a/bot2/tests/unit/service/test_bot_service.py +++ b/bot2/tests/unit/service/test_bot_service.py @@ -383,12 +383,29 @@ def test_models_route_not_interpreted_as_bot_id( class TestBotManagerNotInitialized: """Tests for BotManager not initialized error.""" - def test_spawn_bot_manager_not_initialized(self) -> None: - """POST /bots/spawn raises RuntimeError when BotManager not initialized.""" - # Clear the global manager + @pytest.fixture(autouse=True) + def clear_manager_state(self) -> Generator[None, None, None]: + """Clear and restore manager state for each test.""" + # Import module to access internal state + import bot.service.bot_service as bot_service_module + + # Save original state + original_manager = bot_service_module._bot_manager + original_overrides = app.dependency_overrides.copy() + + # Clear for test set_bot_manager(None) # type: ignore[arg-type] app.dependency_overrides.clear() + yield + + # Restore original state + bot_service_module._bot_manager = original_manager + app.dependency_overrides.clear() + app.dependency_overrides.update(original_overrides) + + def test_spawn_bot_manager_not_initialized(self) -> None: + """POST /bots/spawn raises RuntimeError when BotManager not initialized.""" with TestClient(app, raise_server_exceptions=False) as client: response = client.post( "/bots/spawn", @@ -402,9 +419,6 @@ def test_spawn_bot_manager_not_initialized(self) -> None: def test_list_bots_manager_not_initialized(self) -> None: """GET /bots raises RuntimeError when BotManager not initialized.""" - set_bot_manager(None) # type: ignore[arg-type] - app.dependency_overrides.clear() - with TestClient(app, raise_server_exceptions=False) as client: response = client.get("/bots") @@ -412,9 +426,6 @@ def test_list_bots_manager_not_initialized(self) -> None: def test_health_works_without_manager(self) -> None: """GET /health works even when BotManager not initialized.""" - set_bot_manager(None) # type: ignore[arg-type] - app.dependency_overrides.clear() - with TestClient(app) as client: response = client.get("/health") @@ -445,3 +456,204 @@ def test_special_characters_in_bot_id( assert response.status_code == 404 mock_manager.get_bot.assert_called_once_with("bot_abc-123_456") + + def test_url_encoded_space_in_bot_id( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """Handles URL-encoded space in bot_id.""" + mock_manager.get_bot.return_value = None + + # URL-encoded space (%20) decodes to a space character + response = client.get("/bots/bot%20abc") + + assert response.status_code == 404 + mock_manager.get_bot.assert_called_once_with("bot abc") + + def test_url_encoded_slash_returns_not_found( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """URL-encoded slash in bot_id is treated as path separator and returns 404.""" + mock_manager.get_bot.return_value = None + + # URL-encoded forward slash (%2F) - Starlette/ASGI decodes this as a path separator + # which means /bots/bot%2F123 becomes /bots/bot/123 and doesn't match our route + response = client.get("/bots/bot%2F123") + + # This returns 404 because /bots/bot/123 doesn't match any route + assert response.status_code == 404 + # get_bot is NOT called because the route doesn't match /bots/{bot_id} + mock_manager.get_bot.assert_not_called() + + def test_path_traversal_attempt_returns_not_found( + self, client: TestClient, mock_manager: MagicMock + ) -> None: + """Path traversal attempts don't match route and return 404.""" + mock_manager.get_bot.return_value = None + + # URL-encoded path traversal - this doesn't match our routes + response = client.get("/bots/..%2Fetc%2Fpasswd") + + # This returns 404 because the decoded path doesn't match any route + assert response.status_code == 404 + # get_bot is NOT called because the route doesn't match + mock_manager.get_bot.assert_not_called() + + +class TestMainEntry: + """Tests for the CLI entry point (__main__.py).""" + + @pytest.fixture + def mock_env_defaults(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Clear all bot service environment variables to test defaults.""" + env_vars = [ + "BOT_SERVICE_PORT", + "BOT_SERVICE_HOST", + "GAME_SERVER_HTTP_URL", + "GAME_SERVER_WS_URL", + "MODEL_REGISTRY_PATH", + "DEFAULT_DEVICE", + ] + for var in env_vars: + monkeypatch.delenv(var, raising=False) + + def test_main_uses_default_environment_values( + self, mock_env_defaults: None, monkeypatch: pytest.MonkeyPatch + ) -> None: + """main() uses default values when env vars not set.""" + from bot.service import __main__ as main_module + + captured_args: dict = {} + + def mock_uvicorn_run(app, host: str, port: int) -> None: + captured_args["host"] = host + captured_args["port"] = port + + monkeypatch.setattr("bot.service.__main__.uvicorn.run", mock_uvicorn_run) + monkeypatch.setattr( + "bot.service.__main__.BotManager", + lambda **kwargs: MagicMock(), + ) + monkeypatch.setattr("bot.service.__main__.set_bot_manager", lambda m: None) + + main_module.main() + + assert captured_args["host"] == "0.0.0.0" + assert captured_args["port"] == 8080 + + def test_main_uses_custom_environment_values( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """main() uses custom values from environment variables.""" + from bot.service import __main__ as main_module + + monkeypatch.setenv("BOT_SERVICE_PORT", "9000") + monkeypatch.setenv("BOT_SERVICE_HOST", "127.0.0.1") + monkeypatch.setenv("GAME_SERVER_HTTP_URL", "http://game:5000") + monkeypatch.setenv("GAME_SERVER_WS_URL", "ws://game:5000/ws") + monkeypatch.setenv("DEFAULT_DEVICE", "cuda") + + captured_args: dict = {} + captured_manager_args: dict = {} + + def mock_uvicorn_run(app, host: str, port: int) -> None: + captured_args["host"] = host + captured_args["port"] = port + + def mock_bot_manager(**kwargs) -> MagicMock: + captured_manager_args.update(kwargs) + return MagicMock() + + monkeypatch.setattr("bot.service.__main__.uvicorn.run", mock_uvicorn_run) + monkeypatch.setattr("bot.service.__main__.BotManager", mock_bot_manager) + monkeypatch.setattr("bot.service.__main__.set_bot_manager", lambda m: None) + + main_module.main() + + assert captured_args["host"] == "127.0.0.1" + assert captured_args["port"] == 9000 + assert captured_manager_args["http_url"] == "http://game:5000" + assert captured_manager_args["ws_url"] == "ws://game:5000/ws" + assert captured_manager_args["default_device"] == "cuda" + + def test_main_initializes_model_registry_when_path_provided( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: str + ) -> None: + """main() initializes ModelRegistry when MODEL_REGISTRY_PATH is set.""" + from bot.service import __main__ as main_module + + monkeypatch.setenv("MODEL_REGISTRY_PATH", str(tmp_path)) + + mock_registry = MagicMock() + mock_registry_class = MagicMock(return_value=mock_registry) + captured_manager_args: dict = {} + + def mock_bot_manager(**kwargs) -> MagicMock: + captured_manager_args.update(kwargs) + return MagicMock() + + monkeypatch.setattr("bot.service.__main__.uvicorn.run", lambda *a, **kw: None) + monkeypatch.setattr("bot.service.__main__.ModelRegistry", mock_registry_class) + monkeypatch.setattr("bot.service.__main__.BotManager", mock_bot_manager) + monkeypatch.setattr("bot.service.__main__.set_bot_manager", lambda m: None) + + main_module.main() + + mock_registry_class.assert_called_once_with(str(tmp_path)) + assert captured_manager_args["registry"] is mock_registry + + def test_main_logs_warning_when_registry_path_not_set( + self, mock_env_defaults: None, monkeypatch: pytest.MonkeyPatch, caplog + ) -> None: + """main() logs warning when MODEL_REGISTRY_PATH is not set.""" + from bot.service import __main__ as main_module + + monkeypatch.setattr("bot.service.__main__.uvicorn.run", lambda *a, **kw: None) + monkeypatch.setattr( + "bot.service.__main__.BotManager", + lambda **kwargs: MagicMock(), + ) + monkeypatch.setattr("bot.service.__main__.set_bot_manager", lambda m: None) + + with caplog.at_level("WARNING"): + main_module.main() + + assert any( + "MODEL_REGISTRY_PATH not set" in record.message for record in caplog.records + ) + + def test_main_passes_none_registry_when_path_not_set( + self, mock_env_defaults: None, monkeypatch: pytest.MonkeyPatch + ) -> None: + """main() passes None registry when MODEL_REGISTRY_PATH is not set.""" + from bot.service import __main__ as main_module + + captured_manager_args: dict = {} + + def mock_bot_manager(**kwargs) -> MagicMock: + captured_manager_args.update(kwargs) + return MagicMock() + + monkeypatch.setattr("bot.service.__main__.uvicorn.run", lambda *a, **kw: None) + monkeypatch.setattr("bot.service.__main__.BotManager", mock_bot_manager) + monkeypatch.setattr("bot.service.__main__.set_bot_manager", lambda m: None) + + main_module.main() + + assert captured_manager_args["registry"] is None + + def test_main_invalid_port_raises_value_error( + self, monkeypatch: pytest.MonkeyPatch, caplog + ) -> None: + """main() raises ValueError for invalid port value and logs error.""" + from bot.service import __main__ as main_module + + monkeypatch.setenv("BOT_SERVICE_PORT", "invalid_port") + + with caplog.at_level("ERROR"): + with pytest.raises(ValueError, match="invalid literal"): + main_module.main() + + assert any( + "Invalid BOT_SERVICE_PORT value" in record.message + for record in caplog.records + )