diff --git a/src/base_stages/extract_subtopics.py b/src/base_stages/extract_subtopics.py index eb88deb..3212dd1 100644 --- a/src/base_stages/extract_subtopics.py +++ b/src/base_stages/extract_subtopics.py @@ -6,6 +6,7 @@ from autogen_core.models import ChatCompletionClient from src.base_stages.task_dataclasses import SubTopic +from src.schemas.capability_schemas import Capability from src.utils.base_generation_prompts import format_subtopic_prompt from src.utils.model_client_utils import ModelCallMode, async_call_model @@ -14,7 +15,7 @@ def extract_subtopics( - capability, + capability: Capability, client: ChatCompletionClient, min_subtopics: int = 3, max_subtopics: int = 8, diff --git a/src/base_stages/find_combinations.py b/src/base_stages/find_combinations.py index b90c817..a72d5c8 100644 --- a/src/base_stages/find_combinations.py +++ b/src/base_stages/find_combinations.py @@ -10,6 +10,7 @@ DIFFICULTY_LEVELS, ) from src.base_stages.task_dataclasses import Combination, SubTopic +from src.schemas.capability_schemas import Capability from src.utils.base_generation_prompts import format_combination_prompt from src.utils.model_client_utils import ModelCallMode, async_call_model @@ -18,7 +19,7 @@ def find_valid_combinations( - capability, subtopics: list[SubTopic], client: ChatCompletionClient + capability: Capability, subtopics: list[SubTopic], client: ChatCompletionClient ) -> list[Combination]: """Find valid combinations of Content, Difficulty, and Reasoning. diff --git a/src/base_stages/generate_blueprints.py b/src/base_stages/generate_blueprints.py index 0acf98b..bb180c4 100644 --- a/src/base_stages/generate_blueprints.py +++ b/src/base_stages/generate_blueprints.py @@ -10,6 +10,7 @@ DIFFICULTY_LEVELS, ) from src.base_stages.task_dataclasses import Blueprint, Combination +from src.schemas.capability_schemas import Capability from src.utils.base_generation_prompts import format_blueprint_prompt from src.utils.model_client_utils import ModelCallMode, async_call_model @@ -18,7 +19,7 @@ def generate_blueprints( - capability, + capability: Capability, combinations: list[Combination], client: ChatCompletionClient, ) -> list[Blueprint]: diff --git a/src/base_stages/generate_capabilities.py b/src/base_stages/generate_capabilities.py index 6521cc3..b2b2a49 100644 --- a/src/base_stages/generate_capabilities.py +++ b/src/base_stages/generate_capabilities.py @@ -37,7 +37,7 @@ def generate_capabilities( ------- List of generated Capability objects """ - capabilities = [] + capabilities: List[Capability] = [] # Calculate number of runs needed num_runs = int(np.ceil(num_capabilities / num_capabilities_per_run)) diff --git a/src/base_stages/generate_tasks_from_blueprints.py b/src/base_stages/generate_tasks_from_blueprints.py index c0361a2..81218ac 100644 --- a/src/base_stages/generate_tasks_from_blueprints.py +++ b/src/base_stages/generate_tasks_from_blueprints.py @@ -7,6 +7,7 @@ from autogen_core.models import ChatCompletionClient from src.base_stages.task_dataclasses import Blueprint +from src.schemas.capability_schemas import Capability from src.schemas.task_schemas import Task from src.utils.base_generation_prompts import ( format_options_prompt, @@ -19,7 +20,7 @@ def generate_tasks_from_blueprints( - capability, + capability: Capability, blueprints: list[Blueprint], client: ChatCompletionClient, tasks_per_blueprint: int = 3, @@ -45,7 +46,7 @@ def generate_tasks_from_blueprints( """ logger.info("Generating tasks from blueprints...") - all_tasks = [] + all_tasks: List[Task] = [] for blueprint in blueprints: logger.info( @@ -137,7 +138,8 @@ def generate_tasks_from_blueprints( tasks_for_blueprint = [ t for t in all_tasks - if t.generation_metadata.get("blueprint_id") == blueprint.combination_id + if t.generation_metadata + and t.generation_metadata.get("blueprint_id") == blueprint.combination_id ] logger.info(f" Generated {len(tasks_for_blueprint)} tasks for this blueprint") diff --git a/src/base_stages/stage0_setup.py b/src/base_stages/stage0_setup.py index 56ba00d..28066fc 100644 --- a/src/base_stages/stage0_setup.py +++ b/src/base_stages/stage0_setup.py @@ -5,6 +5,7 @@ import logging from pathlib import Path +from typing import Any, Dict, cast from omegaconf import DictConfig, OmegaConf @@ -56,7 +57,9 @@ def run_stage0(cfg: DictConfig) -> None: domain=domain_name, domain_id=domain_id, pipeline_type=pipeline_type, - configuration=config_dict, + configuration=cast(Dict[str, Any], config_dict) + if isinstance(config_dict, dict) + else {}, ) metadata = PipelineMetadata( diff --git a/src/base_stages/stage2_capabilities.py b/src/base_stages/stage2_capabilities.py index f7dd1d0..6bb6259 100644 --- a/src/base_stages/stage2_capabilities.py +++ b/src/base_stages/stage2_capabilities.py @@ -7,6 +7,7 @@ import logging import math from pathlib import Path +from typing import Optional from omegaconf import DictConfig @@ -28,7 +29,7 @@ def run_stage2( cfg: DictConfig, areas_tag: str, - capabilities_tag: str = None, + capabilities_tag: Optional[str] = None, ) -> str: """Stage 2: Generate capabilities, embed, and filter. @@ -171,4 +172,5 @@ def run_stage2( f"{capabilities_path}" ) + assert capabilities_tag is not None return capabilities_tag diff --git a/src/base_stages/stage3_tasks.py b/src/base_stages/stage3_tasks.py index 4e042c5..7d12dba 100644 --- a/src/base_stages/stage3_tasks.py +++ b/src/base_stages/stage3_tasks.py @@ -6,6 +6,7 @@ import logging from pathlib import Path +from typing import Optional from omegaconf import DictConfig @@ -25,7 +26,7 @@ def run_stage3( cfg: DictConfig, capabilities_tag: str, - tasks_tag: str = None, + tasks_tag: Optional[str] = None, ) -> str: """Stage 3: Generate tasks for each capability. @@ -159,4 +160,5 @@ def run_stage3( # Continue with next capability instead of failing completely continue + assert tasks_tag is not None return tasks_tag diff --git a/src/base_stages/stage4_solutions.py b/src/base_stages/stage4_solutions.py index 1cc80ba..8c0b438 100644 --- a/src/base_stages/stage4_solutions.py +++ b/src/base_stages/stage4_solutions.py @@ -6,6 +6,7 @@ import logging from pathlib import Path +from typing import Optional from omegaconf import DictConfig @@ -23,7 +24,7 @@ def run_stage4( cfg: DictConfig, tasks_tag: str, - solution_tag: str = None, + solution_tag: Optional[str] = None, ) -> str: """Stage 4: Generate solutions for tasks. @@ -66,6 +67,7 @@ def run_stage4( if not tasks_base_dir.exists(): logger.error(f"Tasks directory not found: {tasks_base_dir}") + assert solution_tag is not None return solution_tag area_dirs = [d for d in tasks_base_dir.iterdir() if d.is_dir()] @@ -156,4 +158,5 @@ def run_stage4( ) continue + assert solution_tag is not None return solution_tag diff --git a/src/base_stages/stage5_validation.py b/src/base_stages/stage5_validation.py index 7b33715..d56972e 100644 --- a/src/base_stages/stage5_validation.py +++ b/src/base_stages/stage5_validation.py @@ -6,6 +6,7 @@ import logging from pathlib import Path +from typing import Optional from omegaconf import DictConfig @@ -23,7 +24,7 @@ def run_stage5( cfg: DictConfig, solution_tag: str, - validation_tag: str = None, + validation_tag: Optional[str] = None, ) -> str: """Stage 5: Validate generated task solutions. @@ -65,6 +66,7 @@ def run_stage5( if not solutions_base_dir.exists(): logger.error(f"Solutions directory not found: {solutions_base_dir}") + assert validation_tag is not None return validation_tag # Find all area directories @@ -165,4 +167,5 @@ def run_stage5( continue logger.info(f"Stage 5 completed. Validation tag: {validation_tag}") + assert validation_tag is not None return validation_tag diff --git a/src/base_stages/validate_tasks.py b/src/base_stages/validate_tasks.py index e932ea4..8805555 100644 --- a/src/base_stages/validate_tasks.py +++ b/src/base_stages/validate_tasks.py @@ -114,7 +114,7 @@ def validate_tasks( "suggested_improvements": response.get( "suggested_improvements", "" ), - **task_solution.generation_metadata, + **(task_solution.generation_metadata or {}), }, ) validation_results.append(validation_result) diff --git a/src/utils/capability_management_utils.py b/src/utils/capability_management_utils.py index d693cab..1664f54 100644 --- a/src/utils/capability_management_utils.py +++ b/src/utils/capability_management_utils.py @@ -227,4 +227,4 @@ def filter_schema_capabilities_by_embeddings( ) filtered_capabilities = [capabilities[i] for i in remaining_indices] - return filtered_capabilities, remaining_indices + return filtered_capabilities, list(remaining_indices) diff --git a/src/utils/embedding_utils.py b/src/utils/embedding_utils.py index d70764f..f3afe9d 100644 --- a/src/utils/embedding_utils.py +++ b/src/utils/embedding_utils.py @@ -117,7 +117,7 @@ def generate_and_set_capabilities_embeddings( capabilities: List[Capability], embedding_model_name: str, embed_dimensions: int, - rep_string_order="and", + rep_string_order: str = "and", ) -> None: """Generate the capabilities embeddings using the OpenAI embedding model. diff --git a/src/utils/model_client_utils.py b/src/utils/model_client_utils.py index a905e96..be4c41f 100644 --- a/src/utils/model_client_utils.py +++ b/src/utils/model_client_utils.py @@ -168,7 +168,7 @@ async def async_call_model( try: response = await model_client.create( - messages=list(messages), # type: ignore[arg-type] + messages=list(messages), **request_kwargs, ) except retryable_exceptions as exc: