From 904680467d1617feb72578460dc7ec6bdd1db62d Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 8 Jan 2026 11:38:40 -0800 Subject: [PATCH 1/3] auto no prefix needed --- eval_protocol/pytest/evaluation_test.py | 46 ++++++++++--------- eval_protocol/pytest/evaluation_test_utils.py | 19 ++++++++ 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index d9450a48..5ec8a7e0 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -20,13 +20,12 @@ EvaluationRow, EvaluationThreshold, EvaluationThresholdDict, - EvaluateResult, Status, EPParameters, ) from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper from eval_protocol.pytest.evaluation_test_postprocess import postprocess -from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling +from eval_protocol.pytest.execution import execute_pytest_with_exception_handling from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts from eval_protocol.pytest.generate_parameter_combinations import ( ParameterizedTestKwargs, @@ -56,6 +55,7 @@ AggregationMethod, add_cost_metrics, log_eval_status_and_rows, + normalize_fireworks_model, parse_ep_completion_params, parse_ep_completion_params_overwrite, parse_ep_max_concurrent_rollouts, @@ -205,6 +205,7 @@ def evaluation_test( max_dataset_rows = parse_ep_max_rows(max_dataset_rows) completion_params = parse_ep_completion_params(completion_params) completion_params = parse_ep_completion_params_overwrite(completion_params) + completion_params = [normalize_fireworks_model(cp) for cp in completion_params] original_completion_params = completion_params passed_threshold = parse_ep_passed_threshold(passed_threshold) data_loaders = parse_ep_dataloaders(data_loaders) @@ -409,21 +410,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo rollout_processor.setup() - use_priority_scheduler = ( - ( - os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1" - and not isinstance(rollout_processor, MCPGymRolloutProcessor) - ) - ) + use_priority_scheduler = os.environ.get( + "EP_USE_PRIORITY_SCHEDULER", "0" + ) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor) if use_priority_scheduler: microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None) output_dir = os.environ.get("EP_OUTPUT_DIR", None) if microbatch_output_size and output_dir: - output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl")) + output_buffer = MicroBatchDataBuffer( + num_runs=num_runs, + batch_size=int(microbatch_output_size), + output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"), + ) else: output_buffer = None - + try: priority_results = await execute_priority_rollouts( dataset=data, @@ -441,12 +443,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo finally: if output_buffer: await output_buffer.close() - + for res in priority_results: run_idx = (res.execution_metadata.extra or {}).get("run_index", 0) if run_idx < len(all_results): all_results[run_idx].append(res) - + processed_rows_in_run.append(res) postprocess( @@ -462,6 +464,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo ) else: + async def execute_run(run_idx: int, config: RolloutProcessorConfig): nonlocal all_results @@ -506,9 +509,7 @@ async def _execute_pointwise_eval_with_semaphore( raise ValueError( f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." ) - result.execution_metadata.eval_duration_seconds = ( - time.perf_counter() - start_time - ) + result.execution_metadata.eval_duration_seconds = time.perf_counter() - start_time return result async def _execute_groupwise_eval_with_semaphore( @@ -519,7 +520,9 @@ async def _execute_groupwise_eval_with_semaphore( evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {} primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None group_rollout_ids = [ - r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id + r.execution_metadata.rollout_id + for r in rows + if r.execution_metadata.rollout_id ] async with rollout_logging_context( primary_rollout_id or "", @@ -596,7 +599,9 @@ async def _collect_result(config, lst): row_groups[row.input_metadata.row_id].append(row) tasks = [] for _, rows in row_groups.items(): - tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) + tasks.append( + asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)) + ) results = [] for task in tasks: res = await task @@ -692,9 +697,9 @@ async def _collect_result(config, lst): # For other processors, create all tasks at once and run in parallel # Concurrency is now controlled by the shared semaphore in each rollout processor await run_tasks_with_run_progress(execute_run, num_runs, config) - + experiment_duration_seconds = time.perf_counter() - experiment_start_time - + # for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them # rollout_id is used to differentiate the result from different completion_params if mode == "groupwise": @@ -730,15 +735,12 @@ async def _collect_result(config, lst): experiment_duration_seconds, ) - - if not all(r.evaluation_result is not None for run_results in all_results for r in run_results): raise AssertionError( "Some EvaluationRow instances are missing evaluation_result. " "Your @evaluation_test function must set `row.evaluation_result`" ) - except AssertionError: _log_eval_error( Status.eval_finished(), diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index 64f0c8b3..e953617f 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -619,3 +619,22 @@ def build_rollout_processor_config( server_script_path=None, kwargs=rollout_processor_kwargs, ) + + +def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None: + """Fireworks model names like 'accounts//models/' need the fireworks_ai/ + prefix when routing through LiteLLM. This function adds the prefix if missing. + """ + if completion_params is None: + return None + + model = completion_params.get("model") + if ( + model + and isinstance(model, str) + and not model.startswith("fireworks_ai/") + and re.match(r"^accounts/[^/]+/models/.+", model) + ): + completion_params = completion_params.copy() + completion_params["model"] = f"fireworks_ai/{model}" + return completion_params From 609048829b93fb8fe83773ed377e6d485efdfaee Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 8 Jan 2026 11:44:03 -0800 Subject: [PATCH 2/3] update --- eval_protocol/pytest/evaluation_test.py | 1 + tests/remote_server/test_remote_fireworks.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 5ec8a7e0..6c8a647a 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -366,6 +366,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo row.input_metadata.row_id = generate_id(seed=0, index=index) completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None + completion_params = normalize_fireworks_model(completion_params) # Create eval metadata with test function info and current commit hash eval_metadata = EvalMetadata( name=test_func.__name__, diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index e172e309..43da29ed 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -105,7 +105,7 @@ def rows() -> List[EvaluationRow]: @pytest.mark.parametrize( "completion_params", - [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}], + [{"model": "accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}], ) @evaluation_test( data_loaders=DynamicDataLoader( From 764ac4f132c35fe01c354b4150cbc19c7eedea12 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Thu, 8 Jan 2026 13:29:17 -0800 Subject: [PATCH 3/3] update test --- tests/pytest/test_pydantic_agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/pytest/test_pydantic_agent.py b/tests/pytest/test_pydantic_agent.py index c22faf64..52fc3826 100644 --- a/tests/pytest/test_pydantic_agent.py +++ b/tests/pytest/test_pydantic_agent.py @@ -10,7 +10,10 @@ def agent_factory(config: RolloutProcessorConfig) -> Agent: - model = OpenAIChatModel(config.completion_params["model"], provider="fireworks") + model_name = config.completion_params["model"] + if model_name.startswith("fireworks_ai/"): + model_name = model_name[len("fireworks_ai/") :] + model = OpenAIChatModel(model_name, provider="fireworks") return Agent(model=model)