Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions examples/prisoner/manager.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,61 @@
from typing import Any, Literal
import json
from typing import Any, Literal, Optional, Union

from dotenv import load_dotenv
from pydantic import BaseModel

from econagents import AgentRole
from econagents.core.events import Message
from econagents.core.manager.phase import TurnBasedPhaseManager
from econagents.core.state.game import GameState
from econagents.llm import ChatOpenAI

load_dotenv()


class PrisonerChoice(BaseModel):
"""Structured output the prisoner emits every round."""

gameId: int
type: Literal["choice"]
choice: Literal["COOPERATE", "DEFECT"]


class Prisoner(AgentRole):
"""Base class for prisoner agents in the Prisoner's Dilemma game."""

role = 1
name = "Prisoner"
llm = ChatOpenAI(model_name="gpt-5.4-mini")
llm = ChatOpenAI(model_name="gpt-4o-mini")
task_phases = ["decision"]
default_response_schema = PrisonerChoice

def parse_phase_llm_response(self, response: Union[str, BaseModel], state: GameState) -> dict:
if isinstance(response, PrisonerChoice):
choice = response.choice
else:
choice = json.loads(response)["choice"]
return {
"meta": {"type": "submit-choice", "component": {"type": "standard:coordination"}},
"payload": {"choice": choice},
}

class PDManager(TurnBasedPhaseManager):
"""
Manager for the Prisoner's Dilemma game.
Manages interactions between the server and agents.
"""

class PDManager(TurnBasedPhaseManager):
def __init__(self, game_id: int, auth_mechanism_kwargs: dict[str, Any]):
super().__init__(
auth_mechanism_kwargs=auth_mechanism_kwargs,
agent_role=Prisoner(),
)
self.game_id = game_id
self.register_phase_handler("introduction", self._handle_introduction)

async def _handle_introduction(self, phase: str, state: GameState) -> dict:
return {
"meta": {"type": "ready", "component": {"type": "standard:ready"}},
"payload": {},
}

def _extract_message_data(self, raw_message: str) -> Optional[Message]:
try:
msg = json.loads(raw_message)
event_type = (msg.get("meta") or {}).get("type", "")
data = msg.get("payload") or {}
except json.JSONDecodeError:
self.logger.error("Invalid JSON received.")
return None
return Message(message_type="event", event_type=event_type, data=data)
64 changes: 35 additions & 29 deletions examples/prisoner_personas/run_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

from econagents.core.game_runner import GameRunner, TurnBasedGameRunnerConfig
from econagents.personas import load_persona
from examples.prisoner.manager import PDManager
from examples.prisoner.state import PDGameState
from examples.prisoner_personas.manager import PDManager

logger = logging.getLogger("prisoners_dilemma_personas")

Expand All @@ -39,38 +39,36 @@
async def main(
game_id: int,
recovery_codes: list[str],
personas: list[str]
personas: list[str],
hostname: str,
port: int,
) -> None:
logger.info("Starting persona-driven Prisoner's Dilemma game")
load_dotenv()

login_payloads = [
{"type": "join", "gameId": game_id, "recovery": code}
for code in recovery_codes
]

config = TurnBasedGameRunnerConfig(
game_id=game_id,
logs_dir=Path(__file__).parent / "logs",
prompts_dir=Path(__file__).parent / "prompts",
log_level=logging.INFO,
hostname="localhost",
port=8765,
path="wss",
hostname=hostname,
port=port,
path="",
state_class=PDGameState,
phase_transition_event="round-started",
phase_identifier_key="round",
observability_provider="langsmith",
)

agents = [
PDManager(
agents = []
for i, (recovery_code, persona_id) in enumerate(zip(recovery_codes, personas)):
agent = PDManager(
game_id=game_id,
auth_mechanism_kwargs=payload,
persona=load_persona(persona_id),
auth_mechanism_kwargs={
"meta": {"type": "join"},
"payload": {"recovery": recovery_code},
},
)
for payload, persona_id in zip(login_payloads, personas)
]
if persona_id:
agent.agent_role.persona = load_persona(persona_id, user_dir=PERSONAS_DIR)
agents.append(agent)

runner = GameRunner(config=config, agents=agents)
await runner.run_game()
Expand All @@ -90,7 +88,7 @@ def parse_args() -> argparse.Namespace:
"--persona",
dest="personas",
action="append",
required=True,
default=[],
metavar="PERSONA_ID",
help="Persona id for an agent. Repeat once per agent, in agent order.",
)
Expand All @@ -102,6 +100,17 @@ def parse_args() -> argparse.Namespace:
metavar="CODE",
help="Recovery code for an agent. Repeat once per agent, in agent order.",
)
parser.add_argument(
"--hostname",
default="localhost",
help="Game server hostname.",
)
parser.add_argument(
"--port",
type=int,
default=3000,
help="Game server port.",
)
return parser.parse_args()


Expand All @@ -112,21 +121,18 @@ def parse_args() -> argparse.Namespace:
--game-id 1 \
--persona conditional-cooperator \
--persona marcus-strategic-44 \
--persona tit-for-tat \
--recovery-code CODE1 \
--recovery-code CODE2 \
--recovery-code CODE3
--recovery-code CODE2
"""
args = parse_args()
if len(args.personas) != len(args.recovery_codes):
raise SystemExit(
"The number of --persona and --recovery-code arguments must match "
f"(got {len(args.personas)} personas and {len(args.recovery_codes)} codes)."
)
# Pad personas with empty strings so the list is always as long as recovery_codes.
personas = args.personas + [""] * (len(args.recovery_codes) - len(args.personas))
asyncio.run(
main(
game_id=args.game_id,
recovery_codes=args.recovery_codes,
personas=args.personas,
personas=personas,
hostname=args.hostname,
port=args.port,
)
)
Loading