diff --git a/examples/prisoner/manager.py b/examples/prisoner/manager.py index 23e64dc..b95e72e 100644 --- a/examples/prisoner/manager.py +++ b/examples/prisoner/manager.py @@ -1,41 +1,61 @@ -from typing import Any, Literal +import json +from typing import Any, Literal, Optional, Union from dotenv import load_dotenv from pydantic import BaseModel from econagents import AgentRole +from econagents.core.events import Message from econagents.core.manager.phase import TurnBasedPhaseManager +from econagents.core.state.game import GameState from econagents.llm import ChatOpenAI load_dotenv() class PrisonerChoice(BaseModel): - """Structured output the prisoner emits every round.""" - - gameId: int - type: Literal["choice"] choice: Literal["COOPERATE", "DEFECT"] class Prisoner(AgentRole): - """Base class for prisoner agents in the Prisoner's Dilemma game.""" - role = 1 name = "Prisoner" - llm = ChatOpenAI(model_name="gpt-5.4-mini") + llm = ChatOpenAI(model_name="gpt-4o-mini") + task_phases = ["decision"] default_response_schema = PrisonerChoice + def parse_phase_llm_response(self, response: Union[str, BaseModel], state: GameState) -> dict: + if isinstance(response, PrisonerChoice): + choice = response.choice + else: + choice = json.loads(response)["choice"] + return { + "meta": {"type": "submit-choice", "component": {"type": "standard:coordination"}}, + "payload": {"choice": choice}, + } -class PDManager(TurnBasedPhaseManager): - """ - Manager for the Prisoner's Dilemma game. - Manages interactions between the server and agents. - """ +class PDManager(TurnBasedPhaseManager): def __init__(self, game_id: int, auth_mechanism_kwargs: dict[str, Any]): super().__init__( auth_mechanism_kwargs=auth_mechanism_kwargs, agent_role=Prisoner(), ) self.game_id = game_id + self.register_phase_handler("introduction", self._handle_introduction) + + async def _handle_introduction(self, phase: str, state: GameState) -> dict: + return { + "meta": {"type": "ready", "component": {"type": "standard:ready"}}, + "payload": {}, + } + + def _extract_message_data(self, raw_message: str) -> Optional[Message]: + try: + msg = json.loads(raw_message) + event_type = (msg.get("meta") or {}).get("type", "") + data = msg.get("payload") or {} + except json.JSONDecodeError: + self.logger.error("Invalid JSON received.") + return None + return Message(message_type="event", event_type=event_type, data=data) diff --git a/examples/prisoner_personas/run_game.py b/examples/prisoner_personas/run_game.py index fe3a7ac..3f12b7e 100644 --- a/examples/prisoner_personas/run_game.py +++ b/examples/prisoner_personas/run_game.py @@ -28,8 +28,8 @@ from econagents.core.game_runner import GameRunner, TurnBasedGameRunnerConfig from econagents.personas import load_persona +from examples.prisoner.manager import PDManager from examples.prisoner.state import PDGameState -from examples.prisoner_personas.manager import PDManager logger = logging.getLogger("prisoners_dilemma_personas") @@ -39,38 +39,36 @@ async def main( game_id: int, recovery_codes: list[str], - personas: list[str] + personas: list[str], + hostname: str, + port: int, ) -> None: logger.info("Starting persona-driven Prisoner's Dilemma game") load_dotenv() - login_payloads = [ - {"type": "join", "gameId": game_id, "recovery": code} - for code in recovery_codes - ] - config = TurnBasedGameRunnerConfig( game_id=game_id, logs_dir=Path(__file__).parent / "logs", prompts_dir=Path(__file__).parent / "prompts", log_level=logging.INFO, - hostname="localhost", - port=8765, - path="wss", + hostname=hostname, + port=port, + path="", state_class=PDGameState, - phase_transition_event="round-started", - phase_identifier_key="round", - observability_provider="langsmith", ) - agents = [ - PDManager( + agents = [] + for i, (recovery_code, persona_id) in enumerate(zip(recovery_codes, personas)): + agent = PDManager( game_id=game_id, - auth_mechanism_kwargs=payload, - persona=load_persona(persona_id), + auth_mechanism_kwargs={ + "meta": {"type": "join"}, + "payload": {"recovery": recovery_code}, + }, ) - for payload, persona_id in zip(login_payloads, personas) - ] + if persona_id: + agent.agent_role.persona = load_persona(persona_id, user_dir=PERSONAS_DIR) + agents.append(agent) runner = GameRunner(config=config, agents=agents) await runner.run_game() @@ -90,7 +88,7 @@ def parse_args() -> argparse.Namespace: "--persona", dest="personas", action="append", - required=True, + default=[], metavar="PERSONA_ID", help="Persona id for an agent. Repeat once per agent, in agent order.", ) @@ -102,6 +100,17 @@ def parse_args() -> argparse.Namespace: metavar="CODE", help="Recovery code for an agent. Repeat once per agent, in agent order.", ) + parser.add_argument( + "--hostname", + default="localhost", + help="Game server hostname.", + ) + parser.add_argument( + "--port", + type=int, + default=3000, + help="Game server port.", + ) return parser.parse_args() @@ -112,21 +121,18 @@ def parse_args() -> argparse.Namespace: --game-id 1 \ --persona conditional-cooperator \ --persona marcus-strategic-44 \ - --persona tit-for-tat \ --recovery-code CODE1 \ - --recovery-code CODE2 \ - --recovery-code CODE3 + --recovery-code CODE2 """ args = parse_args() - if len(args.personas) != len(args.recovery_codes): - raise SystemExit( - "The number of --persona and --recovery-code arguments must match " - f"(got {len(args.personas)} personas and {len(args.recovery_codes)} codes)." - ) + # Pad personas with empty strings so the list is always as long as recovery_codes. + personas = args.personas + [""] * (len(args.recovery_codes) - len(args.personas)) asyncio.run( main( game_id=args.game_id, recovery_codes=args.recovery_codes, - personas=args.personas, + personas=personas, + hostname=args.hostname, + port=args.port, ) )