From 29c6d06b81b1ee6dd77ba6281c30e892dc99ce65 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Wed, 17 Dec 2025 15:38:34 -0800 Subject: [PATCH 1/6] fix scheduler concurrency and orchestration --- .../default_single_turn_rollout_process.py | 16 +- eval_protocol/pytest/priority_scheduler.py | 305 ++++++++++-------- tests/test_priority_scheduler.py | 20 +- 3 files changed, 204 insertions(+), 137 deletions(-) diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index b8e4445d..0f770f22 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -18,6 +18,15 @@ logger = logging.getLogger(__name__) +litellm._turn_on_debug() + +# Configure logger with timestamp format if not already configured +if not logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(asctime)s.%(msecs)03d %(message)s', datefmt='%H:%M:%S')) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + class SingleTurnRolloutProcessor(RolloutProcessor): """Single turn rollout processor for direct LLM calls.""" @@ -39,7 +48,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row asynchronously.""" start_time = time.perf_counter() - + if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") @@ -97,7 +106,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: chunks.append(chunk) response = litellm.stream_chunk_builder(chunks, messages_payload) else: + logger.warning(f"******** rolling out {row.execution_metadata.run_id} ") + logger.warning(json.dumps(request_params)) + start_time = time.perf_counter() response = await acompletion(**request_params) + rollout_duration = time.perf_counter() - start_time + logger.warning(f"******** rollout duration for {row.execution_metadata.run_id} {rollout_duration} seconds") assert response is not None, "Response is None" assert isinstance(response, ModelResponse), "Response should be ModelResponse" diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 71958510..397a783b 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -20,25 +20,39 @@ ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1" + +@dataclass +class SampleState: + """ + Tracks state for a single dataset sample across multiple runs. + Enables streaming scheduling where each completed run immediately triggers the next. + """ + row: EvaluationRow + row_index: int + config: RolloutProcessorConfig + history: List[str] = field(default_factory=list) # Accumulated history from completed runs + next_run_idx: int = 0 # Next run index to schedule + active_runs: int = 0 # Currently executing runs for this sample + completed_runs: int = 0 # Total completed runs for this sample + lock: asyncio.Lock = field(default_factory=asyncio.Lock) # Protect state updates + + @dataclass(order=True) class RolloutTask: """ - Represents a single unit of work for the worker pool. - Priority tuple structure: (status, row_index) - - status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample) - 1 = Low Priority (e.g., starting a new sample) - - row_index: Used to maintain dataset order for initial scheduling + Represents a single rollout task (one run for one sample). + Priority tuple structure: (status, row_index, run_index) + - status: 0 = High Priority (continuing a started sample) + 1 = Low Priority (starting a new sample) + - row_index: Dataset order + - run_index: Run order within sample """ - priority: tuple[int, int] + priority: tuple[int, int, int] # Payload (excluded from comparison) - row: EvaluationRow = field(compare=False) - run_indices: List[int] = field(compare=False) # Which runs to execute in this task - config: RolloutProcessorConfig = field(compare=False) - row_index: int = field(compare=False) # To track which sample this belongs to - - # History for speculation (injected from previous micro-batches) - history: List[str] = field(compare=False, default_factory=list) + sample_state: SampleState = field(compare=False) + run_idx: int = field(compare=False) # Single run index for this task + history_snapshot: List[str] = field(compare=False, default_factory=list) # History at scheduling time class PriorityRolloutScheduler: """ @@ -70,7 +84,7 @@ def __init__( # Priority Queue: Stores RolloutTask self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() - # Concurrency Control + # Concurrency Control (rollout concurrency is handled by rollout_processor's semaphore) self.eval_sem = asyncio.Semaphore(max_concurrent_evaluations) # Results storage @@ -94,6 +108,9 @@ def __init__( # Track active evaluations self.active_evals: int = 0 self.active_evals_lock = asyncio.Lock() + + # Per-sample state for streaming scheduling + self.sample_states: Dict[int, SampleState] = {} async def schedule_dataset( self, @@ -101,26 +118,38 @@ async def schedule_dataset( base_config: RolloutProcessorConfig, ): """ - Populates the queue with initial tasks (the first micro-batch for each sample). + Populates the queue with initial tasks. + For each sample, schedules up to in_group_minibatch_size concurrent runs. """ for i, row in enumerate(dataset): - # Calculate ranges for the first in-group minibatch - batch_start = 0 - batch_end = min(self.in_group_minibatch_size, self.rollout_n) - run_indices = list(range(batch_start, batch_end)) - - # Initial priority: Low (1), ordered by dataset index - priority = (1, i) - - task = RolloutTask( - priority=priority, + # Create sample state + sample_state = SampleState( row=row, - run_indices=run_indices, - config=base_config, row_index=i, - history=[] # Initial batch has no history + config=base_config, + history=[], + next_run_idx=0, + active_runs=0, + completed_runs=0, + lock=asyncio.Lock(), ) - self.queue.put_nowait(task) + self.sample_states[i] = sample_state + + # Schedule initial runs (up to in_group_minibatch_size) + initial_runs = min(self.in_group_minibatch_size, self.rollout_n) + for run_idx in range(initial_runs): + # Initial priority: Low (1), ordered by dataset index, then run index + priority = (1, i, run_idx) + + task = RolloutTask( + priority=priority, + sample_state=sample_state, + run_idx=run_idx, + history_snapshot=[], # First runs have no history + ) + self.queue.put_nowait(task) + sample_state.next_run_idx = run_idx + 1 + sample_state.active_runs += 1 async def worker(self): """ @@ -133,7 +162,7 @@ async def worker(self): try: await self._process_task(task) except Exception as e: - logging.error(f"Error processing task for row {task.row.input_metadata.row_id}: {e}", exc_info=True) + logging.error(f"Error processing task for row {task.sample_state.row.input_metadata.row_id} run {task.run_idx}: {e}", exc_info=True) finally: self.queue.task_done() @@ -210,112 +239,127 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): if self.eval_pbar: self.eval_pbar.set_postfix_str(f"active={self.active_evals}") - # 1. Prepare Config & Row for this micro-batch - current_batch_rows = [] - for run_idx in task.run_indices: - row_copy = task.row.model_copy(deep=True) - - row_copy.execution_metadata.run_id = generate_id() - row_copy.execution_metadata.rollout_id = generate_id() - if row_copy.execution_metadata.extra is None: - row_copy.execution_metadata.extra = {} - row_copy.execution_metadata.extra["run_index"] = run_idx + sample_state = task.sample_state + run_idx = task.run_idx + row_index = sample_state.row_index + + # Rollout concurrency is handled by rollout_processor's internal semaphore + # 1. Prepare row for this single run + row_copy = sample_state.row.model_copy(deep=True) + row_copy.execution_metadata.run_id = generate_id() + row_copy.execution_metadata.rollout_id = generate_id() + if row_copy.execution_metadata.extra is None: + row_copy.execution_metadata.extra = {} + row_copy.execution_metadata.extra["run_index"] = run_idx + + # Make a copy of config for this specific run (to inject per-run speculation) + run_config = sample_state.config + + # Inject Speculation History into config.completion_params (use snapshot from when task was scheduled) + if ENABLE_SPECULATION and task.history_snapshot: + # Deep copy completion_params to avoid mutating shared config + cp = dict(sample_state.config.completion_params) if sample_state.config.completion_params else {} + max_tokens = cp.get("max_tokens", 2048) + if "extra_body" not in cp: + cp["extra_body"] = {} - # Inject Speculation History - if ENABLE_SPECULATION and task.history: - cp = row_copy.input_metadata.completion_params - max_tokens = cp.get("max_tokens", 2048) - # Ensure safe dict access - if not isinstance(cp, dict): - cp = {} - # Need to check and initialize nested dicts - extra_body = cp.get("extra_body") - if extra_body is None or not isinstance(extra_body, dict): - extra_body = {} - # for speculation, see - # https://docs.fireworks.ai/guides/predicted-outputs - # https://platform.openai.com/docs/guides/predicted-outputs?lang=python - extra_body["prediction"] = {"type": "content", "content": " ".join(task.history)[:max_tokens]} - cp["extra_body"] = extra_body - row_copy.input_metadata.completion_params = cp + cp["extra_body"]["prediction"] = " ".join(task.history_snapshot)[:max_tokens] - current_batch_rows.append((run_idx, row_copy)) - self.active_logger.log(row_copy) + # Create a new config with the modified completion_params (copy all fields) + base_config = sample_state.config + run_config = RolloutProcessorConfig( + completion_params=cp, + mcp_config_path=base_config.mcp_config_path, + semaphore=base_config.semaphore, + server_script_path=base_config.server_script_path, + steps=base_config.steps, + logger=base_config.logger, + kwargs=base_config.kwargs, + exception_handler_config=base_config.exception_handler_config, + post_processor=base_config.post_processor, + ) - - # 2. Execute Rollout - batch_results: List[EvaluationRow] = [] - if current_batch_rows: - for idx, row in current_batch_rows: - # Track this rollout as active - async with self.active_rollouts_lock: - self.active_rollouts[task.row_index].add(idx) - await self._update_rollout_pbar_postfix() - - try: - async for result_row in rollout_processor_with_retry( - self.rollout_processor, [row], task.config, idx, disable_tqdm=True - ): - batch_results.append(result_row) - - # Update rollout progress bar - if self.rollout_pbar: - self.rollout_pbar.update(1) - - # in pointwise, we start evaluation immediately - if self.mode == "pointwise": - t = asyncio.create_task(_run_eval(result_row)) - self.background_tasks.add(t) - t.add_done_callback(self.background_tasks.discard) - finally: - # Remove from active tracking - async with self.active_rollouts_lock: - self.active_rollouts[task.row_index].discard(idx) - if not self.active_rollouts[task.row_index]: - del self.active_rollouts[task.row_index] - await self._update_rollout_pbar_postfix() + self.active_logger.log(row_copy) - # 3. Evaluate and Collect History - current_batch_history_updates = [] - # Extract history from rollout results (assuming eval doesn't change content needed for history) - for res in batch_results: - last_msg = res.last_assistant_message() - if last_msg and last_msg.content: - content = last_msg.content - current_batch_history_updates.append(str(content)) - else: - current_batch_history_updates.append("") - - # in groupwise, we send all rows to evaluator in one go when the whole group is complete - if self.mode == "groupwise": - self.groups_buffer[task.row_index].extend(batch_results) - if len(self.groups_buffer[task.row_index]) >= self.rollout_n: - full_group = self.groups_buffer.pop(task.row_index) - t = asyncio.create_task(_run_eval(full_group)) - self.background_tasks.add(t) - t.add_done_callback(self.background_tasks.discard) - - # 4. Schedule Next Micro-batch (High Priority) - last_run_idx = task.run_indices[-1] if task.run_indices else -1 - next_start = last_run_idx + 1 + # 2. Track this rollout as active + async with self.active_rollouts_lock: + self.active_rollouts[row_index].add(run_idx) + await self._update_rollout_pbar_postfix() + + # 3. Execute the rollout + result_row: Optional[EvaluationRow] = None + start_time = time.perf_counter() + try: + async for result in rollout_processor_with_retry( + self.rollout_processor, [row_copy], run_config, run_idx, disable_tqdm=True + ): + result_row = result + result_row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time + + # Update rollout progress bar + if self.rollout_pbar: + self.rollout_pbar.update(1) + + # In pointwise mode, start evaluation immediately + if self.mode == "pointwise": + t = asyncio.create_task(_run_eval(result_row)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) + finally: + # Remove from active tracking + async with self.active_rollouts_lock: + self.active_rollouts[row_index].discard(run_idx) + if not self.active_rollouts[row_index]: + del self.active_rollouts[row_index] + await self._update_rollout_pbar_postfix() - if next_start < self.rollout_n: - next_end = min(next_start + self.in_group_minibatch_size, self.rollout_n) - next_indices = list(range(next_start, next_end)) - new_history = task.history + current_batch_history_updates + # 4. Update sample state and schedule next run (streaming) + async with sample_state.lock: + sample_state.active_runs -= 1 + sample_state.completed_runs += 1 - # Priority 0 (High) to ensure we finish this sample ASAP - new_priority = (0, task.row_index) + # Extract history from this run's result + if result_row: + last_msg = result_row.last_assistant_message() + if last_msg and last_msg.content: + print(f"******** history: {str(last_msg.content)}") + sample_state.history.append(str(last_msg.content)) + else: + sample_state.history.append("") - new_task = RolloutTask( - priority=new_priority, - row=task.row, - run_indices=next_indices, - config=task.config, - row_index=task.row_index, - history=new_history - ) - self.queue.put_nowait(new_task) + # In groupwise mode, buffer results + if self.mode == "groupwise": + if result_row: + self.groups_buffer[row_index].append(result_row) + # Check if all runs for this sample are complete + if sample_state.completed_runs >= self.rollout_n: + full_group = self.groups_buffer.pop(row_index, []) + if full_group: + t = asyncio.create_task(_run_eval(full_group)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) + + # Schedule next run if: + # 1. There are more runs to do + # 2. We haven't hit in_group_minibatch_size concurrent runs for this sample + if (sample_state.next_run_idx < self.rollout_n and + sample_state.active_runs < self.in_group_minibatch_size): + + next_run_idx = sample_state.next_run_idx + sample_state.next_run_idx += 1 + sample_state.active_runs += 1 + + # High priority (0) to finish this sample ASAP + # Use current accumulated history for speculation + priority = (0, row_index, next_run_idx) + + new_task = RolloutTask( + priority=priority, + sample_state=sample_state, + run_idx=next_run_idx, + history_snapshot=list(sample_state.history), # Snapshot current history + ) + self.queue.put_nowait(new_task) def _format_active_rollouts(self) -> str: """Format active rollouts for display in progress bar.""" @@ -410,7 +454,8 @@ async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: Ro await self.schedule_dataset(dataset, base_config) # 2. Start Workers - # If we have separate limits, we need enough workers to saturate both stages + # With semaphore-based concurrency control, workers can be equal to max_concurrent_rollouts + # The semaphore will limit actual concurrent executions num_workers = self.max_concurrent_rollouts workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] diff --git a/tests/test_priority_scheduler.py b/tests/test_priority_scheduler.py index f5b4fa31..a6e44ff0 100644 --- a/tests/test_priority_scheduler.py +++ b/tests/test_priority_scheduler.py @@ -5,7 +5,7 @@ from typing import List, Union from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, EvaluateResult -from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask +from eval_protocol.pytest.priority_scheduler import PriorityRolloutScheduler, execute_priority_rollouts, RolloutTask, SampleState from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -244,13 +244,21 @@ async def worker(self): async def schedule_dataset(self, *args): # Put enough items to ensure all workers wake up and grab one for i in range(expected_workers): - task = RolloutTask( - priority=(1, i), + sample_state = SampleState( row=dataset[0], - run_indices=[], - config=base_config, row_index=0, - history=[] + config=base_config, + history=[], + next_run_idx=0, + active_runs=0, + completed_runs=0, + lock=asyncio.Lock(), + ) + task = RolloutTask( + priority=(1, i, 0), + sample_state=sample_state, + run_idx=0, + history_snapshot=[], ) await self.queue.put(task) From c9344c3110119da58e92c59c6afe92a391770a10 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Wed, 17 Dec 2025 15:42:39 -0800 Subject: [PATCH 2/6] remove comment --- .../default_single_turn_rollout_process.py | 16 +--------------- eval_protocol/pytest/priority_scheduler.py | 1 - 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 0f770f22..4e6050e5 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -18,15 +18,6 @@ logger = logging.getLogger(__name__) -litellm._turn_on_debug() - -# Configure logger with timestamp format if not already configured -if not logger.handlers: - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter('%(asctime)s.%(msecs)03d %(message)s', datefmt='%H:%M:%S')) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - class SingleTurnRolloutProcessor(RolloutProcessor): """Single turn rollout processor for direct LLM calls.""" @@ -48,7 +39,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row asynchronously.""" start_time = time.perf_counter() - + if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") @@ -106,12 +97,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: chunks.append(chunk) response = litellm.stream_chunk_builder(chunks, messages_payload) else: - logger.warning(f"******** rolling out {row.execution_metadata.run_id} ") - logger.warning(json.dumps(request_params)) - start_time = time.perf_counter() response = await acompletion(**request_params) - rollout_duration = time.perf_counter() - start_time - logger.warning(f"******** rollout duration for {row.execution_metadata.run_id} {rollout_duration} seconds") assert response is not None, "Response is None" assert isinstance(response, ModelResponse), "Response should be ModelResponse" diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 397a783b..166aef85 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -322,7 +322,6 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): if result_row: last_msg = result_row.last_assistant_message() if last_msg and last_msg.content: - print(f"******** history: {str(last_msg.content)}") sample_state.history.append(str(last_msg.content)) else: sample_state.history.append("") From 1dd4169322c0fce18ed8510e9957d7aaf7e26a4a Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 18 Dec 2025 12:13:56 -0800 Subject: [PATCH 3/6] add --- eval_protocol/pytest/priority_scheduler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 166aef85..18712df6 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -70,7 +70,7 @@ def __init__( output_buffer: Optional[MicroBatchDataBuffer] = None, rollout_n: int = 0, mode: str = "pointwise", - in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time + in_group_minibatch_size: Optional[int] = None, # for one sample, how many runs to execute at the same time evaluation_test_kwargs: Dict[str, Any] = {}, ): self.rollout_processor = rollout_processor @@ -94,6 +94,11 @@ def __init__( self.background_tasks = set() # run evaluations in the background asynchronously self.rollout_n = rollout_n + if in_group_minibatch_size is None: + if ENABLE_SPECULATION: + in_group_minibatch_size = rollout_n // 2 + else: + in_group_minibatch_size = rollout_n self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n self.evaluation_test_kwargs = evaluation_test_kwargs @@ -108,8 +113,7 @@ def __init__( # Track active evaluations self.active_evals: int = 0 self.active_evals_lock = asyncio.Lock() - - # Per-sample state for streaming scheduling + self.sample_states: Dict[int, SampleState] = {} async def schedule_dataset( @@ -504,7 +508,6 @@ async def execute_priority_rollouts( max_concurrent_evaluations=max_concurrent_evaluations, rollout_n=num_runs, mode=mode, - in_group_minibatch_size=(num_runs // 2), evaluation_test_kwargs=evaluation_test_kwargs, ) return await scheduler.run(dataset, num_runs, config) From 692f3ad6a8eb2f507537cdb6c51d680fa602a0b1 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 18 Dec 2025 13:35:57 -0800 Subject: [PATCH 4/6] resolve comments --- eval_protocol/pytest/priority_scheduler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 18712df6..0545c9dd 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -1,4 +1,5 @@ import asyncio +import copy import logging import os import time @@ -261,13 +262,13 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): # Inject Speculation History into config.completion_params (use snapshot from when task was scheduled) if ENABLE_SPECULATION and task.history_snapshot: - # Deep copy completion_params to avoid mutating shared config - cp = dict(sample_state.config.completion_params) if sample_state.config.completion_params else {} + # Deep copy to avoid concurrent mutation of shared nested dicts + cp = copy.deepcopy(sample_state.config.completion_params) if sample_state.config.completion_params else {} max_tokens = cp.get("max_tokens", 2048) if "extra_body" not in cp: cp["extra_body"] = {} - cp["extra_body"]["prediction"] = " ".join(task.history_snapshot)[:max_tokens] + cp["extra_body"]["prediction"] = {"type": "content", "content": " ".join(task.history_snapshot)[:max_tokens]} # Create a new config with the modified completion_params (copy all fields) base_config = sample_state.config From c726a578295821115f69634fc9d8f6e260d2eebd Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 18 Dec 2025 18:18:29 -0800 Subject: [PATCH 5/6] fix --- .../default_single_turn_rollout_process.py | 3 + eval_protocol/pytest/priority_scheduler.py | 87 ++++++++++--------- 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 4e6050e5..75306a9c 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -97,7 +97,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: chunks.append(chunk) response = litellm.stream_chunk_builder(chunks, messages_payload) else: + tc = time.perf_counter() + # print(f"run_id {row.execution_metadata.run_id} request_params: {json.dumps(request_params)}") response = await acompletion(**request_params) + print(f"run_id {row.execution_metadata.run_id} time taken: {time.perf_counter() - tc} speculation_enabled: {request_params.get('extra_body', {}).get('prediction', None) is not None}") assert response is not None, "Response is None" assert isinstance(response, ModelResponse), "Response should be ModelResponse" diff --git a/eval_protocol/pytest/priority_scheduler.py b/eval_protocol/pytest/priority_scheduler.py index 0545c9dd..605b973e 100644 --- a/eval_protocol/pytest/priority_scheduler.py +++ b/eval_protocol/pytest/priority_scheduler.py @@ -318,52 +318,53 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): del self.active_rollouts[row_index] await self._update_rollout_pbar_postfix() - # 4. Update sample state and schedule next run (streaming) - async with sample_state.lock: - sample_state.active_runs -= 1 - sample_state.completed_runs += 1 - - # Extract history from this run's result - if result_row: - last_msg = result_row.last_assistant_message() - if last_msg and last_msg.content: - sample_state.history.append(str(last_msg.content)) - else: - sample_state.history.append("") - - # In groupwise mode, buffer results - if self.mode == "groupwise": - if result_row: - self.groups_buffer[row_index].append(result_row) - # Check if all runs for this sample are complete - if sample_state.completed_runs >= self.rollout_n: - full_group = self.groups_buffer.pop(row_index, []) - if full_group: - t = asyncio.create_task(_run_eval(full_group)) - self.background_tasks.add(t) - t.add_done_callback(self.background_tasks.discard) - - # Schedule next run if: - # 1. There are more runs to do - # 2. We haven't hit in_group_minibatch_size concurrent runs for this sample - if (sample_state.next_run_idx < self.rollout_n and - sample_state.active_runs < self.in_group_minibatch_size): + # 4. Update sample state and schedule next run (streaming) + # Must be in finally to ensure state is updated even on exception + async with sample_state.lock: + sample_state.active_runs -= 1 + sample_state.completed_runs += 1 - next_run_idx = sample_state.next_run_idx - sample_state.next_run_idx += 1 - sample_state.active_runs += 1 + # Extract history from this run's result + if result_row: + last_msg = result_row.last_assistant_message() + if last_msg and last_msg.content: + sample_state.history.append(str(last_msg.content)) + else: + sample_state.history.append("") - # High priority (0) to finish this sample ASAP - # Use current accumulated history for speculation - priority = (0, row_index, next_run_idx) + # In groupwise mode, buffer results + if self.mode == "groupwise": + if result_row: + self.groups_buffer[row_index].append(result_row) + # Check if all runs for this sample are complete + if sample_state.completed_runs >= self.rollout_n: + full_group = self.groups_buffer.pop(row_index, []) + if full_group: + t = asyncio.create_task(_run_eval(full_group)) + self.background_tasks.add(t) + t.add_done_callback(self.background_tasks.discard) - new_task = RolloutTask( - priority=priority, - sample_state=sample_state, - run_idx=next_run_idx, - history_snapshot=list(sample_state.history), # Snapshot current history - ) - self.queue.put_nowait(new_task) + # Schedule next run if: + # 1. There are more runs to do + # 2. We haven't hit in_group_minibatch_size concurrent runs for this sample + if (sample_state.next_run_idx < self.rollout_n and + sample_state.active_runs < self.in_group_minibatch_size): + + next_run_idx = sample_state.next_run_idx + sample_state.next_run_idx += 1 + sample_state.active_runs += 1 + + # High priority (0) to finish this sample ASAP + # Use current accumulated history for speculation + priority = (0, row_index, next_run_idx) + + new_task = RolloutTask( + priority=priority, + sample_state=sample_state, + run_idx=next_run_idx, + history_snapshot=list(sample_state.history), # Snapshot current history + ) + self.queue.put_nowait(new_task) def _format_active_rollouts(self) -> str: """Format active rollouts for display in progress bar.""" From fb8debeba67fc6cb163be0b399fa53b8670dcfd3 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 18 Dec 2025 20:24:57 -0800 Subject: [PATCH 6/6] remove print --- eval_protocol/pytest/default_single_turn_rollout_process.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 75306a9c..4e6050e5 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -97,10 +97,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: chunks.append(chunk) response = litellm.stream_chunk_builder(chunks, messages_payload) else: - tc = time.perf_counter() - # print(f"run_id {row.execution_metadata.run_id} request_params: {json.dumps(request_params)}") response = await acompletion(**request_params) - print(f"run_id {row.execution_metadata.run_id} time taken: {time.perf_counter() - tc} speculation_enabled: {request_params.get('extra_body', {}).get('prediction', None) is not None}") assert response is not None, "Response is None" assert isinstance(response, ModelResponse), "Response should be ModelResponse"