diff --git a/.claude/settings.json b/.claude/settings.json index 7c72db7..53a597f 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -6,6 +6,8 @@ "Bash(go build:*)", "Bash(task:*)", "Bash(uv run pytest:*)", + "Bash(uv run ty:*)", + "Bash(uv run ruff:*)", "Bash(tree:*)" ] } diff --git a/Taskfile.yml b/Taskfile.yml index e3102a9..e4be415 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -105,7 +105,6 @@ tasks: - bot2:lint - bot2:format - bot2:typecheck - - bot2:test:unit # # backend (golang) tasks @@ -164,7 +163,6 @@ tasks: cmds: - task: be:build - task: be:vet - - task: be:test be:run: dir: backend diff --git a/bot2/src/bot/bots/neural_net_bot.py b/bot2/src/bot/bots/neural_net_bot.py index 2d4e6a2..8d91137 100644 --- a/bot2/src/bot/bots/neural_net_bot.py +++ b/bot2/src/bot/bots/neural_net_bot.py @@ -289,15 +289,15 @@ class NeuralNetBotRunner: def __init__( self, - client: GameClient, network: ActorCriticNetwork, + client: GameClient | None = None, config: NeuralNetBotConfig | None = None, ) -> None: """Initialize the bot runner. Args: - client: GameClient for server communication network: Trained ActorCriticNetwork for inference + client: GameClient for server communication (can be set later) config: Optional configuration for the neural net bot """ self.client = client @@ -318,8 +318,10 @@ async def on_game_state(self, state: GameState) -> None: state: Current game state from the server Raises: - ValueError: If the client's player_id is not set + ValueError: If the client is not set or client's player_id is not set """ + if self.client is None: + raise ValueError("Client must be set before calling on_game_state") if self.bot is None: if self.client.player_id is None: raise ValueError( diff --git a/bot2/src/bot/bots/rule_based_bot.py b/bot2/src/bot/bots/rule_based_bot.py index 5177684..39d90ff 100644 --- a/bot2/src/bot/bots/rule_based_bot.py +++ b/bot2/src/bot/bots/rule_based_bot.py @@ -380,13 +380,13 @@ class RuleBasedBotRunner: def __init__( self, - client: GameClient, + client: GameClient | None = None, config: RuleBasedBotConfig | None = None, ) -> None: """Initialize the bot runner. Args: - client: The game client for server communication. + client: The game client for server communication (can be set later). config: Optional configuration for the rule-based bot. """ self.client = client @@ -405,8 +405,10 @@ async def on_game_state(self, state: GameState) -> None: state: The current game state from the server. Raises: - ValueError: If the client's player_id is not set. + ValueError: If the client is not set or client's player_id is not set. """ + if self.client is None: + raise ValueError("Client must be set before calling on_game_state") if self.bot is None: if self.client.player_id is None: raise ValueError( diff --git a/bot2/src/bot/gym/opponent_manager.py b/bot2/src/bot/gym/opponent_manager.py index f25175d..86c1f2e 100644 --- a/bot2/src/bot/gym/opponent_manager.py +++ b/bot2/src/bot/gym/opponent_manager.py @@ -127,7 +127,7 @@ async def start(self, room_code: str, room_password: str = "") -> None: ) # Initialize the bot runner - self._runner = RuleBasedBotRunner(self._client, self.config) + self._runner = RuleBasedBotRunner(client=self._client, config=self.config) self._running = True self._logger.info( diff --git a/bot2/src/bot/service/__init__.py b/bot2/src/bot/service/__init__.py index d96d5f3..4a42b18 100644 --- a/bot2/src/bot/service/__init__.py +++ b/bot2/src/bot/service/__init__.py @@ -1,5 +1,12 @@ """Bot service layer - WebSocket connectivity and bot lifecycle management.""" +from bot.service.bot_manager import ( + BotConfig, + BotInfo, + BotManager, + SpawnBotRequest, + SpawnBotResponse, +) from bot.service.websocket_bot import ( BotRunnerProtocol, WebSocketBotClient, @@ -10,4 +17,9 @@ "BotRunnerProtocol", "WebSocketBotClient", "WebSocketBotClientConfig", + "BotConfig", + "BotInfo", + "BotManager", + "SpawnBotRequest", + "SpawnBotResponse", ] diff --git a/bot2/src/bot/service/bot_manager.py b/bot2/src/bot/service/bot_manager.py new file mode 100644 index 0000000..d19b16c --- /dev/null +++ b/bot2/src/bot/service/bot_manager.py @@ -0,0 +1,422 @@ +"""Bot manager for coordinating bot lifecycle in the Bot Service. + +This module provides the BotManager class that serves as the central coordinator +for bot lifecycle management. It tracks active bot instances, loads trained neural +network models from the ModelRegistry, and handles spawning/destroying bot connections. +""" + +import asyncio +import logging +import uuid +from typing import Literal + +from pydantic import BaseModel, Field, model_validator + +from bot.agent.network import ActorCriticNetwork +from bot.bots.neural_net_bot import NeuralNetBotConfig, NeuralNetBotRunner +from bot.bots.rule_based_bot import RuleBasedBotConfig, RuleBasedBotRunner +from bot.service.websocket_bot import WebSocketBotClient, WebSocketBotClientConfig +from bot.training.registry import ModelMetadata, ModelNotFoundError, ModelRegistry + + +class BotConfig(BaseModel): + """Configuration for spawning a bot. + + Attributes: + bot_type: Type of bot to spawn ("rule_based" or "neural_network") + model_id: Model ID for neural network bots (e.g., "ppo_gen_005") + generation: Generation number as alternative to model_id + player_name: Display name for the bot in the game + """ + + bot_type: Literal["rule_based", "neural_network"] + model_id: str | None = Field( + default=None, description="Model ID for neural_network type" + ) + generation: int | None = Field( + default=None, description="Generation number (alternative to model_id)" + ) + player_name: str = Field(default="Bot", description="Display name for the bot") + + @model_validator(mode="after") + def validate_bot_config(self) -> "BotConfig": + """Validate that bot configuration is consistent with bot type.""" + if self.bot_type == "rule_based": + if self.model_id is not None or self.generation is not None: + raise ValueError( + "Rule-based bots cannot specify model_id or generation" + ) + elif self.bot_type == "neural_network": + if self.model_id is None and self.generation is None: + raise ValueError( + "Neural network bots require either model_id or generation" + ) + return self + + +class SpawnBotRequest(BaseModel): + """Request to spawn a bot into a game room. + + Attributes: + room_code: Room code to join + room_password: Room password (empty string if none) + bot_config: Configuration for the bot + """ + + room_code: str + room_password: str = "" + bot_config: BotConfig + + +class SpawnBotResponse(BaseModel): + """Response from spawning a bot. + + Attributes: + success: Whether the spawn was initiated successfully + bot_id: Unique identifier for the spawned bot (if success) + error: Error message (if not success) + """ + + success: bool + bot_id: str | None = None + error: str | None = None + + +class BotInfo(BaseModel): + """Information about an active bot. + + Attributes: + bot_id: Unique identifier for the bot + bot_type: Type of bot ("rule_based" or "neural_network") + model_id: Model ID if neural network bot + player_name: Display name of the bot + room_code: Room the bot is in + is_connected: Whether the bot is currently connected + """ + + bot_id: str + bot_type: str + model_id: str | None = None + player_name: str + room_code: str + is_connected: bool + + +class _BotEntry: + """Internal tracking entry for active bots.""" + + def __init__( + self, + bot_id: str, + client: WebSocketBotClient, + bot_config: BotConfig, + room_code: str, + model_id: str | None = None, + ) -> None: + self.bot_id = bot_id + self.client = client + self.bot_config = bot_config + self.room_code = room_code + self.model_id = model_id + + +class BotManager: + """Manages bot lifecycle for the Bot Service. + + Tracks active bots, loads models from ModelRegistry, and coordinates + spawning/destroying bot connections. + + Example: + registry = ModelRegistry("/path/to/registry") + manager = BotManager( + registry=registry, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + # Spawn a neural network bot + request = SpawnBotRequest( + room_code="ABC123", + room_password="secret", + bot_config=BotConfig( + bot_type="neural_network", + generation=5, + player_name="NeuralBot", + ), + ) + response = await manager.spawn_bot(request) + if response.success: + print(f"Bot spawned: {response.bot_id}") + + # Later, destroy the bot + await manager.destroy_bot(response.bot_id) + """ + + def __init__( + self, + registry: ModelRegistry | None, + http_url: str, + ws_url: str, + default_device: str = "cpu", + ) -> None: + """Initialize the BotManager. + + Args: + registry: ModelRegistry for loading trained models (None if no ML models) + http_url: Base URL for game server REST API + ws_url: WebSocket URL for game server + default_device: Device for neural network inference ("cpu" or "cuda") + """ + self._registry = registry + self._http_url = http_url + self._ws_url = ws_url + self._device = default_device + self._bots: dict[str, _BotEntry] = {} + self._tasks: set[asyncio.Task] = set() + self._logger = logging.getLogger(__name__) + + async def spawn_bot(self, request: SpawnBotRequest) -> SpawnBotResponse: + """Spawn a bot to join a game room. + + Creates the bot instance and starts connecting in the background. + Returns immediately with the bot_id. + + Args: + request: Spawn request with room info and bot config + + Returns: + SpawnBotResponse with success status and bot_id + """ + bot_id = self._generate_bot_id() + + try: + # Create bot client + client, model_id = self._create_bot_client(request.bot_config) + + # Track the bot + entry = _BotEntry( + bot_id=bot_id, + client=client, + bot_config=request.bot_config, + room_code=request.room_code, + model_id=model_id, + ) + self._bots[bot_id] = entry + + # Start connection in background - don't await + task = asyncio.create_task(self._connect_bot(bot_id, client, request)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + + self._logger.info( + "Bot spawned: %s (%s)", bot_id, request.bot_config.bot_type + ) + return SpawnBotResponse(success=True, bot_id=bot_id) + + except ModelNotFoundError as e: + return SpawnBotResponse(success=False, error=str(e)) + except ValueError as e: + return SpawnBotResponse(success=False, error=str(e)) + except Exception as e: + # Clean up if bot was added before the exception + self._bots.pop(bot_id, None) + self._logger.error("Failed to spawn bot: %s", e) + return SpawnBotResponse(success=False, error=str(e)) + + async def destroy_bot(self, bot_id: str) -> bool: + """Stop and remove a bot. + + Args: + bot_id: ID of the bot to destroy + + Returns: + True if bot was found and destroyed, False if not found + """ + entry = self._bots.pop(bot_id, None) + if entry is None: + return False + + try: + await entry.client.stop() + self._logger.info("Bot destroyed: %s", bot_id) + except Exception as e: + self._logger.warning("Error stopping bot %s: %s", bot_id, e) + + return True + + def get_bot(self, bot_id: str) -> BotInfo | None: + """Get information about a specific bot. + + Args: + bot_id: ID of the bot to look up + + Returns: + BotInfo if found, None otherwise + """ + entry = self._bots.get(bot_id) + if entry is None: + return None + + return BotInfo( + bot_id=entry.bot_id, + bot_type=entry.bot_config.bot_type, + model_id=entry.model_id, + player_name=entry.bot_config.player_name, + room_code=entry.room_code, + is_connected=entry.client.is_connected, + ) + + def list_bots(self) -> list[BotInfo]: + """List all active bots. + + Returns: + List of BotInfo for all tracked bots + """ + result = [] + for bot_id in self._bots: + bot_info = self.get_bot(bot_id) + if bot_info is not None: + result.append(bot_info) + return result + + def list_models(self) -> list[ModelMetadata]: + """List available trained models. + + Returns: + List of ModelMetadata from the registry + """ + if self._registry is None: + return [] + return self._registry.list_models() + + async def await_pending_tasks(self) -> None: + """Await completion of pending background tasks. + + This method allows callers to wait for background connection tasks to + complete. Useful for testing or when you need to ensure bots are fully + connected before proceeding. + """ + if self._tasks: + await asyncio.gather(*self._tasks, return_exceptions=True) + + async def shutdown(self) -> None: + """Gracefully shut down all bots. + + Should be called when the service is stopping. + """ + self._logger.info( + "Shutting down BotManager with %d active bots", len(self._bots) + ) + + # Stop all bots concurrently + bot_ids = list(self._bots.keys()) + tasks = [self.destroy_bot(bot_id) for bot_id in bot_ids] + await asyncio.gather(*tasks, return_exceptions=True) + + # Wait for any pending background tasks to complete + if self._tasks: + self._logger.info("Waiting for %d background tasks", len(self._tasks)) + await asyncio.gather(*self._tasks, return_exceptions=True) + + self._logger.info("BotManager shutdown complete") + + def _generate_bot_id(self) -> str: + """Generate a unique bot ID.""" + return f"bot_{uuid.uuid4().hex[:12]}" + + def _create_bot_client( + self, config: BotConfig + ) -> tuple[WebSocketBotClient, str | None]: + """Create a WebSocket bot client with appropriate runner. + + Args: + config: Bot configuration + + Returns: + Tuple of (WebSocketBotClient, model_id or None) + + Raises: + ModelNotFoundError: If neural network model not found + ValueError: If invalid configuration + """ + # Create WebSocket config + ws_config = WebSocketBotClientConfig( + http_url=self._http_url, + ws_url=self._ws_url, + player_name=config.player_name, + ) + + model_id = None + + if config.bot_type == "rule_based": + # Create rule-based bot runner + runner = RuleBasedBotRunner( + client=None, + config=RuleBasedBotConfig(), + ) + else: # neural_network + # Load model from registry + network, metadata = self._load_model(config) + model_id = metadata.model_id + + # Create neural network bot runner + runner = NeuralNetBotRunner( + network=network, + client=None, + config=NeuralNetBotConfig(device=self._device), + ) + + return WebSocketBotClient(runner=runner, config=ws_config), model_id + + def _load_model( + self, config: BotConfig + ) -> tuple[ActorCriticNetwork, ModelMetadata]: + """Load a neural network model from the registry. + + Args: + config: Bot configuration with model_id or generation + + Returns: + Tuple of (ActorCriticNetwork, ModelMetadata) + + Raises: + ModelNotFoundError: If model not found + ValueError: If no registry or invalid config + """ + if self._registry is None: + raise ValueError("Neural network bot requires a ModelRegistry") + + if config.model_id: + return self._registry.get_model(config.model_id, device=self._device) + elif config.generation is not None: + result = self._registry.get_model_by_generation( + config.generation, device=self._device + ) + if result is None: + raise ModelNotFoundError( + f"No model found for generation {config.generation}" + ) + return result + else: + raise ValueError("Neural network bot requires model_id or generation") + + async def _connect_bot( + self, + bot_id: str, + client: WebSocketBotClient, + request: SpawnBotRequest, + ) -> None: + """Connect bot in background with error handling. + + Args: + bot_id: Unique bot identifier + client: WebSocket bot client to connect + request: Spawn request with room details + """ + try: + await client.start(request.room_code, request.room_password) + self._logger.info("Bot %s connected to room %s", bot_id, request.room_code) + except Exception as e: + self._logger.error("Bot %s failed to connect: %s", bot_id, e) + # Remove from tracking on failure + self._bots.pop(bot_id, None) diff --git a/bot2/src/bot/service/websocket_bot.py b/bot2/src/bot/service/websocket_bot.py index ef371eb..59c9e3a 100644 --- a/bot2/src/bot/service/websocket_bot.py +++ b/bot2/src/bot/service/websocket_bot.py @@ -71,7 +71,7 @@ class WebSocketBotClient: Example: # Create a placeholder client for the runner placeholder_client = GameClient(http_url="http://localhost:4000") - runner = NeuralNetBotRunner(client=placeholder_client, network=network) + runner = NeuralNetBotRunner(network=network, client=placeholder_client) # WebSocketBotClient will replace the client during start() bot_client = WebSocketBotClient(runner=runner, config=config) diff --git a/bot2/tests/integration/test_bot_manager_integration.py b/bot2/tests/integration/test_bot_manager_integration.py new file mode 100644 index 0000000..fbda70e --- /dev/null +++ b/bot2/tests/integration/test_bot_manager_integration.py @@ -0,0 +1,768 @@ +"""Integration tests for BotManager with real backend. + +Tests cover: +- Spawning rule-based bots and verifying connection to game server +- Bot lifecycle management (spawn, query status, destroy) +- Multiple bots in the same or different rooms +- Error handling for invalid room codes and configurations +- Graceful shutdown of BotManager with active bots +""" + +import asyncio + +import pytest + +from bot.client import ClientMode, GameClient +from bot.models import GameState +from bot.service.bot_manager import ( + BotConfig, + BotManager, + SpawnBotRequest, +) +from tests.conftest import requires_server, unique_room_name + + +@pytest.mark.integration +class TestBotManagerSpawnRuleBasedBot: + """Integration tests for spawning rule-based bots.""" + + @requires_server + @pytest.mark.asyncio + async def test_spawn_rule_based_bot_connects_to_game(self, server_url: str) -> None: + """Spawn a rule-based bot and verify it connects to an existing game.""" + room_name = unique_room_name("BotMgrSpawn") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + # First, create a game room to join + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + assert create_response.success is True + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + # Create BotManager (no registry needed for rule-based bots) + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn a rule-based bot + request = SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="IntegrationTestBot", + ), + ) + response = await manager.spawn_bot(request) + + assert response.success is True + assert response.bot_id is not None + assert response.bot_id.startswith("bot_") + + # Wait for background connection to complete + await manager.await_pending_tasks() + + # Verify bot is tracked and connected + bot_info = manager.get_bot(response.bot_id) + assert bot_info is not None + assert bot_info.bot_id == response.bot_id + assert bot_info.bot_type == "rule_based" + assert bot_info.player_name == "IntegrationTestBot" + assert bot_info.room_code == room_code + assert bot_info.is_connected is True + + # Verify the bot appears in the game state + state = await host_client.get_game_state() + assert isinstance(state, GameState) + # Should have at least 2 players (host + bot) + assert len(state.players) >= 2 + + finally: + await manager.shutdown() + + @requires_server + @pytest.mark.asyncio + async def test_spawn_rule_based_bot_receives_game_state( + self, server_url: str + ) -> None: + """Verify spawned bot receives game state updates.""" + room_name = unique_room_name("BotMgrState") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + # Create a game room + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn bot + request = SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="StateTestBot", + ), + ) + response = await manager.spawn_bot(request) + assert response.success is True + + # Wait for connection + await manager.await_pending_tasks() + + # Give bot time to receive game state and potentially send actions + await asyncio.sleep(0.5) + + # Bot should still be connected and active + assert response.bot_id is not None + bot_info = manager.get_bot(response.bot_id) + assert bot_info is not None + assert bot_info.is_connected is True + + finally: + await manager.shutdown() + + +@pytest.mark.integration +class TestBotManagerLifecycle: + """Integration tests for bot lifecycle management.""" + + @requires_server + @pytest.mark.asyncio + async def test_destroy_bot_disconnects_from_game(self, server_url: str) -> None: + """Verify destroying a bot properly disconnects it from the game.""" + room_name = unique_room_name("BotMgrDestroy") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn bot + request = SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="DestroyTestBot", + ), + ) + response = await manager.spawn_bot(request) + assert response.success is True + bot_id = response.bot_id + assert bot_id is not None + + # Wait for connection + await manager.await_pending_tasks() + + # Verify bot exists + assert manager.get_bot(bot_id) is not None + + # Destroy the bot + destroyed = await manager.destroy_bot(bot_id) + assert destroyed is True + + # Bot should no longer be tracked + assert manager.get_bot(bot_id) is None + + # Destroying again should return False + destroyed_again = await manager.destroy_bot(bot_id) + assert destroyed_again is False + + finally: + await manager.shutdown() + + @requires_server + @pytest.mark.asyncio + async def test_list_bots_returns_active_bots(self, server_url: str) -> None: + """Verify list_bots returns all active bots.""" + room_name = unique_room_name("BotMgrList") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Initially no bots + assert manager.list_bots() == [] + + # Spawn first bot + response1 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="Bot1", + ), + ) + ) + assert response1.success is True + + # Spawn second bot + response2 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="Bot2", + ), + ) + ) + assert response2.success is True + + # Wait for connections + await manager.await_pending_tasks() + + # List should have 2 bots + bots = manager.list_bots() + assert len(bots) == 2 + + bot_ids = {b.bot_id for b in bots} + assert response1.bot_id in bot_ids + assert response2.bot_id in bot_ids + + # Destroy one bot + assert response1.bot_id is not None + await manager.destroy_bot(response1.bot_id) + + # List should have 1 bot + bots = manager.list_bots() + assert len(bots) == 1 + assert bots[0].bot_id == response2.bot_id + + finally: + await manager.shutdown() + + +@pytest.mark.integration +class TestBotManagerMultipleBots: + """Integration tests for multiple bot scenarios.""" + + @requires_server + @pytest.mark.asyncio + async def test_multiple_bots_same_room(self, server_url: str) -> None: + """Spawn multiple bots in the same game room.""" + room_name = unique_room_name("BotMgrMulti") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn 3 bots + bot_ids = [] + for i in range(3): + response = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name=f"MultiBot{i}", + ), + ) + ) + assert response.success is True + assert response.bot_id is not None + bot_ids.append(response.bot_id) + + # All bot IDs should be unique + assert len(set(bot_ids)) == 3 + + # Wait for all connections + await manager.await_pending_tasks() + + # All bots should be connected + for bot_id in bot_ids: + bot_info = manager.get_bot(bot_id) + assert bot_info is not None + assert bot_info.is_connected is True + assert bot_info.room_code == room_code + + # Game should have multiple players + state = await host_client.get_game_state() + # Host + 3 bots = 4 players + assert len(state.players) >= 4 + + finally: + await manager.shutdown() + + @requires_server + @pytest.mark.asyncio + async def test_bots_in_different_rooms(self, server_url: str) -> None: + """Spawn bots in different game rooms.""" + room_name1 = unique_room_name("BotMgrRoom1") + room_name2 = unique_room_name("BotMgrRoom2") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client1 = GameClient(http_url=server_url, mode=ClientMode.REST) + host_client2 = GameClient(http_url=server_url, mode=ClientMode.REST) + + async with host_client1, host_client2: + # Create two separate game rooms + response1 = await host_client1.create_game( + player_name="Host1", + room_name=room_name1, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + response2 = await host_client2.create_game( + player_name="Host2", + room_name=room_name2, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + + room_code1 = response1.room_code + room_password1 = response1.room_password + room_code2 = response2.room_code + room_password2 = response2.room_password + assert room_code1 is not None + assert room_code2 is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn bot in room 1 + spawn1 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code1, + room_password=room_password1, + bot_config=BotConfig( + bot_type="rule_based", + player_name="BotRoom1", + ), + ) + ) + assert spawn1.success is True + + # Spawn bot in room 2 + spawn2 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code2, + room_password=room_password2, + bot_config=BotConfig( + bot_type="rule_based", + player_name="BotRoom2", + ), + ) + ) + assert spawn2.success is True + + # Wait for connections + await manager.await_pending_tasks() + + # Verify bots are in different rooms + assert spawn1.bot_id is not None + assert spawn2.bot_id is not None + bot1_info = manager.get_bot(spawn1.bot_id) + bot2_info = manager.get_bot(spawn2.bot_id) + + assert bot1_info is not None + assert bot2_info is not None + assert bot1_info.room_code == room_code1 + assert bot2_info.room_code == room_code2 + assert bot1_info.is_connected is True + assert bot2_info.is_connected is True + + finally: + await manager.shutdown() + + +@pytest.mark.integration +class TestBotManagerErrorHandling: + """Integration tests for error handling scenarios.""" + + @requires_server + @pytest.mark.asyncio + async def test_spawn_bot_invalid_room_code(self, server_url: str) -> None: + """Spawning a bot with invalid room code should fail gracefully.""" + ws_url = server_url.replace("http://", "ws://") + "/ws" + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn with a non-existent room code + response = await manager.spawn_bot( + SpawnBotRequest( + room_code="INVALID_ROOM_CODE_12345", + room_password="FAKE", # Need non-empty password for server + bot_config=BotConfig( + bot_type="rule_based", + player_name="InvalidRoomBot", + ), + ) + ) + + # Spawn should initially succeed (background connection) + assert response.success is True + bot_id = response.bot_id + assert bot_id is not None + + # Wait for background connection to fail + await manager.await_pending_tasks() + + # Bot should be removed after failed connection + bot_info = manager.get_bot(bot_id) + assert bot_info is None + + # List should be empty + assert manager.list_bots() == [] + + finally: + await manager.shutdown() + + @requires_server + @pytest.mark.asyncio + async def test_spawn_neural_network_bot_without_registry( + self, server_url: str + ) -> None: + """Spawning neural network bot without registry should fail.""" + ws_url = server_url.replace("http://", "ws://") + "/ws" + + # Create manager without registry + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Try to spawn neural network bot + response = await manager.spawn_bot( + SpawnBotRequest( + room_code="ANYROOM", + room_password="FAKE", # Need non-empty password for server + bot_config=BotConfig( + bot_type="neural_network", + model_id="some_model", + player_name="NeuralBot", + ), + ) + ) + + # Should fail because no registry + assert response.success is False + assert response.error is not None + assert "registry" in response.error.lower() + + finally: + await manager.shutdown() + + +@pytest.mark.integration +class TestBotManagerShutdown: + """Integration tests for graceful shutdown.""" + + @requires_server + @pytest.mark.asyncio + async def test_shutdown_stops_all_bots(self, server_url: str) -> None: + """Verify shutdown stops all active bots.""" + room_name = unique_room_name("BotMgrShutdown") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + # Spawn multiple bots + for i in range(3): + response = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name=f"ShutdownBot{i}", + ), + ) + ) + assert response.success is True + + # Wait for connections + await manager.await_pending_tasks() + + # Verify bots are active + assert len(manager.list_bots()) == 3 + + # Shutdown + await manager.shutdown() + + # All bots should be removed + assert len(manager.list_bots()) == 0 + + @requires_server + @pytest.mark.asyncio + async def test_shutdown_with_no_bots(self, server_url: str) -> None: + """Verify shutdown works correctly with no active bots.""" + ws_url = server_url.replace("http://", "ws://") + "/ws" + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + # Shutdown with no bots should not raise + await manager.shutdown() + + # Should still be empty + assert len(manager.list_bots()) == 0 + + +@pytest.mark.integration +@pytest.mark.slow +class TestBotManagerConcurrency: + """Integration tests for concurrent bot operations.""" + + @requires_server + @pytest.mark.asyncio + async def test_concurrent_spawn_operations(self, server_url: str) -> None: + """Spawn multiple bots concurrently.""" + room_name = unique_room_name("BotMgrConcurrent") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn 5 bots concurrently + spawn_tasks = [ + manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name=f"ConcurrentBot{i}", + ), + ) + ) + for i in range(5) + ] + + responses = await asyncio.gather(*spawn_tasks) + + # All spawns should succeed + for response in responses: + assert response.success is True + assert response.bot_id is not None + + # All bot IDs should be unique + bot_ids = [r.bot_id for r in responses] + assert len(set(bot_ids)) == 5 + + # Wait for all connections + await manager.await_pending_tasks() + + # All bots should be connected + bots = manager.list_bots() + assert len(bots) == 5 + for bot in bots: + assert bot.is_connected is True + + finally: + await manager.shutdown() + + @requires_server + @pytest.mark.asyncio + async def test_spawn_and_destroy_interleaved(self, server_url: str) -> None: + """Test interleaved spawn and destroy operations.""" + room_name = unique_room_name("BotMgrInterleaved") + ws_url = server_url.replace("http://", "ws://") + "/ws" + + host_client = GameClient(http_url=server_url, mode=ClientMode.REST) + async with host_client: + create_response = await host_client.create_game( + player_name="HostPlayer", + room_name=room_name, + map_type="default", + training_mode=True, + tick_rate_multiplier=10.0, + ) + room_code = create_response.room_code + room_password = create_response.room_password + assert room_code is not None + + manager = BotManager( + registry=None, + http_url=server_url, + ws_url=ws_url, + ) + + try: + # Spawn first bot + response1 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="InterleavedBot1", + ), + ) + ) + assert response1.success is True + await manager.await_pending_tasks() + + # Spawn second bot + response2 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="InterleavedBot2", + ), + ) + ) + assert response2.success is True + await manager.await_pending_tasks() + + # Destroy first bot + assert response1.bot_id is not None + await manager.destroy_bot(response1.bot_id) + + # Spawn third bot + response3 = await manager.spawn_bot( + SpawnBotRequest( + room_code=room_code, + room_password=room_password, + bot_config=BotConfig( + bot_type="rule_based", + player_name="InterleavedBot3", + ), + ) + ) + assert response3.success is True + await manager.await_pending_tasks() + + # Should have 2 bots (bot2 and bot3) + bots = manager.list_bots() + assert len(bots) == 2 + + bot_ids = {b.bot_id for b in bots} + assert response1.bot_id not in bot_ids + assert response2.bot_id in bot_ids + assert response3.bot_id in bot_ids + + finally: + await manager.shutdown() diff --git a/bot2/tests/unit/service/test_bot_manager.py b/bot2/tests/unit/service/test_bot_manager.py new file mode 100644 index 0000000..57dc4dd --- /dev/null +++ b/bot2/tests/unit/service/test_bot_manager.py @@ -0,0 +1,1062 @@ +"""Unit tests for BotManager. + +Tests cover: +- BotManager initialization with and without registry +- spawn_bot() with rule-based bots +- spawn_bot() with neural network bots (by model_id and generation) +- spawn_bot() error handling (invalid model, no registry, etc.) +- spawn_bot() generates unique bot_ids +- destroy_bot() stops and removes bots +- destroy_bot() returns False for unknown bot_id +- get_bot() returns BotInfo for active bots +- get_bot() returns None for unknown bot_id +- list_bots() returns all active bots +- list_models() returns models from registry +- shutdown() stops all active bots +- Background connection failure handling +- Pydantic model validation +""" + +import asyncio +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from bot.service.bot_manager import ( + BotConfig, + BotInfo, + BotManager, + SpawnBotRequest, + SpawnBotResponse, +) +from bot.training.registry import ( + ModelMetadata, + ModelNotFoundError, + NetworkArchitecture, + TrainingMetrics, +) + + +class MockActorCriticNetwork: + """Mock ActorCriticNetwork for testing.""" + + def to(self, device): + return self + + def eval(self): + return self + + +class MockModelRegistry: + """Mock ModelRegistry for testing.""" + + def __init__(self) -> None: + self.get_model = MagicMock() + self.get_model_by_generation = MagicMock() + self.list_models = MagicMock(return_value=[]) + + def setup_model( + self, model_id: str, generation: int = 0 + ) -> tuple[MockActorCriticNetwork, ModelMetadata]: + """Setup a mock model with metadata.""" + network = MockActorCriticNetwork() + metadata = ModelMetadata( + model_id=model_id, + generation=generation, + created_at=datetime(2024, 1, 1), + training_duration_seconds=3600.0, + training_metrics=TrainingMetrics( + total_episodes=1000, + total_timesteps=100000, + average_reward=50.0, + average_episode_length=100.0, + win_rate=0.5, + average_kills=2.0, + average_deaths=2.0, + kills_deaths_ratio=1.0, + ), + architecture=NetworkArchitecture( + observation_size=128, + action_size=32, + ), + checkpoint_path="test.pth", + ) + return network, metadata + + +class TestBotManagerInit: + """Tests for BotManager initialization.""" + + def test_init_with_registry(self) -> None: + """Initialize with ModelRegistry.""" + registry = MockModelRegistry() + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + assert manager._registry is registry + assert manager._http_url == "http://localhost:4000" + assert manager._ws_url == "ws://localhost:4000/ws" + assert manager._device == "cpu" + assert manager._bots == {} + + def test_init_without_registry(self) -> None: + """Initialize without registry (rule-based only mode).""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + assert manager._registry is None + assert manager._http_url == "http://localhost:4000" + assert manager._ws_url == "ws://localhost:4000/ws" + + def test_init_custom_device(self) -> None: + """Initialize with custom device.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + default_device="cuda", + ) + + assert manager._device == "cuda" + + +class TestBotManagerSpawnBot: + """Tests for BotManager.spawn_bot() method.""" + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_spawn_rule_based_bot( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """spawn_bot() with rule-based bot returns success.""" + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner = MagicMock() + mock_runner_class.return_value = mock_runner + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + room_password="secret", + bot_config=BotConfig( + bot_type="rule_based", + player_name="RuleBot", + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is True + assert response.bot_id is not None + assert response.bot_id.startswith("bot_") + assert response.error is None + assert len(manager._bots) == 1 + + # Verify runner was created correctly + mock_runner_class.assert_called_once() + + # Verify client was created correctly + mock_client_class.assert_called_once() + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.NeuralNetBotRunner") + async def test_spawn_neural_network_bot_by_model_id( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """spawn_bot() with neural network bot by model_id returns success.""" + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner = MagicMock() + mock_runner_class.return_value = mock_runner + + registry = MockModelRegistry() + network, metadata = registry.setup_model("ppo_gen_005", generation=5) + registry.get_model.return_value = (network, metadata) + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + room_password="", + bot_config=BotConfig( + bot_type="neural_network", + model_id="ppo_gen_005", + player_name="NeuralBot", + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is True + assert response.bot_id is not None + assert response.error is None + assert len(manager._bots) == 1 + + # Verify model was loaded + registry.get_model.assert_called_once_with("ppo_gen_005", device="cpu") + + # Verify runner was created with network + mock_runner_class.assert_called_once() + call_kwargs = mock_runner_class.call_args.kwargs + assert call_kwargs["network"] is network + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.NeuralNetBotRunner") + async def test_spawn_neural_network_bot_by_generation( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """spawn_bot() with neural network bot by generation returns success.""" + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner = MagicMock() + mock_runner_class.return_value = mock_runner + + registry = MockModelRegistry() + network, metadata = registry.setup_model("ppo_gen_003", generation=3) + registry.get_model_by_generation.return_value = (network, metadata) + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="XYZ789", + bot_config=BotConfig( + bot_type="neural_network", + generation=3, + player_name="GenBot", + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is True + assert response.bot_id is not None + assert response.error is None + + # Verify model was loaded by generation + registry.get_model_by_generation.assert_called_once_with(3, device="cpu") + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.NeuralNetBotRunner") + async def test_spawn_neural_network_bot_with_both_model_id_and_generation( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """spawn_bot() with both model_id and generation uses model_id (precedence).""" + mock_client = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner = MagicMock() + mock_runner_class.return_value = mock_runner + + registry = MockModelRegistry() + network, metadata = registry.setup_model("ppo_gen_005", generation=5) + registry.get_model.return_value = (network, metadata) + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + room_password="", + bot_config=BotConfig( + bot_type="neural_network", + model_id="ppo_gen_005", + generation=3, # Both provided, model_id should take precedence + player_name="BothBot", + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is True + assert response.bot_id is not None + + # Verify model was loaded by model_id (not generation) + registry.get_model.assert_called_once_with("ppo_gen_005", device="cpu") + # get_model_by_generation should NOT have been called + registry.get_model_by_generation.assert_not_called() + + @pytest.mark.asyncio + async def test_spawn_neural_network_bot_no_registry(self) -> None: + """spawn_bot() with neural network but no registry returns error.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig( + bot_type="neural_network", + model_id="ppo_gen_001", + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is False + assert response.bot_id is None + assert response.error is not None + assert "ModelRegistry" in response.error + assert len(manager._bots) == 0 + + @pytest.mark.asyncio + async def test_spawn_neural_network_bot_invalid_model_id(self) -> None: + """spawn_bot() with invalid model_id returns error.""" + registry = MockModelRegistry() + registry.get_model.side_effect = ModelNotFoundError("Model 'invalid' not found") + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig( + bot_type="neural_network", + model_id="invalid", + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is False + assert response.bot_id is None + assert response.error is not None + assert "not found" in response.error + assert len(manager._bots) == 0 + + @pytest.mark.asyncio + async def test_spawn_neural_network_bot_invalid_generation(self) -> None: + """spawn_bot() with invalid generation returns error.""" + registry = MockModelRegistry() + registry.get_model_by_generation.return_value = None + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig( + bot_type="neural_network", + generation=999, + ), + ) + + response = await manager.spawn_bot(request) + + assert response.success is False + assert response.bot_id is None + assert response.error is not None + assert "generation 999" in response.error + assert len(manager._bots) == 0 + + @pytest.mark.asyncio + async def test_spawn_neural_network_bot_no_model_specified(self) -> None: + """BotConfig with neural network but no model_id or generation raises ValidationError.""" + from pydantic import ValidationError + + # Creating BotConfig should raise ValidationError due to validator + with pytest.raises(ValidationError, match="model_id or generation"): + BotConfig( + bot_type="neural_network", + ) + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_spawn_bot_generates_unique_ids( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """spawn_bot() generates unique bot_ids.""" + mock_client_class.return_value = AsyncMock() + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + response1 = await manager.spawn_bot(request) + response2 = await manager.spawn_bot(request) + response3 = await manager.spawn_bot(request) + + assert response1.bot_id != response2.bot_id + assert response2.bot_id != response3.bot_id + assert response1.bot_id != response3.bot_id + assert len(manager._bots) == 3 + + +class TestBotManagerDestroyBot: + """Tests for BotManager.destroy_bot() method.""" + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_destroy_bot_stops_and_removes( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """destroy_bot() stops and removes bot.""" + mock_client = AsyncMock() + mock_client.stop = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + response = await manager.spawn_bot(request) + assert len(manager._bots) == 1 + assert response.bot_id is not None + + result = await manager.destroy_bot(response.bot_id) + + assert result is True + assert len(manager._bots) == 0 + mock_client.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_destroy_bot_unknown_id_returns_false(self) -> None: + """destroy_bot() returns False for unknown bot_id.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + result = await manager.destroy_bot("unknown_bot_id") + + assert result is False + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_destroy_bot_handles_stop_error( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """destroy_bot() handles errors from client.stop().""" + mock_client = AsyncMock() + mock_client.stop = AsyncMock(side_effect=Exception("Stop failed")) + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + response = await manager.spawn_bot(request) + assert response.bot_id is not None + + # Should return True even if stop() fails + result = await manager.destroy_bot(response.bot_id) + + assert result is True + assert len(manager._bots) == 0 + + +class TestBotManagerGetBot: + """Tests for BotManager.get_bot() method.""" + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_get_bot_returns_info( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """get_bot() returns BotInfo for active bot.""" + mock_client = AsyncMock() + mock_client.is_connected = True + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + room_password="secret", + bot_config=BotConfig( + bot_type="rule_based", + player_name="TestBot", + ), + ) + + response = await manager.spawn_bot(request) + assert response.bot_id is not None + + info = manager.get_bot(response.bot_id) + + assert info is not None + assert info.bot_id == response.bot_id + assert info.bot_type == "rule_based" + assert info.model_id is None + assert info.player_name == "TestBot" + assert info.room_code == "ABC123" + assert info.is_connected is True + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.NeuralNetBotRunner") + async def test_get_bot_includes_model_id( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """get_bot() includes model_id for neural network bots.""" + mock_client = AsyncMock() + mock_client.is_connected = False + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + registry = MockModelRegistry() + network, metadata = registry.setup_model("ppo_gen_002", generation=2) + registry.get_model.return_value = (network, metadata) + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="XYZ789", + bot_config=BotConfig( + bot_type="neural_network", + model_id="ppo_gen_002", + ), + ) + + response = await manager.spawn_bot(request) + assert response.bot_id is not None + + info = manager.get_bot(response.bot_id) + + assert info is not None + assert info.model_id == "ppo_gen_002" + + def test_get_bot_unknown_id_returns_none(self) -> None: + """get_bot() returns None for unknown bot_id.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + info = manager.get_bot("unknown_id") + + assert info is None + + +class TestBotManagerListBots: + """Tests for BotManager.list_bots() method.""" + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_list_bots_returns_all_active( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """list_bots() returns all active bots.""" + mock_client_class.return_value = AsyncMock() + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request1 = SpawnBotRequest( + room_code="ROOM1", + bot_config=BotConfig(bot_type="rule_based", player_name="Bot1"), + ) + request2 = SpawnBotRequest( + room_code="ROOM2", + bot_config=BotConfig(bot_type="rule_based", player_name="Bot2"), + ) + + await manager.spawn_bot(request1) + await manager.spawn_bot(request2) + + bots = manager.list_bots() + + assert len(bots) == 2 + assert all(isinstance(bot, BotInfo) for bot in bots) + player_names = {bot.player_name for bot in bots} + assert player_names == {"Bot1", "Bot2"} + + def test_list_bots_returns_empty_when_no_bots(self) -> None: + """list_bots() returns empty list when no bots.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + bots = manager.list_bots() + + assert bots == [] + + +class TestBotManagerListModels: + """Tests for BotManager.list_models() method.""" + + def test_list_models_returns_from_registry(self) -> None: + """list_models() returns models from registry.""" + registry = MockModelRegistry() + _, metadata1 = registry.setup_model("ppo_gen_001", generation=1) + _, metadata2 = registry.setup_model("ppo_gen_002", generation=2) + registry.list_models.return_value = [metadata1, metadata2] + + manager = BotManager( + registry=registry, # type: ignore[arg-type] + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + models = manager.list_models() + + assert len(models) == 2 + assert models[0].model_id == "ppo_gen_001" + assert models[1].model_id == "ppo_gen_002" + + def test_list_models_returns_empty_when_no_registry(self) -> None: + """list_models() returns empty list when no registry.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + models = manager.list_models() + + assert models == [] + + +class TestBotManagerShutdown: + """Tests for BotManager.shutdown() method.""" + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_shutdown_stops_all_bots( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """shutdown() stops all active bots.""" + mock_client1 = AsyncMock() + mock_client2 = AsyncMock() + mock_client1.start = AsyncMock() + mock_client2.start = AsyncMock() + mock_client_class.side_effect = [mock_client1, mock_client2] + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + await manager.spawn_bot(request) + await manager.spawn_bot(request) + + assert len(manager._bots) == 2 + + await manager.shutdown() + + assert len(manager._bots) == 0 + mock_client1.stop.assert_called_once() + mock_client2.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_shutdown_with_no_bots(self) -> None: + """shutdown() is safe when no bots active.""" + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + await manager.shutdown() + + assert len(manager._bots) == 0 + + +class TestBotManagerBackgroundConnection: + """Tests for background bot connection handling.""" + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_background_connection_failure_removes_bot( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """Background connection failure removes bot from tracking.""" + mock_client = AsyncMock() + mock_client.start = AsyncMock(side_effect=Exception("Connection failed")) + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + response = await manager.spawn_bot(request) + + assert response.success is True + assert len(manager._bots) == 1 + + # Wait for background task to complete + await manager.await_pending_tasks() + + # Bot should be removed from tracking + assert len(manager._bots) == 0 + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_concurrent_spawn_bot_calls( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """Multiple concurrent spawn_bot() calls produce unique bot IDs.""" + mock_client = AsyncMock() + mock_client.start = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + # Spawn multiple bots concurrently + requests = [ + SpawnBotRequest( + room_code=f"ROOM{i}", + bot_config=BotConfig(bot_type="rule_based"), + ) + for i in range(5) + ] + + responses = await asyncio.gather(*[manager.spawn_bot(req) for req in requests]) + + # All should succeed + assert all(r.success for r in responses) + + # All bot IDs should be unique + bot_ids = [r.bot_id for r in responses] + assert len(bot_ids) == len(set(bot_ids)) + + # All bots should be tracked + assert len(manager._bots) == 5 + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_destroy_bot_twice_idempotent( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """destroy_bot() called twice for same bot_id is idempotent.""" + mock_client = AsyncMock() + mock_client.start = AsyncMock() + mock_client.stop = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + response = await manager.spawn_bot(request) + bot_id = response.bot_id + assert bot_id is not None + + # First destroy should succeed + result1 = await manager.destroy_bot(bot_id) + assert result1 is True + + # Second destroy should return False (not found) + result2 = await manager.destroy_bot(bot_id) + assert result2 is False + + # Bot should not be in tracking + assert len(manager._bots) == 0 + + @pytest.mark.asyncio + @patch("bot.service.bot_manager.WebSocketBotClient") + @patch("bot.service.bot_manager.RuleBasedBotRunner") + async def test_shutdown_with_pending_connections( + self, + mock_runner_class: MagicMock, + mock_client_class: MagicMock, + ) -> None: + """shutdown() waits for pending background connection tasks.""" + # Create a mock client that takes time to start + mock_client = AsyncMock() + connection_started = asyncio.Event() + + async def slow_start(room_code: str, room_password: str) -> None: + await asyncio.sleep(0.05) # Simulate slow connection + connection_started.set() + + mock_client.start = AsyncMock(side_effect=slow_start) + mock_client.stop = AsyncMock() + mock_client_class.return_value = mock_client + mock_runner_class.return_value = MagicMock() + + manager = BotManager( + registry=None, + http_url="http://localhost:4000", + ws_url="ws://localhost:4000/ws", + ) + + # Spawn a bot (connection starts in background) + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + await manager.spawn_bot(request) + + # Immediately call shutdown (while connection is pending) + assert not connection_started.is_set() + await manager.shutdown() + + # Connection should have completed during shutdown + assert connection_started.is_set() + + # All tasks should be complete + assert len(manager._tasks) == 0 + + +class TestPydanticModels: + """Tests for Pydantic model validation.""" + + def test_bot_config_rule_based(self) -> None: + """BotConfig validates rule_based type.""" + config = BotConfig(bot_type="rule_based") + + assert config.bot_type == "rule_based" + assert config.model_id is None + assert config.generation is None + assert config.player_name == "Bot" + + def test_bot_config_neural_network_with_model_id(self) -> None: + """BotConfig validates neural_network with model_id.""" + config = BotConfig( + bot_type="neural_network", + model_id="ppo_gen_005", + player_name="NeuralBot", + ) + + assert config.bot_type == "neural_network" + assert config.model_id == "ppo_gen_005" + assert config.generation is None + assert config.player_name == "NeuralBot" + + def test_bot_config_neural_network_with_generation(self) -> None: + """BotConfig validates neural_network with generation.""" + config = BotConfig( + bot_type="neural_network", + generation=10, + ) + + assert config.bot_type == "neural_network" + assert config.model_id is None + assert config.generation == 10 + + def test_bot_config_neural_network_with_both_model_id_and_generation(self) -> None: + """BotConfig accepts both model_id and generation (model_id takes precedence in _load_model).""" + config = BotConfig( + bot_type="neural_network", + model_id="ppo_gen_005", + generation=3, + ) + + # Both values should be stored + assert config.bot_type == "neural_network" + assert config.model_id == "ppo_gen_005" + assert config.generation == 3 + + def test_bot_config_invalid_type_raises(self) -> None: + """BotConfig raises error for invalid bot_type.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + BotConfig(bot_type="invalid_type") # type: ignore[arg-type] + + def test_bot_config_rule_based_with_model_id_raises(self) -> None: + """BotConfig raises error for rule_based with model_id.""" + from pydantic import ValidationError + + with pytest.raises( + ValidationError, + match="Rule-based bots cannot specify model_id or generation", + ): + BotConfig(bot_type="rule_based", model_id="ppo_gen_005") + + def test_bot_config_rule_based_with_generation_raises(self) -> None: + """BotConfig raises error for rule_based with generation.""" + from pydantic import ValidationError + + with pytest.raises( + ValidationError, + match="Rule-based bots cannot specify model_id or generation", + ): + BotConfig(bot_type="rule_based", generation=5) + + def test_bot_config_neural_network_without_model_or_gen_raises(self) -> None: + """BotConfig raises error for neural_network without model_id or generation.""" + from pydantic import ValidationError + + with pytest.raises( + ValidationError, + match="Neural network bots require either model_id or generation", + ): + BotConfig(bot_type="neural_network") + + def test_spawn_bot_request_defaults(self) -> None: + """SpawnBotRequest has correct defaults.""" + request = SpawnBotRequest( + room_code="ABC123", + bot_config=BotConfig(bot_type="rule_based"), + ) + + assert request.room_code == "ABC123" + assert request.room_password == "" + assert request.bot_config.bot_type == "rule_based" + + def test_spawn_bot_response_success(self) -> None: + """SpawnBotResponse success state.""" + response = SpawnBotResponse(success=True, bot_id="bot_123") + + assert response.success is True + assert response.bot_id == "bot_123" + assert response.error is None + + def test_spawn_bot_response_error(self) -> None: + """SpawnBotResponse error state.""" + response = SpawnBotResponse(success=False, error="Model not found") + + assert response.success is False + assert response.bot_id is None + assert response.error == "Model not found" + + def test_bot_info_serialization(self) -> None: + """BotInfo serialization.""" + info = BotInfo( + bot_id="bot_abc123", + bot_type="neural_network", + model_id="ppo_gen_005", + player_name="NeuralBot", + room_code="XYZ789", + is_connected=True, + ) + + data = info.model_dump() + + assert data["bot_id"] == "bot_abc123" + assert data["bot_type"] == "neural_network" + assert data["model_id"] == "ppo_gen_005" + assert data["player_name"] == "NeuralBot" + assert data["room_code"] == "XYZ789" + assert data["is_connected"] is True diff --git a/bot2/tests/unit/test_neural_net_bot.py b/bot2/tests/unit/test_neural_net_bot.py index 1d8faf9..cda49d4 100644 --- a/bot2/tests/unit/test_neural_net_bot.py +++ b/bot2/tests/unit/test_neural_net_bot.py @@ -323,8 +323,8 @@ def test_shoot_start_without_prior_aim(self, network: ActorCriticNetwork) -> Non # Verify aim position uses default direction (0.0 radians = right) # aim_x = 100.0 + 100.0 * cos(0.0) = 200.0 # aim_y = 200.0 + 100.0 * sin(0.0) = 200.0 - assert actions[0][2] == 200.0 # aim_x - assert actions[0][3] == 200.0 # aim_y + assert actions[0][2] == 200.0 # aim_x # type: ignore[index-out-of-bounds] + assert actions[0][3] == 200.0 # aim_y # type: ignore[index-out-of-bounds] class TestNeuralNetBotStateDeduplication: @@ -474,7 +474,7 @@ def test_initialization( self, mock_client: GameClient, network: ActorCriticNetwork ) -> None: """Test runner initializes correctly.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) assert runner.client is mock_client assert runner.network is network @@ -488,7 +488,7 @@ async def test_creates_bot_on_first_state( game_state: GameState, ) -> None: """Test runner creates bot on first game state.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) await runner.on_game_state(game_state) @@ -503,7 +503,7 @@ async def test_raises_if_no_player_id( mock_client = MagicMock(spec=GameClient) mock_client.player_id = None - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) with pytest.raises(ValueError, match="player_id must be set"): await runner.on_game_state(game_state) @@ -516,7 +516,7 @@ async def test_sends_keyboard_inputs( game_state: GameState, ) -> None: """Test runner sends keyboard inputs to client.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) await runner.on_game_state(game_state) @@ -531,7 +531,7 @@ async def test_deduplicates_keyboard_inputs( game_state: GameState, ) -> None: """Test runner only sends changed keyboard inputs.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) # First call - should send inputs await runner.on_game_state(game_state) @@ -555,7 +555,7 @@ async def test_deduplicates_mouse_inputs( game_state: GameState, ) -> None: """Test runner only sends changed mouse inputs.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) await runner.on_game_state(game_state) @@ -567,7 +567,7 @@ def test_reset_clears_state( self, mock_client: GameClient, network: ActorCriticNetwork ) -> None: """Test reset clears runner state.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) runner._previous_keyboard_actions = {"a": True} runner._previous_mouse_state = True runner._previous_aim_pos = (100.0, 200.0) @@ -586,7 +586,7 @@ async def test_reset_resets_bot( game_state: GameState, ) -> None: """Test reset also resets the bot.""" - runner = NeuralNetBotRunner(mock_client, network) + runner = NeuralNetBotRunner(network=network, client=mock_client) # Create bot await runner.on_game_state(game_state)