diff --git a/cli/run_all.py b/cli/run_all.py index 4a16e4d9..81406745 100644 --- a/cli/run_all.py +++ b/cli/run_all.py @@ -25,6 +25,7 @@ from arc_agi_benchmarking.utils.task_utils import read_models_config, read_provider_rate_limits from arc_agi_benchmarking.utils.rate_limiter import AsyncRequestRateLimiter from arc_agi_benchmarking.utils.metrics import set_metrics_enabled, set_metrics_filename_prefix +from arc_agi_benchmarking.utils.preflight import run_preflight from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type, before_sleep_log @@ -400,6 +401,17 @@ async def main(task_list_file: Optional[str], default="logs", help="Base directory for JSONL logs. Per-task logs go to ///openai.jsonl (default: logs)." ) + parser.add_argument( + "--skip-preflight", + action="store_true", + help="Skip preflight validation checks (not recommended for production runs)." + ) + parser.add_argument( + "--cost-limit", + type=float, + default=None, + help="Maximum estimated cost in USD. Abort if estimated cost exceeds this limit." + ) args = parser.parse_args() @@ -450,6 +462,37 @@ async def main(task_list_file: Optional[str], project_root = Path(__file__).resolve().parent.parent logs_base_dir = (project_root / logs_base_dir).resolve() + # --- Preflight validation --- + if not args.skip_preflight: + logger.info("Running preflight validation...") + preflight_report = run_preflight( + config_name=config_name, + data_dir=args.data_dir, + output_dir=args.save_submission_dir, + num_attempts=args.num_attempts, + ) + print(preflight_report) + + if not preflight_report.all_passed: + logger.error("Preflight validation failed. Use --skip-preflight to bypass (not recommended).") + sys.exit(1) + + # Check cost limit if specified + if args.cost_limit is not None and preflight_report.cost_estimate: + if preflight_report.cost_estimate.estimated_cost > args.cost_limit: + logger.error( + f"Estimated cost (${preflight_report.cost_estimate.estimated_cost:.2f}) " + f"exceeds limit (${args.cost_limit:.2f}). Aborting." + ) + sys.exit(1) + logger.info( + f"Cost check passed: ${preflight_report.cost_estimate.estimated_cost:.2f} " + f"<= ${args.cost_limit:.2f} limit" + ) + else: + logger.warning("Preflight validation skipped (--skip-preflight flag set)") + # --- End preflight validation --- + # Ensure `main` returns an exit code which is then used by sys.exit exit_code_from_main = asyncio.run(main( task_list_file=args.task_list_file, diff --git a/src/arc_agi_benchmarking/tests/test_preflight.py b/src/arc_agi_benchmarking/tests/test_preflight.py new file mode 100644 index 00000000..4cbbcca7 --- /dev/null +++ b/src/arc_agi_benchmarking/tests/test_preflight.py @@ -0,0 +1,303 @@ +"""Tests for preflight validation module.""" + +import os +import json +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +from arc_agi_benchmarking.utils.preflight import ( + validate_config_exists, + validate_api_key, + validate_data_dir, + validate_output_dir, + estimate_cost, + run_preflight, + ValidationResult, + CostEstimate, + PreflightReport, + PROVIDER_API_KEYS, +) +from arc_agi_benchmarking.schemas import ModelConfig, ModelPricing + + +class TestValidateConfigExists: + """Tests for validate_config_exists function.""" + + def test_valid_config(self): + """Test that a valid config is found.""" + result = validate_config_exists("gpt-4o-2024-11-20") + assert result.passed is True + assert "gpt-4o-2024-11-20" in result.message + + def test_invalid_config(self): + """Test that an invalid config returns failure.""" + result = validate_config_exists("nonexistent-model-xyz") + assert result.passed is False + assert "not found" in result.message.lower() + + +class TestValidateApiKey: + """Tests for validate_api_key function.""" + + def test_known_provider_with_key(self): + """Test that a known provider with an API key passes.""" + with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test12345678"}): + result = validate_api_key("openai") + assert result.passed is True + assert "OPENAI_API_KEY" in result.message + + def test_known_provider_without_key(self): + """Test that a known provider without an API key fails.""" + with patch.dict(os.environ, {}, clear=True): + # Clear any existing DEEPSEEK_API_KEY + if "DEEPSEEK_API_KEY" in os.environ: + del os.environ["DEEPSEEK_API_KEY"] + result = validate_api_key("deepseek") + assert result.passed is False + assert "not found" in result.message.lower() + + def test_random_provider_no_key_needed(self): + """Test that the random provider doesn't need an API key.""" + result = validate_api_key("random") + assert result.passed is True + assert "no api key required" in result.message.lower() + + def test_unknown_provider(self): + """Test that an unknown provider returns failure.""" + result = validate_api_key("unknown_provider_xyz") + assert result.passed is False + assert "unknown provider" in result.message.lower() + + def test_codex_with_either_key(self): + """Test that codex accepts either OPENAI_API_KEY or CODEX_API_KEY.""" + # Test with OPENAI_API_KEY + with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test12345678"}, clear=True): + result = validate_api_key("codex") + assert result.passed is True + + # Test with CODEX_API_KEY + with patch.dict(os.environ, {"CODEX_API_KEY": "codex-test12345678"}, clear=True): + result = validate_api_key("codex") + assert result.passed is True + + +class TestValidateDataDir: + """Tests for validate_data_dir function.""" + + def test_valid_data_dir(self): + """Test validation of a directory with valid task files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a valid task file + task_file = Path(tmpdir) / "task1.json" + task_file.write_text(json.dumps({ + "train": [{"input": [[1]], "output": [[2]]}], + "test": [{"input": [[3]], "output": [[4]]}] + })) + + result, task_ids = validate_data_dir(tmpdir) + assert result.passed is True + assert len(task_ids) == 1 + assert "task1" in task_ids + + def test_nonexistent_dir(self): + """Test validation of a nonexistent directory.""" + result, task_ids = validate_data_dir("/nonexistent/path/xyz") + assert result.passed is False + assert len(task_ids) == 0 + assert "not found" in result.message.lower() + + def test_empty_dir(self): + """Test validation of an empty directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + result, task_ids = validate_data_dir(tmpdir) + assert result.passed is False + assert len(task_ids) == 0 + assert "no task files" in result.message.lower() + + def test_invalid_json_file(self): + """Test validation with an invalid JSON file.""" + with tempfile.TemporaryDirectory() as tmpdir: + invalid_file = Path(tmpdir) / "invalid.json" + invalid_file.write_text("not valid json") + + result, task_ids = validate_data_dir(tmpdir) + assert result.passed is False + assert "invalid" in result.message.lower() + + def test_missing_required_keys(self): + """Test validation with a JSON file missing required keys.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a JSON file without 'train' and 'test' keys + task_file = Path(tmpdir) / "bad_task.json" + task_file.write_text(json.dumps({"data": [1, 2, 3]})) + + result, task_ids = validate_data_dir(tmpdir) + assert result.passed is False + assert len(task_ids) == 0 + + +class TestValidateOutputDir: + """Tests for validate_output_dir function.""" + + def test_existing_writable_dir(self): + """Test validation of an existing writable directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + result = validate_output_dir(tmpdir) + assert result.passed is True + assert "writable" in result.message.lower() + + def test_nonexistent_dir_with_writable_parent(self): + """Test validation of a nonexistent directory with writable parent.""" + with tempfile.TemporaryDirectory() as tmpdir: + new_dir = os.path.join(tmpdir, "new_subdir") + result = validate_output_dir(new_dir) + assert result.passed is True + assert "will be created" in result.message.lower() + + def test_file_instead_of_dir(self): + """Test validation when path is a file, not a directory.""" + with tempfile.NamedTemporaryFile() as tmpfile: + result = validate_output_dir(tmpfile.name) + assert result.passed is False + assert "not a directory" in result.message.lower() + + +class TestEstimateCost: + """Tests for estimate_cost function.""" + + def test_basic_cost_estimate(self): + """Test basic cost estimation.""" + mock_config = ModelConfig( + name="test-model", + model_name="test-model", + provider="test", + pricing=ModelPricing(date="2024-01-01", input=1.0, output=2.0) + ) + + estimate = estimate_cost( + model_config=mock_config, + num_tasks=10, + num_attempts=2, + avg_input_tokens=1000, + avg_output_tokens=500 + ) + + assert estimate.num_tasks == 10 + assert estimate.num_attempts_per_task == 2 + assert estimate.total_attempts == 20 + assert estimate.estimated_input_tokens == 20000 # 10 * 2 * 1000 + assert estimate.estimated_output_tokens == 10000 # 10 * 2 * 500 + + # Cost: (20000/1M) * $1 + (10000/1M) * $2 = $0.02 + $0.02 = $0.04 + assert estimate.estimated_cost == pytest.approx(0.04, rel=0.01) + + def test_zero_tasks(self): + """Test cost estimation with zero tasks.""" + mock_config = ModelConfig( + name="test-model", + model_name="test-model", + provider="test", + pricing=ModelPricing(date="2024-01-01", input=10.0, output=30.0) + ) + + estimate = estimate_cost( + model_config=mock_config, + num_tasks=0, + num_attempts=2 + ) + + assert estimate.estimated_cost == 0.0 + + +class TestRunPreflight: + """Tests for run_preflight function.""" + + def test_full_preflight_with_valid_inputs(self): + """Test full preflight with all valid inputs.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a valid task file + task_file = Path(tmpdir) / "task1.json" + task_file.write_text(json.dumps({ + "train": [{"input": [[1]], "output": [[2]]}], + "test": [{"input": [[3]], "output": [[4]]}] + })) + + output_dir = os.path.join(tmpdir, "output") + + with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test12345678"}): + report = run_preflight( + config_name="gpt-4o-2024-11-20", + data_dir=tmpdir, + output_dir=output_dir, + num_attempts=2 + ) + + assert report.all_passed is True + assert report.cost_estimate is not None + assert len(report.validations) == 4 # config, api key, data dir, output dir + + def test_preflight_with_invalid_config(self): + """Test preflight with an invalid config name.""" + with tempfile.TemporaryDirectory() as tmpdir: + report = run_preflight( + config_name="nonexistent-model", + data_dir=tmpdir, + output_dir=tmpdir, + num_attempts=2 + ) + + assert report.all_passed is False + # Should have failed on config validation + config_results = [v for v in report.validations if "config" in v.message.lower()] + assert any(not v.passed for v in config_results) + + +class TestPreflightReport: + """Tests for PreflightReport string formatting.""" + + def test_report_string_format(self): + """Test that the report string is properly formatted.""" + report = PreflightReport( + config_name="test-config", + validations=[ + ValidationResult(passed=True, message="Check 1 passed"), + ValidationResult(passed=False, message="Check 2 failed", details="Some error"), + ], + cost_estimate=CostEstimate( + num_tasks=10, + num_attempts_per_task=2, + total_attempts=20, + input_price_per_1m=1.0, + output_price_per_1m=2.0, + estimated_input_tokens=10000, + estimated_output_tokens=5000, + estimated_cost=0.02 + ), + all_passed=False + ) + + report_str = str(report) + assert "test-config" in report_str + assert "✓" in report_str # Passed check + assert "✗" in report_str # Failed check + assert "FAILED" in report_str # Overall status + + +class TestProviderApiKeyMapping: + """Tests for PROVIDER_API_KEYS mapping.""" + + def test_all_major_providers_covered(self): + """Test that all major providers are in the mapping.""" + expected_providers = [ + "openai", "anthropic", "gemini", "deepseek", + "fireworks", "xai", "groq", "openrouter", "random" + ] + for provider in expected_providers: + assert provider in PROVIDER_API_KEYS, f"Provider {provider} not in PROVIDER_API_KEYS" + + def test_random_provider_has_no_keys(self): + """Test that random provider has empty key list.""" + assert PROVIDER_API_KEYS["random"] == [] diff --git a/src/arc_agi_benchmarking/utils/__main__.py b/src/arc_agi_benchmarking/utils/__main__.py new file mode 100644 index 00000000..b3ceefb1 --- /dev/null +++ b/src/arc_agi_benchmarking/utils/__main__.py @@ -0,0 +1,5 @@ +"""Run preflight validation as a module: python -m arc_agi_benchmarking.utils""" +from arc_agi_benchmarking.utils.preflight import main + +if __name__ == "__main__": + main() diff --git a/src/arc_agi_benchmarking/utils/preflight.py b/src/arc_agi_benchmarking/utils/preflight.py new file mode 100644 index 00000000..fd961323 --- /dev/null +++ b/src/arc_agi_benchmarking/utils/preflight.py @@ -0,0 +1,433 @@ +""" +Pre-flight validation and cost estimation for ARC-AGI benchmarking runs. + +Run this before expensive batch operations to catch configuration errors early +and estimate costs before spending money. +""" + +import os +import json +import logging +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +from arc_agi_benchmarking.utils.task_utils import read_models_config +from arc_agi_benchmarking.schemas import ModelConfig + +logger = logging.getLogger(__name__) + +# Provider to environment variable mapping +PROVIDER_API_KEYS: Dict[str, List[str]] = { + "openai": ["OPENAI_API_KEY"], + "anthropic": ["ANTHROPIC_API_KEY"], + "claude_agent_sdk": ["ANTHROPIC_API_KEY"], + "gemini": ["GOOGLE_API_KEY"], + "google": ["GOOGLE_API_KEY"], + "deepseek": ["DEEPSEEK_API_KEY"], + "fireworks": ["FIREWORKS_API_KEY"], + "xai": ["XAI_API_KEY"], + "grok": ["XAI_API_KEY"], + "groq": ["GROQ_API_KEY"], + "openrouter": ["OPENROUTER_API_KEY"], + "codex": ["OPENAI_API_KEY", "CODEX_API_KEY"], # Either works + "random": [], # No API key needed +} + +# Average tokens per ARC task (empirically estimated) +# This is a rough estimate based on typical task sizes +DEFAULT_AVG_INPUT_TOKENS_PER_TASK = 2500 +DEFAULT_AVG_OUTPUT_TOKENS_PER_TASK = 500 + + +@dataclass +class ValidationResult: + """Result of a single validation check.""" + passed: bool + message: str + details: Optional[str] = None + + +@dataclass +class CostEstimate: + """Estimated cost breakdown for a benchmark run.""" + num_tasks: int + num_attempts_per_task: int + total_attempts: int + input_price_per_1m: float + output_price_per_1m: float + estimated_input_tokens: int + estimated_output_tokens: int + estimated_cost: float + + def __str__(self) -> str: + return ( + f"Cost Estimate:\n" + f" Tasks: {self.num_tasks}\n" + f" Attempts per task: {self.num_attempts_per_task}\n" + f" Total attempts: {self.total_attempts}\n" + f" Estimated input tokens: {self.estimated_input_tokens:,}\n" + f" Estimated output tokens: {self.estimated_output_tokens:,}\n" + f" Input price: ${self.input_price_per_1m:.2f}/1M tokens\n" + f" Output price: ${self.output_price_per_1m:.2f}/1M tokens\n" + f" Estimated cost: ${self.estimated_cost:.2f}" + ) + + +@dataclass +class PreflightReport: + """Complete preflight validation report.""" + config_name: str + validations: List[ValidationResult] + cost_estimate: Optional[CostEstimate] + all_passed: bool + + def __str__(self) -> str: + lines = [ + "=" * 60, + "PREFLIGHT VALIDATION REPORT", + "=" * 60, + f"Config: {self.config_name}", + "", + "Validations:", + ] + + for v in self.validations: + status = "✓" if v.passed else "✗" + lines.append(f" {status} {v.message}") + if v.details and not v.passed: + lines.append(f" └─ {v.details}") + + lines.append("") + + if self.cost_estimate: + lines.append(str(self.cost_estimate)) + + lines.append("") + lines.append("=" * 60) + + if self.all_passed: + lines.append("✓ All preflight checks PASSED") + else: + lines.append("✗ Preflight checks FAILED - fix issues before running") + + lines.append("=" * 60) + + return "\n".join(lines) + + +def validate_config_exists(config_name: str) -> ValidationResult: + """Check if the model configuration exists in models.yml.""" + try: + config = read_models_config(config_name) + return ValidationResult( + passed=True, + message=f"Config '{config_name}' found", + details=f"Model: {config.model_name}, Provider: {config.provider}" + ) + except ValueError as e: + return ValidationResult( + passed=False, + message=f"Config '{config_name}' not found", + details=str(e) + ) + except Exception as e: + return ValidationResult( + passed=False, + message=f"Error reading config '{config_name}'", + details=str(e) + ) + + +def validate_api_key(provider: str) -> ValidationResult: + """Check if the required API key exists for the provider.""" + if provider not in PROVIDER_API_KEYS: + return ValidationResult( + passed=False, + message=f"Unknown provider '{provider}'", + details=f"Known providers: {', '.join(PROVIDER_API_KEYS.keys())}" + ) + + required_keys = PROVIDER_API_KEYS[provider] + + if not required_keys: + return ValidationResult( + passed=True, + message=f"No API key required for '{provider}'" + ) + + # Check if any of the valid keys exist + for key_name in required_keys: + if os.environ.get(key_name): + # Mask the key for security + key_value = os.environ.get(key_name, "") + masked = key_value[:4] + "..." + key_value[-4:] if len(key_value) > 8 else "***" + return ValidationResult( + passed=True, + message=f"API key '{key_name}' found", + details=f"Value: {masked}" + ) + + return ValidationResult( + passed=False, + message=f"API key not found for '{provider}'", + details=f"Set one of: {', '.join(required_keys)}" + ) + + +def validate_data_dir(data_dir: str) -> Tuple[ValidationResult, List[str]]: + """Check if the data directory exists and contains valid task files.""" + path = Path(data_dir) + task_ids = [] + + if not path.exists(): + return ValidationResult( + passed=False, + message=f"Data directory not found", + details=str(path.absolute()) + ), task_ids + + if not path.is_dir(): + return ValidationResult( + passed=False, + message=f"Data path is not a directory", + details=str(path.absolute()) + ), task_ids + + # Find all JSON files + json_files = list(path.glob("*.json")) + + if not json_files: + return ValidationResult( + passed=False, + message=f"No task files found in data directory", + details=str(path.absolute()) + ), task_ids + + # Validate each file + valid_count = 0 + invalid_files = [] + + for json_file in json_files: + try: + with open(json_file, 'r') as f: + data = json.load(f) + + # Check for required keys + if 'train' in data and 'test' in data: + valid_count += 1 + task_ids.append(json_file.stem) + else: + invalid_files.append(json_file.name) + except json.JSONDecodeError: + invalid_files.append(f"{json_file.name} (invalid JSON)") + except Exception as e: + invalid_files.append(f"{json_file.name} ({str(e)})") + + if invalid_files: + return ValidationResult( + passed=valid_count > 0, # Partial pass if some files are valid + message=f"Found {valid_count} valid tasks, {len(invalid_files)} invalid", + details=f"Invalid: {', '.join(invalid_files[:5])}" + + (f" (+{len(invalid_files)-5} more)" if len(invalid_files) > 5 else "") + ), task_ids + + return ValidationResult( + passed=True, + message=f"Found {valid_count} valid task files", + details=str(path.absolute()) + ), task_ids + + +def validate_output_dir(output_dir: str) -> ValidationResult: + """Check if the output directory is writable.""" + path = Path(output_dir) + + # If it doesn't exist, check if parent is writable + if not path.exists(): + parent = path.parent + if parent.exists() and os.access(parent, os.W_OK): + return ValidationResult( + passed=True, + message=f"Output directory will be created", + details=str(path.absolute()) + ) + else: + return ValidationResult( + passed=False, + message=f"Cannot create output directory", + details=f"Parent not writable: {parent.absolute()}" + ) + + if not path.is_dir(): + return ValidationResult( + passed=False, + message=f"Output path exists but is not a directory", + details=str(path.absolute()) + ) + + if not os.access(path, os.W_OK): + return ValidationResult( + passed=False, + message=f"Output directory not writable", + details=str(path.absolute()) + ) + + return ValidationResult( + passed=True, + message=f"Output directory exists and is writable", + details=str(path.absolute()) + ) + + +def estimate_cost( + model_config: ModelConfig, + num_tasks: int, + num_attempts: int = 2, + avg_input_tokens: int = DEFAULT_AVG_INPUT_TOKENS_PER_TASK, + avg_output_tokens: int = DEFAULT_AVG_OUTPUT_TOKENS_PER_TASK, +) -> CostEstimate: + """Estimate the cost of a benchmark run.""" + total_attempts = num_tasks * num_attempts + + estimated_input_tokens = total_attempts * avg_input_tokens + estimated_output_tokens = total_attempts * avg_output_tokens + + input_cost = (estimated_input_tokens / 1_000_000) * model_config.pricing.input + output_cost = (estimated_output_tokens / 1_000_000) * model_config.pricing.output + total_cost = input_cost + output_cost + + return CostEstimate( + num_tasks=num_tasks, + num_attempts_per_task=num_attempts, + total_attempts=total_attempts, + input_price_per_1m=model_config.pricing.input, + output_price_per_1m=model_config.pricing.output, + estimated_input_tokens=estimated_input_tokens, + estimated_output_tokens=estimated_output_tokens, + estimated_cost=total_cost, + ) + + +def run_preflight( + config_name: str, + data_dir: str, + output_dir: str, + num_attempts: int = 2, +) -> PreflightReport: + """ + Run all preflight validations and return a comprehensive report. + + Args: + config_name: Name of the model configuration from models.yml + data_dir: Directory containing task JSON files + output_dir: Directory where submissions will be saved + num_attempts: Number of attempts per task + + Returns: + PreflightReport with all validation results and cost estimate + """ + validations: List[ValidationResult] = [] + cost_estimate = None + model_config = None + num_tasks = 0 + + # 1. Validate config exists + config_result = validate_config_exists(config_name) + validations.append(config_result) + + if config_result.passed: + try: + model_config = read_models_config(config_name) + except Exception: + pass + + # 2. Validate API key (only if config was found) + if model_config: + api_result = validate_api_key(model_config.provider) + validations.append(api_result) + else: + validations.append(ValidationResult( + passed=False, + message="Skipping API key validation (config not found)" + )) + + # 3. Validate data directory + data_result, task_ids = validate_data_dir(data_dir) + validations.append(data_result) + num_tasks = len(task_ids) + + # 4. Validate output directory + output_result = validate_output_dir(output_dir) + validations.append(output_result) + + # 5. Calculate cost estimate (only if we have valid config and tasks) + if model_config and num_tasks > 0: + cost_estimate = estimate_cost( + model_config=model_config, + num_tasks=num_tasks, + num_attempts=num_attempts, + ) + + all_passed = all(v.passed for v in validations) + + return PreflightReport( + config_name=config_name, + validations=validations, + cost_estimate=cost_estimate, + all_passed=all_passed, + ) + + +def main(): + """CLI entry point for preflight validation.""" + import argparse + from dotenv import load_dotenv + + load_dotenv() + + parser = argparse.ArgumentParser( + description="Run preflight validation before benchmark runs" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Model configuration name from models.yml" + ) + parser.add_argument( + "--data_dir", + type=str, + default="data/sample/tasks", + help="Directory containing task JSON files" + ) + parser.add_argument( + "--output_dir", + type=str, + default="submissions", + help="Directory for saving submissions" + ) + parser.add_argument( + "--num_attempts", + type=int, + default=2, + help="Number of attempts per task" + ) + + args = parser.parse_args() + + report = run_preflight( + config_name=args.config, + data_dir=args.data_dir, + output_dir=args.output_dir, + num_attempts=args.num_attempts, + ) + + print(report) + + # Exit with error code if validation failed + exit(0 if report.all_passed else 1) + + +if __name__ == "__main__": + main()