Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ clean:
rm -rf logs/random-baseline
find . -type f -name "*.pyc" -delete
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
find . -type d -name ".checkpoints" -exec rm -rf {} + 2>/dev/null || true
214 changes: 193 additions & 21 deletions cli/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,20 @@
sys.path.insert(0, project_root)

from main import ARCTester
from arc_agi_benchmarking.utils.task_utils import read_models_config, read_provider_rate_limits
from arc_agi_benchmarking.utils.task_utils import read_models_config, read_provider_rate_limits, get_provider_timeout_config
from arc_agi_benchmarking.utils.rate_limiter import AsyncRequestRateLimiter
from arc_agi_benchmarking.utils.metrics import set_metrics_enabled, set_metrics_filename_prefix
from arc_agi_benchmarking.utils.preflight import run_preflight
from arc_agi_benchmarking.utils.logging_utils import setup_logging, StructuredFormatter
from arc_agi_benchmarking.resilience import (
CircuitBreaker,
CircuitBreakerOpenError,
TaskTimeoutError,
task_timeout,
get_circuit_breaker,
)
from arc_agi_benchmarking.checkpoint import BatchProgressManager, TaskStatus
from arc_agi_benchmarking.storage import LocalStorageBackend

from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type, before_sleep_log

Expand Down Expand Up @@ -82,6 +91,8 @@ def _record_factory(*args, **kwargs):
# Default values
DEFAULT_RATE_LIMIT_RATE = 400
DEFAULT_RATE_LIMIT_PERIOD = 60
DEFAULT_CIRCUIT_BREAKER_THRESHOLD = 5
DEFAULT_CIRCUIT_BREAKER_RECOVERY = 60

# --- Configuration ---
# Default model configuration to test if not provided via CLI.
Expand All @@ -98,6 +109,8 @@ def _record_factory(*args, **kwargs):

# --- Globals for Orchestrator ---
PROVIDER_RATE_LIMITERS: Dict[str, AsyncRequestRateLimiter] = {}
PROVIDER_CIRCUIT_BREAKERS: Dict[str, CircuitBreaker] = {}
PROVIDER_TIMEOUT_CONFIGS: Dict[str, Dict] = {}
MODEL_CONFIG_CACHE: Dict[str, Any] = {}

def get_model_config(config_name: str):
Expand Down Expand Up @@ -129,15 +142,54 @@ def get_or_create_rate_limiter(provider_name: str, all_provider_limits: Dict) ->
PROVIDER_RATE_LIMITERS[provider_name] = AsyncRequestRateLimiter(rate=actual_rate_for_limiter, capacity=actual_capacity_for_limiter)
return PROVIDER_RATE_LIMITERS[provider_name]


def get_or_create_circuit_breaker(
provider_name: str,
all_provider_limits: Dict,
threshold_override: Optional[int] = None
) -> CircuitBreaker:
if provider_name not in PROVIDER_CIRCUIT_BREAKERS:
timeout_config = get_provider_timeout_config(provider_name, all_provider_limits)
failure_threshold = threshold_override or timeout_config['circuit_breaker_threshold']
recovery_timeout = timeout_config['circuit_breaker_recovery']
logger.info(f"Initializing circuit breaker for '{provider_name}': threshold={failure_threshold}, recovery={recovery_timeout}s")
PROVIDER_CIRCUIT_BREAKERS[provider_name] = CircuitBreaker(
name=provider_name,
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
)
return PROVIDER_CIRCUIT_BREAKERS[provider_name]


def get_task_timeout(provider_name: str, all_provider_limits: Dict, max_task_timeout: Optional[float] = None) -> Optional[float]:
if max_task_timeout is not None and max_task_timeout > 0:
return max_task_timeout
return None

async def run_single_test_wrapper(config_name: str, task_id: str, limiter: AsyncRequestRateLimiter,
circuit_breaker: CircuitBreaker,
task_timeout_seconds: Optional[float],
data_dir: str, save_submission_dir: str,
overwrite_submission: bool, print_submission: bool,
num_attempts: int, retry_attempts: int,
logs_base_dir: Path) -> bool: # removed print_logs
logs_base_dir: Path,
progress_manager: Optional[BatchProgressManager] = None) -> Optional[bool]:
logger.info(f"[Orchestrator] Queuing task: {task_id}, config: {config_name}")

# Apply tenacity retry decorator directly to the synchronous function
# The logger passed to before_sleep_log is the module-level logger of cli.run_all
# Claim the task for this worker (if using checkpointing)
if progress_manager is not None:
if not progress_manager.claim_task(task_id):
logger.debug(f"[Orchestrator] Task {task_id} already claimed or completed, skipping")
return None # Skipped, not success or failure

try:
circuit_breaker.raise_if_open()
except CircuitBreakerOpenError as e:
logger.warning(f"[Orchestrator] Circuit breaker OPEN for {config_name}, skipping {task_id}. Recovery in {e.recovery_time:.1f}s")
if progress_manager is not None:
progress_manager.mark_failed(task_id, f"Circuit breaker open: {e}")
return False

@retry(
wait=wait_exponential(multiplier=1, min=4, max=60),
stop=stop_after_attempt(4),
Expand Down Expand Up @@ -192,26 +244,58 @@ def filter(self, record: logging.LogRecord) -> bool:

try:
async with limiter:
logger.info(f"[Orchestrator] Rate limiter acquired for: {config_name}. Executing task {task_id} with tenacity retries...")
await asyncio.to_thread(_synchronous_task_execution_attempt_with_tenacity)

logger.info(f"[Orchestrator] Successfully processed (with tenacity retries if any): {config_name} / {task_id}")
timeout_str = f"{task_timeout_seconds}s" if task_timeout_seconds else "none"
logger.info(f"[Orchestrator] Rate limiter acquired for {config_name}. Executing {task_id} (timeout={timeout_str})")
if task_timeout_seconds:
await task_timeout(
_synchronous_task_execution_attempt_with_tenacity,
task_timeout_seconds,
f"Task {task_id} ({config_name})"
)
else:
await asyncio.get_event_loop().run_in_executor(
None, _synchronous_task_execution_attempt_with_tenacity
)

circuit_breaker.record_success()
logger.info(f"[Orchestrator] Successfully processed: {config_name} / {task_id}")
if progress_manager is not None:
progress_manager.mark_completed(task_id)
return True

except TaskTimeoutError as e:
circuit_breaker.record_failure(e)
logger.error(f"[Orchestrator] Task {task_id} ({config_name}) timed out after {e.elapsed:.2f}s (limit: {e.timeout}s)")
if progress_manager is not None:
progress_manager.mark_failed(task_id, f"Timeout after {e.elapsed:.2f}s")
return False

except Exception as e:
logger.error(f"[Orchestrator] Failed to process (after all tenacity retries or due to non-retryable error): {config_name} / {task_id}. Error: {type(e).__name__} - {e}", exc_info=True)
if isinstance(e, EFFECTIVE_RETRYABLE_EXCEPTIONS):
circuit_breaker.record_failure(e)
logger.error(f"[Orchestrator] Failed after retries: {config_name} / {task_id}. {type(e).__name__}: {e}", exc_info=True)
else:
logger.error(f"[Orchestrator] Failed (non-retryable): {config_name} / {task_id}. {type(e).__name__}: {e}", exc_info=True)
if progress_manager is not None:
progress_manager.mark_failed(task_id, f"{type(e).__name__}: {e}")
return False

async def main(task_list_file: Optional[str],
config_to_test: str,
data_dir: str, save_submission_dir: str,
overwrite_submission: bool, print_submission: bool,
num_attempts: int, retry_attempts: int,
logs_base_dir: Path) -> int: # Added return type hint
# Basic logging setup is now done in if __name__ == "__main__"

logs_base_dir: Path,
max_task_timeout: Optional[float] = None,
circuit_breaker_threshold: Optional[int] = None,
resume: bool = True) -> int:
start_time = time.perf_counter()
logger.info("Starting ARC Test Orchestrator...")
logger.info(f"Testing with model configuration: {config_to_test}")
if max_task_timeout:
logger.info(f"Task timeout: {max_task_timeout}s (CLI override)")
if circuit_breaker_threshold:
logger.info(f"Circuit breaker threshold: {circuit_breaker_threshold} (CLI override)")

task_ids: List[str] = []
try:
Expand Down Expand Up @@ -245,14 +329,52 @@ async def main(task_list_file: Optional[str],
logger.error(f"Error loading tasks: {e}", exc_info=True)
return 1 # Return an error code

# --- Checkpointing Setup ---
checkpoint_dir = Path(save_submission_dir) / config_to_test / ".checkpoints"
storage = LocalStorageBackend(checkpoint_dir)
progress_manager = BatchProgressManager(
storage=storage,
run_id=config_to_test,
)

# Initialize all tasks (idempotent - only adds tasks not already tracked)
progress_manager.initialize_tasks(task_ids, attempts_per_task=num_attempts)

# Reset any stale in-progress tasks (from crashed workers)
stale_count = progress_manager.reset_stale_tasks(max_age_seconds=3600)
if stale_count > 0:
logger.info(f"Reset {stale_count} stale in-progress task(s)")

# Determine which tasks to run
if resume:
# Only run pending tasks
tasks_to_run = [
task_id for task_id in task_ids
if progress_manager.progress.tasks.get(task_id) and
progress_manager.progress.tasks[task_id].status == TaskStatus.PENDING
]
completed_count = progress_manager.progress.completed_count
failed_count = progress_manager.progress.failed_count
if completed_count > 0 or failed_count > 0:
logger.info(
f"Resuming: {completed_count} completed, {failed_count} failed, "
f"{len(tasks_to_run)} remaining"
)
else:
tasks_to_run = task_ids
logger.info("Resume disabled - running all tasks")

all_jobs_to_run: List[Tuple[str, str]] = []
for task_id in task_ids:
for task_id in tasks_to_run:
all_jobs_to_run.append((config_to_test, task_id))

if not all_jobs_to_run:
if resume and progress_manager.is_complete():
logger.info("All tasks already completed. Use --no-resume to re-run.")
return 0
logger.warning("No jobs to run (check config_to_test and task list). Exiting.")
return 1 # Return an error code
return 1

logger.info(f"Total jobs to process: {len(all_jobs_to_run)}")

try:
Expand All @@ -271,12 +393,16 @@ async def main(task_list_file: Optional[str],
model_config_obj = get_model_config(config_name)
provider_name = model_config_obj.provider
limiter = get_or_create_rate_limiter(provider_name, all_provider_limits)
circuit_breaker = get_or_create_circuit_breaker(provider_name, all_provider_limits, circuit_breaker_threshold)
task_timeout_val = get_task_timeout(provider_name, all_provider_limits, max_task_timeout)
async_tasks_to_execute.append(run_single_test_wrapper(
config_name, task_id, limiter,
circuit_breaker, task_timeout_val,
data_dir, save_submission_dir,
overwrite_submission, print_submission,
overwrite_submission, print_submission,
num_attempts, retry_attempts,
logs_base_dir
logs_base_dir,
progress_manager,
))
except ValueError as e: # Specific error for model config issues
logger.error(f"Skipping config '{config_name}' for task '{task_id}' due to model config error: {e}")
Expand All @@ -291,6 +417,7 @@ async def main(task_list_file: Optional[str],
results = await asyncio.gather(*async_tasks_to_execute, return_exceptions=True)

successful_runs = sum(1 for r in results if r is True)
skipped_runs = sum(1 for r in results if r is None)
orchestrator_level_failures = sum(1 for r in results if r is False or isinstance(r, Exception))

logger.info("--- Orchestrator Summary ---")
Expand All @@ -306,10 +433,33 @@ async def main(task_list_file: Optional[str],
elif res is False: # Wrapper reported failure
logger.warning(f" - Failure reported by wrapper for {original_job_config}/{original_job_task_id} (check ARCTester logs for this task/config)")
exit_code = 1 # Indicate failure


# Log checkpoint progress summary
logger.info("--- Checkpoint Progress Summary ---")
summary = progress_manager.get_summary()
logger.info(
f" Run: {summary['run_id']} | "
f"Total: {summary['total']} | "
f"Completed: {summary['completed']} | "
f"Failed: {summary['failed']} | "
f"Pending: {summary['pending']}"
)

# Log circuit breaker statistics
if PROVIDER_CIRCUIT_BREAKERS:
logger.info("--- Circuit Breaker Summary ---")
for provider, cb in PROVIDER_CIRCUIT_BREAKERS.items():
stats = cb.get_stats()
logger.info(
f" {provider}: state={stats['state']}, "
f"requests={stats['total_requests']}, "
f"failures={stats['failed_requests']}, "
f"rejected={stats['rejected_requests']}"
)

logger.info("Note: Individual task success/failure is logged by ARCTester within its own logger (main.py's logger).")
logger.info("Orchestrator failure indicates an issue with running the ARCTester task itself or an unhandled exception in the wrapper.")

end_time = time.perf_counter()
total_duration = end_time - start_time
logger.info("--- Orchestrator Timing ---")
Expand Down Expand Up @@ -396,9 +546,28 @@ async def main(task_list_file: Optional[str],
default=None,
help="Maximum estimated cost in USD. Abort if estimated cost exceeds this limit."
)
parser.add_argument(
"--max-task-timeout",
type=float,
default=None,
help="Maximum timeout in seconds for each task execution. Overrides provider-specific timeouts."
)
parser.add_argument(
"--circuit-breaker-threshold",
type=int,
default=None,
help="Number of failures before circuit breaker opens. Overrides provider-specific thresholds."
)
parser.add_argument(
"--no-resume",
action="store_true",
help="Disable resume - run all tasks even if some are already completed."
)

args = parser.parse_args()

resume_enabled = not args.no_resume

# Set metrics enabled status based on CLI arg
set_metrics_enabled(args.enable_metrics)

Expand Down Expand Up @@ -474,7 +643,10 @@ async def main(task_list_file: Optional[str],
print_submission=args.print_submission,
num_attempts=args.num_attempts,
retry_attempts=args.retry_attempts,
logs_base_dir=logs_base_dir
logs_base_dir=logs_base_dir,
max_task_timeout=args.max_task_timeout,
circuit_breaker_threshold=args.circuit_breaker_threshold,
resume=resume_enabled,
))

sys.exit(exit_code_from_main)
Loading