diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index b62e68dc..d9450a48 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -489,6 +489,7 @@ async def execute_run(run_idx: int, config: RolloutProcessorConfig): async def _execute_pointwise_eval_with_semaphore( row: EvaluationRow, ) -> EvaluationRow: + start_time = time.perf_counter() async with semaphore: evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} async with rollout_logging_context( @@ -505,11 +506,15 @@ async def _execute_pointwise_eval_with_semaphore( raise ValueError( f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." ) + result.execution_metadata.eval_duration_seconds = ( + time.perf_counter() - start_time + ) return result async def _execute_groupwise_eval_with_semaphore( rows: list[EvaluationRow], ) -> list[EvaluationRow]: + start_time = time.perf_counter() async with semaphore: evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None @@ -531,6 +536,9 @@ async def _execute_groupwise_eval_with_semaphore( raise ValueError( f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." ) + eval_duration = time.perf_counter() - start_time + for r in results: + r.execution_metadata.eval_duration_seconds = eval_duration return results if mode == "pointwise": @@ -617,11 +625,16 @@ async def _collect_result(config, lst): run_id=run_id, rollout_ids=group_rollout_ids or None, ): + start_time = time.perf_counter() results = await execute_pytest_with_exception_handling( test_func=test_func, evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, processed_dataset=input_dataset, ) + if isinstance(results, list): + eval_duration = time.perf_counter() - start_time + for r in results: + r.execution_metadata.eval_duration_seconds = eval_duration if ( results is None or not isinstance(results, list)