Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
try:
# Handle dataset loading
data: list[EvaluationRow] = []
data_loader_rows_preprocessed = False
# Track all rows processed in the current run for error logging
processed_rows_in_run: list[EvaluationRow] = []
if "data_loaders" in kwargs and kwargs["data_loaders"] is not None:
Expand All @@ -318,6 +319,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
results = data_loader.load()
for result in results:
data.extend(result.rows)
data_loader_rows_preprocessed = data_loader_rows_preprocessed or result.preprocessed
# Apply max_dataset_rows limit to data from data loaders
if max_dataset_rows is not None:
data = data[:max_dataset_rows]
Expand Down Expand Up @@ -345,18 +347,11 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
if filtered_row_ids is not None:
data = [row for row in data if row.input_metadata.row_id in filtered_row_ids]

"""
data_loaders handles preprocess_fn internally so we want
to specially handle data_loaders here so we don't double
apply preprocess_fn.
"""
if preprocess_fn:
if not data_loaders:
# If data loaders already applied preprocessing, skip the decorator-level
# preprocess_fn to avoid running the same transform twice.
if not data_loader_rows_preprocessed:
data = preprocess_fn(data)
else:
raise ValueError(
"preprocess_fn should not be used with data_loaders. Pass preprocess_fn to data_loaders instead."
)

for row in data:
# generate a stable row_id for each row
Expand Down
92 changes: 92 additions & 0 deletions tests/pytest/test_preprocess_fn_data_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from eval_protocol.data_loader import DynamicDataLoader
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor


class InMemoryLogger(DatasetLogger):
def log(self, row: EvaluationRow) -> None:
return None

def read(self) -> list[EvaluationRow]:
return []


class StopAfterPreprocess(Exception):
pass


class StopAfterPreprocessRolloutProcessor(NoOpRolloutProcessor):
def setup(self) -> None:
raise StopAfterPreprocess("Stop after preprocessing for focused test assertions")


def _build_rows() -> list[EvaluationRow]:
return [
EvaluationRow(
messages=[
Message(role="user", content="question"),
Message(role="assistant", content="answer"),
]
)
]


async def test_preprocess_fn_runs_with_data_loader_without_loader_preprocess():
call_count = {"decorator_preprocess": 0}

def decorator_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]:
call_count["decorator_preprocess"] += 1
return rows

data_loader = DynamicDataLoader(generators=[_build_rows])

@evaluation_test(
data_loaders=data_loader,
preprocess_fn=decorator_preprocess,
rollout_processor=StopAfterPreprocessRolloutProcessor(),
logger=InMemoryLogger(),
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
row.evaluation_result = EvaluateResult(score=1.0, reason="ok")
return row

try:
await eval_fn(data_loaders=data_loader)
except StopAfterPreprocess:
pass

assert call_count["decorator_preprocess"] == 1


async def test_preprocess_fn_not_double_applied_when_data_loader_preprocess_exists():
call_count = {"loader_preprocess": 0, "decorator_preprocess": 0}

def loader_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]:
call_count["loader_preprocess"] += 1
return rows

def decorator_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]:
call_count["decorator_preprocess"] += 1
return rows

data_loader = DynamicDataLoader(generators=[_build_rows], preprocess_fn=loader_preprocess)

@evaluation_test(
data_loaders=data_loader,
preprocess_fn=decorator_preprocess,
rollout_processor=StopAfterPreprocessRolloutProcessor(),
logger=InMemoryLogger(),
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
row.evaluation_result = EvaluateResult(score=1.0, reason="ok")
return row

try:
await eval_fn(data_loaders=data_loader)
except StopAfterPreprocess:
pass

assert call_count["loader_preprocess"] == 1
assert call_count["decorator_preprocess"] == 0