diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 8e5c7d15..4913e33b 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -253,6 +253,7 @@ def __init__( project_id: Optional[str] = None, base_url: str = "https://tracing.fireworks.ai", timeout: int = 300, + api_key: Optional[str] = None, ): """Initialize the Fireworks Tracing adapter. @@ -260,10 +261,16 @@ def __init__( project_id: Optional project ID. If not provided, uses the default project configured on the server. base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai) timeout: Request timeout in seconds (default: 300) + api_key: Optional API key. If not provided, falls back to FIREWORKS_API_KEY environment variable. """ self.project_id = project_id self.base_url = base_url.rstrip("/") self.timeout = timeout + self._api_key = api_key + + def _get_api_key(self) -> Optional[str]: + """Get the API key, preferring instance-level key over environment variable.""" + return self._api_key or os.environ.get("FIREWORKS_API_KEY") def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -> List[Dict[str, Any]]: """Fetch logs from Fireworks tracing gateway /logs endpoint. @@ -276,7 +283,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) - from ..common_utils import get_user_agent headers = { - "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "Authorization": f"Bearer {self._get_api_key()}", "User-Agent": get_user_agent(), } params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"} @@ -407,7 +414,7 @@ def get_evaluation_rows( from ..common_utils import get_user_agent headers = { - "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "Authorization": f"Bearer {self._get_api_key()}", "User-Agent": get_user_agent(), } diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 158fcbb4..7d6b1714 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -19,7 +19,9 @@ def default_fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDat def fetch_traces() -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) + # Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY + api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY") + adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key) return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -131,7 +133,9 @@ def build_init_request( final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url) # Extract API key from environment or completion_params - api_key = os.environ.get("FIREWORKS_API_KEY") + # EP_REMOTE_API_KEY takes precedence for remote rollout processors, + # falling back to FIREWORKS_API_KEY for backwards compatibility + api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY") return InitRequest( completion_params=completion_params_dict,