Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
EvaluationRow,
EvaluationThreshold,
EvaluationThresholdDict,
EvaluateResult,
Status,
EPParameters,
)
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling
from eval_protocol.pytest.execution import execute_pytest_with_exception_handling
from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts
from eval_protocol.pytest.generate_parameter_combinations import (
ParameterizedTestKwargs,
Expand Down Expand Up @@ -56,6 +55,7 @@
AggregationMethod,
add_cost_metrics,
log_eval_status_and_rows,
normalize_fireworks_model,
parse_ep_completion_params,
parse_ep_completion_params_overwrite,
parse_ep_max_concurrent_rollouts,
Expand Down Expand Up @@ -205,6 +205,7 @@ def evaluation_test(
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
completion_params = parse_ep_completion_params(completion_params)
completion_params = parse_ep_completion_params_overwrite(completion_params)
completion_params = [normalize_fireworks_model(cp) for cp in completion_params]
original_completion_params = completion_params
passed_threshold = parse_ep_passed_threshold(passed_threshold)
data_loaders = parse_ep_dataloaders(data_loaders)
Expand Down Expand Up @@ -365,6 +366,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
row.input_metadata.row_id = generate_id(seed=0, index=index)

completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
completion_params = normalize_fireworks_model(completion_params)
# Create eval metadata with test function info and current commit hash
eval_metadata = EvalMetadata(
name=test_func.__name__,
Expand Down Expand Up @@ -409,21 +411,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo

rollout_processor.setup()

use_priority_scheduler = (
(
os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1"
and not isinstance(rollout_processor, MCPGymRolloutProcessor)
)
)
use_priority_scheduler = os.environ.get(
"EP_USE_PRIORITY_SCHEDULER", "0"
) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor)

if use_priority_scheduler:
microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None)
output_dir = os.environ.get("EP_OUTPUT_DIR", None)
if microbatch_output_size and output_dir:
output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
output_buffer = MicroBatchDataBuffer(
num_runs=num_runs,
batch_size=int(microbatch_output_size),
output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"),
)
else:
output_buffer = None

try:
priority_results = await execute_priority_rollouts(
dataset=data,
Expand All @@ -441,12 +444,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
finally:
if output_buffer:
await output_buffer.close()

for res in priority_results:
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
if run_idx < len(all_results):
all_results[run_idx].append(res)

processed_rows_in_run.append(res)

postprocess(
Expand All @@ -462,6 +465,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
)

else:

async def execute_run(run_idx: int, config: RolloutProcessorConfig):
nonlocal all_results

Expand Down Expand Up @@ -506,9 +510,7 @@ async def _execute_pointwise_eval_with_semaphore(
raise ValueError(
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
)
result.execution_metadata.eval_duration_seconds = (
time.perf_counter() - start_time
)
result.execution_metadata.eval_duration_seconds = time.perf_counter() - start_time
return result

async def _execute_groupwise_eval_with_semaphore(
Expand All @@ -519,7 +521,9 @@ async def _execute_groupwise_eval_with_semaphore(
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
group_rollout_ids = [
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
r.execution_metadata.rollout_id
for r in rows
if r.execution_metadata.rollout_id
]
async with rollout_logging_context(
primary_rollout_id or "",
Expand Down Expand Up @@ -596,7 +600,9 @@ async def _collect_result(config, lst):
row_groups[row.input_metadata.row_id].append(row)
tasks = []
for _, rows in row_groups.items():
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
tasks.append(
asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))
)
results = []
for task in tasks:
res = await task
Expand Down Expand Up @@ -692,9 +698,9 @@ async def _collect_result(config, lst):
# For other processors, create all tasks at once and run in parallel
# Concurrency is now controlled by the shared semaphore in each rollout processor
await run_tasks_with_run_progress(execute_run, num_runs, config)

experiment_duration_seconds = time.perf_counter() - experiment_start_time

# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
# rollout_id is used to differentiate the result from different completion_params
if mode == "groupwise":
Expand Down Expand Up @@ -730,15 +736,12 @@ async def _collect_result(config, lst):
experiment_duration_seconds,
)



if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
raise AssertionError(
"Some EvaluationRow instances are missing evaluation_result. "
"Your @evaluation_test function must set `row.evaluation_result`"
)


except AssertionError:
_log_eval_error(
Status.eval_finished(),
Expand Down
19 changes: 19 additions & 0 deletions eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,22 @@ def build_rollout_processor_config(
server_script_path=None,
kwargs=rollout_processor_kwargs,
)


def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None:
"""Fireworks model names like 'accounts/<org>/models/<model>' need the fireworks_ai/
prefix when routing through LiteLLM. This function adds the prefix if missing.
"""
if completion_params is None:
return None

model = completion_params.get("model")
if (
model
and isinstance(model, str)
and not model.startswith("fireworks_ai/")
and re.match(r"^accounts/[^/]+/models/.+", model)
):
completion_params = completion_params.copy()
completion_params["model"] = f"fireworks_ai/{model}"
return completion_params
5 changes: 4 additions & 1 deletion tests/pytest/test_pydantic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@


def agent_factory(config: RolloutProcessorConfig) -> Agent:
model = OpenAIChatModel(config.completion_params["model"], provider="fireworks")
model_name = config.completion_params["model"]
if model_name.startswith("fireworks_ai/"):
model_name = model_name[len("fireworks_ai/") :]
model = OpenAIChatModel(model_name, provider="fireworks")
return Agent(model=model)


Expand Down
2 changes: 1 addition & 1 deletion tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def rows() -> List[EvaluationRow]:

@pytest.mark.parametrize(
"completion_params",
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
[{"model": "accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
)
@evaluation_test(
data_loaders=DynamicDataLoader(
Expand Down
Loading