Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/base_stages/extract_subtopics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from autogen_core.models import ChatCompletionClient

from src.base_stages.task_dataclasses import SubTopic
from src.schemas.capability_schemas import Capability
from src.utils.base_generation_prompts import format_subtopic_prompt
from src.utils.model_client_utils import ModelCallMode, async_call_model

Expand All @@ -14,7 +15,7 @@


def extract_subtopics(
capability,
capability: Capability,
client: ChatCompletionClient,
min_subtopics: int = 3,
max_subtopics: int = 8,
Expand Down
3 changes: 2 additions & 1 deletion src/base_stages/find_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DIFFICULTY_LEVELS,
)
from src.base_stages.task_dataclasses import Combination, SubTopic
from src.schemas.capability_schemas import Capability
from src.utils.base_generation_prompts import format_combination_prompt
from src.utils.model_client_utils import ModelCallMode, async_call_model

Expand All @@ -18,7 +19,7 @@


def find_valid_combinations(
capability, subtopics: list[SubTopic], client: ChatCompletionClient
capability: Capability, subtopics: list[SubTopic], client: ChatCompletionClient
) -> list[Combination]:
"""Find valid combinations of Content, Difficulty, and Reasoning.

Expand Down
3 changes: 2 additions & 1 deletion src/base_stages/generate_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DIFFICULTY_LEVELS,
)
from src.base_stages.task_dataclasses import Blueprint, Combination
from src.schemas.capability_schemas import Capability
from src.utils.base_generation_prompts import format_blueprint_prompt
from src.utils.model_client_utils import ModelCallMode, async_call_model

Expand All @@ -18,7 +19,7 @@


def generate_blueprints(
capability,
capability: Capability,
combinations: list[Combination],
client: ChatCompletionClient,
) -> list[Blueprint]:
Expand Down
2 changes: 1 addition & 1 deletion src/base_stages/generate_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def generate_capabilities(
-------
List of generated Capability objects
"""
capabilities = []
capabilities: List[Capability] = []

# Calculate number of runs needed
num_runs = int(np.ceil(num_capabilities / num_capabilities_per_run))
Expand Down
8 changes: 5 additions & 3 deletions src/base_stages/generate_tasks_from_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autogen_core.models import ChatCompletionClient

from src.base_stages.task_dataclasses import Blueprint
from src.schemas.capability_schemas import Capability
from src.schemas.task_schemas import Task
from src.utils.base_generation_prompts import (
format_options_prompt,
Expand All @@ -19,7 +20,7 @@


def generate_tasks_from_blueprints(
capability,
capability: Capability,
blueprints: list[Blueprint],
client: ChatCompletionClient,
tasks_per_blueprint: int = 3,
Expand All @@ -45,7 +46,7 @@ def generate_tasks_from_blueprints(
"""
logger.info("Generating tasks from blueprints...")

all_tasks = []
all_tasks: List[Task] = []

for blueprint in blueprints:
logger.info(
Expand Down Expand Up @@ -137,7 +138,8 @@ def generate_tasks_from_blueprints(
tasks_for_blueprint = [
t
for t in all_tasks
if t.generation_metadata.get("blueprint_id") == blueprint.combination_id
if t.generation_metadata
and t.generation_metadata.get("blueprint_id") == blueprint.combination_id
]
logger.info(f" Generated {len(tasks_for_blueprint)} tasks for this blueprint")

Expand Down
5 changes: 4 additions & 1 deletion src/base_stages/stage0_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
from pathlib import Path
from typing import Any, Dict, cast

from omegaconf import DictConfig, OmegaConf

Expand Down Expand Up @@ -56,7 +57,9 @@ def run_stage0(cfg: DictConfig) -> None:
domain=domain_name,
domain_id=domain_id,
pipeline_type=pipeline_type,
configuration=config_dict,
configuration=cast(Dict[str, Any], config_dict)
if isinstance(config_dict, dict)
else {},
)

metadata = PipelineMetadata(
Expand Down
4 changes: 3 additions & 1 deletion src/base_stages/stage2_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import math
from pathlib import Path
from typing import Optional

from omegaconf import DictConfig

Expand All @@ -28,7 +29,7 @@
def run_stage2(
cfg: DictConfig,
areas_tag: str,
capabilities_tag: str = None,
capabilities_tag: Optional[str] = None,
) -> str:
"""Stage 2: Generate capabilities, embed, and filter.

Expand Down Expand Up @@ -171,4 +172,5 @@ def run_stage2(
f"{capabilities_path}"
)

assert capabilities_tag is not None
return capabilities_tag
4 changes: 3 additions & 1 deletion src/base_stages/stage3_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
from pathlib import Path
from typing import Optional

from omegaconf import DictConfig

Expand All @@ -25,7 +26,7 @@
def run_stage3(
cfg: DictConfig,
capabilities_tag: str,
tasks_tag: str = None,
tasks_tag: Optional[str] = None,
) -> str:
"""Stage 3: Generate tasks for each capability.

Expand Down Expand Up @@ -159,4 +160,5 @@ def run_stage3(
# Continue with next capability instead of failing completely
continue

assert tasks_tag is not None
return tasks_tag
5 changes: 4 additions & 1 deletion src/base_stages/stage4_solutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
from pathlib import Path
from typing import Optional

from omegaconf import DictConfig

Expand All @@ -23,7 +24,7 @@
def run_stage4(
cfg: DictConfig,
tasks_tag: str,
solution_tag: str = None,
solution_tag: Optional[str] = None,
) -> str:
"""Stage 4: Generate solutions for tasks.

Expand Down Expand Up @@ -66,6 +67,7 @@ def run_stage4(

if not tasks_base_dir.exists():
logger.error(f"Tasks directory not found: {tasks_base_dir}")
assert solution_tag is not None
return solution_tag

area_dirs = [d for d in tasks_base_dir.iterdir() if d.is_dir()]
Expand Down Expand Up @@ -156,4 +158,5 @@ def run_stage4(
)
continue

assert solution_tag is not None
return solution_tag
5 changes: 4 additions & 1 deletion src/base_stages/stage5_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
from pathlib import Path
from typing import Optional

from omegaconf import DictConfig

Expand All @@ -23,7 +24,7 @@
def run_stage5(
cfg: DictConfig,
solution_tag: str,
validation_tag: str = None,
validation_tag: Optional[str] = None,
) -> str:
"""Stage 5: Validate generated task solutions.

Expand Down Expand Up @@ -65,6 +66,7 @@ def run_stage5(

if not solutions_base_dir.exists():
logger.error(f"Solutions directory not found: {solutions_base_dir}")
assert validation_tag is not None
return validation_tag

# Find all area directories
Expand Down Expand Up @@ -165,4 +167,5 @@ def run_stage5(
continue

logger.info(f"Stage 5 completed. Validation tag: {validation_tag}")
assert validation_tag is not None
return validation_tag
2 changes: 1 addition & 1 deletion src/base_stages/validate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def validate_tasks(
"suggested_improvements": response.get(
"suggested_improvements", ""
),
**task_solution.generation_metadata,
**(task_solution.generation_metadata or {}),
},
)
validation_results.append(validation_result)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/capability_management_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,4 @@ def filter_schema_capabilities_by_embeddings(
)

filtered_capabilities = [capabilities[i] for i in remaining_indices]
return filtered_capabilities, remaining_indices
return filtered_capabilities, list(remaining_indices)
2 changes: 1 addition & 1 deletion src/utils/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def generate_and_set_capabilities_embeddings(
capabilities: List[Capability],
embedding_model_name: str,
embed_dimensions: int,
rep_string_order="and",
rep_string_order: str = "and",
) -> None:
"""Generate the capabilities embeddings using the OpenAI embedding model.

Expand Down
2 changes: 1 addition & 1 deletion src/utils/model_client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def async_call_model(

try:
response = await model_client.create(
messages=list(messages), # type: ignore[arg-type]
messages=list(messages),
**request_kwargs,
)
except retryable_exceptions as exc:
Expand Down
Loading