diff --git a/cli/run_all.py b/cli/run_all.py
index 4a16e4d9..81406745 100644
--- a/cli/run_all.py
+++ b/cli/run_all.py
@@ -25,6 +25,7 @@
from arc_agi_benchmarking.utils.task_utils import read_models_config, read_provider_rate_limits
from arc_agi_benchmarking.utils.rate_limiter import AsyncRequestRateLimiter
from arc_agi_benchmarking.utils.metrics import set_metrics_enabled, set_metrics_filename_prefix
+from arc_agi_benchmarking.utils.preflight import run_preflight
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type, before_sleep_log
@@ -400,6 +401,17 @@ async def main(task_list_file: Optional[str],
default="logs",
help="Base directory for JSONL logs. Per-task logs go to ///openai.jsonl (default: logs)."
)
+ parser.add_argument(
+ "--skip-preflight",
+ action="store_true",
+ help="Skip preflight validation checks (not recommended for production runs)."
+ )
+ parser.add_argument(
+ "--cost-limit",
+ type=float,
+ default=None,
+ help="Maximum estimated cost in USD. Abort if estimated cost exceeds this limit."
+ )
args = parser.parse_args()
@@ -450,6 +462,37 @@ async def main(task_list_file: Optional[str],
project_root = Path(__file__).resolve().parent.parent
logs_base_dir = (project_root / logs_base_dir).resolve()
+ # --- Preflight validation ---
+ if not args.skip_preflight:
+ logger.info("Running preflight validation...")
+ preflight_report = run_preflight(
+ config_name=config_name,
+ data_dir=args.data_dir,
+ output_dir=args.save_submission_dir,
+ num_attempts=args.num_attempts,
+ )
+ print(preflight_report)
+
+ if not preflight_report.all_passed:
+ logger.error("Preflight validation failed. Use --skip-preflight to bypass (not recommended).")
+ sys.exit(1)
+
+ # Check cost limit if specified
+ if args.cost_limit is not None and preflight_report.cost_estimate:
+ if preflight_report.cost_estimate.estimated_cost > args.cost_limit:
+ logger.error(
+ f"Estimated cost (${preflight_report.cost_estimate.estimated_cost:.2f}) "
+ f"exceeds limit (${args.cost_limit:.2f}). Aborting."
+ )
+ sys.exit(1)
+ logger.info(
+ f"Cost check passed: ${preflight_report.cost_estimate.estimated_cost:.2f} "
+ f"<= ${args.cost_limit:.2f} limit"
+ )
+ else:
+ logger.warning("Preflight validation skipped (--skip-preflight flag set)")
+ # --- End preflight validation ---
+
# Ensure `main` returns an exit code which is then used by sys.exit
exit_code_from_main = asyncio.run(main(
task_list_file=args.task_list_file,
diff --git a/src/arc_agi_benchmarking/tests/test_preflight.py b/src/arc_agi_benchmarking/tests/test_preflight.py
new file mode 100644
index 00000000..4cbbcca7
--- /dev/null
+++ b/src/arc_agi_benchmarking/tests/test_preflight.py
@@ -0,0 +1,303 @@
+"""Tests for preflight validation module."""
+
+import os
+import json
+import tempfile
+from pathlib import Path
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+from arc_agi_benchmarking.utils.preflight import (
+ validate_config_exists,
+ validate_api_key,
+ validate_data_dir,
+ validate_output_dir,
+ estimate_cost,
+ run_preflight,
+ ValidationResult,
+ CostEstimate,
+ PreflightReport,
+ PROVIDER_API_KEYS,
+)
+from arc_agi_benchmarking.schemas import ModelConfig, ModelPricing
+
+
+class TestValidateConfigExists:
+ """Tests for validate_config_exists function."""
+
+ def test_valid_config(self):
+ """Test that a valid config is found."""
+ result = validate_config_exists("gpt-4o-2024-11-20")
+ assert result.passed is True
+ assert "gpt-4o-2024-11-20" in result.message
+
+ def test_invalid_config(self):
+ """Test that an invalid config returns failure."""
+ result = validate_config_exists("nonexistent-model-xyz")
+ assert result.passed is False
+ assert "not found" in result.message.lower()
+
+
+class TestValidateApiKey:
+ """Tests for validate_api_key function."""
+
+ def test_known_provider_with_key(self):
+ """Test that a known provider with an API key passes."""
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test12345678"}):
+ result = validate_api_key("openai")
+ assert result.passed is True
+ assert "OPENAI_API_KEY" in result.message
+
+ def test_known_provider_without_key(self):
+ """Test that a known provider without an API key fails."""
+ with patch.dict(os.environ, {}, clear=True):
+ # Clear any existing DEEPSEEK_API_KEY
+ if "DEEPSEEK_API_KEY" in os.environ:
+ del os.environ["DEEPSEEK_API_KEY"]
+ result = validate_api_key("deepseek")
+ assert result.passed is False
+ assert "not found" in result.message.lower()
+
+ def test_random_provider_no_key_needed(self):
+ """Test that the random provider doesn't need an API key."""
+ result = validate_api_key("random")
+ assert result.passed is True
+ assert "no api key required" in result.message.lower()
+
+ def test_unknown_provider(self):
+ """Test that an unknown provider returns failure."""
+ result = validate_api_key("unknown_provider_xyz")
+ assert result.passed is False
+ assert "unknown provider" in result.message.lower()
+
+ def test_codex_with_either_key(self):
+ """Test that codex accepts either OPENAI_API_KEY or CODEX_API_KEY."""
+ # Test with OPENAI_API_KEY
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test12345678"}, clear=True):
+ result = validate_api_key("codex")
+ assert result.passed is True
+
+ # Test with CODEX_API_KEY
+ with patch.dict(os.environ, {"CODEX_API_KEY": "codex-test12345678"}, clear=True):
+ result = validate_api_key("codex")
+ assert result.passed is True
+
+
+class TestValidateDataDir:
+ """Tests for validate_data_dir function."""
+
+ def test_valid_data_dir(self):
+ """Test validation of a directory with valid task files."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a valid task file
+ task_file = Path(tmpdir) / "task1.json"
+ task_file.write_text(json.dumps({
+ "train": [{"input": [[1]], "output": [[2]]}],
+ "test": [{"input": [[3]], "output": [[4]]}]
+ }))
+
+ result, task_ids = validate_data_dir(tmpdir)
+ assert result.passed is True
+ assert len(task_ids) == 1
+ assert "task1" in task_ids
+
+ def test_nonexistent_dir(self):
+ """Test validation of a nonexistent directory."""
+ result, task_ids = validate_data_dir("/nonexistent/path/xyz")
+ assert result.passed is False
+ assert len(task_ids) == 0
+ assert "not found" in result.message.lower()
+
+ def test_empty_dir(self):
+ """Test validation of an empty directory."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ result, task_ids = validate_data_dir(tmpdir)
+ assert result.passed is False
+ assert len(task_ids) == 0
+ assert "no task files" in result.message.lower()
+
+ def test_invalid_json_file(self):
+ """Test validation with an invalid JSON file."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ invalid_file = Path(tmpdir) / "invalid.json"
+ invalid_file.write_text("not valid json")
+
+ result, task_ids = validate_data_dir(tmpdir)
+ assert result.passed is False
+ assert "invalid" in result.message.lower()
+
+ def test_missing_required_keys(self):
+ """Test validation with a JSON file missing required keys."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a JSON file without 'train' and 'test' keys
+ task_file = Path(tmpdir) / "bad_task.json"
+ task_file.write_text(json.dumps({"data": [1, 2, 3]}))
+
+ result, task_ids = validate_data_dir(tmpdir)
+ assert result.passed is False
+ assert len(task_ids) == 0
+
+
+class TestValidateOutputDir:
+ """Tests for validate_output_dir function."""
+
+ def test_existing_writable_dir(self):
+ """Test validation of an existing writable directory."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ result = validate_output_dir(tmpdir)
+ assert result.passed is True
+ assert "writable" in result.message.lower()
+
+ def test_nonexistent_dir_with_writable_parent(self):
+ """Test validation of a nonexistent directory with writable parent."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ new_dir = os.path.join(tmpdir, "new_subdir")
+ result = validate_output_dir(new_dir)
+ assert result.passed is True
+ assert "will be created" in result.message.lower()
+
+ def test_file_instead_of_dir(self):
+ """Test validation when path is a file, not a directory."""
+ with tempfile.NamedTemporaryFile() as tmpfile:
+ result = validate_output_dir(tmpfile.name)
+ assert result.passed is False
+ assert "not a directory" in result.message.lower()
+
+
+class TestEstimateCost:
+ """Tests for estimate_cost function."""
+
+ def test_basic_cost_estimate(self):
+ """Test basic cost estimation."""
+ mock_config = ModelConfig(
+ name="test-model",
+ model_name="test-model",
+ provider="test",
+ pricing=ModelPricing(date="2024-01-01", input=1.0, output=2.0)
+ )
+
+ estimate = estimate_cost(
+ model_config=mock_config,
+ num_tasks=10,
+ num_attempts=2,
+ avg_input_tokens=1000,
+ avg_output_tokens=500
+ )
+
+ assert estimate.num_tasks == 10
+ assert estimate.num_attempts_per_task == 2
+ assert estimate.total_attempts == 20
+ assert estimate.estimated_input_tokens == 20000 # 10 * 2 * 1000
+ assert estimate.estimated_output_tokens == 10000 # 10 * 2 * 500
+
+ # Cost: (20000/1M) * $1 + (10000/1M) * $2 = $0.02 + $0.02 = $0.04
+ assert estimate.estimated_cost == pytest.approx(0.04, rel=0.01)
+
+ def test_zero_tasks(self):
+ """Test cost estimation with zero tasks."""
+ mock_config = ModelConfig(
+ name="test-model",
+ model_name="test-model",
+ provider="test",
+ pricing=ModelPricing(date="2024-01-01", input=10.0, output=30.0)
+ )
+
+ estimate = estimate_cost(
+ model_config=mock_config,
+ num_tasks=0,
+ num_attempts=2
+ )
+
+ assert estimate.estimated_cost == 0.0
+
+
+class TestRunPreflight:
+ """Tests for run_preflight function."""
+
+ def test_full_preflight_with_valid_inputs(self):
+ """Test full preflight with all valid inputs."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a valid task file
+ task_file = Path(tmpdir) / "task1.json"
+ task_file.write_text(json.dumps({
+ "train": [{"input": [[1]], "output": [[2]]}],
+ "test": [{"input": [[3]], "output": [[4]]}]
+ }))
+
+ output_dir = os.path.join(tmpdir, "output")
+
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-test12345678"}):
+ report = run_preflight(
+ config_name="gpt-4o-2024-11-20",
+ data_dir=tmpdir,
+ output_dir=output_dir,
+ num_attempts=2
+ )
+
+ assert report.all_passed is True
+ assert report.cost_estimate is not None
+ assert len(report.validations) == 4 # config, api key, data dir, output dir
+
+ def test_preflight_with_invalid_config(self):
+ """Test preflight with an invalid config name."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ report = run_preflight(
+ config_name="nonexistent-model",
+ data_dir=tmpdir,
+ output_dir=tmpdir,
+ num_attempts=2
+ )
+
+ assert report.all_passed is False
+ # Should have failed on config validation
+ config_results = [v for v in report.validations if "config" in v.message.lower()]
+ assert any(not v.passed for v in config_results)
+
+
+class TestPreflightReport:
+ """Tests for PreflightReport string formatting."""
+
+ def test_report_string_format(self):
+ """Test that the report string is properly formatted."""
+ report = PreflightReport(
+ config_name="test-config",
+ validations=[
+ ValidationResult(passed=True, message="Check 1 passed"),
+ ValidationResult(passed=False, message="Check 2 failed", details="Some error"),
+ ],
+ cost_estimate=CostEstimate(
+ num_tasks=10,
+ num_attempts_per_task=2,
+ total_attempts=20,
+ input_price_per_1m=1.0,
+ output_price_per_1m=2.0,
+ estimated_input_tokens=10000,
+ estimated_output_tokens=5000,
+ estimated_cost=0.02
+ ),
+ all_passed=False
+ )
+
+ report_str = str(report)
+ assert "test-config" in report_str
+ assert "✓" in report_str # Passed check
+ assert "✗" in report_str # Failed check
+ assert "FAILED" in report_str # Overall status
+
+
+class TestProviderApiKeyMapping:
+ """Tests for PROVIDER_API_KEYS mapping."""
+
+ def test_all_major_providers_covered(self):
+ """Test that all major providers are in the mapping."""
+ expected_providers = [
+ "openai", "anthropic", "gemini", "deepseek",
+ "fireworks", "xai", "groq", "openrouter", "random"
+ ]
+ for provider in expected_providers:
+ assert provider in PROVIDER_API_KEYS, f"Provider {provider} not in PROVIDER_API_KEYS"
+
+ def test_random_provider_has_no_keys(self):
+ """Test that random provider has empty key list."""
+ assert PROVIDER_API_KEYS["random"] == []
diff --git a/src/arc_agi_benchmarking/utils/__main__.py b/src/arc_agi_benchmarking/utils/__main__.py
new file mode 100644
index 00000000..b3ceefb1
--- /dev/null
+++ b/src/arc_agi_benchmarking/utils/__main__.py
@@ -0,0 +1,5 @@
+"""Run preflight validation as a module: python -m arc_agi_benchmarking.utils"""
+from arc_agi_benchmarking.utils.preflight import main
+
+if __name__ == "__main__":
+ main()
diff --git a/src/arc_agi_benchmarking/utils/preflight.py b/src/arc_agi_benchmarking/utils/preflight.py
new file mode 100644
index 00000000..fd961323
--- /dev/null
+++ b/src/arc_agi_benchmarking/utils/preflight.py
@@ -0,0 +1,433 @@
+"""
+Pre-flight validation and cost estimation for ARC-AGI benchmarking runs.
+
+Run this before expensive batch operations to catch configuration errors early
+and estimate costs before spending money.
+"""
+
+import os
+import json
+import logging
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+from dataclasses import dataclass
+
+from arc_agi_benchmarking.utils.task_utils import read_models_config
+from arc_agi_benchmarking.schemas import ModelConfig
+
+logger = logging.getLogger(__name__)
+
+# Provider to environment variable mapping
+PROVIDER_API_KEYS: Dict[str, List[str]] = {
+ "openai": ["OPENAI_API_KEY"],
+ "anthropic": ["ANTHROPIC_API_KEY"],
+ "claude_agent_sdk": ["ANTHROPIC_API_KEY"],
+ "gemini": ["GOOGLE_API_KEY"],
+ "google": ["GOOGLE_API_KEY"],
+ "deepseek": ["DEEPSEEK_API_KEY"],
+ "fireworks": ["FIREWORKS_API_KEY"],
+ "xai": ["XAI_API_KEY"],
+ "grok": ["XAI_API_KEY"],
+ "groq": ["GROQ_API_KEY"],
+ "openrouter": ["OPENROUTER_API_KEY"],
+ "codex": ["OPENAI_API_KEY", "CODEX_API_KEY"], # Either works
+ "random": [], # No API key needed
+}
+
+# Average tokens per ARC task (empirically estimated)
+# This is a rough estimate based on typical task sizes
+DEFAULT_AVG_INPUT_TOKENS_PER_TASK = 2500
+DEFAULT_AVG_OUTPUT_TOKENS_PER_TASK = 500
+
+
+@dataclass
+class ValidationResult:
+ """Result of a single validation check."""
+ passed: bool
+ message: str
+ details: Optional[str] = None
+
+
+@dataclass
+class CostEstimate:
+ """Estimated cost breakdown for a benchmark run."""
+ num_tasks: int
+ num_attempts_per_task: int
+ total_attempts: int
+ input_price_per_1m: float
+ output_price_per_1m: float
+ estimated_input_tokens: int
+ estimated_output_tokens: int
+ estimated_cost: float
+
+ def __str__(self) -> str:
+ return (
+ f"Cost Estimate:\n"
+ f" Tasks: {self.num_tasks}\n"
+ f" Attempts per task: {self.num_attempts_per_task}\n"
+ f" Total attempts: {self.total_attempts}\n"
+ f" Estimated input tokens: {self.estimated_input_tokens:,}\n"
+ f" Estimated output tokens: {self.estimated_output_tokens:,}\n"
+ f" Input price: ${self.input_price_per_1m:.2f}/1M tokens\n"
+ f" Output price: ${self.output_price_per_1m:.2f}/1M tokens\n"
+ f" Estimated cost: ${self.estimated_cost:.2f}"
+ )
+
+
+@dataclass
+class PreflightReport:
+ """Complete preflight validation report."""
+ config_name: str
+ validations: List[ValidationResult]
+ cost_estimate: Optional[CostEstimate]
+ all_passed: bool
+
+ def __str__(self) -> str:
+ lines = [
+ "=" * 60,
+ "PREFLIGHT VALIDATION REPORT",
+ "=" * 60,
+ f"Config: {self.config_name}",
+ "",
+ "Validations:",
+ ]
+
+ for v in self.validations:
+ status = "✓" if v.passed else "✗"
+ lines.append(f" {status} {v.message}")
+ if v.details and not v.passed:
+ lines.append(f" └─ {v.details}")
+
+ lines.append("")
+
+ if self.cost_estimate:
+ lines.append(str(self.cost_estimate))
+
+ lines.append("")
+ lines.append("=" * 60)
+
+ if self.all_passed:
+ lines.append("✓ All preflight checks PASSED")
+ else:
+ lines.append("✗ Preflight checks FAILED - fix issues before running")
+
+ lines.append("=" * 60)
+
+ return "\n".join(lines)
+
+
+def validate_config_exists(config_name: str) -> ValidationResult:
+ """Check if the model configuration exists in models.yml."""
+ try:
+ config = read_models_config(config_name)
+ return ValidationResult(
+ passed=True,
+ message=f"Config '{config_name}' found",
+ details=f"Model: {config.model_name}, Provider: {config.provider}"
+ )
+ except ValueError as e:
+ return ValidationResult(
+ passed=False,
+ message=f"Config '{config_name}' not found",
+ details=str(e)
+ )
+ except Exception as e:
+ return ValidationResult(
+ passed=False,
+ message=f"Error reading config '{config_name}'",
+ details=str(e)
+ )
+
+
+def validate_api_key(provider: str) -> ValidationResult:
+ """Check if the required API key exists for the provider."""
+ if provider not in PROVIDER_API_KEYS:
+ return ValidationResult(
+ passed=False,
+ message=f"Unknown provider '{provider}'",
+ details=f"Known providers: {', '.join(PROVIDER_API_KEYS.keys())}"
+ )
+
+ required_keys = PROVIDER_API_KEYS[provider]
+
+ if not required_keys:
+ return ValidationResult(
+ passed=True,
+ message=f"No API key required for '{provider}'"
+ )
+
+ # Check if any of the valid keys exist
+ for key_name in required_keys:
+ if os.environ.get(key_name):
+ # Mask the key for security
+ key_value = os.environ.get(key_name, "")
+ masked = key_value[:4] + "..." + key_value[-4:] if len(key_value) > 8 else "***"
+ return ValidationResult(
+ passed=True,
+ message=f"API key '{key_name}' found",
+ details=f"Value: {masked}"
+ )
+
+ return ValidationResult(
+ passed=False,
+ message=f"API key not found for '{provider}'",
+ details=f"Set one of: {', '.join(required_keys)}"
+ )
+
+
+def validate_data_dir(data_dir: str) -> Tuple[ValidationResult, List[str]]:
+ """Check if the data directory exists and contains valid task files."""
+ path = Path(data_dir)
+ task_ids = []
+
+ if not path.exists():
+ return ValidationResult(
+ passed=False,
+ message=f"Data directory not found",
+ details=str(path.absolute())
+ ), task_ids
+
+ if not path.is_dir():
+ return ValidationResult(
+ passed=False,
+ message=f"Data path is not a directory",
+ details=str(path.absolute())
+ ), task_ids
+
+ # Find all JSON files
+ json_files = list(path.glob("*.json"))
+
+ if not json_files:
+ return ValidationResult(
+ passed=False,
+ message=f"No task files found in data directory",
+ details=str(path.absolute())
+ ), task_ids
+
+ # Validate each file
+ valid_count = 0
+ invalid_files = []
+
+ for json_file in json_files:
+ try:
+ with open(json_file, 'r') as f:
+ data = json.load(f)
+
+ # Check for required keys
+ if 'train' in data and 'test' in data:
+ valid_count += 1
+ task_ids.append(json_file.stem)
+ else:
+ invalid_files.append(json_file.name)
+ except json.JSONDecodeError:
+ invalid_files.append(f"{json_file.name} (invalid JSON)")
+ except Exception as e:
+ invalid_files.append(f"{json_file.name} ({str(e)})")
+
+ if invalid_files:
+ return ValidationResult(
+ passed=valid_count > 0, # Partial pass if some files are valid
+ message=f"Found {valid_count} valid tasks, {len(invalid_files)} invalid",
+ details=f"Invalid: {', '.join(invalid_files[:5])}" +
+ (f" (+{len(invalid_files)-5} more)" if len(invalid_files) > 5 else "")
+ ), task_ids
+
+ return ValidationResult(
+ passed=True,
+ message=f"Found {valid_count} valid task files",
+ details=str(path.absolute())
+ ), task_ids
+
+
+def validate_output_dir(output_dir: str) -> ValidationResult:
+ """Check if the output directory is writable."""
+ path = Path(output_dir)
+
+ # If it doesn't exist, check if parent is writable
+ if not path.exists():
+ parent = path.parent
+ if parent.exists() and os.access(parent, os.W_OK):
+ return ValidationResult(
+ passed=True,
+ message=f"Output directory will be created",
+ details=str(path.absolute())
+ )
+ else:
+ return ValidationResult(
+ passed=False,
+ message=f"Cannot create output directory",
+ details=f"Parent not writable: {parent.absolute()}"
+ )
+
+ if not path.is_dir():
+ return ValidationResult(
+ passed=False,
+ message=f"Output path exists but is not a directory",
+ details=str(path.absolute())
+ )
+
+ if not os.access(path, os.W_OK):
+ return ValidationResult(
+ passed=False,
+ message=f"Output directory not writable",
+ details=str(path.absolute())
+ )
+
+ return ValidationResult(
+ passed=True,
+ message=f"Output directory exists and is writable",
+ details=str(path.absolute())
+ )
+
+
+def estimate_cost(
+ model_config: ModelConfig,
+ num_tasks: int,
+ num_attempts: int = 2,
+ avg_input_tokens: int = DEFAULT_AVG_INPUT_TOKENS_PER_TASK,
+ avg_output_tokens: int = DEFAULT_AVG_OUTPUT_TOKENS_PER_TASK,
+) -> CostEstimate:
+ """Estimate the cost of a benchmark run."""
+ total_attempts = num_tasks * num_attempts
+
+ estimated_input_tokens = total_attempts * avg_input_tokens
+ estimated_output_tokens = total_attempts * avg_output_tokens
+
+ input_cost = (estimated_input_tokens / 1_000_000) * model_config.pricing.input
+ output_cost = (estimated_output_tokens / 1_000_000) * model_config.pricing.output
+ total_cost = input_cost + output_cost
+
+ return CostEstimate(
+ num_tasks=num_tasks,
+ num_attempts_per_task=num_attempts,
+ total_attempts=total_attempts,
+ input_price_per_1m=model_config.pricing.input,
+ output_price_per_1m=model_config.pricing.output,
+ estimated_input_tokens=estimated_input_tokens,
+ estimated_output_tokens=estimated_output_tokens,
+ estimated_cost=total_cost,
+ )
+
+
+def run_preflight(
+ config_name: str,
+ data_dir: str,
+ output_dir: str,
+ num_attempts: int = 2,
+) -> PreflightReport:
+ """
+ Run all preflight validations and return a comprehensive report.
+
+ Args:
+ config_name: Name of the model configuration from models.yml
+ data_dir: Directory containing task JSON files
+ output_dir: Directory where submissions will be saved
+ num_attempts: Number of attempts per task
+
+ Returns:
+ PreflightReport with all validation results and cost estimate
+ """
+ validations: List[ValidationResult] = []
+ cost_estimate = None
+ model_config = None
+ num_tasks = 0
+
+ # 1. Validate config exists
+ config_result = validate_config_exists(config_name)
+ validations.append(config_result)
+
+ if config_result.passed:
+ try:
+ model_config = read_models_config(config_name)
+ except Exception:
+ pass
+
+ # 2. Validate API key (only if config was found)
+ if model_config:
+ api_result = validate_api_key(model_config.provider)
+ validations.append(api_result)
+ else:
+ validations.append(ValidationResult(
+ passed=False,
+ message="Skipping API key validation (config not found)"
+ ))
+
+ # 3. Validate data directory
+ data_result, task_ids = validate_data_dir(data_dir)
+ validations.append(data_result)
+ num_tasks = len(task_ids)
+
+ # 4. Validate output directory
+ output_result = validate_output_dir(output_dir)
+ validations.append(output_result)
+
+ # 5. Calculate cost estimate (only if we have valid config and tasks)
+ if model_config and num_tasks > 0:
+ cost_estimate = estimate_cost(
+ model_config=model_config,
+ num_tasks=num_tasks,
+ num_attempts=num_attempts,
+ )
+
+ all_passed = all(v.passed for v in validations)
+
+ return PreflightReport(
+ config_name=config_name,
+ validations=validations,
+ cost_estimate=cost_estimate,
+ all_passed=all_passed,
+ )
+
+
+def main():
+ """CLI entry point for preflight validation."""
+ import argparse
+ from dotenv import load_dotenv
+
+ load_dotenv()
+
+ parser = argparse.ArgumentParser(
+ description="Run preflight validation before benchmark runs"
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ required=True,
+ help="Model configuration name from models.yml"
+ )
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default="data/sample/tasks",
+ help="Directory containing task JSON files"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="submissions",
+ help="Directory for saving submissions"
+ )
+ parser.add_argument(
+ "--num_attempts",
+ type=int,
+ default=2,
+ help="Number of attempts per task"
+ )
+
+ args = parser.parse_args()
+
+ report = run_preflight(
+ config_name=args.config,
+ data_dir=args.data_dir,
+ output_dir=args.output_dir,
+ num_attempts=args.num_attempts,
+ )
+
+ print(report)
+
+ # Exit with error code if validation failed
+ exit(0 if report.all_passed else 1)
+
+
+if __name__ == "__main__":
+ main()