diff --git a/pyproject.toml b/pyproject.toml index f957922..46b0e92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ readme = "README.md" dependencies = [ "jsonschema==4.25.1", "openai==2.9.0", + "google-genai>=1.59.0", "pydantic==2.12.5", "sqlalchemy==2.0.44", "sqlmodel==0.0.27", diff --git a/src/extrai/core/__init__.py b/src/extrai/core/__init__.py index 31a0286..944f030 100644 --- a/src/extrai/core/__init__.py +++ b/src/extrai/core/__init__.py @@ -5,32 +5,31 @@ database writing, LLM interaction, and workflow orchestration. """ +from .analytics_collector import WorkflowAnalyticsCollector +from .conflict_resolvers import ( + SimilarityClusterResolver, + default_conflict_resolver, + prefer_most_common_resolver, +) from .errors import ( - WorkflowError, - LLMInteractionError, + ConfigurationError, ConsensusProcessError, HydrationError, + LLMAPICallError, LLMConfigurationError, + LLMInteractionError, LLMOutputParseError, LLMOutputValidationError, - LLMAPICallError, - ConfigurationError, + WorkflowError, ) - -from .analytics_collector import WorkflowAnalyticsCollector +from .example_json_generator import ExampleJSONGenerator from .json_consensus import JSONConsensus -from .prompt_builder import generate_system_prompt, generate_user_prompt_for_docs from .model_registry import ModelRegistry -from .schema_inspector import SchemaInspector +from .prompt_builder import generate_system_prompt, generate_user_prompt_for_docs from .result_processor import ResultProcessor, SQLAlchemyHydrator, persist_objects -from .workflow_orchestrator import WorkflowOrchestrator +from .schema_inspector import SchemaInspector from .sqlmodel_generator import SQLModelCodeGenerator -from .example_json_generator import ExampleJSONGenerator -from .conflict_resolvers import ( - SimilarityClusterResolver, - default_conflict_resolver, - prefer_most_common_resolver, -) +from .workflow_orchestrator import WorkflowOrchestrator __all__ = [ # Errors diff --git a/src/extrai/core/analytics_collector.py b/src/extrai/core/analytics_collector.py index 751d9f7..b487819 100644 --- a/src/extrai/core/analytics_collector.py +++ b/src/extrai/core/analytics_collector.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional, Dict, Any +from typing import Any class WorkflowAnalyticsCollector: @@ -7,7 +7,7 @@ class WorkflowAnalyticsCollector: Collects analytics data throughout the LLM workflow. """ - def __init__(self, logger: Optional[logging.Logger] = None): + def __init__(self, logger: logging.Logger | None = None): self.logger = logger or logging.getLogger(self.__class__.__name__) if not logger: self.logger.setLevel(logging.WARNING) @@ -19,14 +19,14 @@ def __init__(self, logger: Optional[logging.Logger] = None): self._hydrated_objects_successes: int = 0 self._hydration_failures: int = 0 # Stores a list of dictionaries, each dictionary being the analytics_details from a consensus run - self._consensus_run_details_list: List[Dict[str, Any]] = [] - self._custom_events: List[Dict[str, Any]] = [] - self._workflow_errors: List[Dict[str, Any]] = [] - self._llm_output_validations_errors_details: List[Dict[str, Any]] = [] + self._consensus_run_details_list: list[dict[str, Any]] = [] + self._custom_events: list[dict[str, Any]] = [] + self._workflow_errors: list[dict[str, Any]] = [] + self._llm_output_validations_errors_details: list[dict[str, Any]] = [] self._total_llm_cost: float = 0.0 self._total_input_tokens: int = 0 self._total_output_tokens: int = 0 - self._llm_cost_details: List[Dict[str, Any]] = [] + self._llm_cost_details: list[dict[str, Any]] = [] def record_llm_usage( self, @@ -34,12 +34,13 @@ def record_llm_usage( output_tokens: int, model: str, cost: float = 0.0, - details: Optional[Dict[str, Any]] = None, + details: dict[str, Any] | None = None, ): """Records the token usage and optional cost of an LLM call.""" self._total_input_tokens += input_tokens self._total_output_tokens += output_tokens - self._total_llm_cost += cost + if cost is not None: + self._total_llm_cost += cost usage_details = { "model": model, @@ -76,7 +77,7 @@ def record_hydration_failure(self): """Increments the count of hydration failures.""" self._hydration_failures += 1 - def record_consensus_run_details(self, consensus_analytics_details: Dict[str, Any]): + def record_consensus_run_details(self, consensus_analytics_details: dict[str, Any]): """ Records detailed analytics from a single consensus calculation. @@ -224,7 +225,7 @@ def average_paths_omitted_ratio(self) -> float: else 0.0 ) - def get_report(self) -> Dict[str, Any]: + def get_report(self) -> dict[str, Any]: """ Returns a dictionary summarizing all collected analytics. """ @@ -282,7 +283,7 @@ def get_report(self) -> Dict[str, Any]: return report def record_custom_event( - self, event_name: str, details: Optional[Dict[str, Any]] = None + self, event_name: str, details: dict[str, Any] | None = None ): """Records a generic custom event.""" event_record = {"event_name": event_name} @@ -293,9 +294,9 @@ def record_custom_event( def record_workflow_error( self, error_type: str, - context: Optional[str] = None, - message: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, + context: str | None = None, + message: str | None = None, + details: dict[str, Any] | None = None, ): """Records a generic workflow error.""" error_record = {"error_type": error_type} diff --git a/src/extrai/core/base_llm_client.py b/src/extrai/core/base_llm_client.py index 3de5e78..95d021f 100644 --- a/src/extrai/core/base_llm_client.py +++ b/src/extrai/core/base_llm_client.py @@ -1,57 +1,60 @@ +import asyncio import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Type -import asyncio +from collections.abc import Callable +from contextvars import ContextVar +from enum import Enum +from typing import Any + from sqlmodel import SQLModel +from extrai.core.analytics_collector import WorkflowAnalyticsCollector from extrai.core.errors import ( + LLMAPICallError, LLMOutputParseError, LLMOutputValidationError, - LLMAPICallError, LLMRevisionGenerationError, ) from extrai.utils.llm_output_processing import ( process_and_validate_llm_output, process_and_validate_raw_json, ) -from extrai.core.analytics_collector import ( - WorkflowAnalyticsCollector, -) + +revision_context: ContextVar[str] = ContextVar("revision_context", default="") + + +class ResponseMode(Enum): + """Defines the format of the LLM response.""" + + TEXT = "text" + STRUCTURED = "structured" + + +class ProviderBatchStatus(Enum): + """Standardized batch job status across providers.""" + + PROCESSING = "processing" + PENDING = "pending" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" class BaseLLMClient(ABC): """ Abstract base class for LLM clients. - This class provides a common structure for interacting with various LLM providers. - It handles the generic logic for generating multiple JSON revisions, including - retries and validation, while delegating the actual LLM API call to subclasses. - - Attributes: - api_key (str): The API key for authenticating with the LLM service. - model_name (str): The specific model identifier to be used for generation. - base_url (Optional[str]): Base URL for the API, if applicable. - temperature (Optional[float]): The sampling temperature for generation. + Handles LLM API calls with retry logic, validation, and concurrent revision generation. """ def __init__( self, api_key: str, model_name: str, - base_url: Optional[str] = None, - temperature: Optional[float] = 0.7, - logger: Optional[logging.Logger] = None, + base_url: str | None = None, + temperature: float | None = 0.7, + logger: logging.Logger | None = None, ): - """ - Initializes the BaseLLMClient. - - Args: - api_key: The API key for the LLM service. - model_name: The model identifier. - base_url: Optional base URL for the LLM API. - temperature: Optional sampling temperature. - logger: An optional logger instance. If not provided, a default logger is created. - """ self.api_key = api_key self.model_name = model_name self.base_url = base_url @@ -65,311 +68,219 @@ async def _execute_llm_call( self, system_prompt: str, user_prompt: str, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - ) -> str: + response_mode: ResponseMode = ResponseMode.TEXT, + response_model: type[Any] | None = None, + analytics_collector: WorkflowAnalyticsCollector | None = None, + **kwargs: Any, + ) -> Any: """ - Makes the actual API call to the LLM and returns the raw string content. - - This method must be implemented by concrete subclasses to interact with - their specific LLM provider's API. + Makes the actual API call to the LLM. Args: system_prompt: The system prompt for the LLM. user_prompt: The user prompt for the LLM. + response_mode: Whether to return raw text or structured output. + response_model: The Pydantic/SQLModel class for structured responses. analytics_collector: Optional analytics collector for tracking costs. + **kwargs: Additional provider-specific arguments. Returns: - The raw string content from the LLM response. Should return an empty - string if the LLM response was empty but did not constitute an API error. + - TEXT mode: Raw string content from the LLM + - STRUCTURED mode: Instance of response_model Raises: - LLMAPICallError: If the underlying API call fails. - Exception: For other unexpected errors during the API call. + LLMAPICallError: If the API call fails. + NotImplementedError: If the provider doesn't support the requested mode. """ ... - async def generate_structured( + async def _generate_single_revision( self, system_prompt: str, user_prompt: str, - response_model: Type[Any], - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - **kwargs: Any, - ) -> Any: + max_attempts: int, + validation_fn: Callable[[Any], dict[str, Any]] | None, + revision_index: int, + response_mode: ResponseMode = ResponseMode.TEXT, + response_model: type[Any] | None = None, + analytics_collector: WorkflowAnalyticsCollector | None = None, + ) -> dict[str, Any]: """ - Generates a structured output directly from the LLM. - This defaults to raising NotImplementedError for providers that don't support it. + Generates a single revision with retry logic. Args: - system_prompt: The system prompt. - user_prompt: The user prompt. - response_model: The Pydantic model class to parse the response into. - analytics_collector: Optional analytics collector. - **kwargs: Additional arguments. - - Returns: - An instance of response_model. + validation_fn: Optional function to validate/transform the response. + If None (structured mode), response is used as-is. """ - raise NotImplementedError( - "Structured generation is not supported by this provider." - ) - - async def _attempt_single_generation_and_validation( - self, - *, - system_prompt: str, - user_prompt: str, - validation_callable: Callable[[str, str], Dict[str, Any]], - revision_info_for_error: str, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - ) -> Dict[str, Any]: - """ - Performs one LLM call and one validation attempt. - """ - raw_response_content = await self._execute_llm_call( - system_prompt=system_prompt, - user_prompt=user_prompt, - analytics_collector=analytics_collector, - ) - - if not raw_response_content: - raise ValueError(f"{revision_info_for_error}: LLM returned empty content.") - - self.logger.debug(f"received {raw_response_content} from the llm") + last_error: Exception | None = None - validated_data = validation_callable( - raw_response_content, revision_info_for_error - ) - return validated_data - - async def _generate_one_revision_with_retries( - self, - *, - system_prompt: str, - user_prompt: str, - max_attempts: int, - validation_callable: Callable[[str, str], Dict[str, Any]], - analytics_collector: Optional[WorkflowAnalyticsCollector], - revision_index: int, - ) -> Dict[str, Any]: - """ - Manages the retry loop for generating a single valid revision. - """ - last_error: Optional[Exception] = None for attempt in range(max_attempts): - revision_info_for_error = ( - f"Revision {revision_index + 1}, Attempt {attempt + 1}" - ) + revision_info = f"Revision {revision_index + 1}, Attempt {attempt + 1}" + token = revision_context.set(revision_info) try: - validated_data = await self._attempt_single_generation_and_validation( + # Execute LLM call + response = await self._execute_llm_call( system_prompt=system_prompt, user_prompt=user_prompt, - validation_callable=validation_callable, - revision_info_for_error=revision_info_for_error, + response_mode=response_mode, + response_model=response_model, analytics_collector=analytics_collector, ) + + # Validate/transform if needed + if validation_fn: + if not response: + raise ValueError("LLM returned empty content") + result = validation_fn(response) + else: + # Structured mode - convert model to dict + if hasattr(response, "model_dump"): + result = response.model_dump() + elif hasattr(response, "dict"): + result = response.dict() + else: + result = response + if analytics_collector: analytics_collector.record_llm_api_call_success() - self.logger.debug( - f"{revision_info_for_error}: Successfully generated and validated." - ) - return validated_data - except (LLMOutputParseError, LLMOutputValidationError, ValueError) as e: - self.logger.warning( - f"{revision_info_for_error}: Validation or parsing error: {e}" - ) - last_error = e - except LLMAPICallError as e: - self.logger.warning(f"{revision_info_for_error}: API call error: {e}") - last_error = e - except Exception as e: - self.logger.warning( - f"{revision_info_for_error}: Unexpected error: {type(e).__name__} - {e}" - ) - last_error = Exception( - f"{revision_info_for_error}: Unexpected error: {type(e).__name__} - {e}" - ) - if attempt + 1 < max_attempts: - delay_multiplier = 2 if isinstance(last_error, LLMAPICallError) else 1 - delay = 0.5 * (attempt + 1) * delay_multiplier - self.logger.info( - f"{revision_info_for_error}: Retrying in {delay:.2f} seconds..." - ) - await asyncio.sleep(delay) - - if last_error: - if analytics_collector: - if isinstance(last_error, LLMAPICallError): - analytics_collector.record_llm_api_call_failure() - elif isinstance(last_error, LLMOutputParseError): - analytics_collector.record_llm_output_parse_error() - elif isinstance(last_error, LLMOutputValidationError): - analytics_collector.record_llm_output_validation_error() - raise last_error - - # This line should be unreachable, but linters might complain. - raise RuntimeError("Revision generation failed without a recorded error.") - - async def _generate_all_revisions( + self.logger.debug(f"{revision_info}: Success") + return result + + except Exception as e: + last_error = e + self.logger.warning(f"{revision_info}: {type(e).__name__} - {e}") + + # Record analytics for final attempt + if attempt + 1 == max_attempts and analytics_collector: + if isinstance(e, LLMAPICallError): + analytics_collector.record_llm_api_call_failure() + elif isinstance(e, LLMOutputParseError): + analytics_collector.record_llm_output_parse_error() + elif isinstance(e, LLMOutputValidationError): + analytics_collector.record_llm_output_validation_error() + + # Retry with backoff + if attempt + 1 < max_attempts: + delay = ( + 0.5 + * (attempt + 1) + * (2 if isinstance(e, LLMAPICallError) else 1) + ) + self.logger.info(f"Retrying in {delay:.2f}s...") + await asyncio.sleep(delay) + finally: + revision_context.reset(token) + + raise last_error or RuntimeError("Generation failed without recorded error") + + async def generate_revisions( self, - *, system_prompt: str, user_prompt: str, num_revisions: int, - max_validation_retries_per_revision: int, - validation_callable: Callable[[str, str], Any], - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - ) -> List[Any]: + max_attempts_per_revision: int = 3, + validation_fn: Callable[[Any], dict[str, Any]] | None = None, + response_mode: ResponseMode = ResponseMode.TEXT, + response_model: type[Any] | None = None, + analytics_collector: WorkflowAnalyticsCollector | None = None, + ) -> list[dict[str, Any]]: """ - Orchestrates the generation of all revisions concurrently. - """ - if max_validation_retries_per_revision < 1: - actual_attempts_per_revision = 1 - else: - actual_attempts_per_revision = max_validation_retries_per_revision + Generates multiple revisions concurrently. + Args: + validation_fn: Optional validation/transformation function. + Required for TEXT mode, not used for STRUCTURED mode. + """ tasks = [ - self._generate_one_revision_with_retries( + self._generate_single_revision( system_prompt=system_prompt, user_prompt=user_prompt, - max_attempts=actual_attempts_per_revision, - validation_callable=validation_callable, - analytics_collector=analytics_collector, + max_attempts=max(1, max_attempts_per_revision), + validation_fn=validation_fn, revision_index=i, + response_mode=response_mode, + response_model=response_model, + analytics_collector=analytics_collector, ) for i in range(num_revisions) ] results = await asyncio.gather(*tasks, return_exceptions=True) - successful_revisions = [] + # Separate successes from failures + successful = [] failures = [] - for i, res in enumerate(results): - if isinstance(res, Exception): - failures.append(res) - self.logger.error(f"Revision {i + 1} failed: {res}") - elif isinstance(res, list): - successful_revisions.extend(res) + for i, result in enumerate(results): + if isinstance(result, Exception): + failures.append(result) + self.logger.error(f"Revision {i + 1} failed: {result}") + elif isinstance(result, list): + successful.extend(result) else: - successful_revisions.append(res) + successful.append(result) - num_successful = len(successful_revisions) - num_failures = len(failures) self.logger.info( - f"Revision generation summary: {num_successful} successful, {num_failures} failed." + f"Generated {len(successful)}/{num_revisions} revisions successfully" ) - self.logger.debug(f"Generated objects : {successful_revisions}") - if failures: - # If all revisions failed, raise an aggregate error. - # If some succeeded, this error could be logged or handled differently. - if not successful_revisions: - self.logger.error("All LLM revisions failed.") - raise LLMRevisionGenerationError( - "All LLM revisions failed.", failures=failures - ) - self.logger.warning( - f"Partial failure in revision generation: {num_failures} revision(s) failed." + if not successful and num_revisions > 0: + raise LLMRevisionGenerationError( + "All LLM revisions failed.", failures=failures ) - return successful_revisions - - async def create_batch_job( - self, - requests: List[Dict[str, Any]], - endpoint: str = "/v1/chat/completions", - completion_window: str = "24h", - metadata: Optional[Dict[str, str]] = None, - ) -> Any: - """ - Creates a batch job for processing multiple requests. - - Args: - requests: List of request bodies. Each request should be a dictionary - representing the body of a single API call (e.g. chat completion). - Each request MUST have a 'custom_id' field for identification. - endpoint: The API endpoint to target (default: /v1/chat/completions). - completion_window: The time window for completion (default: 24h). - metadata: Optional metadata to attach to the batch. - - Returns: - The created batch job object. - - Raises: - NotImplementedError: If the provider does not support batch processing. - """ - raise NotImplementedError("Batch processing is not supported by this provider.") - - async def retrieve_batch_job(self, batch_id: str) -> Any: - """ - Retrieves the status and details of a batch job. - """ - raise NotImplementedError("Batch processing is not supported by this provider.") - - async def list_batch_jobs( - self, limit: int = 20, after: Optional[str] = None - ) -> Any: - """ - Lists batch jobs. - """ - raise NotImplementedError("Batch processing is not supported by this provider.") - - async def cancel_batch_job(self, batch_id: str) -> Any: - """ - Cancels a batch job. - """ - raise NotImplementedError("Batch processing is not supported by this provider.") - - async def retrieve_batch_results(self, file_id: str) -> str: - """ - Retrieves the content of a batch output file. - """ - raise NotImplementedError("Batch processing is not supported by this provider.") - - def extract_content_from_batch_response( - self, response: Dict[str, Any] - ) -> Optional[str]: - """ - Extracts the text content from a single item in a batch response file. + if failures: + self.logger.warning(f"{len(failures)} revision(s) failed") - Args: - response: A dictionary representing a single line/item from the batch output. + return successful - Returns: - The extracted content string, or None if extraction failed. - """ - raise NotImplementedError("Batch processing is not supported by this provider.") + # ========================================================================= + # HIGH-LEVEL CONVENIENCE METHODS + # ========================================================================= async def generate_json_revisions( self, system_prompt: str, user_prompt: str, num_revisions: int, - model_schema_map: Dict[str, Type[SQLModel]], - max_validation_retries_per_revision: int, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - ) -> List[Dict[str, Any]]: + model_schema_map: dict[str, type[SQLModel]], + max_validation_retries_per_revision: int = 3, + use_structured_output: bool = False, + analytics_collector: WorkflowAnalyticsCollector | None = None, + ) -> list[dict[str, Any]]: """ - Generates multiple JSON output revisions from the LLM, validating against a SQLModel. + Generates JSON revisions validated against SQLModel schemas. """ + response_mode = ResponseMode.TEXT + response_model = None + validation_fn = None - def validation_callable( - content: str, revision_info: str - ) -> List[Dict[str, Any]]: - return process_and_validate_llm_output( - raw_llm_content=content, - model_schema_map=model_schema_map, - revision_info_for_error=revision_info, - analytics_collector=analytics_collector, - ) + if use_structured_output and len(model_schema_map) == 1: + response_mode = ResponseMode.STRUCTURED + response_model = list(model_schema_map.values())[0] + else: + if use_structured_output: + self.logger.warning( + "Structured output with multiple schemas not supported. Using text mode." + ) - return await self._generate_all_revisions( + def validation_fn(content: str) -> list[dict[str, Any]]: + revision_info = revision_context.get() + return process_and_validate_llm_output( + raw_llm_content=content, + model_schema_map=model_schema_map, + revision_info_for_error=revision_info, + analytics_collector=analytics_collector, + ) + + return await self.generate_revisions( system_prompt=system_prompt, user_prompt=user_prompt, num_revisions=num_revisions, - max_validation_retries_per_revision=max_validation_retries_per_revision, - validation_callable=validation_callable, + max_attempts_per_revision=max_validation_retries_per_revision, + validation_fn=validation_fn, + response_mode=response_mode, + response_model=response_model, analytics_collector=analytics_collector, ) @@ -378,16 +289,17 @@ async def generate_and_validate_raw_json_output( system_prompt: str, user_prompt: str, num_revisions: int, - max_validation_retries_per_revision: int, - target_json_schema: Optional[Dict[str, Any]] = None, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, + max_validation_retries_per_revision: int = 3, + target_json_schema: dict[str, Any] | None = None, attempt_unwrap: bool = True, - ) -> List[Dict[str, Any]]: + analytics_collector: WorkflowAnalyticsCollector | None = None, + ) -> list[dict[str, Any]]: """ - Generates multiple JSON output revisions, validating against a raw JSON schema. + Generates JSON revisions validated against a raw JSON schema. """ - def validation_callable(content: str, revision_info: str) -> Dict[str, Any]: + def validation_fn(content: str) -> dict[str, Any]: + revision_info = revision_context.get() return process_and_validate_raw_json( raw_llm_content=content, revision_info_for_error=revision_info, @@ -395,11 +307,63 @@ def validation_callable(content: str, revision_info: str) -> Dict[str, Any]: attempt_unwrap=attempt_unwrap, ) - return await self._generate_all_revisions( + return await self.generate_revisions( system_prompt=system_prompt, user_prompt=user_prompt, num_revisions=num_revisions, - max_validation_retries_per_revision=max_validation_retries_per_revision, - validation_callable=validation_callable, + max_attempts_per_revision=max_validation_retries_per_revision, + validation_fn=validation_fn, analytics_collector=analytics_collector, ) + + # ========================================================================= + # BATCH PROCESSING (Optional - Provider-Specific) + # ========================================================================= + + async def create_batch_job( + self, + requests: list[dict[str, Any]], + endpoint: str = "/v1/chat/completions", + completion_window: str = "24h", + metadata: dict[str, str] | None = None, + response_model: type[Any] | None = None, + ) -> Any: + """Creates a batch job. Override in subclass if supported.""" + raise NotImplementedError("Batch processing is not supported by this provider") + + async def get_batch_status(self, batch_id: str) -> "ProviderBatchStatus": + """Retrieves a standardized batch job status. Override in subclass if supported.""" + raise NotImplementedError("Batch processing is not supported by this provider") + + async def list_batch_jobs(self, limit: int = 20, after: str | None = None) -> Any: + """Lists batch jobs. Override in subclass if supported.""" + raise NotImplementedError("Batch processing is not supported by this provider") + + async def cancel_batch_job(self, batch_id: str) -> Any: + """Cancels a batch job. Override in subclass if supported.""" + raise NotImplementedError("Batch processing is not supported by this provider") + + async def retrieve_batch_results(self, file_id: str) -> str: + """Retrieves batch results. Override in subclass if supported.""" + raise NotImplementedError("Batch processing is not supported by this provider") + + def extract_content_from_batch_response( + self, response: dict[str, Any] + ) -> str | None: + """Extracts content from batch response. Override in subclass if supported.""" + raise NotImplementedError("Batch processing is not supported by this provider") + + def prepare_request( + self, + system_prompt: str, + user_prompt: str, + json_schema: Any | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Prepares a request dictionary for batch processing. + Override in subclasses to provide provider-specific formatting. + """ + raise NotImplementedError( + "Batch request preparation is not supported by this provider" + ) diff --git a/src/extrai/core/batch/__init__.py b/src/extrai/core/batch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extrai/core/batch/batch_pipeline.py b/src/extrai/core/batch/batch_pipeline.py new file mode 100644 index 0000000..9c1f41d --- /dev/null +++ b/src/extrai/core/batch/batch_pipeline.py @@ -0,0 +1,205 @@ +import asyncio +import logging +from typing import Any, Union + +from sqlalchemy.orm import Session +from sqlmodel import SQLModel + +from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.batch_models import BatchJobStatus, BatchProcessResult +from extrai.core.client_rotator import ClientRotator +from extrai.core.entity_counter import EntityCounter +from extrai.core.extraction_config import ExtractionConfig +from extrai.core.extraction_context_preparer import ExtractionContextPreparer +from extrai.core.extraction_request_factory import ExtractionRequestFactory +from extrai.core.model_registry import ModelRegistry +from extrai.core.model_wrapper_builder import ModelWrapperBuilder +from extrai.core.prompt_builder import PromptBuilder +from extrai.core.shared.consensus_runner import ConsensusRunner +from extrai.core.shared.hierarchical_coordinator import HierarchicalCoordinator + +from .batch_processor import BatchProcessor +from .batch_result_retriever import BatchResultRetriever +from .batch_status_checker import BatchStatusChecker +from .batch_submitter import BatchSubmitter + + +class BatchPipeline: + """Manages batch extraction workflows.""" + + def __init__( + self, + model_registry: ModelRegistry, + llm_client: Union["BaseLLMClient", list["BaseLLMClient"]], + config: ExtractionConfig, + analytics_collector: WorkflowAnalyticsCollector, + logger: logging.Logger, + counting_llm_client: BaseLLMClient | None = None, + ): + self.model_registry = model_registry + self.config = config + self.analytics_collector = analytics_collector + self.logger = logger + + self.client_rotator = ClientRotator(llm_client) + self.prompt_builder = PromptBuilder(model_registry, logger=logger) + c_client = counting_llm_client or llm_client + if isinstance(c_client, list): + c_client = c_client[0] + + self.entity_counter = EntityCounter( + model_registry, c_client, config, analytics_collector, logger=logger + ) + self.context_preparer = ExtractionContextPreparer( + model_registry, + analytics_collector, + config.max_validation_retries_per_revision, + logger=logger, + ) + self.model_wrapper_builder = ModelWrapperBuilder() + self.consensus_runner = ConsensusRunner(config, analytics_collector, logger) + self.request_factory = ExtractionRequestFactory( + model_registry, + self.prompt_builder, + self.model_wrapper_builder, + logger=logger, + ) + self.hierarchical_coordinator = HierarchicalCoordinator(model_registry, logger) + + # Instantiate components + self.submitter = BatchSubmitter( + model_registry, + self.client_rotator, + config, + self.entity_counter, + self.context_preparer, + self.request_factory, + logger, + ) + self.status_checker = BatchStatusChecker( + self.client_rotator, self.entity_counter, logger + ) + self.retriever = BatchResultRetriever( + model_registry, logger, analytics_collector + ) + self.processor = BatchProcessor( + model_registry, + config, + analytics_collector, + self.client_rotator, + self.entity_counter, + self.submitter, + self.status_checker, + self.retriever, + self.consensus_runner, + self.hierarchical_coordinator, + logger, + ) + + async def submit_batch( + self, + db_session: Session, + input_strings: list[str], + extraction_example_json: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", + count_entities: bool = False, + custom_counting_context: str | list[str] = "", + ) -> str: + """Submits a batch job and returns root_batch_id.""" + return await self.submitter.submit_batch( + db_session, + input_strings, + extraction_example_json, + extraction_example_object, + custom_extraction_process, + custom_extraction_guidelines, + custom_final_checklist, + custom_context, + count_entities, + custom_counting_context, + ) + + async def create_continuation_batch( + self, + db_session: Session, + original_batch_id: str, + new_config_dict: dict[str, Any], + start_from_step_index: int, + ) -> str: + """ + Creates a new batch cycle continuing from a previous batch's state. + Copies completed steps up to start_from_step_index into the new batch. + """ + return await self.submitter.create_continuation_batch( + db_session, original_batch_id, new_config_dict, start_from_step_index + ) + + async def get_status( + self, root_batch_id: str, db_session: Session + ) -> BatchJobStatus: + return await self.status_checker.get_status(root_batch_id, db_session) + + async def process_batch( + self, root_batch_id: str, db_session: Session + ) -> BatchProcessResult: + return await self.processor.process_batch(root_batch_id, db_session) + + async def monitor_batch_job( + self, root_batch_id: str, db_session: Session, poll_interval: int = 60 + ) -> BatchProcessResult: + """ + Polls the batch job status until it reaches a terminal state. + """ + self.logger.info(f"Monitoring batch job {root_batch_id}...") + + while True: + status = await self.get_status(root_batch_id, db_session) + self.logger.info(f"Batch Status: {status}") + + if status in [ + BatchJobStatus.READY_TO_PROCESS, + BatchJobStatus.COUNTING_READY_TO_PROCESS, + ]: + self.logger.info("Batch ready! Processing...") + result = await self.process_batch(root_batch_id, db_session) + + if result.status == BatchJobStatus.COMPLETED: + self.logger.info("Batch workflow completed successfully.") + if result.hydrated_objects: + self.processor.result_processor.persist( + result.hydrated_objects, db_session + ) + return result + + # Other non-terminal statuses mean we should continue polling + elif result.status not in [ + BatchJobStatus.COMPLETED, + BatchJobStatus.FAILED, + BatchJobStatus.CANCELLED, + ]: + self.logger.info( + f"Intermediate step processed (new status: {result.status}). Continuing to monitor..." + ) + else: + # Processing returned a terminal status + self.logger.error( + f"Batch processing failed with terminal status: {result.status} - {result.message}" + ) + return result + + elif status in [ + BatchJobStatus.COMPLETED, + BatchJobStatus.FAILED, + BatchJobStatus.CANCELLED, + ]: + self.logger.info(f"Batch job reached terminal state: {status}") + return await self.process_batch(root_batch_id, db_session) + + # Any other status is an in-progress state, so we wait. + self.logger.debug(f"Current status {status}, waiting for next poll.") + await asyncio.sleep(poll_interval) diff --git a/src/extrai/core/batch/batch_processor.py b/src/extrai/core/batch/batch_processor.py new file mode 100644 index 0000000..f8a4841 --- /dev/null +++ b/src/extrai/core/batch/batch_processor.py @@ -0,0 +1,405 @@ +import logging +from datetime import UTC, datetime + +from sqlalchemy.orm import Session +from sqlmodel import select + +from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from extrai.core.batch_models import ( + BatchJobContext, + BatchJobStatus, + BatchJobStep, + BatchProcessResult, +) +from extrai.core.client_rotator import ClientRotator +from extrai.core.cost_calculator import track_usage_from_response +from extrai.core.entity_counter import EntityCounter +from extrai.core.extraction_config import ExtractionConfig +from extrai.core.model_registry import ModelRegistry +from extrai.core.result_processor import ResultProcessor +from extrai.core.shared.consensus_runner import ConsensusRunner +from extrai.core.shared.hierarchical_coordinator import HierarchicalCoordinator +from extrai.utils.alignment_utils import normalize_json_revisions + +from .batch_result_retriever import BatchResultRetriever +from .batch_status_checker import BatchStatusChecker +from .batch_submitter import BatchSubmitter + + +class BatchProcessor: + def __init__( + self, + model_registry: ModelRegistry, + config: ExtractionConfig, + analytics_collector: WorkflowAnalyticsCollector, + client_rotator: ClientRotator, + entity_counter: EntityCounter, + submitter: BatchSubmitter, + status_checker: BatchStatusChecker, + retriever: BatchResultRetriever, + consensus_runner: ConsensusRunner, + hierarchical_coordinator: HierarchicalCoordinator, + logger: logging.Logger, + ): + self.model_registry = model_registry + self.config = config + self.analytics_collector = analytics_collector + self.client_rotator = client_rotator + self.entity_counter = entity_counter + self.submitter = submitter + self.status_checker = status_checker + self.retriever = retriever + self.consensus_runner = consensus_runner + self.hierarchical_coordinator = hierarchical_coordinator + self.logger = logger + + # Inject ResultProcessor + self.result_processor = ResultProcessor( + self.model_registry, self.analytics_collector, self.logger + ) + + def _get_absolute_step_index(self, model_index: int, phase: str, count_entities: bool) -> int: + """Calculates the absolute workflow step index based on model index and phase.""" + if not count_entities: + return model_index + + # If counting enabled: + # Model 0: Count (0), Extract (1) + # Model 1: Count (2), Extract (3) + base = model_index * 2 + return base if phase == "counting" else base + 1 + + async def process_batch( + self, root_batch_id: str, db_session: Session + ) -> BatchProcessResult: + status = await self.status_checker.get_status(root_batch_id, db_session) + context = db_session.get(BatchJobContext, root_batch_id) + + # 1. Already Completed + if status == BatchJobStatus.COMPLETED and context.results: + return await self._finalize_completion(context, db_session) + + # 2. Counting Phase Completed -> Submit Extraction + if status == BatchJobStatus.COUNTING_READY_TO_PROCESS: + return await self._process_counting_completion(context, db_session) + + # 3. Extraction Phase Ready -> Process Results + if status == BatchJobStatus.READY_TO_PROCESS: + return await self._process_extraction_completion(context, db_session) + + return BatchProcessResult( + status=status, message="Batch not ready for processing" + ) + + async def _process_counting_completion( + self, context: BatchJobContext, db_session: Session + ) -> BatchProcessResult: + try: + # Use counting client + client = self.entity_counter.llm_client + results_content = await client.retrieve_batch_results( + context.current_batch_id + ) + + # Debug log the raw result format + self.logger.debug( + f"Counting results type: {type(results_content)}, " + f"first 200 chars: {str(results_content)[:200]}" + ) + + # Determine expected models for validation + if context.config.hierarchical: + current_idx = context.config.current_model_index + if 0 <= current_idx < len(self.model_registry.models): + target_model_names = [ + self.model_registry.models[current_idx].__name__ + ] + else: + target_model_names = self.model_registry.get_all_model_names() + else: + target_model_names = self.model_registry.get_all_model_names() + + import json + revisions = [] + + # Extract raw content from results + # Handle both string (JSONL) and list return types + lines = [] + if isinstance(results_content, str): + lines = [l.strip() for l in results_content.strip().split('\n') if l.strip()] + elif isinstance(results_content, list): + lines = results_content + else: + raise ValueError(f"Unexpected results content type: {type(results_content)}") + + if not lines: + raise ValueError("Empty results content") + + # Parse each line as a revision + for raw_content in lines: + try: + if isinstance(raw_content, str): + wrapper = json.loads(raw_content) + else: + wrapper = raw_content + + # Check if it's wrapped in OpenAI batch response format + if "response" in wrapper and "body" in wrapper.get("response", {}): + body = wrapper["response"]["body"] + + if self.analytics_collector: + track_usage_from_response( + wrapper, + client, + self.analytics_collector, + context.current_batch_id, + extra_details={"phase": "counting"}, + ) + + if "choices" in body and body["choices"]: + content = body["choices"][0]["message"]["content"] + parsed_json = json.loads(content) + revisions.append(parsed_json) + else: + # Maybe it's directly the JSON string or dict + if isinstance(wrapper, str): + revisions.append(json.loads(wrapper)) + elif isinstance(wrapper, dict): + revisions.append(wrapper) + + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse counting result as JSON: {e}") + continue + + self.logger.debug(f"Parsed {len(revisions)} counting revisions") + + # Recreate original prompts to use for consensus fallback if needed + from extrai.core.prompts.counting import ( + generate_entity_counting_system_prompt, + generate_entity_counting_user_prompt, + ) + + schema_json = self.model_registry.get_schema_for_models(target_model_names) + system_prompt = generate_entity_counting_system_prompt( + target_model_names, + schema_json, + context.config.custom_counting_context, + ) + user_prompt = generate_entity_counting_user_prompt(context.input_strings) + target_json_schema = self.entity_counter.get_counting_model(target_model_names).model_json_schema() if self.config.use_structured_output else None + + # Achieve consensus + consensus_result = await self.entity_counter.counting_consensus.achieve_consensus( + revisions=revisions, + system_prompt=system_prompt, + user_prompt=user_prompt, + target_json_schema=target_json_schema, + ) + + # Filter out any hallucinated models not in target_model_names + entity_descriptions = [ + item for item in consensus_result if item.get("model") in target_model_names + ] + + self.logger.debug(f"Extracted entity descriptions: {entity_descriptions}") + + context.config = context.config.evolve( + expected_entity_descriptions=entity_descriptions + ) + + # Create step for counting completion + step_index_abs = self._get_absolute_step_index( + context.config.current_model_index, "counting", context.config.count_entities + ) + + step = BatchJobStep( + batch_id=context.root_batch_id, + step_index=step_index_abs, + status=BatchJobStatus.COMPLETED, + result=entity_descriptions, + metadata_json={"phase": "counting"}, + ) + db_session.add(step) + db_session.add(context) + db_session.commit() + + # Submit next phase + await self.submitter._submit_extraction_phase( + context, db_session, step_index=context.config.current_model_index + ) + + return BatchProcessResult( + status=BatchJobStatus.SUBMITTED, + message="Counting complete, extraction submitted.", + ) + except Exception as e: + self.logger.error(f"Error processing counting completion: {e}", exc_info=True) + context.status = BatchJobStatus.FAILED + db_session.add(context) + db_session.commit() + return BatchProcessResult( + status=BatchJobStatus.FAILED, message=f"Processing failed: {e}" + ) + + async def _process_extraction_completion( + self, context: BatchJobContext, db_session: Session + ) -> BatchProcessResult: + try: + client = self.client_rotator.get_next_client() + results, validation_errors = await self.retriever.retrieve_and_validate_results( + context, client + ) + + if validation_errors: + return await self._handle_batch_retry( + context, db_session, results, validation_errors + ) + + # Re-run consensus with partial results + all_revisions = [r["revisions"] for r in results] + all_revisions.extend(context.config.partial_results) + all_revisions = normalize_json_revisions(all_revisions) + + processed_results = self.consensus_runner.run(all_revisions) + + return await self._process_hierarchical_step( + context, db_session, processed_results + ) + + except Exception as e: + self.logger.error( + f"Error processing batch completion for {context.root_batch_id}: {e}", + exc_info=True, + ) + context.status = BatchJobStatus.FAILED + db_session.add(context) + db_session.commit() + return BatchProcessResult( + status=BatchJobStatus.FAILED, message=f"Processing failed: {e}" + ) + + async def _process_hierarchical_step( + self, + context: BatchJobContext, + db_session: Session, + processed_results: list[dict], + ): + # Save step results + step_index_abs = self._get_absolute_step_index( + context.config.current_model_index, "extraction", context.config.count_entities + ) + + step = BatchJobStep( + batch_id=context.root_batch_id, + step_index=step_index_abs, + status=BatchJobStatus.COMPLETED, + result=processed_results, + ) + db_session.add(step) + db_session.commit() + + # Check for completion using HierarchicalCoordinator + is_final = False + if not context.config.hierarchical: + is_final = True + else: + is_final = self.hierarchical_coordinator.is_final_step(context.config.current_model_index) + + if is_final: + return await self._finalize_completion(context, db_session) + else: + # Submit next step using HierarchicalCoordinator + next_step_index = self.hierarchical_coordinator.next_index(context.config.current_model_index) + + if context.config.count_entities: + await self.submitter._submit_counting_phase( + context, db_session, step_index=next_step_index + ) + return BatchProcessResult( + status=BatchJobStatus.COUNTING_SUBMITTED, + message=f"Step {context.config.current_model_index} complete, counting for step {next_step_index} submitted.", + ) + else: + await self.submitter._submit_extraction_phase( + context, db_session, step_index=next_step_index + ) + return BatchProcessResult( + status=BatchJobStatus.SUBMITTED, + message=f"Step {context.config.current_model_index} complete, extraction for step {next_step_index} submitted.", + ) + + async def _handle_batch_retry( + self, + context: BatchJobContext, + db_session: Session, + validated_results: list[dict], + validation_errors: list[dict], + ) -> BatchProcessResult: + if not validated_results: + context.status = BatchJobStatus.FAILED + db_session.add(context) + db_session.commit() + return BatchProcessResult( + status=BatchJobStatus.FAILED, + message="All revisions failed validation, cannot retry.", + errors=validation_errors, + ) + + # Store valid partial results + partial_results = [r["revisions"] for r in validated_results] + context.config = context.config.evolve( + partial_results=context.config.partial_results + partial_results + ) + db_session.add(context) + db_session.commit() + + # Resubmit with fewer revisions + num_to_retry = len(validation_errors) + await self.submitter._submit_extraction_phase( + context, + db_session, + step_index=context.config.current_model_index, + num_revisions=num_to_retry, + ) + + return BatchProcessResult( + status=BatchJobStatus.SUBMITTED, + message=f"Partial success. Retrying {num_to_retry} failed revisions.", + errors=validation_errors, + ) + + async def _finalize_completion( + self, context: BatchJobContext, db_session: Session + ) -> BatchProcessResult: + # Load all step results + steps = db_session.exec( + select(BatchJobStep) + .where(BatchJobStep.batch_id == context.root_batch_id) + .where(BatchJobStep.status == BatchJobStatus.COMPLETED) + .order_by(BatchJobStep.step_index) + ).all() + final_results = [ + item + for s in steps + if not s.metadata_json or s.metadata_json.get("phase") != "counting" + for item in s.result + ] + + # Process final results using the injected result_processor + processed_objects = self.result_processor.hydrate(final_results, db_session) + + # Persist objects (this links FKs and commits) + self.result_processor.persist(processed_objects, db_session) + + context.results = [p.model_dump(mode='json') for p in processed_objects] + context.status = BatchJobStatus.COMPLETED + context.updated_at = datetime.now(UTC) + db_session.add(context) + db_session.commit() + + return BatchProcessResult( + status=BatchJobStatus.COMPLETED, + message="Batch processing complete.", + hydrated_objects=processed_objects, + original_pk_map=self.result_processor.original_pk_map, + ) diff --git a/src/extrai/core/batch/batch_result_retriever.py b/src/extrai/core/batch/batch_result_retriever.py new file mode 100644 index 0000000..2a0243d --- /dev/null +++ b/src/extrai/core/batch/batch_result_retriever.py @@ -0,0 +1,99 @@ +import json +import logging + +from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.batch_models import BatchJobContext +from extrai.core.cost_calculator import track_usage_from_response +from extrai.core.model_registry import ModelRegistry +from extrai.utils.llm_output_processing import process_and_validate_llm_output + + +class BatchResultRetriever: + def __init__( + self, + model_registry: ModelRegistry, + logger: logging.Logger, + analytics_collector: WorkflowAnalyticsCollector | None = None, + ): + self.model_registry = model_registry + self.logger = logger + self.analytics_collector = analytics_collector + + async def retrieve_and_validate_results( + self, context: BatchJobContext, client: BaseLLMClient + ) -> tuple[list[dict], list[dict]]: + results_content = await client.retrieve_batch_results(context.current_batch_id) + + # DEBUG: Log context info for diagnosis + self.logger.debug( + f"[BatchResultRetriever] current_model_index={context.config.current_model_index}, " + f"hierarchical={context.config.hierarchical}, root_model={self.model_registry.root_model.__name__}" + ) + + # Determine the correct default model type based on current model index + # This is critical for hierarchical extraction where we process models in order + current_model_index = context.config.current_model_index + if context.config.hierarchical and 0 <= current_model_index < len(self.model_registry.models): + default_model_type = self.model_registry.models[current_model_index].__name__ + self.logger.debug( + f"[BatchResultRetriever] Using hierarchical model type: {default_model_type}" + ) + else: + default_model_type = self.model_registry.root_model.__name__ + self.logger.debug( + f"[BatchResultRetriever] Using root model type: {default_model_type}" + ) + + # Handle both string (JSONL) and list return types + if isinstance(results_content, str): + # Split by lines and filter empty lines + result_lines = [l.strip() for l in results_content.strip().split('\n') if l.strip()] + else: + result_lines = results_content if isinstance(results_content, list) else [results_content] + + validated_results = [] + validation_errors = [] + model_schema = self.model_registry.model_map + + for res in result_lines: + try: + # Extract the actual LLM content from the batch response wrapper + # Batch responses are wrapped like: {"id": "...", "response": {"body": {"choices": [{"message": {"content": "..."}}]}}} + raw_content = res + if isinstance(res, str): + try: + parsed = json.loads(res) + # Try to extract content from batch response structure + try: + if self.analytics_collector: + track_usage_from_response( + parsed, + client, + self.analytics_collector, + context.current_batch_id, + ) + + extracted = client.extract_content_from_batch_response(parsed) + if extracted: + raw_content = extracted + except NotImplementedError: + # Client does not support extraction, proceed with parsed or raw + pass + except json.JSONDecodeError: + pass # Use raw string as-is + + processed = process_and_validate_llm_output( + raw_content, + model_schema, + self.logger, + default_model_type=default_model_type, + ) + validated_results.append({"revisions": processed}) + except Exception as e: + self.logger.warning( + f"Validation failed for a batch result: {e}", exc_info=True + ) + validation_errors.append({"original": res, "error": str(e)}) + + return validated_results, validation_errors diff --git a/src/extrai/core/batch/batch_status_checker.py b/src/extrai/core/batch/batch_status_checker.py new file mode 100644 index 0000000..31d17b4 --- /dev/null +++ b/src/extrai/core/batch/batch_status_checker.py @@ -0,0 +1,92 @@ +import logging +from datetime import UTC, datetime + +from sqlalchemy.orm import Session + +from extrai.core.base_llm_client import ProviderBatchStatus +from extrai.core.batch_models import BatchJobContext, BatchJobStatus +from extrai.core.client_rotator import ClientRotator +from extrai.core.entity_counter import EntityCounter + + +class BatchStatusChecker: + def __init__( + self, + client_rotator: ClientRotator, + entity_counter: EntityCounter, + logger: logging.Logger, + ): + self.client_rotator = client_rotator + self.entity_counter = entity_counter + self.logger = logger + + async def get_status( + self, root_batch_id: str, db_session: Session + ) -> BatchJobStatus: + context = db_session.get(BatchJobContext, root_batch_id) + if not context: + raise ValueError(f"Batch job {root_batch_id} not found") + + terminal_states = [ + BatchJobStatus.COMPLETED, + BatchJobStatus.FAILED, + BatchJobStatus.CANCELLED, + BatchJobStatus.READY_TO_PROCESS, + BatchJobStatus.COUNTING_READY_TO_PROCESS, + ] + if context.status in terminal_states: + return context.status + + try: + # Determine client based on phase + if context.status in [ + BatchJobStatus.COUNTING_SUBMITTED, + BatchJobStatus.COUNTING_PROCESSING, + ]: + client = self.entity_counter.llm_client + else: + client = self.client_rotator.get_next_client() + + provider_status = await client.get_batch_status( + context.current_batch_id + ) + + if provider_status == ProviderBatchStatus.COMPLETED: + if context.status in [ + BatchJobStatus.COUNTING_SUBMITTED, + BatchJobStatus.COUNTING_PROCESSING, + ]: + new_status = BatchJobStatus.COUNTING_READY_TO_PROCESS + else: + new_status = BatchJobStatus.READY_TO_PROCESS + elif provider_status == ProviderBatchStatus.FAILED: + new_status = BatchJobStatus.FAILED + elif provider_status == ProviderBatchStatus.CANCELLED: + new_status = BatchJobStatus.CANCELLED + elif provider_status == ProviderBatchStatus.PENDING: + if context.status in [ + BatchJobStatus.COUNTING_SUBMITTED, + BatchJobStatus.COUNTING_PROCESSING, + ]: + new_status = BatchJobStatus.COUNTING_SUBMITTED + else: + new_status = BatchJobStatus.SUBMITTED + else: # PROCESSING + if context.status in [ + BatchJobStatus.COUNTING_SUBMITTED, + BatchJobStatus.COUNTING_PROCESSING, + ]: + new_status = BatchJobStatus.COUNTING_PROCESSING + else: + new_status = BatchJobStatus.PROCESSING + + if new_status != context.status: + context.status = new_status + context.updated_at = datetime.now(UTC) + db_session.add(context) + db_session.commit() + + except Exception as e: + self.logger.error(f"Failed to check batch status: {e}", exc_info=True) + + return context.status diff --git a/src/extrai/core/batch/batch_submitter.py b/src/extrai/core/batch/batch_submitter.py new file mode 100644 index 0000000..60de1e1 --- /dev/null +++ b/src/extrai/core/batch/batch_submitter.py @@ -0,0 +1,435 @@ +import logging +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.orm import Session +from sqlmodel import SQLModel, select + +from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.batch_models import ( + BatchJobContext, + BatchJobStatus, + BatchJobStep, +) +from extrai.core.client_rotator import ClientRotator +from extrai.core.config.batch_job_config import BatchJobConfig +from extrai.core.entity_counter import EntityCounter +from extrai.core.extraction_config import ExtractionConfig +from extrai.core.extraction_context_preparer import ExtractionContextPreparer +from extrai.core.extraction_request_factory import ExtractionRequestFactory +from extrai.core.model_registry import ModelRegistry +from extrai.utils.serialization_utils import resolve_step_param + + +class BatchSubmitter: + def __init__( + self, + model_registry: ModelRegistry, + client_rotator: ClientRotator, + config: ExtractionConfig, + entity_counter: EntityCounter, + context_preparer: ExtractionContextPreparer, + request_factory: ExtractionRequestFactory, + logger: logging.Logger, + ): + self.model_registry = model_registry + self.client_rotator = client_rotator + self.config = config + self.entity_counter = entity_counter + self.context_preparer = context_preparer + self.request_factory = request_factory + self.logger = logger + + def _get_absolute_step_index(self, model_index: int, phase: str, count_entities: bool) -> int: + """Calculates the absolute workflow step index based on model index and phase.""" + if not count_entities: + return model_index + + # If counting enabled: + # Model 0: Count (0), Extract (1) + # Model 1: Count (2), Extract (3) + base = model_index * 2 + return base if phase == "counting" else base + 1 + + async def submit_batch( + self, + db_session: Session, + input_strings: list[str], + extraction_example_json: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", + count_entities: bool = False, + custom_counting_context: str | list[str] = "", + ) -> str: + """Submits a batch job and returns root_batch_id.""" + if not input_strings: + raise ValueError("input_strings cannot be empty") + + # Prepare example + example_json = await self.context_preparer.prepare_example( + extraction_example_json, + extraction_example_object, + self.client_rotator.get_next_client, + ) + + root_batch_id = str(uuid.uuid4()) + + # Initialize configuration + config_data = BatchJobConfig( + extraction_example_json=example_json, + custom_extraction_process=custom_extraction_process, + custom_extraction_guidelines=custom_extraction_guidelines, + custom_final_checklist=custom_final_checklist, + custom_context=custom_context, + count_entities=count_entities, + custom_counting_context=custom_counting_context, + schema_json=self.model_registry.llm_schema_json, + ) + + if self.config.use_hierarchical_extraction: + config_data = config_data.evolve(hierarchical=True, current_model_index=0) + + context = BatchJobContext( + root_batch_id=root_batch_id, + current_batch_id="pending", + status=BatchJobStatus.SUBMITTED, + input_strings=input_strings, + config=config_data, + ) + db_session.add(context) + db_session.commit() + + if count_entities: + await self._submit_counting_phase(context, db_session) + else: + await self._submit_extraction_phase(context, db_session, step_index=0) + + self.logger.info(f"Batch workflow initiated: {root_batch_id}") + return root_batch_id + + async def create_continuation_batch( + self, + db_session: Session, + original_batch_id: str, + new_config_dict: dict[str, Any], + start_from_step_index: int, + ) -> str: + """ + Creates a new batch cycle continuing from a previous batch's state. + Copies completed steps up to start_from_step_index into the new batch. + """ + old_context = db_session.get(BatchJobContext, original_batch_id) + if not old_context: + raise ValueError("Old batch not found") + + new_batch_id = str(uuid.uuid4()) + + new_config = BatchJobConfig(**new_config_dict) + + # Ensure new config has required fields + if self.config.use_hierarchical_extraction and not new_config.hierarchical: + new_config = new_config.evolve( + hierarchical=True, current_model_index=start_from_step_index + ) + + if old_context.config.expected_entity_descriptions is not None: + new_config = new_config.evolve( + expected_entity_descriptions=old_context.config.expected_entity_descriptions + ) + + new_context = BatchJobContext( + root_batch_id=new_batch_id, + current_batch_id="pending", + status=BatchJobStatus.SUBMITTED, + input_strings=old_context.input_strings, + config=new_config, + ) + db_session.add(new_context) + db_session.commit() + + # Copy valid steps from old batch + if start_from_step_index > 0: + old_steps = db_session.exec( + select(BatchJobStep) + .where(BatchJobStep.batch_id == original_batch_id) + .where(BatchJobStep.status == BatchJobStatus.COMPLETED) + ).all() + + for step in old_steps: + effective_index = step.step_index + + if effective_index < start_from_step_index: + new_step = BatchJobStep( + batch_id=new_batch_id, + step_index=step.step_index, + status=step.status, + result=step.result, + metadata_json=step.metadata_json, + ) + db_session.add(new_step) + + db_session.commit() + + self.logger.info( + f"Created continuation batch {new_batch_id} from {original_batch_id}, starting at step {start_from_step_index}" + ) + + # Determine starting phase and normalize step index + target_step_index = start_from_step_index + is_counting_phase = False + + if new_context.config.count_entities: + if new_context.config.hierarchical: + # Interleaved: Even=Count, Odd=Extract. Model = step // 2 + target_step_index = start_from_step_index // 2 + is_counting_phase = (start_from_step_index % 2 == 0) + else: + # Non-hierarchical: 0=Count, 1=Extract + target_step_index = 0 + is_counting_phase = (start_from_step_index == 0) + + if is_counting_phase: + await self._submit_counting_phase( + new_context, db_session, step_index=target_step_index + ) + else: + await self._submit_extraction_phase( + new_context, db_session, step_index=target_step_index + ) + + return new_batch_id + + async def _submit_counting_phase( + self, + context: BatchJobContext, + db_session: Session, + step_index: int | None = None, + ): + if context.config.hierarchical and step_index is not None: + context.config = context.config.evolve(current_model_index=step_index) + db_session.add(context) + db_session.commit() + + input_strings = context.input_strings + + # Determine which models to count + current_step_index = step_index + if context.config.hierarchical: + idx = ( + step_index + if step_index is not None + else context.config.current_model_index + ) + current_step_index = idx + if 0 <= idx < len(self.model_registry.models): + model_names = [self.model_registry.models[idx].__name__] + else: + model_names = self.model_registry.get_all_model_names() + else: + model_names = self.model_registry.get_all_model_names() + + # Retrieve previous entities from completed steps + previous_entities = [] + if ( + context.config.hierarchical + and current_step_index is not None + and current_step_index > 0 + ): + step_threshold = self._get_absolute_step_index( + current_step_index, "counting", context.config.count_entities + ) + steps = db_session.exec( + select(BatchJobStep) + .where(BatchJobStep.batch_id == context.root_batch_id) + .where(BatchJobStep.step_index < step_threshold) + .where(BatchJobStep.status == BatchJobStatus.COMPLETED) + .order_by(BatchJobStep.step_index) + ).all() + for s in steps: + # Exclude counting phases from previous entities + if s.metadata_json and s.metadata_json.get("phase") == "counting": + continue + previous_entities.extend(s.result) + + custom_counting_context = context.config.custom_counting_context + examples = context.config.extraction_example_json + + total_steps = ( + len(self.model_registry.models) + if context.config.hierarchical + else 1 + ) + resolved_context = resolve_step_param( + custom_counting_context, + current_step_index if current_step_index is not None else 0, + total_steps, + ) + + system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts( + input_strings, + model_names, + resolved_context, + previous_entities=previous_entities if previous_entities else None, + examples=examples, + ) + + client = self.entity_counter.llm_client + requests = self._create_batch_requests( + system_prompt, user_prompt, num_revisions=self.config.num_counting_revisions, override_client=client + ) + + response_model = None + if self.config.use_structured_output: + response_model = self.entity_counter.get_counting_model(model_names) + + batch_job = await client.create_batch_job( + requests, response_model=response_model + ) + if hasattr(batch_job, "id"): + provider_batch_id = batch_job.id + elif hasattr(batch_job, "name"): + provider_batch_id = batch_job.name + else: + provider_batch_id = str(batch_job) + + context.current_batch_id = provider_batch_id + context.status = BatchJobStatus.COUNTING_SUBMITTED + context.updated_at = datetime.now(UTC) + db_session.add(context) + db_session.commit() + + self.logger.info( + f"Submitted counting batch for the models {model_names} for {context.root_batch_id}: {provider_batch_id}" + ) + + async def _submit_extraction_phase( + self, + context: BatchJobContext, + db_session: Session, + step_index: int = 0, + num_revisions: int | None = None, + ): + # Update current index in config if hierarchical + if context.config.hierarchical: + context.config = context.config.evolve(current_model_index=step_index) + db_session.add(context) + db_session.commit() + + # Retrieve previous entities from completed steps + previous_entities = [] + if context.config.hierarchical and step_index > 0: + step_threshold = self._get_absolute_step_index( + step_index, "extraction", context.config.count_entities + ) + steps = db_session.exec( + select(BatchJobStep) + .where(BatchJobStep.batch_id == context.root_batch_id) + .where(BatchJobStep.step_index < step_threshold) + .where(BatchJobStep.status == BatchJobStatus.COMPLETED) + .order_by(BatchJobStep.step_index) + ).all() + for s in steps: + # Exclude counting phases from previous entities + if s.metadata_json and s.metadata_json.get("phase") == "counting": + continue + previous_entities.extend(s.result) + + # Prepare request + total_steps = ( + len(self.model_registry.models) + if context.config.hierarchical + else 1 + ) + + request = self.request_factory.prepare_request( + input_strings=context.input_strings, + config=self.config, + extraction_example_json=context.config.extraction_example_json, + custom_extraction_process=resolve_step_param( + context.config.custom_extraction_process, + step_index, + total_steps, + ), + custom_extraction_guidelines=resolve_step_param( + context.config.custom_extraction_guidelines, + step_index, + total_steps, + ), + custom_final_checklist=resolve_step_param( + context.config.custom_final_checklist, + step_index, + total_steps, + ), + custom_context=resolve_step_param( + context.config.custom_context, step_index, total_steps + ), + expected_entity_descriptions=context.config.expected_entity_descriptions, + previous_entities=previous_entities if previous_entities else None, + hierarchical_model_index=step_index + if context.config.hierarchical + else None, + ) + + requests = self._create_batch_requests( + request.system_prompt, + request.user_prompt, + request.json_schema, + num_revisions=num_revisions, + ) + + client = self.client_rotator.get_next_client() + batch_job = await client.create_batch_job( + requests, response_model=request.response_model + ) + if hasattr(batch_job, "id"): + provider_batch_id = batch_job.id + elif hasattr(batch_job, "name"): + provider_batch_id = batch_job.name + else: + provider_batch_id = str(batch_job) + + context.current_batch_id = provider_batch_id + context.status = BatchJobStatus.SUBMITTED + context.updated_at = datetime.now(UTC) + db_session.add(context) + db_session.commit() + + phase_name = ( + f"step {step_index}" if context.config.hierarchical else "extraction" + ) + self.logger.info( + f"Submitted extraction batch ({phase_name}) for {context.root_batch_id}: {provider_batch_id}" + ) + + def _create_batch_requests( + self, + system_prompt: str, + user_prompt: str, + json_schema: str | None = None, + num_revisions: int | None = None, + override_client: BaseLLMClient | None = None, + ) -> list[dict]: + """Create batch requests in OpenAI batch format.""" + num_revs = num_revisions or self.config.num_llm_revisions + client = override_client or self.client_rotator.get_next_client() + + requests = [] + for _ in range(num_revs): + req = client.prepare_request( + system_prompt=system_prompt, + user_prompt=user_prompt, + json_schema=json_schema, + ) + # Wrap request in OpenAI batch format + batch_req = { + "custom_id": str(uuid.uuid4()), + "method": "POST", + "url": "/v1/chat/completions", + "body": req, + } + requests.append(batch_req) + return requests diff --git a/src/extrai/core/batch_models.py b/src/extrai/core/batch_models.py index 93e6525..3ac65f8 100644 --- a/src/extrai/core/batch_models.py +++ b/src/extrai/core/batch_models.py @@ -1,8 +1,36 @@ +from dataclasses import asdict, is_dataclass +from datetime import UTC, datetime from enum import Enum -from typing import List, Optional, Any, Dict -from datetime import datetime, timezone -from sqlmodel import SQLModel, Field, Relationship -from sqlalchemy import JSON +from typing import Any + +from sqlalchemy import JSON, Column +from sqlalchemy.types import TypeDecorator +from sqlmodel import Field, Relationship, SQLModel + +from enferno.extensions import db + +from .config.batch_job_config import BatchJobConfig + + +class DataClassJSON(TypeDecorator): + """Custom SQLAlchemy type for dataclasses stored as JSON""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value: Any | None, dialect) -> dict | None: + if value is None: + return None + if is_dataclass(value): + return asdict(value) + return value + + def process_result_value( + self, value: Any | None, dialect + ) -> BatchJobConfig | None: + if value is None: + return None + return BatchJobConfig(**value) class BatchJobStatus(str, Enum): @@ -23,37 +51,42 @@ class BatchJobContext(SQLModel, table=True): """ Stores the state of a batch job managed by the WorkflowOrchestrator. """ + metadata = db.metadata root_batch_id: str = Field(primary_key=True) current_batch_id: str = Field(index=True) # Provider's batch ID status: BatchJobStatus = Field(default=BatchJobStatus.SUBMITTED) - input_strings: List[str] = Field(default_factory=list, sa_type=JSON) - config: Dict[str, Any] = Field(default_factory=dict, sa_type=JSON) + input_strings: list[str] = Field(default_factory=list, sa_type=JSON) + config: BatchJobConfig = Field( + default_factory=BatchJobConfig, sa_column=Column(DataClassJSON) + ) # Store results when completed - results: Optional[List[Any]] = Field(default=None, sa_type=JSON) + results: list[Any] | None = Field(default=None, sa_type=JSON) # Tracking retries retry_count: int = Field(default=0) - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) - updated_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) # Error tracking - last_error: Optional[str] = None + last_error: str | None = None - steps: List["BatchJobStep"] = Relationship(back_populates="batch") + steps: list["BatchJobStep"] = Relationship(back_populates="batch") class BatchJobStep(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) + metadata = db.metadata + + id: int | None = Field(default=None, primary_key=True) batch_id: str = Field(foreign_key="batchjobcontext.root_batch_id") step_index: int status: BatchJobStatus = Field(default=BatchJobStatus.COMPLETED) - result: List[Any] = Field(default_factory=list, sa_type=JSON) - metadata_json: Dict[str, Any] = Field(default_factory=dict, sa_type=JSON) - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + result: list[Any] = Field(default_factory=list, sa_type=JSON) + metadata_json: dict[str, Any] = Field(default_factory=dict, sa_type=JSON) + created_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) batch: BatchJobContext = Relationship(back_populates="steps") @@ -64,7 +97,7 @@ class BatchProcessResult(SQLModel): """ status: BatchJobStatus - hydrated_objects: Optional[List[Any]] = None - original_pk_map: Optional[Dict[Any, Any]] = Field(default=None, exclude=True) - retry_batch_id: Optional[str] = None - message: Optional[str] = None + hydrated_objects: list[Any] | None = None + original_pk_map: dict[Any, Any] | None = Field(default=None, exclude=True) + retry_batch_id: str | None = None + message: str | None = None diff --git a/src/extrai/core/batch_pipeline.py b/src/extrai/core/batch_pipeline.py deleted file mode 100644 index 1e9c1bd..0000000 --- a/src/extrai/core/batch_pipeline.py +++ /dev/null @@ -1,754 +0,0 @@ -import json -import uuid -import logging -from datetime import datetime, timezone -from typing import List, Dict, Any, Optional, Union -from sqlalchemy.orm import Session -from sqlmodel import SQLModel, select - -from extrai.core.base_llm_client import BaseLLMClient -from .client_rotator import ClientRotator -from .extraction_context_preparer import ExtractionContextPreparer -from .model_registry import ModelRegistry -from .extraction_config import ExtractionConfig -from .prompt_builder import PromptBuilder -from .entity_counter import EntityCounter -from .analytics_collector import WorkflowAnalyticsCollector -from .batch_models import ( - BatchJobContext, - BatchJobStatus, - BatchProcessResult, - BatchJobStep, -) -from .model_wrapper_builder import ModelWrapperBuilder -from extrai.utils.llm_output_processing import process_and_validate_llm_output -from extrai.utils.alignment_utils import normalize_json_revisions -from .json_consensus import JSONConsensus -from .extraction_request_factory import ExtractionRequestFactory - - -class BatchPipeline: - """Manages batch extraction workflows.""" - - def __init__( - self, - model_registry: ModelRegistry, - llm_client: Union["BaseLLMClient", List["BaseLLMClient"]], - config: ExtractionConfig, - analytics_collector: WorkflowAnalyticsCollector, - logger: logging.Logger, - counting_llm_client: Optional[BaseLLMClient] = None, - ): - self.model_registry = model_registry - self.config = config - self.analytics_collector = analytics_collector - self.logger = logger - - self.client_rotator = ClientRotator(llm_client) - self.prompt_builder = PromptBuilder(model_registry, logger=logger) - c_client = counting_llm_client or llm_client - if isinstance(c_client, list): - c_client = c_client[0] - - self.entity_counter = EntityCounter( - model_registry, c_client, config, analytics_collector, logger=logger - ) - self.context_preparer = ExtractionContextPreparer( - model_registry, - analytics_collector, - config.max_validation_retries_per_revision, - logger=logger, - ) - self.model_wrapper_builder = ModelWrapperBuilder() - self.consensus = JSONConsensus( - consensus_threshold=config.consensus_threshold, - conflict_resolver=config.conflict_resolver, - logger=logger, - ) - self.request_factory = ExtractionRequestFactory( - model_registry, - self.prompt_builder, - self.model_wrapper_builder, - logger=logger, - ) - - async def submit_batch( - self, - db_session: Session, - input_strings: List[str], - extraction_example_json: str = "", - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None, - custom_extraction_process: str = "", - custom_extraction_guidelines: str = "", - custom_final_checklist: str = "", - custom_context: str = "", - count_entities: bool = False, - custom_counting_context: str = "", - ) -> str: - """Submits a batch job and returns root_batch_id.""" - if not input_strings: - raise ValueError("input_strings cannot be empty") - - # Prepare example - example_json = await self.context_preparer.prepare_example( - extraction_example_json, - extraction_example_object, - self.client_rotator.get_next_client, - ) - - root_batch_id = str(uuid.uuid4()) - - # Initialize configuration - config_data = { - "extraction_example_json": example_json, - "custom_extraction_process": custom_extraction_process, - "custom_extraction_guidelines": custom_extraction_guidelines, - "custom_final_checklist": custom_final_checklist, - "custom_context": custom_context, - "count_entities": count_entities, - "custom_counting_context": custom_counting_context, - "schema_json": self.model_registry.llm_schema_json, - } - - if self.config.use_hierarchical_extraction: - config_data.update({"hierarchical": True, "current_model_index": 0}) - - context = BatchJobContext( - root_batch_id=root_batch_id, - current_batch_id="pending", - status=BatchJobStatus.SUBMITTED, - input_strings=input_strings, - config=config_data, - ) - db_session.add(context) - db_session.commit() - - if count_entities: - await self._submit_counting_phase(context, db_session) - else: - await self._submit_extraction_phase(context, db_session, step_index=0) - - self.logger.info(f"Batch workflow initiated: {root_batch_id}") - return root_batch_id - - async def create_continuation_batch( - self, - db_session: Session, - original_batch_id: str, - new_config: Dict[str, Any], - start_from_step_index: int, - ) -> str: - """ - Creates a new batch cycle continuing from a previous batch's state. - Copies completed steps up to start_from_step_index into the new batch. - """ - old_context = db_session.get(BatchJobContext, original_batch_id) - if not old_context: - raise ValueError("Old batch not found") - - new_batch_id = str(uuid.uuid4()) - - # Ensure new config has required fields - if self.config.use_hierarchical_extraction and "hierarchical" not in new_config: - new_config["hierarchical"] = True - new_config["current_model_index"] = start_from_step_index - - if "expected_entity_descriptions" in old_context.config: - new_config["expected_entity_descriptions"] = old_context.config[ - "expected_entity_descriptions" - ] - - new_context = BatchJobContext( - root_batch_id=new_batch_id, - current_batch_id="pending", - status=BatchJobStatus.SUBMITTED, - input_strings=old_context.input_strings, - config=new_config, - ) - db_session.add(new_context) - db_session.commit() - - # Copy valid steps from old batch - if start_from_step_index > 0: - old_steps = db_session.exec( - select(BatchJobStep) - .where(BatchJobStep.batch_id == original_batch_id) - .where(BatchJobStep.step_index < start_from_step_index) - .where(BatchJobStep.status == BatchJobStatus.COMPLETED) - ).all() - - for step in old_steps: - new_step = BatchJobStep( - batch_id=new_batch_id, - step_index=step.step_index, - status=step.status, - result=step.result, - metadata_json=step.metadata_json, - ) - db_session.add(new_step) - - db_session.commit() - - self.logger.info( - f"Created continuation batch {new_batch_id} from {original_batch_id}, starting at step {start_from_step_index}" - ) - - # Determine starting phase - # If counting is enabled, we start with counting phase for the starting step - if new_config.get("count_entities"): - step_idx = start_from_step_index if new_config.get("hierarchical") else 0 - await self._submit_counting_phase( - new_context, db_session, step_index=step_idx - ) - elif new_config.get("hierarchical"): - await self._submit_extraction_phase( - new_context, db_session, step_index=start_from_step_index - ) - else: - await self._submit_extraction_phase(new_context, db_session, step_index=0) - - return new_batch_id - - async def _submit_counting_phase( - self, - context: BatchJobContext, - db_session: Session, - step_index: Optional[int] = None, - ): - input_strings = context.input_strings - - # Determine which models to count - if context.config.get("hierarchical"): - idx = ( - step_index - if step_index is not None - else context.config.get("current_model_index", 0) - ) - if 0 <= idx < len(self.model_registry.models): - model_names = [self.model_registry.models[idx].__name__] - else: - model_names = self.model_registry.get_all_model_names() - else: - model_names = self.model_registry.get_all_model_names() - - custom_counting_context = context.config.get("custom_counting_context", "") - - system_prompt, user_prompt = self.entity_counter.prepare_counting_prompts( - input_strings, model_names, custom_counting_context - ) - - client = self.entity_counter.llm_client - requests = self._create_batch_requests( - system_prompt, user_prompt, num_revisions=1, override_client=client - ) - - batch_job = await client.create_batch_job(requests) - provider_batch_id = batch_job.id if hasattr(batch_job, "id") else str(batch_job) - - context.current_batch_id = provider_batch_id - context.status = BatchJobStatus.COUNTING_SUBMITTED - context.updated_at = datetime.now(timezone.utc) - db_session.add(context) - db_session.commit() - - self.logger.info( - f"Submitted counting batch for {context.root_batch_id}: {provider_batch_id}" - ) - - async def _submit_extraction_phase( - self, context: BatchJobContext, db_session: Session, step_index: int = 0 - ): - # Update current index in config if hierarchical - if context.config.get("hierarchical"): - # Update the config dictionary properly - new_config = context.config.copy() - new_config["current_model_index"] = step_index - context.config = new_config - db_session.add(context) - db_session.commit() - - # Retrieve previous entities from completed steps - previous_entities = [] - if context.config.get("hierarchical") and step_index > 0: - steps = db_session.exec( - select(BatchJobStep) - .where(BatchJobStep.batch_id == context.root_batch_id) - .where(BatchJobStep.step_index < step_index) - .where(BatchJobStep.status == BatchJobStatus.COMPLETED) - .order_by(BatchJobStep.step_index) - ).all() - for s in steps: - previous_entities.extend(s.result) - - # Prepare request - request = self.request_factory.prepare_request( - input_strings=context.input_strings, - config=self.config, - extraction_example_json=context.config.get("extraction_example_json", ""), - custom_extraction_process=context.config.get( - "custom_extraction_process", "" - ), - custom_extraction_guidelines=context.config.get( - "custom_extraction_guidelines", "" - ), - custom_final_checklist=context.config.get("custom_final_checklist", ""), - custom_context=context.config.get("custom_context", ""), - expected_entity_descriptions=context.config.get( - "expected_entity_descriptions" - ), - previous_entities=previous_entities if previous_entities else None, - hierarchical_model_index=step_index - if context.config.get("hierarchical") - else None, - ) - - requests = self._create_batch_requests( - request.system_prompt, request.user_prompt, request.json_schema - ) - - client = self.client_rotator.get_next_client() - batch_job = await client.create_batch_job(requests) - provider_batch_id = batch_job.id if hasattr(batch_job, "id") else str(batch_job) - - context.current_batch_id = provider_batch_id - context.status = BatchJobStatus.SUBMITTED - context.updated_at = datetime.now(timezone.utc) - db_session.add(context) - db_session.commit() - - phase_name = ( - f"step {step_index}" if context.config.get("hierarchical") else "extraction" - ) - self.logger.info( - f"Submitted extraction batch ({phase_name}) for {context.root_batch_id}: {provider_batch_id}" - ) - - async def get_status( - self, root_batch_id: str, db_session: Session - ) -> BatchJobStatus: - context = db_session.get(BatchJobContext, root_batch_id) - if not context: - raise ValueError(f"Batch job {root_batch_id} not found") - - terminal_states = [ - BatchJobStatus.COMPLETED, - BatchJobStatus.FAILED, - BatchJobStatus.CANCELLED, - BatchJobStatus.READY_TO_PROCESS, - BatchJobStatus.COUNTING_READY_TO_PROCESS, - ] - if context.status in terminal_states: - return context.status - - try: - # Determine client based on phase - if context.status in [ - BatchJobStatus.COUNTING_SUBMITTED, - BatchJobStatus.COUNTING_PROCESSING, - ]: - client = self.entity_counter.llm_client - else: - client = self.client_rotator.get_next_client() - - batch_job = await client.retrieve_batch_job(context.current_batch_id) - new_provider_status = self._map_provider_status(batch_job.status) - - new_status = context.status - # Map provider status based on current internal phase - if context.status in [ - BatchJobStatus.COUNTING_SUBMITTED, - BatchJobStatus.COUNTING_PROCESSING, - ]: - if new_provider_status == BatchJobStatus.READY_TO_PROCESS: - new_status = BatchJobStatus.COUNTING_READY_TO_PROCESS - elif new_provider_status == BatchJobStatus.FAILED: - new_status = BatchJobStatus.FAILED - elif new_provider_status == BatchJobStatus.CANCELLED: - new_status = BatchJobStatus.CANCELLED - elif new_provider_status == BatchJobStatus.PROCESSING: - new_status = BatchJobStatus.COUNTING_PROCESSING - else: - new_status = new_provider_status - - if new_status != context.status: - context.status = new_status - context.updated_at = datetime.now(timezone.utc) - db_session.add(context) - db_session.commit() - - except Exception as e: - self.logger.error(f"Failed to check batch status: {e}", exc_info=True) - - return context.status - - async def process_batch( - self, root_batch_id: str, db_session: Session - ) -> BatchProcessResult: - status = await self.get_status(root_batch_id, db_session) - context = db_session.get(BatchJobContext, root_batch_id) - - # 1. Already Completed - if status == BatchJobStatus.COMPLETED and context.results: - return await self._finalize_completion(context, db_session) - - # 2. Counting Phase Completed -> Submit Extraction - if status == BatchJobStatus.COUNTING_READY_TO_PROCESS: - return await self._process_counting_completion(context, db_session) - - # 3. Extraction Phase Ready -> Process Results - if status == BatchJobStatus.READY_TO_PROCESS: - return await self._process_extraction_completion(context, db_session) - - return BatchProcessResult( - status=status, message="Batch not ready for processing" - ) - - async def _process_counting_completion( - self, context: BatchJobContext, db_session: Session - ) -> BatchProcessResult: - try: - # Use counting client - client = self.entity_counter.llm_client - results_content = await client.retrieve_batch_results( - context.current_batch_id - ) - - # Determine expected models for validation - if context.config.get("hierarchical"): - current_idx = context.config.get("current_model_index", 0) - if 0 <= current_idx < len(self.model_registry.models): - target_model_names = [ - self.model_registry.models[current_idx].__name__ - ] - else: - target_model_names = self.model_registry.get_all_model_names() - else: - target_model_names = self.model_registry.get_all_model_names() - - # Parse descriptions - descriptions = [] - - for line in results_content.strip().split("\n"): - if not line.strip(): - continue - try: - item = json.loads(line) - content = client.extract_content_from_batch_response(item) - if content: - raw_json = json.loads(content) - if isinstance(raw_json, list) and raw_json: - raw_json = raw_json[0] - if isinstance(raw_json, dict): - validated_counts = self.entity_counter.validate_counts( - raw_json, target_model_names - ) - for model_name, descs in validated_counts.items(): - for desc in descs: - descriptions.append(f"[{model_name}] {desc}") - except Exception as e: - self.logger.warning(f"Failed to parse counting result: {e}") - - # Update config with descriptions - new_config = context.config.copy() - new_config["expected_entity_descriptions"] = descriptions - context.config = new_config - - # Proceed to Extraction - next_step = ( - context.config.get("current_model_index", 0) - if context.config.get("hierarchical") - else 0 - ) - await self._submit_extraction_phase( - context, db_session, step_index=next_step - ) - - return BatchProcessResult( - status=BatchJobStatus.PROCESSING, - message="Transitioned from counting to extraction", - retry_batch_id=context.root_batch_id, - ) - - except Exception as e: - self.logger.error(f"Failed to process counting results: {e}", exc_info=True) - return BatchProcessResult( - status=BatchJobStatus.FAILED, message=f"Counting failed: {e}" - ) - - async def _process_extraction_completion( - self, context: BatchJobContext, db_session: Session - ) -> BatchProcessResult: - try: - results = await self._retrieve_and_validate_results(context) - - if results: - consensus_output, details = self.consensus.get_consensus(results) - if details: - self.analytics_collector.record_consensus_run_details(details) - - processed = self._process_consensus_output(consensus_output) - - if context.config.get("hierarchical"): - return await self._process_hierarchical_step( - context, processed, db_session - ) - - # Finalize non-hierarchical - context.results = processed - context.status = BatchJobStatus.COMPLETED - context.updated_at = datetime.now(timezone.utc) - db_session.add(context) - db_session.commit() - - return await self._finalize_completion(context, db_session) - - # If no valid results, maybe retry? - return await self._handle_batch_retry( - context, context.root_batch_id, db_session - ) - - except Exception as e: - self.logger.error(f"Batch processing failed: {e}", exc_info=True) - return BatchProcessResult(status=BatchJobStatus.FAILED, message=str(e)) - - async def _process_hierarchical_step( - self, - context: BatchJobContext, - processed_results: List[Dict], - db_session: Session, - ) -> BatchProcessResult: - current_index = context.config.get("current_model_index", 0) - - # Save step result to DB - step = BatchJobStep( - batch_id=context.root_batch_id, - step_index=current_index, - status=BatchJobStatus.COMPLETED, - result=processed_results, - metadata_json={"timestamp": datetime.now(timezone.utc).isoformat()}, - ) - db_session.add(step) - - # Advance index - next_index = current_index + 1 - - # Update config - new_config = context.config.copy() - new_config["current_model_index"] = next_index - context.config = new_config - db_session.add(context) - db_session.commit() - - if next_index >= len(self.model_registry.models): - # All steps done - aggregate results for final hydration - all_steps = db_session.exec( - select(BatchJobStep) - .where(BatchJobStep.batch_id == context.root_batch_id) - .order_by(BatchJobStep.step_index) - ).all() - - final_results = [] - for s in all_steps: - final_results.extend(s.result) - - context.results = final_results - context.status = BatchJobStatus.COMPLETED - context.updated_at = datetime.now(timezone.utc) - db_session.add(context) - db_session.commit() - return await self._finalize_completion(context, db_session) - - # Submit next step (counting or extraction) - model_name = self.model_registry.models[next_index].__name__ - - if context.config.get("count_entities"): - await self._submit_counting_phase( - context, db_session, step_index=next_index - ) - return BatchProcessResult( - status=BatchJobStatus.COUNTING_PROCESSING, - message=f"Submitted counting step for {model_name}", - retry_batch_id=context.root_batch_id, - ) - else: - await self._submit_extraction_phase( - context, db_session, step_index=next_index - ) - return BatchProcessResult( - status=BatchJobStatus.PROCESSING, - message=f"Submitted hierarchical step for {model_name}", - retry_batch_id=context.root_batch_id, - ) - - async def _finalize_completion( - self, context: BatchJobContext, db_session: Session - ) -> BatchProcessResult: - from .result_processor import ResultProcessor - - processor = ResultProcessor( - self.model_registry, self.analytics_collector, self.logger - ) - - # Determine default model type for hydration - default_model_type = None - if self.config.use_structured_output: - default_model_type = self.model_registry.root_model.__name__ - - hydrated = processor.hydrate( - context.results, db_session, default_model_type=default_model_type - ) - return BatchProcessResult( - status=BatchJobStatus.COMPLETED, - hydrated_objects=hydrated, - original_pk_map=processor.original_pk_map, - ) - - async def _retrieve_and_validate_results( - self, context: BatchJobContext - ) -> List[List[Dict]]: - client = self.client_rotator.get_next_client() - results_content = await client.retrieve_batch_results(context.current_batch_id) - - # Determine default model type clearly - default_model_type = None - if self.config.use_structured_output: - if context.config.get("hierarchical"): - current_idx = context.config.get("current_model_index", 0) - if 0 <= current_idx < len(self.model_registry.models): - default_model_type = self.model_registry.models[ - current_idx - ].__name__ - else: - default_model_type = self.model_registry.root_model.__name__ - - valid_revisions = [] - for line in results_content.strip().split("\n"): - if not line.strip(): - continue - - try: - item = json.loads(line) - content = client.extract_content_from_batch_response(item) - - if content: - validated = process_and_validate_llm_output( - raw_llm_content=content, - model_schema_map=self.model_registry.model_map, - revision_info_for_error="batch_revision", - analytics_collector=self.analytics_collector, - default_model_type=default_model_type, - ) - if validated: - valid_revisions.append(validated) - except Exception as e: - self.logger.warning(f"Failed to validate batch revision: {e}") - - return normalize_json_revisions(valid_revisions) if valid_revisions else [] - - async def _handle_batch_retry( - self, context: BatchJobContext, root_batch_id: str, db_session: Session - ): - max_retries = self.config.max_validation_retries_per_revision - - if context.retry_count < max_retries: - context.retry_count += 1 - self.logger.info( - f"Retrying batch {root_batch_id} ({context.retry_count}/{max_retries})" - ) - - # Resubmit current step - if "counting" in context.status.value: - await self._submit_counting_phase(context, db_session) - else: - current_idx = context.config.get("current_model_index", 0) - await self._submit_extraction_phase( - context, db_session, step_index=current_idx - ) - - return BatchProcessResult( - status=BatchJobStatus.PROCESSING, - message="Retry submitted", - retry_batch_id=root_batch_id, - ) - - context.status = BatchJobStatus.FAILED - context.last_error = "Max retries exceeded" - context.updated_at = datetime.now(timezone.utc) - db_session.add(context) - db_session.commit() - - return BatchProcessResult( - status=BatchJobStatus.FAILED, message="Max retries exceeded" - ) - - def _create_batch_requests( - self, - system_prompt: str, - user_prompt: str, - json_schema: Optional[Dict] = None, - num_revisions: Optional[int] = None, - override_client: Optional[BaseLLMClient] = None, - ) -> List[Dict]: - requests = [] - client = override_client or self.client_rotator.current_client - revisions = ( - num_revisions - if num_revisions is not None - else self.config.num_llm_revisions - ) - - if self.config.use_structured_output and json_schema: - self.logger.info("Using structured output for batch requests") - - for i in range(revisions): - body = { - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - "temperature": client.temperature, - } - if hasattr(client, "model_name"): - body["model"] = client.model_name - - if self.config.use_structured_output and json_schema: - body["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": "extraction_response", - "schema": json_schema, - "strict": True, - }, - } - elif self.model_registry.llm_schema_json: - body["response_format"] = {"type": "json_object"} - - requests.append({"custom_id": f"rev-{i}", "body": body}) - return requests - - def _map_provider_status(self, provider_status) -> BatchJobStatus: - status_str = str(provider_status).lower() - - if "complete" in status_str or "succeeded" in status_str: - return BatchJobStatus.READY_TO_PROCESS - elif "fail" in status_str: - return BatchJobStatus.FAILED - elif "cancel" in status_str: - return BatchJobStatus.CANCELLED - elif ( - "process" in status_str or "active" in status_str or "running" in status_str - ): - return BatchJobStatus.PROCESSING - - return BatchJobStatus.SUBMITTED - - def _process_consensus_output(self, output) -> List[Dict[str, Any]]: - if output is None: - return [] - if isinstance(output, list): - return output - if isinstance(output, dict): - if "results" in output and isinstance(output["results"], list): - return output["results"] - return [output] - return [] diff --git a/src/extrai/core/client_rotator.py b/src/extrai/core/client_rotator.py index 48298e6..2b1687a 100644 --- a/src/extrai/core/client_rotator.py +++ b/src/extrai/core/client_rotator.py @@ -1,4 +1,3 @@ -from typing import List, Union from .base_llm_client import BaseLLMClient @@ -7,7 +6,7 @@ class ClientRotator: Manages rotation through a list of LLM clients. """ - def __init__(self, clients: Union[BaseLLMClient, List[BaseLLMClient]]): + def __init__(self, clients: BaseLLMClient | list[BaseLLMClient]): self.clients = clients if isinstance(clients, list) else [clients] if not self.clients: raise ValueError("At least one client must be provided") diff --git a/src/extrai/core/code_generation/python_builder.py b/src/extrai/core/code_generation/python_builder.py index 51eb89f..ed22a06 100644 --- a/src/extrai/core/code_generation/python_builder.py +++ b/src/extrai/core/code_generation/python_builder.py @@ -1,15 +1,15 @@ import keyword -from typing import Any, Dict, Set, List +from typing import Any class ImportManager: """Manages imports for the generated code, handling consolidation.""" def __init__(self): - self.typing_imports: Set[str] = set() - self.sqlmodel_imports: Set[str] = {"SQLModel"} - self.module_imports: Set[str] = set() - self.custom_imports: Set[str] = set() + self.typing_imports: set[str] = set() + self.sqlmodel_imports: set[str] = {"SQLModel"} + self.module_imports: set[str] = set() + self.custom_imports: set[str] = set() def add_import_for_type(self, type_str: str): if "datetime." in type_str: @@ -27,7 +27,7 @@ def add_import_for_type(self, type_str: str): if "Any" in type_str: self.typing_imports.add("Any") - def add_custom_imports(self, imports: List[str]): + def add_custom_imports(self, imports: list[str]): for imp in imports: self.custom_imports.add(imp.strip()) @@ -85,12 +85,12 @@ def render(self) -> str: class FieldGenerator: """Generates the code for a single field in a SQLModel.""" - def __init__(self, field_info: Dict[str, Any], import_manager: ImportManager): + def __init__(self, field_info: dict[str, Any], import_manager: ImportManager): self.field_info = field_info self.imports = import_manager self.field_name_original = self.field_info["name"] self.field_name_python = self.field_name_original - self.args_map: Dict[str, str] = {} + self.args_map: dict[str, str] = {} def _handle_keyword_name(self): if keyword.iskeyword(self.field_name_original): @@ -231,7 +231,7 @@ def __init__( import_manager: ImportManager, description: str, table_name: str, - base_classes: List[str], + base_classes: list[str], is_table_model: bool, ): self.model_name = model_name @@ -240,7 +240,7 @@ def __init__( self.table_name = table_name self.base_classes = base_classes self.is_table_model = is_table_model - self.fields_code: List[str] = [] + self.fields_code: list[str] = [] def add_field(self, field_code: str): self.fields_code.append(field_code) @@ -274,7 +274,7 @@ def render_class_definition(self) -> str: class PythonModelBuilder: """Facade for generating Python code for SQLModels from description dictionaries.""" - def generate_model_code(self, model_descriptions: List[Dict[str, Any]]) -> str: + def generate_model_code(self, model_descriptions: list[dict[str, Any]]) -> str: """ Generates Python code for the provided model descriptions. diff --git a/src/extrai/core/config/__init__.py b/src/extrai/core/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extrai/core/config/batch_job_config.py b/src/extrai/core/config/batch_job_config.py new file mode 100644 index 0000000..1e303a2 --- /dev/null +++ b/src/extrai/core/config/batch_job_config.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass, field, replace + + +@dataclass +class BatchJobConfig: + extraction_example_json: str = "" + custom_extraction_process: str | list[str] = "" + custom_extraction_guidelines: str | list[str] = "" + custom_final_checklist: str | list[str] = "" + custom_context: str | list[str] = "" + count_entities: bool = False + custom_counting_context: str | list[str] = "" + schema_json: str = "" + # batch-specific runtime state + hierarchical: bool = False + current_model_index: int = 0 + expected_entity_descriptions: list[dict] | None = None + partial_results: list[dict] = field(default_factory=list) + + def evolve(self, **changes) -> "BatchJobConfig": + return replace(self, **changes) diff --git a/src/extrai/core/conflict_resolvers.py b/src/extrai/core/conflict_resolvers.py index 3bdf486..58dfaf5 100644 --- a/src/extrai/core/conflict_resolvers.py +++ b/src/extrai/core/conflict_resolvers.py @@ -1,18 +1,20 @@ # extrai/core/conflict_resolvers.py from collections import Counter -from typing import List, Optional, Callable, Dict, Any -from extrai.utils.flattening_utils import Path, JSONValue +from collections.abc import Callable from difflib import SequenceMatcher +from typing import Any + +from extrai.utils.flattening_utils import JSONValue, Path # Define conflict resolution strategies ConflictResolutionStrategy = Callable[ - [Path, List[JSONValue], Optional[List[float]]], Optional[JSONValue] + [Path, list[JSONValue], list[float] | None], JSONValue | None ] def default_conflict_resolver( - path: Path, values: List[JSONValue], weights: Optional[List[float]] = None -) -> Optional[JSONValue]: + path: Path, values: list[JSONValue], weights: list[float] | None = None +) -> JSONValue | None: """ Default conflict resolution: if no consensus, omit the field. """ @@ -20,8 +22,8 @@ def default_conflict_resolver( def prefer_most_common_resolver( - _path: Path, values: List[JSONValue], weights: Optional[List[float]] = None -) -> Optional[JSONValue]: + _path: Path, values: list[JSONValue], weights: list[float] | None = None +) -> JSONValue | None: """ Conflict resolution: prefer the most common value. If weights are provided, prefers the value with the highest total weight. @@ -31,7 +33,7 @@ def prefer_most_common_resolver( if weights and len(weights) == len(values): # Weighted voting - weighted_counts: Dict[Any, float] = {} + weighted_counts: dict[Any, float] = {} # We need to handle unhashable types (like dicts/lists) if they appear in values # But JSONValue can be complex. Typically conflict resolution is on leaves (primitives). # Flattening utils usually produce primitives at leaves, but lists can be values if not recursed? @@ -88,8 +90,8 @@ def __init__( self.scorer = scorer def __call__( - self, path: Path, values: List[JSONValue], weights: Optional[List[float]] = None - ) -> Optional[JSONValue]: + self, path: Path, values: list[JSONValue], weights: list[float] | None = None + ) -> JSONValue | None: if not values: return None diff --git a/src/extrai/core/cost_calculator.py b/src/extrai/core/cost_calculator.py new file mode 100644 index 0000000..77d2e43 --- /dev/null +++ b/src/extrai/core/cost_calculator.py @@ -0,0 +1,101 @@ +# extrai/core/cost_calculator.py +import typing +from dataclasses import dataclass + +from .analytics_collector import WorkflowAnalyticsCollector +from .pricing_updater import load_pricing_data, update_prices_if_stale + +if typing.TYPE_CHECKING: + from .base_llm_client import BaseLLMClient + +@dataclass +class ModelCosts: + """A simple dataclass to store cost per million tokens for a model.""" + + input_cost_per_million: float + output_cost_per_million: float + input_cached_cost_per_million: float | None = None + + +# Costs are per million tokens +update_prices_if_stale() +pricing_data = load_pricing_data() + +MODEL_COSTS = {} +if pricing_data: + for item in pricing_data.get("prices", []): + MODEL_COSTS[item["id"]] = ModelCosts( + input_cost_per_million=item["input"], + output_cost_per_million=item["output"], + input_cached_cost_per_million=item["input_cached"] if item.get("input_cached") else None, + ) + + +def calculate_cost( + model_name: str, input_tokens: int, output_tokens: int, is_batch: bool = False +) -> float | None: + """ + Calculates the cost of a single LLM call. + + Args: + model_name: The name of the model used. + input_tokens: The number of input tokens. + output_tokens: The number of output tokens. + is_batch: If True, uses the cached input cost if available. + + Returns: + The calculated cost, or None if the model is not found. + """ + if model_name not in MODEL_COSTS: + return None + + costs = MODEL_COSTS[model_name] + + if is_batch and costs.input_cached_cost_per_million is not None: + input_cost = (input_tokens / 1_000_000) * costs.input_cached_cost_per_million + else: + input_cost = (input_tokens / 1_000_000) * costs.input_cost_per_million + + output_cost = (output_tokens / 1_000_000) * costs.output_cost_per_million + return input_cost + output_cost + + +def track_usage_from_response( + response_dict: dict, + client: "BaseLLMClient", + analytics_collector: WorkflowAnalyticsCollector, + batch_id: str, + extra_details: dict | None = None, +) -> None: + """ + Extracts usage from a batch response wrapper and records it via analytics collector. + """ + # Check for OpenAI-style batch response structure: {"response": {"body": {"usage": {...}}}} + if ( + isinstance(response_dict, dict) + and "response" in response_dict + and "body" in response_dict["response"] + and "usage" in response_dict["response"]["body"] + ): + usage = response_dict["response"]["body"]["usage"] + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + + cost = calculate_cost( + client.model_name, + input_tokens, + output_tokens, + is_batch=True, + ) + + details = {"batch_id": batch_id} + if extra_details: + details.update(extra_details) + + analytics_collector.record_llm_usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + model=client.model_name, + cost=cost, + details=details, + ) diff --git a/src/extrai/core/counting_consensus.py b/src/extrai/core/counting_consensus.py new file mode 100644 index 0000000..8cfb761 --- /dev/null +++ b/src/extrai/core/counting_consensus.py @@ -0,0 +1,152 @@ +import json +import logging +from typing import Any + +from .extraction_config import ExtractionConfig +from ..utils.alignment_utils import align_entity_arrays, calculate_similarity + + +class CountingConsensus: + """ + Implements a multi-revision consensus step specifically for the counting phase. + Utilizes Levenshtein distance to evaluate similarity among returned string arrays, + and falls back to a "resolver LLM" if there is too much discrepancy. + """ + + def __init__( + self, + config: ExtractionConfig, + llm_client, + logger: logging.Logger | None = None, + ): + self.config = config + self.llm_client = llm_client + self.logger = logger or logging.getLogger(__name__) + + async def achieve_consensus( + self, + revisions: list[dict[str, Any]], + system_prompt: str, + user_prompt: str, + target_json_schema: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """ + Attempts to reach consensus across multiple counting revisions. + If consensus fails, triggers a fallback LLM call to merge. + + Args: + revisions: A list where each item is a parsed JSON response (dict) containing `counted_entities`. + system_prompt: The original system prompt used for counting. + user_prompt: The original user prompt used for counting. + target_json_schema: The JSON schema to enforce on the fallback call (if using structured output). + + Returns: + The final consensus list of counted entities. + """ + if not revisions: + return [] + + if len(revisions) == 1: + return revisions[0].get("counted_entities", []) + + # Extract just the entity lists + entity_lists = [rev.get("counted_entities", []) for rev in revisions] + + # Ensure all lists actually exist and are lists + entity_lists = [lst if isinstance(lst, list) else [] for lst in entity_lists] + + # Step 2a: Length Verification + lengths = [len(lst) for lst in entity_lists] + all_same_length = all(l == lengths[0] for l in lengths) + + consensus_reached = False + best_list_idx = 0 + + if all_same_length and lengths[0] > 0: + # Step 2b: Levenshtein Distance Comparison + # Align arrays using the longest as reference (since lengths are equal, the first is fine) + aligned_arrays = align_entity_arrays( + entity_lists, truncate_to_min_length=False + ) + + reference_array = aligned_arrays[0] + avg_similarities = [] + + for i in range(1, len(aligned_arrays)): + current_array = aligned_arrays[i] + sim_sum = 0.0 + + for j in range(len(reference_array)): + sim = calculate_similarity(reference_array[j], current_array[j]) + sim_sum += sim + + avg_sim = sim_sum / len(reference_array) if len(reference_array) > 0 else 1.0 + avg_similarities.append(avg_sim) + + # If the average similarity across all matched pairs exceeds a threshold, consensus reached. + # We can pick the reference list since it's most similar. + overall_avg_sim = sum(avg_similarities) / len(avg_similarities) if avg_similarities else 1.0 + + if overall_avg_sim >= self.config.counting_levenshtein_threshold: + consensus_reached = True + best_list_idx = 0 + self.logger.info(f"Counting consensus reached with average similarity {overall_avg_sim:.2f}") + + elif all_same_length and lengths[0] == 0: + # All returned empty lists + return [] + + if consensus_reached: + return entity_lists[best_list_idx] + + # Step 2c: Discrepancy & Fallback (LLM Resolution) + self.logger.warning("Counting consensus failed. Triggering Merger LLM Call.") + + from extrai.core.prompts.counting import generate_entity_counting_system_prompt + + # We need to recreate the system prompt but with conflicting_revisions injected. + # However, we only have the raw `system_prompt` string. + # Actually, if we're inside the LLM call, we can append the revisions manually + # to the existing system prompt. + + revisions_json = json.dumps(revisions, indent=2) + merge_instructions = f""" + +# MERGE REQUIRED: +Previous extraction attempts returned conflicting results. Here are the conflicting revisions: +{revisions_json} + +Your task is to cross-reference these previous attempts with the text and provide the final, comprehensive, and correct list of entities, resolving any discrepancies. +""" + new_system_prompt = system_prompt + merge_instructions + + # Ensure we use a client (could be a rotator) + client = self.llm_client + if isinstance(client, list): + client = client[0] + + try: + merged_result = await client.generate_and_validate_raw_json_output( + system_prompt=new_system_prompt, + user_prompt=user_prompt, + target_json_schema=target_json_schema, + num_revisions=1, + max_validation_retries_per_revision=self.config.max_validation_retries_per_revision, + attempt_unwrap=False, + ) + + if isinstance(merged_result, list) and merged_result: + merged_result = merged_result[0] + + if isinstance(merged_result, dict) and "counted_entities" in merged_result: + return merged_result.get("counted_entities", []) + + except Exception as e: + self.logger.error(f"Fallback merger LLM call failed: {e}") + # Fallback: return the longest list + max_idx = lengths.index(max(lengths)) + return entity_lists[max_idx] + + # Ultimate fallback + max_idx = lengths.index(max(lengths)) + return entity_lists[max_idx] diff --git a/src/extrai/core/entity_counter.py b/src/extrai/core/entity_counter.py index e69f9a4..e0ab80a 100644 --- a/src/extrai/core/entity_counter.py +++ b/src/extrai/core/entity_counter.py @@ -1,15 +1,28 @@ import logging -from typing import List, Dict, Any, Optional -from pydantic import create_model +from typing import Any -from .model_registry import ModelRegistry +from pydantic import BaseModel, Field + +from .counting_consensus import CountingConsensus from .extraction_config import ExtractionConfig +from .model_registry import ModelRegistry from .prompt_builder import ( generate_entity_counting_system_prompt, generate_entity_counting_user_prompt, ) +class CountedEntity(BaseModel): + model: str + temp_id: str + related_ids: list[str] = Field(default_factory=list) + description: str + + +class EntityCountResult(BaseModel): + counted_entities: list[CountedEntity] = Field(default_factory=list) + + class EntityCounter: """Counts entities in input documents using LLM.""" @@ -26,13 +39,19 @@ def __init__( self.config = config self.analytics_collector = analytics_collector self.logger = logger + self.counting_consensus = CountingConsensus( + config=self.config, + llm_client=self.llm_client, + logger=self.logger, + ) def prepare_counting_prompts( self, - input_strings: List[str], - model_names: List[str], + input_strings: list[str], + model_names: list[str], custom_counting_context: str = "", - previous_entities: Optional[List[Dict[str, Any]]] = None, + previous_entities: list[dict[str, Any]] | None = None, + examples: str = "", ): """Prepares prompts for batch counting.""" # Generate schema for models @@ -40,73 +59,85 @@ def prepare_counting_prompts( # Build prompts system_prompt = generate_entity_counting_system_prompt( - model_names, schema_json, custom_counting_context, previous_entities + model_names, + schema_json, + custom_counting_context, + previous_entities, + examples, ) user_prompt = generate_entity_counting_user_prompt(input_strings) return system_prompt, user_prompt def validate_counts( - self, raw_counts: Dict[str, Any], model_names: List[str] - ) -> Dict[str, List[str]]: - """Validates raw counting results against dynamic model.""" - fields = {name: (List[str], ...) for name in model_names} - EntityCountModel = create_model("EntityCountModel", **fields) + self, raw_counts: dict[str, Any], model_names: list[str] + ) -> dict[str, list[str]]: + """Validates raw counting results against static model.""" try: - validated = EntityCountModel(**raw_counts) + validated = EntityCountResult(**raw_counts) return validated.model_dump() except Exception as e: self.logger.warning(f"Count validation failed: {e}") return {} + def get_counting_model(self, model_names: list[str]): + """Creates a Pydantic model for entity counting.""" + return EntityCountResult + async def count_entities( self, - input_strings: List[str], - model_names: List[str], + input_strings: list[str], + model_names: list[str], custom_counting_context: str = "", - previous_entities: Optional[List[Dict[str, Any]]] = None, - ) -> Dict[str, List[str]]: - """Performs entity counting for specified models.""" + previous_entities: list[dict[str, Any]] | None = None, + examples: str = "", + ) -> list[dict[str, Any]]: + """Performs entity counting for specified models using consensus.""" self.logger.info(f"Counting entities for: {model_names}") system_prompt, user_prompt = self.prepare_counting_prompts( - input_strings, model_names, custom_counting_context, previous_entities + input_strings, + model_names, + custom_counting_context, + previous_entities, + examples, ) - # Create validation model - fields = {name: (List[str], ...) for name in model_names} - EntityCountModel = create_model("EntityCountModel", **fields) + target_json_schema = EntityCountResult.model_json_schema() if self.config.use_structured_output else None - # Call LLM - try: - # Get next client (assuming llm_client is list or has rotation) - if isinstance(self.llm_client, list): - client = self.llm_client[0] - else: - client = self.llm_client + client = self.llm_client + if isinstance(client, list): + client = client[0] - result = await client.generate_and_validate_raw_json_output( + try: + # Execute multiple revisions natively + revisions = await client.generate_and_validate_raw_json_output( system_prompt=system_prompt, user_prompt=user_prompt, - target_json_schema=None, - num_revisions=1, + target_json_schema=target_json_schema, + num_revisions=self.config.num_counting_revisions, max_validation_retries_per_revision=self.config.max_validation_retries_per_revision, attempt_unwrap=False, ) - # Process result - if isinstance(result, list) and result: - result = result[0] - - if isinstance(result, dict): - validated = EntityCountModel(**result) - counts = validated.model_dump() - self.logger.info(f"Entity counts: {counts}") - return counts + # Revisions should be a list of dictionaries if successful + if not isinstance(revisions, list): + if isinstance(revisions, dict): + revisions = [revisions] + else: + self.logger.warning("Entity counting returned invalid result format") + return [] + + # Achieve consensus + consensus_result = await self.counting_consensus.achieve_consensus( + revisions=revisions, + system_prompt=system_prompt, + user_prompt=user_prompt, + target_json_schema=target_json_schema, + ) - self.logger.warning("Entity counting returned invalid result") - return {} + return consensus_result except Exception as e: self.logger.error(f"Entity counting failed: {e}") - return {} + return [] diff --git a/src/extrai/core/errors.py b/src/extrai/core/errors.py index 12b6c9f..c197eed 100644 --- a/src/extrai/core/errors.py +++ b/src/extrai/core/errors.py @@ -5,7 +5,7 @@ to provide a single point of reference for error types. """ -from typing import Any, Dict, Optional +from typing import Any from pydantic import ValidationError @@ -70,8 +70,8 @@ class LLMOutputParseError(LLMClientError): def __init__( self, message: str, - raw_content: Optional[str] = None, - original_exception: Optional[Exception] = None, + raw_content: str | None = None, + original_exception: Exception | None = None, ): super().__init__(message) self.raw_content = raw_content @@ -94,8 +94,8 @@ class LLMOutputValidationError(LLMClientError): def __init__( self, message: str, - parsed_json: Optional[Dict[str, Any]] = None, - validation_error: Optional[Any] = None, + parsed_json: dict[str, Any] | None = None, + validation_error: Any | None = None, ): # PydanticValidationError type hint can be 'Any' for simplicity here or more specific if PydanticValidationError is imported super().__init__(message) self.parsed_json = parsed_json @@ -177,6 +177,6 @@ def __init__( class ExampleGenerationError(Exception): """Custom exception for errors during example JSON generation.""" - def __init__(self, message: str, original_exception: Optional[Exception] = None): + def __init__(self, message: str, original_exception: Exception | None = None): super().__init__(message) self.original_exception = original_exception diff --git a/src/extrai/core/example_json_generator.py b/src/extrai/core/example_json_generator.py index 1ef0de6..2bb38b2 100644 --- a/src/extrai/core/example_json_generator.py +++ b/src/extrai/core/example_json_generator.py @@ -1,22 +1,24 @@ import json import logging -from typing import Optional, Dict, Any, Type +from typing import Any from sqlmodel import SQLModel -from extrai.core.base_llm_client import BaseLLMClient -from extrai.core.prompt_builder import ( - generate_prompt_for_example_json_generation, -) + from extrai.core.analytics_collector import ( WorkflowAnalyticsCollector, ) +from extrai.core.base_llm_client import BaseLLMClient from extrai.core.errors import ( + ConfigurationError, ExampleGenerationError, LLMAPICallError, LLMOutputParseError, LLMOutputValidationError, - ConfigurationError, ) +from extrai.core.prompt_builder import ( + generate_prompt_for_example_json_generation, +) + from .schema_inspector import SchemaInspector @@ -29,10 +31,10 @@ class ExampleJSONGenerator: def __init__( self, llm_client: BaseLLMClient, - output_model: Type[SQLModel], - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, + output_model: type[SQLModel], + analytics_collector: WorkflowAnalyticsCollector | None = None, max_validation_retries_per_revision: int = 1, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, ): self.logger = logger or logging.getLogger(self.__class__.__name__) if not logger: @@ -75,7 +77,7 @@ def __init__( # The schema for basic validation by the LLM client needs to match the new # expected output format: `{"entities": [...]}`. - self.target_json_schema_dict: Dict[str, Any] = { + self.target_json_schema_dict: dict[str, Any] = { "type": "object", "properties": { "entities": {"type": "array", "items": {"type": "object"}} diff --git a/src/extrai/core/extraction_config.py b/src/extrai/core/extraction_config.py index eb52c43..af8e0c9 100644 --- a/src/extrai/core/extraction_config.py +++ b/src/extrai/core/extraction_config.py @@ -1,7 +1,7 @@ # extrai/core/extraction_config.py +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable, Optional @dataclass @@ -9,9 +9,11 @@ class ExtractionConfig: """Configuration for extraction workflows.""" num_llm_revisions: int = 3 + num_counting_revisions: int = 3 max_validation_retries_per_revision: int = 2 consensus_threshold: float = 0.51 - conflict_resolver: Optional[Callable] = None + counting_levenshtein_threshold: float = 0.85 + conflict_resolver: Callable | None = None use_hierarchical_extraction: bool = False use_structured_output: bool = False diff --git a/src/extrai/core/extraction_context_preparer.py b/src/extrai/core/extraction_context_preparer.py index 1f49274..ac70308 100644 --- a/src/extrai/core/extraction_context_preparer.py +++ b/src/extrai/core/extraction_context_preparer.py @@ -1,14 +1,16 @@ import json import logging -from typing import List, Optional, Union, Callable +from collections.abc import Callable + from sqlmodel import SQLModel -from .model_registry import ModelRegistry -from .example_json_generator import ExampleJSONGenerator, ExampleGenerationError +from extrai.utils.serialization_utils import serialize_sqlmodel_with_relationships + from .analytics_collector import WorkflowAnalyticsCollector -from .errors import WorkflowError from .base_llm_client import BaseLLMClient -from extrai.utils.serialization_utils import serialize_sqlmodel_with_relationships +from .errors import WorkflowError +from .example_json_generator import ExampleGenerationError, ExampleJSONGenerator +from .model_registry import ModelRegistry class ExtractionContextPreparer: @@ -31,7 +33,7 @@ def __init__( async def prepare_example( self, extraction_example_json: str, - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]], + extraction_example_object: SQLModel | list[SQLModel] | None, client_provider: Callable[[], BaseLLMClient], ) -> str: """ @@ -56,7 +58,7 @@ async def prepare_example( self.logger.info("No example provided, auto-generating...") return await self._auto_generate_example(client_provider) - def _serialize_example_object(self, obj: Union[SQLModel, List[SQLModel]]) -> str: + def _serialize_example_object(self, obj: SQLModel | list[SQLModel]) -> str: """Serializes SQLModel object(s) to JSON.""" objects = obj if isinstance(obj, list) else [obj] serialized = [] diff --git a/src/extrai/core/extraction_pipeline.py b/src/extrai/core/extraction_pipeline.py index 0e92139..5090a70 100644 --- a/src/extrai/core/extraction_pipeline.py +++ b/src/extrai/core/extraction_pipeline.py @@ -1,19 +1,22 @@ import logging -from typing import List, Dict, Any, Optional, Union +from typing import Any, Union + from sqlmodel import SQLModel from extrai.core.base_llm_client import BaseLLMClient +from extrai.utils.serialization_utils import resolve_step_param + +from .analytics_collector import WorkflowAnalyticsCollector from .client_rotator import ClientRotator -from .extraction_context_preparer import ExtractionContextPreparer -from .model_registry import ModelRegistry -from .extraction_config import ExtractionConfig -from .prompt_builder import PromptBuilder from .entity_counter import EntityCounter -from .llm_runner import LLMRunner +from .extraction_config import ExtractionConfig +from .extraction_context_preparer import ExtractionContextPreparer +from .extraction_request_factory import ExtractionRequestFactory from .hierarchical_extractor import HierarchicalExtractor -from .analytics_collector import WorkflowAnalyticsCollector +from .llm_runner import LLMRunner +from .model_registry import ModelRegistry from .model_wrapper_builder import ModelWrapperBuilder -from .extraction_request_factory import ExtractionRequestFactory +from .prompt_builder import PromptBuilder class ExtractionPipeline: @@ -33,11 +36,11 @@ class ExtractionPipeline: def __init__( self, model_registry: ModelRegistry, - llm_client: Union["BaseLLMClient", List["BaseLLMClient"]], + llm_client: Union["BaseLLMClient", list["BaseLLMClient"]], config: ExtractionConfig, analytics_collector: WorkflowAnalyticsCollector, logger: logging.Logger, - counting_llm_client: Optional[BaseLLMClient] = None, + counting_llm_client: BaseLLMClient | None = None, ): """ Initialize the extraction pipeline. @@ -102,16 +105,16 @@ def __init__( async def extract( self, - input_strings: List[str], + input_strings: list[str], extraction_example_json: str = "", - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None, - custom_extraction_process: str = "", - custom_extraction_guidelines: str = "", - custom_final_checklist: str = "", - custom_context: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", count_entities: bool = False, - custom_counting_context: str = "", - ) -> List[Dict[str, Any]]: + custom_counting_context: str | list[str] = "", + ) -> list[dict[str, Any]]: """ Executes extraction and returns consensus JSON. @@ -145,12 +148,24 @@ async def extract( # Step 2: Count entities if requested # Note: For hierarchical extraction, counting is handled per-model within the extractor expected_entity_descriptions = None + skip_extraction = False if count_entities and not self.config.use_hierarchical_extraction: expected_entity_descriptions = await self._count_entities( - input_strings, custom_counting_context + input_strings, custom_counting_context, examples=example_json ) if expected_entity_descriptions is not None: self.logger.info(f"Entity count: {len(expected_entity_descriptions)}") + # Check if descriptions are empty - if so, skip extraction + if len(expected_entity_descriptions) == 0: + self.logger.info( + "Skipping extraction - no entity descriptions found from counting" + ) + skip_extraction = True + else: + # Counting failed, but we'll continue with extraction without descriptions + self.logger.warning( + "Entity counting failed or returned None, proceeding with extraction without descriptions" + ) # Step 3: Run extraction (hierarchical or standard) if self.config.use_hierarchical_extraction: @@ -186,60 +201,72 @@ async def extract( f"Using {'structured' if self.config.use_structured_output else 'standard'} extraction mode" ) - request = self.request_factory.prepare_request( - input_strings=input_strings, - config=self.config, - extraction_example_json=example_json, - custom_extraction_process=custom_extraction_process, - custom_extraction_guidelines=custom_extraction_guidelines, - custom_final_checklist=custom_final_checklist, - custom_context=custom_context, - expected_entity_descriptions=expected_entity_descriptions, - ) - - self.logger.debug( - f"System prompt length: {len(request.system_prompt)} chars" - ) - self.logger.debug(f"User prompt length: {len(request.user_prompt)} chars") - - if request.response_model: - results = await self.llm_runner.run_structured_extraction_cycle( - system_prompt=request.system_prompt, - user_prompt=request.user_prompt, - response_model=request.response_model, + # Check if we should skip extraction due to empty descriptions from counting + if skip_extraction: + self.logger.info( + "Skipping extraction - no entity descriptions found from counting" ) + results = [] else: - results = await self.llm_runner.run_extraction_cycle( - system_prompt=request.system_prompt, user_prompt=request.user_prompt + request = self.request_factory.prepare_request( + input_strings=input_strings, + config=self.config, + extraction_example_json=example_json, + custom_extraction_process=resolve_step_param(custom_extraction_process), + custom_extraction_guidelines=resolve_step_param( + custom_extraction_guidelines + ), + custom_final_checklist=resolve_step_param(custom_final_checklist), + custom_context=resolve_step_param(custom_context), + expected_entity_descriptions=expected_entity_descriptions, ) + self.logger.debug( + f"System prompt length: {len(request.system_prompt)} chars" + ) + self.logger.debug(f"User prompt length: {len(request.user_prompt)} chars") + + if request.response_model: + results = await self.llm_runner.run_structured_extraction_cycle( + system_prompt=request.system_prompt, + user_prompt=request.user_prompt, + response_model=request.response_model, + ) + else: + results = await self.llm_runner.run_extraction_cycle( + system_prompt=request.system_prompt, user_prompt=request.user_prompt + ) + self.logger.info(f"Extraction completed. Found {len(results)} entities.") return results async def _count_entities( - self, input_strings: List[str], custom_counting_context: str = "" - ) -> Optional[List[str]]: + self, + input_strings: list[str], + custom_counting_context: str | list[str] = "", + examples: str = "", + ) -> list[dict] | None: """ Counts entities in the input documents. Args: input_strings: Documents to analyze custom_counting_context: Custom context for counting phase + examples: Optional examples to guide the counting phase Returns: - List of descriptions of all model entities, or None if counting fails + List of descriptions (dicts) of all model entities, or None if counting fails """ all_model_names = self.model_registry.get_all_model_names() try: counts = await self.entity_counter.count_entities( - input_strings, all_model_names, custom_counting_context + input_strings, + all_model_names, + resolve_step_param(custom_counting_context), + examples=examples, ) - flat_descriptions = [] - for model_name, descriptions in counts.items(): - for desc in descriptions: - flat_descriptions.append(f"[{model_name}] {desc}") - return flat_descriptions + return counts except Exception as e: self.logger.warning(f"Entity counting failed: {e}") return None diff --git a/src/extrai/core/extraction_request_factory.py b/src/extrai/core/extraction_request_factory.py index 1433411..ba28703 100644 --- a/src/extrai/core/extraction_request_factory.py +++ b/src/extrai/core/extraction_request_factory.py @@ -1,19 +1,19 @@ import logging -from typing import List, Dict, Any, Optional, NamedTuple +from typing import Any, NamedTuple +from extrai.core.extraction_config import ExtractionConfig from extrai.core.model_registry import ModelRegistry -from extrai.core.prompt_builder import PromptBuilder from extrai.core.model_wrapper_builder import ModelWrapperBuilder -from extrai.core.extraction_config import ExtractionConfig +from extrai.core.prompt_builder import PromptBuilder from extrai.utils.serialization_utils import make_json_serializable class ExtractionRequest(NamedTuple): system_prompt: str user_prompt: str - json_schema: Optional[Dict[str, Any]] - model_name: Optional[str] - response_model: Optional[Any] = None + json_schema: dict[str, Any] | None + model_name: str | None + response_model: Any | None = None class ExtractionRequestFactory: @@ -27,7 +27,7 @@ def __init__( model_registry: ModelRegistry, prompt_builder: PromptBuilder, model_wrapper_builder: ModelWrapperBuilder, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, ): self.model_registry = model_registry self.prompt_builder = prompt_builder @@ -36,16 +36,16 @@ def __init__( def prepare_request( self, - input_strings: List[str], + input_strings: list[str], config: ExtractionConfig, extraction_example_json: str = "", custom_extraction_process: str = "", custom_extraction_guidelines: str = "", custom_final_checklist: str = "", custom_context: str = "", - expected_entity_descriptions: Optional[List[str]] = None, - previous_entities: Optional[List[Dict[str, Any]]] = None, - hierarchical_model_index: Optional[int] = None, + expected_entity_descriptions: list[dict] | None = None, + previous_entities: list[dict[str, Any]] | None = None, + hierarchical_model_index: int | None = None, ) -> ExtractionRequest: """ Prepares the extraction request based on the configuration and current state. diff --git a/src/extrai/core/hierarchical_extractor.py b/src/extrai/core/hierarchical_extractor.py index 62f0234..ef6e679 100644 --- a/src/extrai/core/hierarchical_extractor.py +++ b/src/extrai/core/hierarchical_extractor.py @@ -1,14 +1,19 @@ import logging -from typing import List, Dict, Any, Tuple, Optional +from typing import Any + +from extrai.core.extraction_config import ExtractionConfig +from extrai.utils.serialization_utils import ( + make_json_serializable, + resolve_step_param, +) -from .model_registry import ModelRegistry -from .prompt_builder import PromptBuilder from .entity_counter import EntityCounter +from .extraction_request_factory import ExtractionRequestFactory from .llm_runner import LLMRunner +from .model_registry import ModelRegistry from .model_wrapper_builder import ModelWrapperBuilder -from .extraction_request_factory import ExtractionRequestFactory -from extrai.core.extraction_config import ExtractionConfig -from extrai.utils.serialization_utils import make_json_serializable +from .prompt_builder import PromptBuilder +from .shared.hierarchical_coordinator import HierarchicalCoordinator class HierarchicalExtractor: @@ -29,7 +34,7 @@ def __init__( request_factory: ExtractionRequestFactory, model_wrapper_builder: ModelWrapperBuilder = None, use_structured_output: bool = False, - config: Optional[ExtractionConfig] = None, + config: ExtractionConfig | None = None, ): self.model_registry = model_registry self.prompt_builder = prompt_builder @@ -40,76 +45,160 @@ def __init__( self.model_wrapper_builder = model_wrapper_builder self.use_structured_output = use_structured_output self.config = config + self.coordinator = HierarchicalCoordinator(model_registry, logger) + + def _has_valid_descriptions( + self, descriptions: list[dict] | None + ) -> bool: + """ + Check if descriptions list is valid (not None and contains non-empty dicts). + + Args: + descriptions: List of description dicts from entity counting + + Returns: + True if descriptions exist and have at least one valid dict + """ + if descriptions is None: + return False + if not isinstance(descriptions, list): + return False + if len(descriptions) == 0: + return False + + for item in descriptions: + if isinstance(item, dict) and "description" in item: + if item["description"] and item["description"].strip(): + return True + return False async def extract( self, - input_strings: List[str], + input_strings: list[str], extraction_example_json: str, - custom_extraction_process: str, - custom_extraction_guidelines: str, - custom_final_checklist: str, - custom_context: str, + custom_extraction_process: str | list[str], + custom_extraction_guidelines: str | list[str], + custom_final_checklist: str | list[str], + custom_context: str | list[str], count_entities: bool, - custom_counting_context: str = "", - ) -> List[Dict[str, Any]]: + custom_counting_context: str | list[str] = "", + ) -> list[dict[str, Any]]: """Executes hierarchical extraction.""" self.logger.info("Starting hierarchical extraction...") - models = self.model_registry.models - results_store: Dict[Tuple[str, str], Dict[str, Any]] = {} + models = self.coordinator.get_models() + num_models = len(models) + total_steps = num_models * 2 if count_entities else num_models + results_store: dict[tuple[str, str], dict[str, Any]] = {} + current_step = 0 for i, model_class in enumerate(models): model_name = model_class.__name__ - self.logger.info(f"Processing model: {model_name}") + + # Resolve parameters for this step using model index `i` + counting_context_for_model = resolve_step_param( + custom_counting_context, i, num_models + ) + extraction_process_for_model = resolve_step_param( + custom_extraction_process, i, num_models + ) + guidelines_for_model = resolve_step_param( + custom_extraction_guidelines, i, num_models + ) + checklist_for_model = resolve_step_param( + custom_final_checklist, i, num_models + ) + context_for_model = resolve_step_param(custom_context, i, num_models) # Count entities if needed expected_entity_descriptions = None if count_entities: + current_step += 1 + self.logger.info( + f"Step {current_step}/{total_steps}: Counting entities for {model_name}" + ) # Prepare previous entities for context previous_entities = None if results_store: previous_entities = make_json_serializable( - list(results_store.values()) + self.coordinator.collect_previous_entities( + list(results_store.values()) + ) ) counts = await self.entity_counter.count_entities( input_strings, [model_name], - custom_counting_context, + counting_context_for_model, previous_entities=previous_entities, + examples=extraction_example_json, ) - expected_entity_descriptions = counts.get(model_name) - - if not self.config: - raise ValueError( - "ExtractionConfig is required for HierarchicalExtractor" + + # Filter counts just for this model + expected_entity_descriptions = [ + item for item in counts if item.get("model") == model_name + ] + + # DEBUG: Log counting results to diagnose empty descriptions issue + self.logger.debug( + f"DEBUG: Counting results for {model_name}: " + f"expected_entity_descriptions={expected_entity_descriptions}, " + f"type={type(expected_entity_descriptions)}, " + f"count={len(expected_entity_descriptions) if expected_entity_descriptions else 0}" ) - request = self.request_factory.prepare_request( - input_strings=input_strings, - config=self.config, - extraction_example_json=extraction_example_json, - custom_extraction_process=custom_extraction_process, - custom_extraction_guidelines=custom_extraction_guidelines, - custom_final_checklist=custom_final_checklist, - custom_context=custom_context, - expected_entity_descriptions=expected_entity_descriptions, - previous_entities=list(results_store.values()) - if results_store - else None, - hierarchical_model_index=i, + # Check if we should skip extraction due to empty descriptions from counting + should_skip_extraction = ( + count_entities + and not self._has_valid_descriptions(expected_entity_descriptions) ) - if self.use_structured_output: - entities = await self.llm_runner.run_structured_extraction_cycle( - system_prompt=request.system_prompt, - user_prompt=request.user_prompt, - response_model=request.response_model, + if should_skip_extraction: + self.logger.info( + f"Step {current_step + 1}/{total_steps}: Skipping extraction for {model_name} " + f"- no valid entity descriptions found from counting" ) + # Skip the extraction step but still increment the step counter + current_step += 1 + # Set entities to empty list to indicate nothing was extracted + entities = [] else: - entities = await self.llm_runner.run_extraction_cycle( - system_prompt=request.system_prompt, user_prompt=request.user_prompt + if not self.config: + raise ValueError( + "ExtractionConfig is required for HierarchicalExtractor" + ) + + current_step += 1 + self.logger.info( + f"Step {current_step}/{total_steps}: Extracting entities for {model_name}" ) + request = self.request_factory.prepare_request( + input_strings=input_strings, + config=self.config, + extraction_example_json=extraction_example_json, + custom_extraction_process=extraction_process_for_model, + custom_extraction_guidelines=guidelines_for_model, + custom_final_checklist=checklist_for_model, + custom_context=context_for_model, + expected_entity_descriptions=expected_entity_descriptions, + previous_entities=self.coordinator.collect_previous_entities( + list(results_store.values()) + ) + if results_store + else None, + hierarchical_model_index=i, + ) + + if self.use_structured_output: + entities = await self.llm_runner.run_structured_extraction_cycle( + system_prompt=request.system_prompt, + user_prompt=request.user_prompt, + response_model=request.response_model, + ) + else: + entities = await self.llm_runner.run_extraction_cycle( + system_prompt=request.system_prompt, user_prompt=request.user_prompt + ) # Store results for idx, entity in enumerate(entities): @@ -127,3 +216,4 @@ async def extract( ) return list(results_store.values()) + diff --git a/src/extrai/core/json_consensus.py b/src/extrai/core/json_consensus.py index 294a861..892e8c9 100644 --- a/src/extrai/core/json_consensus.py +++ b/src/extrai/core/json_consensus.py @@ -2,28 +2,29 @@ import logging import math from collections import Counter -from typing import List, Dict, Any, Optional, Union, Tuple -from extrai.utils.flattening_utils import ( - flatten_json, - unflatten_json, - Path, - JSONValue, - JSONObject, - JSONArray, - FlattenedJSON, -) +from typing import Any + from extrai.core.conflict_resolvers import ( ConflictResolutionStrategy, default_conflict_resolver, - prefer_most_common_resolver, levenshtein_similarity, + prefer_most_common_resolver, +) +from extrai.utils.flattening_utils import ( + FlattenedJSON, + JSONArray, + JSONObject, + JSONValue, + Path, + flatten_json, + unflatten_json, ) # Sentinel value to indicate that no consensus was reached for a path. _NO_CONSENSUS = object() # Define a type for a list of JSON revisions -JSONRevisions = List[Union[JSONObject, JSONArray]] +JSONRevisions = list[JSONObject | JSONArray] class JSONConsensus: @@ -35,10 +36,9 @@ class JSONConsensus: def __init__( self, consensus_threshold: float = 0.5, - conflict_resolver: Optional[ - ConflictResolutionStrategy - ] = default_conflict_resolver, - logger: Optional[logging.Logger] = None, + conflict_resolver: ConflictResolutionStrategy + | None = default_conflict_resolver, + logger: logging.Logger | None = None, ): """ Initializes the JSONConsensus processor. @@ -63,7 +63,7 @@ def __init__( def get_consensus( self, revisions: JSONRevisions - ) -> Tuple[Union[JSONObject, JSONArray, JSONValue, None], Dict[str, Any]]: + ) -> tuple[JSONObject | JSONArray | JSONValue | None, dict[str, Any]]: num_revisions = len(revisions) analytics = self._initialize_analytics(num_revisions) @@ -96,7 +96,7 @@ def get_consensus( return final_consensus_object, analytics - def _initialize_analytics(self, num_revisions: int) -> Dict[str, Any]: + def _initialize_analytics(self, num_revisions: int) -> dict[str, Any]: return { "revisions_processed": num_revisions, "unique_paths_considered": 0, @@ -108,7 +108,7 @@ def _initialize_analytics(self, num_revisions: int) -> Dict[str, Any]: "average_string_similarity": 0.0, # Average Levenshtein ratio (1.0 = identical) } - def _calculate_revision_weights(self, revisions: JSONRevisions) -> List[float]: + def _calculate_revision_weights(self, revisions: JSONRevisions) -> list[float]: """ Calculates weights for each revision based on its similarity to other revisions. Revisions that are similar to others get higher weights (centrality). @@ -163,11 +163,11 @@ def _calculate_revision_weights(self, revisions: JSONRevisions) -> List[float]: def _aggregate_paths( self, revisions: JSONRevisions - ) -> Dict[Path, List[Tuple[JSONValue, int]]]: + ) -> dict[Path, list[tuple[JSONValue, int]]]: """ Aggregates values for each path, preserving the source revision index. """ - path_to_values: Dict[Path, List[Tuple[JSONValue, int]]] = {} + path_to_values: dict[Path, list[tuple[JSONValue, int]]] = {} flattened_revisions = [flatten_json(rev) for rev in revisions] for idx, flat_rev in enumerate(flattened_revisions): for path, value in flat_rev.items(): @@ -176,10 +176,10 @@ def _aggregate_paths( def _build_consensus_json( self, - path_to_values: Dict[Path, List[Tuple[JSONValue, int]]], + path_to_values: dict[Path, list[tuple[JSONValue, int]]], num_revisions: int, - analytics: Dict[str, Any], - revision_weights: List[float], + analytics: dict[str, Any], + revision_weights: list[float], ) -> FlattenedJSON: consensus_flat_json: FlattenedJSON = {} @@ -266,10 +266,10 @@ def _build_consensus_json( def _get_consensus_for_path( self, path: Path, - values: List[JSONValue], - weights: List[float], + values: list[JSONValue], + weights: list[float], num_revisions: int, - ) -> Union[JSONValue, object]: + ) -> JSONValue | object: # Use weighted voting if weights provided most_common_candidate = prefer_most_common_resolver(path, values, weights) @@ -304,7 +304,7 @@ def _get_consensus_for_path( def _build_final_object( self, consensus_flat_json: FlattenedJSON, revisions: JSONRevisions - ) -> Union[JSONObject, JSONArray, JSONValue, None]: + ) -> JSONObject | JSONArray | JSONValue | None: if not consensus_flat_json and revisions: return [] if isinstance(revisions[0], list) else {} return unflatten_json(consensus_flat_json) diff --git a/src/extrai/core/llm_runner.py b/src/extrai/core/llm_runner.py index d26510e..a988149 100644 --- a/src/extrai/core/llm_runner.py +++ b/src/extrai/core/llm_runner.py @@ -1,84 +1,53 @@ # extrai/core/llm_runner.py -import logging import asyncio -from typing import List, Dict, Any, Union +import logging +from typing import Any -from .model_registry import ModelRegistry -from .extraction_config import ExtractionConfig -from .json_consensus import JSONConsensus, default_conflict_resolver from .analytics_collector import WorkflowAnalyticsCollector -from .base_llm_client import BaseLLMClient +from .base_llm_client import BaseLLMClient, ResponseMode from .errors import ( - LLMInteractionError, - ConsensusProcessError, + LLMAPICallError, LLMConfigurationError, + LLMInteractionError, LLMOutputParseError, LLMOutputValidationError, - LLMAPICallError, ) -from extrai.utils.alignment_utils import normalize_json_revisions +from .extraction_config import ExtractionConfig +from .model_registry import ModelRegistry +from .shared.consensus_runner import ConsensusRunner class LLMRunner: """ Manages LLM client rotation and extraction cycles. - - Responsibilities: - - Rotate through multiple LLM clients for load balancing - - Execute parallel LLM calls for multiple revisions - - Run consensus mechanism on results - - Handle LLM-related errors gracefully - - This class abstracts away the complexity of managing multiple LLM - clients and coordinating their outputs through consensus. """ def __init__( self, model_registry: ModelRegistry, - llm_client: Union[BaseLLMClient, List[BaseLLMClient]], + llm_client: BaseLLMClient | list[BaseLLMClient], config: ExtractionConfig, analytics_collector: WorkflowAnalyticsCollector, logger: logging.Logger, ): - """ - Initialize the LLM runner. - - Args: - model_registry: Registry of SQLModel schemas - llm_client: Single client or list of LLM clients - config: Extraction configuration - analytics_collector: Collector for tracking metrics - logger: Logger instance - - Raises: - ValueError: If llm_client list is empty or contains invalid clients - """ self.model_registry = model_registry self.config = config self.analytics_collector = analytics_collector self.logger = logger - - # Setup clients with validation self.clients = self._setup_clients(llm_client) self.client_index = 0 - - # Setup consensus mechanism - self.consensus = JSONConsensus( - consensus_threshold=config.consensus_threshold, - conflict_resolver=config.conflict_resolver or default_conflict_resolver, - logger=logger, + self.consensus_runner = ConsensusRunner( + config, analytics_collector, logger ) - self.logger.info( f"LLMRunner initialized with {len(self.clients)} client(s), " f"{config.num_llm_revisions} revisions per cycle" ) def _setup_clients( - self, llm_client: Union[BaseLLMClient, List[BaseLLMClient]] - ) -> List[BaseLLMClient]: + self, llm_client: BaseLLMClient | list[BaseLLMClient] + ) -> list[BaseLLMClient]: """ Validates and normalizes LLM client input. @@ -129,7 +98,7 @@ def get_next_client(self) -> BaseLLMClient: async def run_extraction_cycle( self, system_prompt: str, user_prompt: str - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Runs a complete extraction cycle. @@ -157,15 +126,10 @@ async def run_extraction_cycle( # Step 1: Generate revisions in parallel revisions = await self._generate_revisions(system_prompt, user_prompt) - self.logger.debug(f"Generated {len(revisions)} revisions before normalization") - - # Step 2: Normalize for consensus (handles array ordering) - revisions = normalize_json_revisions(revisions) - - self.logger.debug(f"Normalized to {len(revisions)} revisions for consensus") + self.logger.debug(f"Generated {len(revisions)} revisions before consensus") # Step 3: Run consensus - results = self._run_consensus(revisions) + results = self.consensus_runner.run(revisions) self.logger.info(f"Extraction cycle completed with {len(results)} entities") @@ -176,7 +140,7 @@ async def run_structured_extraction_cycle( system_prompt: str, user_prompt: str, response_model: Any, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Runs a structured extraction cycle using response_model directly. """ @@ -184,18 +148,21 @@ async def run_structured_extraction_cycle( f"Starting structured extraction cycle with {self.config.num_llm_revisions} revisions" ) + async def generate_single_revision(client_instance: BaseLLMClient) -> Any: + """Helper to generate a single revision and extract the result.""" + results = await client_instance.generate_revisions( + system_prompt=system_prompt, + user_prompt=user_prompt, + num_revisions=1, + response_mode=ResponseMode.STRUCTURED, + response_model=response_model, + ) + return results[0] + tasks = [] for i in range(self.config.num_llm_revisions): client = self.get_next_client() - tasks.append( - asyncio.create_task( - client.generate_structured( - system_prompt=system_prompt, - user_prompt=user_prompt, - response_model=response_model, - ) - ) - ) + tasks.append(asyncio.create_task(generate_single_revision(client))) try: results = await asyncio.gather(*tasks) @@ -214,24 +181,21 @@ async def run_structured_extraction_cycle( self.logger.warning(f"Result {type(result)} is not a Pydantic model.") # Extract the list of entities if present - normalized_revisions = [] + processed_revisions = [] for rev in revisions: if "entities" in rev and isinstance(rev["entities"], list): - normalized_revisions.append(rev["entities"]) + processed_revisions.append(rev["entities"]) else: - normalized_revisions.append(rev) - - # Step 2: Normalize - normalized_revisions = normalize_json_revisions(normalized_revisions) + processed_revisions.append(rev) # Step 3: Consensus - final_results = self._run_consensus(normalized_revisions) + final_results = self.consensus_runner.run(processed_revisions) return final_results async def _generate_revisions( self, system_prompt: str, user_prompt: str - ) -> List[Any]: + ) -> list[Any]: """ Generates multiple LLM revisions in parallel. @@ -302,93 +266,6 @@ async def _generate_revisions( f"An unexpected error occurred during LLM interaction: {e}" ) from e - def _run_consensus(self, revisions: List[Any]) -> List[Dict[str, Any]]: - """ - Runs consensus mechanism on revisions. - - Args: - revisions: List of normalized revision outputs - - Returns: - List of consensus entity dictionaries - - Raises: - ConsensusProcessError: If consensus fails - """ - try: - self.logger.debug(f"Running consensus on {len(revisions)} revisions") - - # Run consensus - consensus_output, details = self.consensus.get_consensus(revisions) - - # Record analytics if available - if details: - self.analytics_collector.record_consensus_run_details(details) - self.logger.debug(f"Consensus details: {details}") - - # Process and normalize output - processed = self._process_consensus_output(consensus_output) - - self.logger.debug(f"Consensus produced {len(processed)} entities") - - return processed - - except ConsensusProcessError: - # Re-raise consensus errors as-is - raise - - except Exception as e: - # Wrap unexpected errors - self.logger.error(f"Consensus processing failed: {e}") - raise ConsensusProcessError( - f"Failed during JSON consensus processing: {e}" - ) from e - - def _process_consensus_output(self, consensus_output: Any) -> List[Dict[str, Any]]: - """ - Normalizes consensus output to list format. - - The consensus mechanism can return various formats: - - None (no consensus reached) - - List of dicts (standard format) - - Dict with 'results' key - - Single dict (wrap in list) - - Args: - consensus_output: Raw output from consensus mechanism - - Returns: - Normalized list of entity dictionaries - - Raises: - ConsensusProcessError: If output format is unexpected - """ - # Handle None - if consensus_output is None: - self.logger.warning("Consensus returned None, returning empty list") - return [] - - # Handle list (standard format) - if isinstance(consensus_output, list): - return consensus_output - - # Handle dict - if isinstance(consensus_output, dict): - # Check for 'results' key (wrapped format) - if "results" in consensus_output and isinstance( - consensus_output["results"], list - ): - return consensus_output["results"] - - # Single entity dict, wrap in list - return [consensus_output] - - # Unexpected type - raise ConsensusProcessError( - f"Unexpected consensus output type: {type(consensus_output)}. " - f"Expected None, list, or dict." - ) - def get_client_count(self) -> int: """ Returns the number of LLM clients in rotation. diff --git a/src/extrai/core/model_registry.py b/src/extrai/core/model_registry.py index 06eca47..7151bdf 100644 --- a/src/extrai/core/model_registry.py +++ b/src/extrai/core/model_registry.py @@ -2,7 +2,7 @@ import json import logging -from typing import Type, List, Optional + from sqlmodel import SQLModel from .errors import ConfigurationError @@ -21,7 +21,7 @@ class ModelRegistry: """ def __init__( - self, root_model: Type[SQLModel], logger: Optional[logging.Logger] = None + self, root_model: type[SQLModel], logger: logging.Logger | None = None ): """ Initialize the model registry. @@ -56,7 +56,7 @@ def __init__( f"{', '.join(self.model_map.keys())}" ) - def _discover_models(self, root_model: Type[SQLModel]) -> List[Type[SQLModel]]: + def _discover_models(self, root_model: type[SQLModel]) -> list[type[SQLModel]]: """ Discovers all SQLModel classes from root. @@ -106,7 +106,7 @@ def _generate_llm_schema(self) -> str: except Exception as e: raise ConfigurationError(f"Failed to generate LLM schema: {e}") from e - def get_schema_for_models(self, model_names: List[str]) -> str: + def get_schema_for_models(self, model_names: list[str]) -> str: """ Generates schema JSON for specific models. @@ -135,7 +135,7 @@ def get_schema_for_models(self, model_names: List[str]) -> str: self.logger.error(f"Failed to generate schema for {model_names}: {e}") return self.llm_schema_json - def get_model_by_name(self, name: str) -> Optional[Type[SQLModel]]: + def get_model_by_name(self, name: str) -> type[SQLModel] | None: """ Retrieves a model class by name. @@ -147,7 +147,7 @@ def get_model_by_name(self, name: str) -> Optional[Type[SQLModel]]: """ return self.model_map.get(name) - def get_all_model_names(self) -> List[str]: + def get_all_model_names(self) -> list[str]: """ Returns list of all discovered model names. diff --git a/src/extrai/core/model_wrapper_builder.py b/src/extrai/core/model_wrapper_builder.py index e7b1393..8cf7e7e 100644 --- a/src/extrai/core/model_wrapper_builder.py +++ b/src/extrai/core/model_wrapper_builder.py @@ -1,8 +1,9 @@ -from typing import Type, List, Optional, Any, Dict -from pydantic import BaseModel, create_model, Field -from sqlmodel import SQLModel +from typing import Any, Optional + +from pydantic import BaseModel, Field, create_model from sqlalchemy import inspect from sqlalchemy.orm import RelationshipProperty +from sqlmodel import SQLModel class ModelWrapperBuilder: @@ -13,11 +14,11 @@ class ModelWrapperBuilder: """ def __init__(self): - self._generated_models: Dict[Type[SQLModel], Type[BaseModel]] = {} + self._generated_models: dict[type[SQLModel], type[BaseModel]] = {} def generate_wrapper_model( - self, root_sqlmodel: Type[SQLModel], include_relationships: bool = True - ) -> Type[BaseModel]: + self, root_sqlmodel: type[SQLModel], include_relationships: bool = True + ) -> type[BaseModel]: """ Generates a Pydantic wrapper model for the given root SQLModel. This wrapper creates a hierarchy of Pydantic models where relationships @@ -42,7 +43,7 @@ def generate_wrapper_model( wrapper_model = create_model( wrapper_name, entities=( - List[pydantic_model], + list[pydantic_model], Field( description=f"List of extracted {root_sqlmodel.__name__} entities." ), @@ -97,8 +98,8 @@ def _enrich_field_description(self, field_info: Any) -> Any: return new_field_info def _create_pydantic_model_recursive( - self, sql_model: Type[SQLModel], include_relationships: bool = True - ) -> Type[BaseModel]: + self, sql_model: type[SQLModel], include_relationships: bool = True + ) -> type[BaseModel]: if sql_model in self._generated_models: return self._generated_models[sql_model] @@ -135,7 +136,7 @@ def _create_pydantic_model_recursive( if rel.uselist: # List[NestedModel] - field_type = List[nested_model] + field_type = list[nested_model] field_desc = f"List of {target_model.__name__} items." else: # NestedModel (Optional?) diff --git a/src/extrai/core/pricing_updater.py b/src/extrai/core/pricing_updater.py new file mode 100644 index 0000000..4cfb1a5 --- /dev/null +++ b/src/extrai/core/pricing_updater.py @@ -0,0 +1,83 @@ +# extrai/core/pricing_updater.py +import json +import os +from datetime import datetime, timedelta + +import requests +from jsonschema import ValidationError, validate + +PRICING_URL = "https://www.llm-prices.com/current-v1.json" +CACHE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "model_prices.json") +CACHE_EXPIRATION = timedelta(days=1) + +PRICING_SCHEMA = { + "type": "object", + "properties": { + "updated_at": {"type": "string", "format": "date"}, + "prices": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "vendor": {"type": "string"}, + "name": {"type": "string"}, + "input": {"type": "number"}, + "output": {"type": "number"}, + "input_cached": {"type": ["number", "null"]}, + }, + "required": ["id", "vendor", "name", "input", "output"], + }, + }, + }, + "required": ["updated_at", "prices"], +} + + +def fetch_pricing_data(): + """Fetches pricing data from the remote URL and validates it.""" + response = requests.get(PRICING_URL) + response.raise_for_status() + data = response.json() + + try: + validate(instance=data, schema=PRICING_SCHEMA) + except ValidationError as e: + # Handle validation error, e.g., log it or raise an exception + print(f"Pricing data validation error: {e}") + return None + + return data + + +def save_pricing_data(data): + """Saves pricing data to the local cache.""" + if data is None: + return + + with open(CACHE_FILE, "w") as f: + json.dump(data, f, indent=2) + + +def load_pricing_data(): + """Loads pricing data from the local cache.""" + if not os.path.exists(CACHE_FILE): + return None + with open(CACHE_FILE) as f: + return json.load(f) + + +def is_cache_stale(): + """Checks if the cached pricing data is stale.""" + if not os.path.exists(CACHE_FILE): + return True + + last_modified_time = datetime.fromtimestamp(os.path.getmtime(CACHE_FILE)) + return datetime.now() - last_modified_time > CACHE_EXPIRATION + + +def update_prices_if_stale(): + """Updates the pricing data if the cache is stale.""" + if is_cache_stale(): + data = fetch_pricing_data() + save_pricing_data(data) diff --git a/src/extrai/core/prompt_builder.py b/src/extrai/core/prompt_builder.py index 1f4e226..7708f96 100644 --- a/src/extrai/core/prompt_builder.py +++ b/src/extrai/core/prompt_builder.py @@ -4,19 +4,9 @@ """ import logging -from typing import List, Optional, Tuple -from extrai.core.model_registry import ModelRegistry +from extrai.core.model_registry import ModelRegistry from extrai.core.prompts.common import generate_user_prompt_for_docs -from extrai.core.prompts.extraction import ( - generate_system_prompt, -) -from extrai.core.prompts.structured_extraction import ( - generate_structured_system_prompt, -) -from extrai.core.prompts.sqlmodel import ( - generate_sqlmodel_creation_system_prompt, -) from extrai.core.prompts.counting import ( generate_entity_counting_system_prompt, generate_entity_counting_user_prompt, @@ -24,6 +14,15 @@ from extrai.core.prompts.examples import ( generate_prompt_for_example_json_generation, ) +from extrai.core.prompts.extraction import ( + generate_system_prompt, +) +from extrai.core.prompts.sqlmodel import ( + generate_sqlmodel_creation_system_prompt, +) +from extrai.core.prompts.structured_extraction import ( + generate_structured_system_prompt, +) class PromptBuilder: @@ -33,24 +32,24 @@ class PromptBuilder: """ def __init__( - self, model_registry: ModelRegistry, logger: Optional[logging.Logger] = None + self, model_registry: ModelRegistry, logger: logging.Logger | None = None ): self.model_registry = model_registry self.logger = logger or logging.getLogger(__name__) def build_prompts( self, - input_strings: List[str], + input_strings: list[str], schema_json: str, extraction_example_json: str = "", custom_extraction_process: str = "", custom_extraction_guidelines: str = "", custom_final_checklist: str = "", custom_context: str = "", - expected_entity_descriptions: Optional[List[str]] = None, - previous_entities: Optional[List[dict]] = None, - target_model_name: Optional[str] = None, - ) -> Tuple[str, str]: + expected_entity_descriptions: list[dict] | None = None, + previous_entities: list[dict] | None = None, + target_model_name: str | None = None, + ) -> tuple[str, str]: """ Builds system and user prompts for extraction. """ @@ -72,15 +71,15 @@ def build_prompts( def build_structured_prompts( self, - input_strings: List[str], + input_strings: list[str], custom_extraction_process: str = "", custom_extraction_guidelines: str = "", custom_context: str = "", extraction_example_json: str = "", - expected_entity_descriptions: Optional[List[str]] = None, - previous_entities: Optional[List[dict]] = None, - target_model_name: Optional[str] = None, - ) -> Tuple[str, str]: + expected_entity_descriptions: list[dict] | None = None, + previous_entities: list[dict] | None = None, + target_model_name: str | None = None, + ) -> tuple[str, str]: """ Builds prompts for structured extraction. """ diff --git a/src/extrai/core/prompts/__init__.py b/src/extrai/core/prompts/__init__.py index 5898c7b..121c76a 100644 --- a/src/extrai/core/prompts/__init__.py +++ b/src/extrai/core/prompts/__init__.py @@ -1,11 +1,11 @@ -from .extraction import generate_system_prompt from .common import generate_user_prompt_for_docs -from .sqlmodel import generate_sqlmodel_creation_system_prompt from .counting import ( generate_entity_counting_system_prompt, generate_entity_counting_user_prompt, ) from .examples import generate_prompt_for_example_json_generation +from .extraction import generate_system_prompt +from .sqlmodel import generate_sqlmodel_creation_system_prompt __all__ = [ "generate_system_prompt", diff --git a/src/extrai/core/prompts/common.py b/src/extrai/core/prompts/common.py index d94553b..b36fea1 100644 --- a/src/extrai/core/prompts/common.py +++ b/src/extrai/core/prompts/common.py @@ -1,8 +1,5 @@ -from typing import List - - def generate_user_prompt_for_docs( - documents: List[str], custom_context: str = "" + documents: list[str], custom_context: str = "" ) -> str: """ Generates a generic user prompt containing the documents for extraction. diff --git a/src/extrai/core/prompts/counting.py b/src/extrai/core/prompts/counting.py index 6dafaf9..ba20786 100644 --- a/src/extrai/core/prompts/counting.py +++ b/src/extrai/core/prompts/counting.py @@ -1,12 +1,14 @@ import json -from typing import List, Dict, Any, Optional +from typing import Any def generate_entity_counting_system_prompt( model_names: list[str], schema_json: str = None, custom_counting_context: str = "", - previous_entities: Optional[List[Dict[str, Any]]] = None, + previous_entities: list[dict[str, Any]] | None = None, + examples: str = "", + conflicting_revisions: list[dict[str, Any]] | None = None, ) -> str: """ Generates a system prompt for counting entities in the provided documents. @@ -17,6 +19,8 @@ def generate_entity_counting_system_prompt( This helps the LLM understand the structure of the entities to count. custom_counting_context: Optional custom context to guide the counting phase. previous_entities: Optional list of previously extracted entities for context. + examples: Optional string containing examples of the entities to count. + conflicting_revisions: Optional list of previous conflicting counting attempts to merge. Returns: A string representing the system prompt for entity counting. @@ -40,7 +44,8 @@ def generate_entity_counting_system_prompt( # PREVIOUSLY EXTRACTED ENTITIES: {entities_json} -IMPORTANT: If the entities you are counting are related to any of the previously extracted entities above, you MUST specify the unique ID (or temp_id) of that related entity in your description string. This ensures correct linking in subsequent steps. +IMPORTANT: If the entities you are counting are related to any of the previously extracted entities above, you MUST specify the unique ID (or temp_id) of that related entity in your description string. This ensures correct linking in subsequent steps. Therefore take a good look at the previously extracted entities and ensure they are linked correctly. +Do not hesitate to add details to help identify those links. Also note that it's possible that there are no objects to extract! """ prompt += f""" @@ -49,30 +54,35 @@ def generate_entity_counting_system_prompt( ```json {schema_json} ``` +""" + + if examples: + prompt += f""" +# EXAMPLES: +Here are some examples of the objects that will be extracted on the next step. Your goal is to facilitate the extraction of these objects in the future: +{examples} +""" + + if conflicting_revisions: + revisions_json = json.dumps(conflicting_revisions, indent=2) + prompt += f""" +# MERGE REQUIRED: +Previous extraction attempts returned conflicting results. Here are the conflicting revisions: +{revisions_json} + +Your task is to cross-reference these previous attempts with the text and provide the final, comprehensive, and correct list of entities, resolving any discrepancies. """ prompt += """ # OUTPUT INSTRUCTIONS: -1. **Output Format:** Your output must be a single, valid JSON object. -2. **Keys:** The JSON object keys must be the exact names of the entities provided above. -3. **Values:** The values must be a list of strings, where each string is a description of the entity found. -4. **Order:** The order of the descriptions in the list must match the order of appearance in the document. -5. **Relational Detail:** If an entity relates to a previously extracted entity (e.g., a child entity belonging to a parent), your description MUST include the ID of that parent entity from the provided context. -6. **No Extra Text:** Do NOT include any explanations, markdown formatting, or text outside the JSON object. - -Example Output: -{{ - "Invoice": [ - "Invoice #123 from ABC Corp with a value of 50euros", - "Invoice #456 from XYZ Inc with a value of 506euros", - "Invoice #789 from Foo Bar with a value of 30euros" - ], - "LineItem": [ - "Item A - Widget linked to Invoice ID: invoice_123", - "Item B - Gadget linked to Invoice ID: invoice_123", - "Item C - Doohickey linked to Invoice ID: invoice_456", - ] -}} +1. **Output Format:** Your output must be a single JSON object with a `counted_entities` array. +2. **Array Items:** Each item in the array must be an object containing: + - `model`: the exact name of the entity model + - `temp_id`: a unique temporary string identifier for this specific entity instance + - `related_ids`: a list of string identifiers (temp_id or actual id) of any related entities + - `description`: a detailed description of the entity found +3. **Order:** The order of the entities in the list should generally match the order of appearance in the document. +4. **Relational Detail:** If an entity relates to a previously extracted entity (e.g., a child entity belonging to a parent), you MUST include the ID of that parent entity in the `related_ids` list and optionally in the `description`. Proceed with identifying and describing the entities in the user's documents. """.strip() @@ -100,6 +110,6 @@ def generate_entity_counting_user_prompt(documents: list[str]) -> str: {combined_documents} --- -Remember: Your output must be only a single, valid JSON object mapping entity names to counts. +Remember: Your output must match the structured format requested (an object with a `counted_entities` array). """.strip() return prompt diff --git a/src/extrai/core/prompts/extraction.py b/src/extrai/core/prompts/extraction.py index 8bca2a2..29a62be 100644 --- a/src/extrai/core/prompts/extraction.py +++ b/src/extrai/core/prompts/extraction.py @@ -1,5 +1,5 @@ import json -from typing import Optional, List, Dict, Any +from typing import Any def generate_system_prompt( @@ -9,9 +9,9 @@ def generate_system_prompt( custom_extraction_guidelines: str = "", custom_final_checklist: str = "", custom_context: str = "", - expected_entity_descriptions: Optional[List[str]] = None, - previous_entities: Optional[List[Dict[str, Any]]] = None, - target_model_name: Optional[str] = None, + expected_entity_descriptions: list[dict] | None = None, + previous_entities: list[dict[str, Any]] | None = None, + target_model_name: str | None = None, ) -> str: """ Generates a generic system prompt for guiding an LLM to extract information @@ -111,16 +111,27 @@ def generate_system_prompt( "Do not extract other entity types in this step." ) - if expected_entity_descriptions: + if expected_entity_descriptions is not None: prompt_parts.append("\n# EXPECTED ENTITIES & ORDER:") - prompt_parts.append( - "You MUST extract entities matching the following descriptions, in this exact order:" - ) - for i, desc in enumerate(expected_entity_descriptions, 1): - prompt_parts.append(f"{i}. {desc}") - prompt_parts.append( - f"\nYou must extract EXACTLY {len(expected_entity_descriptions)} items/entities corresponding to these descriptions." - ) + if len(expected_entity_descriptions) == 0: + prompt_parts.append( + "Based on the counting phase, there are NO entities of this type to extract. " + "You MUST return an empty array/list. Extract exactly 0 entities." + ) + else: + prompt_parts.append( + "You MUST extract entities matching the following descriptions, in this exact order:" + ) + for i, entity_dict in enumerate(expected_entity_descriptions, 1): + model = entity_dict.get("model", "Unknown") + desc = entity_dict.get("description", "") + related_ids = entity_dict.get("related_ids", []) + related_str = f" | Related IDs: {', '.join(related_ids)}" if related_ids else "" + prompt_parts.append(f"{i}. [Model: {model}] Description: {desc}{related_str}") + + prompt_parts.append( + f"\nYou must extract EXACTLY {len(expected_entity_descriptions)} items/entities corresponding to these descriptions." + ) if custom_context: prompt_parts.append("\n# ADDITIONAL CONTEXT:") diff --git a/src/extrai/core/prompts/structured_extraction.py b/src/extrai/core/prompts/structured_extraction.py index 8b8b430..1fdd6d5 100644 --- a/src/extrai/core/prompts/structured_extraction.py +++ b/src/extrai/core/prompts/structured_extraction.py @@ -1,5 +1,5 @@ import json -from typing import Optional, List, Dict, Any +from typing import Any def generate_structured_system_prompt( @@ -7,9 +7,9 @@ def generate_structured_system_prompt( custom_extraction_guidelines: str = "", custom_context: str = "", extraction_example_json: str = "", - expected_entity_descriptions: Optional[List[str]] = None, - previous_entities: Optional[List[Dict[str, Any]]] = None, - target_model_name: Optional[str] = None, + expected_entity_descriptions: list[dict] | None = None, + previous_entities: list[dict[str, Any]] | None = None, + target_model_name: str | None = None, ) -> str: """ Generates a system prompt tailored for structured output extraction. @@ -37,6 +37,7 @@ def generate_structured_system_prompt( 3. **Accuracy:** Ensure all extracted data is accurate and supported by the text. 4. **Inference:** If a field is missing but can be reasonably inferred from context, you may do so. Otherwise, leave it as null/None. 5. **Relationships:** Capture relationships by nesting entities as defined in the structure. +6. **IDs:** If the schema contains an `id` field, you MUST populate it. If the `id` is an integer, start counting from 1. If it is a string, create a deterministic ID based on the content. """ parts = [default_instructions] @@ -48,16 +49,27 @@ def generate_structured_system_prompt( "Do not extract other entity types in this step." ) - if expected_entity_descriptions: + if expected_entity_descriptions is not None: parts.append("# EXPECTED ENTITIES & ORDER") - parts.append( - "You MUST extract entities matching the following descriptions, in this exact order:" - ) - for i, desc in enumerate(expected_entity_descriptions, 1): - parts.append(f"{i}. {desc}") - parts.append( - f"\nYou must extract EXACTLY {len(expected_entity_descriptions)} items/entities corresponding to these descriptions." - ) + if len(expected_entity_descriptions) == 0: + parts.append( + "Based on the counting phase, there are NO entities of this type to extract. " + "You MUST return an empty array/list. Extract exactly 0 entities." + ) + else: + parts.append( + "You MUST extract entities matching the following descriptions, in this exact order:" + ) + for i, entity_dict in enumerate(expected_entity_descriptions, 1): + model = entity_dict.get("model", "Unknown") + desc = entity_dict.get("description", "") + related_ids = entity_dict.get("related_ids", []) + related_str = f" | Related IDs: {', '.join(related_ids)}" if related_ids else "" + parts.append(f"{i}. [Model: {model}] Description: {desc}{related_str}") + + parts.append( + f"\nYou must extract EXACTLY {len(expected_entity_descriptions)} items/entities corresponding to these descriptions." + ) # Assemble comprehensive custom instructions instructions_parts = [] diff --git a/src/extrai/core/result_processor.py b/src/extrai/core/result_processor.py index d106e91..c3f01df 100644 --- a/src/extrai/core/result_processor.py +++ b/src/extrai/core/result_processor.py @@ -1,23 +1,21 @@ +import json import logging import uuid from typing import ( - List, - Dict, Any, - Optional, - Type, - get_origin, - get_args, - Union, NamedTuple, + Union, + get_args, + get_origin, ) -from sqlalchemy.orm import Session + from sqlalchemy import create_engine, inspect from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from sqlmodel import SQLModel -from .model_registry import ModelRegistry from .errors import HydrationError, WorkflowError +from .model_registry import ModelRegistry SQLModelInstance = SQLModel @@ -29,8 +27,8 @@ class DatabaseWriterError(Exception): class PrimaryKeyInfo(NamedTuple): - name: Optional[str] - type: Optional[Type[Any]] + name: str | None + type: type[Any] | None has_uuid_factory: bool @@ -45,9 +43,9 @@ class DirectHydrator: def __init__( self, session: Session, - logger: Optional[logging.Logger] = None, - original_pk_map: Dict[tuple[str, Any], SQLModelInstance] = None, - all_instances: List[SQLModelInstance] = None, + logger: logging.Logger | None = None, + original_pk_map: dict[tuple[str, Any], SQLModelInstance] = None, + all_instances: list[SQLModelInstance] = None, ): self.session = session self.logger = logger or logging.getLogger(__name__) @@ -56,13 +54,19 @@ def __init__( def hydrate( self, - data: List[Dict[str, Any]], - model_map: Dict[str, Type[SQLModel]], - default_model_class: Optional[Type[SQLModel]] = None, - ) -> List[SQLModelInstance]: + data: list[dict[str, Any]], + model_map: dict[str, type[SQLModel]], + default_model_class: type[SQLModel] | None = None, + ) -> list[SQLModelInstance]: instances = [] for item in data: try: + if isinstance(item, str): + try: + item = json.loads(item) + except json.JSONDecodeError: + self.logger.error(f"Failed to decode JSON string: {item}") + continue # Determine model class _type = item.get("_type") model_class = None @@ -89,9 +93,9 @@ def hydrate( def _hydrate_recursive( self, - data: Dict[str, Any], - model_class: Type[SQLModel], - model_map: Dict[str, Type[SQLModel]], + data: dict[str, Any], + model_class: type[SQLModel], + model_map: dict[str, type[SQLModel]], ) -> SQLModelInstance: """ Recursively hydrates an instance and its relationships. @@ -190,9 +194,9 @@ class SQLAlchemyHydrator: def __init__( self, session: Session, - logger: Optional[logging.Logger] = None, - original_pk_map: Dict[tuple[str, Any], SQLModelInstance] = None, - all_instances: List[SQLModelInstance] = None, + logger: logging.Logger | None = None, + original_pk_map: dict[tuple[str, Any], SQLModelInstance] = None, + all_instances: list[SQLModelInstance] = None, ): """ Initializes the Hydrator. @@ -203,14 +207,14 @@ def __init__( logger: Optional logger instance. """ self.session: Session = session - self.temp_id_to_instance_map: Dict[ + self.temp_id_to_instance_map: dict[ str, SQLModelInstance ] = {} # Stores _temp_id -> SQLModel instance self.original_pk_map = original_pk_map if original_pk_map is not None else {} self.all_instances = all_instances if all_instances is not None else [] self.logger = logger or logging.getLogger(__name__) - def _filter_special_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + def _filter_special_fields(self, data: dict[str, Any]) -> dict[str, Any]: """Removes _temp_id, _type, and relationship reference fields before Pydantic validation.""" return { k: v @@ -220,7 +224,7 @@ def _filter_special_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: and not k.endswith("_ref_ids") } - def _validate_entities_list(self, entities_list: List[Dict[str, Any]]) -> None: + def _validate_entities_list(self, entities_list: list[dict[str, Any]]) -> None: """Performs initial validation on the input entities list.""" if not isinstance(entities_list, list): raise TypeError( @@ -235,7 +239,7 @@ def _validate_entities_list(self, entities_list: List[Dict[str, Any]]) -> None: f"Found an item of type: {type(first_non_dict)}." ) - def _get_primary_key_info(self, model_class: Type[SQLModel]) -> PrimaryKeyInfo: + def _get_primary_key_info(self, model_class: type[SQLModel]) -> PrimaryKeyInfo: """Introspects the model to find primary key details.""" for field_name, model_field in model_class.model_fields.items(): if getattr(model_field, "primary_key", False): @@ -268,7 +272,7 @@ def _get_primary_key_info(self, model_class: Type[SQLModel]) -> PrimaryKeyInfo: return PrimaryKeyInfo(name=None, type=None, has_uuid_factory=False) def _generate_pk_if_needed( - self, instance: SQLModelInstance, model_class: Type[SQLModel] + self, instance: SQLModelInstance, model_class: type[SQLModel] ) -> None: """Generates a primary key for the instance if it's needed.""" pk_info = self._get_primary_key_info(model_class) @@ -288,8 +292,8 @@ def _generate_pk_if_needed( def _create_single_instance( self, - entity_data: Dict[str, Any], - model_schema_map: Dict[str, Type[SQLModel]], + entity_data: dict[str, Any], + model_schema_map: dict[str, type[SQLModel]], ) -> None: """Creates a single SQLModel instance from its dictionary representation.""" _temp_id = entity_data.get("_temp_id") @@ -312,7 +316,7 @@ def _create_single_instance( filtered_data = self._filter_special_fields(entity_data.copy()) - pk_field_name: Optional[str] = None + pk_field_name: str | None = None for field_name, model_field in model_class.model_fields.items(): if getattr(model_field, "primary_key", False): pk_field_name = field_name @@ -345,8 +349,8 @@ def _create_single_instance( def _create_and_map_instances( self, - entities_list: List[Dict[str, Any]], - model_schema_map: Dict[str, Type[SQLModel]], + entities_list: list[dict[str, Any]], + model_schema_map: dict[str, type[SQLModel]], ) -> None: """Pass 1: Creates and maps all SQLModel instances.""" for entity_data in entities_list: @@ -357,7 +361,7 @@ def _link_to_one_relation( instance: SQLModelInstance, relation_name: str, ref_id: Any, - entity_data: Dict[str, Any], + entity_data: dict[str, Any], ) -> None: """Handles the logic for a single to-one relationship.""" if ref_id is None: @@ -380,7 +384,7 @@ def _link_to_many_relation( instance: SQLModelInstance, relation_name: str, ref_ids: Any, - entity_data: Dict[str, Any], + entity_data: dict[str, Any], ) -> None: """Handles the logic for a single to-many relationship.""" _temp_id = entity_data.get("_temp_id", "N/A") @@ -405,7 +409,7 @@ def _link_to_many_relation( ) setattr(instance, relation_name, related_instances) - def _link_relations_for_instance(self, entity_data: Dict[str, Any]) -> None: + def _link_relations_for_instance(self, entity_data: dict[str, Any]) -> None: """Links relationships for a single instance by dispatching to specialized helpers.""" _temp_id = entity_data["_temp_id"] instance = self.temp_id_to_instance_map[_temp_id] @@ -424,7 +428,7 @@ def _link_relations_for_instance(self, entity_data: Dict[str, Any]) -> None: instance, relation_name, value, entity_data ) - def _link_relationships(self, entities_list: List[Dict[str, Any]]) -> None: + def _link_relationships(self, entities_list: list[dict[str, Any]]) -> None: """Pass 2: Links all created instances together.""" for entity_data in entities_list: self._link_relations_for_instance(entity_data) @@ -436,9 +440,9 @@ def _add_instances_to_session(self) -> None: def hydrate( self, - entities_list: List[Dict[str, Any]], - model_schema_map: Dict[str, Type[SQLModel]], - ) -> List[SQLModelInstance]: + entities_list: list[dict[str, Any]], + model_schema_map: dict[str, type[SQLModel]], + ) -> list[SQLModelInstance]: """ Hydrates SQLModel objects from a list of entity data dictionaries. """ @@ -458,7 +462,7 @@ def hydrate( def persist_objects( - db_session: Session, objects_to_persist: List[Any], logger: logging.Logger + db_session: Session, objects_to_persist: list[Any], logger: logging.Logger ) -> None: """ Persists a list of SQLAlchemy objects to the database using the provided session. @@ -518,15 +522,15 @@ def __init__( self.model_registry = model_registry self.analytics_collector = analytics_collector self.logger = logger - self.original_pk_map: Dict[tuple[str, Any], SQLModelInstance] = {} - self.all_hydrated_instances: List[SQLModelInstance] = [] + self.original_pk_map: dict[tuple[str, Any], SQLModelInstance] = {} + self.all_hydrated_instances: list[SQLModelInstance] = [] def hydrate( self, - results: List[Dict[str, Any]], - db_session: Optional[Session] = None, - default_model_type: Optional[str] = None, - ) -> List[Any]: + results: list[dict[str, Any]], + db_session: Session | None = None, + default_model_type: str | None = None, + ) -> list[Any]: """ Hydrates dictionaries into SQLModel objects. @@ -604,7 +608,7 @@ def hydrate( if db_session is None and session: session.close() - def persist(self, objects: List[Any], db_session: Session): + def persist(self, objects: list[Any], db_session: Session): """Persists objects to database.""" if not objects: self.logger.info("No objects to persist") @@ -626,7 +630,7 @@ def persist(self, objects: List[Any], db_session: Session): raise WorkflowError(f"Persistence failed: {e}") from e def _link_foreign_keys( - self, instances: Optional[List[SQLModelInstance]] = None + self, instances: list[SQLModelInstance] | None = None ) -> None: """ Links foreign keys for all hydrated instances before persisting. @@ -639,8 +643,8 @@ def _link_foreign_keys( def _perform_fk_recovery( self, - instances: List[SQLModelInstance], - original_pk_map: Dict[tuple[str, Any], SQLModelInstance], + instances: list[SQLModelInstance], + original_pk_map: dict[tuple[str, Any], SQLModelInstance], ) -> None: """ Scans all hydrated instances for Foreign Key fields that are set (not None) @@ -691,7 +695,7 @@ def _perform_fk_recovery( f"Universal FK Recovery: Restored {count_recovered} relationships." ) - def _get_or_create_session(self, db_session: Optional[Session]) -> Session: + def _get_or_create_session(self, db_session: Session | None) -> Session: """Creates temporary in-memory session if none provided.""" if db_session: return db_session diff --git a/src/extrai/core/schema_inspector.py b/src/extrai/core/schema_inspector.py index 70544e2..bc3c8c1 100644 --- a/src/extrai/core/schema_inspector.py +++ b/src/extrai/core/schema_inspector.py @@ -1,23 +1,24 @@ +import enum import json import logging -import enum -from typing import Type, List, Optional, Any, Dict, Set, Tuple -from sqlalchemy import inspect, Column, Table -from sqlalchemy.orm import RelationshipProperty +from typing import Any + +from sqlalchemy import Column, Table, inspect from sqlalchemy.exc import NoInspectionAvailable -from sqlalchemy.schema import UniqueConstraint, PrimaryKeyConstraint +from sqlalchemy.orm import RelationshipProperty +from sqlalchemy.schema import PrimaryKeyConstraint, UniqueConstraint from sqlmodel import SQLModel from extrai.utils.type_mapping import ( - map_sql_type_to_llm_type, get_python_type_str_from_pydantic_annotation, + map_sql_type_to_llm_type, ) class SchemaInspector: """Helper class to inspect SQLAlchemy models and generate LLM schemas.""" - def __init__(self, logger: Optional[logging.Logger] = None): + def __init__(self, logger: logging.Logger | None = None): self.logger = logger or logging.getLogger(__name__) def _is_column_unique(self, column_obj: Column) -> bool: @@ -44,7 +45,7 @@ def _get_python_type_from_column(self, column_obj: Column) -> str: def _build_column_info( self, column_obj: Column, is_unique: bool, python_type_name: str - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Builds the column information dictionary.""" enum_values = None # Handle SQLAlchemy Enum types (both class-based and string-based) @@ -72,7 +73,7 @@ def _build_column_info( col_info["foreign_key_to"] = str(fk_constraint_obj.column) return col_info - def _get_columns_from_inspector(self, inspector) -> Dict[str, Any]: + def _get_columns_from_inspector(self, inspector) -> dict[str, Any]: """Extracts all column properties from a SQLAlchemy inspector.""" columns_info = {} for col_attr in inspector.column_attrs: @@ -86,9 +87,9 @@ def _get_columns_from_inspector(self, inspector) -> Dict[str, Any]: ) return columns_info - def _get_fks_from_secondary_table(self, rel_prop: RelationshipProperty) -> Set[str]: + def _get_fks_from_secondary_table(self, rel_prop: RelationshipProperty) -> set[str]: """Handles relationships that use a secondary table.""" - involved_fk_columns: Set[str] = set() + involved_fk_columns: set[str] = set() if rel_prop.secondary is not None: for fk_constraint in rel_prop.secondary.foreign_key_constraints: for col in fk_constraint.columns: @@ -97,9 +98,9 @@ def _get_fks_from_secondary_table(self, rel_prop: RelationshipProperty) -> Set[s def _get_fks_from_synchronize_pairs( self, rel_prop: RelationshipProperty - ) -> Set[str]: + ) -> set[str]: """Handles relationships that use synchronize_pairs.""" - involved_fk_columns: Set[str] = set() + involved_fk_columns: set[str] = set() if rel_prop.synchronize_pairs: for local_join_col, remote_join_col in rel_prop.synchronize_pairs: if ( @@ -116,15 +117,15 @@ def _get_fks_from_synchronize_pairs( def _get_fks_from_direct_foreign_keys( self, rel_prop: RelationshipProperty - ) -> Set[str]: + ) -> set[str]: """Handles relationships that have direct foreign_keys.""" - involved_fk_columns: Set[str] = set() + involved_fk_columns: set[str] = set() if hasattr(rel_prop, "foreign_keys") and rel_prop.foreign_keys is not None: for fk_col in rel_prop.foreign_keys: involved_fk_columns.add(str(fk_col)) return involved_fk_columns - def _get_involved_foreign_keys(self, rel_prop: RelationshipProperty) -> Set[str]: + def _get_involved_foreign_keys(self, rel_prop: RelationshipProperty) -> set[str]: """ Finds all foreign key columns involved in a relationship by dispatching to helper functions. """ @@ -142,9 +143,9 @@ def _get_involved_foreign_keys(self, rel_prop: RelationshipProperty) -> Set[str] def _build_relationship_info( self, rel_prop: RelationshipProperty, - involved_fk_columns: Set[str], - recursion_path_tracker: Set[Type[Any]], - ) -> Dict[str, Any]: + involved_fk_columns: set[str], + recursion_path_tracker: set[type[Any]], + ) -> dict[str, Any]: """Builds the relationship information dictionary, including recursion.""" related_model_class = rel_prop.mapper.class_ return { @@ -169,8 +170,8 @@ def _build_relationship_info( } def _get_relationships_from_inspector( - self, inspector, recursion_path_tracker: Set[Type[Any]] - ) -> Dict[str, Any]: + self, inspector, recursion_path_tracker: set[type[Any]] + ) -> dict[str, Any]: """Extracts all relationship properties from a SQLAlchemy inspector.""" relationships_info = {} for name, rel_prop in inspector.relationships.items(): @@ -182,8 +183,8 @@ def _get_relationships_from_inspector( return relationships_info def _inspect_sqlalchemy_model_recursive( - self, model_class: Type[Any], recursion_path_tracker: Set[Type[Any]] - ) -> Dict[str, Any]: + self, model_class: type[Any], recursion_path_tracker: set[type[Any]] + ) -> dict[str, Any]: """ Internal recursive function to introspect a SQLAlchemy model class. """ @@ -225,7 +226,7 @@ def _inspect_sqlalchemy_model_recursive( recursion_path_tracker.add(model_class) - schema_info: Dict[str, Any] = { + schema_info: dict[str, Any] = { "table_name": table_name_str, "model_name": model_class.__name__, "info_dict": table_info_dict, @@ -239,7 +240,7 @@ def _inspect_sqlalchemy_model_recursive( recursion_path_tracker.remove(model_class) return schema_info - def inspect_sqlalchemy_model(self, model_class: Type[Any]) -> Dict[str, Any]: + def inspect_sqlalchemy_model(self, model_class: type[Any]) -> dict[str, Any]: """ Public wrapper function to start the SQLAlchemy model introspection. """ @@ -247,9 +248,9 @@ def inspect_sqlalchemy_model(self, model_class: Type[Any]) -> Dict[str, Any]: def _collect_all_sqla_models_recursively( self, - current_model_class: Type[Any], - all_discovered_models: List[Type[Any]], - recursion_guard: Set[Type[Any]], + current_model_class: type[Any], + all_discovered_models: list[type[Any]], + recursion_guard: set[type[Any]], ) -> None: """ Recursively collects all unique SQLAlchemy model classes related to current_model_class. @@ -283,11 +284,11 @@ def _collect_all_sqla_models_recursively( def _get_prioritized_description( self, *, - custom_desc: Optional[str] = None, - pydantic_desc: Optional[str] = None, - info_dict: Optional[Dict[str, Any]] = None, - comment: Optional[str] = None, - ) -> Tuple[Optional[str], Dict[str, Any]]: + custom_desc: str | None = None, + pydantic_desc: str | None = None, + info_dict: dict[str, Any] | None = None, + comment: str | None = None, + ) -> tuple[str | None, dict[str, Any]]: """ Centralized helper to determine the best description from multiple sources. """ @@ -314,11 +315,11 @@ def _get_prioritized_description( def _process_column_for_llm_schema( self, col_name: str, - col_data: Dict[str, Any], - pydantic_fields: Dict[str, Any], - custom_descs: Dict[str, str], + col_data: dict[str, Any], + pydantic_fields: dict[str, Any], + custom_descs: dict[str, str], model_name: str, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Processes a single column to generate its LLM schema representation.""" python_type_for_mapping = str(col_data.get("python_type", "")) pydantic_field_description = None @@ -374,8 +375,8 @@ def _process_column_for_llm_schema( return col_name, formatted_string def _process_relationship_for_llm_schema( - self, rel_name: str, rel_data: Dict[str, Any], custom_descs: Dict[str, str] - ) -> Optional[Tuple[str, str]]: + self, rel_name: str, rel_data: dict[str, Any], custom_descs: dict[str, str] + ) -> tuple[str, str] | None: """Processes a single relationship to generate its LLM schema representation.""" related_model_name = rel_data.get("related_model_name", "UnknownRelatedModel") @@ -426,7 +427,7 @@ def _process_relationship_for_llm_schema( return ref_field_name_for_llm, formatted_string def _generate_model_level_description( - self, model_name: str, raw_schema: Dict[str, Any], custom_descs: Dict[str, str] + self, model_name: str, raw_schema: dict[str, Any], custom_descs: dict[str, str] ) -> str: """Generates the complete model-level description block.""" description, other_info = self._get_prioritized_description( @@ -455,8 +456,8 @@ def _generate_model_level_description( def generate_llm_schema_from_models( self, - initial_model_classes: List[Type[SQLModel]], - custom_field_descriptions: Optional[Dict[str, Dict[str, str]]] = None, + initial_model_classes: list[type[SQLModel]], + custom_field_descriptions: dict[str, dict[str, str]] | None = None, ) -> str: """ Generates an LLM-friendly schema representation for a list of SQLAlchemy models. @@ -464,7 +465,7 @@ def generate_llm_schema_from_models( if custom_field_descriptions is None: custom_field_descriptions = {} - all_sqla_models_to_document: List[Type[Any]] = [] + all_sqla_models_to_document: list[type[Any]] = [] for root_model_class in initial_model_classes: self._collect_all_sqla_models_recursively( root_model_class, all_sqla_models_to_document, set() @@ -530,8 +531,8 @@ def generate_llm_schema_from_models( def discover_sqlmodels_from_root( self, - root_sqlmodel_class: Type[SQLModel], - ) -> List[Type[SQLModel]]: + root_sqlmodel_class: type[SQLModel], + ) -> list[type[SQLModel]]: """ Discovers all unique SQLModel classes starting from a root SQLModel class. """ @@ -539,7 +540,7 @@ def discover_sqlmodels_from_root( self.logger.warning(f"{root_sqlmodel_class} is not a valid SQLModel class.") return [] - all_discovered_models: List[Type[SQLModel]] = [] + all_discovered_models: list[type[SQLModel]] = [] try: self._collect_all_sqla_models_recursively( current_model_class=root_sqlmodel_class, diff --git a/src/extrai/core/shared/consensus_runner.py b/src/extrai/core/shared/consensus_runner.py new file mode 100644 index 0000000..e94713c --- /dev/null +++ b/src/extrai/core/shared/consensus_runner.py @@ -0,0 +1,75 @@ +import logging +from typing import Any + +from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from extrai.core.errors import ConsensusProcessError +from extrai.core.extraction_config import ExtractionConfig +from extrai.core.json_consensus import JSONConsensus, default_conflict_resolver +from extrai.utils.alignment_utils import normalize_json_revisions + + +class ConsensusRunner: + def __init__( + self, + config: ExtractionConfig, + analytics_collector: WorkflowAnalyticsCollector, + logger: logging.Logger, + ): + self.config = config + self.analytics_collector = analytics_collector + self.logger = logger + self.consensus = JSONConsensus( + consensus_threshold=config.consensus_threshold, + conflict_resolver=config.conflict_resolver or default_conflict_resolver, + logger=logger, + ) + + def run(self, revisions: list[list[dict]]) -> list[dict]: + try: + self.logger.debug(f"Running consensus on {len(revisions)} revisions") + + normalized_revisions = normalize_json_revisions(revisions) + + consensus_output, details = self.consensus.get_consensus( + normalized_revisions + ) + + if details: + self.analytics_collector.record_consensus_run_details(details) + self.logger.debug(f"Consensus details: {details}") + + processed = self._process_output(consensus_output) + + self.logger.debug(f"Consensus produced {len(processed)} entities") + + return processed + + except ConsensusProcessError: + raise + + except Exception as e: + self.logger.error(f"Consensus processing failed: {e}") + raise ConsensusProcessError( + f"Failed during JSON consensus processing: {e}" + ) from e + + def _process_output(self, consensus_output: Any) -> list[dict[str, Any]]: + if consensus_output is None: + self.logger.warning("Consensus returned None, returning empty list") + return [] + + if isinstance(consensus_output, list): + return consensus_output + + if isinstance(consensus_output, dict): + if "results" in consensus_output and isinstance( + consensus_output["results"], list + ): + return consensus_output["results"] + + return [consensus_output] + + raise ConsensusProcessError( + f"Unexpected consensus output type: {type(consensus_output)}. " + f"Expected None, list, or dict." + ) diff --git a/src/extrai/core/shared/hierarchical_coordinator.py b/src/extrai/core/shared/hierarchical_coordinator.py new file mode 100644 index 0000000..6cd6a9f --- /dev/null +++ b/src/extrai/core/shared/hierarchical_coordinator.py @@ -0,0 +1,32 @@ +import logging +from typing import Any + + +class HierarchicalCoordinator: + """ + Captures the model-level iteration policy (order, context passing, termination check) + shared between standard and batch pipelines. + """ + + def __init__(self, model_registry: Any, logger: logging.Logger): + self.model_registry = model_registry + self.logger = logger + + def get_models(self) -> list[type]: + """Returns the list of models in the order they should be processed.""" + return self.model_registry.models + + def is_final_step(self, index: int) -> bool: + """Checks if the current step is the final one in the hierarchy.""" + return index >= len(self.model_registry.models) - 1 + + def next_index(self, index: int) -> int: + """Returns the index of the next step.""" + return index + 1 + + def collect_previous_entities(self, completed_steps: list[dict]) -> list[dict]: + """ + Collects entities from previous steps to be used as context for the current step. + In the current implementation, this is a pass-through of all entities found so far. + """ + return completed_steps diff --git a/src/extrai/core/sqlmodel_generator.py b/src/extrai/core/sqlmodel_generator.py index 9be472d..307ab1f 100644 --- a/src/extrai/core/sqlmodel_generator.py +++ b/src/extrai/core/sqlmodel_generator.py @@ -1,35 +1,36 @@ -import logging -from typing import Any, Dict, Type, List as TypingList, Optional, Generator -import tempfile import importlib.util -import sys +import json +import logging import os +import sys +import tempfile import uuid -import json +from collections.abc import Generator from contextlib import contextmanager +from typing import Any from pydantic import ValidationError from sqlmodel import SQLModel +from extrai.core.analytics_collector import ( + WorkflowAnalyticsCollector, +) +from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.code_generation.python_builder import PythonModelBuilder from extrai.core.errors import ( - SQLModelCodeGeneratorError, - SQLModelInstantiationValidationError, - LLMInteractionError, ConfigurationError, + LLMAPICallError, LLMConfigurationError, + LLMInteractionError, LLMOutputParseError, LLMOutputValidationError, - LLMAPICallError, -) -from extrai.core.base_llm_client import BaseLLMClient -from extrai.core.analytics_collector import ( - WorkflowAnalyticsCollector, + SQLModelCodeGeneratorError, + SQLModelInstantiationValidationError, ) from extrai.core.prompt_builder import ( generate_sqlmodel_creation_system_prompt, generate_user_prompt_for_docs, ) -from extrai.core.code_generation.python_builder import PythonModelBuilder class SQLModelCodeGenerator: @@ -40,7 +41,7 @@ class SQLModelCodeGenerator: The generated code is then dynamically loaded. """ - _sqlmodel_description_schema_cache: Optional[Dict[str, Any]] = None + _sqlmodel_description_schema_cache: dict[str, Any] | None = None # Adjusted path to be relative to this file (sqlmodel_generator.py) _SCHEMA_FILE_PATH = os.path.join( os.path.dirname(__file__), "schemas", "sqlmodel_description_schema.json" @@ -49,8 +50,8 @@ class SQLModelCodeGenerator: def __init__( self, llm_client: BaseLLMClient, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - logger: Optional[logging.Logger] = None, + analytics_collector: WorkflowAnalyticsCollector | None = None, + logger: logging.Logger | None = None, ): """ Initializes the SQLModelCodeGenerator. @@ -71,7 +72,7 @@ def __init__( else: self.analytics_collector = analytics_collector - def _load_sqlmodel_description_schema(self) -> Dict[str, Any]: + def _load_sqlmodel_description_schema(self) -> dict[str, Any]: """ Loads the SQLModel description JSON schema from file. Caches the schema after the first load. @@ -87,7 +88,7 @@ def _load_sqlmodel_description_schema(self) -> Dict[str, Any]: current_dir, "schemas", "sqlmodel_description_schema.json" ) - with open(schema_file_path, "r") as f: + with open(schema_file_path) as f: schema = json.load(f) SQLModelCodeGenerator._sqlmodel_description_schema_cache = schema except FileNotFoundError: @@ -100,7 +101,7 @@ def _load_sqlmodel_description_schema(self) -> Dict[str, Any]: ) return SQLModelCodeGenerator._sqlmodel_description_schema_cache - def _generate_code_from_description(self, llm_json_output: Dict[str, Any]) -> str: + def _generate_code_from_description(self, llm_json_output: dict[str, Any]) -> str: """ Delegates the code generation to the PythonModelBuilder. """ @@ -156,10 +157,10 @@ def _import_module_from_path(self, module_name: str, path: str) -> Any: def _extract_models_from_module( self, module: Any, - model_names: TypingList[str], + model_names: list[str], generated_code: str, module_name: str, - ) -> Dict[str, Type[SQLModel]]: + ) -> dict[str, type[SQLModel]]: """Extracts and validates SQLModel classes from a loaded module.""" loaded_classes_map = {} for name_to_load in model_names: @@ -176,7 +177,7 @@ def _extract_models_from_module( return loaded_classes_map def _rebuild_and_validate_models( - self, loaded_classes: Dict[str, Type[SQLModel]], generated_code: str + self, loaded_classes: dict[str, type[SQLModel]], generated_code: str ): """Calls model_rebuild and validates instantiation for all loaded models.""" for cls_to_rebuild in loaded_classes.values(): @@ -198,8 +199,8 @@ def _rebuild_and_validate_models( ) from inst_e def _generate_and_load_class_from_description( - self, model_description: Dict[str, Any] - ) -> tuple[Dict[str, Type[SQLModel]], str]: + self, model_description: dict[str, Any] + ) -> tuple[dict[str, type[SQLModel]], str]: """ Generates SQLModel Python code from a given description, dynamically loads it, and returns the generated SQLModel classes and the generated code. @@ -246,11 +247,11 @@ def _generate_and_load_class_from_description( async def generate_and_load_models_via_llm( self, - input_documents: TypingList[str], + input_documents: list[str], user_task_description: str, num_model_revisions: int = 1, max_retries_per_model_revision: int = 2, - ) -> tuple[Dict[str, Type[SQLModel]], str]: + ) -> tuple[dict[str, type[SQLModel]], str]: """ Generates SQLModel description(s) via LLM, then uses internal methods to generate Python code and dynamically load the class(es). @@ -294,8 +295,8 @@ async def generate_and_load_models_via_llm( ) try: - validated_descriptions: TypingList[ - Dict[str, Any] + validated_descriptions: list[ + dict[str, Any] ] = await self.llm_client.generate_and_validate_raw_json_output( system_prompt=system_prompt_for_model_gen, user_prompt=user_prompt_for_model_gen, diff --git a/src/extrai/core/workflow_orchestrator.py b/src/extrai/core/workflow_orchestrator.py index ca15e89..2141f5c 100644 --- a/src/extrai/core/workflow_orchestrator.py +++ b/src/extrai/core/workflow_orchestrator.py @@ -1,19 +1,20 @@ # extrai/core/workflow_orchestrator.py -import asyncio import logging -from typing import List, Dict, Any, Type, Optional, Union -from extrai.core.base_llm_client import BaseLLMClient -from extrai.core.batch_models import BatchJobStatus, BatchProcessResult +from typing import Any + from sqlalchemy.orm import Session from sqlmodel import SQLModel +from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.batch.batch_pipeline import BatchPipeline +from extrai.core.batch_models import BatchJobStatus, BatchProcessResult + +from .analytics_collector import WorkflowAnalyticsCollector from .extraction_config import ExtractionConfig from .extraction_pipeline import ExtractionPipeline -from .batch_pipeline import BatchPipeline -from .result_processor import ResultProcessor from .model_registry import ModelRegistry -from .analytics_collector import WorkflowAnalyticsCollector +from .result_processor import ResultProcessor class WorkflowOrchestrator: @@ -29,17 +30,19 @@ class WorkflowOrchestrator: def __init__( self, - root_sqlmodel_class: Type[SQLModel], - llm_client: Union[BaseLLMClient, List[BaseLLMClient]], + root_sqlmodel_class: type[SQLModel], + llm_client: BaseLLMClient | list[BaseLLMClient], num_llm_revisions: int = 3, + num_counting_revisions: int = 3, max_validation_retries_per_revision: int = 2, consensus_threshold: float = 0.51, + counting_levenshtein_threshold: float = 0.85, conflict_resolver=None, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, + analytics_collector: WorkflowAnalyticsCollector | None = None, use_hierarchical_extraction: bool = False, use_structured_output: bool = False, - logger: Optional[logging.Logger] = None, - counting_llm_client: Optional[BaseLLMClient] = None, + logger: logging.Logger | None = None, + counting_llm_client: BaseLLMClient | None = None, ): self.logger = logger or self._create_default_logger() @@ -49,8 +52,10 @@ def __init__( # Create shared config self.config = ExtractionConfig( num_llm_revisions=num_llm_revisions, + num_counting_revisions=num_counting_revisions, max_validation_retries_per_revision=max_validation_retries_per_revision, consensus_threshold=consensus_threshold, + counting_levenshtein_threshold=counting_levenshtein_threshold, conflict_resolver=conflict_resolver, use_hierarchical_extraction=use_hierarchical_extraction, use_structured_output=use_structured_output, @@ -94,17 +99,17 @@ def _create_default_logger(self) -> logging.Logger: async def synthesize( self, - input_strings: List[str], - db_session_for_hydration: Optional[Session] = None, + input_strings: list[str], + db_session_for_hydration: Session | None = None, extraction_example_json: str = "", - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None, - custom_extraction_process: str = "", - custom_extraction_guidelines: str = "", - custom_final_checklist: str = "", - custom_context: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", count_entities: bool = False, - custom_counting_context: str = "", - ) -> List[Any]: + custom_counting_context: str | list[str] = "", + ) -> list[Any]: """Executes extraction pipeline and returns hydrated objects.""" if not input_strings: raise ValueError("Input strings list cannot be empty.") @@ -129,17 +134,17 @@ async def synthesize( async def synthesize_and_save( self, - input_strings: List[str], + input_strings: list[str], db_session: Session, extraction_example_json: str = "", - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None, - custom_extraction_process: str = "", - custom_extraction_guidelines: str = "", - custom_final_checklist: str = "", - custom_context: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", count_entities: bool = False, - custom_counting_context: str = "", - ) -> List[Any]: + custom_counting_context: str | list[str] = "", + ) -> list[Any]: """Synthesizes and persists objects in a single transaction.""" hydrated_objects = await self.synthesize( input_strings=input_strings, @@ -163,19 +168,19 @@ async def synthesize_and_save( async def synthesize_batch( self, - input_strings: List[str], + input_strings: list[str], db_session: Session, extraction_example_json: str = "", - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None, - custom_extraction_process: str = "", - custom_extraction_guidelines: str = "", - custom_final_checklist: str = "", - custom_context: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", count_entities: bool = False, - custom_counting_context: str = "", + custom_counting_context: str | list[str] = "", wait_for_completion: bool = False, poll_interval: int = 60, - ) -> Union[str, BatchProcessResult]: + ) -> str | BatchProcessResult: """Submits a batch job. Args: @@ -215,16 +220,16 @@ async def create_continuation_batch( db_session: Session, start_from_step_index: int, extraction_example_json: str = "", - extraction_example_object: Optional[Union[SQLModel, List[SQLModel]]] = None, - custom_extraction_process: str = "", - custom_extraction_guidelines: str = "", - custom_final_checklist: str = "", - custom_context: str = "", + extraction_example_object: SQLModel | list[SQLModel] | None = None, + custom_extraction_process: str | list[str] = "", + custom_extraction_guidelines: str | list[str] = "", + custom_final_checklist: str | list[str] = "", + custom_context: str | list[str] = "", count_entities: bool = False, - custom_counting_context: str = "", + custom_counting_context: str | list[str] = "", wait_for_completion: bool = False, poll_interval: int = 60, - ) -> Union[str, BatchProcessResult]: + ) -> str | BatchProcessResult: """ Creates a new batch cycle continuing from a previous batch's state. Copies completed steps up to start_from_step_index into the new batch. @@ -274,17 +279,8 @@ async def process_batch( db_session, ) - if result.status.name == "COMPLETED" and result.hydrated_objects: - try: - # Add the PK map from the batch pipeline to the main result processor - if result.original_pk_map: - self.result_processor.original_pk_map.update(result.original_pk_map) - - self.result_processor.persist(result.hydrated_objects, db_session) - except Exception as e: - self.logger.error(f"Persistence failed for batch {root_batch_id}: {e}") - result.message = f"Extraction successful but persistence failed: {e}" - raise + if result.status.name == "COMPLETED" and result.hydrated_objects and result.original_pk_map: + self.result_processor.original_pk_map.update(result.original_pk_map) return result @@ -295,62 +291,24 @@ async def monitor_batch_job( Polls the batch job status until it reaches a terminal state. Automatically handles hierarchical extraction steps by re-polling if an intermediate step is submitted. - - Useful for scripts or simple workflows where blocking is acceptable. """ - self.logger.info(f"Monitoring batch job {root_batch_id}...") - - while True: - status = await self.get_batch_status(root_batch_id, db_session) - self.logger.info(f"Batch Status: {status}") - - if status in [ - BatchJobStatus.READY_TO_PROCESS, - BatchJobStatus.COUNTING_READY_TO_PROCESS, - ]: - self.logger.info("Batch ready! Processing...") - result = await self.process_batch(root_batch_id, db_session) - - if result.status == BatchJobStatus.COMPLETED: - self.logger.info("Batch workflow completed successfully.") - return result - - elif result.status in [ - BatchJobStatus.PROCESSING, - BatchJobStatus.SUBMITTED, - ]: - self.logger.info( - f"Intermediate step processed (new status: {result.status}). Continuing workflow..." - ) - continue - - else: - self.logger.error(f"Batch processing failed: {result.message}") - return result - - elif status in [ - BatchJobStatus.COMPLETED, - BatchJobStatus.FAILED, - BatchJobStatus.CANCELLED, - ]: - # If it's already COMPLETED (e.g. checked before monitoring started), retrieve results - if status == BatchJobStatus.COMPLETED: - self.logger.info("Batch already completed. Retrieving results...") - return await self.process_batch(root_batch_id, db_session) - - self.logger.error(f"Batch job ended with status: {status}") - return BatchProcessResult( - status=status, message=f"Batch ended with status: {status}" - ) - - await asyncio.sleep(poll_interval) + return await self.batch_pipeline.monitor_batch_job( + root_batch_id, db_session, poll_interval + ) # ==================== Analytics ==================== - def get_analytics_report(self) -> Dict[str, Any]: + def get_analytics_report(self) -> dict[str, Any]: """Retrieves analytics report.""" return self.analytics_collector.get_report() def get_analytics_collector(self) -> WorkflowAnalyticsCollector: """Returns the analytics collector instance.""" return self.analytics_collector + + def get_total_steps(self, count_entities: bool) -> int: + """Calculates the total number of steps for a workflow.""" + num_models = len(self.model_registry.models) + if self.config.use_hierarchical_extraction and count_entities: + return num_models * 2 + return num_models diff --git a/src/extrai/data/model_prices.json b/src/extrai/data/model_prices.json new file mode 100644 index 0000000..d06d5de --- /dev/null +++ b/src/extrai/data/model_prices.json @@ -0,0 +1,821 @@ +{ + "updated_at": "2026-03-17", + "prices": [ + { + "id": "amazon-nova-micro", + "vendor": "amazon", + "name": "Amazon Nova Micro", + "input": 0.035, + "output": 0.14, + "input_cached": null + }, + { + "id": "amazon-nova-lite", + "vendor": "amazon", + "name": "Amazon Nova Lite", + "input": 0.06, + "output": 0.24, + "input_cached": null + }, + { + "id": "amazon-nova-pro", + "vendor": "amazon", + "name": "Amazon Nova Pro", + "input": 0.8, + "output": 3.2, + "input_cached": null + }, + { + "id": "amazon-nova-premier", + "vendor": "amazon", + "name": "Amazon Nova Premier", + "input": 2.5, + "output": 12.5, + "input_cached": null + }, + { + "id": "claude-3.7-sonnet", + "vendor": "anthropic", + "name": "Claude 3.7 Sonnet", + "input": 3, + "output": 15, + "input_cached": null + }, + { + "id": "claude-3.5-sonnet", + "vendor": "anthropic", + "name": "Claude 3.5 Sonnet", + "input": 3, + "output": 15, + "input_cached": null + }, + { + "id": "claude-3-opus", + "vendor": "anthropic", + "name": "Claude 3 Opus", + "input": 15, + "output": 75, + "input_cached": null + }, + { + "id": "claude-3-haiku", + "vendor": "anthropic", + "name": "Claude 3 Haiku", + "input": 0.25, + "output": 1.25, + "input_cached": null + }, + { + "id": "claude-3.5-haiku", + "vendor": "anthropic", + "name": "Claude 3.5 Haiku", + "input": 0.8, + "output": 4, + "input_cached": null + }, + { + "id": "claude-4.5-haiku", + "vendor": "anthropic", + "name": "Claude 4.5 Haiku", + "input": 1, + "output": 5, + "input_cached": null + }, + { + "id": "claude-sonnet-4.5", + "vendor": "anthropic", + "name": "Claude Sonnet 4 and 4.5 \u2264200k", + "input": 3, + "output": 15, + "input_cached": null + }, + { + "id": "claude-sonnet-4.5-200k", + "vendor": "anthropic", + "name": "Claude Sonnet 4 and 4.5 >200k", + "input": 6, + "output": 22.5, + "input_cached": null + }, + { + "id": "claude-opus-4", + "vendor": "anthropic", + "name": "Claude Opus 4", + "input": 15, + "output": 75, + "input_cached": null + }, + { + "id": "claude-opus-4-1", + "vendor": "anthropic", + "name": "Claude Opus 4.1", + "input": 15, + "output": 75, + "input_cached": null + }, + { + "id": "claude-opus-4-5", + "vendor": "anthropic", + "name": "Claude Opus 4.5", + "input": 5, + "output": 25, + "input_cached": null + }, + { + "id": "deepseek-chat", + "vendor": "deepseek", + "name": "DeepSeek Chat", + "input": 0.27, + "output": 1.1, + "input_cached": null + }, + { + "id": "deepseek-reasoner", + "vendor": "deepseek", + "name": "DeepSeek Reasoner", + "input": 0.55, + "output": 2.19, + "input_cached": null + }, + { + "id": "gemini-2.5-pro-preview-03-25", + "vendor": "google", + "name": "Gemini 2.5 Pro Preview \u2264200k", + "input": 1.25, + "output": 10, + "input_cached": null + }, + { + "id": "gemini-2.5-pro-preview-03-25-200k", + "vendor": "google", + "name": "Gemini 2.5 Pro Preview >200k", + "input": 2.5, + "output": 15, + "input_cached": null + }, + { + "id": "gemini-2.0-flash-lite", + "vendor": "google", + "name": "Gemini 2.0 Flash Lite", + "input": 0.075, + "output": 0.3, + "input_cached": null + }, + { + "id": "gemini-2.0-flash", + "vendor": "google", + "name": "Gemini 2.0 Flash", + "input": 0.1, + "output": 0.4, + "input_cached": null + }, + { + "id": "gemini-1.5-flash", + "vendor": "google", + "name": "Gemini 1.5 Flash \u2264128k", + "input": 0.075, + "output": 0.3, + "input_cached": null + }, + { + "id": "gemini-1.5-flash-128k", + "vendor": "google", + "name": "Gemini 1.5 Flash >128k", + "input": 0.15, + "output": 0.6, + "input_cached": null + }, + { + "id": "gemini-1.5-flash-8b", + "vendor": "google", + "name": "Gemini 1.5 Flash-8B \u2264128k", + "input": 0.0375, + "output": 0.15, + "input_cached": null + }, + { + "id": "gemini-1.5-flash-8b-128k", + "vendor": "google", + "name": "Gemini 1.5 Flash-8B >128k", + "input": 0.075, + "output": 0.3, + "input_cached": null + }, + { + "id": "gemini-1.5-pro", + "vendor": "google", + "name": "Gemini 1.5 Pro \u2264128k", + "input": 1.25, + "output": 5, + "input_cached": null + }, + { + "id": "gemini-1.5-pro-128k", + "vendor": "google", + "name": "Gemini 1.5 Pro >128k", + "input": 2.5, + "output": 10, + "input_cached": null + }, + { + "id": "gemini-2.5-flash", + "vendor": "google", + "name": "Gemini 2.5 Flash", + "input": 0.3, + "output": 2.5, + "input_cached": 0.03 + }, + { + "id": "gemini-2.5-flash-lite", + "vendor": "google", + "name": "Gemini 2.5 Flash-Lite", + "input": 0.1, + "output": 0.4, + "input_cached": 0.01 + }, + { + "id": "gemini-2.5-flash-preview-09-2025", + "vendor": "google", + "name": "Gemini 2.5 Flash Preview (09-2025)", + "input": 0.3, + "output": 2.5, + "input_cached": 0.03 + }, + { + "id": "gemini-2.5-pro", + "vendor": "google", + "name": "Gemini 2.5 Pro \u2264200k", + "input": 1.25, + "output": 10, + "input_cached": 0.125 + }, + { + "id": "gemini-2.5-pro-200k", + "vendor": "google", + "name": "Gemini 2.5 Pro >200k", + "input": 2.5, + "output": 15, + "input_cached": 0.25 + }, + { + "id": "gemini-3-pro-preview", + "vendor": "google", + "name": "Gemini 3 Pro \u2264200k", + "input": 2, + "output": 12, + "input_cached": null + }, + { + "id": "gemini-3-pro-preview-200k", + "vendor": "google", + "name": "Gemini 3 Pro >200k", + "input": 4, + "output": 18, + "input_cached": null + }, + { + "id": "gemini-3-flash-preview", + "vendor": "google", + "name": "Gemini 3 Flash Preview", + "input": 0.5, + "output": 3, + "input_cached": null + }, + { + "id": "gemini-3-1-pro-preview", + "vendor": "google", + "name": "Gemini 3.1 Pro \u2264200k", + "input": 2, + "output": 12, + "input_cached": null + }, + { + "id": "gemini-3-1-pro-preview-200k", + "vendor": "google", + "name": "Gemini 3.1 Pro >200k", + "input": 4, + "output": 18, + "input_cached": null + }, + { + "id": "gemini-3.1-flash-lite-preview", + "vendor": "google", + "name": "Gemini 3.1 Flash-Lite", + "input": 0.25, + "output": 1.5, + "input_cached": 0.025 + }, + { + "id": "minimax-m2", + "vendor": "minimax", + "name": "MiniMax M2", + "input": 0.3, + "output": 1.2, + "input_cached": null + }, + { + "id": "pixtral-12b", + "vendor": "mistral", + "name": "Pixtral 12B", + "input": 0.15, + "output": 0.15, + "input_cached": null + }, + { + "id": "mistral-small-latest", + "vendor": "mistral", + "name": "Mistral Small 3.1", + "input": 0.1, + "output": 0.3, + "input_cached": null + }, + { + "id": "mistral-medium-2505", + "vendor": "mistral", + "name": "Mistral Medium 3", + "input": 0.4, + "output": 2, + "input_cached": null + }, + { + "id": "mistral-nemo", + "vendor": "mistral", + "name": "Mistral NeMo", + "input": 0.15, + "output": 0.15, + "input_cached": null + }, + { + "id": "open-mistral-7b", + "vendor": "mistral", + "name": "Mistral 7B", + "input": 0.25, + "output": 0.25, + "input_cached": null + }, + { + "id": "open-mixtral-8x7b", + "vendor": "mistral", + "name": "Mixtral 8x7B", + "input": 0.7, + "output": 0.7, + "input_cached": null + }, + { + "id": "open-mixtral-8x22b", + "vendor": "mistral", + "name": "Mixtral 8x22B", + "input": 2, + "output": 6, + "input_cached": null + }, + { + "id": "mistral-large-latest", + "vendor": "mistral", + "name": "Mistral Large 24.11", + "input": 2, + "output": 6, + "input_cached": null + }, + { + "id": "pixtral-large-latest", + "vendor": "mistral", + "name": "Pixtral Large", + "input": 2, + "output": 6, + "input_cached": null + }, + { + "id": "mistral-saba-latest", + "vendor": "mistral", + "name": "Mistral Saba", + "input": 0.2, + "output": 0.6, + "input_cached": null + }, + { + "id": "codestral-latest", + "vendor": "mistral", + "name": "Codestral", + "input": 0.3, + "output": 0.9, + "input_cached": null + }, + { + "id": "ministral-8b-latest", + "vendor": "mistral", + "name": "Ministral 8B 24.10", + "input": 0.1, + "output": 0.1, + "input_cached": null + }, + { + "id": "ministral-3b-latest", + "vendor": "mistral", + "name": "Ministral 3B 24.10", + "input": 0.04, + "output": 0.04, + "input_cached": null + }, + { + "id": "magistral-medium-latest", + "vendor": "mistral", + "name": "Magistral Medium", + "input": 2, + "output": 5, + "input_cached": null + }, + { + "id": "kimi-k2-0905-preview", + "vendor": "moonshot-ai", + "name": "Kimi K2 0905 Preview", + "input": 0.6, + "output": 2.5, + "input_cached": 0.15 + }, + { + "id": "kimi-k2-0711-preview", + "vendor": "moonshot-ai", + "name": "Kimi K2 0711 Preview", + "input": 0.6, + "output": 2.5, + "input_cached": 0.15 + }, + { + "id": "kimi-k2-turbo-preview", + "vendor": "moonshot-ai", + "name": "Kimi K2 Turbo Preview", + "input": 1.15, + "output": 8.0, + "input_cached": 0.15 + }, + { + "id": "kimi-k2-thinking", + "vendor": "moonshot-ai", + "name": "Kimi K2 Thinking", + "input": 0.6, + "output": 2.5, + "input_cached": 0.15 + }, + { + "id": "kimi-k2-thinking-turbo", + "vendor": "moonshot-ai", + "name": "Kimi K2 Thinking Turbo", + "input": 1.15, + "output": 8.0, + "input_cached": 0.15 + }, + { + "id": "text-davinci-003", + "vendor": "openai", + "name": "GPT-3 Text Davinci 003", + "input": 20, + "output": 20, + "input_cached": null + }, + { + "id": "gpt-4.5", + "vendor": "openai", + "name": "GPT-4.5", + "input": 75, + "output": 150, + "input_cached": 37.5 + }, + { + "id": "gpt-4o", + "vendor": "openai", + "name": "GPT-4o", + "input": 2.5, + "output": 10, + "input_cached": 1.25 + }, + { + "id": "gpt-4o-mini", + "vendor": "openai", + "name": "GPT-4o Mini", + "input": 0.15, + "output": 0.6, + "input_cached": 0.075 + }, + { + "id": "chatgpt-4o-latest", + "vendor": "openai", + "name": "ChatGPT 4o Latest", + "input": 5, + "output": 15, + "input_cached": null + }, + { + "id": "o1-preview", + "vendor": "openai", + "name": "o1 and o1-preview", + "input": 15, + "output": 60, + "input_cached": 7.5 + }, + { + "id": "o1-pro", + "vendor": "openai", + "name": "o1 Pro", + "input": 150, + "output": 600, + "input_cached": null + }, + { + "id": "o1-mini", + "vendor": "openai", + "name": "o1-mini", + "input": 1.1, + "output": 4.4, + "input_cached": 0.55 + }, + { + "id": "o3-mini", + "vendor": "openai", + "name": "o3-mini", + "input": 1.1, + "output": 4.4, + "input_cached": 0.55 + }, + { + "id": "gpt-4.1", + "vendor": "openai", + "name": "GPT-4.1", + "input": 2, + "output": 8, + "input_cached": 0.5 + }, + { + "id": "gpt-4.1-mini", + "vendor": "openai", + "name": "GPT-4.1 Mini", + "input": 0.4, + "output": 1.6, + "input_cached": 0.1 + }, + { + "id": "gpt-4.1-nano", + "vendor": "openai", + "name": "GPT-4.1 Nano", + "input": 0.1, + "output": 0.4, + "input_cached": 0.025 + }, + { + "id": "o3", + "vendor": "openai", + "name": "o3", + "input": 10, + "output": 40, + "input_cached": 0.5 + }, + { + "id": "o4-mini", + "vendor": "openai", + "name": "o4-mini", + "input": 1.1, + "output": 4.4, + "input_cached": 0.275 + }, + { + "id": "gpt-5-nano", + "vendor": "openai", + "name": "GPT-5 Nano", + "input": 0.05, + "output": 0.4, + "input_cached": 0.005 + }, + { + "id": "gpt-5-mini", + "vendor": "openai", + "name": "GPT-5 Mini", + "input": 0.25, + "output": 2, + "input_cached": 0.025 + }, + { + "id": "gpt-5", + "vendor": "openai", + "name": "GPT-5", + "input": 1.25, + "output": 10, + "input_cached": 0.125 + }, + { + "id": "gpt-image-1", + "vendor": "openai", + "name": "gpt-image-1 (image gen)", + "input": 10, + "output": 40, + "input_cached": 1.25 + }, + { + "id": "gpt-image-1-mini", + "vendor": "openai", + "name": "gpt-image-1-mini (image gen)", + "input": 2, + "output": 8, + "input_cached": 0.2 + }, + { + "id": "gpt-5-pro", + "vendor": "openai", + "name": "GPT-5 Pro", + "input": 15, + "output": 120, + "input_cached": null + }, + { + "id": "o3-pro", + "vendor": "openai", + "name": "o3 Pro", + "input": 20, + "output": 80, + "input_cached": null + }, + { + "id": "o4-mini-deep-research", + "vendor": "openai", + "name": "o4-mini Deep Research", + "input": 2, + "output": 8, + "input_cached": 0.5 + }, + { + "id": "o3-deep-research", + "vendor": "openai", + "name": "o3 Deep Research", + "input": 10, + "output": 40, + "input_cached": 2.5 + }, + { + "id": "gpt-5.1-codex-mini", + "vendor": "openai", + "name": "GPT-5.1 Codex mini", + "input": 0.25, + "output": 2.0, + "input_cached": 0.025 + }, + { + "id": "gpt-5.1-codex", + "vendor": "openai", + "name": "GPT-5.1 Codex", + "input": 1.25, + "output": 10.0, + "input_cached": 0.125 + }, + { + "id": "gpt-5.1", + "vendor": "openai", + "name": "GPT-5.1", + "input": 1.25, + "output": 10.0, + "input_cached": 0.125 + }, + { + "id": "gpt-5.2", + "vendor": "openai", + "name": "GPT-5.2", + "input": 1.75, + "output": 14.0, + "input_cached": 0.175 + }, + { + "id": "gpt-5.2-pro", + "vendor": "openai", + "name": "GPT-5.2 Pro", + "input": 21.0, + "output": 168.0, + "input_cached": null + }, + { + "id": "gpt-5.4", + "vendor": "openai", + "name": "GPT-5.4 \u2264272k", + "input": 2.5, + "output": 15.0, + "input_cached": 0.25 + }, + { + "id": "gpt-5.4-272k", + "vendor": "openai", + "name": "GPT-5.4 >272k", + "input": 5.0, + "output": 22.5, + "input_cached": 0.5 + }, + { + "id": "gpt-5.4-pro", + "vendor": "openai", + "name": "GPT-5.4 Pro \u2264272k", + "input": 30.0, + "output": 180.0, + "input_cached": null + }, + { + "id": "gpt-5.4-pro-272k", + "vendor": "openai", + "name": "GPT-5.4 Pro >272k", + "input": 60.0, + "output": 270.0, + "input_cached": null + }, + { + "id": "gpt-5.4-mini", + "vendor": "openai", + "name": "GPT-5.4 Mini", + "input": 0.75, + "output": 4.5, + "input_cached": 0.075 + }, + { + "id": "gpt-5.4-nano", + "vendor": "openai", + "name": "GPT-5.4 Nano", + "input": 0.2, + "output": 1.25, + "input_cached": 0.02 + }, + { + "id": "grok-3", + "vendor": "xai", + "name": "Grok 3", + "input": 3, + "output": 15, + "input_cached": 0.75 + }, + { + "id": "grok-3-mini", + "vendor": "xai", + "name": "Grok 3 Mini", + "input": 0.3, + "output": 0.5, + "input_cached": 0.075 + }, + { + "id": "grok-4-fast", + "vendor": "xai", + "name": "Grok 4 Fast", + "input": 0.2, + "output": 0.5, + "input_cached": 0.05 + }, + { + "id": "grok-4", + "vendor": "xai", + "name": "Grok 4 \u2264128k", + "input": 3, + "output": 15, + "input_cached": 0.75 + }, + { + "id": "grok-4-128k", + "vendor": "xai", + "name": "Grok 4 >128k", + "input": 6, + "output": 30, + "input_cached": 0.75 + }, + { + "id": "grok-4-fast", + "vendor": "xai", + "name": "Grok 4 Fast \u2264128k", + "input": 0.2, + "output": 0.5, + "input_cached": 0.05 + }, + { + "id": "grok-4-fast-128k", + "vendor": "xai", + "name": "Grok 4 Fast >128k", + "input": 0.4, + "output": 1.0, + "input_cached": 0.05 + }, + { + "id": "grok-4-fast-reasoning", + "vendor": "xai", + "name": "Grok 4 Fast Reasoning \u2264128k", + "input": 0.2, + "output": 0.5, + "input_cached": 0.05 + }, + { + "id": "grok-4-fast-reasoning-128k", + "vendor": "xai", + "name": "Grok 4 Fast Reasoning >128k", + "input": 0.4, + "output": 1.0, + "input_cached": 0.05 + }, + { + "id": "grok-code-fast-1", + "vendor": "xai", + "name": "Grok Code Fast 1", + "input": 0.2, + "output": 1.5, + "input_cached": 0.02 + } + ] +} \ No newline at end of file diff --git a/src/extrai/llm_providers/__init__.py b/src/extrai/llm_providers/__init__.py index 8810b56..b0fee95 100644 --- a/src/extrai/llm_providers/__init__.py +++ b/src/extrai/llm_providers/__init__.py @@ -1,13 +1,17 @@ +from .base_google_client import BaseGoogleGenAIClient +from .deepseek_client import DeepSeekClient from .gemini_client import GeminiClient +from .generic_openai_client import GenericOpenAIClient from .huggingface_client import HuggingFaceClient -from .deepseek_client import DeepSeekClient from .ollama_client import OllamaClient from .openai_client import OpenAIClient -from .generic_openai_client import GenericOpenAIClient +from .vertex_ai_client import VertexAIClient __all__ = [ # Clients + "BaseGoogleGenAIClient", "GeminiClient", + "VertexAIClient", "HuggingFaceClient", "DeepSeekClient", "OllamaClient", diff --git a/src/extrai/llm_providers/base_google_client.py b/src/extrai/llm_providers/base_google_client.py new file mode 100644 index 0000000..fcbc7da --- /dev/null +++ b/src/extrai/llm_providers/base_google_client.py @@ -0,0 +1,267 @@ +import asyncio +import json +import logging +from typing import Any + +from extrai.core.base_llm_client import ProviderBatchStatus +from .generic_openai_client import GenericOpenAIClient + +try: + from google import genai +except ImportError: + genai = None + + +def _resolve_refs(s, root=None): + if root is None: + root = s + if isinstance(s, dict): + s.pop("additionalProperties", None) + if "$ref" in s: + ref_path = s["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + import copy + resolved_def = copy.deepcopy(root.get("$defs", {}).get(def_name, {})) + resolved_def = _resolve_refs(resolved_def, root) + new_schema = {} + for k, v in s.items(): + if k != "$ref": + if k.startswith("$"): + continue + new_schema[k] = _resolve_refs(v, root) + for k, v in resolved_def.items(): + if k.startswith("$"): + continue + new_schema[k] = v + return new_schema + new_schema = {} + for k, v in s.items(): + if k.startswith("$"): + continue + new_schema[k] = _resolve_refs(v, root) + return new_schema + elif isinstance(s, list): + return [_resolve_refs(item, root) for item in s] + return s + + +class BaseGoogleGenAIClient(GenericOpenAIClient): + """ + Base client for Google models (Gemini and Vertex AI) that share inline batching logic using google-genai. + """ + + def __init__( + self, + api_key: str, + model_name: str, + base_url: str, + temperature: float | None = 0.3, + logger: logging.Logger | None = None, + ): + super().__init__( + api_key=api_key, + model_name=model_name, + base_url=base_url, + temperature=temperature, + logger=logger, + ) + self.genai_client = None + + def create_inline_batch_job( + self, src: list[Any], config: dict | None = None + ) -> Any: + """ + Creates an inline batch job using the Google GenAI SDK. + """ + if not self.genai_client: + raise ImportError("google-genai package is required for this feature") + + # Ensure model name has 'models/' prefix if not present, as often required by GenAI SDK + model = self.model_name + if not model.startswith("models/"): + model = f"models/{model}" + + return self.genai_client.batches.create( + model=model, + src=src, + config=config, + ) + + def get_inline_batch_job(self, name: str) -> Any: + """ + Retrieves an inline batch job using the Google GenAI SDK. + """ + if not self.genai_client: + raise ImportError("google-genai package is required for this feature") + return self.genai_client.batches.get(name=name) + + async def create_batch_job( + self, + requests: list[dict[str, Any]], + endpoint: str = "/v1/chat/completions", + completion_window: str = "24h", + metadata: dict[str, str] | None = None, + response_model: Any | None = None, + ) -> Any: + """ + Creates a batch job. If google-genai is available, uses inline batching. + Overridden to support Google GenAI inline batching logic. + """ + if not self.genai_client: + return await super().create_batch_job( + requests, endpoint, completion_window, metadata, response_model + ) + + google_requests = [] + for req in requests: + body = req.get("body", {}) + messages = body.get("messages", []) + + system_instruction = None + contents = [] + + for msg in messages: + role = msg.get("role") + content = msg.get("content") + if role == "system": + system_instruction = content + elif role == "user": + contents.append({"role": "user", "parts": [{"text": content}]}) + + config = {"response_mime_type": "application/json"} + + if response_model: + try: + # Gemini does not support 'additionalProperties' in schema + import pydantic + + if hasattr(response_model, "model_json_schema"): + schema = response_model.model_json_schema() + else: + schema = pydantic.TypeAdapter(response_model).json_schema() + + schema = _resolve_refs(schema) + config["response_schema"] = schema + + if self.logger: + self.logger.info( + "BaseGoogleGenAIClient: Enabled Structured Output with Sanitized Schema." + ) + except Exception as e: + if self.logger: + self.logger.error(f"Failed to generate sanitized schema: {e}") + # Fallback + config["response_schema"] = response_model + elif "response_format" in body: + rf = body["response_format"] + if rf.get("type") == "json_schema" and "json_schema" in rf: + # Pass the schema dict directly + raw_schema = rf["json_schema"].get("schema", {}) + config["response_schema"] = _resolve_refs(raw_schema) + if self.logger: + self.logger.info( + "BaseGoogleGenAIClient: Enabled Structured Output (response_schema dict) for batch request." + ) + + if system_instruction: + config["system_instruction"] = system_instruction + + google_requests.append({"contents": contents, "config": config}) + + if self.logger: + self.logger.debug( + f"Submitting {len(google_requests)} requests to inline batch." + ) + if google_requests: + self.logger.debug( + f"Sample request 0: {json.dumps(google_requests[0], default=str)}" + ) + + # Run sync call in thread to avoid blocking + return await asyncio.to_thread( + self.create_inline_batch_job, src=google_requests + ) + + async def retrieve_batch_job(self, batch_id: str) -> Any: + """ + Retrieves batch job status. + """ + if not self.genai_client: + return await super().retrieve_batch_job(batch_id) + return await asyncio.to_thread(self.get_inline_batch_job, name=batch_id) + + async def retrieve_batch_results(self, batch_id: str) -> str: + """ + Retrieves batch results. + For Google GenAI, results are inline in the job object. + """ + if not self.genai_client: + return await super().retrieve_batch_results(batch_id) + + job = await self.retrieve_batch_job(batch_id) + + if self.logger: + self.logger.debug(f"Retrieved batch job: {batch_id}") + try: + self.logger.debug( + f"Job state/status: {getattr(job, 'state', getattr(job, 'status', 'unknown'))}" + ) + except Exception: + pass + + output_lines = [] + if hasattr(job, "dest") and hasattr(job.dest, "inlined_responses"): + for resp in job.dest.inlined_responses: + content_text = "" + if resp.response: + if hasattr(resp.response, "text"): + content_text = resp.response.text + else: + content_text = str(resp.response) + + openai_resp = { + "id": "batch_req", + "response": { + "status_code": 200, + "body": {"choices": [{"message": {"content": content_text}}]}, + }, + } + + # Extract usage if available + if resp.response and hasattr(resp.response, "usage_metadata"): + usage = resp.response.usage_metadata + openai_resp["response"]["body"]["usage"] = { + "prompt_tokens": getattr(usage, "prompt_token_count", 0), + "completion_tokens": getattr(usage, "candidates_token_count", 0), + "total_tokens": getattr(usage, "total_token_count", 0), + } + + output_lines.append(json.dumps(openai_resp)) + + return "\n".join(output_lines) + + async def get_batch_status(self, batch_id: str) -> "ProviderBatchStatus": + """ + Retrieves batch job status and maps it to a standardized format. + """ + if not self.genai_client: + return await super().get_batch_status(batch_id) + + job = await asyncio.to_thread(self.get_inline_batch_job, name=batch_id) + + # Defensive access to state name, handling strings or enum-like objects + state = getattr(job, "state", "unknown") + state_name = getattr(state, "name", str(state)) + + if state_name in ("JOB_STATE_SUCCEEDED", "SUCCEEDED"): + return ProviderBatchStatus.COMPLETED + elif state_name in ("JOB_STATE_FAILED", "JOB_STATE_EXPIRED", "FAILED", "EXPIRED"): + return ProviderBatchStatus.FAILED + elif state_name in ("JOB_STATE_PENDING", "PENDING"): + return ProviderBatchStatus.PENDING + elif state_name in ("JOB_STATE_CANCELLED", "CANCELLED"): + return ProviderBatchStatus.CANCELLED + + # All other states (RUNNING, PROCESSING, VALIDATING, UNSPECIFIED) are treated as PROCESSING + return ProviderBatchStatus.PROCESSING diff --git a/src/extrai/llm_providers/deepseek_client.py b/src/extrai/llm_providers/deepseek_client.py index 92b6a4c..9bfd881 100644 --- a/src/extrai/llm_providers/deepseek_client.py +++ b/src/extrai/llm_providers/deepseek_client.py @@ -1,5 +1,5 @@ import logging -from typing import Optional + from .generic_openai_client import GenericOpenAIClient @@ -13,8 +13,8 @@ def __init__( api_key: str, model_name: str = "deepseek-chat", base_url: str = "https://api.deepseek.com", - temperature: Optional[float] = 0.3, - logger: Optional[logging.Logger] = None, + temperature: float | None = 0.3, + logger: logging.Logger | None = None, ): """ Initializes the DeepSeekClient. diff --git a/src/extrai/llm_providers/gemini_client.py b/src/extrai/llm_providers/gemini_client.py index 3e948f7..4be4ae4 100644 --- a/src/extrai/llm_providers/gemini_client.py +++ b/src/extrai/llm_providers/gemini_client.py @@ -1,39 +1,36 @@ import logging -import json -from typing import Optional, Dict, Any, List -from extrai.utils.rate_limiter import AsyncRateLimiter -from .generic_openai_client import GenericOpenAIClient -from extrai.core.errors import LLMAPICallError -from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from typing import Any +from .base_google_client import BaseGoogleGenAIClient -class GeminiClient(GenericOpenAIClient): +try: + from google import genai +except ImportError: + genai = None + + +class GeminiClient(BaseGoogleGenAIClient): """ - LLM Client specifically for Google Gemini models, using an OpenAI-compatible interface. - Inherits from GenericOpenAIClient to leverage common revision generation and validation logic. + LLM Client specifically for Gemini models, inheriting from BaseGoogleGenAIClient. """ def __init__( self, api_key: str, model_name: str = "gemini-2.5-flash", - base_url: str = "https://generativelanguage.googleapis.com/v1beta/", - temperature: Optional[float] = 0.3, - logger: Optional[logging.Logger] = None, - requests_per_minute: int = 15, - tokens_per_minute: int = 32000, + base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/", + temperature: float | None = 0.3, + logger: logging.Logger | None = None, ): """ Initializes the GeminiClient. Args: - api_key: The API key for the Gemini service. - model_name: The specific Gemini model identifier. - base_url: The base URL for the Gemini API (OpenAI-compatible endpoint). + api_key: The API key for Gemini. + model_name: The model name to use (e.g., "gemini-2.5-flash"). + base_url: The base URL for the Gemini API - openai compatible. temperature: The sampling temperature for generation. logger: Logger. - requests_per_minute: Maximum number of requests allowed per minute. - tokens_per_minute: Maximum number of input tokens allowed per minute. """ super().__init__( api_key=api_key, @@ -42,311 +39,7 @@ def __init__( temperature=temperature, logger=logger, ) - self.request_limiter = AsyncRateLimiter( - max_capacity=requests_per_minute, period=60.0 - ) - self.token_limiter = AsyncRateLimiter( - max_capacity=tokens_per_minute, period=60.0 - ) - self.logger = logger - - async def _execute_llm_call( - self, - system_prompt: str, - user_prompt: str, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - ) -> str: - """ - Executes the LLM call with rate limiting. - """ - # Estimate token count (simple character heuristic) - # 1 token ~= 4 chars - estimated_tokens = (len(system_prompt) + len(user_prompt)) // 4 - # Minimum 1 token - estimated_tokens = max(1, estimated_tokens) - - self.logger.warning("estimated tokens: " + str(estimated_tokens)) - # Acquire rate limits - await self.request_limiter.acquire(1) - await self.token_limiter.acquire(estimated_tokens) - - return await super()._execute_llm_call( - system_prompt, user_prompt, analytics_collector=analytics_collector - ) - - def _sanitize_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]: - """ - Ensures JSON schema compatibility with Gemini REST API by inlining $defs. - Gemini REST API does not support $defs/$ref in schema payloads. - This implements a dependency-free version of the 'jsonref' workaround. - """ - import copy - - schema = copy.deepcopy(schema) - defs = schema.pop("$defs", {}) or schema.pop("definitions", {}) - - def _resolve(node: Any) -> Any: - if isinstance(node, dict): - if "$ref" in node: - ref = node["$ref"].split("/")[-1] - if ref in defs: - return _resolve(defs[ref]) - return {k: _resolve(v) for k, v in node.items()} - elif isinstance(node, list): - return [_resolve(x) for x in node] - return node - - return _resolve(schema) - - async def create_batch_job( - self, - requests: List[Dict[str, Any]], - endpoint: str = None, - completion_window: str = None, - metadata: Optional[Dict[str, str]] = None, - ) -> Any: - """ - Creates a Gemini batch job using the native REST API (Inline Requests). - """ - import httpx - - # Convert requests to Gemini 'contents' format - gemini_requests = [] - for i, req in enumerate(requests): - body = req.get("body", req) - custom_id = req.get("custom_id", f"req-{i}") - - messages = body.get("messages", []) - contents = [] - system_instruction = None - - for msg in messages: - role = msg.get("role") - content = msg.get("content") - if role == "system": - system_instruction = {"parts": [{"text": content}]} - elif role == "user": - contents.append({"role": "user", "parts": [{"text": content}]}) - elif role == "assistant": - contents.append({"role": "model", "parts": [{"text": content}]}) - - # Construct the request object - # Note: We need to ensure we use the correct model format - # API expects model resource name in URL usually, but can also be in request? - # Inline requests structure: { "request": { ... }, "metadata": ... } - - # Map configuration - generation_config = {} - if "temperature" in body: - generation_config["temperature"] = body["temperature"] - if "max_tokens" in body: - generation_config["maxOutputTokens"] = body["max_tokens"] - - # Map OpenAI response_format to Gemini generationConfig - response_format = body.get("response_format", {}) - if response_format.get("type") == "json_schema": - generation_config["responseMimeType"] = "application/json" - if ( - "json_schema" in response_format - and "schema" in response_format["json_schema"] - ): - raw_schema = response_format["json_schema"]["schema"] - generation_config["responseJsonSchema"] = ( - self._sanitize_schema_for_gemini(raw_schema) - ) - elif response_format.get("type") == "json_object": - generation_config["responseMimeType"] = "application/json" - - g_req_inner = {"contents": contents, "generationConfig": generation_config} - if system_instruction: - g_req_inner["system_instruction"] = system_instruction - - gemini_requests.append( - {"request": g_req_inner, "metadata": {"key": custom_id}} - ) - - # Construct Payload - payload = { - "batch": {"input_config": {"requests": {"requests": gemini_requests}}} - } - if metadata and "display_name" in metadata: - payload["batch"]["display_name"] = metadata["display_name"] - - url = f"{self.base_url}models/{self.model_name}:batchGenerateContent?key={self.api_key}" - - async with httpx.AsyncClient() as client: - resp = await client.post( - url, json=payload, headers={"Content-Type": "application/json"} - ) - if resp.status_code >= 400: - raise LLMAPICallError( - f"Gemini Batch Creation Failed: {resp.status_code} - {resp.text}" - ) - - return self._wrap_batch_response(resp.json()) - - async def retrieve_batch_job(self, batch_id: str) -> Any: - """ - Retrieves batch status using Native REST API. - """ - import httpx - - # batch_id is expected to be the full resource name e.g., "batches/12345" - url = f"{self.base_url}{batch_id}?key={self.api_key}" - - async with httpx.AsyncClient() as client: - resp = await client.get(url) - if resp.status_code >= 400: - raise LLMAPICallError( - f"Gemini Batch Retrieve Failed: {resp.status_code} - {resp.text}" - ) - return self._wrap_batch_response(resp.json()) - - async def cancel_batch_job(self, batch_id: str) -> Any: - """ - Cancels batch job using Native REST API. - """ - import httpx - - url = f"{self.base_url}{batch_id}:cancel?key={self.api_key}" - - async with httpx.AsyncClient() as client: - resp = await client.post(url) - if resp.status_code >= 400: - raise LLMAPICallError( - f"Gemini Batch Cancel Failed: {resp.status_code} - {resp.text}" - ) - # Empty response usually on success? Or updated metadata. - # We can re-fetch or just return true/empty - return True - - def _wrap_batch_response(self, data: Dict[str, Any]) -> Any: - class GeminiBatchJob: - def __init__(self, data): - self.id = data.get("name") # "batches/..." - # State might be at top level or inside metadata (depending on API version/endpoint) - self.status = data.get("state") - if not self.status and "metadata" in data: - self.status = data["metadata"].get("state") - self.original_data = data - - # If finished, it might have results - # Inline results structure? - # The docs say: response.inlinedResponses - # We can try to extract output_file_id if it exists, or handle inline. - # OpenAI interface expects output_file_id for retrieve_batch_results. - # If inline, we can't provide a file ID. - # We'll need retrieve_batch_results to handle the batch_id as file_id for inline. - - return GeminiBatchJob(data) - - async def retrieve_batch_results(self, file_id: str) -> str: - """ - Retrieves batch results. - For Gemini Inline, 'file_id' should be the batch ID. - """ - # If the batch job had 'responsesFile', we download it. - # If it had 'inlinedResponses', we format it as JSONL. - - # We need to fetch the batch first to see which one it is (or assume we have the object) - # But this method usually takes just an ID. - # So we fetch the batch. - - batch = await self.retrieve_batch_job(file_id) - data = batch.original_data - - # Check for inline responses - # Structure: data.get("response", {}).get("inlinedResponses", []) - # Actually docs say: batch_job.dest.inlined_responses (SDK) or .response.inlinedResponses (REST) - - # REST: .response.inlinedResponses - response_section = data.get( - "response", {} - ) # Not to be confused with 'responses' - # Wait, the example output JSON says: - # "response": { "inlinedResponses": [ ... ] } OR "response": { "responsesFile": "..." } - - inlined = response_section.get("inlinedResponses") - if inlined: - # Handle case where inlined might be a dict (unexpected but observed) - if isinstance(inlined, dict): - # If it's a dict, maybe the list is nested or it's a map? - self.logger.warning( - f"inlinedResponses is a dict, keys: {list(inlined.keys())}" - ) - # Try to find the actual list - if "inlinedResponses" in inlined: - inlined = inlined["inlinedResponses"] - elif "responses" in inlined: - inlined = inlined["responses"] - elif "results" in inlined: - inlined = inlined["results"] - else: - # Fallback: treat values as the list if they look like items - inlined = list(inlined.values()) - - # Convert to JSONL string to match OpenAI format - lines = [] - for item in inlined: - # item has 'response' or 'error' and 'requestKey' (if we used metadata.key) - # We should map back to OpenAI-like format if possible - if isinstance(item, str): - self.logger.warning( - f"Unexpected string item in inlinedResponses: {item}" - ) - continue - lines.append(json.dumps(item)) - return "\n".join(lines) - - file_name = response_section.get("responsesFile") - if file_name: - # Download file - # url: https://generativelanguage.googleapis.com/download/v1beta/$responses_file_name:download?alt=media - import httpx - - url = f"https://generativelanguage.googleapis.com/download/v1beta/{file_name}:download?alt=media&key={self.api_key}" - async with httpx.AsyncClient() as client: - resp = await client.get(url) - if resp.status_code >= 400: - raise LLMAPICallError( - f"Gemini Result Download Failed: {resp.status_code}" - ) - return resp.text - - raise LLMAPICallError("No results found in batch (or batch not complete).") - - async def list_batch_jobs( - self, limit: int = 20, after: Optional[str] = None - ) -> Any: - import httpx - - url = f"{self.base_url}batches?key={self.api_key}&pageSize={limit}" - if after: - url += f"&pageToken={after}" - - async with httpx.AsyncClient() as client: - resp = await client.get(url) - if resp.status_code >= 400: - raise LLMAPICallError(f"Gemini List Batches Failed: {resp.text}") - - data = resp.json() - # Wrap list? - return data - - def extract_content_from_batch_response( - self, response: Dict[str, Any] - ) -> Optional[str]: - """ - Extracts content from Gemini batch response item. - """ - if "error" in response: - self.logger.error(f"Batch item contains error: {response['error']}") - return None - - if "response" in response and "candidates" in response["response"]: - candidates = response["response"]["candidates"] - if candidates and "content" in candidates[0]: - parts = candidates[0]["content"].get("parts", []) - if parts: - return parts[0].get("text") - return None + if genai: + self.genai_client = genai.Client(api_key=api_key) + else: + self.genai_client = None diff --git a/src/extrai/llm_providers/generic_openai_client.py b/src/extrai/llm_providers/generic_openai_client.py index afc9561..a927255 100644 --- a/src/extrai/llm_providers/generic_openai_client.py +++ b/src/extrai/llm_providers/generic_openai_client.py @@ -1,11 +1,14 @@ +import io +import json import logging +from typing import Any + import openai -import json -import io -from typing import Optional, List, Dict, Any -from extrai.core.errors import LLMAPICallError -from extrai.core.base_llm_client import BaseLLMClient + from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from extrai.core.base_llm_client import BaseLLMClient, ResponseMode +from extrai.core.cost_calculator import calculate_cost +from extrai.core.errors import LLMAPICallError class GenericOpenAIClient(BaseLLMClient): @@ -19,8 +22,8 @@ def __init__( api_key: str, model_name: str, base_url: str, - temperature: Optional[float] = 0.3, - logger: Optional[logging.Logger] = None, + temperature: float | None = 0.3, + logger: logging.Logger | None = None, ): """ Initializes the GenericOpenAIClient. @@ -45,22 +48,29 @@ async def _execute_llm_call( self, system_prompt: str, user_prompt: str, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - ) -> str: + response_mode: ResponseMode = ResponseMode.TEXT, + response_model: Any | None = None, + analytics_collector: WorkflowAnalyticsCollector | None = None, + **kwargs: Any, + ) -> Any: """ Makes the actual API call to an OpenAI-compatible LLM. Args: system_prompt: The system prompt for the LLM. user_prompt: The user prompt for the LLM. + response_mode: Whether to return raw text or structured output. + response_model: The Pydantic/SQLModel class for structured responses. analytics_collector: Optional analytics collector. + **kwargs: Additional arguments to pass to the API client. Returns: - The raw string content from the LLM response. Returns an empty string - if the LLM response content is None. + - TEXT mode: The raw string content from the LLM response. + - STRUCTURED mode: Instance of response_model. Raises: LLMAPICallError: If the API call fails or returns an error. + ValueError: If structured mode is requested but no response_model is provided. """ try: messages = [] @@ -68,95 +78,84 @@ async def _execute_llm_call( messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_prompt}) - chat_completion = await self.client.chat.completions.create( - model=self.model_name, - messages=messages, - response_format={"type": "json_object"}, - temperature=self.temperature - if self.temperature is not None - else openai.NOT_GIVEN, - ) + if response_mode == ResponseMode.STRUCTURED: + if response_model is None: + raise ValueError("response_model required for STRUCTURED mode") - if ( - analytics_collector - and hasattr(chat_completion, "usage") - and chat_completion.usage - ): - analytics_collector.record_llm_usage( - input_tokens=getattr(chat_completion.usage, "prompt_tokens", 0), - output_tokens=getattr( - chat_completion.usage, "completion_tokens", 0 - ), + completion = await self.client.beta.chat.completions.parse( model=self.model_name, + messages=messages, + response_format=response_model, + temperature=self.temperature + if self.temperature is not None + else openai.NOT_GIVEN, + **kwargs, ) - response_content = chat_completion.choices[0].message.content - return response_content if response_content is not None else "" - - except openai.APIError as e: - error_message = str(e) - if hasattr(e, "message") and e.message: - error_message = e.message - elif hasattr(e, "body") and e.body: - if "message" in e.body: - error_message = e.body["message"] - elif "error" in e.body and "message" in e.body["error"]: - error_message = e.body["error"]["message"] - - status_code = e.status_code if hasattr(e, "status_code") else "N/A" - raise LLMAPICallError( - f"API call failed. Status: {status_code}. Error: {error_message}" - ) from e - except Exception as e: - raise LLMAPICallError( - f"Unexpected error during API call: {type(e).__name__} - {str(e)}" - ) from e - - async def generate_structured( - self, - system_prompt: str, - user_prompt: str, - response_model: Any, - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - **kwargs: Any, - ) -> Any: - """ - Generates structured output using OpenAI's beta.chat.completions.parse. - """ - try: - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": user_prompt}) - - completion = await self.client.beta.chat.completions.parse( - model=self.model_name, - messages=messages, - response_format=response_model, - temperature=self.temperature - if self.temperature is not None - else openai.NOT_GIVEN, - **kwargs, - ) - - if ( - analytics_collector - and hasattr(completion, "usage") - and completion.usage - ): - analytics_collector.record_llm_usage( - input_tokens=getattr(completion.usage, "prompt_tokens", 0), - output_tokens=getattr(completion.usage, "completion_tokens", 0), + if ( + analytics_collector + and hasattr(completion, "usage") + and completion.usage + ): + input_tokens = getattr(completion.usage, "prompt_tokens", 0) + output_tokens = getattr( + completion.usage, "completion_tokens", 0 + ) + cost = calculate_cost( + self.model_name, input_tokens, output_tokens + ) + analytics_collector.record_llm_usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + model=self.model_name, + cost=cost, + ) + + message = completion.choices[0].message + if message.refusal: + raise LLMAPICallError( + f"Model refused to generate structured output: {message.refusal}" + ) + + return message.parsed + + else: # TEXT mode + # Default to json_object + if "response_format" not in kwargs: + kwargs["response_format"] = {"type": "json_object"} + + chat_completion = await self.client.chat.completions.create( model=self.model_name, + messages=messages, + temperature=self.temperature + if self.temperature is not None + else openai.NOT_GIVEN, + **kwargs, ) - message = completion.choices[0].message - if message.refusal: - raise LLMAPICallError( - f"Model refused to generate structured output: {message.refusal}" - ) - - return message.parsed + if ( + analytics_collector + and hasattr(chat_completion, "usage") + and chat_completion.usage + ): + input_tokens = getattr( + chat_completion.usage, "prompt_tokens", 0 + ) + output_tokens = getattr( + chat_completion.usage, "completion_tokens", 0 + ) + cost = calculate_cost( + self.model_name, input_tokens, output_tokens + ) + analytics_collector.record_llm_usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + model=self.model_name, + cost=cost, + ) + + response_content = chat_completion.choices[0].message.content + return response_content if response_content is not None else "" except openai.APIError as e: error_message = str(e) @@ -179,10 +178,11 @@ async def generate_structured( async def create_batch_job( self, - requests: List[Dict[str, Any]], + requests: list[dict[str, Any]], endpoint: str = "/v1/chat/completions", completion_window: str = "24h", - metadata: Optional[Dict[str, str]] = None, + metadata: dict[str, str] | None = None, + response_model: Any | None = None, ) -> Any: """ Creates a batch job for processing multiple requests. @@ -246,9 +246,7 @@ async def retrieve_batch_job(self, batch_id: str) -> Any: except openai.APIError as e: raise LLMAPICallError(f"Failed to retrieve batch {batch_id}: {e}") from e - async def list_batch_jobs( - self, limit: int = 20, after: Optional[str] = None - ) -> Any: + async def list_batch_jobs(self, limit: int = 20, after: str | None = None) -> Any: """ Lists batch jobs. """ @@ -279,8 +277,8 @@ async def retrieve_batch_results(self, file_id: str) -> str: ) from e def extract_content_from_batch_response( - self, response: Dict[str, Any] - ) -> Optional[str]: + self, response: dict[str, Any] + ) -> str | None: """ Extracts content from OpenAI batch response item. """ @@ -289,3 +287,30 @@ def extract_content_from_batch_response( if "choices" in body and body["choices"]: return body["choices"][0]["message"]["content"] return None + + def prepare_request( + self, + system_prompt: str, + user_prompt: str, + json_schema: Any | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Prepares a request dictionary for OpenAI batch processing. + """ + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_prompt}) + + body = { + "model": self.model_name, + "messages": messages, + "temperature": self.temperature, + **kwargs, + } + + if json_schema: + body["response_format"] = {"type": "json_object"} + + return body diff --git a/src/extrai/llm_providers/huggingface_client.py b/src/extrai/llm_providers/huggingface_client.py index 59f0e2b..0a77a93 100644 --- a/src/extrai/llm_providers/huggingface_client.py +++ b/src/extrai/llm_providers/huggingface_client.py @@ -1,5 +1,5 @@ import logging -from typing import Optional + from .generic_openai_client import GenericOpenAIClient @@ -13,8 +13,8 @@ def __init__( api_key: str, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1", base_url: str = "https://api-inference.huggingface.co/v1/", - temperature: Optional[float] = 0.3, - logger: Optional[logging.Logger] = None, + temperature: float | None = 0.3, + logger: logging.Logger | None = None, ): """ Initializes the HuggingFaceClient. diff --git a/src/extrai/llm_providers/ollama_client.py b/src/extrai/llm_providers/ollama_client.py index bdcb25c..055c52c 100644 --- a/src/extrai/llm_providers/ollama_client.py +++ b/src/extrai/llm_providers/ollama_client.py @@ -1,5 +1,5 @@ import logging -from typing import Optional + from .generic_openai_client import GenericOpenAIClient @@ -14,8 +14,8 @@ def __init__( api_key: str = "ollama", # Often not required, but good practice to have a default model_name: str = "llama2", base_url: str = "http://localhost:11434/v1", - temperature: Optional[float] = 0.3, - logger: Optional[logging.Logger] = None, + temperature: float | None = 0.3, + logger: logging.Logger | None = None, ): """ Initializes the OllamaClient. diff --git a/src/extrai/llm_providers/openai_client.py b/src/extrai/llm_providers/openai_client.py index be2bea2..e30865e 100644 --- a/src/extrai/llm_providers/openai_client.py +++ b/src/extrai/llm_providers/openai_client.py @@ -1,5 +1,5 @@ import logging -from typing import Optional + from .generic_openai_client import GenericOpenAIClient @@ -13,8 +13,8 @@ def __init__( api_key: str, model_name: str = "gpt-4o", base_url: str = "https://api.openai.com/v1", - temperature: Optional[float] = 0.3, - logger: Optional[logging.Logger] = None, + temperature: float | None = 0.3, + logger: logging.Logger | None = None, ): """ Initializes the OpenAIClient. diff --git a/src/extrai/llm_providers/vertex_ai_client.py b/src/extrai/llm_providers/vertex_ai_client.py new file mode 100644 index 0000000..3f229f1 --- /dev/null +++ b/src/extrai/llm_providers/vertex_ai_client.py @@ -0,0 +1,278 @@ +import json +import logging +import os +from typing import Any + +from google.auth.transport.requests import Request +from google.oauth2 import service_account + +try: + from google import genai +except ImportError: + genai = None + +from .base_google_client import BaseGoogleGenAIClient + + +class VertexAIClient(BaseGoogleGenAIClient): + """ + LLM Client specifically for Vertex AI models, inheriting from BaseGoogleGenAIClient. + Supports either API Key or GCP Service Account JSON. + """ + + def __init__( + self, + model_name: str = "gemini-2.5-flash", + api_key: str | None = None, + service_account_json: str | dict | None = None, + project_id: str | None = None, + location: str = "global", + temperature: float | None = 0.3, + logger: logging.Logger | None = None, + gcs_bucket_name: str | None = None, + ): + """ + Initializes the VertexAIClient. + + Args: + model_name: The model name to use (e.g., "gemini-2.5-flash"). + api_key: The API key for Vertex/Gemini (if using direct API key). + service_account_json: Path to the service account JSON file, or dict containing the credentials. + project_id: GCP Project ID (can be inferred from service_account_json if provided). + location: GCP region for Vertex AI (default: global). + temperature: The sampling temperature for generation. + logger: Logger. + gcs_bucket_name: GCS bucket name for batch API usage. + """ + credentials = None + self.gcs_bucket_name = gcs_bucket_name + + if service_account_json: + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + if isinstance(service_account_json, str): + if os.path.exists(service_account_json): + credentials = service_account.Credentials.from_service_account_file( + service_account_json, scopes=scopes + ) + else: + try: + key_dict = json.loads(service_account_json) + credentials = service_account.Credentials.from_service_account_info( + key_dict, scopes=scopes + ) + except json.JSONDecodeError: + raise ValueError( + "service_account_json must be a valid file path or valid JSON string." + ) + elif isinstance(service_account_json, dict): + credentials = service_account.Credentials.from_service_account_info( + service_account_json, scopes=scopes + ) + + if credentials: + # Need an access token for the OpenAI-compatible endpoint + credentials.refresh(Request()) + api_key = credentials.token + + # Infer project_id if not explicitly provided + if not project_id and hasattr(credentials, "project_id"): + project_id = credentials.project_id + + if not project_id and credentials: + project_id = getattr(credentials, "project_id", project_id) + + self._credentials = credentials + self._project_id = project_id + + # Base URL for Vertex AI OpenAI-compatible endpoint + if project_id: + base_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/endpoints/openapi" + else: + # Fallback to standard OpenAI compatible endpoint if project is unknown (e.g. standard Gemini API) + base_url = "https://generativelanguage.googleapis.com/v1beta/openai/" + + # api_key must be passed to GenericOpenAIClient. Fallback to dummy string if none. + if not api_key: + api_key = "dummy" + + # The OpenAI-compatible Vertex endpoint expects the model to have a publisher prefix + # such as 'google/gemini-2.5-flash'. If no slash is present, add 'google/'. + openai_model_name = model_name + if "/" not in openai_model_name: + openai_model_name = f"google/{openai_model_name}" + + super().__init__( + api_key=api_key, + model_name=openai_model_name, + base_url=base_url, + temperature=temperature, + logger=logger, + ) + # However, google-genai client expects the original model name, sometimes with 'models/' + self.original_model_name = model_name + + if genai: + client_args = {"vertexai": True} + if credentials: + client_args["credentials"] = credentials + if project_id: + client_args["project"] = project_id + client_args["location"] = location + elif api_key and api_key != "dummy": + client_args["api_key"] = api_key + + self.genai_client = genai.Client(**client_args) + else: + self.genai_client = None + + def create_inline_batch_job( + self, src: list[Any], config: dict | None = None + ) -> Any: + if not self.genai_client: + raise ImportError("google-genai package is required for this feature") + + model = getattr(self, "original_model_name", self.model_name) + + gcs_bucket = self.gcs_bucket_name or os.environ.get("GCS_BUCKET_NAME") + if gcs_bucket: + import json + import tempfile + import uuid + import os + from google.cloud import storage + + job_id = str(uuid.uuid4()) + jsonl_lines = [] + for req in src: + vertex_req = {"request": {"contents": req.get("contents", [])}} + req_config = req.get("config", {}) + + gen_config = {} + if "response_mime_type" in req_config: + gen_config["responseMimeType"] = req_config["response_mime_type"] + if "response_schema" in req_config: + gen_config["responseSchema"] = req_config["response_schema"] + if "temperature" in req_config: + gen_config["temperature"] = req_config["temperature"] + elif self.temperature is not None: + gen_config["temperature"] = self.temperature + + if gen_config: + vertex_req["request"]["generationConfig"] = gen_config + + if "system_instruction" in req_config: + vertex_req["request"]["systemInstruction"] = { + "role": "system", + "parts": [{"text": req_config["system_instruction"]}] + } + jsonl_lines.append(json.dumps(vertex_req)) + + jsonl_content = "\n".join(jsonl_lines) + + temp_prompt_dir = os.path.join(tempfile.gettempdir(), "temp_prompts") + os.makedirs(temp_prompt_dir, exist_ok=True) + prompt_file_path = os.path.join(temp_prompt_dir, f"{job_id}_prompt.jsonl") + + with open(prompt_file_path, "w") as f: + f.write(jsonl_content + "\n") + + if hasattr(self, "_credentials") and self._credentials: + storage_client = storage.Client(credentials=self._credentials, project=self._project_id) + else: + storage_client = storage.Client() + + bucket = storage_client.bucket(gcs_bucket) + input_blob_name = f"batch_inputs/{job_id}_prompt.jsonl" + blob = bucket.blob(input_blob_name) + blob.upload_from_filename(prompt_file_path) + + try: + os.remove(prompt_file_path) + except OSError: + pass + + gcs_input_uri = f"gs://{gcs_bucket}/{input_blob_name}" + gcs_output_uri = f"gs://{gcs_bucket}/batch_outputs/{job_id}/" + + if config is None: + config = {} + config["dest"] = gcs_output_uri + config["display_name"] = f"batch_{job_id}" + + if self.logger: + self.logger.info(f"VertexAIClient: Uploaded batch inputs to {gcs_input_uri}. Output will be at {gcs_output_uri}") + + return self.genai_client.batches.create( + model=model, + src=gcs_input_uri, + config=config, + ) + else: + if self.logger: + self.logger.warning("GCS_BUCKET_NAME not set, falling back to inline source") + + return self.genai_client.batches.create( + model=model, + src=src, + config=config, + ) + + async def retrieve_batch_results(self, batch_id: str) -> str: + if not self.genai_client: + return await super().retrieve_batch_results(batch_id) + + job = await self.retrieve_batch_job(batch_id) + + if hasattr(job, "dest") and hasattr(job.dest, "gcs_uri") and job.dest.gcs_uri: + from google.cloud import storage + import json + + if hasattr(self, "_credentials") and self._credentials: + storage_client = storage.Client(credentials=self._credentials, project=self._project_id) + else: + storage_client = storage.Client() + + gcs_uri = job.dest.gcs_uri + bucket_name, blob_prefix = gcs_uri.replace("gs://", "").split("/", 1) + bucket = storage_client.bucket(bucket_name) + + output_lines = [] + blobs = bucket.list_blobs(prefix=blob_prefix) + for blob in blobs: + if blob.name.endswith(".jsonl"): + jsonl_string = blob.download_as_string() + for line in jsonl_string.strip().split(b"\n"): + if not line.strip(): + continue + response_data = json.loads(line) + content_text = "" + if "response" in response_data: + resp = response_data["response"] + if "text" in resp: + content_text = resp["text"] + elif "candidates" in resp and resp["candidates"]: + try: + content_text = resp["candidates"][0]["content"]["parts"][0]["text"] + except (KeyError, IndexError): + pass + + openai_resp = { + "id": "batch_req", + "response": { + "status_code": 200, + "body": {"choices": [{"message": {"content": content_text}}]}, + }, + } + + if "response" in response_data and "usage_metadata" in response_data["response"]: + usage = response_data["response"]["usage_metadata"] + openai_resp["response"]["body"]["usage"] = { + "prompt_tokens": usage.get("promptTokenCount", usage.get("prompt_token_count", 0)), + "completion_tokens": usage.get("candidatesTokenCount", usage.get("candidates_token_count", 0)), + "total_tokens": usage.get("totalTokenCount", usage.get("total_token_count", 0)), + } + + output_lines.append(json.dumps(openai_resp)) + return "\n".join(output_lines) + + return await super().retrieve_batch_results(batch_id) diff --git a/src/extrai/utils/alignment_utils.py b/src/extrai/utils/alignment_utils.py index ee56ac0..fe6264b 100644 --- a/src/extrai/utils/alignment_utils.py +++ b/src/extrai/utils/alignment_utils.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, List from difflib import SequenceMatcher +from typing import Any -def normalize_json_revisions(revisions: List[Any]) -> List[Any]: +def normalize_json_revisions(revisions: list[Any]) -> list[Any]: """ Aligns arrays across revisions using similarity-based matching. Handles different structures and ensures consistent ordering. @@ -30,46 +30,72 @@ def normalize_json_revisions(revisions: List[Any]) -> List[Any]: def align_entity_arrays( - arrays: List[List[Dict[str, Any]]], -) -> List[List[Dict[str, Any]]]: + arrays: list[list[dict[str, Any]]], + truncate_to_min_length: bool = True, +) -> list[list[dict[str, Any]]]: """ Aligns multiple arrays of entities so similar objects are in the same positions. - Uses the first array as reference and matches objects based on similarity. + When truncate_to_min_length is True, uses the minimum length across all arrays. + When False, uses the longest array as reference and pads shorter arrays with empty dicts. """ if not arrays or not any(arrays): return arrays - # Validate all arrays have the same length lengths = [len(arr) for arr in arrays] - if len(set(lengths)) > 1: - print( - f"Warning: Arrays have different lengths {lengths}. Using minimum length." - ) - min_length = min(lengths) - arrays = [arr[:min_length] for arr in arrays] - - # Use first array as reference - reference = arrays[0] - aligned = [reference[:]] - - # Align each subsequent array to match the reference - for arr in arrays[1:]: + + if truncate_to_min_length: + if len(set(lengths)) > 1: + print( + f"Warning: Arrays have different lengths {lengths}. Using minimum length." + ) + min_length = min(lengths) + arrays = [arr[:min_length] for arr in arrays] + + reference = arrays[0] + else: + # Find the longest array to use as reference + max_idx = lengths.index(max(lengths)) + reference = arrays[max_idx] + + aligned_results = [] + + # Align each array to match the reference + for arr in arrays: + if arr is reference and truncate_to_min_length: + aligned_results.append(reference[:]) + continue + + if arr is reference: + # If this is the reference array (and we didn't truncate), + # we still want to add it as is + aligned_results.append(reference[:]) + continue + reordered = [] used_indices = set() for ref_obj in reference: # Find best match in current array best_idx = find_best_match(ref_obj, arr, used_indices) - reordered.append(arr[best_idx]) - used_indices.add(best_idx) + + if best_idx != -1: + reordered.append(arr[best_idx]) + used_indices.add(best_idx) + elif not truncate_to_min_length: + # Pad with empty dict if no match found and we aren't truncating + reordered.append({}) + else: + # This case shouldn't happen if we truncated to min length, + # but adding a fallback just in case + reordered.append({}) - aligned.append(reordered) + aligned_results.append(reordered) - return aligned + return aligned_results def find_best_match( - target: Dict[str, Any], candidates: List[Dict[str, Any]], used_indices: set + target: dict[str, Any], candidates: list[dict[str, Any]], used_indices: set ) -> int: """ Finds the index of the most similar object in candidates that hasn't been used. @@ -89,7 +115,7 @@ def find_best_match( return best_idx -def calculate_similarity(obj1: Dict[str, Any], obj2: Dict[str, Any]) -> float: +def calculate_similarity(obj1: dict[str, Any], obj2: dict[str, Any]) -> float: """ Calculates similarity score between two objects (0-1, higher is more similar). Handles different field types recursively. diff --git a/src/extrai/utils/flattening_utils.py b/src/extrai/utils/flattening_utils.py index d264e02..7dfdf84 100644 --- a/src/extrai/utils/flattening_utils.py +++ b/src/extrai/utils/flattening_utils.py @@ -1,15 +1,15 @@ -from typing import Dict, List, Any, Union, Tuple +from typing import Any, Union # Define type aliases for clarity JSONValue = Union[str, int, float, bool, None] -JSONObject = Dict[str, Any] -JSONArray = List[Any] -Path = Tuple[Union[str, int], ...] -FlattenedJSON = Dict[Path, JSONValue] +JSONObject = dict[str, Any] +JSONArray = list[Any] +Path = tuple[str | int, ...] +FlattenedJSON = dict[Path, JSONValue] def flatten_json( - nested_json: Union[JSONObject, JSONArray], + nested_json: JSONObject | JSONArray, parent_path: Path = (), separator: str = ".", ) -> FlattenedJSON: @@ -64,7 +64,7 @@ def flatten_json( def unflatten_json( flat_json: FlattenedJSON, -) -> Union[JSONObject, JSONArray, JSONValue, None]: +) -> JSONObject | JSONArray | JSONValue | None: """ Unflattens a flat dictionary (with tuple paths) back into a nested JSON-like structure. @@ -118,7 +118,7 @@ def unflatten_json( # or build lists dynamically. # For list items to be correctly placed, they need to be filled. # If paths are like {(0, 'a'): 1, (2, 'b'): 1}, we need list of size 3. - root: Union[JSONObject, JSONArray] = [None] * (max_index + 1) + root: JSONObject | JSONArray = [None] * (max_index + 1) else: root = {} diff --git a/src/extrai/utils/json_validation_utils.py b/src/extrai/utils/json_validation_utils.py index 511939c..a311520 100644 --- a/src/extrai/utils/json_validation_utils.py +++ b/src/extrai/utils/json_validation_utils.py @@ -1,9 +1,10 @@ +from typing import Any + import jsonschema -from typing import Any, Dict def is_json_valid( - json_data_to_validate: Any, json_schema_definition: Dict[str, Any] + json_data_to_validate: Any, json_schema_definition: dict[str, Any] ) -> bool: """ Validates JSON data against a JSON schema. diff --git a/src/extrai/utils/llm_output_processing.py b/src/extrai/utils/llm_output_processing.py index 51c3f60..b5a52cf 100644 --- a/src/extrai/utils/llm_output_processing.py +++ b/src/extrai/utils/llm_output_processing.py @@ -1,10 +1,10 @@ import json -from typing import Any, Dict, Type, Optional, Union, Tuple +from typing import Any -from extrai.core.analytics_collector import WorkflowAnalyticsCollector -from sqlmodel import SQLModel from pydantic import ValidationError as PydanticValidationError +from sqlmodel import SQLModel +from extrai.core.analytics_collector import WorkflowAnalyticsCollector from extrai.core.errors import ( LLMOutputParseError, LLMOutputValidationError, @@ -12,7 +12,7 @@ from extrai.utils.json_validation_utils import is_json_valid -def _filter_special_fields_for_validation(data: Dict[str, Any]) -> Dict[str, Any]: +def _filter_special_fields_for_validation(data: dict[str, Any]) -> dict[str, Any]: """ Removes fields that are part of an extended schema (e.g., for relationship handling or temporary IDs) but not part of the core SQLModel definition for validation. @@ -31,7 +31,7 @@ def _filter_special_fields_for_validation(data: Dict[str, Any]) -> Dict[str, Any } -def _unwrap_priority_keys(data: Any) -> Tuple[Any, bool]: +def _unwrap_priority_keys(data: Any) -> tuple[Any, bool]: """ Recursively unwraps priority keys (result, data, etc.) from a dictionary. Returns a tuple (unwrapped_data, was_unwrapped). @@ -82,12 +82,12 @@ def _unwrap_llm_output(data: Any) -> Any: def process_and_validate_llm_output( - raw_llm_content: Optional[str], - model_schema_map: Dict[str, Type[SQLModel]], + raw_llm_content: str | None, + model_schema_map: dict[str, type[SQLModel]], revision_info_for_error: str = "LLM Output", - analytics_collector: Optional[WorkflowAnalyticsCollector] = None, - default_model_type: Optional[str] = None, -) -> list[Dict[str, Any]]: + analytics_collector: WorkflowAnalyticsCollector | None = None, + default_model_type: str | None = None, +) -> list[dict[str, Any]]: """ Parses raw LLM JSON content, unwraps structures, and validates a list of objects against a map of SQLModel schemas. @@ -171,9 +171,9 @@ def process_and_validate_llm_output( def process_and_validate_raw_json( raw_llm_content: str, revision_info_for_error: str, - target_json_schema: Optional[Dict[str, Any]] = None, + target_json_schema: dict[str, Any] | None = None, attempt_unwrap: bool = True, -) -> Union[Dict[str, Any], list[Dict[str, Any]]]: +) -> dict[str, Any] | list[dict[str, Any]]: """ Parses, unwraps, and validates raw JSON content against a schema. diff --git a/src/extrai/utils/rate_limiter.py b/src/extrai/utils/rate_limiter.py index a82ec0e..523b5d5 100644 --- a/src/extrai/utils/rate_limiter.py +++ b/src/extrai/utils/rate_limiter.py @@ -1,6 +1,5 @@ import asyncio import time -from typing import List, Tuple class AsyncRateLimiter: @@ -18,7 +17,7 @@ def __init__(self, max_capacity: int, period: float = 60.0): self.max_capacity = max_capacity self.period = period # List of (timestamp, cost) - self.history: List[Tuple[float, int]] = [] + self.history: list[tuple[float, int]] = [] self._lock = asyncio.Lock() async def acquire(self, cost: int = 1): diff --git a/src/extrai/utils/serialization_utils.py b/src/extrai/utils/serialization_utils.py index 3b08680..ec246d5 100644 --- a/src/extrai/utils/serialization_utils.py +++ b/src/extrai/utils/serialization_utils.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, Set, Optional -from sqlmodel import SQLModel +from typing import Any + from sqlalchemy.orm.collections import InstrumentedList +from sqlmodel import SQLModel def serialize_sqlmodel_with_relationships( - obj: SQLModel, seen: Optional[Set[int]] = None -) -> Dict[str, Any]: + obj: SQLModel, seen: set[int] | None = None +) -> dict[str, Any]: """ Recursively serializes a SQLModel instance, including its loaded relationships. Uses model_dump(mode='json') to handle basic types (including Decimal -> str/float). @@ -67,3 +68,47 @@ def make_json_serializable(obj: Any) -> Any: elif isinstance(obj, list): return [make_json_serializable(item) for item in obj] return obj + + +def resolve_step_param( + param: str | list[str], step_index: int = 0, total_steps: int = 1 +) -> str: + """ + Resolves a parameter that can be a single string or a list of strings + to the specific string for the current step. + + Args: + param: The parameter value (str or list[str]) + step_index: The current step index (0-based) + total_steps: The total number of steps in the process + + Returns: + The string value for the current step. + + Raises: + ValueError: If list length does not match requirements. + """ + if isinstance(param, str): + return param + + if not isinstance(param, list): + return str(param) if param is not None else "" + + if not param: + return "" + + if len(param) == 1: + return param[0] + + if len(param) != total_steps: + raise ValueError( + f"Parameter list has {len(param)} elements, but process has {total_steps} steps. " + "Pass a single string, a 1-element list, or a list matching the number of steps." + ) + + if step_index < 0 or step_index >= len(param): + raise ValueError( + f"Step index {step_index} out of bounds for parameter list of length {len(param)}" + ) + + return param[step_index] diff --git a/src/extrai/utils/type_mapping.py b/src/extrai/utils/type_mapping.py index 0b891db..ff01631 100644 --- a/src/extrai/utils/type_mapping.py +++ b/src/extrai/utils/type_mapping.py @@ -2,11 +2,11 @@ import enum from typing import ( Any, - Dict, - List, Optional, get_args, get_origin, +) +from typing import ( Union as TypingUnion, ) @@ -24,21 +24,21 @@ def _process_union_types(args, recurse_func): # Handler registry for different type origins ORIGIN_HANDLERS = { - Optional: lambda args, r: r(args[0]) - if args and args[0] is not type(None) - else "none", - list: lambda args, r: f"list[{','.join([r(arg) for arg in args])}]" - if args - else "list", - List: lambda args, r: f"list[{','.join([r(arg) for arg in args])}]" - if args - else "list", - dict: lambda args, r: f"dict[{r(args[0])},{r(args[1])}]" - if args and len(args) == 2 - else "dict", - Dict: lambda args, r: f"dict[{r(args[0])},{r(args[1])}]" - if args and len(args) == 2 - else "dict", + Optional: lambda args, r: ( + r(args[0]) if args and args[0] is not type(None) else "none" + ), + list: lambda args, r: ( + f"list[{','.join([r(arg) for arg in args])}]" if args else "list" + ), + list: lambda args, r: ( + f"list[{','.join([r(arg) for arg in args])}]" if args else "list" + ), + dict: lambda args, r: ( + f"dict[{r(args[0])},{r(args[1])}]" if args and len(args) == 2 else "dict" + ), + dict: lambda args, r: ( + f"dict[{r(args[0])},{r(args[1])}]" if args and len(args) == 2 else "dict" + ), TypingUnion: _process_union_types, } @@ -117,7 +117,7 @@ def get_python_type_str_from_pydantic_annotation(annotation: Any) -> str: # --- Handlers for complex and generic types --- -def _handle_list_type(python_type_lower: str) -> Optional[str]: +def _handle_list_type(python_type_lower: str) -> str | None: """Handles list[...] and array[...] type mappings.""" if python_type_lower.startswith("list[") and python_type_lower.endswith("]"): inner_type_str = python_type_lower[5:-1] @@ -126,7 +126,7 @@ def _handle_list_type(python_type_lower: str) -> Optional[str]: return None -def _handle_dict_type(python_type_lower: str) -> Optional[str]: +def _handle_dict_type(python_type_lower: str) -> str | None: """Handles dict[...] and object[...] type mappings.""" if python_type_lower.startswith("dict[") and python_type_lower.endswith("]"): inner_types_str = python_type_lower[5:-1] @@ -140,7 +140,7 @@ def _handle_dict_type(python_type_lower: str) -> Optional[str]: return None -def _handle_union_type(python_type_lower: str) -> Optional[str]: +def _handle_union_type(python_type_lower: str) -> str | None: """Handles union[...] type mappings.""" if python_type_lower.startswith("union[") and python_type_lower.endswith("]"): inner_types_str = python_type_lower[6:-1] @@ -160,7 +160,7 @@ def _handle_union_type(python_type_lower: str) -> Optional[str]: def _handle_generic_or_unknown_type( python_type_lower: str, sql_type_lower: str -) -> Optional[str]: +) -> str | None: """Handles ambiguous types like plain 'list' or 'dict' and unknown types.""" if python_type_lower == "list": if "text" in sql_type_lower: # Let the SQL keyword mapping handle this case diff --git a/tests/core/batch_pipeline/test_batch_counting.py b/tests/core/batch_pipeline/test_batch_counting.py index b3603d5..c85d62c 100644 --- a/tests/core/batch_pipeline/test_batch_counting.py +++ b/tests/core/batch_pipeline/test_batch_counting.py @@ -1,10 +1,12 @@ import unittest +import json from unittest.mock import MagicMock, patch, AsyncMock -from extrai.core.batch_pipeline import BatchPipeline, BatchJobStatus +from extrai.core.batch.batch_pipeline import BatchPipeline, BatchJobStatus from extrai.core.model_registry import ModelRegistry from extrai.core.extraction_config import ExtractionConfig -from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.base_llm_client import BaseLLMClient, ProviderBatchStatus from extrai.core.batch_models import BatchJobContext +from extrai.core.config.batch_job_config import BatchJobConfig class TestBatchPipelineCounting(unittest.IsolatedAsyncioTestCase): @@ -33,14 +35,16 @@ def setUp(self): self.mock_logger = MagicMock() with ( - patch("extrai.core.batch_pipeline.ClientRotator") as MockClientRotator, patch( - "extrai.core.batch_pipeline.ExtractionContextPreparer" + "extrai.core.batch.batch_pipeline.ClientRotator" + ) as MockClientRotator, + patch( + "extrai.core.batch.batch_pipeline.ExtractionContextPreparer" ) as MockContextPreparer, - patch("extrai.core.batch_pipeline.PromptBuilder") as MockBuilder, - patch("extrai.core.batch_pipeline.EntityCounter") as MockCounter, - patch("extrai.core.batch_pipeline.JSONConsensus") as MockConsensus, - patch("extrai.core.batch_pipeline.ModelWrapperBuilder"), + patch("extrai.core.batch.batch_pipeline.PromptBuilder") as MockBuilder, + patch("extrai.core.batch.batch_pipeline.EntityCounter") as MockCounter, + patch("extrai.core.batch.batch_pipeline.ConsensusRunner") as MockConsensus, + patch("extrai.core.batch.batch_pipeline.ModelWrapperBuilder"), ): self.pipeline = BatchPipeline( self.mock_model_registry, @@ -53,7 +57,24 @@ def setUp(self): self.pipeline.context_preparer = MockContextPreparer.return_value self.pipeline.prompt_builder = MockBuilder.return_value self.pipeline.entity_counter = MockCounter.return_value - self.pipeline.consensus = MockConsensus.return_value + self.pipeline.consensus_runner = MockConsensus.return_value + + # Since we are testing process_batch logic which is delegated to processor, + # we need to ensure the processor uses our mocked components/config. + # The processor was instantiated with these mocks in __init__, so it should be fine. + # However, test_process_batch_counting_transition mocks `self.pipeline.entity_counter.llm_client`. + # Since processor holds a reference to entity_counter, and we updated it on pipeline, + # check if processor refers to the same object. + + # pipeline.entity_counter = MockCounter.return_value sets the attribute on pipeline instance. + # But processor.entity_counter was set during init. + # We need to update processor's reference too. + self.pipeline.processor.entity_counter = self.pipeline.entity_counter + self.pipeline.processor.client_rotator = self.pipeline.client_rotator + self.pipeline.status_checker.entity_counter = self.pipeline.entity_counter + self.pipeline.status_checker.client_rotator = self.pipeline.client_rotator + self.pipeline.submitter.entity_counter = self.pipeline.entity_counter + self.pipeline.submitter.client_rotator = self.pipeline.client_rotator async def test_submit_batch_counting(self): # Setup mocks @@ -85,7 +106,7 @@ async def test_submit_batch_counting(self): self.assertEqual(added_context.status, BatchJobStatus.COUNTING_SUBMITTED) config = added_context.config - self.assertTrue(config["count_entities"]) + self.assertTrue(config.count_entities) async def test_process_batch_counting_transition(self): # Mock Context @@ -94,29 +115,25 @@ async def test_process_batch_counting_transition(self): current_batch_id="counting_batch_id", status=BatchJobStatus.COUNTING_SUBMITTED, input_strings=["doc"], - config={"count_entities": True, "custom_extraction_process": "proc"}, + config=BatchJobConfig( + count_entities=True, custom_extraction_process="proc" + ), ) self.mock_session.get.return_value = context # Mock get_status to return COUNTING_READY_TO_PROCESS - mock_provider_job = MagicMock() - mock_provider_job.status = "completed" - - # Correctly mock entity_counter.llm_client for counting status check and results retrieval - self.pipeline.entity_counter.llm_client.retrieve_batch_job = AsyncMock( - return_value=mock_provider_job + # We need to mock get_batch_status to return ProviderBatchStatus.COMPLETED + # The actual status checker will call get_batch_status + self.pipeline.entity_counter.llm_client.get_batch_status = AsyncMock( + return_value=ProviderBatchStatus.COMPLETED ) # Mock counting results - counting_results_file = '{"id": "line1"}' + # The processor expects results in the format {"ModelName": ["desc1", "desc2"]} + counting_results_file = json.dumps({"RootModel": ["desc1"]}) self.pipeline.entity_counter.llm_client.retrieve_batch_results = AsyncMock( - return_value=counting_results_file + return_value=[counting_results_file] ) - self.pipeline.entity_counter.llm_client.extract_content_from_batch_response.return_value = '{"RootModel": ["desc1"]}' - - self.pipeline.entity_counter.validate_counts.return_value = { - "RootModel": ["desc1"] - } # Mock extraction batch submission # Extraction phase uses client_rotator client @@ -130,12 +147,21 @@ async def test_process_batch_counting_transition(self): # Ensure build_prompts returns expected tuple self.pipeline.prompt_builder.build_prompts.return_value = ("sys", "user") + # Mock counting_consensus + self.pipeline.entity_counter.counting_consensus.achieve_consensus = AsyncMock( + return_value=[{"model": "RootModel", "description": "desc1"}] + ) + # Test process + # We don't need to patch process_and_validate_llm_output because BatchProcessor + # parses the JSON manually in _process_counting_completion result = await self.pipeline.process_batch("root_1", self.mock_session) + print("RESULT", result) + print("RESULT MESSAGE", result.message) # Verify transition - self.assertEqual(result.status, BatchJobStatus.PROCESSING) - self.assertEqual(result.message, "Transitioned from counting to extraction") + self.assertEqual(result.status, BatchJobStatus.SUBMITTED) + self.assertEqual(result.message, "Counting complete, extraction submitted.") # Verify context updated self.assertEqual(context.status, BatchJobStatus.SUBMITTED) @@ -143,8 +169,11 @@ async def test_process_batch_counting_transition(self): # Verify config updated with descriptions config = context.config - self.assertIn("expected_entity_descriptions", config) - self.assertEqual(config["expected_entity_descriptions"], ["[RootModel] desc1"]) + self.assertIsNotNone(config.expected_entity_descriptions) + self.assertEqual( + config.expected_entity_descriptions, + [{"model": "RootModel", "description": "desc1"}], + ) if __name__ == "__main__": diff --git a/tests/core/batch_pipeline/test_batch_pipeline.py b/tests/core/batch_pipeline/test_batch_pipeline.py index 6f25376..df22b1a 100644 --- a/tests/core/batch_pipeline/test_batch_pipeline.py +++ b/tests/core/batch_pipeline/test_batch_pipeline.py @@ -1,10 +1,10 @@ import unittest from unittest.mock import MagicMock, patch, AsyncMock -from extrai.core.batch_pipeline import BatchPipeline, BatchJobStatus +from extrai.core.batch.batch_pipeline import BatchPipeline from extrai.core.model_registry import ModelRegistry from extrai.core.extraction_config import ExtractionConfig from extrai.core.base_llm_client import BaseLLMClient -from extrai.core.batch_models import BatchJobContext +from extrai.core.batch_models import BatchJobContext, BatchJobStatus class TestBatchPipeline(unittest.IsolatedAsyncioTestCase): @@ -32,14 +32,18 @@ def setUp(self): self.mock_logger = MagicMock() with ( - patch("extrai.core.batch_pipeline.ClientRotator") as MockClientRotator, + patch("extrai.core.batch.batch_pipeline.ClientRotator") as MockClientRotator, patch( - "extrai.core.batch_pipeline.ExtractionContextPreparer" + "extrai.core.batch.batch_pipeline.ExtractionContextPreparer" ) as MockContextPreparer, - patch("extrai.core.batch_pipeline.PromptBuilder") as MockBuilder, - patch("extrai.core.batch_pipeline.EntityCounter") as MockCounter, - patch("extrai.core.batch_pipeline.JSONConsensus") as MockConsensus, - patch("extrai.core.batch_pipeline.ModelWrapperBuilder"), + patch("extrai.core.batch.batch_pipeline.PromptBuilder") as MockBuilder, + patch("extrai.core.batch.batch_pipeline.EntityCounter") as MockCounter, + patch("extrai.core.batch.batch_pipeline.ConsensusRunner") as MockConsensus, + patch("extrai.core.batch.batch_pipeline.ModelWrapperBuilder"), + patch("extrai.core.batch.batch_pipeline.BatchSubmitter") as MockSubmitter, + patch("extrai.core.batch.batch_pipeline.BatchStatusChecker"), + patch("extrai.core.batch.batch_pipeline.BatchResultRetriever"), + patch("extrai.core.batch.batch_pipeline.BatchProcessor"), ): self.pipeline = BatchPipeline( self.mock_model_registry, @@ -48,14 +52,25 @@ def setUp(self): self.mock_analytics, self.mock_logger, ) - # We need to access the instances created inside, so we'll mock them on the pipeline instance self.pipeline.client_rotator = MockClientRotator.return_value self.pipeline.context_preparer = MockContextPreparer.return_value self.pipeline.prompt_builder = MockBuilder.return_value self.pipeline.entity_counter = MockCounter.return_value - self.pipeline.consensus = MockConsensus.return_value - + self.pipeline.consensus_runner = MockConsensus.return_value + async def test_submit_batch_success(self): + # We instantiate a real submitter to test the logic + from extrai.core.batch.batch_submitter import BatchSubmitter + self.pipeline.submitter = BatchSubmitter( + self.mock_model_registry, + self.pipeline.client_rotator, + self.mock_config, + self.pipeline.entity_counter, + self.pipeline.context_preparer, + self.pipeline.request_factory, + self.mock_logger + ) + self.pipeline.prompt_builder.build_prompts.return_value = ("sys", "user") mock_batch_job = MagicMock() @@ -65,7 +80,6 @@ async def test_submit_batch_success(self): mock_client_instance.create_batch_job = AsyncMock(return_value=mock_batch_job) self.pipeline.context_preparer.prepare_example = AsyncMock(return_value="") - self.pipeline._count_if_needed = AsyncMock(return_value=None) root_id = await self.pipeline.submit_batch(self.mock_session, ["doc"]) @@ -79,28 +93,54 @@ async def test_submit_batch_success(self): self.assertEqual(added_context.status, BatchJobStatus.SUBMITTED) async def test_retrieve_and_validate_results(self): + # Use real retriever + from extrai.core.batch.batch_result_retriever import BatchResultRetriever + self.pipeline.retriever = BatchResultRetriever(self.mock_model_registry, self.mock_logger) + mock_context = BatchJobContext(current_batch_id="prov_1") mock_client = self.pipeline.client_rotator.get_next_client.return_value mock_client.retrieve_batch_results = AsyncMock( - return_value='{"key": "value"}\n{"key": "value2"}' + return_value=["res1", "res2"] ) - mock_client.extract_content_from_batch_response.side_effect = [ - '{"_type": "RootModel", "id": 1}', - '{"_type": "RootModel", "id": 2}', - ] with patch( - "extrai.core.batch_pipeline.process_and_validate_llm_output" + "extrai.core.batch.batch_result_retriever.process_and_validate_llm_output" ) as mock_validate: mock_validate.side_effect = [[{"id": 1}], [{"id": 2}]] - results = await self.pipeline._retrieve_and_validate_results(mock_context) - - # normalize_json_revisions wraps each revision in a list. + results, _ = await self.pipeline.retriever.retrieve_and_validate_results( + mock_context, mock_client + ) self.assertEqual(len(results), 2) - self.assertEqual(results[0], [{"id": 1}]) - self.assertEqual(results[1], [{"id": 2}]) + async def test_monitor_batch_job_success(self): + root_batch_id = "root_123" + db_session = MagicMock() + from extrai.core.batch_models import BatchProcessResult + + # 1. READY_TO_PROCESS + self.pipeline.status_checker.get_status = AsyncMock( + return_value=BatchJobStatus.READY_TO_PROCESS + ) + + process_result = BatchProcessResult( + status=BatchJobStatus.COMPLETED, hydrated_objects=["obj1"] + ) + + # Mock processor.process_batch + self.pipeline.processor.process_batch = AsyncMock(return_value=process_result) + + # Ensure processor has result_processor + self.pipeline.processor.result_processor = MagicMock() + + result = await self.pipeline.monitor_batch_job( + root_batch_id, db_session, poll_interval=0.001 + ) + + self.assertEqual(result.status, BatchJobStatus.COMPLETED) + self.pipeline.processor.result_processor.persist.assert_called_once_with( + ["obj1"], db_session + ) if __name__ == "__main__": unittest.main() diff --git a/tests/core/batch_pipeline/test_batch_pipeline_structured.py b/tests/core/batch_pipeline/test_batch_pipeline_structured.py index c75c154..56dcefe 100644 --- a/tests/core/batch_pipeline/test_batch_pipeline_structured.py +++ b/tests/core/batch_pipeline/test_batch_pipeline_structured.py @@ -3,11 +3,12 @@ import json from sqlmodel import SQLModel, Field -from extrai.core.batch_pipeline import BatchPipeline +from extrai.core.batch.batch_pipeline import BatchPipeline from extrai.core.model_registry import ModelRegistry from extrai.core.extraction_config import ExtractionConfig from extrai.core.base_llm_client import BaseLLMClient from extrai.core.batch_models import BatchJobContext +from extrai.core.config.batch_job_config import BatchJobConfig class Recipe(SQLModel): @@ -41,14 +42,14 @@ def setUp(self): self.mock_logger = MagicMock() with ( - patch("extrai.core.batch_pipeline.ClientRotator") as MockClientRotator, + patch("extrai.core.batch.batch_pipeline.ClientRotator") as MockClientRotator, patch( - "extrai.core.batch_pipeline.ExtractionContextPreparer" + "extrai.core.batch.batch_pipeline.ExtractionContextPreparer" ) as MockContextPreparer, - patch("extrai.core.batch_pipeline.PromptBuilder") as MockBuilder, - patch("extrai.core.batch_pipeline.EntityCounter") as MockCounter, - patch("extrai.core.batch_pipeline.JSONConsensus") as MockConsensus, - patch("extrai.core.batch_pipeline.ModelWrapperBuilder"), + patch("extrai.core.batch.batch_pipeline.PromptBuilder") as MockBuilder, + patch("extrai.core.batch.batch_pipeline.EntityCounter") as MockCounter, + patch("extrai.core.batch.batch_pipeline.ConsensusRunner") as MockConsensus, + patch("extrai.core.batch.batch_pipeline.ModelWrapperBuilder"), ): self.pipeline = BatchPipeline( self.mock_model_registry, @@ -61,7 +62,7 @@ def setUp(self): self.pipeline.context_preparer = MockContextPreparer.return_value self.pipeline.prompt_builder = MockBuilder.return_value self.pipeline.entity_counter = MockCounter.return_value - self.pipeline.consensus = MockConsensus.return_value + self.pipeline.consensus_runner = MockConsensus.return_value async def test_retrieve_and_validate_results_missing_type(self): """ @@ -70,10 +71,9 @@ async def test_retrieve_and_validate_results_missing_type(self): """ mock_context = BatchJobContext( current_batch_id="prov_1", - config={ - "use_structured_output": True, - "schema_json": {}, - }, + config=BatchJobConfig( + schema_json="{}", + ), ) mock_client = self.pipeline.client_rotator.get_next_client.return_value @@ -86,48 +86,21 @@ async def test_retrieve_and_validate_results_missing_type(self): } mock_client.retrieve_batch_results = AsyncMock( - return_value=json.dumps(structured_response) - ) - - # We need to simulate how extract_content_from_batch_response behaves. - # Assuming it returns the inner JSON string or dict. - # In the original code it calls `process_and_validate_llm_output`. - - # For this test, we mock extract_content_from_batch_response to return the JSON string of entities wrapper - # The real client implementation varies, but let's assume it returns the raw JSON string - mock_client.extract_content_from_batch_response.return_value = json.dumps( - structured_response + return_value=[json.dumps(structured_response)] ) + mock_client.extract_content_from_batch_response.side_effect = NotImplementedError # We expect this to fail because we haven't fixed the code yet, # and process_and_validate_llm_output will look for _type. - # NOTE: process_and_validate_llm_output is imported in batch_pipeline. - # We shouldn't patch it if we want to test the failure integration, - # but process_and_validate_llm_output raises LLMOutputValidationError. - # BatchPipeline catches Exception and logs it, returning empty list if validation fails. - - # However, looking at _retrieve_and_validate_results: - # It logs warning on validation failure. - - # To assert failure, we can check that the returned list is empty - # OR we can mock process_and_validate_llm_output to see what it was called with - # OR we can let it run and see if it returns valid objects. - - # Since we want to prove it fails validation, we should let the real process_and_validate_llm_output run. - # But `process_and_validate_llm_output` requires `Recipe` (SQLModel) to be in `model_schema_map`. - # We set that up in setUp. - - results = await self.pipeline._retrieve_and_validate_results(mock_context) + # NOTE: process_and_validate_llm_output is imported in batch_result_retriever. + # We allow the real implementation to run. + + # Access retrieval logic via pipeline.retriever + results, validation_errors = await self.pipeline.retriever.retrieve_and_validate_results( + mock_context, mock_client + ) - # With the fix, we expect results to be validated and returned - # Since _type is injected, it should be present in the result + # With default_model_type provided by BatchResultRetriever, missing _type should be handled self.assertEqual(len(results), 1) - self.assertEqual(len(results[0]), 1) - item = results[0][0] - self.assertEqual(item["name"], "Pancake") - self.assertEqual(item["_type"], "Recipe") - - -if __name__ == "__main__": - unittest.main() + self.assertEqual(len(validation_errors), 0) diff --git a/tests/core/batch_pipeline/test_context_passing.py b/tests/core/batch_pipeline/test_context_passing.py new file mode 100644 index 0000000..b59d615 --- /dev/null +++ b/tests/core/batch_pipeline/test_context_passing.py @@ -0,0 +1,156 @@ +import unittest +from unittest.mock import MagicMock, patch, AsyncMock +from sqlmodel import SQLModel, Session, select +from datetime import datetime, timezone +import json + +from extrai.core.batch.batch_pipeline import BatchPipeline +from extrai.core.batch_models import BatchJobContext, BatchJobStatus, BatchJobStep +from extrai.core.model_registry import ModelRegistry +from extrai.core.extraction_config import ExtractionConfig +from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.config.batch_job_config import BatchJobConfig + +class TestBatchPipelineContext(unittest.IsolatedAsyncioTestCase): + def setUp(self): + # Mocks + self.mock_registry = MagicMock(spec=ModelRegistry) + models = [MagicMock(), MagicMock()] + models[0].__name__ = "ModelA" + models[1].__name__ = "ModelB" + self.mock_registry.models = models + self.mock_registry.get_all_model_names.return_value = ["ModelA", "ModelB"] + self.mock_registry.llm_schema_json = "{}" + + self.mock_client = MagicMock(spec=BaseLLMClient) + self.mock_client.create_batch_job = AsyncMock(return_value=MagicMock(id="job_id")) + self.mock_client.temperature = 0.0 # Required for _create_batch_requests + self.mock_client.model_name = "gpt-4o" # Optional but good for completeness + + self.mock_config = MagicMock(spec=ExtractionConfig) + self.mock_config.use_hierarchical_extraction = True + self.mock_config.num_llm_revisions = 1 + + self.mock_analytics = MagicMock() + self.mock_logger = MagicMock() + + # Patch dependencies + self.patchers = { + "ClientRotator": patch("extrai.core.batch.batch_pipeline.ClientRotator"), + "ExtractionContextPreparer": patch("extrai.core.batch.batch_pipeline.ExtractionContextPreparer"), + "PromptBuilder": patch("extrai.core.batch.batch_pipeline.PromptBuilder"), + "EntityCounter": patch("extrai.core.batch.batch_pipeline.EntityCounter"), + "ConsensusRunner": patch("extrai.core.batch.batch_pipeline.ConsensusRunner"), + "ExtractionRequestFactory": patch("extrai.core.batch.batch_pipeline.ExtractionRequestFactory"), + } + + self.mock_deps = {} + for name, p in self.patchers.items(): + self.mock_deps[name] = p.start() + + self.pipeline = BatchPipeline( + self.mock_registry, + self.mock_client, + self.mock_config, + self.mock_analytics, + self.mock_logger + ) + + # Setup ClientRotator mock to return our mock client + self.pipeline.client_rotator.get_next_client.return_value = self.mock_client + self.pipeline.client_rotator.current_client = self.mock_client + + # Access mock instances + self.pipeline.entity_counter = MagicMock() + self.pipeline.entity_counter.llm_client = self.mock_client + self.pipeline.entity_counter.prepare_counting_prompts.return_value = ("sys", "user") + + self.pipeline.request_factory = MagicMock() + self.pipeline.request_factory.prepare_request.return_value = MagicMock( + system_prompt="sys", user_prompt="user", json_schema={}, response_model=None + ) + + # Update submitter references + self.pipeline.submitter.entity_counter = self.pipeline.entity_counter + self.pipeline.submitter.request_factory = self.pipeline.request_factory + self.pipeline.submitter.client_rotator = self.pipeline.client_rotator + + def tearDown(self): + for p in self.patchers.values(): + p.stop() + + async def test_submit_counting_phase_passes_context(self): + # Setup context and DB session + context = BatchJobContext( + root_batch_id="batch_123", + input_strings=["doc"], + config=BatchJobConfig( + hierarchical=True, + current_model_index=1, + custom_counting_context="ctx", + ) + ) + + # Mock Session + mock_session = MagicMock(spec=Session) + + # Mock Previous Step Result + prev_result = [{"id": 1, "_type": "ModelA"}] + mock_step = BatchJobStep( + batch_id="batch_123", step_index=0, status=BatchJobStatus.COMPLETED, result=prev_result + ) + + # Mock DB query execution + # The code uses db_session.exec(select(...)).all() + mock_exec = MagicMock() + mock_exec.all.return_value = [mock_step] + mock_session.exec.return_value = mock_exec + + # Call method under test - use submitter + await self.pipeline.submitter._submit_counting_phase(context, mock_session, step_index=1) + + # Verification + # 1. Verify DB query was made to fetch previous steps + self.assertTrue(mock_session.exec.called) + + # 2. Verify prepare_counting_prompts called with previous_entities + self.pipeline.entity_counter.prepare_counting_prompts.assert_called_once() + call_args = self.pipeline.entity_counter.prepare_counting_prompts.call_args + + # Args: input_strings, model_names, custom_context, previous_entities + self.assertEqual(call_args[0][0], ["doc"]) # input_strings + self.assertEqual(call_args[0][1], ["ModelB"]) # model_names (index 1) + self.assertEqual(call_args[0][2], "ctx") # custom_context + + # Keyword arg 'previous_entities' check + kwargs = call_args[1] + self.assertIn("previous_entities", kwargs) + self.assertEqual(kwargs["previous_entities"], prev_result) + + async def test_submit_extraction_phase_passes_context(self): + # Setup context + context = BatchJobContext( + root_batch_id="batch_123", + input_strings=["doc"], + config=BatchJobConfig( + hierarchical=True, + current_model_index=1, + ) + ) + + mock_session = MagicMock(spec=Session) + + prev_result = [{"id": 1, "_type": "ModelA"}] + mock_step = BatchJobStep( + batch_id="batch_123", step_index=0, status=BatchJobStatus.COMPLETED, result=prev_result + ) + + mock_exec = MagicMock() + mock_exec.all.return_value = [mock_step] + mock_session.exec.return_value = mock_exec + + # Call method - use submitter + await self.pipeline.submitter._submit_extraction_phase(context, mock_session, step_index=1) + + # Verify request factory call + self.pipeline.request_factory.prepare_request.assert_called_once() diff --git a/tests/core/entity_counter/test_entity_counter.py b/tests/core/entity_counter/test_entity_counter.py index a5feed8..2cfa164 100644 --- a/tests/core/entity_counter/test_entity_counter.py +++ b/tests/core/entity_counter/test_entity_counter.py @@ -1,6 +1,6 @@ import unittest from unittest.mock import MagicMock, patch, AsyncMock -from extrai.core.entity_counter import EntityCounter +from extrai.core.entity_counter import EntityCounter, EntityCountResult from extrai.core.model_registry import ModelRegistry from extrai.core.extraction_config import ExtractionConfig from extrai.core.base_llm_client import BaseLLMClient @@ -12,6 +12,8 @@ def setUp(self): self.mock_client = MagicMock(spec=BaseLLMClient) self.mock_config = MagicMock(spec=ExtractionConfig) self.mock_config.max_validation_retries_per_revision = 1 + self.mock_config.num_counting_revisions = 1 + self.mock_config.use_structured_output = False self.mock_analytics = MagicMock() self.mock_logger = MagicMock() @@ -22,35 +24,39 @@ def setUp(self): self.mock_analytics, self.mock_logger, ) + self.counter.counting_consensus.achieve_consensus = AsyncMock() @patch("extrai.core.entity_counter.generate_entity_counting_system_prompt") @patch("extrai.core.entity_counter.generate_entity_counting_user_prompt") - @patch("extrai.core.entity_counter.create_model") - async def test_count_entities_success( - self, mock_create_model, mock_user_prompt, mock_system_prompt - ): + async def test_count_entities_success(self, mock_user_prompt, mock_system_prompt): # Setup mocks self.mock_model_registry.get_schema_for_models.return_value = ( '{"type": "object"}' ) + mock_result = [ + { + "counted_entities": [ + {"model": "ModelA", "temp_id": "1", "description": "desc"} + ] + } + ] self.mock_client.generate_and_validate_raw_json_output = AsyncMock( - return_value={"ModelA": 5} + return_value=mock_result ) - mock_model_instance = MagicMock() - mock_model_instance.model_dump.return_value = {"ModelA": 5} - - # Mock the dynamically created Pydantic model - MockPydanticModel = MagicMock() - MockPydanticModel.return_value = mock_model_instance - mock_create_model.return_value = MockPydanticModel + expected_consensus = [ + {"model": "ModelA", "temp_id": "1", "description": "desc"} + ] + self.counter.counting_consensus.achieve_consensus.return_value = ( + expected_consensus + ) counts = await self.counter.count_entities(["doc"], ["ModelA"]) - self.assertEqual(counts, {"ModelA": 5}) + self.assertEqual(counts, expected_consensus) self.mock_model_registry.get_schema_for_models.assert_called_with(["ModelA"]) self.mock_client.generate_and_validate_raw_json_output.assert_called_once() - mock_create_model.assert_called_once() + self.counter.counting_consensus.achieve_consensus.assert_called_once() async def test_count_entities_llm_failure(self): self.mock_model_registry.get_schema_for_models.return_value = "{}" @@ -58,22 +64,20 @@ async def test_count_entities_llm_failure(self): side_effect=Exception("LLM Fail") ) - with patch("extrai.core.entity_counter.create_model"): - counts = await self.counter.count_entities(["doc"], ["ModelA"]) + counts = await self.counter.count_entities(["doc"], ["ModelA"]) - self.assertEqual(counts, {}) + self.assertEqual(counts, []) self.mock_logger.error.assert_called_once() async def test_count_entities_invalid_output(self): self.mock_model_registry.get_schema_for_models.return_value = "{}" self.mock_client.generate_and_validate_raw_json_output = AsyncMock( - return_value="Not a dict" + return_value="Not a dict or list" ) - with patch("extrai.core.entity_counter.create_model"): - counts = await self.counter.count_entities(["doc"], ["ModelA"]) + counts = await self.counter.count_entities(["doc"], ["ModelA"]) - self.assertEqual(counts, {}) + self.assertEqual(counts, []) self.mock_logger.warning.assert_called_once() diff --git a/tests/core/extraction_pipeline/test_extraction_pipeline.py b/tests/core/extraction_pipeline/test_extraction_pipeline.py index 9a5c0a3..b4d8f90 100644 --- a/tests/core/extraction_pipeline/test_extraction_pipeline.py +++ b/tests/core/extraction_pipeline/test_extraction_pipeline.py @@ -132,7 +132,7 @@ async def test_count_entities_failure(self): await self.pipeline.extract(["doc"], count_entities=True) self.mock_logger.warning.assert_called_with( - "Entity counting failed: Count failed" + "Entity counting failed or returned None, proceeding with extraction without descriptions" ) def test_repr(self): diff --git a/tests/core/llm_runner/test_llm_runner.py b/tests/core/llm_runner/test_llm_runner.py index 0d6083b..8ef7329 100644 --- a/tests/core/llm_runner/test_llm_runner.py +++ b/tests/core/llm_runner/test_llm_runner.py @@ -46,17 +46,14 @@ def test_client_rotation(self): self.assertEqual(c2, self.mock_client2) self.assertEqual(c3, self.mock_client1) - @patch("extrai.core.llm_runner.normalize_json_revisions") - async def test_run_extraction_cycle_success(self, mock_normalize): + async def test_run_extraction_cycle_success(self): # Setup mocks self.mock_client1.generate_json_revisions = AsyncMock(return_value=[{"id": 1}]) self.mock_client2.generate_json_revisions = AsyncMock(return_value=[{"id": 1}]) - mock_normalize.return_value = [{"id": 1}, {"id": 1}] - # Mock consensus - with patch.object(self.runner, "consensus") as mock_consensus: - mock_consensus.get_consensus.return_value = ([{"id": 1}], {}) + with patch.object(self.runner, "consensus_runner") as mock_consensus_runner: + mock_consensus_runner.run.return_value = [{"id": 1}] results = await self.runner.run_extraction_cycle("sys", "user") @@ -66,8 +63,7 @@ async def test_run_extraction_cycle_success(self, mock_normalize): # Verify calls self.assertEqual(self.mock_client1.generate_json_revisions.call_count, 1) self.assertEqual(self.mock_client2.generate_json_revisions.call_count, 1) - mock_normalize.assert_called_once() - mock_consensus.get_consensus.assert_called_once() + mock_consensus_runner.run.assert_called_once() async def test_run_extraction_cycle_llm_failure(self): self.mock_client1.generate_json_revisions = AsyncMock( @@ -81,23 +77,6 @@ async def test_run_extraction_cycle_llm_failure(self): with self.assertRaises(LLMInteractionError): await self.runner.run_extraction_cycle("sys", "user") - def test_process_consensus_output(self): - # List - res = self.runner._process_consensus_output([{"a": 1}]) - self.assertEqual(res, [{"a": 1}]) - - # None - res = self.runner._process_consensus_output(None) - self.assertEqual(res, []) - - # Dict - res = self.runner._process_consensus_output({"a": 1}) - self.assertEqual(res, [{"a": 1}]) - - # Dict with results - res = self.runner._process_consensus_output({"results": [{"a": 1}]}) - self.assertEqual(res, [{"a": 1}]) - def test_get_client_count(self): self.assertEqual(self.runner.get_client_count(), 2) diff --git a/tests/core/test_base_llm_client.py b/tests/core/test_base_llm_client.py index 60a3ca0..4b65ad5 100644 --- a/tests/core/test_base_llm_client.py +++ b/tests/core/test_base_llm_client.py @@ -1,4 +1,5 @@ import pytest +from typing import Any, Optional, Type from unittest.mock import AsyncMock, patch, call, Mock # Added Mock from sqlmodel import SQLModel @@ -7,13 +8,14 @@ WorkflowAnalyticsCollector, ) # Added for analytics -from extrai.core.base_llm_client import BaseLLMClient +from extrai.core.base_llm_client import BaseLLMClient, ResponseMode from extrai.core.errors import ( LLMOutputParseError, LLMOutputValidationError, LLMAPICallError, LLMRevisionGenerationError, ) +from extrai.llm_providers.generic_openai_client import GenericOpenAIClient # --- Test Fixtures and Mocks --- @@ -31,7 +33,15 @@ def __init__(self, api_key: str = "test_key", model_name: str = "test_model"): # We need to provide a concrete implementation for the abstract method, # even if it's going to be replaced by a mock in most tests. - async def _execute_llm_call(self, system_prompt: str, user_prompt: str) -> str: + async def _execute_llm_call( + self, + system_prompt: str, + user_prompt: str, + response_mode: ResponseMode = ResponseMode.TEXT, + response_model: Optional[Type[Any]] = None, + analytics_collector: Optional[WorkflowAnalyticsCollector] = None, + **kwargs: Any, + ) -> Any: # This default implementation can be overridden by the mock object return "{}" @@ -69,10 +79,10 @@ def mock_analytics_collector() -> Mock: @pytest.mark.asyncio -async def test_generate_all_revisions_orchestrator_logic(mock_client: MockLLMClient): +async def test_generate_revisions_orchestrator_logic(mock_client: MockLLMClient): """ - Tests the internal logic of the `_generate_all_revisions` orchestrator, - mocking the validation_callable to simulate different outcomes. + Tests the internal logic of the `generate_revisions` orchestrator, + mocking the validation_fn to simulate different outcomes. """ valid_output = {"status": "success"} mock_client._execute_llm_call.side_effect = [ @@ -81,9 +91,9 @@ async def test_generate_all_revisions_orchestrator_logic(mock_client: MockLLMCli '{"status": "success"}', # Attempt 3: Success ] - # A mock validation callable - validation_callable_mock = Mock() - validation_callable_mock.side_effect = [ + # A mock validation function + validation_fn_mock = Mock() + validation_fn_mock.side_effect = [ LLMOutputParseError( "Parse Error", "invalid_json" ), # Corresponds to 2nd LLM call @@ -91,24 +101,24 @@ async def test_generate_all_revisions_orchestrator_logic(mock_client: MockLLMCli ] with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: - results = await mock_client._generate_all_revisions( + results = await mock_client.generate_revisions( system_prompt="sys", user_prompt="user", num_revisions=1, - max_validation_retries_per_revision=3, - validation_callable=validation_callable_mock, + max_attempts_per_revision=3, + validation_fn=validation_fn_mock, analytics_collector=None, ) assert len(results) == 1 assert results[0] == valid_output assert mock_client._execute_llm_call.call_count == 3 - assert validation_callable_mock.call_count == 2 # Not called on API error attempt + assert validation_fn_mock.call_count == 2 # Not called on API error attempt # Called for "invalid_json" and '{"status": "success"}' - validation_callable_mock.assert_has_calls( + validation_fn_mock.assert_has_calls( [ - call("invalid_json", "Revision 1, Attempt 2"), - call('{"status": "success"}', "Revision 1, Attempt 3"), + call("invalid_json"), + call('{"status": "success"}'), ] ) assert ( @@ -378,7 +388,7 @@ async def test_generate_json_revisions_unexpected_error_in_processing( with ( patch( "extrai.core.base_llm_client.process_and_validate_llm_output", - side_effect=RuntimeError("Unexpected processing issue"), + side_effect=ValueError("Unexpected processing issue"), ) as mock_validate, patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, ): @@ -490,6 +500,52 @@ async def test_generate_zero_revisions(mock_client: MockLLMClient): assert mock_client._execute_llm_call.call_count == 0 +@pytest.mark.asyncio +async def test_generate_json_revisions_with_cost_calculation( + mock_analytics_collector: Mock, +): + """Tests that cost is calculated and recorded correctly.""" + # Mock the response from the OpenAI client + mock_message = Mock() + mock_message.content = '{"_type": "MockOutputModel", "name": "Test", "value": 123}' + mock_choice = Mock() + mock_choice.message = mock_message + mock_completion = Mock() + mock_completion.usage.prompt_tokens = 1000 + mock_completion.usage.completion_tokens = 2000 + mock_completion.choices = [mock_choice] + + with patch( + "extrai.llm_providers.generic_openai_client.openai.AsyncOpenAI" + ) as mock_openai, patch( + "extrai.core.base_llm_client.process_and_validate_llm_output", + return_value=[{"_type": "MockOutputModel", "name": "Test", "value": 123}], + ), patch( + "extrai.llm_providers.generic_openai_client.calculate_cost", return_value=0.07 + ): + mock_openai.return_value.chat.completions.create = AsyncMock( + return_value=mock_completion + ) + client = GenericOpenAIClient( + api_key="test_key", + model_name="gpt-4-turbo", + base_url="http://localhost:8080", + ) + + await client.generate_json_revisions( + system_prompt="sys", + user_prompt="user", + num_revisions=1, + model_schema_map={"MockOutputModel": MockOutputModelWithType}, + analytics_collector=mock_analytics_collector, + ) + + mock_analytics_collector.record_llm_usage.assert_called_once() + args, kwargs = mock_analytics_collector.record_llm_usage.call_args + assert "cost" in kwargs + assert kwargs["cost"] is not None + + @pytest.mark.asyncio async def test_api_call_error_sleep_multiplier(mock_client: MockLLMClient): """ @@ -571,7 +627,7 @@ def __init__( @pytest.mark.asyncio -async def test_generate_one_revision_with_zero_attempts_raises_runtime_error( +async def test_generate_single_revision_with_zero_attempts_raises_runtime_error( mock_client: MockLLMClient, ): """ @@ -579,13 +635,13 @@ async def test_generate_one_revision_with_zero_attempts_raises_runtime_error( which should only happen if max_attempts is 0. """ with pytest.raises( - RuntimeError, match="Revision generation failed without a recorded error." + RuntimeError, match="Generation failed without recorded error" ): - await mock_client._generate_one_revision_with_retries( + await mock_client._generate_single_revision( system_prompt="sys", user_prompt="user", max_attempts=0, # Set max_attempts to 0 to prevent loop from running - validation_callable=Mock(), + validation_fn=Mock(), analytics_collector=None, revision_index=0, ) @@ -623,7 +679,7 @@ async def test_batch_methods_raise_not_implemented(mock_client: MockLLMClient): await mock_client.create_batch_job([]) with pytest.raises(NotImplementedError, match="Batch processing is not supported"): - await mock_client.retrieve_batch_job("id") + await mock_client.get_batch_status("id") with pytest.raises(NotImplementedError, match="Batch processing is not supported"): await mock_client.list_batch_jobs() diff --git a/tests/core/test_cost_calculator.py b/tests/core/test_cost_calculator.py new file mode 100644 index 0000000..082835f --- /dev/null +++ b/tests/core/test_cost_calculator.py @@ -0,0 +1,34 @@ +# tests/core/test_cost_calculator.py +import unittest +from unittest.mock import patch +from extrai.core.cost_calculator import calculate_cost, ModelCosts + +class TestCostCalculator(unittest.TestCase): + def setUp(self): + self.mock_costs = { + "gpt-4-turbo": ModelCosts( + input_cost_per_million=10.0, + output_cost_per_million=30.0 + ) + } + + def test_calculate_cost_known_model(self): + # Test with a known model + with patch("extrai.core.cost_calculator.MODEL_COSTS", self.mock_costs): + cost = calculate_cost("gpt-4-turbo", 1000, 2000) + self.assertAlmostEqual(cost, 0.07) + + def test_calculate_cost_unknown_model(self): + # Test with an unknown model + with patch("extrai.core.cost_calculator.MODEL_COSTS", self.mock_costs): + cost = calculate_cost("unknown-model", 1000, 2000) + self.assertIsNone(cost) + + def test_calculate_cost_zero_tokens(self): + # Test with zero tokens + with patch("extrai.core.cost_calculator.MODEL_COSTS", self.mock_costs): + cost = calculate_cost("gpt-4-turbo", 0, 0) + self.assertEqual(cost, 0) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_cost_tracking.py b/tests/core/test_cost_tracking.py index f1fb1cc..37c5318 100644 --- a/tests/core/test_cost_tracking.py +++ b/tests/core/test_cost_tracking.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import MagicMock, AsyncMock from extrai.core.analytics_collector import WorkflowAnalyticsCollector +from extrai.core.base_llm_client import ResponseMode from extrai.llm_providers.generic_openai_client import GenericOpenAIClient @@ -50,8 +51,12 @@ async def test_structured_cost_tracking(): client.client.beta.chat.completions.parse = AsyncMock(return_value=mock_completion) # Execute - await client.generate_structured( - "system", "user", MagicMock(), analytics_collector=collector + await client._execute_llm_call( + "system", + "user", + response_mode=ResponseMode.STRUCTURED, + response_model=MagicMock(), + analytics_collector=collector, ) # Verify diff --git a/tests/core/test_sqlmodel_generator.py b/tests/core/test_sqlmodel_generator.py index 16ebb09..575befa 100644 --- a/tests/core/test_sqlmodel_generator.py +++ b/tests/core/test_sqlmodel_generator.py @@ -134,9 +134,7 @@ def test_load_sqlmodel_description_schema_scenarios( self.generator_for_llm_tests._load_sqlmodel_description_schema() ) assert schema == json.loads(mock_open_config["read_data"]) - mock_open_instance.assert_called_once_with( - expected_fallback_path, "r" - ) + mock_open_instance.assert_called_once_with(expected_fallback_path) finally: SQLModelCodeGenerator._SCHEMA_FILE_PATH = original_schema_path diff --git a/tests/core/workflow_orchestrator/test_integration.py b/tests/core/workflow_orchestrator/test_integration.py index b1f9bf2..d9a2944 100644 --- a/tests/core/workflow_orchestrator/test_integration.py +++ b/tests/core/workflow_orchestrator/test_integration.py @@ -87,12 +87,12 @@ async def test_successful_synthesis_clear_consensus(self): expected_consensus_input = [revision_content] * 2 with mock.patch.object( - self.orchestrator.pipeline.llm_runner.consensus, - "get_consensus", - return_value=(mock_consensus_output, mock_analytics_for_clear_consensus), - ) as mock_get_consensus_call: + self.orchestrator.pipeline.llm_runner.consensus_runner, + "run", + return_value=mock_consensus_output, + ) as mock_run_call: await self.orchestrator.synthesize(["input"], self.db_session) - mock_get_consensus_call.assert_called_once_with(expected_consensus_input) + mock_run_call.assert_called_once_with(expected_consensus_input) self.assertEqual(self.mock_llm_client1.call_count, 1) self.assertEqual(self.mock_llm_client2.call_count, 1) diff --git a/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py b/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py index 5f7aac8..8ac8b15 100644 --- a/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py +++ b/tests/core/workflow_orchestrator/test_workflow_orchestrator_batch.py @@ -59,20 +59,22 @@ async def test_process_batch_success(self): batch_id = "batch_123" db_session = mock.Mock(spec=Session) hydrated_objects = ["obj1"] + pk_map = {"1": "new_1"} expected_result = BatchProcessResult( - status=BatchJobStatus.COMPLETED, hydrated_objects=hydrated_objects + status=BatchJobStatus.COMPLETED, + hydrated_objects=hydrated_objects, + original_pk_map=pk_map, ) self.orchestrator.batch_pipeline.process_batch.return_value = expected_result + self.orchestrator.result_processor.original_pk_map = {} result = await self.orchestrator.process_batch(batch_id, db_session) self.orchestrator.batch_pipeline.process_batch.assert_called_once_with( batch_id, db_session ) - self.orchestrator.result_processor.persist.assert_called_once_with( - hydrated_objects, db_session - ) self.assertEqual(result, expected_result) + self.assertEqual(self.orchestrator.result_processor.original_pk_map, pk_map) async def test_process_batch_not_completed(self): batch_id = "batch_123" @@ -88,69 +90,26 @@ async def test_process_batch_not_completed(self): self.orchestrator.result_processor.persist.assert_not_called() self.assertEqual(result, expected_result) - async def test_process_batch_persistence_failure(self): - batch_id = "batch_123" - db_session = mock.Mock(spec=Session) - hydrated_objects = ["obj1"] - process_result = BatchProcessResult( - status=BatchJobStatus.COMPLETED, hydrated_objects=hydrated_objects - ) - - self.orchestrator.batch_pipeline.process_batch.return_value = process_result - self.orchestrator.result_processor.persist.side_effect = Exception( - "Persistence Error" - ) - - with self.assertRaisesRegex(Exception, "Persistence Error"): - await self.orchestrator.process_batch(batch_id, db_session) - - self.orchestrator.logger.error.assert_called() - self.assertIn( - "Extraction successful but persistence failed", process_result.message - ) - - async def test_monitor_batch_job_counting_transition(self): + async def test_monitor_batch_job_delegation(self): batch_id = "batch_123" db_session = mock.Mock(spec=Session) + poll_interval = 5 - # Mock status sequence: - # 1. COUNTING_READY_TO_PROCESS -> triggers first process_batch - # 2. PROCESSING -> waits - # 3. READY_TO_PROCESS -> triggers second process_batch - self.orchestrator.batch_pipeline.get_status.side_effect = [ - BatchJobStatus.COUNTING_READY_TO_PROCESS, - BatchJobStatus.PROCESSING, - BatchJobStatus.READY_TO_PROCESS, - ] - - # Mock process results - # 1. Result of processing COUNTING_READY: new batch submitted (PROCESSING) - process_result_1 = BatchProcessResult( - status=BatchJobStatus.PROCESSING, - message="Transitioned from counting to extraction", - ) - # 2. Result of processing READY_TO_PROCESS: completed - process_result_2 = BatchProcessResult( + expected_result = BatchProcessResult( status=BatchJobStatus.COMPLETED, hydrated_objects=["obj1"] ) + self.orchestrator.batch_pipeline.monitor_batch_job.return_value = ( + expected_result + ) - self.orchestrator.batch_pipeline.process_batch.side_effect = [ - process_result_1, - process_result_2, - ] - - # Run monitoring with short poll interval result = await self.orchestrator.monitor_batch_job( - batch_id, db_session, poll_interval=0.001 + batch_id, db_session, poll_interval ) - # Verify final result - self.assertEqual(result.status, BatchJobStatus.COMPLETED) - self.assertEqual(result.hydrated_objects, ["obj1"]) - - # Verify calls - self.assertEqual(self.orchestrator.batch_pipeline.get_status.call_count, 3) - self.assertEqual(self.orchestrator.batch_pipeline.process_batch.call_count, 2) + self.orchestrator.batch_pipeline.monitor_batch_job.assert_called_once_with( + batch_id, db_session, poll_interval + ) + self.assertEqual(result, expected_result) if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index 6271399..5e2812d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,6 @@ version = 1 revision = 3 -requires-python = ">=3.13" +requires-python = ">=3.12" [[package]] name = "accessible-pygments" @@ -47,6 +47,7 @@ version = "4.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "idna" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/16/ce/8a777047513153587e5434fd752e89334ac33e379aa3497db860eeb60377/anyio-4.12.0.tar.gz", hash = "sha256:73c693b567b0c55130c104d0b43a9baf3aa6a31fc6110116509f27bf75e21ec0", size = 228266, upload-time = "2025-11-28T23:37:38.911Z" } wheels = [ @@ -116,22 +117,21 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, + { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, + { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, + { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, { url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" }, { url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" }, - { url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" }, - { url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" }, { url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" }, { url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" }, { url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" }, { url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049, upload-time = "2025-09-08T23:23:20.853Z" }, - { url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793, upload-time = "2025-09-08T23:23:22.08Z" }, - { url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300, upload-time = "2025-09-08T23:23:23.314Z" }, { url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244, upload-time = "2025-09-08T23:23:24.541Z" }, { url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828, upload-time = "2025-09-08T23:23:26.143Z" }, { url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926, upload-time = "2025-09-08T23:23:27.873Z" }, { url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593, upload-time = "2025-09-08T23:23:31.91Z" }, - { url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354, upload-time = "2025-09-08T23:23:33.214Z" }, - { url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480, upload-time = "2025-09-08T23:23:34.495Z" }, { url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584, upload-time = "2025-09-08T23:23:36.096Z" }, { url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443, upload-time = "2025-09-08T23:23:37.328Z" }, { url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437, upload-time = "2025-09-08T23:23:38.945Z" }, @@ -143,6 +143,22 @@ version = "3.4.4" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, + { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, + { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, + { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, + { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, + { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, + { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, + { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, + { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, + { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, + { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, + { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, + { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, @@ -193,6 +209,19 @@ version = "7.12.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/89/26/4a96807b193b011588099c3b5c89fbb05294e5b90e71018e065465f34eb6/coverage-7.12.0.tar.gz", hash = "sha256:fc11e0a4e372cb5f282f16ef90d4a585034050ccda536451901abfb19a57f40c", size = 819341, upload-time = "2025-11-18T13:34:20.766Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/02/bf/638c0427c0f0d47638242e2438127f3c8ee3cfc06c7fdeb16778ed47f836/coverage-7.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:29644c928772c78512b48e14156b81255000dcfd4817574ff69def189bcb3647", size = 217704, upload-time = "2025-11-18T13:32:28.906Z" }, + { url = "https://files.pythonhosted.org/packages/08/e1/706fae6692a66c2d6b871a608bbde0da6281903fa0e9f53a39ed441da36a/coverage-7.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8638cbb002eaa5d7c8d04da667813ce1067080b9a91099801a0053086e52b736", size = 218064, upload-time = "2025-11-18T13:32:30.161Z" }, + { url = "https://files.pythonhosted.org/packages/a9/8b/eb0231d0540f8af3ffda39720ff43cb91926489d01524e68f60e961366e4/coverage-7.12.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:083631eeff5eb9992c923e14b810a179798bb598e6a0dd60586819fc23be6e60", size = 249560, upload-time = "2025-11-18T13:32:31.835Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a1/67fb52af642e974d159b5b379e4d4c59d0ebe1288677fbd04bbffe665a82/coverage-7.12.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:99d5415c73ca12d558e07776bd957c4222c687b9f1d26fa0e1b57e3598bdcde8", size = 252318, upload-time = "2025-11-18T13:32:33.178Z" }, + { url = "https://files.pythonhosted.org/packages/41/e5/38228f31b2c7665ebf9bdfdddd7a184d56450755c7e43ac721c11a4b8dab/coverage-7.12.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e949ebf60c717c3df63adb4a1a366c096c8d7fd8472608cd09359e1bd48ef59f", size = 253403, upload-time = "2025-11-18T13:32:34.45Z" }, + { url = "https://files.pythonhosted.org/packages/ec/4b/df78e4c8188f9960684267c5a4897836f3f0f20a20c51606ee778a1d9749/coverage-7.12.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6d907ddccbca819afa2cd014bc69983b146cca2735a0b1e6259b2a6c10be1e70", size = 249984, upload-time = "2025-11-18T13:32:35.747Z" }, + { url = "https://files.pythonhosted.org/packages/ba/51/bb163933d195a345c6f63eab9e55743413d064c291b6220df754075c2769/coverage-7.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b1518ecbad4e6173f4c6e6c4a46e49555ea5679bf3feda5edb1b935c7c44e8a0", size = 251339, upload-time = "2025-11-18T13:32:37.352Z" }, + { url = "https://files.pythonhosted.org/packages/15/40/c9b29cdb8412c837cdcbc2cfa054547dd83affe6cbbd4ce4fdb92b6ba7d1/coverage-7.12.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51777647a749abdf6f6fd8c7cffab12de68ab93aab15efc72fbbb83036c2a068", size = 249489, upload-time = "2025-11-18T13:32:39.212Z" }, + { url = "https://files.pythonhosted.org/packages/c8/da/b3131e20ba07a0de4437a50ef3b47840dfabf9293675b0cd5c2c7f66dd61/coverage-7.12.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:42435d46d6461a3b305cdfcad7cdd3248787771f53fe18305548cba474e6523b", size = 249070, upload-time = "2025-11-18T13:32:40.598Z" }, + { url = "https://files.pythonhosted.org/packages/70/81/b653329b5f6302c08d683ceff6785bc60a34be9ae92a5c7b63ee7ee7acec/coverage-7.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5bcead88c8423e1855e64b8057d0544e33e4080b95b240c2a355334bb7ced937", size = 250929, upload-time = "2025-11-18T13:32:42.915Z" }, + { url = "https://files.pythonhosted.org/packages/a3/00/250ac3bca9f252a5fb1338b5ad01331ebb7b40223f72bef5b1b2cb03aa64/coverage-7.12.0-cp312-cp312-win32.whl", hash = "sha256:dcbb630ab034e86d2a0f79aefd2be07e583202f41e037602d438c80044957baa", size = 220241, upload-time = "2025-11-18T13:32:44.665Z" }, + { url = "https://files.pythonhosted.org/packages/64/1c/77e79e76d37ce83302f6c21980b45e09f8aa4551965213a10e62d71ce0ab/coverage-7.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:2fd8354ed5d69775ac42986a691fbf68b4084278710cee9d7c3eaa0c28fa982a", size = 221051, upload-time = "2025-11-18T13:32:46.008Z" }, + { url = "https://files.pythonhosted.org/packages/31/f5/641b8a25baae564f9e52cac0e2667b123de961985709a004e287ee7663cc/coverage-7.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:737c3814903be30695b2de20d22bcc5428fdae305c61ba44cdc8b3252984c49c", size = 219692, upload-time = "2025-11-18T13:32:47.372Z" }, { url = "https://files.pythonhosted.org/packages/b8/14/771700b4048774e48d2c54ed0c674273702713c9ee7acdfede40c2666747/coverage-7.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:47324fffca8d8eae7e185b5bb20c14645f23350f870c1649003618ea91a78941", size = 217725, upload-time = "2025-11-18T13:32:49.22Z" }, { url = "https://files.pythonhosted.org/packages/17/a7/3aa4144d3bcb719bf67b22d2d51c2d577bf801498c13cb08f64173e80497/coverage-7.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ccf3b2ede91decd2fb53ec73c1f949c3e034129d1e0b07798ff1d02ea0c8fa4a", size = 218098, upload-time = "2025-11-18T13:32:50.78Z" }, { url = "https://files.pythonhosted.org/packages/fc/9c/b846bbc774ff81091a12a10203e70562c91ae71badda00c5ae5b613527b1/coverage-7.12.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b365adc70a6936c6b0582dc38746b33b2454148c02349345412c6e743efb646d", size = 249093, upload-time = "2025-11-18T13:32:52.554Z" }, @@ -261,10 +290,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807, upload-time = "2025-10-15T23:16:56.414Z" }, { url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615, upload-time = "2025-10-15T23:16:58.442Z" }, { url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800, upload-time = "2025-10-15T23:17:00.378Z" }, - { url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707, upload-time = "2025-10-15T23:17:01.98Z" }, { url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541, upload-time = "2025-10-15T23:17:04.078Z" }, { url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464, upload-time = "2025-10-15T23:17:05.483Z" }, - { url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838, upload-time = "2025-10-15T23:17:07.425Z" }, { url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596, upload-time = "2025-10-15T23:17:09.343Z" }, { url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782, upload-time = "2025-10-15T23:17:11.22Z" }, { url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381, upload-time = "2025-10-15T23:17:12.829Z" }, @@ -272,10 +299,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078, upload-time = "2025-10-15T23:17:23.042Z" }, { url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460, upload-time = "2025-10-15T23:17:24.885Z" }, { url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237, upload-time = "2025-10-15T23:17:26.449Z" }, - { url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344, upload-time = "2025-10-15T23:17:28.06Z" }, { url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564, upload-time = "2025-10-15T23:17:29.665Z" }, { url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415, upload-time = "2025-10-15T23:17:31.686Z" }, - { url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457, upload-time = "2025-10-15T23:17:33.478Z" }, { url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074, upload-time = "2025-10-15T23:17:35.158Z" }, { url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569, upload-time = "2025-10-15T23:17:37.188Z" }, { url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941, upload-time = "2025-10-15T23:17:39.236Z" }, @@ -283,10 +308,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029, upload-time = "2025-10-15T23:17:49.837Z" }, { url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222, upload-time = "2025-10-15T23:17:51.357Z" }, { url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280, upload-time = "2025-10-15T23:17:52.964Z" }, - { url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958, upload-time = "2025-10-15T23:17:54.965Z" }, { url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714, upload-time = "2025-10-15T23:17:56.754Z" }, { url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970, upload-time = "2025-10-15T23:17:58.588Z" }, - { url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236, upload-time = "2025-10-15T23:18:00.897Z" }, { url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642, upload-time = "2025-10-15T23:18:02.749Z" }, { url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126, upload-time = "2025-10-15T23:18:04.85Z" }, { url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573, upload-time = "2025-10-15T23:18:06.908Z" }, @@ -315,6 +338,7 @@ name = "extrai-workflow" version = "1.0.1" source = { editable = "." } dependencies = [ + { name = "google-genai" }, { name = "jsonschema" }, { name = "openai" }, { name = "pydantic" }, @@ -343,6 +367,7 @@ requires-dist = [ { name = "build", marker = "extra == 'dev'" }, { name = "coverage", marker = "extra == 'dev'" }, { name = "furo", marker = "extra == 'dev'" }, + { name = "google-genai", specifier = ">=1.59.0" }, { name = "jsonschema", specifier = "==4.25.1" }, { name = "openai", specifier = "==2.9.0" }, { name = "pydantic", specifier = "==2.12.5" }, @@ -374,16 +399,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/69/964b55f389c289e16ba2a5dfe587c3c462aac09e24123f09ddf703889584/furo-2025.9.25-py3-none-any.whl", hash = "sha256:2937f68e823b8e37b410c972c371bc2b1d88026709534927158e0cb3fac95afe", size = 340409, upload-time = "2025-09-25T21:37:17.244Z" }, ] +[[package]] +name = "google-auth" +version = "2.47.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/3c/ec64b9a275ca22fa1cd3b6e77fefcf837b0732c890aa32d2bd21313d9b33/google_auth-2.47.0.tar.gz", hash = "sha256:833229070a9dfee1a353ae9877dcd2dec069a8281a4e72e72f77d4a70ff945da", size = 323719, upload-time = "2026-01-06T21:55:31.045Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl", hash = "sha256:c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498", size = 234867, upload-time = "2026-01-06T21:55:28.6Z" }, +] + +[package.optional-dependencies] +requests = [ + { name = "requests" }, +] + +[[package]] +name = "google-genai" +version = "1.59.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "google-auth", extra = ["requests"] }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "sniffio" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/34/c03bcbc759d67ac3d96077838cdc1eac85417de6ea3b65b313fe53043eee/google_genai-1.59.0.tar.gz", hash = "sha256:0b7a2dc24582850ae57294209d8dfc2c4f5fcfde0a3f11d81dc5aca75fb619e2", size = 487374, upload-time = "2026-01-15T20:29:46.619Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/53/6d00692fe50d73409b3406ae90c71bc4499c8ae7fac377ba16e283da917c/google_genai-1.59.0-py3-none-any.whl", hash = "sha256:59fc01a225d074fe9d1e626c3433da292f33249dadce4deb34edea698305a6df", size = 719099, upload-time = "2026-01-15T20:29:44.604Z" }, +] + [[package]] name = "greenlet" version = "3.3.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/c7/e5/40dbda2736893e3e53d25838e0f19a2b417dfc122b9989c91918db30b5d3/greenlet-3.3.0.tar.gz", hash = "sha256:a82bb225a4e9e4d653dd2fb7b8b2d36e4fb25bc0165422a11e48b88e9e6f78fb", size = 190651, upload-time = "2025-12-04T14:49:44.05Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, + { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, + { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, + { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, + { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, + { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, + { url = "https://files.pythonhosted.org/packages/6c/79/3912a94cf27ec503e51ba493692d6db1e3cd8ac7ac52b0b47c8e33d7f4f9/greenlet-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:a7a34b13d43a6b78abf828a6d0e87d3385680eaf830cd60d20d52f249faabf39", size = 301964, upload-time = "2025-12-04T14:36:58.316Z" }, { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, - { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -391,7 +461,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/7c/f0a6d0ede2c7bf092d00bc83ad5bafb7e6ec9b4aab2fbdfa6f134dc73327/greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f", size = 275671, upload-time = "2025-12-04T14:23:05.267Z" }, { url = "https://files.pythonhosted.org/packages/44/06/dac639ae1a50f5969d82d2e3dd9767d30d6dbdbab0e1a54010c8fe90263c/greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365", size = 646360, upload-time = "2025-12-04T14:50:10.026Z" }, { url = "https://files.pythonhosted.org/packages/e0/94/0fb76fe6c5369fba9bf98529ada6f4c3a1adf19e406a47332245ef0eb357/greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3", size = 658160, upload-time = "2025-12-04T14:57:45.41Z" }, - { url = "https://files.pythonhosted.org/packages/93/79/d2c70cae6e823fac36c3bbc9077962105052b7ef81db2f01ec3b9bf17e2b/greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45", size = 671388, upload-time = "2025-12-04T15:07:15.789Z" }, { url = "https://files.pythonhosted.org/packages/b8/14/bab308fc2c1b5228c3224ec2bf928ce2e4d21d8046c161e44a2012b5203e/greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955", size = 660166, upload-time = "2025-12-04T14:26:05.099Z" }, { url = "https://files.pythonhosted.org/packages/4b/d2/91465d39164eaa0085177f61983d80ffe746c5a1860f009811d498e7259c/greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55", size = 1615193, upload-time = "2025-12-04T15:04:27.041Z" }, { url = "https://files.pythonhosted.org/packages/42/1b/83d110a37044b92423084d52d5d5a3b3a73cafb51b547e6d7366ff62eff1/greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc", size = 1683653, upload-time = "2025-12-04T14:27:32.366Z" }, @@ -399,7 +468,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/66/bd6317bc5932accf351fc19f177ffba53712a202f9df10587da8df257c7e/greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931", size = 282638, upload-time = "2025-12-04T14:25:20.941Z" }, { url = "https://files.pythonhosted.org/packages/30/cf/cc81cb030b40e738d6e69502ccbd0dd1bced0588e958f9e757945de24404/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388", size = 651145, upload-time = "2025-12-04T14:50:11.039Z" }, { url = "https://files.pythonhosted.org/packages/9c/ea/1020037b5ecfe95ca7df8d8549959baceb8186031da83d5ecceff8b08cd2/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3", size = 654236, upload-time = "2025-12-04T14:57:47.007Z" }, - { url = "https://files.pythonhosted.org/packages/69/cc/1e4bae2e45ca2fa55299f4e85854606a78ecc37fead20d69322f96000504/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221", size = 662506, upload-time = "2025-12-04T15:07:16.906Z" }, { url = "https://files.pythonhosted.org/packages/57/b9/f8025d71a6085c441a7eaff0fd928bbb275a6633773667023d19179fe815/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b", size = 653783, upload-time = "2025-12-04T14:26:06.225Z" }, { url = "https://files.pythonhosted.org/packages/f6/c7/876a8c7a7485d5d6b5c6821201d542ef28be645aa024cfe1145b35c120c1/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd", size = 1614857, upload-time = "2025-12-04T15:04:28.484Z" }, { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" }, @@ -541,6 +609,19 @@ version = "0.12.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/45/9d/e0660989c1370e25848bb4c52d061c71837239738ad937e83edca174c273/jiter-0.12.0.tar.gz", hash = "sha256:64dfcd7d5c168b38d3f9f8bba7fc639edb3418abcc74f22fdbe6b8938293f30b", size = 168294, upload-time = "2025-11-09T20:49:23.302Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/92/c9/5b9f7b4983f1b542c64e84165075335e8a236fa9e2ea03a0c79780062be8/jiter-0.12.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:305e061fa82f4680607a775b2e8e0bcb071cd2205ac38e6ef48c8dd5ebe1cf37", size = 314449, upload-time = "2025-11-09T20:47:22.999Z" }, + { url = "https://files.pythonhosted.org/packages/98/6e/e8efa0e78de00db0aee82c0cf9e8b3f2027efd7f8a71f859d8f4be8e98ef/jiter-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c1860627048e302a528333c9307c818c547f214d8659b0705d2195e1a94b274", size = 319855, upload-time = "2025-11-09T20:47:24.779Z" }, + { url = "https://files.pythonhosted.org/packages/20/26/894cd88e60b5d58af53bec5c6759d1292bd0b37a8b5f60f07abf7a63ae5f/jiter-0.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df37577a4f8408f7e0ec3205d2a8f87672af8f17008358063a4d6425b6081ce3", size = 350171, upload-time = "2025-11-09T20:47:26.469Z" }, + { url = "https://files.pythonhosted.org/packages/f5/27/a7b818b9979ac31b3763d25f3653ec3a954044d5e9f5d87f2f247d679fd1/jiter-0.12.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:75fdd787356c1c13a4f40b43c2156276ef7a71eb487d98472476476d803fb2cf", size = 365590, upload-time = "2025-11-09T20:47:27.918Z" }, + { url = "https://files.pythonhosted.org/packages/ba/7e/e46195801a97673a83746170b17984aa8ac4a455746354516d02ca5541b4/jiter-0.12.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1eb5db8d9c65b112aacf14fcd0faae9913d07a8afea5ed06ccdd12b724e966a1", size = 479462, upload-time = "2025-11-09T20:47:29.654Z" }, + { url = "https://files.pythonhosted.org/packages/ca/75/f833bfb009ab4bd11b1c9406d333e3b4357709ed0570bb48c7c06d78c7dd/jiter-0.12.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:73c568cc27c473f82480abc15d1301adf333a7ea4f2e813d6a2c7d8b6ba8d0df", size = 378983, upload-time = "2025-11-09T20:47:31.026Z" }, + { url = "https://files.pythonhosted.org/packages/71/b3/7a69d77943cc837d30165643db753471aff5df39692d598da880a6e51c24/jiter-0.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4321e8a3d868919bcb1abb1db550d41f2b5b326f72df29e53b2df8b006eb9403", size = 361328, upload-time = "2025-11-09T20:47:33.286Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ac/a78f90caf48d65ba70d8c6efc6f23150bc39dc3389d65bbec2a95c7bc628/jiter-0.12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0a51bad79f8cc9cac2b4b705039f814049142e0050f30d91695a2d9a6611f126", size = 386740, upload-time = "2025-11-09T20:47:34.703Z" }, + { url = "https://files.pythonhosted.org/packages/39/b6/5d31c2cc8e1b6a6bcf3c5721e4ca0a3633d1ab4754b09bc7084f6c4f5327/jiter-0.12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2a67b678f6a5f1dd6c36d642d7db83e456bc8b104788262aaefc11a22339f5a9", size = 520875, upload-time = "2025-11-09T20:47:36.058Z" }, + { url = "https://files.pythonhosted.org/packages/30/b5/4df540fae4e9f68c54b8dab004bd8c943a752f0b00efd6e7d64aa3850339/jiter-0.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efe1a211fe1fd14762adea941e3cfd6c611a136e28da6c39272dbb7a1bbe6a86", size = 511457, upload-time = "2025-11-09T20:47:37.932Z" }, + { url = "https://files.pythonhosted.org/packages/07/65/86b74010e450a1a77b2c1aabb91d4a91dd3cd5afce99f34d75fd1ac64b19/jiter-0.12.0-cp312-cp312-win32.whl", hash = "sha256:d779d97c834b4278276ec703dc3fc1735fca50af63eb7262f05bdb4e62203d44", size = 204546, upload-time = "2025-11-09T20:47:40.47Z" }, + { url = "https://files.pythonhosted.org/packages/1c/c7/6659f537f9562d963488e3e55573498a442503ced01f7e169e96a6110383/jiter-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:e8269062060212b373316fe69236096aaf4c49022d267c6736eebd66bbbc60bb", size = 205196, upload-time = "2025-11-09T20:47:41.794Z" }, + { url = "https://files.pythonhosted.org/packages/21/f4/935304f5169edadfec7f9c01eacbce4c90bb9a82035ac1de1f3bd2d40be6/jiter-0.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:06cb970936c65de926d648af0ed3d21857f026b1cf5525cb2947aa5e01e05789", size = 186100, upload-time = "2025-11-09T20:47:43.007Z" }, { url = "https://files.pythonhosted.org/packages/3d/a6/97209693b177716e22576ee1161674d1d58029eb178e01866a0422b69224/jiter-0.12.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:6cc49d5130a14b732e0612bc76ae8db3b49898732223ef8b7599aa8d9810683e", size = 313658, upload-time = "2025-11-09T20:47:44.424Z" }, { url = "https://files.pythonhosted.org/packages/06/4d/125c5c1537c7d8ee73ad3d530a442d6c619714b95027143f1b61c0b4dfe0/jiter-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:37f27a32ce36364d2fa4f7fdc507279db604d27d239ea2e044c8f148410defe1", size = 318605, upload-time = "2025-11-09T20:47:45.973Z" }, { url = "https://files.pythonhosted.org/packages/99/bf/a840b89847885064c41a5f52de6e312e91fa84a520848ee56c97e4fa0205/jiter-0.12.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbc0944aa3d4b4773e348cda635252824a78f4ba44328e042ef1ff3f6080d1cf", size = 349803, upload-time = "2025-11-09T20:47:47.535Z" }, @@ -584,6 +665,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/51/2cb4468b3448a8385ebcd15059d325c9ce67df4e2758d133ab9442b19834/jiter-0.12.0-cp314-cp314t-win32.whl", hash = "sha256:8bbcfe2791dfdb7c5e48baf646d37a6a3dcb5a97a032017741dea9f817dca183", size = 205110, upload-time = "2025-11-09T20:48:47.033Z" }, { url = "https://files.pythonhosted.org/packages/b2/c5/ae5ec83dec9c2d1af805fd5fe8f74ebded9c8670c5210ec7820ce0dbeb1e/jiter-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:2fa940963bf02e1d8226027ef461e36af472dea85d36054ff835aeed944dd873", size = 205223, upload-time = "2025-11-09T20:48:49.076Z" }, { url = "https://files.pythonhosted.org/packages/97/9a/3c5391907277f0e55195550cf3fa8e293ae9ee0c00fb402fec1e38c0c82f/jiter-0.12.0-cp314-cp314t-win_arm64.whl", hash = "sha256:506c9708dd29b27288f9f8f1140c3cb0e3d8ddb045956d7757b1fa0e0f39a473", size = 185564, upload-time = "2025-11-09T20:48:50.376Z" }, + { url = "https://files.pythonhosted.org/packages/cb/f5/12efb8ada5f5c9edc1d4555fe383c1fb2eac05ac5859258a72d61981d999/jiter-0.12.0-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:e8547883d7b96ef2e5fe22b88f8a4c8725a56e7f4abafff20fd5272d634c7ecb", size = 309974, upload-time = "2025-11-09T20:49:17.187Z" }, + { url = "https://files.pythonhosted.org/packages/85/15/d6eb3b770f6a0d332675141ab3962fd4a7c270ede3515d9f3583e1d28276/jiter-0.12.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:89163163c0934854a668ed783a2546a0617f71706a2551a4a0666d91ab365d6b", size = 304233, upload-time = "2025-11-09T20:49:18.734Z" }, + { url = "https://files.pythonhosted.org/packages/8c/3e/e7e06743294eea2cf02ced6aa0ff2ad237367394e37a0e2b4a1108c67a36/jiter-0.12.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d96b264ab7d34bbb2312dedc47ce07cd53f06835eacbc16dde3761f47c3a9e7f", size = 338537, upload-time = "2025-11-09T20:49:20.317Z" }, + { url = "https://files.pythonhosted.org/packages/2f/9c/6753e6522b8d0ef07d3a3d239426669e984fb0eba15a315cdbc1253904e4/jiter-0.12.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24e864cb30ab82311c6425655b0cdab0a98c5d973b065c66a3f020740c2324c", size = 346110, upload-time = "2025-11-09T20:49:21.817Z" }, ] [[package]] @@ -648,6 +733,17 @@ version = "3.0.3" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/72/147da192e38635ada20e0a2e1a51cf8823d2119ce8883f7053879c2199b5/markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e", size = 11615, upload-time = "2025-09-27T18:36:30.854Z" }, + { url = "https://files.pythonhosted.org/packages/9a/81/7e4e08678a1f98521201c3079f77db69fb552acd56067661f8c2f534a718/markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce", size = 12020, upload-time = "2025-09-27T18:36:31.971Z" }, + { url = "https://files.pythonhosted.org/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d", size = 24332, upload-time = "2025-09-27T18:36:32.813Z" }, + { url = "https://files.pythonhosted.org/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d", size = 22947, upload-time = "2025-09-27T18:36:33.86Z" }, + { url = "https://files.pythonhosted.org/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a", size = 21962, upload-time = "2025-09-27T18:36:35.099Z" }, + { url = "https://files.pythonhosted.org/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b", size = 23760, upload-time = "2025-09-27T18:36:36.001Z" }, + { url = "https://files.pythonhosted.org/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f", size = 21529, upload-time = "2025-09-27T18:36:36.906Z" }, + { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, + { url = "https://files.pythonhosted.org/packages/2f/e1/78ee7a023dac597a5825441ebd17170785a9dab23de95d2c7508ade94e0e/markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d", size = 14540, upload-time = "2025-09-27T18:36:38.761Z" }, + { url = "https://files.pythonhosted.org/packages/aa/5b/bec5aa9bbbb2c946ca2733ef9c4ca91c91b6a24580193e891b5f7dbe8e1e/markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c", size = 15105, upload-time = "2025-09-27T18:36:39.701Z" }, + { url = "https://files.pythonhosted.org/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906, upload-time = "2025-09-27T18:36:40.689Z" }, { url = "https://files.pythonhosted.org/packages/38/2f/907b9c7bbba283e68f20259574b13d005c121a0fa4c175f9bed27c4597ff/markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795", size = 11622, upload-time = "2025-09-27T18:36:41.777Z" }, { url = "https://files.pythonhosted.org/packages/9c/d9/5f7756922cdd676869eca1c4e3c0cd0df60ed30199ffd775e319089cb3ed/markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219", size = 12029, upload-time = "2025-09-27T18:36:43.257Z" }, { url = "https://files.pythonhosted.org/packages/00/07/575a68c754943058c78f30db02ee03a64b3c638586fba6a6dd56830b30a3/markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6", size = 24374, upload-time = "2025-09-27T18:36:44.508Z" }, @@ -782,6 +878,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "2.23" @@ -815,6 +932,20 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/71/70/23b021c950c2addd24ec408e9ab05d59b035b39d97cdc1130e1bce647bb6/pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e", size = 460952, upload-time = "2025-11-04T13:43:49.098Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/5d/5f6c63eebb5afee93bcaae4ce9a898f3373ca23df3ccaef086d0233a35a7/pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7", size = 2110990, upload-time = "2025-11-04T13:39:58.079Z" }, + { url = "https://files.pythonhosted.org/packages/aa/32/9c2e8ccb57c01111e0fd091f236c7b371c1bccea0fa85247ac55b1e2b6b6/pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0", size = 1896003, upload-time = "2025-11-04T13:39:59.956Z" }, + { url = "https://files.pythonhosted.org/packages/68/b8/a01b53cb0e59139fbc9e4fda3e9724ede8de279097179be4ff31f1abb65a/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69", size = 1919200, upload-time = "2025-11-04T13:40:02.241Z" }, + { url = "https://files.pythonhosted.org/packages/38/de/8c36b5198a29bdaade07b5985e80a233a5ac27137846f3bc2d3b40a47360/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75", size = 2052578, upload-time = "2025-11-04T13:40:04.401Z" }, + { url = "https://files.pythonhosted.org/packages/00/b5/0e8e4b5b081eac6cb3dbb7e60a65907549a1ce035a724368c330112adfdd/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05", size = 2208504, upload-time = "2025-11-04T13:40:06.072Z" }, + { url = "https://files.pythonhosted.org/packages/77/56/87a61aad59c7c5b9dc8caad5a41a5545cba3810c3e828708b3d7404f6cef/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc", size = 2335816, upload-time = "2025-11-04T13:40:07.835Z" }, + { url = "https://files.pythonhosted.org/packages/0d/76/941cc9f73529988688a665a5c0ecff1112b3d95ab48f81db5f7606f522d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c", size = 2075366, upload-time = "2025-11-04T13:40:09.804Z" }, + { url = "https://files.pythonhosted.org/packages/d3/43/ebef01f69baa07a482844faaa0a591bad1ef129253ffd0cdaa9d8a7f72d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5", size = 2171698, upload-time = "2025-11-04T13:40:12.004Z" }, + { url = "https://files.pythonhosted.org/packages/b1/87/41f3202e4193e3bacfc2c065fab7706ebe81af46a83d3e27605029c1f5a6/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c", size = 2132603, upload-time = "2025-11-04T13:40:13.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7d/4c00df99cb12070b6bccdef4a195255e6020a550d572768d92cc54dba91a/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294", size = 2329591, upload-time = "2025-11-04T13:40:15.672Z" }, + { url = "https://files.pythonhosted.org/packages/cc/6a/ebf4b1d65d458f3cda6a7335d141305dfa19bdc61140a884d165a8a1bbc7/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1", size = 2319068, upload-time = "2025-11-04T13:40:17.532Z" }, + { url = "https://files.pythonhosted.org/packages/49/3b/774f2b5cd4192d5ab75870ce4381fd89cf218af999515baf07e7206753f0/pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d", size = 1985908, upload-time = "2025-11-04T13:40:19.309Z" }, + { url = "https://files.pythonhosted.org/packages/86/45/00173a033c801cacf67c190fef088789394feaf88a98a7035b0e40d53dc9/pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815", size = 2020145, upload-time = "2025-11-04T13:40:21.548Z" }, + { url = "https://files.pythonhosted.org/packages/f9/22/91fbc821fa6d261b376a3f73809f907cec5ca6025642c463d3488aad22fb/pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3", size = 1976179, upload-time = "2025-11-04T13:40:23.393Z" }, { url = "https://files.pythonhosted.org/packages/87/06/8806241ff1f70d9939f9af039c6c35f2360cf16e93c2ca76f184e76b1564/pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9", size = 2120403, upload-time = "2025-11-04T13:40:25.248Z" }, { url = "https://files.pythonhosted.org/packages/94/02/abfa0e0bda67faa65fef1c84971c7e45928e108fe24333c81f3bfe35d5f5/pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34", size = 1896206, upload-time = "2025-11-04T13:40:27.099Z" }, { url = "https://files.pythonhosted.org/packages/15/df/a4c740c0943e93e6500f9eb23f4ca7ec9bf71b19e608ae5b579678c8d02f/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0", size = 1919307, upload-time = "2025-11-04T13:40:29.806Z" }, @@ -857,6 +988,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/96/5fb7d8c3c17bc8c62fdb031c47d77a1af698f1d7a406b0f79aaa1338f9ad/pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa", size = 1988906, upload-time = "2025-11-04T13:41:56.606Z" }, { url = "https://files.pythonhosted.org/packages/22/ed/182129d83032702912c2e2d8bbe33c036f342cc735737064668585dac28f/pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c", size = 1981607, upload-time = "2025-11-04T13:41:58.889Z" }, { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, + { url = "https://files.pythonhosted.org/packages/09/32/59b0c7e63e277fa7911c2fc70ccfb45ce4b98991e7ef37110663437005af/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd", size = 2110495, upload-time = "2025-11-04T13:42:49.689Z" }, + { url = "https://files.pythonhosted.org/packages/aa/81/05e400037eaf55ad400bcd318c05bb345b57e708887f07ddb2d20e3f0e98/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc", size = 1915388, upload-time = "2025-11-04T13:42:52.215Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0d/e3549b2399f71d56476b77dbf3cf8937cec5cd70536bdc0e374a421d0599/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56", size = 1942879, upload-time = "2025-11-04T13:42:56.483Z" }, + { url = "https://files.pythonhosted.org/packages/f7/07/34573da085946b6a313d7c42f82f16e8920bfd730665de2d11c0c37a74b5/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b", size = 2139017, upload-time = "2025-11-04T13:42:59.471Z" }, ] [[package]] @@ -899,6 +1034,7 @@ version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } wheels = [ @@ -920,6 +1056,16 @@ version = "6.0.3" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, @@ -971,6 +1117,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "rpds-py" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } wheels = [ @@ -1041,6 +1188,21 @@ version = "0.30.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/03/e7/98a2f4ac921d82f33e03f3835f5bf3a4a40aa1bfdc57975e74a97b2b4bdd/rpds_py-0.30.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a161f20d9a43006833cd7068375a94d035714d73a172b681d8881820600abfad", size = 375086, upload-time = "2025-11-30T20:22:17.93Z" }, + { url = "https://files.pythonhosted.org/packages/4d/a1/bca7fd3d452b272e13335db8d6b0b3ecde0f90ad6f16f3328c6fb150c889/rpds_py-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6abc8880d9d036ecaafe709079969f56e876fcf107f7a8e9920ba6d5a3878d05", size = 359053, upload-time = "2025-11-30T20:22:19.297Z" }, + { url = "https://files.pythonhosted.org/packages/65/1c/ae157e83a6357eceff62ba7e52113e3ec4834a84cfe07fa4b0757a7d105f/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca28829ae5f5d569bb62a79512c842a03a12576375d5ece7d2cadf8abe96ec28", size = 390763, upload-time = "2025-11-30T20:22:21.661Z" }, + { url = "https://files.pythonhosted.org/packages/d4/36/eb2eb8515e2ad24c0bd43c3ee9cd74c33f7ca6430755ccdb240fd3144c44/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1010ed9524c73b94d15919ca4d41d8780980e1765babf85f9a2f90d247153dd", size = 408951, upload-time = "2025-11-30T20:22:23.408Z" }, + { url = "https://files.pythonhosted.org/packages/d6/65/ad8dc1784a331fabbd740ef6f71ce2198c7ed0890dab595adb9ea2d775a1/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d1736cfb49381ba528cd5baa46f82fdc65c06e843dab24dd70b63d09121b3f", size = 514622, upload-time = "2025-11-30T20:22:25.16Z" }, + { url = "https://files.pythonhosted.org/packages/63/8e/0cfa7ae158e15e143fe03993b5bcd743a59f541f5952e1546b1ac1b5fd45/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d948b135c4693daff7bc2dcfc4ec57237a29bd37e60c2fabf5aff2bbacf3e2f1", size = 414492, upload-time = "2025-11-30T20:22:26.505Z" }, + { url = "https://files.pythonhosted.org/packages/60/1b/6f8f29f3f995c7ffdde46a626ddccd7c63aefc0efae881dc13b6e5d5bb16/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47f236970bccb2233267d89173d3ad2703cd36a0e2a6e92d0560d333871a3d23", size = 394080, upload-time = "2025-11-30T20:22:27.934Z" }, + { url = "https://files.pythonhosted.org/packages/6d/d5/a266341051a7a3ca2f4b750a3aa4abc986378431fc2da508c5034d081b70/rpds_py-0.30.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:2e6ecb5a5bcacf59c3f912155044479af1d0b6681280048b338b28e364aca1f6", size = 408680, upload-time = "2025-11-30T20:22:29.341Z" }, + { url = "https://files.pythonhosted.org/packages/10/3b/71b725851df9ab7a7a4e33cf36d241933da66040d195a84781f49c50490c/rpds_py-0.30.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a8fa71a2e078c527c3e9dc9fc5a98c9db40bcc8a92b4e8858e36d329f8684b51", size = 423589, upload-time = "2025-11-30T20:22:31.469Z" }, + { url = "https://files.pythonhosted.org/packages/00/2b/e59e58c544dc9bd8bd8384ecdb8ea91f6727f0e37a7131baeff8d6f51661/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:73c67f2db7bc334e518d097c6d1e6fed021bbc9b7d678d6cc433478365d1d5f5", size = 573289, upload-time = "2025-11-30T20:22:32.997Z" }, + { url = "https://files.pythonhosted.org/packages/da/3e/a18e6f5b460893172a7d6a680e86d3b6bc87a54c1f0b03446a3c8c7b588f/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5ba103fb455be00f3b1c2076c9d4264bfcb037c976167a6047ed82f23153f02e", size = 599737, upload-time = "2025-11-30T20:22:34.419Z" }, + { url = "https://files.pythonhosted.org/packages/5c/e2/714694e4b87b85a18e2c243614974413c60aa107fd815b8cbc42b873d1d7/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee9c752c0364588353e627da8a7e808a66873672bcb5f52890c33fd965b394", size = 563120, upload-time = "2025-11-30T20:22:35.903Z" }, + { url = "https://files.pythonhosted.org/packages/6f/ab/d5d5e3bcedb0a77f4f613706b750e50a5a3ba1c15ccd3665ecc636c968fd/rpds_py-0.30.0-cp312-cp312-win32.whl", hash = "sha256:1ab5b83dbcf55acc8b08fc62b796ef672c457b17dbd7820a11d6c52c06839bdf", size = 223782, upload-time = "2025-11-30T20:22:37.271Z" }, + { url = "https://files.pythonhosted.org/packages/39/3b/f786af9957306fdc38a74cef405b7b93180f481fb48453a114bb6465744a/rpds_py-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:a090322ca841abd453d43456ac34db46e8b05fd9b3b4ac0c78bcde8b089f959b", size = 240463, upload-time = "2025-11-30T20:22:39.021Z" }, + { url = "https://files.pythonhosted.org/packages/f3/d2/b91dc748126c1559042cfe41990deb92c4ee3e2b415f6b5234969ffaf0cc/rpds_py-0.30.0-cp312-cp312-win_arm64.whl", hash = "sha256:669b1805bd639dd2989b281be2cfd951c6121b65e729d9b843e9639ef1fd555e", size = 230868, upload-time = "2025-11-30T20:22:40.493Z" }, { url = "https://files.pythonhosted.org/packages/ed/dc/d61221eb88ff410de3c49143407f6f3147acf2538c86f2ab7ce65ae7d5f9/rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2", size = 374887, upload-time = "2025-11-30T20:22:41.812Z" }, { url = "https://files.pythonhosted.org/packages/fd/32/55fb50ae104061dbc564ef15cc43c013dc4a9f4527a1f4d99baddf56fe5f/rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8", size = 358904, upload-time = "2025-11-30T20:22:43.479Z" }, { url = "https://files.pythonhosted.org/packages/58/70/faed8186300e3b9bdd138d0273109784eea2396c68458ed580f885dfe7ad/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2771c6c15973347f50fece41fc447c054b7ac2ae0502388ce3b6738cd366e3d4", size = 389945, upload-time = "2025-11-30T20:22:44.819Z" }, @@ -1101,6 +1263,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" }, ] +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + [[package]] name = "ruff" version = "0.14.8" @@ -1296,6 +1470,14 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/f0/f2/840d7b9496825333f532d2e3976b8eadbf52034178aac53630d09fe6e1ef/sqlalchemy-2.0.44.tar.gz", hash = "sha256:0ae7454e1ab1d780aee69fd2aae7d6b8670a581d8847f2d1e0f7ddfbf47e5a22", size = 9819830, upload-time = "2025-10-10T14:39:12.935Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/62/c4/59c7c9b068e6813c898b771204aad36683c96318ed12d4233e1b18762164/sqlalchemy-2.0.44-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:72fea91746b5890f9e5e0997f16cbf3d53550580d76355ba2d998311b17b2250", size = 2139675, upload-time = "2025-10-10T16:03:31.064Z" }, + { url = "https://files.pythonhosted.org/packages/d6/ae/eeb0920537a6f9c5a3708e4a5fc55af25900216bdb4847ec29cfddf3bf3a/sqlalchemy-2.0.44-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:585c0c852a891450edbb1eaca8648408a3cc125f18cf433941fa6babcc359e29", size = 2127726, upload-time = "2025-10-10T16:03:35.934Z" }, + { url = "https://files.pythonhosted.org/packages/d8/d5/2ebbabe0379418eda8041c06b0b551f213576bfe4c2f09d77c06c07c8cc5/sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b94843a102efa9ac68a7a30cd46df3ff1ed9c658100d30a725d10d9c60a2f44", size = 3327603, upload-time = "2025-10-10T15:35:28.322Z" }, + { url = "https://files.pythonhosted.org/packages/45/e5/5aa65852dadc24b7d8ae75b7efb8d19303ed6ac93482e60c44a585930ea5/sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:119dc41e7a7defcefc57189cfa0e61b1bf9c228211aba432b53fb71ef367fda1", size = 3337842, upload-time = "2025-10-10T15:43:45.431Z" }, + { url = "https://files.pythonhosted.org/packages/41/92/648f1afd3f20b71e880ca797a960f638d39d243e233a7082c93093c22378/sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0765e318ee9179b3718c4fd7ba35c434f4dd20332fbc6857a5e8df17719c24d7", size = 3264558, upload-time = "2025-10-10T15:35:29.93Z" }, + { url = "https://files.pythonhosted.org/packages/40/cf/e27d7ee61a10f74b17740918e23cbc5bc62011b48282170dc4c66da8ec0f/sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2e7b5b079055e02d06a4308d0481658e4f06bc7ef211567edc8f7d5dce52018d", size = 3301570, upload-time = "2025-10-10T15:43:48.407Z" }, + { url = "https://files.pythonhosted.org/packages/3b/3d/3116a9a7b63e780fb402799b6da227435be878b6846b192f076d2f838654/sqlalchemy-2.0.44-cp312-cp312-win32.whl", hash = "sha256:846541e58b9a81cce7dee8329f352c318de25aa2f2bbe1e31587eb1f057448b4", size = 2103447, upload-time = "2025-10-10T15:03:21.678Z" }, + { url = "https://files.pythonhosted.org/packages/25/83/24690e9dfc241e6ab062df82cc0df7f4231c79ba98b273fa496fb3dd78ed/sqlalchemy-2.0.44-cp312-cp312-win_amd64.whl", hash = "sha256:7cbcb47fd66ab294703e1644f78971f6f2f1126424d2b300678f419aa73c7b6e", size = 2130912, upload-time = "2025-10-10T15:03:24.656Z" }, { url = "https://files.pythonhosted.org/packages/45/d3/c67077a2249fdb455246e6853166360054c331db4613cda3e31ab1cadbef/sqlalchemy-2.0.44-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ff486e183d151e51b1d694c7aa1695747599bb00b9f5f604092b54b74c64a8e1", size = 2135479, upload-time = "2025-10-10T16:03:37.671Z" }, { url = "https://files.pythonhosted.org/packages/2b/91/eabd0688330d6fd114f5f12c4f89b0d02929f525e6bf7ff80aa17ca802af/sqlalchemy-2.0.44-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b1af8392eb27b372ddb783b317dea0f650241cea5bd29199b22235299ca2e45", size = 2123212, upload-time = "2025-10-10T16:03:41.755Z" }, { url = "https://files.pythonhosted.org/packages/b0/bb/43e246cfe0e81c018076a16036d9b548c4cc649de241fa27d8d9ca6f85ab/sqlalchemy-2.0.44-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b61188657e3a2b9ac4e8f04d6cf8e51046e28175f79464c67f2fd35bceb0976", size = 3255353, upload-time = "2025-10-10T15:35:31.221Z" }, @@ -1320,6 +1502,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8c/92/c35e036151fe53822893979f8a13e6f235ae8191f4164a79ae60a95d66aa/sqlmodel-0.0.27-py3-none-any.whl", hash = "sha256:667fe10aa8ff5438134668228dc7d7a08306f4c5c4c7e6ad3ad68defa0e7aa49", size = 29131, upload-time = "2025-10-08T16:39:10.917Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/d4/2b0cd0fe285e14b36db076e78c93766ff1d529d70408bd1d2a5a84f1d929/tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb", size = 48036, upload-time = "2025-04-02T08:25:09.966Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, +] + [[package]] name = "tqdm" version = "4.67.1" @@ -1381,3 +1572,34 @@ sdist = { url = "https://files.pythonhosted.org/packages/1c/43/554c2569b62f49350 wheels = [ { url = "https://files.pythonhosted.org/packages/56/1a/9ffe814d317c5224166b23e7c47f606d6e473712a2fad0f704ea9b99f246/urllib3-2.6.0-py3-none-any.whl", hash = "sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f", size = 131083, upload-time = "2025-12-05T15:08:45.983Z" }, ] + +[[package]] +name = "websockets" +version = "15.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/e6/26d09fab466b7ca9c7737474c52be4f76a40301b08362eb2dbc19dcc16c1/websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee", size = 177016, upload-time = "2025-03-05T20:03:41.606Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/6b/4545a0d843594f5d0771e86463606a3988b5a09ca5123136f8a76580dd63/websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3", size = 175437, upload-time = "2025-03-05T20:02:16.706Z" }, + { url = "https://files.pythonhosted.org/packages/f4/71/809a0f5f6a06522af902e0f2ea2757f71ead94610010cf570ab5c98e99ed/websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665", size = 173096, upload-time = "2025-03-05T20:02:18.832Z" }, + { url = "https://files.pythonhosted.org/packages/3d/69/1a681dd6f02180916f116894181eab8b2e25b31e484c5d0eae637ec01f7c/websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2", size = 173332, upload-time = "2025-03-05T20:02:20.187Z" }, + { url = "https://files.pythonhosted.org/packages/a6/02/0073b3952f5bce97eafbb35757f8d0d54812b6174ed8dd952aa08429bcc3/websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215", size = 183152, upload-time = "2025-03-05T20:02:22.286Z" }, + { url = "https://files.pythonhosted.org/packages/74/45/c205c8480eafd114b428284840da0b1be9ffd0e4f87338dc95dc6ff961a1/websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5", size = 182096, upload-time = "2025-03-05T20:02:24.368Z" }, + { url = "https://files.pythonhosted.org/packages/14/8f/aa61f528fba38578ec553c145857a181384c72b98156f858ca5c8e82d9d3/websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65", size = 182523, upload-time = "2025-03-05T20:02:25.669Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6d/0267396610add5bc0d0d3e77f546d4cd287200804fe02323797de77dbce9/websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe", size = 182790, upload-time = "2025-03-05T20:02:26.99Z" }, + { url = "https://files.pythonhosted.org/packages/02/05/c68c5adbf679cf610ae2f74a9b871ae84564462955d991178f95a1ddb7dd/websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4", size = 182165, upload-time = "2025-03-05T20:02:30.291Z" }, + { url = "https://files.pythonhosted.org/packages/29/93/bb672df7b2f5faac89761cb5fa34f5cec45a4026c383a4b5761c6cea5c16/websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597", size = 182160, upload-time = "2025-03-05T20:02:31.634Z" }, + { url = "https://files.pythonhosted.org/packages/ff/83/de1f7709376dc3ca9b7eeb4b9a07b4526b14876b6d372a4dc62312bebee0/websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9", size = 176395, upload-time = "2025-03-05T20:02:33.017Z" }, + { url = "https://files.pythonhosted.org/packages/7d/71/abf2ebc3bbfa40f391ce1428c7168fb20582d0ff57019b69ea20fa698043/websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7", size = 176841, upload-time = "2025-03-05T20:02:34.498Z" }, + { url = "https://files.pythonhosted.org/packages/cb/9f/51f0cf64471a9d2b4d0fc6c534f323b664e7095640c34562f5182e5a7195/websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931", size = 175440, upload-time = "2025-03-05T20:02:36.695Z" }, + { url = "https://files.pythonhosted.org/packages/8a/05/aa116ec9943c718905997412c5989f7ed671bc0188ee2ba89520e8765d7b/websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675", size = 173098, upload-time = "2025-03-05T20:02:37.985Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0b/33cef55ff24f2d92924923c99926dcce78e7bd922d649467f0eda8368923/websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151", size = 173329, upload-time = "2025-03-05T20:02:39.298Z" }, + { url = "https://files.pythonhosted.org/packages/31/1d/063b25dcc01faa8fada1469bdf769de3768b7044eac9d41f734fd7b6ad6d/websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22", size = 183111, upload-time = "2025-03-05T20:02:40.595Z" }, + { url = "https://files.pythonhosted.org/packages/93/53/9a87ee494a51bf63e4ec9241c1ccc4f7c2f45fff85d5bde2ff74fcb68b9e/websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f", size = 182054, upload-time = "2025-03-05T20:02:41.926Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b2/83a6ddf56cdcbad4e3d841fcc55d6ba7d19aeb89c50f24dd7e859ec0805f/websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8", size = 182496, upload-time = "2025-03-05T20:02:43.304Z" }, + { url = "https://files.pythonhosted.org/packages/98/41/e7038944ed0abf34c45aa4635ba28136f06052e08fc2168520bb8b25149f/websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375", size = 182829, upload-time = "2025-03-05T20:02:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/e0/17/de15b6158680c7623c6ef0db361da965ab25d813ae54fcfeae2e5b9ef910/websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d", size = 182217, upload-time = "2025-03-05T20:02:50.14Z" }, + { url = "https://files.pythonhosted.org/packages/33/2b/1f168cb6041853eef0362fb9554c3824367c5560cbdaad89ac40f8c2edfc/websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4", size = 182195, upload-time = "2025-03-05T20:02:51.561Z" }, + { url = "https://files.pythonhosted.org/packages/86/eb/20b6cdf273913d0ad05a6a14aed4b9a85591c18a987a3d47f20fa13dcc47/websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa", size = 176393, upload-time = "2025-03-05T20:02:53.814Z" }, + { url = "https://files.pythonhosted.org/packages/1b/6c/c65773d6cab416a64d191d6ee8a8b1c68a09970ea6909d16965d26bfed1e/websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561", size = 176837, upload-time = "2025-03-05T20:02:55.237Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, +]