diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index ec6f983b..b3c7c334 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -22,6 +22,7 @@ from openai.types import CompletionUsage from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig +from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm from pydantic import BaseModel from typing import Optional @@ -251,8 +252,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" start_time = time.perf_counter() + # Normalize Fireworks model names for LiteLLM routing + completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {} + row.input_metadata.completion_params = completion_params agent = Agent( - model=row.input_metadata.completion_params["model"], + model=completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger, diff --git a/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py b/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py index 27d44b80..23d727cc 100644 --- a/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py +++ b/eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py @@ -11,6 +11,7 @@ from eval_protocol.models import EvaluationRow from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm from eval_protocol.pytest.default_agent_rollout_processor import Agent from klavis import Klavis @@ -30,7 +31,7 @@ def __init__( self.server_name = server_name self.initialize_data_factory = initialize_data_factory self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY")) - + def _init_sandbox(self) -> CreateSandboxResponse: try: server_name_enum = SandboxMcpServer(self.server_name) @@ -38,7 +39,7 @@ def _init_sandbox(self) -> CreateSandboxResponse: except Exception as e: logger.error(f"Error creating sandbox: {str(e)}", exc_info=True) raise - + @staticmethod def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str: """Create a temporary MCP config file and return its path.""" @@ -47,26 +48,24 @@ def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str server_key: { "url": server_url, "transport": "streamable_http", - **({"authorization": f"Bearer {auth_token}"} if auth_token else {}) + **({"authorization": f"Bearer {auth_token}"} if auth_token else {}), } } } - + # Create a temp file that persists for the session fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_") - with os.fdopen(fd, 'w') as f: + with os.fdopen(fd, "w") as f: json.dump(config, f) return path - def __call__( - self, rows: List[EvaluationRow], config: RolloutProcessorConfig - ) -> List[asyncio.Task[EvaluationRow]]: + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: """Process evaluation rows with Klavis sandbox lifecycle management""" semaphore = config.semaphore async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with complete sandbox lifecycle""" - + start_time = time.perf_counter() agent: Agent | None = None temp_config_path: str | None = None @@ -88,25 +87,32 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if row.input_metadata is not None else None ) - + if init_data: - logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") + logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") # pyright: ignore[reportOptionalMemberAccess] initialize_method = getattr( - self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox" + self.klavis_client.sandbox, + f"initialize_{sandbox.server_name.value}_sandbox", # pyright: ignore[reportOptionalMemberAccess] ) - init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) + init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) # pyright: ignore[reportOptionalMemberAccess] logger.info(f"Initialization response: {init_response}") - + # Step 2: Create temporary MCP config with sandbox URL temp_config_path = self.create_mcp_config( - server_url=sandbox.server_url, server_key=sandbox.server_name.value + server_url=sandbox.server_url, # pyright: ignore[reportOptionalMemberAccess] + server_key=sandbox.server_name.value, # pyright: ignore[reportOptionalMemberAccess] ) logger.info(f"MCP config created: {temp_config_path}") # Step 3: Run agent with sandbox MCP server - logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox") + logger.info( + f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox" + ) + # Normalize Fireworks model names for LiteLLM routing + completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {} + row.input_metadata.completion_params = completion_params agent = Agent( - model=row.input_metadata.completion_params["model"], + model=completion_params["model"], row=row, config_path=temp_config_path, logger=config.logger, @@ -124,8 +130,8 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}") # Step 4: Export sandbox data - dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") - dump_response = dump_method(sandbox_id=sandbox.sandbox_id) + dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") # pyright: ignore[reportOptionalMemberAccess] + dump_response = dump_method(sandbox_id=sandbox.sandbox_id) # pyright: ignore[reportOptionalMemberAccess] sandbox_data = dump_response.data logger.info(f"Sandbox data: {sandbox_data}") @@ -133,7 +139,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if not row.execution_metadata.extra: row.execution_metadata.extra = {} row.execution_metadata.extra["sandbox_data"] = sandbox_data - row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id + row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id # pyright: ignore[reportOptionalMemberAccess] row.execution_metadata.extra["server_name"] = self.server_name except Exception as e: @@ -149,7 +155,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: await agent.mcp_client.cleanup() if temp_config_path and os.path.exists(temp_config_path): os.unlink(temp_config_path) - + # Release sandbox if sandbox and sandbox.sandbox_id: try: diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 4587be0a..8418740a 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -14,6 +14,7 @@ from eval_protocol.models import EvaluationRow from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode +from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm class MCPServerManager: @@ -280,17 +281,20 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> "Cannot retry without existing server/environments. Call with start_server=True first." ) - model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini") - temperature = config.completion_params.get("temperature", 0.0) - max_tokens = config.completion_params.get("max_tokens", 4096) + # Normalize Fireworks model names for LiteLLM routing + completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {} + # Update all rows with normalized completion_params + for row in rows: + row.input_metadata.completion_params = completion_params + model_id = str(completion_params.get("model") or "gpt-4o-mini") + temperature = completion_params.get("temperature", 0.0) + max_tokens = completion_params.get("max_tokens", 4096) # Pass all other completion_params (e.g. stream=True) via kwargs other_params = { - k: v - for k, v in (config.completion_params or {}).items() - if k not in ["model", "temperature", "max_tokens", "extra_body"] + k: v for k, v in completion_params.items() if k not in ["model", "temperature", "max_tokens", "extra_body"] } - extra_body = config.completion_params.get("extra_body", {}) or {} + extra_body = completion_params.get("extra_body", {}) or {} self.policy = ep.LiteLLMPolicy( model_id=model_id, diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index c3e09ba3..cabab274 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -17,6 +17,7 @@ from openai.types import CompletionUsage from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row asynchronously.""" start_time = time.perf_counter() - + if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") @@ -77,7 +78,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: # Use the Message class method that excludes unsupported fields messages_payload = [message.dump_mdoel_for_chat_completion_request() for message in messages_for_request] - request_params = {"messages": messages_payload, **config.completion_params} + # Normalize Fireworks model names for LiteLLM routing + completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {} + row.input_metadata.completion_params = completion_params + request_params = {"messages": messages_payload, **completion_params} # Ensure caching is disabled only for this request (review feedback) request_params["cache"] = {"no-cache": True} @@ -87,18 +91,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: # Single-level reasoning effort: expect `reasoning_effort` only effort_val = None - if ( - "reasoning_effort" in config.completion_params - and config.completion_params["reasoning_effort"] is not None - ): - effort_val = str(config.completion_params["reasoning_effort"]) # flat shape + if "reasoning_effort" in completion_params and completion_params["reasoning_effort"] is not None: + effort_val = str(completion_params["reasoning_effort"]) # flat shape elif ( - isinstance(config.completion_params.get("extra_body"), dict) - and "reasoning_effort" in config.completion_params["extra_body"] - and config.completion_params["extra_body"]["reasoning_effort"] is not None + isinstance(completion_params.get("extra_body"), dict) + and "reasoning_effort" in completion_params["extra_body"] + and completion_params["extra_body"]["reasoning_effort"] is not None ): # Accept if user passed it directly inside extra_body - effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body + effort_val = str(completion_params["extra_body"]["reasoning_effort"]) # already in extra_body if effort_val: # Always under extra_body so LiteLLM forwards to provider-specific param set diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 16daa18c..84a66805 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -55,7 +55,6 @@ AggregationMethod, add_cost_metrics, log_eval_status_and_rows, - normalize_fireworks_model, parse_ep_completion_params, parse_ep_completion_params_overwrite, parse_ep_max_concurrent_rollouts, @@ -205,7 +204,6 @@ def evaluation_test( max_dataset_rows = parse_ep_max_rows(max_dataset_rows) completion_params = parse_ep_completion_params(completion_params) completion_params = parse_ep_completion_params_overwrite(completion_params) - completion_params = [normalize_fireworks_model(cp) for cp in completion_params] original_completion_params = completion_params passed_threshold = parse_ep_passed_threshold(passed_threshold) data_loaders = parse_ep_dataloaders(data_loaders) @@ -366,7 +364,6 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo row.input_metadata.row_id = generate_id(seed=0, index=index) completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None - completion_params = normalize_fireworks_model(completion_params) # Create eval metadata with test function info and current commit hash eval_metadata = EvalMetadata( name=test_func.__name__, diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index e953617f..64f0c8b3 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -619,22 +619,3 @@ def build_rollout_processor_config( server_script_path=None, kwargs=rollout_processor_kwargs, ) - - -def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None: - """Fireworks model names like 'accounts//models/' need the fireworks_ai/ - prefix when routing through LiteLLM. This function adds the prefix if missing. - """ - if completion_params is None: - return None - - model = completion_params.get("model") - if ( - model - and isinstance(model, str) - and not model.startswith("fireworks_ai/") - and re.match(r"^accounts/[^/]+/models/.+", model) - ): - completion_params = completion_params.copy() - completion_params["model"] = f"fireworks_ai/{model}" - return completion_params diff --git a/eval_protocol/pytest/github_action_rollout_processor.py b/eval_protocol/pytest/github_action_rollout_processor.py index 3e4f9ec0..9e4cbb50 100644 --- a/eval_protocol/pytest/github_action_rollout_processor.py +++ b/eval_protocol/pytest/github_action_rollout_processor.py @@ -11,6 +11,7 @@ from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace +from .utils import normalize_fireworks_model_for_litellm class GithubActionRolloutProcessor(RolloutProcessor): @@ -80,6 +81,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: if row.input_metadata.row_id is None: raise ValueError("Row ID is required in GithubActionRolloutProcessor") + # Normalize Fireworks model names for LiteLLM routing + config.completion_params = ( + normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params + ) + row.input_metadata.completion_params = config.completion_params + init_request = build_init_request(row, config, self.model_base_url) def _dispatch_workflow(): diff --git a/eval_protocol/pytest/openenv_rollout_processor.py b/eval_protocol/pytest/openenv_rollout_processor.py index 0f662692..82d80cd1 100644 --- a/eval_protocol/pytest/openenv_rollout_processor.py +++ b/eval_protocol/pytest/openenv_rollout_processor.py @@ -24,6 +24,7 @@ from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm logger = logging.getLogger(__name__) @@ -177,15 +178,18 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: logger.debug("[OpenEnvRolloutProcessor] Environment client created successfully") try: + # Normalize Fireworks model names for LiteLLM routing + completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {} + row.input_metadata.completion_params = completion_params # Get model config - raw_model = config.completion_params.get("model", "gpt-4o-mini") + raw_model = completion_params.get("model", "gpt-4o-mini") model = raw_model - temperature = config.completion_params.get("temperature", 0.0) - max_tokens = config.completion_params.get("max_tokens", 100) + temperature = completion_params.get("temperature", 0.0) + max_tokens = completion_params.get("max_tokens", 100) # Optional: direct routing or provider overrides (e.g., base_url, api_key, top_p, stop, etc.) - base_url = config.completion_params.get("base_url") + base_url = completion_params.get("base_url") # Forward any extra completion params to LiteLLMPolicy (they will be sent per-request) - extra_params: Dict[str, Any] = dict(config.completion_params or {}) + extra_params: Dict[str, Any] = dict(completion_params) for _k in ("model", "temperature", "max_tokens", "base_url"): try: extra_params.pop(_k, None) @@ -247,7 +251,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: messages = list(row.messages) # Copy initial messages # Inject system prompt if provided and not already present has_system = any(m.role == "system" for m in messages) - system_prompt = config.completion_params.get("system_prompt") + system_prompt = completion_params.get("system_prompt") if system_prompt and not has_system: messages.insert(0, Message(role="system", content=system_prompt)) usage = { diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index bf49e543..f2abca78 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -11,6 +11,7 @@ from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace +from .utils import normalize_fireworks_model_for_litellm import logging import os @@ -87,6 +88,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: if row.input_metadata.row_id is None: raise ValueError("Row ID is required in RemoteRolloutProcessor") + # Normalize Fireworks model names for LiteLLM routing + config.completion_params = ( + normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params + ) + row.input_metadata.completion_params = config.completion_params + init_payload = build_init_request(row, config, model_base_url) # Fire-and-poll diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py new file mode 100644 index 00000000..01389f84 --- /dev/null +++ b/eval_protocol/pytest/utils.py @@ -0,0 +1,24 @@ +"""Utility functions for model name handling.""" + +import re + +from eval_protocol.models import CompletionParams + + +def normalize_fireworks_model_for_litellm(completion_params: CompletionParams | None) -> CompletionParams | None: + """Fireworks model names like 'accounts//models/' or 'accounts//deployments/' + need the fireworks_ai/ prefix when routing through LiteLLM. This function adds the prefix if missing. + """ + if completion_params is None: + return None + + model = completion_params.get("model") + if ( + model + and isinstance(model, str) + and not model.startswith("fireworks_ai/") + and re.match(r"^accounts/[^/]+/(models|deployments)/.+", model) + ): + completion_params = completion_params.copy() + completion_params["model"] = f"fireworks_ai/{model}" + return completion_params