diff --git a/src/area_generation/generator.py b/src/area_generation/generator.py index 0176dfaa..45b59f4f 100644 --- a/src/area_generation/generator.py +++ b/src/area_generation/generator.py @@ -18,7 +18,7 @@ from src.area_generation.messages import Domain from src.area_generation.moderator import AreaModerator from src.area_generation.scientist import AreaScientist -from src.utils.model_client_utils import get_model_client +from src.utils.model_client_utils import get_standard_model_client log = logging.getLogger("agentic_area_gen.generator") @@ -27,7 +27,7 @@ logging.getLogger(EVENT_LOGGER_NAME).setLevel(logging.WARNING) -async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> None: +async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse) -> None: """Generate areas using multi-agent debate system.""" domain_name = cfg.global_cfg.domain exp_id = cfg.exp_cfg.exp_id @@ -86,7 +86,7 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N runtime, "AreaScientistA", lambda: AreaScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_a.model_name, seed=cfg.agents.scientist_a.seed, ), @@ -99,7 +99,7 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N runtime, "AreaScientistB", lambda: AreaScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_b.model_name, seed=cfg.agents.scientist_b.seed, ), @@ -112,7 +112,7 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N runtime, "AreaModerator", lambda: AreaModerator( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.moderator.model_name, seed=cfg.agents.moderator.seed, ), diff --git a/src/area_generation/moderator.py b/src/area_generation/moderator.py index ba2bbbac..83410bc4 100644 --- a/src/area_generation/moderator.py +++ b/src/area_generation/moderator.py @@ -4,7 +4,7 @@ import logging import traceback from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List from autogen_core import ( DefaultTopicId, @@ -47,7 +47,7 @@ def __init__( num_final_areas: int, max_round: int, output_dir: Path, - langfuse_client: Langfuse = None, + langfuse_client: Langfuse, ) -> None: super().__init__("Area Moderator") self._model_client = model_client @@ -288,7 +288,7 @@ async def _merge_proposals( ) raise - async def _finalize_areas(self, final_areas: dict) -> None: + async def _finalize_areas(self, final_areas: Dict[str, Any]) -> None: """Save final areas to file.""" with self._langfuse_client.start_as_current_span( name="moderator_finalize_areas" diff --git a/src/area_generation/scientist.py b/src/area_generation/scientist.py index 899d2c62..1690546b 100644 --- a/src/area_generation/scientist.py +++ b/src/area_generation/scientist.py @@ -42,7 +42,7 @@ def __init__( self, model_client: ChatCompletionClient, scientist_id: str, - langfuse_client: Langfuse = None, + langfuse_client: Langfuse, ) -> None: super().__init__(f"Area Scientist {scientist_id}") self._model_client = model_client diff --git a/src/capability_generation/generator.py b/src/capability_generation/generator.py index b8ffc653..d0bbe4c8 100644 --- a/src/capability_generation/generator.py +++ b/src/capability_generation/generator.py @@ -21,7 +21,7 @@ from src.capability_generation.messages import Area from src.capability_generation.moderator import CapabilityModerator from src.capability_generation.scientist import CapabilityScientist -from src.utils.model_client_utils import get_model_client +from src.utils.model_client_utils import get_standard_model_client log = logging.getLogger("agentic_cap_gen.generator") @@ -58,7 +58,7 @@ async def generate_capabilities_for_area( runtime, "CapabilityScientistA", lambda: CapabilityScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_a.model_name, seed=cfg.agents.scientist_a.seed, ), @@ -71,7 +71,7 @@ async def generate_capabilities_for_area( runtime, "CapabilityScientistB", lambda: CapabilityScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_b.model_name, seed=cfg.agents.scientist_b.seed, ), @@ -84,7 +84,7 @@ async def generate_capabilities_for_area( runtime, "CapabilityModerator", lambda: CapabilityModerator( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.moderator.model_name, seed=cfg.agents.moderator.seed, ), diff --git a/src/capability_generation/moderator.py b/src/capability_generation/moderator.py index fad5aaab..6b7944a6 100644 --- a/src/capability_generation/moderator.py +++ b/src/capability_generation/moderator.py @@ -4,7 +4,7 @@ import logging import traceback from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List from autogen_core import ( DefaultTopicId, @@ -48,7 +48,7 @@ def __init__( max_round: int, output_dir: Path, domain: str, - langfuse_client: Langfuse = None, + langfuse_client: Langfuse, ) -> None: super().__init__("Capability Moderator") self._model_client = model_client @@ -322,7 +322,7 @@ async def _merge_proposals( raise async def _finalize_capabilities( - self, final_capabilities: dict, area_name: str + self, final_capabilities: Dict[str, Any], area_name: str ) -> None: """Save final capabilities to file.""" with self._langfuse_client.start_as_current_span( diff --git a/src/capability_generation/scientist.py b/src/capability_generation/scientist.py index cbe8c84e..dbb11d4a 100644 --- a/src/capability_generation/scientist.py +++ b/src/capability_generation/scientist.py @@ -42,7 +42,7 @@ def __init__( self, model_client: ChatCompletionClient, scientist_id: str, - langfuse_client: Langfuse = None, + langfuse_client: Langfuse, ) -> None: super().__init__(f"Capability Scientist {scientist_id}") self._scientist_id = scientist_id diff --git a/src/task_generation/generator.py b/src/task_generation/generator.py index 7ff3468c..455d3985 100644 --- a/src/task_generation/generator.py +++ b/src/task_generation/generator.py @@ -21,7 +21,7 @@ from src.task_generation.messages import Capability from src.task_generation.moderator import TaskModerator from src.task_generation.scientist import TaskScientist -from src.utils.model_client_utils import get_model_client +from src.utils.model_client_utils import get_standard_model_client log = logging.getLogger("agentic_task_gen.generator") @@ -60,7 +60,7 @@ async def generate_tasks_for_capability( runtime, "TaskScientistA", lambda: TaskScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_a.model_name, seed=cfg.agents.scientist_a.seed, ), @@ -74,7 +74,7 @@ async def generate_tasks_for_capability( runtime, "TaskScientistB", lambda: TaskScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_b.model_name, seed=cfg.agents.scientist_b.seed, ), @@ -89,7 +89,7 @@ async def generate_tasks_for_capability( runtime, "TaskModerator", lambda: TaskModerator( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.moderator.model_name, seed=cfg.agents.moderator.seed, ), diff --git a/src/task_solver/generator.py b/src/task_solver/generator.py index 0165c8b6..598ab531 100644 --- a/src/task_solver/generator.py +++ b/src/task_solver/generator.py @@ -20,7 +20,7 @@ from src.task_solver.messages import Task from src.task_solver.moderator import TaskSolverModerator from src.task_solver.scientist import TaskSolverScientist -from src.utils.model_client_utils import get_model_client +from src.utils.model_client_utils import get_standard_model_client log = logging.getLogger("task_solver.generator") @@ -61,7 +61,7 @@ async def solve_task( runtime, "TaskSolverModerator", lambda: TaskSolverModerator( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.moderator.model_name, seed=cfg.agents.moderator.get("seed"), ), @@ -77,7 +77,7 @@ async def solve_task( runtime, "TaskSolverScientistA", lambda: TaskSolverScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_a.model_name, seed=cfg.agents.scientist_a.get("seed"), ), @@ -90,7 +90,7 @@ async def solve_task( runtime, "TaskSolverScientistB", lambda: TaskSolverScientist( - model_client=get_model_client( + model_client=get_standard_model_client( model_name=cfg.agents.scientist_b.model_name, seed=cfg.agents.scientist_b.get("seed"), ),