diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index b8e4445d..4e6050e5 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -39,7 +39,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row asynchronously.""" start_time = time.perf_counter() - + if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 71958510..605b973e 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -1,4 +1,5 @@ import asyncio +import copy import logging import os import time @@ -20,25 +21,39 @@ ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1" + +@dataclass +class SampleState: + """ + Tracks state for a single dataset sample across multiple runs. + Enables streaming scheduling where each completed run immediately triggers the next. + """ + row: EvaluationRow + row_index: int + config: RolloutProcessorConfig + history: List[str] = field(default_factory=list) # Accumulated history from completed runs + next_run_idx: int = 0 # Next run index to schedule + active_runs: int = 0 # Currently executing runs for this sample + completed_runs: int = 0 # Total completed runs for this sample + lock: asyncio.Lock = field(default_factory=asyncio.Lock) # Protect state updates + + @dataclass(order=True) class RolloutTask: """ - Represents a single unit of work for the worker pool. - Priority tuple structure: (status, row_index) - - status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample) - 1 = Low Priority (e.g., starting a new sample) - - row_index: Used to maintain dataset order for initial scheduling + Represents a single rollout task (one run for one sample). + Priority tuple structure: (status, row_index, run_index) + - status: 0 = High Priority (continuing a started sample) + 1 = Low Priority (starting a new sample) + - row_index: Dataset order + - run_index: Run order within sample """ - priority: tuple[int, int] + priority: tuple[int, int, int] # Payload (excluded from comparison) - row: EvaluationRow = field(compare=False) - run_indices: List[int] = field(compare=False) # Which runs to execute in this task - config: RolloutProcessorConfig = field(compare=False) - row_index: int = field(compare=False) # To track which sample this belongs to - - # History for speculation (injected from previous micro-batches) - history: List[str] = field(compare=False, default_factory=list) + sample_state: SampleState = field(compare=False) + run_idx: int = field(compare=False) # Single run index for this task + history_snapshot: List[str] = field(compare=False, default_factory=list) # History at scheduling time class PriorityRolloutScheduler: """ @@ -56,7 +71,7 @@ def __init__( output_buffer: Optional[MicroBatchDataBuffer] = None, rollout_n: int = 0, mode: str = "pointwise", - in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time + in_group_minibatch_size: Optional[int] = None, # for one sample, how many runs to execute at the same time evaluation_test_kwargs: Dict[str, Any] = {}, ): self.rollout_processor = rollout_processor @@ -70,7 +85,7 @@ def __init__( # Priority Queue: Stores RolloutTask self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() - # Concurrency Control + # Concurrency Control (rollout concurrency is handled by rollout_processor's semaphore) self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) # Results storage @@ -80,6 +95,11 @@ def __init__( self.background_tasks = set() # run evaluations in the background asynchronously self.rollout_n = rollout_n + if in_group_minibatch_size is None: + if ENABLE_SPECULATION: + in_group_minibatch_size = rollout_n // 2 + else: + in_group_minibatch_size = rollout_n self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n self.evaluation_test_kwargs = evaluation_test_kwargs @@ -95,32 +115,46 @@ def __init__( self.active_evals: int = 0 self.active_evals_lock = asyncio.Lock() + self.sample_states: Dict[int, SampleState] = {} + async def schedule_dataset( self, dataset: List[EvaluationRow], base_config: RolloutProcessorConfig, ): """ - Populates the queue with initial tasks (the first micro-batch for each sample). + Populates the queue with initial tasks. + For each sample, schedules up to in_group_minibatch_size concurrent runs. """ for i, row in enumerate(dataset): - # Calculate ranges for the first in-group minibatch - batch_start = 0 - batch_end = min(self.in_group_minibatch_size, self.rollout_n) - run_indices = list(range(batch_start, batch_end)) - - # Initial priority: Low (1), ordered by dataset index - priority = (1, i) - - task = RolloutTask( - priority=priority, + # Create sample state + sample_state = SampleState( row=row, - run_indices=run_indices, - config=base_config, row_index=i, - history=[] # Initial batch has no history + config=base_config, + history=[], + next_run_idx=0, + active_runs=0, + completed_runs=0, + lock=asyncio.Lock(), ) - self.queue.put_nowait(task) + self.sample_states[i] = sample_state + + # Schedule initial runs (up to in_group_minibatch_size) + initial_runs = min(self.in_group_minibatch_size, self.rollout_n) + for run_idx in range(initial_runs): + # Initial priority: Low (1), ordered by dataset index, then run index + priority = (1, i, run_idx) + + task = RolloutTask( + priority=priority, + sample_state=sample_state, + run_idx=run_idx, + history_snapshot=[], # First runs have no history + ) + self.queue.put_nowait(task) + sample_state.next_run_idx = run_idx + 1 + sample_state.active_runs += 1 async def worker(self): """ @@ -133,7 +167,7 @@ async def worker(self): try: await self._process_task(task) except Exception as e: - logging.error(f"Error processing task for row {task.row.input_metadata.row_id}: {e}", exc_info=True) + logging.error(f"Error processing task for row {task.sample_state.row.input_metadata.row_id} run {task.run_idx}: {e}", exc_info=True) finally: self.queue.task_done() @@ -210,112 +244,127 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): if self.eval_pbar: self.eval_pbar.set_postfix_str(f"active={self.active_evals}") - # 1. Prepare Config & Row for this micro-batch - current_batch_rows = [] - for run_idx in task.run_indices: - row_copy = task.row.model_copy(deep=True) - - row_copy.execution_metadata.run_id = generate_id() - row_copy.execution_metadata.rollout_id = generate_id() - if row_copy.execution_metadata.extra is None: - row_copy.execution_metadata.extra = {} - row_copy.execution_metadata.extra["run_index"] = run_idx - - # Inject Speculation History - if ENABLE_SPECULATION and task.history: - cp = row_copy.input_metadata.completion_params - max_tokens = cp.get("max_tokens", 2048) - # Ensure safe dict access - if not isinstance(cp, dict): - cp = {} - # Need to check and initialize nested dicts - extra_body = cp.get("extra_body") - if extra_body is None or not isinstance(extra_body, dict): - extra_body = {} - # for speculation, see - # https://docs.fireworks.ai/guides/predicted-outputs - # https://platform.openai.com/docs/guides/predicted-outputs?lang=python - extra_body["prediction"] = {"type": "content", "content": " ".join(task.history)[:max_tokens]} - cp["extra_body"] = extra_body - row_copy.input_metadata.completion_params = cp - - current_batch_rows.append((run_idx, row_copy)) - self.active_logger.log(row_copy) + sample_state = task.sample_state + run_idx = task.run_idx + row_index = sample_state.row_index - - # 2. Execute Rollout - batch_results: List[EvaluationRow] = [] - if current_batch_rows: - for idx, row in current_batch_rows: - # Track this rollout as active - async with self.active_rollouts_lock: - self.active_rollouts[task.row_index].add(idx) - await self._update_rollout_pbar_postfix() - - try: - async for result_row in rollout_processor_with_retry( - self.rollout_processor, [row], task.config, idx, disable_tqdm=True - ): - batch_results.append(result_row) - - # Update rollout progress bar - if self.rollout_pbar: - self.rollout_pbar.update(1) - - # in pointwise, we start evaluation immediately - if self.mode == "pointwise": - t = asyncio.create_task(_run_eval(result_row)) - self.background_tasks.add(t) - t.add_done_callback(self.background_tasks.discard) - finally: - # Remove from active tracking - async with self.active_rollouts_lock: - self.active_rollouts[task.row_index].discard(idx) - if not self.active_rollouts[task.row_index]: - del self.active_rollouts[task.row_index] - await self._update_rollout_pbar_postfix() + # Rollout concurrency is handled by rollout_processor's internal semaphore + # 1. Prepare row for this single run + row_copy = sample_state.row.model_copy(deep=True) + row_copy.execution_metadata.run_id = generate_id() + row_copy.execution_metadata.rollout_id = generate_id() + if row_copy.execution_metadata.extra is None: + row_copy.execution_metadata.extra = {} + row_copy.execution_metadata.extra["run_index"] = run_idx - # 3. Evaluate and Collect History - current_batch_history_updates = [] - # Extract history from rollout results (assuming eval doesn't change content needed for history) - for res in batch_results: - last_msg = res.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") - - # in groupwise, we send all rows to evaluator in one go when the whole group is complete - if self.mode == "groupwise": - self.groups_buffer[task.row_index].extend(batch_results) - if len(self.groups_buffer[task.row_index]) >= self.rollout_n: - full_group = self.groups_buffer.pop(task.row_index) - t = asyncio.create_task(_run_eval(full_group)) - self.background_tasks.add(t) - t.add_done_callback(self.background_tasks.discard) - - # 4. Schedule Next Micro-batch (High Priority) - last_run_idx = task.run_indices[-1] if task.run_indices else -1 - next_start = last_run_idx + 1 + # Make a copy of config for this specific run (to inject per-run speculation) + run_config = sample_state.config - if next_start < self.rollout_n: - next_end = min(next_start + self.in_group_minibatch_size, self.rollout_n) - next_indices = list(range(next_start, next_end)) - new_history = task.history + current_batch_history_updates + # Inject Speculation History into config.completion_params (use snapshot from when task was scheduled) + if ENABLE_SPECULATION and task.history_snapshot: + # Deep copy to avoid concurrent mutation of shared nested dicts + cp = copy.deepcopy(sample_state.config.completion_params) if sample_state.config.completion_params else {} + max_tokens = cp.get("max_tokens", 2048) + if "extra_body" not in cp: + cp["extra_body"] = {} - # Priority 0 (High) to ensure we finish this sample ASAP - new_priority = (0, task.row_index) + cp["extra_body"]["prediction"] = {"type": "content", "content": " ".join(task.history_snapshot)[:max_tokens]} - new_task = RolloutTask( - priority=new_priority, - row=task.row, - run_indices=next_indices, - config=task.config, - row_index=task.row_index, - history=new_history + # Create a new config with the modified completion_params (copy all fields) + base_config = sample_state.config + run_config = RolloutProcessorConfig( + completion_params=cp, + mcp_config_path=base_config.mcp_config_path, + semaphore=base_config.semaphore, + server_script_path=base_config.server_script_path, + steps=base_config.steps, + logger=base_config.logger, + kwargs=base_config.kwargs, + exception_handler_config=base_config.exception_handler_config, + post_processor=base_config.post_processor, ) - self.queue.put_nowait(new_task) + + self.active_logger.log(row_copy) + + # 2. Track this rollout as active + async with self.active_rollouts_lock: + self.active_rollouts[row_index].add(run_idx) + await self._update_rollout_pbar_postfix() + + # 3. Execute the rollout + result_row: Optional[EvaluationRow] = None + start_time = time.perf_counter() + try: + async for result in rollout_processor_with_retry( + self.rollout_processor, [row_copy], run_config, run_idx, disable_tqdm=True + ): + result_row = result + result_row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time + + # Update rollout progress bar + if self.rollout_pbar: + self.rollout_pbar.update(1) + + # In pointwise mode, start evaluation immediately + if self.mode == "pointwise": + t = asyncio.create_task(_run_eval(result_row)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) + finally: + # Remove from active tracking + async with self.active_rollouts_lock: + self.active_rollouts[row_index].discard(run_idx) + if not self.active_rollouts[row_index]: + del self.active_rollouts[row_index] + await self._update_rollout_pbar_postfix() + + # 4. Update sample state and schedule next run (streaming) + # Must be in finally to ensure state is updated even on exception + async with sample_state.lock: + sample_state.active_runs -= 1 + sample_state.completed_runs += 1 + + # Extract history from this run's result + if result_row: + last_msg = result_row.last_assistant_message() + if last_msg and last_msg.content: + sample_state.history.append(str(last_msg.content)) + else: + sample_state.history.append("") + + # In groupwise mode, buffer results + if self.mode == "groupwise": + if result_row: + self.groups_buffer[row_index].append(result_row) + # Check if all runs for this sample are complete + if sample_state.completed_runs >= self.rollout_n: + full_group = self.groups_buffer.pop(row_index, []) + if full_group: + t = asyncio.create_task(_run_eval(full_group)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) + + # Schedule next run if: + # 1. There are more runs to do + # 2. We haven't hit in_group_minibatch_size concurrent runs for this sample + if (sample_state.next_run_idx < self.rollout_n and + sample_state.active_runs < self.in_group_minibatch_size): + + next_run_idx = sample_state.next_run_idx + sample_state.next_run_idx += 1 + sample_state.active_runs += 1 + + # High priority (0) to finish this sample ASAP + # Use current accumulated history for speculation + priority = (0, row_index, next_run_idx) + + new_task = RolloutTask( + priority=priority, + sample_state=sample_state, + run_idx=next_run_idx, + history_snapshot=list(sample_state.history), # Snapshot current history + ) + self.queue.put_nowait(new_task) def _format_active_rollouts(self) -> str: """Format active rollouts for display in progress bar.""" @@ -410,7 +459,8 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: Ro await self.schedule_dataset(dataset, base_config) # 2. Start Workers - # If we have separate limits, we need enough workers to saturate both stages + # With semaphore-based concurrency control, workers can be equal to max_concurrent_rollouts + # The semaphore will limit actual concurrent executions num_workers = self.max_concurrent_rollouts workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] @@ -460,7 +510,6 @@ async def execute_priority_rollouts( max_concurrent_evaluations=max_concurrent_evaluations, rollout_n=num_runs, mode=mode, - in_group_minibatch_size=(num_runs // 2), evaluation_test_kwargs=evaluation_test_kwargs, ) return await scheduler.run(dataset, num_runs, config) diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py index f5b4fa31..a6e44ff0 100644 --- a/tests/test_priority_scheduler.py +++ b/tests/test_priority_scheduler.py @@ -5,7 +5,7 @@ from typing import List, Union from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, EvaluateResult -from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask +from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask, SampleState from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -244,13 +244,21 @@ async def worker(self): async def schedule_dataset(self, *args): # Put enough items to ensure all workers wake up and grab one for i in range(expected_workers): - task = RolloutTask( - priority=(1, i), + sample_state = SampleState( row=dataset[0], - run_indices=[], - config=base_config, row_index=0, - history=[] + config=base_config, + history=[], + next_run_idx=0, + active_runs=0, + completed_runs=0, + lock=asyncio.Lock(), + ) + task = RolloutTask( + priority=(1, i, 0), + sample_state=sample_state, + run_idx=0, + history_snapshot=[], ) await self.queue.put(task)