diff --git a/Makefile b/Makefile index 080a4415..1a1d9898 100644 --- a/Makefile +++ b/Makefile @@ -51,3 +51,4 @@ clean: rm -rf logs/random-baseline find . -type f -name "*.pyc" -delete find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + find . -type d -name ".checkpoints" -exec rm -rf {} + 2>/dev/null || true diff --git a/cli/run_all.py b/cli/run_all.py index 68ab470d..a48a50aa 100644 --- a/cli/run_all.py +++ b/cli/run_all.py @@ -20,11 +20,20 @@ sys.path.insert(0, project_root) from main import ARCTester -from arc_agi_benchmarking.utils.task_utils import read_models_config, read_provider_rate_limits +from arc_agi_benchmarking.utils.task_utils import read_models_config, read_provider_rate_limits, get_provider_timeout_config from arc_agi_benchmarking.utils.rate_limiter import AsyncRequestRateLimiter from arc_agi_benchmarking.utils.metrics import set_metrics_enabled, set_metrics_filename_prefix from arc_agi_benchmarking.utils.preflight import run_preflight from arc_agi_benchmarking.utils.logging_utils import setup_logging, StructuredFormatter +from arc_agi_benchmarking.resilience import ( + CircuitBreaker, + CircuitBreakerOpenError, + TaskTimeoutError, + task_timeout, + get_circuit_breaker, +) +from arc_agi_benchmarking.checkpoint import BatchProgressManager, TaskStatus +from arc_agi_benchmarking.storage import LocalStorageBackend from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type, before_sleep_log @@ -82,6 +91,8 @@ def _record_factory(*args, **kwargs): # Default values DEFAULT_RATE_LIMIT_RATE = 400 DEFAULT_RATE_LIMIT_PERIOD = 60 +DEFAULT_CIRCUIT_BREAKER_THRESHOLD = 5 +DEFAULT_CIRCUIT_BREAKER_RECOVERY = 60 # --- Configuration --- # Default model configuration to test if not provided via CLI. @@ -98,6 +109,8 @@ def _record_factory(*args, **kwargs): # --- Globals for Orchestrator --- PROVIDER_RATE_LIMITERS: Dict[str, AsyncRequestRateLimiter] = {} +PROVIDER_CIRCUIT_BREAKERS: Dict[str, CircuitBreaker] = {} +PROVIDER_TIMEOUT_CONFIGS: Dict[str, Dict] = {} MODEL_CONFIG_CACHE: Dict[str, Any] = {} def get_model_config(config_name: str): @@ -129,15 +142,54 @@ def get_or_create_rate_limiter(provider_name: str, all_provider_limits: Dict) -> PROVIDER_RATE_LIMITERS[provider_name] = AsyncRequestRateLimiter(rate=actual_rate_for_limiter, capacity=actual_capacity_for_limiter) return PROVIDER_RATE_LIMITERS[provider_name] + +def get_or_create_circuit_breaker( + provider_name: str, + all_provider_limits: Dict, + threshold_override: Optional[int] = None +) -> CircuitBreaker: + if provider_name not in PROVIDER_CIRCUIT_BREAKERS: + timeout_config = get_provider_timeout_config(provider_name, all_provider_limits) + failure_threshold = threshold_override or timeout_config['circuit_breaker_threshold'] + recovery_timeout = timeout_config['circuit_breaker_recovery'] + logger.info(f"Initializing circuit breaker for '{provider_name}': threshold={failure_threshold}, recovery={recovery_timeout}s") + PROVIDER_CIRCUIT_BREAKERS[provider_name] = CircuitBreaker( + name=provider_name, + failure_threshold=failure_threshold, + recovery_timeout=recovery_timeout, + ) + return PROVIDER_CIRCUIT_BREAKERS[provider_name] + + +def get_task_timeout(provider_name: str, all_provider_limits: Dict, max_task_timeout: Optional[float] = None) -> Optional[float]: + if max_task_timeout is not None and max_task_timeout > 0: + return max_task_timeout + return None + async def run_single_test_wrapper(config_name: str, task_id: str, limiter: AsyncRequestRateLimiter, + circuit_breaker: CircuitBreaker, + task_timeout_seconds: Optional[float], data_dir: str, save_submission_dir: str, overwrite_submission: bool, print_submission: bool, num_attempts: int, retry_attempts: int, - logs_base_dir: Path) -> bool: # removed print_logs + logs_base_dir: Path, + progress_manager: Optional[BatchProgressManager] = None) -> Optional[bool]: logger.info(f"[Orchestrator] Queuing task: {task_id}, config: {config_name}") - # Apply tenacity retry decorator directly to the synchronous function - # The logger passed to before_sleep_log is the module-level logger of cli.run_all + # Claim the task for this worker (if using checkpointing) + if progress_manager is not None: + if not progress_manager.claim_task(task_id): + logger.debug(f"[Orchestrator] Task {task_id} already claimed or completed, skipping") + return None # Skipped, not success or failure + + try: + circuit_breaker.raise_if_open() + except CircuitBreakerOpenError as e: + logger.warning(f"[Orchestrator] Circuit breaker OPEN for {config_name}, skipping {task_id}. Recovery in {e.recovery_time:.1f}s") + if progress_manager is not None: + progress_manager.mark_failed(task_id, f"Circuit breaker open: {e}") + return False + @retry( wait=wait_exponential(multiplier=1, min=4, max=60), stop=stop_after_attempt(4), @@ -192,13 +244,40 @@ def filter(self, record: logging.LogRecord) -> bool: try: async with limiter: - logger.info(f"[Orchestrator] Rate limiter acquired for: {config_name}. Executing task {task_id} with tenacity retries...") - await asyncio.to_thread(_synchronous_task_execution_attempt_with_tenacity) - - logger.info(f"[Orchestrator] Successfully processed (with tenacity retries if any): {config_name} / {task_id}") + timeout_str = f"{task_timeout_seconds}s" if task_timeout_seconds else "none" + logger.info(f"[Orchestrator] Rate limiter acquired for {config_name}. Executing {task_id} (timeout={timeout_str})") + if task_timeout_seconds: + await task_timeout( + _synchronous_task_execution_attempt_with_tenacity, + task_timeout_seconds, + f"Task {task_id} ({config_name})" + ) + else: + await asyncio.get_event_loop().run_in_executor( + None, _synchronous_task_execution_attempt_with_tenacity + ) + + circuit_breaker.record_success() + logger.info(f"[Orchestrator] Successfully processed: {config_name} / {task_id}") + if progress_manager is not None: + progress_manager.mark_completed(task_id) return True + + except TaskTimeoutError as e: + circuit_breaker.record_failure(e) + logger.error(f"[Orchestrator] Task {task_id} ({config_name}) timed out after {e.elapsed:.2f}s (limit: {e.timeout}s)") + if progress_manager is not None: + progress_manager.mark_failed(task_id, f"Timeout after {e.elapsed:.2f}s") + return False + except Exception as e: - logger.error(f"[Orchestrator] Failed to process (after all tenacity retries or due to non-retryable error): {config_name} / {task_id}. Error: {type(e).__name__} - {e}", exc_info=True) + if isinstance(e, EFFECTIVE_RETRYABLE_EXCEPTIONS): + circuit_breaker.record_failure(e) + logger.error(f"[Orchestrator] Failed after retries: {config_name} / {task_id}. {type(e).__name__}: {e}", exc_info=True) + else: + logger.error(f"[Orchestrator] Failed (non-retryable): {config_name} / {task_id}. {type(e).__name__}: {e}", exc_info=True) + if progress_manager is not None: + progress_manager.mark_failed(task_id, f"{type(e).__name__}: {e}") return False async def main(task_list_file: Optional[str], @@ -206,12 +285,17 @@ async def main(task_list_file: Optional[str], data_dir: str, save_submission_dir: str, overwrite_submission: bool, print_submission: bool, num_attempts: int, retry_attempts: int, - logs_base_dir: Path) -> int: # Added return type hint - # Basic logging setup is now done in if __name__ == "__main__" - + logs_base_dir: Path, + max_task_timeout: Optional[float] = None, + circuit_breaker_threshold: Optional[int] = None, + resume: bool = True) -> int: start_time = time.perf_counter() logger.info("Starting ARC Test Orchestrator...") logger.info(f"Testing with model configuration: {config_to_test}") + if max_task_timeout: + logger.info(f"Task timeout: {max_task_timeout}s (CLI override)") + if circuit_breaker_threshold: + logger.info(f"Circuit breaker threshold: {circuit_breaker_threshold} (CLI override)") task_ids: List[str] = [] try: @@ -245,14 +329,52 @@ async def main(task_list_file: Optional[str], logger.error(f"Error loading tasks: {e}", exc_info=True) return 1 # Return an error code + # --- Checkpointing Setup --- + checkpoint_dir = Path(save_submission_dir) / config_to_test / ".checkpoints" + storage = LocalStorageBackend(checkpoint_dir) + progress_manager = BatchProgressManager( + storage=storage, + run_id=config_to_test, + ) + + # Initialize all tasks (idempotent - only adds tasks not already tracked) + progress_manager.initialize_tasks(task_ids, attempts_per_task=num_attempts) + + # Reset any stale in-progress tasks (from crashed workers) + stale_count = progress_manager.reset_stale_tasks(max_age_seconds=3600) + if stale_count > 0: + logger.info(f"Reset {stale_count} stale in-progress task(s)") + + # Determine which tasks to run + if resume: + # Only run pending tasks + tasks_to_run = [ + task_id for task_id in task_ids + if progress_manager.progress.tasks.get(task_id) and + progress_manager.progress.tasks[task_id].status == TaskStatus.PENDING + ] + completed_count = progress_manager.progress.completed_count + failed_count = progress_manager.progress.failed_count + if completed_count > 0 or failed_count > 0: + logger.info( + f"Resuming: {completed_count} completed, {failed_count} failed, " + f"{len(tasks_to_run)} remaining" + ) + else: + tasks_to_run = task_ids + logger.info("Resume disabled - running all tasks") + all_jobs_to_run: List[Tuple[str, str]] = [] - for task_id in task_ids: + for task_id in tasks_to_run: all_jobs_to_run.append((config_to_test, task_id)) - + if not all_jobs_to_run: + if resume and progress_manager.is_complete(): + logger.info("All tasks already completed. Use --no-resume to re-run.") + return 0 logger.warning("No jobs to run (check config_to_test and task list). Exiting.") - return 1 # Return an error code - + return 1 + logger.info(f"Total jobs to process: {len(all_jobs_to_run)}") try: @@ -271,12 +393,16 @@ async def main(task_list_file: Optional[str], model_config_obj = get_model_config(config_name) provider_name = model_config_obj.provider limiter = get_or_create_rate_limiter(provider_name, all_provider_limits) + circuit_breaker = get_or_create_circuit_breaker(provider_name, all_provider_limits, circuit_breaker_threshold) + task_timeout_val = get_task_timeout(provider_name, all_provider_limits, max_task_timeout) async_tasks_to_execute.append(run_single_test_wrapper( config_name, task_id, limiter, + circuit_breaker, task_timeout_val, data_dir, save_submission_dir, - overwrite_submission, print_submission, + overwrite_submission, print_submission, num_attempts, retry_attempts, - logs_base_dir + logs_base_dir, + progress_manager, )) except ValueError as e: # Specific error for model config issues logger.error(f"Skipping config '{config_name}' for task '{task_id}' due to model config error: {e}") @@ -291,6 +417,7 @@ async def main(task_list_file: Optional[str], results = await asyncio.gather(*async_tasks_to_execute, return_exceptions=True) successful_runs = sum(1 for r in results if r is True) + skipped_runs = sum(1 for r in results if r is None) orchestrator_level_failures = sum(1 for r in results if r is False or isinstance(r, Exception)) logger.info("--- Orchestrator Summary ---") @@ -306,10 +433,33 @@ async def main(task_list_file: Optional[str], elif res is False: # Wrapper reported failure logger.warning(f" - Failure reported by wrapper for {original_job_config}/{original_job_task_id} (check ARCTester logs for this task/config)") exit_code = 1 # Indicate failure - + + # Log checkpoint progress summary + logger.info("--- Checkpoint Progress Summary ---") + summary = progress_manager.get_summary() + logger.info( + f" Run: {summary['run_id']} | " + f"Total: {summary['total']} | " + f"Completed: {summary['completed']} | " + f"Failed: {summary['failed']} | " + f"Pending: {summary['pending']}" + ) + + # Log circuit breaker statistics + if PROVIDER_CIRCUIT_BREAKERS: + logger.info("--- Circuit Breaker Summary ---") + for provider, cb in PROVIDER_CIRCUIT_BREAKERS.items(): + stats = cb.get_stats() + logger.info( + f" {provider}: state={stats['state']}, " + f"requests={stats['total_requests']}, " + f"failures={stats['failed_requests']}, " + f"rejected={stats['rejected_requests']}" + ) + logger.info("Note: Individual task success/failure is logged by ARCTester within its own logger (main.py's logger).") logger.info("Orchestrator failure indicates an issue with running the ARCTester task itself or an unhandled exception in the wrapper.") - + end_time = time.perf_counter() total_duration = end_time - start_time logger.info("--- Orchestrator Timing ---") @@ -396,9 +546,28 @@ async def main(task_list_file: Optional[str], default=None, help="Maximum estimated cost in USD. Abort if estimated cost exceeds this limit." ) + parser.add_argument( + "--max-task-timeout", + type=float, + default=None, + help="Maximum timeout in seconds for each task execution. Overrides provider-specific timeouts." + ) + parser.add_argument( + "--circuit-breaker-threshold", + type=int, + default=None, + help="Number of failures before circuit breaker opens. Overrides provider-specific thresholds." + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Disable resume - run all tasks even if some are already completed." + ) args = parser.parse_args() + resume_enabled = not args.no_resume + # Set metrics enabled status based on CLI arg set_metrics_enabled(args.enable_metrics) @@ -474,7 +643,10 @@ async def main(task_list_file: Optional[str], print_submission=args.print_submission, num_attempts=args.num_attempts, retry_attempts=args.retry_attempts, - logs_base_dir=logs_base_dir + logs_base_dir=logs_base_dir, + max_task_timeout=args.max_task_timeout, + circuit_breaker_threshold=args.circuit_breaker_threshold, + resume=resume_enabled, )) sys.exit(exit_code_from_main) diff --git a/scripts/demo_checkpoint.py b/scripts/demo_checkpoint.py new file mode 100644 index 00000000..2bd44c3d --- /dev/null +++ b/scripts/demo_checkpoint.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Demo script to test checkpointing with fake tasks.""" + +import random +import time +from decimal import Decimal +from pathlib import Path +from tempfile import TemporaryDirectory + +from arc_agi_benchmarking.checkpoint import ( + BatchProgressManager, + TaskCheckpointManager, +) +from arc_agi_benchmarking.storage import LocalStorageBackend + + +def simulate_task(task_checkpoint: TaskCheckpointManager, num_test_pairs: int = 3, max_attempts: int = 2, fail_task: bool = False): + """Simulate running a task with checkpointing. + + Returns True if task succeeded, False if it failed. + """ + all_succeeded = True + + for test_pair_idx in range(num_test_pairs): + pair_succeeded = False + while True: + attempt_idx = task_checkpoint.get_next_attempt_index(test_pair_idx, max_attempts) + if attempt_idx is None: + break + + print(f" Running test pair {test_pair_idx}, attempt {attempt_idx}...") + time.sleep(0.1) + + if fail_task or random.random() < 0.3: + print(f" ❌ Simulated failure!") + task_checkpoint.record_attempt( + test_pair_index=test_pair_idx, + attempt_index=attempt_idx, + response=None, + error="Simulated API error", + cost_usd=Decimal("0.001"), + ) + else: + print(f" ✓ Success!") + task_checkpoint.record_attempt( + test_pair_index=test_pair_idx, + attempt_index=attempt_idx, + response={"answer": [[1, 2], [3, 4]]}, + cost_usd=Decimal("0.005"), + tokens_input=100, + tokens_output=50, + ) + pair_succeeded = True + break + + if not pair_succeeded: + all_succeeded = False + + return all_succeeded + + +def run_batch(storage: LocalStorageBackend, task_ids: list[str], interrupt_after: int | None = None, force_fail_task: str | None = None): + """Run a batch of tasks with checkpointing.""" + batch_manager = BatchProgressManager(storage, run_id="demo_run") + batch_manager.initialize_tasks(task_ids) + + print(f"\n{'='*50}") + print(f"Starting batch: {batch_manager.get_summary()}") + print(f"{'='*50}\n") + + tasks_processed = 0 + while True: + task_id = batch_manager.claim_next_task() + if not task_id: + break + + print(f"\n▶ Processing task: {task_id}") + + task_checkpoint = TaskCheckpointManager(storage, task_id) + existing = len(task_checkpoint.get_completed_attempts()) + if existing > 0: + print(f" Resuming from checkpoint ({existing} attempts already done)") + + try: + fail_this_task = (task_id == force_fail_task) + success = simulate_task(task_checkpoint, fail_task=fail_this_task) + + if success: + batch_manager.mark_completed( + task_id, + cost_usd=task_checkpoint.checkpoint.total_cost_usd, + tokens_input=task_checkpoint.checkpoint.total_tokens_input, + tokens_output=task_checkpoint.checkpoint.total_tokens_output, + ) + task_checkpoint.delete_checkpoint() + print(f" ✓ Task {task_id} completed") + else: + batch_manager.mark_failed( + task_id, + "All attempts exhausted", + cost_usd=task_checkpoint.checkpoint.total_cost_usd, + tokens_input=task_checkpoint.checkpoint.total_tokens_input, + tokens_output=task_checkpoint.checkpoint.total_tokens_output, + ) + task_checkpoint.delete_checkpoint() + print(f" ✗ Task {task_id} failed (all attempts exhausted)") + + except Exception as e: + batch_manager.mark_failed(task_id, str(e)) + print(f" ✗ Task {task_id} failed: {e}") + + tasks_processed += 1 + if interrupt_after and tasks_processed >= interrupt_after: + print(f"\n⚠️ Simulating interruption after {tasks_processed} tasks!") + return False + + print(f"\n{'='*50}") + print(f"Batch complete: {batch_manager.get_summary()}") + print(f"{'='*50}\n") + return batch_manager + + +def main(): + task_ids = [f"task_{i:03d}" for i in range(5)] + + with TemporaryDirectory() as tmpdir: + storage = LocalStorageBackend(Path(tmpdir)) + print(f"Using temp storage: {tmpdir}") + + print("\n" + "="*60) + print("RUN 1: Processing tasks (task_002 will fail)") + print("="*60) + batch_manager = run_batch(storage, task_ids, force_fail_task="task_002") + + print(f"\nFailed tasks: {batch_manager.progress.failed_count}") + + print("\n" + "="*60) + print("RUN 2: Retrying failed tasks") + print("="*60) + reset_count = batch_manager.retry_failed_tasks() + print(f"Reset {reset_count} failed task(s) back to pending") + + run_batch(storage, task_ids) + + print("\n" + "="*60) + print("RUN 3: Running again (should be no-op)") + print("="*60) + run_batch(storage, task_ids) + + +if __name__ == "__main__": + main() diff --git a/src/arc_agi_benchmarking/checkpoint/__init__.py b/src/arc_agi_benchmarking/checkpoint/__init__.py new file mode 100644 index 00000000..c5cb8501 --- /dev/null +++ b/src/arc_agi_benchmarking/checkpoint/__init__.py @@ -0,0 +1,21 @@ +"""Checkpointing and progress tracking for benchmark runs.""" + +from arc_agi_benchmarking.checkpoint.models import ( + TaskStatus, + TaskProgress, + BatchProgress, + TaskCheckpoint, + AttemptResult, +) +from arc_agi_benchmarking.checkpoint.batch_progress import BatchProgressManager +from arc_agi_benchmarking.checkpoint.task_checkpoint import TaskCheckpointManager + +__all__ = [ + "TaskStatus", + "TaskProgress", + "BatchProgress", + "TaskCheckpoint", + "AttemptResult", + "BatchProgressManager", + "TaskCheckpointManager", +] diff --git a/src/arc_agi_benchmarking/checkpoint/batch_progress.py b/src/arc_agi_benchmarking/checkpoint/batch_progress.py new file mode 100644 index 00000000..7bc9807f --- /dev/null +++ b/src/arc_agi_benchmarking/checkpoint/batch_progress.py @@ -0,0 +1,246 @@ +"""Batch progress manager for tracking overall run progress.""" + +import json +import logging +import os +from datetime import datetime, timezone +from decimal import Decimal + +from arc_agi_benchmarking.checkpoint.models import ( + BatchProgress, + TaskProgress, + TaskStatus, +) +from arc_agi_benchmarking.storage import StorageBackend + +logger = logging.getLogger(__name__) + + +class BatchProgressManager: + """Manages batch-level progress tracking. + + Tracks which tasks are pending, in-progress, completed, or failed. + Persists progress to storage for resume capability. + """ + + def __init__( + self, + storage: StorageBackend, + run_id: str, + progress_key: str = "progress.json", + ): + self.storage = storage + self.run_id = run_id + self.progress_key = progress_key + self._progress: BatchProgress | None = None + self._worker_id = f"{os.getpid()}@{os.uname().nodename}" + + @property + def progress(self) -> BatchProgress: + if self._progress is None: + self._progress = self._load_or_create() + return self._progress + + def _load_or_create(self) -> BatchProgress: + data = self.storage.read_text(self.progress_key) + if data: + try: + progress = BatchProgress.from_dict(json.loads(data)) + if progress.run_id != self.run_id: + logger.warning( + f"Run ID mismatch: expected {self.run_id}, " + f"found {progress.run_id}. Starting fresh." + ) + return BatchProgress(run_id=self.run_id) + return progress + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f"Failed to load progress, starting fresh: {e}") + return BatchProgress(run_id=self.run_id) + + def _save(self) -> None: + self.progress.updated_at = datetime.now(timezone.utc) + self.storage.write_text( + self.progress_key, + json.dumps(self.progress.to_dict(), indent=2), + ) + + def initialize_tasks(self, task_ids: list[str], attempts_per_task: int = 1) -> None: + """Initialize progress tracking for a list of tasks. + + Only adds tasks that don't already exist (preserves resumed state). + """ + for task_id in task_ids: + if task_id not in self.progress.tasks: + self.progress.tasks[task_id] = TaskProgress( + task_id=task_id, + attempts_total=attempts_per_task, + ) + self._save() + + def claim_task(self, task_id: str) -> bool: + """Attempt to claim a task for processing. + + Returns True if the task was successfully claimed, False if it's + already being processed or completed. + """ + task = self.progress.tasks.get(task_id) + if not task: + return False + + if task.status != TaskStatus.PENDING: + return False + + task.status = TaskStatus.IN_PROGRESS + task.worker_id = self._worker_id + task.started_at = datetime.now(timezone.utc) + self._save() + return True + + def get_next_pending_task(self) -> str | None: + """Get the next pending task ID, or None if all tasks are done.""" + for task_id, task in self.progress.tasks.items(): + if task.status == TaskStatus.PENDING: + return task_id + return None + + def claim_next_task(self) -> str | None: + """Claim the next available pending task. + + Returns the task ID if successful, None if no tasks available. + Uses a retry loop to handle race conditions where another worker + claims a task between get_next_pending_task() and claim_task(). + """ + while True: + task_id = self.get_next_pending_task() + if task_id is None: + return None + if self.claim_task(task_id): + return task_id + + def mark_completed( + self, + task_id: str, + cost_usd: Decimal = Decimal("0"), + tokens_input: int = 0, + tokens_output: int = 0, + ) -> None: + """Mark a task as completed.""" + task = self.progress.tasks.get(task_id) + if not task: + return + + task.status = TaskStatus.COMPLETED + task.completed_at = datetime.now(timezone.utc) + task.cost_usd = cost_usd + + self.progress.total_cost_usd += cost_usd + self.progress.total_tokens_input += tokens_input + self.progress.total_tokens_output += tokens_output + self._save() + + def mark_failed( + self, + task_id: str, + error: str, + cost_usd: Decimal = Decimal("0"), + tokens_input: int = 0, + tokens_output: int = 0, + ) -> None: + """Mark a task as failed, accumulating any costs incurred.""" + task = self.progress.tasks.get(task_id) + if not task: + return + + task.status = TaskStatus.FAILED + task.error = error + task.completed_at = datetime.now(timezone.utc) + task.cost_usd = cost_usd + + self.progress.total_cost_usd += cost_usd + self.progress.total_tokens_input += tokens_input + self.progress.total_tokens_output += tokens_output + self._save() + + def update_task_progress( + self, + task_id: str, + attempts_completed: int, + cost_usd: Decimal = Decimal("0"), + ) -> None: + """Update progress within a task (e.g., after each attempt).""" + task = self.progress.tasks.get(task_id) + if not task: + return + + task.attempts_completed = attempts_completed + task.cost_usd = cost_usd + self._save() + + def reset_stale_tasks(self, max_age_seconds: int = 3600) -> int: + """Reset tasks that have been in-progress too long (stale workers). + + Returns the number of tasks reset. + """ + now = datetime.now(timezone.utc) + reset_count = 0 + + for task in self.progress.tasks.values(): + if task.status != TaskStatus.IN_PROGRESS: + continue + if task.started_at is None: + continue + + age = (now - task.started_at).total_seconds() + if age > max_age_seconds: + logger.info( + f"Resetting stale task {task.task_id} " + f"(age: {age:.0f}s, worker: {task.worker_id})" + ) + task.status = TaskStatus.PENDING + task.worker_id = None + task.started_at = None + reset_count += 1 + + if reset_count > 0: + self._save() + + return reset_count + + def is_complete(self) -> bool: + """Check if all tasks are completed or failed.""" + return all( + t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED) + for t in self.progress.tasks.values() + ) + + def retry_failed_tasks(self) -> int: + """Reset failed tasks back to pending for retry. + + Returns the number of tasks reset. + """ + reset_count = 0 + for task in self.progress.tasks.values(): + if task.status == TaskStatus.FAILED: + task.status = TaskStatus.PENDING + task.error = None + task.worker_id = None + task.started_at = None + task.completed_at = None + reset_count += 1 + + if reset_count > 0: + self._save() + + return reset_count + + def get_summary(self) -> dict: + """Get a summary of the current progress.""" + return { + "run_id": self.run_id, + "total": self.progress.total_count, + "pending": self.progress.pending_count, + "in_progress": self.progress.in_progress_count, + "completed": self.progress.completed_count, + "failed": self.progress.failed_count, + "total_cost_usd": str(self.progress.total_cost_usd), + } diff --git a/src/arc_agi_benchmarking/checkpoint/models.py b/src/arc_agi_benchmarking/checkpoint/models.py new file mode 100644 index 00000000..02d534ea --- /dev/null +++ b/src/arc_agi_benchmarking/checkpoint/models.py @@ -0,0 +1,221 @@ +"""Data models for checkpointing.""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from decimal import Decimal +from enum import Enum +from typing import Any + + +def _utc_now() -> datetime: + """Return current UTC time as timezone-aware datetime.""" + return datetime.now(timezone.utc) + + +class TaskStatus(str, Enum): + """Status of a task in the batch.""" + + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class AttemptResult: + """Result of a single attempt within a task.""" + + attempt_index: int + test_pair_index: int + response: Any + cost_usd: Decimal = Decimal("0") + tokens_input: int = 0 + tokens_output: int = 0 + duration_seconds: float = 0.0 + error: str | None = None + timestamp: datetime = field(default_factory=_utc_now) + + def to_dict(self) -> dict: + return { + "attempt_index": self.attempt_index, + "test_pair_index": self.test_pair_index, + "response": self.response, + "cost_usd": str(self.cost_usd), + "tokens_input": self.tokens_input, + "tokens_output": self.tokens_output, + "duration_seconds": self.duration_seconds, + "error": self.error, + "timestamp": self.timestamp.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict) -> "AttemptResult": + return cls( + attempt_index=data["attempt_index"], + test_pair_index=data["test_pair_index"], + response=data["response"], + cost_usd=Decimal(data.get("cost_usd", "0")), + tokens_input=data.get("tokens_input", 0), + tokens_output=data.get("tokens_output", 0), + duration_seconds=data.get("duration_seconds", 0.0), + error=data.get("error"), + timestamp=datetime.fromisoformat(data["timestamp"]), + ) + + +@dataclass +class TaskCheckpoint: + """Checkpoint for within-task progress.""" + + schema_version: int = 1 + task_id: str = "" + completed_attempts: list[AttemptResult] = field(default_factory=list) + total_cost_usd: Decimal = Decimal("0") + total_tokens_input: int = 0 + total_tokens_output: int = 0 + started_at: datetime = field(default_factory=_utc_now) + updated_at: datetime = field(default_factory=_utc_now) + + def to_dict(self) -> dict: + return { + "schema_version": self.schema_version, + "task_id": self.task_id, + "completed_attempts": [a.to_dict() for a in self.completed_attempts], + "total_cost_usd": str(self.total_cost_usd), + "total_tokens_input": self.total_tokens_input, + "total_tokens_output": self.total_tokens_output, + "started_at": self.started_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict) -> "TaskCheckpoint": + version = data.get("schema_version", 1) + if version != 1: + raise ValueError(f"Unsupported checkpoint schema version: {version}") + return cls( + schema_version=version, + task_id=data["task_id"], + completed_attempts=[ + AttemptResult.from_dict(a) for a in data.get("completed_attempts", []) + ], + total_cost_usd=Decimal(data.get("total_cost_usd", "0")), + total_tokens_input=data.get("total_tokens_input", 0), + total_tokens_output=data.get("total_tokens_output", 0), + started_at=datetime.fromisoformat(data["started_at"]), + updated_at=datetime.fromisoformat(data["updated_at"]), + ) + + +@dataclass +class TaskProgress: + """Progress of a single task within a batch.""" + + task_id: str + status: TaskStatus = TaskStatus.PENDING + attempts_completed: int = 0 + attempts_total: int = 0 + cost_usd: Decimal = Decimal("0") + error: str | None = None + worker_id: str | None = None + started_at: datetime | None = None + completed_at: datetime | None = None + + def to_dict(self) -> dict: + return { + "task_id": self.task_id, + "status": self.status.value, + "attempts_completed": self.attempts_completed, + "attempts_total": self.attempts_total, + "cost_usd": str(self.cost_usd), + "error": self.error, + "worker_id": self.worker_id, + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + } + + @classmethod + def from_dict(cls, data: dict) -> "TaskProgress": + return cls( + task_id=data["task_id"], + status=TaskStatus(data["status"]), + attempts_completed=data.get("attempts_completed", 0), + attempts_total=data.get("attempts_total", 0), + cost_usd=Decimal(data.get("cost_usd", "0")), + error=data.get("error"), + worker_id=data.get("worker_id"), + started_at=( + datetime.fromisoformat(data["started_at"]) + if data.get("started_at") + else None + ), + completed_at=( + datetime.fromisoformat(data["completed_at"]) + if data.get("completed_at") + else None + ), + ) + + +@dataclass +class BatchProgress: + """Overall progress of a benchmark batch.""" + + schema_version: int = 1 + run_id: str = "" + tasks: dict[str, TaskProgress] = field(default_factory=dict) + total_cost_usd: Decimal = Decimal("0") + total_tokens_input: int = 0 + total_tokens_output: int = 0 + started_at: datetime = field(default_factory=_utc_now) + updated_at: datetime = field(default_factory=_utc_now) + + @property + def pending_count(self) -> int: + return sum(1 for t in self.tasks.values() if t.status == TaskStatus.PENDING) + + @property + def in_progress_count(self) -> int: + return sum(1 for t in self.tasks.values() if t.status == TaskStatus.IN_PROGRESS) + + @property + def completed_count(self) -> int: + return sum(1 for t in self.tasks.values() if t.status == TaskStatus.COMPLETED) + + @property + def failed_count(self) -> int: + return sum(1 for t in self.tasks.values() if t.status == TaskStatus.FAILED) + + @property + def total_count(self) -> int: + return len(self.tasks) + + def to_dict(self) -> dict: + return { + "schema_version": self.schema_version, + "run_id": self.run_id, + "tasks": {tid: t.to_dict() for tid, t in self.tasks.items()}, + "total_cost_usd": str(self.total_cost_usd), + "total_tokens_input": self.total_tokens_input, + "total_tokens_output": self.total_tokens_output, + "started_at": self.started_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict) -> "BatchProgress": + version = data.get("schema_version", 1) + if version != 1: + raise ValueError(f"Unsupported batch progress schema version: {version}") + return cls( + schema_version=version, + run_id=data["run_id"], + tasks={ + tid: TaskProgress.from_dict(t) for tid, t in data.get("tasks", {}).items() + }, + total_cost_usd=Decimal(data.get("total_cost_usd", "0")), + total_tokens_input=data.get("total_tokens_input", 0), + total_tokens_output=data.get("total_tokens_output", 0), + started_at=datetime.fromisoformat(data["started_at"]), + updated_at=datetime.fromisoformat(data["updated_at"]), + ) diff --git a/src/arc_agi_benchmarking/checkpoint/task_checkpoint.py b/src/arc_agi_benchmarking/checkpoint/task_checkpoint.py new file mode 100644 index 00000000..c88c991e --- /dev/null +++ b/src/arc_agi_benchmarking/checkpoint/task_checkpoint.py @@ -0,0 +1,145 @@ +"""Task checkpoint manager for within-task progress tracking.""" + +import json +import logging +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any + +from arc_agi_benchmarking.checkpoint.models import ( + AttemptResult, + TaskCheckpoint, +) +from arc_agi_benchmarking.storage import StorageBackend + +logger = logging.getLogger(__name__) + + +class TaskCheckpointManager: + """Manages within-task checkpointing. + + Tracks completed attempts within a task and enables resuming + from the last successful attempt after interruption. + """ + + def __init__( + self, + storage: StorageBackend, + task_id: str, + checkpoint_dir: str = "checkpoints", + ): + self.storage = storage + self.task_id = task_id + self.checkpoint_dir = checkpoint_dir + self._checkpoint: TaskCheckpoint | None = None + + @property + def checkpoint_key(self) -> str: + return f"{self.checkpoint_dir}/{self.task_id}.json" + + @property + def checkpoint(self) -> TaskCheckpoint: + if self._checkpoint is None: + self._checkpoint = self._load_or_create() + return self._checkpoint + + def _load_or_create(self) -> TaskCheckpoint: + data = self.storage.read_text(self.checkpoint_key) + if data: + try: + checkpoint = TaskCheckpoint.from_dict(json.loads(data)) + logger.info( + f"Resumed checkpoint for {self.task_id} " + f"with {len(checkpoint.completed_attempts)} completed attempts" + ) + return checkpoint + except (json.JSONDecodeError, KeyError, ValueError) as e: + logger.warning(f"Failed to load checkpoint, starting fresh: {e}") + return TaskCheckpoint(task_id=self.task_id) + + def _save(self) -> None: + self.checkpoint.updated_at = datetime.now(timezone.utc) + data = json.dumps(self.checkpoint.to_dict(), indent=2) + self.storage.write_text(self.checkpoint_key, data) + + def get_completed_attempts(self) -> list[AttemptResult]: + """Get list of completed attempts.""" + return list(self.checkpoint.completed_attempts) + + def get_next_attempt_index(self, test_pair_index: int, max_attempts: int) -> int | None: + """Get the next attempt index to run for a test pair. + + Returns None if all attempts for this test pair are complete. + """ + completed = { + a.attempt_index + for a in self.checkpoint.completed_attempts + if a.test_pair_index == test_pair_index + } + + for i in range(max_attempts): + if i not in completed: + return i + return None + + def is_test_pair_complete(self, test_pair_index: int, max_attempts: int) -> bool: + """Check if all attempts for a test pair are complete.""" + return self.get_next_attempt_index(test_pair_index, max_attempts) is None + + def record_attempt( + self, + test_pair_index: int, + attempt_index: int, + response: Any, + cost_usd: Decimal = Decimal("0"), + tokens_input: int = 0, + tokens_output: int = 0, + duration_seconds: float = 0.0, + error: str | None = None, + ) -> None: + """Record a completed attempt and save checkpoint.""" + result = AttemptResult( + attempt_index=attempt_index, + test_pair_index=test_pair_index, + response=response, + cost_usd=cost_usd, + tokens_input=tokens_input, + tokens_output=tokens_output, + duration_seconds=duration_seconds, + error=error, + ) + + self.checkpoint.completed_attempts.append(result) + self.checkpoint.total_cost_usd += cost_usd + self.checkpoint.total_tokens_input += tokens_input + self.checkpoint.total_tokens_output += tokens_output + self._save() + + logger.debug( + f"Checkpointed attempt {attempt_index} for test pair {test_pair_index} " + f"of task {self.task_id}" + ) + + def get_results_for_test_pair(self, test_pair_index: int) -> list[AttemptResult]: + """Get all completed attempts for a specific test pair.""" + return [ + a + for a in self.checkpoint.completed_attempts + if a.test_pair_index == test_pair_index + ] + + def delete_checkpoint(self) -> None: + """Delete the checkpoint file after successful task completion.""" + self.storage.delete(self.checkpoint_key) + self._checkpoint = None + logger.debug(f"Deleted checkpoint for task {self.task_id}") + + def get_summary(self) -> dict: + """Get a summary of the checkpoint state.""" + return { + "task_id": self.task_id, + "completed_attempts": len(self.checkpoint.completed_attempts), + "total_cost_usd": str(self.checkpoint.total_cost_usd), + "total_tokens_input": self.checkpoint.total_tokens_input, + "total_tokens_output": self.checkpoint.total_tokens_output, + } diff --git a/src/arc_agi_benchmarking/resilience/__init__.py b/src/arc_agi_benchmarking/resilience/__init__.py new file mode 100644 index 00000000..ad688473 --- /dev/null +++ b/src/arc_agi_benchmarking/resilience/__init__.py @@ -0,0 +1,43 @@ +""" +Resilience module for timeout and circuit breaker functionality. + +This module provides resilience patterns to handle: +- Request timeouts to prevent indefinite hangs +- Circuit breakers to prevent cascading failures + +Usage: + from arc_agi_benchmarking.resilience import ( + CircuitBreaker, + CircuitBreakerOpenError, + CircuitBreakerState, + TaskTimeoutError, + request_timeout, + task_timeout, + ) +""" + +from arc_agi_benchmarking.resilience.timeout import ( + TaskTimeoutError, + request_timeout, + task_timeout, +) +from arc_agi_benchmarking.resilience.circuit_breaker import ( + CircuitBreaker, + CircuitBreakerOpenError, + CircuitBreakerState, + CircuitBreakerRegistry, + get_circuit_breaker, + get_circuit_breaker_registry, +) + +__all__ = [ + "CircuitBreaker", + "CircuitBreakerOpenError", + "CircuitBreakerRegistry", + "CircuitBreakerState", + "TaskTimeoutError", + "get_circuit_breaker", + "get_circuit_breaker_registry", + "request_timeout", + "task_timeout", +] diff --git a/src/arc_agi_benchmarking/resilience/circuit_breaker.py b/src/arc_agi_benchmarking/resilience/circuit_breaker.py new file mode 100644 index 00000000..e132251b --- /dev/null +++ b/src/arc_agi_benchmarking/resilience/circuit_breaker.py @@ -0,0 +1,269 @@ +"""Circuit breaker implementation for preventing cascading failures.""" + +import logging +import threading +import time +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Set + +logger = logging.getLogger(__name__) + + +class CircuitBreakerState(str, Enum): + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" + + +class CircuitBreakerOpenError(Exception): + """Raised when the circuit breaker is open.""" + + def __init__( + self, + message: str, + provider: Optional[str] = None, + recovery_time: Optional[float] = None, + failure_count: Optional[int] = None, + ): + super().__init__(message) + self.provider = provider + self.recovery_time = recovery_time + self.failure_count = failure_count + + +@dataclass +class CircuitBreakerConfig: + failure_threshold: int = 5 + recovery_timeout: float = 60.0 + success_threshold: int = 2 + failure_window: float = 0.0 + failure_exceptions: Optional[Set[type]] = None + excluded_exceptions: Optional[Set[type]] = None + + +@dataclass +class CircuitBreakerStats: + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + rejected_requests: int = 0 + state_transitions: int = 0 + last_failure_time: Optional[float] = None + last_success_time: Optional[float] = None + last_state_change_time: Optional[float] = None + current_consecutive_failures: int = 0 + current_consecutive_successes: int = 0 + + +class CircuitBreaker: + """Circuit breaker with configurable thresholds.""" + + def __init__( + self, + name: str, + failure_threshold: int = 5, + recovery_timeout: float = 60.0, + success_threshold: int = 2, + failure_window: float = 0.0, + failure_exceptions: Optional[Set[type]] = None, + excluded_exceptions: Optional[Set[type]] = None, + ): + self.name = name + self.config = CircuitBreakerConfig( + failure_threshold=failure_threshold, + recovery_timeout=recovery_timeout, + success_threshold=success_threshold, + failure_window=failure_window, + failure_exceptions=failure_exceptions, + excluded_exceptions=excluded_exceptions, + ) + self.stats = CircuitBreakerStats() + self._state = CircuitBreakerState.CLOSED + self._failure_times: list[float] = [] + self._lock = threading.RLock() + + @property + def state(self) -> CircuitBreakerState: + with self._lock: + if self._state == CircuitBreakerState.OPEN: + if self._should_attempt_recovery(): + self._transition_to(CircuitBreakerState.HALF_OPEN) + return self._state + + def _should_attempt_recovery(self) -> bool: + if self.stats.last_failure_time is None: + return True + elapsed = time.time() - self.stats.last_failure_time + return elapsed >= self.config.recovery_timeout + + def _transition_to(self, new_state: CircuitBreakerState) -> None: + old_state = self._state + self._state = new_state + self.stats.state_transitions += 1 + self.stats.last_state_change_time = time.time() + logger.info(f"Circuit breaker '{self.name}': {old_state.value} -> {new_state.value}") + + def _count_recent_failures(self) -> int: + if self.config.failure_window <= 0: + return self.stats.current_consecutive_failures + cutoff = time.time() - self.config.failure_window + self._failure_times = [t for t in self._failure_times if t > cutoff] + return len(self._failure_times) + + def _should_count_as_failure(self, exception: Exception) -> bool: + exc_type = type(exception) + if self.config.excluded_exceptions and exc_type in self.config.excluded_exceptions: + return False + if self.config.failure_exceptions: + return exc_type in self.config.failure_exceptions + return True + + def can_execute(self) -> bool: + return self.state != CircuitBreakerState.OPEN + + def raise_if_open(self) -> None: + if not self.can_execute(): + with self._lock: + self.stats.rejected_requests += 1 + recovery_time = None + if self.stats.last_failure_time: + recovery_time = max( + 0, self.config.recovery_timeout - (time.time() - self.stats.last_failure_time) + ) + raise CircuitBreakerOpenError( + f"Circuit breaker '{self.name}' is OPEN. Retry after {recovery_time:.1f}s" if recovery_time else "", + provider=self.name, + recovery_time=recovery_time, + failure_count=self.stats.current_consecutive_failures, + ) + + def record_success(self) -> None: + with self._lock: + self.stats.total_requests += 1 + self.stats.successful_requests += 1 + self.stats.last_success_time = time.time() + self.stats.current_consecutive_failures = 0 + self.stats.current_consecutive_successes += 1 + + if self._state == CircuitBreakerState.HALF_OPEN: + if self.stats.current_consecutive_successes >= self.config.success_threshold: + self._transition_to(CircuitBreakerState.CLOSED) + self.stats.current_consecutive_successes = 0 + + def record_failure(self, exception: Optional[Exception] = None) -> None: + with self._lock: + if exception and not self._should_count_as_failure(exception): + return + + self.stats.total_requests += 1 + self.stats.failed_requests += 1 + self.stats.last_failure_time = time.time() + self.stats.current_consecutive_failures += 1 + self.stats.current_consecutive_successes = 0 + self._failure_times.append(time.time()) + + failure_count = self._count_recent_failures() + logger.warning( + f"Circuit breaker '{self.name}' failure ({failure_count}/{self.config.failure_threshold})" + ) + + if self._state == CircuitBreakerState.HALF_OPEN: + self._transition_to(CircuitBreakerState.OPEN) + elif self._state == CircuitBreakerState.CLOSED: + if failure_count >= self.config.failure_threshold: + self._transition_to(CircuitBreakerState.OPEN) + + def reset(self) -> None: + with self._lock: + self._state = CircuitBreakerState.CLOSED + self.stats = CircuitBreakerStats() + self._failure_times = [] + + def get_stats(self) -> dict[str, Any]: + with self._lock: + return { + "name": self.name, + "state": self._state.value, + "total_requests": self.stats.total_requests, + "successful_requests": self.stats.successful_requests, + "failed_requests": self.stats.failed_requests, + "rejected_requests": self.stats.rejected_requests, + "state_transitions": self.stats.state_transitions, + "consecutive_failures": self.stats.current_consecutive_failures, + "consecutive_successes": self.stats.current_consecutive_successes, + "last_failure_time": self.stats.last_failure_time, + "last_success_time": self.stats.last_success_time, + } + + def __repr__(self) -> str: + return ( + f"CircuitBreaker(name='{self.name}', state={self._state.value}, " + f"failures={self.stats.current_consecutive_failures}/{self.config.failure_threshold})" + ) + + +class CircuitBreakerRegistry: + """Registry for managing multiple circuit breakers by name.""" + + def __init__(self, default_config: Optional[CircuitBreakerConfig] = None): + self._breakers: dict[str, CircuitBreaker] = {} + self._lock = threading.RLock() + self._default_config = default_config or CircuitBreakerConfig() + + def get(self, name: str) -> Optional[CircuitBreaker]: + with self._lock: + return self._breakers.get(name) + + def get_or_create( + self, + name: str, + failure_threshold: Optional[int] = None, + recovery_timeout: Optional[float] = None, + **kwargs: Any, + ) -> CircuitBreaker: + with self._lock: + if name not in self._breakers: + self._breakers[name] = CircuitBreaker( + name=name, + failure_threshold=failure_threshold or self._default_config.failure_threshold, + recovery_timeout=recovery_timeout or self._default_config.recovery_timeout, + success_threshold=kwargs.get("success_threshold", self._default_config.success_threshold), + failure_window=kwargs.get("failure_window", self._default_config.failure_window), + failure_exceptions=kwargs.get("failure_exceptions", self._default_config.failure_exceptions), + excluded_exceptions=kwargs.get("excluded_exceptions", self._default_config.excluded_exceptions), + ) + return self._breakers[name] + + def get_all_stats(self) -> dict[str, dict[str, Any]]: + with self._lock: + return {name: cb.get_stats() for name, cb in self._breakers.items()} + + def reset_all(self) -> None: + with self._lock: + for cb in self._breakers.values(): + cb.reset() + + def remove(self, name: str) -> bool: + with self._lock: + if name in self._breakers: + del self._breakers[name] + return True + return False + + +_global_registry: Optional[CircuitBreakerRegistry] = None +_global_registry_lock = threading.Lock() + + +def get_circuit_breaker_registry() -> CircuitBreakerRegistry: + global _global_registry + if _global_registry is None: + with _global_registry_lock: + if _global_registry is None: + _global_registry = CircuitBreakerRegistry() + return _global_registry + + +def get_circuit_breaker(name: str, **kwargs: Any) -> CircuitBreaker: + return get_circuit_breaker_registry().get_or_create(name, **kwargs) diff --git a/src/arc_agi_benchmarking/resilience/timeout.py b/src/arc_agi_benchmarking/resilience/timeout.py new file mode 100644 index 00000000..d8a01df9 --- /dev/null +++ b/src/arc_agi_benchmarking/resilience/timeout.py @@ -0,0 +1,137 @@ +"""Timeout utilities for API calls and task execution.""" + +import asyncio +import logging +import sys +import time +from contextlib import asynccontextmanager, contextmanager +from functools import wraps +from typing import Any, Callable, Optional, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +# Python 3.11+ has asyncio.timeout, older versions need a polyfill +_HAS_ASYNCIO_TIMEOUT = sys.version_info >= (3, 11) + + +class TaskTimeoutError(Exception): + """Raised when a task or request times out.""" + + def __init__(self, message: str, elapsed: Optional[float] = None, timeout: Optional[float] = None): + super().__init__(message) + self.elapsed = elapsed + self.timeout = timeout + + +if _HAS_ASYNCIO_TIMEOUT: + @asynccontextmanager + async def request_timeout(seconds: float, operation: str = "request"): + """Async context manager for request timeouts.""" + if seconds <= 0: + yield + return + + start_time = time.monotonic() + try: + async with asyncio.timeout(seconds): + yield + except asyncio.TimeoutError: + elapsed = time.monotonic() - start_time + raise TaskTimeoutError( + f"{operation} timed out after {elapsed:.2f}s (limit: {seconds}s)", + elapsed=elapsed, + timeout=seconds, + ) +else: + @asynccontextmanager + async def request_timeout(seconds: float, operation: str = "request"): + """Async context manager for request timeouts (Python 3.10 compatible).""" + if seconds <= 0: + yield + return + + start_time = time.monotonic() + task = asyncio.current_task() + loop = asyncio.get_event_loop() + + timeout_handle = loop.call_later(seconds, task.cancel) + try: + yield + except asyncio.CancelledError: + elapsed = time.monotonic() - start_time + if elapsed >= seconds: + raise TaskTimeoutError( + f"{operation} timed out after {elapsed:.2f}s (limit: {seconds}s)", + elapsed=elapsed, + timeout=seconds, + ) + raise + finally: + timeout_handle.cancel() + + +@contextmanager +def sync_timeout(seconds: float, operation: str = "operation"): + """Tracks elapsed time and logs warning if exceeded. Does NOT interrupt.""" + start_time = time.monotonic() + try: + yield + finally: + elapsed = time.monotonic() - start_time + if elapsed > seconds: + logger.warning(f"{operation} exceeded duration: {elapsed:.2f}s > {seconds}s") + + +async def task_timeout( + coro_or_func: Callable[..., T], + timeout_seconds: float, + operation: str = "task", + *args: Any, + **kwargs: Any, +) -> T: + """Execute a coroutine or sync function with a timeout. + + Note: For sync functions, the timeout only cancels the await, not the thread. + The underlying thread continues running until completion. + """ + if timeout_seconds <= 0: + if asyncio.iscoroutinefunction(coro_or_func): + return await coro_or_func(*args, **kwargs) + else: + return await asyncio.to_thread(coro_or_func, *args, **kwargs) + + start_time = time.monotonic() + try: + if asyncio.iscoroutinefunction(coro_or_func): + coro = coro_or_func(*args, **kwargs) + else: + coro = asyncio.to_thread(coro_or_func, *args, **kwargs) + return await asyncio.wait_for(coro, timeout=timeout_seconds) + except asyncio.TimeoutError: + elapsed = time.monotonic() - start_time + raise TaskTimeoutError( + f"{operation} timed out after {elapsed:.2f}s (limit: {timeout_seconds}s)", + elapsed=elapsed, + timeout=timeout_seconds, + ) + + +def with_timeout(timeout_seconds: float, operation: Optional[str] = None): + """Decorator to add timeout to async functions.""" + def decorator(func: Callable[..., T]) -> Callable[..., T]: + op_name = operation or func.__name__ + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + return await task_timeout(func, timeout_seconds, op_name, *args, **kwargs) + + return wrapper + + return decorator + + +DEFAULT_REQUEST_TIMEOUT = 300 +DEFAULT_REASONING_TIMEOUT = 900 +DEFAULT_TASK_TIMEOUT = 1800 diff --git a/src/arc_agi_benchmarking/storage/__init__.py b/src/arc_agi_benchmarking/storage/__init__.py new file mode 100644 index 00000000..d25c9ef3 --- /dev/null +++ b/src/arc_agi_benchmarking/storage/__init__.py @@ -0,0 +1,15 @@ +"""Storage abstraction layer for checkpoints and submissions.""" + +from arc_agi_benchmarking.storage.base import StorageBackend +from arc_agi_benchmarking.storage.filesystem import LocalStorageBackend + +__all__ = [ + "StorageBackend", + "LocalStorageBackend", +] + +try: + from arc_agi_benchmarking.storage.s3 import S3StorageBackend + __all__.append("S3StorageBackend") +except ImportError: + pass diff --git a/src/arc_agi_benchmarking/storage/base.py b/src/arc_agi_benchmarking/storage/base.py new file mode 100644 index 00000000..8a532a07 --- /dev/null +++ b/src/arc_agi_benchmarking/storage/base.py @@ -0,0 +1,65 @@ +"""Abstract base class for storage backends.""" + +from abc import ABC, abstractmethod +from typing import Optional + + +class StorageBackend(ABC): + """Abstract storage backend for checkpoints and submissions. + + Keys are path-like strings (e.g., "checkpoints/task_123.json"). + """ + + @abstractmethod + def read(self, key: str) -> Optional[bytes]: + """Read data from storage. Returns None if key doesn't exist.""" + pass + + @abstractmethod + def write(self, key: str, data: bytes) -> None: + """Write data to storage atomically.""" + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """Check if a key exists in storage.""" + pass + + @abstractmethod + def delete(self, key: str) -> None: + """Delete a key from storage. No-op if key doesn't exist.""" + pass + + @abstractmethod + def list_keys(self, prefix: str) -> list[str]: + """List all keys with a given prefix.""" + pass + + def read_text(self, key: str, encoding: str = "utf-8") -> Optional[str]: + """Read data as text.""" + data = self.read(key) + if data is None: + return None + return data.decode(encoding) + + def write_text(self, key: str, text: str, encoding: str = "utf-8") -> None: + """Write text data.""" + self.write(key, text.encode(encoding)) + + +class StorageError(Exception): + """Base exception for storage errors.""" + + pass + + +class StorageWriteError(StorageError): + """Raised when a write operation fails.""" + + pass + + +class StorageReadError(StorageError): + """Raised when a read operation fails unexpectedly.""" + + pass diff --git a/src/arc_agi_benchmarking/storage/filesystem.py b/src/arc_agi_benchmarking/storage/filesystem.py new file mode 100644 index 00000000..44b406bb --- /dev/null +++ b/src/arc_agi_benchmarking/storage/filesystem.py @@ -0,0 +1,91 @@ +"""Local filesystem storage backend.""" + +import logging +import os +from pathlib import Path +from typing import Optional + +from arc_agi_benchmarking.storage.base import ( + StorageBackend, + StorageReadError, + StorageWriteError, +) + +logger = logging.getLogger(__name__) + + +class LocalStorageBackend(StorageBackend): + """Filesystem-based storage backend. Writes are atomic via temp file + rename.""" + + def __init__(self, base_dir: Path | str): + self.base_dir = Path(base_dir) + self.base_dir.mkdir(parents=True, exist_ok=True) + + def _resolve_path(self, key: str) -> Path: + """Resolve a key to a full filesystem path within base_dir.""" + path = (self.base_dir / key).resolve() + try: + path.relative_to(self.base_dir.resolve()) + except ValueError: + raise ValueError(f"Key '{key}' would escape base directory") + return path + + def read(self, key: str) -> Optional[bytes]: + path = self._resolve_path(key) + if not path.exists(): + return None + try: + return path.read_bytes() + except PermissionError as e: + raise StorageReadError(f"Permission denied reading {key}: {e}") + except OSError as e: + raise StorageReadError(f"Failed to read {key}: {e}") + + def write(self, key: str, data: bytes) -> None: + path = self._resolve_path(key) + path.parent.mkdir(parents=True, exist_ok=True) + temp_path = path.with_suffix(path.suffix + ".tmp") + try: + temp_path.write_bytes(data) + os.replace(str(temp_path), str(path)) + except PermissionError as e: + temp_path.unlink(missing_ok=True) + raise StorageWriteError(f"Permission denied writing {key}: {e}") + except OSError as e: + temp_path.unlink(missing_ok=True) + raise StorageWriteError(f"Failed to write {key}: {e}") + + def exists(self, key: str) -> bool: + return self._resolve_path(key).exists() + + def delete(self, key: str) -> None: + self._resolve_path(key).unlink(missing_ok=True) + + def list_keys(self, prefix: str) -> list[str]: + prefix_path = self._resolve_path(prefix) + + if prefix_path.is_dir(): + keys = [] + for path in prefix_path.rglob("*"): + if path.is_file(): + keys.append(str(path.relative_to(self.base_dir))) + return sorted(keys) + + parent = prefix_path.parent + if not parent.exists(): + return [] + + prefix_name = prefix_path.name + keys = [] + for path in parent.iterdir(): + if path.is_file() and path.name.startswith(prefix_name): + keys.append(str(path.relative_to(self.base_dir))) + elif path.is_dir() and path.name.startswith(prefix_name): + for subpath in path.rglob("*"): + if subpath.is_file(): + keys.append(str(subpath.relative_to(self.base_dir))) + + return sorted(keys) + + def __repr__(self) -> str: + return f"LocalStorageBackend(base_dir={self.base_dir})" diff --git a/src/arc_agi_benchmarking/storage/s3.py b/src/arc_agi_benchmarking/storage/s3.py new file mode 100644 index 00000000..a15c3a10 --- /dev/null +++ b/src/arc_agi_benchmarking/storage/s3.py @@ -0,0 +1,99 @@ +"""S3 storage backend for AWS deployments.""" + +import logging +from typing import Optional + +try: + import boto3 + from botocore.exceptions import ClientError +except ImportError: + raise ImportError( + "boto3 is required for S3StorageBackend. " + "Install it with: pip install boto3" + ) + +from arc_agi_benchmarking.storage.base import ( + StorageBackend, + StorageReadError, + StorageWriteError, +) + +logger = logging.getLogger(__name__) + + +class S3StorageBackend(StorageBackend): + """S3-based storage backend. S3 PutObject is atomic.""" + + def __init__( + self, + bucket: str, + prefix: str = "", + region_name: Optional[str] = None, + ): + self.bucket = bucket + self.prefix = prefix.strip("/") + self.s3 = boto3.client("s3", region_name=region_name) + + def _full_key(self, key: str) -> str: + if self.prefix: + return f"{self.prefix}/{key}" + return key + + def _strip_prefix(self, full_key: str) -> str: + if self.prefix and full_key.startswith(f"{self.prefix}/"): + return full_key[len(self.prefix) + 1 :] + return full_key + + def read(self, key: str) -> Optional[bytes]: + full_key = self._full_key(key) + try: + response = self.s3.get_object(Bucket=self.bucket, Key=full_key) + return response["Body"].read() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "NoSuchKey": + return None + logger.error(f"S3 read error for {key}: {e}") + raise StorageReadError(f"Failed to read {key} from S3: {e}") + + def write(self, key: str, data: bytes) -> None: + full_key = self._full_key(key) + try: + self.s3.put_object(Bucket=self.bucket, Key=full_key, Body=data) + except ClientError as e: + logger.error(f"S3 write error for {key}: {e}") + raise StorageWriteError(f"Failed to write {key} to S3: {e}") + + def exists(self, key: str) -> bool: + full_key = self._full_key(key) + try: + self.s3.head_object(Bucket=self.bucket, Key=full_key) + return True + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "") + if error_code == "404": + return False + raise StorageReadError(f"Failed to check existence of {key}: {e}") + + def delete(self, key: str) -> None: + full_key = self._full_key(key) + try: + self.s3.delete_object(Bucket=self.bucket, Key=full_key) + except ClientError as e: + logger.warning(f"S3 delete error for {key}: {e}") + + def list_keys(self, prefix: str) -> list[str]: + full_prefix = self._full_key(prefix) + keys = [] + paginator = self.s3.get_paginator("list_objects_v2") + try: + for page in paginator.paginate(Bucket=self.bucket, Prefix=full_prefix): + for obj in page.get("Contents", []): + keys.append(self._strip_prefix(obj["Key"])) + except ClientError as e: + raise StorageReadError(f"Failed to list keys with prefix {prefix}: {e}") + return sorted(keys) + + def __repr__(self) -> str: + prefix_str = f", prefix={self.prefix}" if self.prefix else "" + return f"S3StorageBackend(bucket={self.bucket}{prefix_str})" diff --git a/src/arc_agi_benchmarking/tests/test_checkpoint.py b/src/arc_agi_benchmarking/tests/test_checkpoint.py new file mode 100644 index 00000000..1a313222 --- /dev/null +++ b/src/arc_agi_benchmarking/tests/test_checkpoint.py @@ -0,0 +1,455 @@ +"""Tests for checkpointing and progress tracking.""" + +from datetime import datetime, timedelta, timezone +from decimal import Decimal +from pathlib import Path + +import pytest + +from arc_agi_benchmarking.checkpoint import ( + AttemptResult, + BatchProgress, + BatchProgressManager, + TaskCheckpoint, + TaskCheckpointManager, + TaskProgress, + TaskStatus, +) +from arc_agi_benchmarking.storage import LocalStorageBackend + + +class TestTaskStatus: + def test_status_values(self): + assert TaskStatus.PENDING.value == "pending" + assert TaskStatus.IN_PROGRESS.value == "in_progress" + assert TaskStatus.COMPLETED.value == "completed" + assert TaskStatus.FAILED.value == "failed" + + +class TestAttemptResult: + def test_to_dict_and_from_dict(self): + result = AttemptResult( + attempt_index=0, + test_pair_index=1, + response={"answer": [[1, 2], [3, 4]]}, + cost_usd=Decimal("0.005"), + tokens_input=100, + tokens_output=50, + duration_seconds=1.5, + ) + + data = result.to_dict() + restored = AttemptResult.from_dict(data) + + assert restored.attempt_index == 0 + assert restored.test_pair_index == 1 + assert restored.response == {"answer": [[1, 2], [3, 4]]} + assert restored.cost_usd == Decimal("0.005") + assert restored.tokens_input == 100 + assert restored.tokens_output == 50 + assert restored.duration_seconds == 1.5 + + def test_with_error(self): + result = AttemptResult( + attempt_index=0, + test_pair_index=0, + response=None, + error="API timeout", + ) + + data = result.to_dict() + restored = AttemptResult.from_dict(data) + + assert restored.error == "API timeout" + assert restored.response is None + + +class TestTaskCheckpoint: + def test_to_dict_and_from_dict(self): + checkpoint = TaskCheckpoint( + task_id="task_001", + total_cost_usd=Decimal("0.01"), + ) + checkpoint.completed_attempts.append( + AttemptResult( + attempt_index=0, + test_pair_index=0, + response={"answer": [[1]]}, + ) + ) + + data = checkpoint.to_dict() + restored = TaskCheckpoint.from_dict(data) + + assert restored.task_id == "task_001" + assert restored.total_cost_usd == Decimal("0.01") + assert len(restored.completed_attempts) == 1 + + def test_unsupported_version_raises(self): + data = {"schema_version": 99, "task_id": "test"} + with pytest.raises(ValueError, match="Unsupported checkpoint schema version"): + TaskCheckpoint.from_dict(data) + + +class TestTaskProgress: + def test_to_dict_and_from_dict(self): + progress = TaskProgress( + task_id="task_001", + status=TaskStatus.IN_PROGRESS, + attempts_completed=2, + attempts_total=3, + cost_usd=Decimal("0.05"), + worker_id="worker_1", + started_at=datetime.utcnow(), + ) + + data = progress.to_dict() + restored = TaskProgress.from_dict(data) + + assert restored.task_id == "task_001" + assert restored.status == TaskStatus.IN_PROGRESS + assert restored.attempts_completed == 2 + assert restored.attempts_total == 3 + assert restored.cost_usd == Decimal("0.05") + + +class TestBatchProgress: + def test_to_dict_and_from_dict(self): + batch = BatchProgress(run_id="run_123") + batch.tasks["task_1"] = TaskProgress(task_id="task_1", status=TaskStatus.COMPLETED) + batch.tasks["task_2"] = TaskProgress(task_id="task_2", status=TaskStatus.PENDING) + + data = batch.to_dict() + restored = BatchProgress.from_dict(data) + + assert restored.run_id == "run_123" + assert len(restored.tasks) == 2 + assert restored.tasks["task_1"].status == TaskStatus.COMPLETED + + def test_count_properties(self): + batch = BatchProgress(run_id="test") + batch.tasks["t1"] = TaskProgress(task_id="t1", status=TaskStatus.PENDING) + batch.tasks["t2"] = TaskProgress(task_id="t2", status=TaskStatus.IN_PROGRESS) + batch.tasks["t3"] = TaskProgress(task_id="t3", status=TaskStatus.COMPLETED) + batch.tasks["t4"] = TaskProgress(task_id="t4", status=TaskStatus.FAILED) + + assert batch.pending_count == 1 + assert batch.in_progress_count == 1 + assert batch.completed_count == 1 + assert batch.failed_count == 1 + assert batch.total_count == 4 + + +class TestBatchProgressManager: + @pytest.fixture + def storage(self, tmp_path: Path) -> LocalStorageBackend: + return LocalStorageBackend(tmp_path) + + @pytest.fixture + def manager(self, storage: LocalStorageBackend) -> BatchProgressManager: + return BatchProgressManager(storage, run_id="test_run") + + def test_initialize_empty_task_list(self, manager: BatchProgressManager): + manager.initialize_tasks([]) + assert manager.progress.total_count == 0 + + def test_claim_nonexistent_task(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"]) + assert not manager.claim_task("nonexistent") + + def test_initialize_preserves_existing(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"], attempts_per_task=2) + manager.claim_task("task_1") + manager.mark_completed("task_1") + + manager.initialize_tasks(["task_1", "task_2"], attempts_per_task=3) + + assert manager.progress.tasks["task_1"].status == TaskStatus.COMPLETED + assert manager.progress.tasks["task_1"].attempts_total == 2 + assert manager.progress.tasks["task_2"].attempts_total == 3 + + def test_corrupted_json_recovery(self, storage: LocalStorageBackend): + storage.write_text("progress.json", "not valid json{{{") + manager = BatchProgressManager(storage, run_id="test_run") + assert manager.progress.total_count == 0 + + def test_initialize_tasks(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1", "task_2", "task_3"], attempts_per_task=2) + + assert manager.progress.total_count == 3 + assert all( + t.status == TaskStatus.PENDING for t in manager.progress.tasks.values() + ) + assert all(t.attempts_total == 2 for t in manager.progress.tasks.values()) + + def test_claim_task(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"]) + + assert manager.claim_task("task_1") + assert manager.progress.tasks["task_1"].status == TaskStatus.IN_PROGRESS + assert manager.progress.tasks["task_1"].worker_id is not None + + def test_claim_task_already_claimed(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"]) + manager.claim_task("task_1") + + assert not manager.claim_task("task_1") + + def test_claim_next_task(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1", "task_2"]) + + task_id = manager.claim_next_task() + assert task_id == "task_1" + + task_id = manager.claim_next_task() + assert task_id == "task_2" + + task_id = manager.claim_next_task() + assert task_id is None + + def test_mark_completed(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"]) + manager.claim_task("task_1") + manager.mark_completed( + "task_1", + cost_usd=Decimal("0.05"), + tokens_input=100, + tokens_output=50, + ) + + task = manager.progress.tasks["task_1"] + assert task.status == TaskStatus.COMPLETED + assert task.completed_at is not None + assert manager.progress.total_cost_usd == Decimal("0.05") + + def test_mark_failed(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"]) + manager.claim_task("task_1") + manager.mark_failed("task_1", error="API error") + + task = manager.progress.tasks["task_1"] + assert task.status == TaskStatus.FAILED + assert task.error == "API error" + + def test_mark_failed_accumulates_costs(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1", "task_2"]) + manager.claim_task("task_1") + manager.mark_completed( + "task_1", + cost_usd=Decimal("0.05"), + tokens_input=100, + tokens_output=50, + ) + manager.claim_task("task_2") + manager.mark_failed( + "task_2", + error="API error", + cost_usd=Decimal("0.03"), + tokens_input=80, + tokens_output=20, + ) + + assert manager.progress.total_cost_usd == Decimal("0.08") + assert manager.progress.total_tokens_input == 180 + assert manager.progress.total_tokens_output == 70 + assert manager.progress.tasks["task_2"].cost_usd == Decimal("0.03") + + def test_run_id_mismatch_starts_fresh(self, storage: LocalStorageBackend): + manager1 = BatchProgressManager(storage, run_id="run_1") + manager1.initialize_tasks(["task_1", "task_2"]) + manager1.claim_task("task_1") + manager1.mark_completed("task_1") + + manager2 = BatchProgressManager(storage, run_id="run_2") + + assert manager2.progress.run_id == "run_2" + assert manager2.progress.total_count == 0 + + def test_is_complete(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1", "task_2"]) + + assert not manager.is_complete() + + manager.claim_task("task_1") + manager.mark_completed("task_1") + assert not manager.is_complete() + + manager.claim_task("task_2") + manager.mark_failed("task_2", "error") + assert manager.is_complete() + + def test_persistence(self, storage: LocalStorageBackend): + manager1 = BatchProgressManager(storage, run_id="test_run") + manager1.initialize_tasks(["task_1", "task_2"]) + manager1.claim_task("task_1") + manager1.mark_completed("task_1", cost_usd=Decimal("0.10")) + + manager2 = BatchProgressManager(storage, run_id="test_run") + + assert manager2.progress.total_count == 2 + assert manager2.progress.tasks["task_1"].status == TaskStatus.COMPLETED + assert manager2.progress.tasks["task_2"].status == TaskStatus.PENDING + assert manager2.progress.total_cost_usd == Decimal("0.10") + + def test_reset_stale_tasks(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1"]) + manager.claim_task("task_1") + + manager.progress.tasks["task_1"].started_at = datetime.now(timezone.utc) - timedelta( + hours=2 + ) + manager._save() + + reset_count = manager.reset_stale_tasks(max_age_seconds=3600) + + assert reset_count == 1 + assert manager.progress.tasks["task_1"].status == TaskStatus.PENDING + + def test_get_summary(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1", "task_2"]) + manager.claim_task("task_1") + manager.mark_completed("task_1", cost_usd=Decimal("0.05")) + + summary = manager.get_summary() + + assert summary["total"] == 2 + assert summary["completed"] == 1 + assert summary["pending"] == 1 + assert summary["total_cost_usd"] == "0.05" + + def test_retry_failed_tasks(self, manager: BatchProgressManager): + manager.initialize_tasks(["task_1", "task_2", "task_3"]) + manager.claim_task("task_1") + manager.mark_completed("task_1") + manager.claim_task("task_2") + manager.mark_failed("task_2", "error1") + manager.claim_task("task_3") + manager.mark_failed("task_3", "error2") + + assert manager.progress.completed_count == 1 + assert manager.progress.failed_count == 2 + + reset_count = manager.retry_failed_tasks() + + assert reset_count == 2 + assert manager.progress.completed_count == 1 + assert manager.progress.failed_count == 0 + assert manager.progress.pending_count == 2 + assert manager.progress.tasks["task_2"].status == TaskStatus.PENDING + assert manager.progress.tasks["task_2"].error is None + assert manager.progress.tasks["task_3"].status == TaskStatus.PENDING + + +class TestTaskCheckpointManager: + @pytest.fixture + def storage(self, tmp_path: Path) -> LocalStorageBackend: + return LocalStorageBackend(tmp_path) + + @pytest.fixture + def manager(self, storage: LocalStorageBackend) -> TaskCheckpointManager: + return TaskCheckpointManager(storage, task_id="task_001") + + def test_corrupted_json_recovery(self, storage: LocalStorageBackend): + storage.write_text("checkpoints/task_001.json", "{invalid") + manager = TaskCheckpointManager(storage, task_id="task_001") + assert len(manager.get_completed_attempts()) == 0 + + def test_multiple_test_pairs(self, manager: TaskCheckpointManager): + manager.record_attempt(test_pair_index=0, attempt_index=0, response="a") + manager.record_attempt(test_pair_index=1, attempt_index=0, response="b") + manager.record_attempt(test_pair_index=2, attempt_index=0, response="c") + + assert len(manager.get_results_for_test_pair(0)) == 1 + assert len(manager.get_results_for_test_pair(1)) == 1 + assert len(manager.get_results_for_test_pair(2)) == 1 + assert manager.checkpoint.total_cost_usd == Decimal("0") + + def test_record_attempt(self, manager: TaskCheckpointManager): + manager.record_attempt( + test_pair_index=0, + attempt_index=0, + response={"answer": [[1, 2]]}, + cost_usd=Decimal("0.01"), + tokens_input=50, + tokens_output=25, + ) + + attempts = manager.get_completed_attempts() + assert len(attempts) == 1 + assert attempts[0].test_pair_index == 0 + assert attempts[0].attempt_index == 0 + + def test_get_next_attempt_index(self, manager: TaskCheckpointManager): + assert manager.get_next_attempt_index(test_pair_index=0, max_attempts=2) == 0 + + manager.record_attempt( + test_pair_index=0, + attempt_index=0, + response=None, + ) + assert manager.get_next_attempt_index(test_pair_index=0, max_attempts=2) == 1 + + manager.record_attempt( + test_pair_index=0, + attempt_index=1, + response=None, + ) + assert manager.get_next_attempt_index(test_pair_index=0, max_attempts=2) is None + + def test_is_test_pair_complete(self, manager: TaskCheckpointManager): + assert not manager.is_test_pair_complete(test_pair_index=0, max_attempts=2) + + manager.record_attempt(test_pair_index=0, attempt_index=0, response=None) + assert not manager.is_test_pair_complete(test_pair_index=0, max_attempts=2) + + manager.record_attempt(test_pair_index=0, attempt_index=1, response=None) + assert manager.is_test_pair_complete(test_pair_index=0, max_attempts=2) + + def test_get_results_for_test_pair(self, manager: TaskCheckpointManager): + manager.record_attempt(test_pair_index=0, attempt_index=0, response="a") + manager.record_attempt(test_pair_index=1, attempt_index=0, response="b") + manager.record_attempt(test_pair_index=0, attempt_index=1, response="c") + + results = manager.get_results_for_test_pair(0) + assert len(results) == 2 + assert results[0].response == "a" + assert results[1].response == "c" + + def test_persistence(self, storage: LocalStorageBackend): + manager1 = TaskCheckpointManager(storage, task_id="task_001") + manager1.record_attempt( + test_pair_index=0, + attempt_index=0, + response={"answer": [[1]]}, + cost_usd=Decimal("0.02"), + ) + + manager2 = TaskCheckpointManager(storage, task_id="task_001") + + attempts = manager2.get_completed_attempts() + assert len(attempts) == 1 + assert manager2.checkpoint.total_cost_usd == Decimal("0.02") + + def test_delete_checkpoint(self, manager: TaskCheckpointManager, storage: LocalStorageBackend): + manager.record_attempt(test_pair_index=0, attempt_index=0, response=None) + assert storage.exists(manager.checkpoint_key) + + manager.delete_checkpoint() + assert not storage.exists(manager.checkpoint_key) + + def test_get_summary(self, manager: TaskCheckpointManager): + manager.record_attempt( + test_pair_index=0, + attempt_index=0, + response=None, + cost_usd=Decimal("0.01"), + tokens_input=100, + tokens_output=50, + ) + + summary = manager.get_summary() + + assert summary["task_id"] == "task_001" + assert summary["completed_attempts"] == 1 + assert summary["total_cost_usd"] == "0.01" + assert summary["total_tokens_input"] == 100 diff --git a/src/arc_agi_benchmarking/tests/test_resilience.py b/src/arc_agi_benchmarking/tests/test_resilience.py new file mode 100644 index 00000000..bb7ee3ff --- /dev/null +++ b/src/arc_agi_benchmarking/tests/test_resilience.py @@ -0,0 +1,562 @@ +""" +Tests for the resilience module (timeout and circuit breaker functionality). +""" + +import asyncio +import time +import pytest +from unittest.mock import patch, MagicMock + +from arc_agi_benchmarking.resilience import ( + CircuitBreaker, + CircuitBreakerOpenError, + CircuitBreakerState, + TaskTimeoutError, + request_timeout, + task_timeout, +) +from arc_agi_benchmarking.resilience.circuit_breaker import ( + CircuitBreakerConfig, + CircuitBreakerRegistry, + get_circuit_breaker, + get_circuit_breaker_registry, +) +from arc_agi_benchmarking.resilience.timeout import ( + with_timeout, + sync_timeout, + DEFAULT_REQUEST_TIMEOUT, + DEFAULT_REASONING_TIMEOUT, + DEFAULT_TASK_TIMEOUT, +) + + +# ============================================================================= +# Timeout Tests +# ============================================================================= + + +@pytest.mark.asyncio +class TestRequestTimeout: + """Tests for the request_timeout async context manager.""" + + async def test_successful_operation_within_timeout(self): + """Test that operations completing within timeout succeed.""" + async with request_timeout(1.0, "test operation"): + await asyncio.sleep(0.1) + # Should complete without raising + + async def test_timeout_raises_task_timeout_error(self): + """Test that operations exceeding timeout raise TaskTimeoutError.""" + with pytest.raises(TaskTimeoutError) as exc_info: + async with request_timeout(0.1, "slow operation"): + await asyncio.sleep(1.0) + + assert "slow operation" in str(exc_info.value) + assert exc_info.value.timeout == 0.1 + assert exc_info.value.elapsed is not None + assert exc_info.value.elapsed >= 0.1 + + async def test_invalid_timeout_value_logs_warning(self): + """Test that invalid (zero or negative) timeout values are handled.""" + # With zero timeout, should execute without timeout enforcement + async with request_timeout(0, "zero timeout"): + await asyncio.sleep(0.05) + + # With negative timeout, should execute without timeout enforcement + async with request_timeout(-1, "negative timeout"): + await asyncio.sleep(0.05) + + +@pytest.mark.asyncio +class TestTaskTimeout: + """Tests for the task_timeout function.""" + + async def test_async_function_within_timeout(self): + """Test async function completing within timeout.""" + async def async_func(value: int) -> int: + await asyncio.sleep(0.05) + return value * 2 + + result = await task_timeout(async_func, 1.0, "async test", 5) + assert result == 10 + + async def test_sync_function_within_timeout(self): + """Test sync function completing within timeout.""" + def sync_func(value: int) -> int: + time.sleep(0.05) + return value * 3 + + result = await task_timeout(sync_func, 1.0, "sync test", 7) + assert result == 21 + + async def test_async_function_timeout(self): + """Test async function exceeding timeout.""" + async def slow_async(): + await asyncio.sleep(2.0) + return "never reached" + + with pytest.raises(TaskTimeoutError) as exc_info: + await task_timeout(slow_async, 0.1, "slow async") + + assert "slow async" in str(exc_info.value) + assert exc_info.value.timeout == 0.1 + + async def test_sync_function_timeout(self): + """Test sync function exceeding timeout.""" + def slow_sync(): + time.sleep(2.0) + return "never reached" + + with pytest.raises(TaskTimeoutError) as exc_info: + await task_timeout(slow_sync, 0.1, "slow sync") + + assert "slow sync" in str(exc_info.value) + + async def test_kwargs_passed_correctly(self): + """Test that kwargs are passed to the function.""" + async def func_with_kwargs(a: int, b: int = 0) -> int: + return a + b + + result = await task_timeout(func_with_kwargs, 1.0, "kwargs test", 5, b=10) + assert result == 15 + + +@pytest.mark.asyncio +class TestWithTimeoutDecorator: + """Tests for the with_timeout decorator.""" + + async def test_decorator_on_async_function(self): + """Test decorator on async function.""" + @with_timeout(1.0, "decorated function") + async def decorated_func(): + await asyncio.sleep(0.05) + return "success" + + result = await decorated_func() + assert result == "success" + + async def test_decorator_timeout(self): + """Test decorator raises timeout.""" + @with_timeout(0.1, "slow decorated") + async def slow_decorated(): + await asyncio.sleep(1.0) + + with pytest.raises(TaskTimeoutError): + await slow_decorated() + + async def test_decorator_uses_function_name_as_default(self): + """Test decorator uses function name when operation not specified.""" + @with_timeout(0.1) + async def my_slow_function(): + await asyncio.sleep(1.0) + + with pytest.raises(TaskTimeoutError) as exc_info: + await my_slow_function() + + assert "my_slow_function" in str(exc_info.value) + + +class TestSyncTimeout: + """Tests for the sync_timeout context manager.""" + + def test_logs_warning_when_exceeded(self): + """Test that sync_timeout logs a warning when exceeded.""" + # sync_timeout doesn't interrupt, just logs warnings + with sync_timeout(0.01, "quick operation"): + time.sleep(0.05) # Will exceed but won't raise + + def test_no_warning_when_within_limit(self): + """Test that sync_timeout doesn't warn when within limit.""" + with sync_timeout(1.0, "within limit"): + time.sleep(0.01) + + +class TestDefaultTimeoutValues: + """Tests for default timeout constants.""" + + def test_default_values_are_reasonable(self): + """Verify default timeout values are sensible.""" + assert DEFAULT_REQUEST_TIMEOUT == 300 # 5 minutes + assert DEFAULT_REASONING_TIMEOUT == 900 # 15 minutes + assert DEFAULT_TASK_TIMEOUT == 1800 # 30 minutes + assert DEFAULT_REQUEST_TIMEOUT < DEFAULT_REASONING_TIMEOUT < DEFAULT_TASK_TIMEOUT + + +# ============================================================================= +# Circuit Breaker Tests +# ============================================================================= + + +class TestCircuitBreakerBasics: + """Tests for basic CircuitBreaker functionality.""" + + def test_initial_state_is_closed(self): + """Test that a new circuit breaker starts in CLOSED state.""" + cb = CircuitBreaker("test") + assert cb.state == CircuitBreakerState.CLOSED + assert cb.can_execute() is True + + def test_record_success(self): + """Test recording a successful request.""" + cb = CircuitBreaker("test") + cb.record_success() + + stats = cb.get_stats() + assert stats["total_requests"] == 1 + assert stats["successful_requests"] == 1 + assert stats["failed_requests"] == 0 + assert stats["consecutive_successes"] == 1 + + def test_record_failure(self): + """Test recording a failed request.""" + cb = CircuitBreaker("test", failure_threshold=5) + cb.record_failure() + + stats = cb.get_stats() + assert stats["total_requests"] == 1 + assert stats["failed_requests"] == 1 + assert stats["consecutive_failures"] == 1 + assert cb.state == CircuitBreakerState.CLOSED # Still closed + + def test_opens_after_threshold(self): + """Test that circuit opens after reaching failure threshold.""" + cb = CircuitBreaker("test", failure_threshold=3) + + for i in range(3): + cb.record_failure() + + assert cb.state == CircuitBreakerState.OPEN + assert cb.can_execute() is False + + def test_raise_if_open(self): + """Test that raise_if_open raises when circuit is open.""" + cb = CircuitBreaker("test", failure_threshold=1) + cb.record_failure() + + with pytest.raises(CircuitBreakerOpenError) as exc_info: + cb.raise_if_open() + + assert exc_info.value.provider == "test" + assert exc_info.value.failure_count == 1 + + def test_success_resets_consecutive_failures(self): + """Test that success resets consecutive failure count.""" + cb = CircuitBreaker("test", failure_threshold=5) + + cb.record_failure() + cb.record_failure() + assert cb.stats.current_consecutive_failures == 2 + + cb.record_success() + assert cb.stats.current_consecutive_failures == 0 + + +class TestCircuitBreakerStateTransitions: + """Tests for circuit breaker state transitions.""" + + def test_closed_to_open_transition(self): + """Test transition from CLOSED to OPEN.""" + cb = CircuitBreaker("test", failure_threshold=2) + + assert cb.state == CircuitBreakerState.CLOSED + cb.record_failure() + assert cb.state == CircuitBreakerState.CLOSED + cb.record_failure() + assert cb.state == CircuitBreakerState.OPEN + + def test_open_to_half_open_after_recovery(self): + """Test transition from OPEN to HALF_OPEN after recovery timeout.""" + cb = CircuitBreaker("test", failure_threshold=1, recovery_timeout=0.1) + cb.record_failure() + assert cb.state == CircuitBreakerState.OPEN + + # Wait for recovery timeout + time.sleep(0.15) + + # Accessing state triggers the transition + assert cb.state == CircuitBreakerState.HALF_OPEN + + def test_half_open_to_closed_on_success(self): + """Test transition from HALF_OPEN to CLOSED on success.""" + cb = CircuitBreaker("test", failure_threshold=1, recovery_timeout=0.1, success_threshold=2) + cb.record_failure() + time.sleep(0.15) + assert cb.state == CircuitBreakerState.HALF_OPEN + + cb.record_success() + assert cb.state == CircuitBreakerState.HALF_OPEN # Need 2 successes + + cb.record_success() + assert cb.state == CircuitBreakerState.CLOSED + + def test_half_open_to_open_on_failure(self): + """Test transition from HALF_OPEN to OPEN on failure.""" + cb = CircuitBreaker("test", failure_threshold=1, recovery_timeout=0.1) + cb.record_failure() + time.sleep(0.15) + assert cb.state == CircuitBreakerState.HALF_OPEN + + cb.record_failure() + assert cb.state == CircuitBreakerState.OPEN + + +class TestCircuitBreakerReset: + """Tests for circuit breaker reset functionality.""" + + def test_reset_returns_to_closed(self): + """Test that reset returns circuit to CLOSED state.""" + cb = CircuitBreaker("test", failure_threshold=1) + cb.record_failure() + assert cb.state == CircuitBreakerState.OPEN + + cb.reset() + assert cb.state == CircuitBreakerState.CLOSED + assert cb.stats.total_requests == 0 + assert cb.stats.failed_requests == 0 + + +class TestCircuitBreakerExceptionFiltering: + """Tests for circuit breaker exception filtering.""" + + def test_excluded_exceptions_not_counted(self): + """Test that excluded exceptions are not counted as failures.""" + cb = CircuitBreaker( + "test", + failure_threshold=2, + excluded_exceptions={ValueError}, + ) + + cb.record_failure(ValueError("excluded")) + assert cb.stats.failed_requests == 0 + assert cb.state == CircuitBreakerState.CLOSED + + cb.record_failure(RuntimeError("counted")) + cb.record_failure(RuntimeError("counted")) + assert cb.state == CircuitBreakerState.OPEN + + def test_only_specified_exceptions_counted(self): + """Test that only specified failure_exceptions are counted.""" + cb = CircuitBreaker( + "test", + failure_threshold=2, + failure_exceptions={TimeoutError}, + ) + + # This shouldn't count + cb.record_failure(ValueError("not counted")) + assert cb.stats.failed_requests == 0 + + # These should count + cb.record_failure(TimeoutError("counted")) + cb.record_failure(TimeoutError("counted")) + assert cb.state == CircuitBreakerState.OPEN + + +class TestCircuitBreakerRegistry: + """Tests for the CircuitBreakerRegistry.""" + + def test_get_or_create(self): + """Test get_or_create returns same instance for same name.""" + registry = CircuitBreakerRegistry() + + cb1 = registry.get_or_create("provider1", failure_threshold=5) + cb2 = registry.get_or_create("provider1", failure_threshold=10) # Different threshold + + assert cb1 is cb2 # Same instance + assert cb1.config.failure_threshold == 5 # Original threshold kept + + def test_different_names_different_instances(self): + """Test different names create different instances.""" + registry = CircuitBreakerRegistry() + + cb1 = registry.get_or_create("provider1") + cb2 = registry.get_or_create("provider2") + + assert cb1 is not cb2 + + def test_get_returns_none_if_not_exists(self): + """Test get returns None if circuit breaker doesn't exist.""" + registry = CircuitBreakerRegistry() + assert registry.get("nonexistent") is None + + def test_get_all_stats(self): + """Test getting stats for all circuit breakers.""" + registry = CircuitBreakerRegistry() + registry.get_or_create("provider1") + registry.get_or_create("provider2") + + stats = registry.get_all_stats() + assert "provider1" in stats + assert "provider2" in stats + + def test_reset_all(self): + """Test resetting all circuit breakers.""" + registry = CircuitBreakerRegistry() + cb1 = registry.get_or_create("provider1", failure_threshold=1) + cb2 = registry.get_or_create("provider2", failure_threshold=1) + + cb1.record_failure() + cb2.record_failure() + + registry.reset_all() + + assert cb1.state == CircuitBreakerState.CLOSED + assert cb2.state == CircuitBreakerState.CLOSED + + def test_remove(self): + """Test removing a circuit breaker from registry.""" + registry = CircuitBreakerRegistry() + registry.get_or_create("provider1") + + assert registry.remove("provider1") is True + assert registry.get("provider1") is None + assert registry.remove("nonexistent") is False + + +class TestGlobalCircuitBreakerRegistry: + """Tests for the global circuit breaker registry.""" + + def test_get_circuit_breaker_registry_returns_same_instance(self): + """Test global registry returns same instance.""" + registry1 = get_circuit_breaker_registry() + registry2 = get_circuit_breaker_registry() + assert registry1 is registry2 + + def test_get_circuit_breaker_convenience_function(self): + """Test convenience function for getting circuit breakers.""" + cb = get_circuit_breaker("test_provider", failure_threshold=3) + assert cb.name == "test_provider" + assert cb.config.failure_threshold == 3 + + +class TestCircuitBreakerStats: + """Tests for circuit breaker statistics.""" + + def test_stats_track_all_metrics(self): + """Test that stats track all expected metrics.""" + cb = CircuitBreaker("test", failure_threshold=5) + + cb.record_success() + cb.record_failure() + cb.record_success() + + stats = cb.get_stats() + + assert stats["name"] == "test" + assert stats["state"] == "closed" + assert stats["total_requests"] == 3 + assert stats["successful_requests"] == 2 + assert stats["failed_requests"] == 1 + assert stats["consecutive_successes"] == 1 + assert stats["consecutive_failures"] == 0 + + def test_rejected_requests_counted(self): + """Test that rejected requests are counted when circuit is open.""" + cb = CircuitBreaker("test", failure_threshold=1) + cb.record_failure() + + assert cb.state == CircuitBreakerState.OPEN + assert cb.stats.rejected_requests == 0 + + # raise_if_open should increment rejected_requests + with pytest.raises(CircuitBreakerOpenError): + cb.raise_if_open() + assert cb.stats.rejected_requests == 1 + + with pytest.raises(CircuitBreakerOpenError): + cb.raise_if_open() + assert cb.stats.rejected_requests == 2 + + +class TestCircuitBreakerThreadSafety: + """Tests for circuit breaker thread safety.""" + + def test_concurrent_record_operations(self): + """Test that concurrent operations are handled safely.""" + import threading + + cb = CircuitBreaker("test", failure_threshold=100) + errors = [] + + def record_ops(): + try: + for _ in range(50): + cb.record_failure() + cb.record_success() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=record_ops) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert cb.stats.total_requests == 1000 # 10 threads * 100 ops each + + +class TestCircuitBreakerRepr: + """Tests for circuit breaker string representation.""" + + def test_repr(self): + """Test __repr__ output.""" + cb = CircuitBreaker("test", failure_threshold=5) + cb.record_failure() + cb.record_failure() + + repr_str = repr(cb) + assert "test" in repr_str + assert "closed" in repr_str + assert "2/5" in repr_str + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +@pytest.mark.asyncio +class TestTimeoutAndCircuitBreakerIntegration: + """Integration tests for timeout and circuit breaker working together.""" + + async def test_timeout_feeds_circuit_breaker(self): + """Test that timeout errors can be recorded by circuit breaker.""" + cb = CircuitBreaker("test", failure_threshold=2) + + async def slow_operation(): + await asyncio.sleep(1.0) + + for _ in range(2): + try: + await task_timeout(slow_operation, 0.05, "slow op") + except TaskTimeoutError as e: + cb.record_failure(e) + + assert cb.state == CircuitBreakerState.OPEN + + async def test_circuit_breaker_prevents_timeout_attempts(self): + """Test that circuit breaker prevents further attempts after opening.""" + cb = CircuitBreaker("test", failure_threshold=1) + cb.record_failure() # Open the circuit + + # Circuit is open, should not attempt + with pytest.raises(CircuitBreakerOpenError): + cb.raise_if_open() + + # The slow operation is never called + call_count = 0 + + async def tracked_slow_operation(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(1.0) + + # Check circuit before attempting + try: + cb.raise_if_open() + await task_timeout(tracked_slow_operation, 0.1, "tracked") + except CircuitBreakerOpenError: + pass # Expected + + assert call_count == 0 # Operation was never called diff --git a/src/arc_agi_benchmarking/tests/test_run_all_retries.py b/src/arc_agi_benchmarking/tests/test_run_all_retries.py index d9ea3f85..66ff5dbb 100644 --- a/src/arc_agi_benchmarking/tests/test_run_all_retries.py +++ b/src/arc_agi_benchmarking/tests/test_run_all_retries.py @@ -11,6 +11,7 @@ # If 'src' is part of the import path (e.g. 'from src.cli.run_all ...'), adjust as needed. # For now, sticking to the simpler form based on common pytest setups from project root. from cli.run_all import run_single_test_wrapper, AsyncRequestRateLimiter +from arc_agi_benchmarking.resilience import CircuitBreaker # This class was unused and causing a PytestCollectionWarning. Removing it. # class TestRetryableException(Exception): @@ -72,12 +73,15 @@ async def test_retry_and_eventual_success(caplog): # Only pytest fixtures like c mock_arc_instance.generate_task_solution.side_effect = simulator.simulate_generate_task_solution limiter = AsyncRequestRateLimiter(rate=1000, capacity=1000) + circuit_breaker = CircuitBreaker("test_provider", failure_threshold=10) # Execute the function under test result = await run_single_test_wrapper( config_name, task_id, limiter, + circuit_breaker=circuit_breaker, + task_timeout_seconds=300.0, # 5 minute timeout for tests data_dir=TEST_DATA_DIR, # DEFAULT_DATA_DIR save_submission_dir="submissions_test_retries", overwrite_submission=True, # DEFAULT_OVERWRITE_SUBMISSION is False, but True for test clarity @@ -149,10 +153,14 @@ async def test_failure_after_all_retries(caplog): mock_arc_instance.generate_task_solution.side_effect = simulator.simulate_generate_task_solution limiter = AsyncRequestRateLimiter(rate=1000, capacity=1000) + circuit_breaker = CircuitBreaker("test_provider", failure_threshold=10) + result = await run_single_test_wrapper( config_name, task_id, limiter, + circuit_breaker=circuit_breaker, + task_timeout_seconds=300.0, data_dir=TEST_DATA_DIR, save_submission_dir="submissions_test_retries", overwrite_submission=True, @@ -200,10 +208,14 @@ async def test_non_retryable_exception(caplog): mock_arc_instance.generate_task_solution.side_effect = simulator.simulate_generate_task_solution limiter = AsyncRequestRateLimiter(rate=1000, capacity=1000) + circuit_breaker = CircuitBreaker("test_provider", failure_threshold=10) + result = await run_single_test_wrapper( config_name, task_id, limiter, + circuit_breaker=circuit_breaker, + task_timeout_seconds=300.0, data_dir=TEST_DATA_DIR, save_submission_dir="submissions_test_retries", overwrite_submission=True, diff --git a/src/arc_agi_benchmarking/tests/test_storage.py b/src/arc_agi_benchmarking/tests/test_storage.py new file mode 100644 index 00000000..895cfc2b --- /dev/null +++ b/src/arc_agi_benchmarking/tests/test_storage.py @@ -0,0 +1,296 @@ +"""Tests for storage backends.""" + +import os +import tempfile +from pathlib import Path + +import pytest + +from arc_agi_benchmarking.storage.base import ( + StorageBackend, + StorageReadError, + StorageWriteError, +) +from arc_agi_benchmarking.storage.filesystem import LocalStorageBackend + + +class TestLocalStorageBackend: + """Tests for LocalStorageBackend.""" + + @pytest.fixture + def storage(self, tmp_path: Path) -> LocalStorageBackend: + """Create a LocalStorageBackend with a temporary directory.""" + return LocalStorageBackend(tmp_path) + + def test_init_creates_directory(self, tmp_path: Path): + """Test that initialization creates the base directory.""" + new_dir = tmp_path / "new_storage_dir" + assert not new_dir.exists() + + storage = LocalStorageBackend(new_dir) + + assert new_dir.exists() + assert new_dir.is_dir() + + def test_write_and_read(self, storage: LocalStorageBackend): + """Test basic write and read operations.""" + data = b"hello world" + storage.write("test.txt", data) + + result = storage.read("test.txt") + assert result == data + + def test_write_and_read_text(self, storage: LocalStorageBackend): + """Test text convenience methods.""" + text = "hello world" + storage.write_text("test.txt", text) + + result = storage.read_text("test.txt") + assert result == text + + def test_read_nonexistent_returns_none(self, storage: LocalStorageBackend): + """Test that reading a nonexistent key returns None.""" + result = storage.read("nonexistent.txt") + assert result is None + + def test_write_creates_parent_directories(self, storage: LocalStorageBackend): + """Test that write creates necessary parent directories.""" + data = b"nested data" + storage.write("deeply/nested/path/file.txt", data) + + result = storage.read("deeply/nested/path/file.txt") + assert result == data + + def test_write_is_atomic(self, storage: LocalStorageBackend, tmp_path: Path): + """Test that writes are atomic (no partial writes visible).""" + key = "atomic_test.txt" + data1 = b"original data" + data2 = b"new data that is longer" + + # Write initial data + storage.write(key, data1) + + # Write new data - should be atomic + storage.write(key, data2) + + # Verify new data is complete + result = storage.read(key) + assert result == data2 + + # Verify no temp files left behind + files = list(tmp_path.rglob("*.tmp")) + assert len(files) == 0 + + def test_exists(self, storage: LocalStorageBackend): + """Test exists method.""" + assert not storage.exists("test.txt") + + storage.write("test.txt", b"data") + + assert storage.exists("test.txt") + + def test_delete(self, storage: LocalStorageBackend): + """Test delete method.""" + storage.write("test.txt", b"data") + assert storage.exists("test.txt") + + storage.delete("test.txt") + + assert not storage.exists("test.txt") + + def test_delete_nonexistent_does_not_raise(self, storage: LocalStorageBackend): + """Test that deleting a nonexistent key doesn't raise.""" + # Should not raise + storage.delete("nonexistent.txt") + + def test_list_keys_in_directory(self, storage: LocalStorageBackend): + """Test listing keys in a directory.""" + storage.write("dir/file1.txt", b"1") + storage.write("dir/file2.txt", b"2") + storage.write("dir/subdir/file3.txt", b"3") + storage.write("other/file4.txt", b"4") + + keys = storage.list_keys("dir") + + assert sorted(keys) == [ + "dir/file1.txt", + "dir/file2.txt", + "dir/subdir/file3.txt", + ] + + def test_list_keys_with_prefix(self, storage: LocalStorageBackend): + """Test listing keys with a filename prefix.""" + storage.write("checkpoint_001.json", b"1") + storage.write("checkpoint_002.json", b"2") + storage.write("submission_001.json", b"3") + + keys = storage.list_keys("checkpoint") + + assert sorted(keys) == ["checkpoint_001.json", "checkpoint_002.json"] + + def test_list_keys_empty(self, storage: LocalStorageBackend): + """Test listing keys when no matches.""" + keys = storage.list_keys("nonexistent") + assert keys == [] + + def test_path_traversal_blocked(self, storage: LocalStorageBackend): + """Test that path traversal attacks are blocked.""" + with pytest.raises(ValueError, match="escape base directory"): + storage.write("../../../etc/passwd", b"malicious") + + with pytest.raises(ValueError, match="escape base directory"): + storage.read("../../../etc/passwd") + + def test_overwrite_existing_file(self, storage: LocalStorageBackend): + """Test overwriting an existing file.""" + storage.write("test.txt", b"original") + storage.write("test.txt", b"updated") + + result = storage.read("test.txt") + assert result == b"updated" + + def test_empty_data(self, storage: LocalStorageBackend): + """Test writing and reading empty data.""" + storage.write("empty.txt", b"") + + result = storage.read("empty.txt") + assert result == b"" + + def test_binary_data(self, storage: LocalStorageBackend): + """Test writing and reading binary data.""" + data = bytes(range(256)) + storage.write("binary.dat", data) + + result = storage.read("binary.dat") + assert result == data + + def test_unicode_in_text(self, storage: LocalStorageBackend): + """Test unicode characters in text mode.""" + text = "Hello \u4e16\u754c \U0001f600" # "Hello 世界 😀" + storage.write_text("unicode.txt", text) + + result = storage.read_text("unicode.txt") + assert result == text + + def test_repr(self, storage: LocalStorageBackend, tmp_path: Path): + """Test string representation.""" + repr_str = repr(storage) + assert "LocalStorageBackend" in repr_str + assert str(tmp_path) in repr_str + + +class TestS3StorageBackend: + """Tests for S3StorageBackend using moto for mocking.""" + + @pytest.fixture + def s3_storage(self): + """Create an S3StorageBackend with mocked S3.""" + try: + import boto3 + from moto import mock_aws + except ImportError: + pytest.skip("boto3 and moto required for S3 tests") + + with mock_aws(): + # Create the bucket + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="test-bucket") + + from arc_agi_benchmarking.storage.s3 import S3StorageBackend + + storage = S3StorageBackend( + bucket="test-bucket", + prefix="test-prefix", + region_name="us-east-1", + ) + yield storage + + def test_write_and_read(self, s3_storage): + """Test basic write and read operations.""" + data = b"hello world" + s3_storage.write("test.txt", data) + + result = s3_storage.read("test.txt") + assert result == data + + def test_read_nonexistent_returns_none(self, s3_storage): + """Test that reading a nonexistent key returns None.""" + result = s3_storage.read("nonexistent.txt") + assert result is None + + def test_exists(self, s3_storage): + """Test exists method.""" + assert not s3_storage.exists("test.txt") + + s3_storage.write("test.txt", b"data") + + assert s3_storage.exists("test.txt") + + def test_delete(self, s3_storage): + """Test delete method.""" + s3_storage.write("test.txt", b"data") + assert s3_storage.exists("test.txt") + + s3_storage.delete("test.txt") + + assert not s3_storage.exists("test.txt") + + def test_delete_nonexistent_does_not_raise(self, s3_storage): + """Test that deleting a nonexistent key doesn't raise.""" + # Should not raise + s3_storage.delete("nonexistent.txt") + + def test_list_keys(self, s3_storage): + """Test listing keys with a prefix.""" + s3_storage.write("dir/file1.txt", b"1") + s3_storage.write("dir/file2.txt", b"2") + s3_storage.write("other/file3.txt", b"3") + + keys = s3_storage.list_keys("dir/") + + assert sorted(keys) == ["dir/file1.txt", "dir/file2.txt"] + + def test_prefix_applied_correctly(self, s3_storage): + """Test that the storage prefix is applied correctly.""" + # Write through our storage + s3_storage.write("myfile.txt", b"content") + + # Verify the full S3 key includes our prefix + import boto3 + + s3 = boto3.client("s3", region_name="us-east-1") + response = s3.get_object(Bucket="test-bucket", Key="test-prefix/myfile.txt") + assert response["Body"].read() == b"content" + + def test_repr(self, s3_storage): + """Test string representation.""" + repr_str = repr(s3_storage) + assert "S3StorageBackend" in repr_str + assert "test-bucket" in repr_str + assert "test-prefix" in repr_str + + +class TestStorageBackendInterface: + """Tests to verify the interface contract.""" + + def test_local_implements_interface(self, tmp_path: Path): + """Test that LocalStorageBackend implements StorageBackend.""" + storage = LocalStorageBackend(tmp_path) + assert isinstance(storage, StorageBackend) + + def test_s3_implements_interface(self): + """Test that S3StorageBackend implements StorageBackend.""" + try: + from moto import mock_aws + import boto3 + except ImportError: + pytest.skip("boto3 and moto required for S3 tests") + + with mock_aws(): + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="test-bucket") + + from arc_agi_benchmarking.storage.s3 import S3StorageBackend + + storage = S3StorageBackend(bucket="test-bucket") + assert isinstance(storage, StorageBackend) diff --git a/src/arc_agi_benchmarking/utils/task_utils.py b/src/arc_agi_benchmarking/utils/task_utils.py index 68ddab25..047a37f0 100644 --- a/src/arc_agi_benchmarking/utils/task_utils.py +++ b/src/arc_agi_benchmarking/utils/task_utils.py @@ -152,16 +152,45 @@ def read_provider_rate_limits() -> dict: rate_limits_data = yaml.safe_load(f) if not isinstance(rate_limits_data, dict): raise yaml.YAMLError("provider_config.yml root should be a dictionary of providers.") - # Basic validation for each provider's config + # Basic validation for each provider's config (skip 'defaults' key) for provider, limits in rate_limits_data.items(): + if provider == 'defaults': + continue # Skip defaults section if not isinstance(limits, dict) or 'rate' not in limits or 'period' not in limits: raise yaml.YAMLError( f"Provider '{provider}' in provider_config.yml must have 'rate' and 'period' keys." ) - if not isinstance(limits['rate'], int) or not isinstance(limits['period'], int): + if not isinstance(limits['rate'], (int, float)) or not isinstance(limits['period'], (int, float)): raise yaml.YAMLError( - f"'rate' and 'period' for provider '{provider}' must be integers." + f"'rate' and 'period' for provider '{provider}' must be numbers." ) return rate_limits_data except yaml.YAMLError as e: - raise yaml.YAMLError(f"Error parsing provider_config.yml: {e}") \ No newline at end of file + raise yaml.YAMLError(f"Error parsing provider_config.yml: {e}") + + +from arc_agi_benchmarking.resilience.timeout import ( + DEFAULT_REQUEST_TIMEOUT, + DEFAULT_REASONING_TIMEOUT, +) + +DEFAULT_CIRCUIT_BREAKER_THRESHOLD = 5 +DEFAULT_CIRCUIT_BREAKER_RECOVERY = 60 + + +def get_provider_timeout_config(provider_name: str, all_provider_limits: dict) -> dict: + """Get timeout and circuit breaker configuration for a provider.""" + defaults = all_provider_limits.get('defaults', {}) + default_request_timeout = defaults.get('request_timeout', DEFAULT_REQUEST_TIMEOUT) + default_reasoning_timeout = defaults.get('reasoning_timeout', DEFAULT_REASONING_TIMEOUT) + default_cb_threshold = defaults.get('circuit_breaker_threshold', DEFAULT_CIRCUIT_BREAKER_THRESHOLD) + default_cb_recovery = defaults.get('circuit_breaker_recovery', DEFAULT_CIRCUIT_BREAKER_RECOVERY) + + provider_config = all_provider_limits.get(provider_name, {}) + + return { + 'request_timeout': provider_config.get('request_timeout', default_request_timeout), + 'reasoning_timeout': provider_config.get('reasoning_timeout', default_reasoning_timeout), + 'circuit_breaker_threshold': provider_config.get('circuit_breaker_threshold', default_cb_threshold), + 'circuit_breaker_recovery': provider_config.get('circuit_breaker_recovery', default_cb_recovery), + } \ No newline at end of file