Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def evaluation_test(
filtered_row_ids: Sequence[str] | None = None,
max_dataset_rows: int | None = None,
mcp_config_path: str | None = None,
max_concurrent_rollouts: int = 8,
max_concurrent_evaluations: int = 64,
max_concurrent_rollouts: int = 96,
max_concurrent_evaluations: int = 96,
server_script_path: str | None = None,
steps: int = 30,
mode: EvaluationTestMode = "pointwise",
Expand Down Expand Up @@ -409,21 +409,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo

rollout_processor.setup()

use_priority_scheduler = (
(
os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1"
and not isinstance(rollout_processor, MCPGymRolloutProcessor)
)
)
use_priority_scheduler = os.environ.get(
"EP_USE_PRIORITY_SCHEDULER", "0"
) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor)

if use_priority_scheduler:
microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None)
output_dir = os.environ.get("EP_OUTPUT_DIR", None)
if microbatch_output_size and output_dir:
output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
output_buffer = MicroBatchDataBuffer(
num_runs=num_runs,
batch_size=int(microbatch_output_size),
output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"),
)
else:
output_buffer = None

try:
priority_results = await execute_priority_rollouts(
dataset=data,
Expand All @@ -441,12 +442,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
finally:
if output_buffer:
await output_buffer.close()

for res in priority_results:
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
if run_idx < len(all_results):
all_results[run_idx].append(res)

processed_rows_in_run.append(res)

postprocess(
Expand All @@ -462,6 +463,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
)

else:

async def execute_run(run_idx: int, config: RolloutProcessorConfig):
nonlocal all_results

Expand Down Expand Up @@ -506,9 +508,7 @@ async def _execute_pointwise_eval_with_semaphore(
raise ValueError(
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
)
result.execution_metadata.eval_duration_seconds = (
time.perf_counter() - start_time
)
result.execution_metadata.eval_duration_seconds = time.perf_counter() - start_time
return result

async def _execute_groupwise_eval_with_semaphore(
Expand All @@ -519,7 +519,9 @@ async def _execute_groupwise_eval_with_semaphore(
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
group_rollout_ids = [
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
r.execution_metadata.rollout_id
for r in rows
if r.execution_metadata.rollout_id
]
async with rollout_logging_context(
primary_rollout_id or "",
Expand Down Expand Up @@ -596,7 +598,9 @@ async def _collect_result(config, lst):
row_groups[row.input_metadata.row_id].append(row)
tasks = []
for _, rows in row_groups.items():
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
tasks.append(
asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))
)
results = []
for task in tasks:
res = await task
Expand Down Expand Up @@ -692,9 +696,9 @@ async def _collect_result(config, lst):
# For other processors, create all tasks at once and run in parallel
# Concurrency is now controlled by the shared semaphore in each rollout processor
await run_tasks_with_run_progress(execute_run, num_runs, config)

experiment_duration_seconds = time.perf_counter() - experiment_start_time

# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
# rollout_id is used to differentiate the result from different completion_params
if mode == "groupwise":
Expand Down Expand Up @@ -730,15 +734,12 @@ async def _collect_result(config, lst):
experiment_duration_seconds,
)



if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
raise AssertionError(
"Some EvaluationRow instances are missing evaluation_result. "
"Your @evaluation_test function must set `row.evaluation_result`"
)


except AssertionError:
_log_eval_error(
Status.eval_finished(),
Expand Down
Loading