diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 2256b35f..e70f8794 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -307,6 +307,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo try: # Handle dataset loading data: list[EvaluationRow] = [] + data_loader_rows_preprocessed = False # Track all rows processed in the current run for error logging processed_rows_in_run: list[EvaluationRow] = [] if "data_loaders" in kwargs and kwargs["data_loaders"] is not None: @@ -318,6 +319,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo results = data_loader.load() for result in results: data.extend(result.rows) + data_loader_rows_preprocessed = data_loader_rows_preprocessed or result.preprocessed # Apply max_dataset_rows limit to data from data loaders if max_dataset_rows is not None: data = data[:max_dataset_rows] @@ -345,18 +347,11 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo if filtered_row_ids is not None: data = [row for row in data if row.input_metadata.row_id in filtered_row_ids] - """ - data_loaders handles preprocess_fn internally so we want - to specially handle data_loaders here so we don't double - apply preprocess_fn. - """ if preprocess_fn: - if not data_loaders: + # If data loaders already applied preprocessing, skip the decorator-level + # preprocess_fn to avoid running the same transform twice. + if not data_loader_rows_preprocessed: data = preprocess_fn(data) - else: - raise ValueError( - "preprocess_fn should not be used with data_loaders. Pass preprocess_fn to data_loaders instead." - ) for row in data: # generate a stable row_id for each row diff --git a/tests/pytest/test_preprocess_fn_data_loaders.py b/tests/pytest/test_preprocess_fn_data_loaders.py new file mode 100644 index 00000000..60b91ebc --- /dev/null +++ b/tests/pytest/test_preprocess_fn_data_loaders.py @@ -0,0 +1,92 @@ +from eval_protocol.data_loader import DynamicDataLoader +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger +from eval_protocol.models import EvaluateResult, EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor + + +class InMemoryLogger(DatasetLogger): + def log(self, row: EvaluationRow) -> None: + return None + + def read(self) -> list[EvaluationRow]: + return [] + + +class StopAfterPreprocess(Exception): + pass + + +class StopAfterPreprocessRolloutProcessor(NoOpRolloutProcessor): + def setup(self) -> None: + raise StopAfterPreprocess("Stop after preprocessing for focused test assertions") + + +def _build_rows() -> list[EvaluationRow]: + return [ + EvaluationRow( + messages=[ + Message(role="user", content="question"), + Message(role="assistant", content="answer"), + ] + ) + ] + + +async def test_preprocess_fn_runs_with_data_loader_without_loader_preprocess(): + call_count = {"decorator_preprocess": 0} + + def decorator_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]: + call_count["decorator_preprocess"] += 1 + return rows + + data_loader = DynamicDataLoader(generators=[_build_rows]) + + @evaluation_test( + data_loaders=data_loader, + preprocess_fn=decorator_preprocess, + rollout_processor=StopAfterPreprocessRolloutProcessor(), + logger=InMemoryLogger(), + ) + def eval_fn(row: EvaluationRow) -> EvaluationRow: + row.evaluation_result = EvaluateResult(score=1.0, reason="ok") + return row + + try: + await eval_fn(data_loaders=data_loader) + except StopAfterPreprocess: + pass + + assert call_count["decorator_preprocess"] == 1 + + +async def test_preprocess_fn_not_double_applied_when_data_loader_preprocess_exists(): + call_count = {"loader_preprocess": 0, "decorator_preprocess": 0} + + def loader_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]: + call_count["loader_preprocess"] += 1 + return rows + + def decorator_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]: + call_count["decorator_preprocess"] += 1 + return rows + + data_loader = DynamicDataLoader(generators=[_build_rows], preprocess_fn=loader_preprocess) + + @evaluation_test( + data_loaders=data_loader, + preprocess_fn=decorator_preprocess, + rollout_processor=StopAfterPreprocessRolloutProcessor(), + logger=InMemoryLogger(), + ) + def eval_fn(row: EvaluationRow) -> EvaluationRow: + row.evaluation_result = EvaluateResult(score=1.0, reason="ok") + return row + + try: + await eval_fn(data_loaders=data_loader) + except StopAfterPreprocess: + pass + + assert call_count["loader_preprocess"] == 1 + assert call_count["decorator_preprocess"] == 0