Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/area_generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from src.area_generation.messages import Domain
from src.area_generation.moderator import AreaModerator
from src.area_generation.scientist import AreaScientist
from src.utils.model_client_utils import get_model_client
from src.utils.model_client_utils import get_standard_model_client


log = logging.getLogger("agentic_area_gen.generator")
Expand All @@ -27,7 +27,7 @@
logging.getLogger(EVENT_LOGGER_NAME).setLevel(logging.WARNING)


async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> None:
async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse) -> None:
"""Generate areas using multi-agent debate system."""
domain_name = cfg.global_cfg.domain
exp_id = cfg.exp_cfg.exp_id
Expand Down Expand Up @@ -86,7 +86,7 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
runtime,
"AreaScientistA",
lambda: AreaScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_a.model_name,
seed=cfg.agents.scientist_a.seed,
),
Expand All @@ -99,7 +99,7 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
runtime,
"AreaScientistB",
lambda: AreaScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_b.model_name,
seed=cfg.agents.scientist_b.seed,
),
Expand All @@ -112,7 +112,7 @@ async def generate_areas(cfg: DictConfig, langfuse_client: Langfuse = None) -> N
runtime,
"AreaModerator",
lambda: AreaModerator(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.moderator.model_name,
seed=cfg.agents.moderator.seed,
),
Expand Down
6 changes: 3 additions & 3 deletions src/area_generation/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import traceback
from pathlib import Path
from typing import Dict, List
from typing import Any, Dict, List

from autogen_core import (
DefaultTopicId,
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
num_final_areas: int,
max_round: int,
output_dir: Path,
langfuse_client: Langfuse = None,
langfuse_client: Langfuse,
) -> None:
super().__init__("Area Moderator")
self._model_client = model_client
Expand Down Expand Up @@ -288,7 +288,7 @@ async def _merge_proposals(
)
raise

async def _finalize_areas(self, final_areas: dict) -> None:
async def _finalize_areas(self, final_areas: Dict[str, Any]) -> None:
"""Save final areas to file."""
with self._langfuse_client.start_as_current_span(
name="moderator_finalize_areas"
Expand Down
2 changes: 1 addition & 1 deletion src/area_generation/scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self,
model_client: ChatCompletionClient,
scientist_id: str,
langfuse_client: Langfuse = None,
langfuse_client: Langfuse,
) -> None:
super().__init__(f"Area Scientist {scientist_id}")
self._model_client = model_client
Expand Down
8 changes: 4 additions & 4 deletions src/capability_generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from src.capability_generation.messages import Area
from src.capability_generation.moderator import CapabilityModerator
from src.capability_generation.scientist import CapabilityScientist
from src.utils.model_client_utils import get_model_client
from src.utils.model_client_utils import get_standard_model_client


log = logging.getLogger("agentic_cap_gen.generator")
Expand Down Expand Up @@ -58,7 +58,7 @@ async def generate_capabilities_for_area(
runtime,
"CapabilityScientistA",
lambda: CapabilityScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_a.model_name,
seed=cfg.agents.scientist_a.seed,
),
Expand All @@ -71,7 +71,7 @@ async def generate_capabilities_for_area(
runtime,
"CapabilityScientistB",
lambda: CapabilityScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_b.model_name,
seed=cfg.agents.scientist_b.seed,
),
Expand All @@ -84,7 +84,7 @@ async def generate_capabilities_for_area(
runtime,
"CapabilityModerator",
lambda: CapabilityModerator(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.moderator.model_name,
seed=cfg.agents.moderator.seed,
),
Expand Down
6 changes: 3 additions & 3 deletions src/capability_generation/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import traceback
from pathlib import Path
from typing import Dict, List
from typing import Any, Dict, List

from autogen_core import (
DefaultTopicId,
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
max_round: int,
output_dir: Path,
domain: str,
langfuse_client: Langfuse = None,
langfuse_client: Langfuse,
) -> None:
super().__init__("Capability Moderator")
self._model_client = model_client
Expand Down Expand Up @@ -322,7 +322,7 @@ async def _merge_proposals(
raise

async def _finalize_capabilities(
self, final_capabilities: dict, area_name: str
self, final_capabilities: Dict[str, Any], area_name: str
) -> None:
"""Save final capabilities to file."""
with self._langfuse_client.start_as_current_span(
Expand Down
2 changes: 1 addition & 1 deletion src/capability_generation/scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self,
model_client: ChatCompletionClient,
scientist_id: str,
langfuse_client: Langfuse = None,
langfuse_client: Langfuse,
) -> None:
super().__init__(f"Capability Scientist {scientist_id}")
self._scientist_id = scientist_id
Expand Down
8 changes: 4 additions & 4 deletions src/task_generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from src.task_generation.messages import Capability
from src.task_generation.moderator import TaskModerator
from src.task_generation.scientist import TaskScientist
from src.utils.model_client_utils import get_model_client
from src.utils.model_client_utils import get_standard_model_client


log = logging.getLogger("agentic_task_gen.generator")
Expand Down Expand Up @@ -60,7 +60,7 @@ async def generate_tasks_for_capability(
runtime,
"TaskScientistA",
lambda: TaskScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_a.model_name,
seed=cfg.agents.scientist_a.seed,
),
Expand All @@ -74,7 +74,7 @@ async def generate_tasks_for_capability(
runtime,
"TaskScientistB",
lambda: TaskScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_b.model_name,
seed=cfg.agents.scientist_b.seed,
),
Expand All @@ -89,7 +89,7 @@ async def generate_tasks_for_capability(
runtime,
"TaskModerator",
lambda: TaskModerator(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.moderator.model_name,
seed=cfg.agents.moderator.seed,
),
Expand Down
8 changes: 4 additions & 4 deletions src/task_solver/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from src.task_solver.messages import Task
from src.task_solver.moderator import TaskSolverModerator
from src.task_solver.scientist import TaskSolverScientist
from src.utils.model_client_utils import get_model_client
from src.utils.model_client_utils import get_standard_model_client


log = logging.getLogger("task_solver.generator")
Expand Down Expand Up @@ -61,7 +61,7 @@ async def solve_task(
runtime,
"TaskSolverModerator",
lambda: TaskSolverModerator(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.moderator.model_name,
seed=cfg.agents.moderator.get("seed"),
),
Expand All @@ -77,7 +77,7 @@ async def solve_task(
runtime,
"TaskSolverScientistA",
lambda: TaskSolverScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_a.model_name,
seed=cfg.agents.scientist_a.get("seed"),
),
Expand All @@ -90,7 +90,7 @@ async def solve_task(
runtime,
"TaskSolverScientistB",
lambda: TaskSolverScientist(
model_client=get_model_client(
model_client=get_standard_model_client(
model_name=cfg.agents.scientist_b.model_name,
seed=cfg.agents.scientist_b.get("seed"),
),
Expand Down
Loading