From 0a00bcdd486c2e583776774baf92754068bf4044 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 7 Jan 2026 13:07:09 -0800 Subject: [PATCH 01/39] Refactor Fireworks client integration - Introduced a new `fireworks_client.py` module to centralize Fireworks SDK client creation. - Updated CLI and evaluation modules to use the new `create_fireworks_client` function instead of direct instantiation of the Fireworks class. - Enhanced handling of API key, account ID, base URL, and extra headers through environment variables. - Added tests for the new Fireworks client factory to ensure proper functionality and configuration. --- eval_protocol/cli.py | 7 +- eval_protocol/cli_commands/create_rft.py | 7 +- eval_protocol/evaluation.py | 17 ++- eval_protocol/fireworks_client.py | 132 +++++++++++++++++++++ eval_protocol/platform_api.py | 21 +++- tests/test_fireworks_client.py | 143 +++++++++++++++++++++++ 6 files changed, 310 insertions(+), 17 deletions(-) create mode 100644 eval_protocol/fireworks_client.py create mode 100644 tests/test_fireworks_client.py diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index ac8a8d9d..743fe15d 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -8,10 +8,9 @@ import sys from pathlib import Path -from fireworks import Fireworks - from .cli_commands.common import setup_logging from .cli_commands.utils import add_args_from_callable_signature +from .fireworks_client import create_fireworks_client logger = logging.getLogger(__name__) @@ -88,7 +87,7 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse ) # Auto-generate flags from SDK Fireworks().evaluators.create() signature - create_evaluator_fn = Fireworks().evaluators.create + create_evaluator_fn = create_fireworks_client().evaluators.create upload_skip_fields = { "__top_level__": { @@ -198,7 +197,7 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.", } - create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create + create_rft_job_fn = create_fireworks_client().reinforcement_fine_tuning_jobs.create add_args_from_callable_signature( rft_parser, diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 702eb2fe..4865a62f 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -12,6 +12,7 @@ from pydantic import ValidationError from ..auth import get_fireworks_api_base, get_fireworks_api_key +from ..fireworks_client import create_fireworks_client from ..common_utils import get_user_agent, load_jsonl from ..fireworks_rft import ( create_dataset_from_jsonl, @@ -35,8 +36,6 @@ ) from .local_test import run_evaluator_test -from fireworks import Fireworks - def _extract_dataset_adapter( test_file_path: str, test_func_name: str @@ -672,7 +671,7 @@ def _create_rft_job( ) -> int: """Build and submit the RFT job request (via Fireworks SDK).""" - signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create) + signature = inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create) # Build top-level SDK kwargs sdk_kwargs: Dict[str, Any] = { @@ -711,7 +710,7 @@ def _create_rft_job( return 0 try: - fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base) + fw: Fireworks = create_fireworks_client(api_key=api_key, base_url=api_base) job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs) job_name = job.name print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}") diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 9c84d34e..38b9e011 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -5,13 +5,13 @@ import fireworks import requests -from fireworks import Fireworks from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from eval_protocol.fireworks_client import create_fireworks_client from eval_protocol.get_pep440_version import get_pep440_version logger = logging.getLogger(__name__) @@ -164,7 +164,11 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) logger.error("Authentication error: API credentials appear to be invalid or incomplete.") raise ValueError("Invalid or missing API credentials.") - client = Fireworks(api_key=auth_token, base_url=self.api_base, account_id=account_id) + client = create_fireworks_client( + api_key=auth_token, + base_url=self.api_base, + account_id=account_id, + ) self.display_name = display_name or evaluator_id self.description = description or f"Evaluator created from {evaluator_id}" @@ -239,9 +243,12 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) tar_size = self._create_tar_gz_with_ignores(tar_path, cwd) + version_id = "test" + # Call GetEvaluatorUploadEndpoint using SDK logger.info(f"Requesting upload endpoint for {tar_filename}") - upload_response = client.evaluators.get_upload_endpoint( + upload_response = client.evaluator_versions.get_upload_endpoint( + version_id=version_id, evaluator_id=evaluator_id, filename_to_size={tar_filename: str(tar_size)}, ) @@ -322,9 +329,9 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) raise # Step 3: Validate upload using SDK - client.evaluators.validate_upload( + client.evaluator_versions.validate_upload( + version_id=version_id, evaluator_id=evaluator_id, - body={}, ) logger.info("Upload validated successfully") diff --git a/eval_protocol/fireworks_client.py b/eval_protocol/fireworks_client.py new file mode 100644 index 00000000..d92d8bec --- /dev/null +++ b/eval_protocol/fireworks_client.py @@ -0,0 +1,132 @@ +""" +Consolidated Fireworks client factory. + +This module provides a single point of instantiation for the Fireworks SDK client, +ensuring consistent handling of environment variables and configuration across the +eval_protocol codebase. + +Environment variables: + FIREWORKS_API_KEY: API key for authentication (required) + FIREWORKS_ACCOUNT_ID: Account ID (optional, can be derived from API key) + FIREWORKS_API_BASE: Base URL for the API (default: https://api.fireworks.ai) + FIREWORKS_EXTRA_HEADERS: JSON-encoded extra headers to include in requests + Example: '{"X-Custom-Header": "value", "X-Another": "another-value"}' +""" + +import json +import logging +import os +from typing import Mapping, Optional + +from fireworks import Fireworks + +from eval_protocol.auth import ( + get_fireworks_account_id, + get_fireworks_api_base, + get_fireworks_api_key, +) + +logger = logging.getLogger(__name__) + + +def get_fireworks_extra_headers() -> Optional[Mapping[str, str]]: + """ + Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable. + + The value should be a JSON-encoded object mapping header names to values. + Example: '{"X-Custom-Header": "value"}' + + Returns: + A mapping of header names to values if set and valid, otherwise None. + """ + extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS") + if not extra_headers_str or not extra_headers_str.strip(): + return None + + try: + headers = json.loads(extra_headers_str) + if not isinstance(headers, dict): + logger.warning( + "FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s. Ignoring.", + type(headers).__name__, + ) + return None + # Validate all keys and values are strings + for k, v in headers.items(): + if not isinstance(k, str) or not isinstance(v, str): + logger.warning( + "FIREWORKS_EXTRA_HEADERS contains non-string key or value: %s=%s. Ignoring all extra headers.", + k, + v, + ) + return None + logger.debug("Using FIREWORKS_EXTRA_HEADERS: %s", list(headers.keys())) + return headers + except json.JSONDecodeError as e: + logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s. Ignoring.", e) + return None + + +def create_fireworks_client( + *, + api_key: Optional[str] = None, + account_id: Optional[str] = None, + base_url: Optional[str] = None, + extra_headers: Optional[Mapping[str, str]] = None, +) -> Fireworks: + """ + Create a Fireworks client with consistent configuration. + + This factory function centralizes the logic for creating Fireworks clients, + ensuring that environment variables are handled consistently across the codebase. + + Resolution order for each parameter: + 1. Explicit argument passed to this function + 2. Environment variable (via auth module helpers) + 3. SDK defaults (for base_url only) + + Args: + api_key: Fireworks API key. If not provided, resolves from FIREWORKS_API_KEY. + account_id: Fireworks account ID. If not provided, resolves from FIREWORKS_ACCOUNT_ID + or derives from the API key via the verifyApiKey endpoint. + base_url: Base URL for the Fireworks API. If not provided, resolves from + FIREWORKS_API_BASE or defaults to https://api.fireworks.ai. + extra_headers: Additional headers to include in all requests. If not provided, + resolves from FIREWORKS_EXTRA_HEADERS environment variable (JSON-encoded). + + Returns: + A configured Fireworks client instance. + + Raises: + fireworks.FireworksError: If api_key is not provided and FIREWORKS_API_KEY + environment variable is not set. + """ + # Resolve parameters from environment if not explicitly provided + resolved_api_key = api_key or get_fireworks_api_key() + resolved_account_id = account_id or get_fireworks_account_id() + resolved_base_url = base_url or get_fireworks_api_base() + + # Merge extra headers: env var headers first, then explicit headers override + env_extra_headers = get_fireworks_extra_headers() + merged_headers: Optional[Mapping[str, str]] = None + if env_extra_headers or extra_headers: + merged = {} + if env_extra_headers: + merged.update(env_extra_headers) + if extra_headers: + merged.update(extra_headers) + merged_headers = merged if merged else None + + logger.debug( + "Creating Fireworks client: base_url=%s, account_id=%s, extra_headers=%s", + resolved_base_url, + resolved_account_id, + list(merged_headers.keys()) if merged_headers else None, + ) + + return Fireworks( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_base_url, + default_headers=merged_headers, + ) diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 60743ccb..f6dd2d89 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -10,8 +10,9 @@ get_fireworks_api_base, get_fireworks_api_key, ) +from eval_protocol.fireworks_client import create_fireworks_client from fireworks.types import Secret -from fireworks import Fireworks, FireworksError, NotFoundError, InternalServerError +from fireworks import FireworksError, NotFoundError, InternalServerError logger = logging.getLogger(__name__) @@ -88,7 +89,11 @@ def create_or_update_fireworks_secret( resolved_api_key = api_key or get_fireworks_api_key() resolved_api_base = api_base or get_fireworks_api_base() resolved_account_id = account_id # Must be provided - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) if not all([resolved_api_key, resolved_api_base, resolved_account_id]): logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") @@ -173,7 +178,11 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) resource_id = _normalize_secret_resource_id(key_name) try: @@ -215,7 +224,11 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) resource_id = _normalize_secret_resource_id(key_name) try: diff --git a/tests/test_fireworks_client.py b/tests/test_fireworks_client.py new file mode 100644 index 00000000..db0b08c6 --- /dev/null +++ b/tests/test_fireworks_client.py @@ -0,0 +1,143 @@ +"""Tests for the consolidated Fireworks client factory.""" + +import os +from unittest.mock import patch + +import pytest + +from eval_protocol.fireworks_client import ( + create_fireworks_client, + get_fireworks_extra_headers, +) + + +class TestGetFireworksExtraHeaders: + """Tests for get_fireworks_extra_headers function.""" + + def test_returns_none_when_env_var_not_set(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is not set.""" + with patch.dict(os.environ, {}, clear=True): + # Remove the env var if it exists + os.environ.pop("FIREWORKS_EXTRA_HEADERS", None) + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_empty_string(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is empty.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": ""}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_whitespace_only(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is whitespace only.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": " "}): + result = get_fireworks_extra_headers() + assert result is None + + def test_parses_valid_json_object(self): + """Should parse valid JSON object into dict.""" + headers = '{"X-Custom": "value", "X-Another": "test"}' + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": headers}): + result = get_fireworks_extra_headers() + assert result == {"X-Custom": "value", "X-Another": "test"} + + def test_returns_none_for_invalid_json(self): + """Should return None and log warning for invalid JSON.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": "not json"}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_json_array(self): + """Should return None when JSON is an array instead of object.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '["item1", "item2"]'}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_json_string(self): + """Should return None when JSON is a string instead of object.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '"just a string"'}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_non_string_values(self): + """Should return None when JSON object has non-string values.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '{"key": 123}'}): + result = get_fireworks_extra_headers() + assert result is None + + +class TestCreateFireworksClient: + """Tests for create_fireworks_client function.""" + + def test_creates_client_with_explicit_api_key(self): + """Should create client with explicitly provided API key.""" + client = create_fireworks_client(api_key="test-api-key") + assert client.api_key == "test-api-key" + + def test_creates_client_with_explicit_base_url(self): + """Should create client with explicitly provided base URL.""" + client = create_fireworks_client( + api_key="test-api-key", + base_url="https://custom.api.example.com", + ) + assert str(client.base_url).rstrip("/") == "https://custom.api.example.com" + + def test_creates_client_with_explicit_account_id(self): + """Should create client with explicitly provided account ID.""" + client = create_fireworks_client( + api_key="test-api-key", + account_id="test-account-123", + ) + assert client.account_id == "test-account-123" + + def test_creates_client_with_explicit_extra_headers(self): + """Should create client with explicitly provided extra headers.""" + extra_headers = {"X-Custom-Header": "test-value"} + client = create_fireworks_client( + api_key="test-api-key", + extra_headers=extra_headers, + ) + assert "X-Custom-Header" in client._custom_headers + assert client._custom_headers["X-Custom-Header"] == "test-value" + + def test_merges_env_and_explicit_extra_headers(self): + """Should merge env var headers with explicit headers, explicit taking precedence.""" + env_headers = '{"X-Env-Header": "env-value", "X-Override": "env"}' + explicit_headers = {"X-Explicit-Header": "explicit-value", "X-Override": "explicit"} + + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": env_headers}): + client = create_fireworks_client( + api_key="test-api-key", + extra_headers=explicit_headers, + ) + # Both headers should be present + assert client._custom_headers["X-Env-Header"] == "env-value" + assert client._custom_headers["X-Explicit-Header"] == "explicit-value" + # Explicit should override env + assert client._custom_headers["X-Override"] == "explicit" + + def test_uses_env_extra_headers_when_no_explicit(self): + """Should use env var extra headers when no explicit headers provided.""" + env_headers = '{"X-Env-Header": "env-value"}' + + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": env_headers}): + client = create_fireworks_client(api_key="test-api-key") + assert client._custom_headers["X-Env-Header"] == "env-value" + + def test_resolves_api_key_from_env(self): + """Should resolve API key from environment when not explicitly provided.""" + with patch.dict(os.environ, {"FIREWORKS_API_KEY": "env-api-key"}): + client = create_fireworks_client() + assert client.api_key == "env-api-key" + + def test_resolves_base_url_from_env(self): + """Should resolve base URL from environment when not explicitly provided.""" + with patch.dict( + os.environ, + { + "FIREWORKS_API_KEY": "test-key", + "FIREWORKS_API_BASE": "https://env.api.example.com", + }, + ): + client = create_fireworks_client() + assert str(client.base_url).rstrip("/") == "https://env.api.example.com" From d465a896113fa3a4da9d4ed1c839ad46c3240039 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 7 Jan 2026 14:40:09 -0800 Subject: [PATCH 02/39] remove launch.json --- .vscode/launch.json | 39 --------------------------------------- 1 file changed, 39 deletions(-) delete mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 38fff2f8..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Debug Tests", - "type": "python", - "request": "launch", - "module": "pytest", - "args": ["-s", "--tb=short", "${file}"], - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - }, - { - "name": "Python: Debug Current File", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - }, - { - "name": "Python: Debug Logs Server", - "type": "python", - "request": "launch", - "module": "eval_protocol.utils.logs_server", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - } - ] -} From 348bb58d617c4eb8fa52d805c53677bd6af20dc6 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 7 Jan 2026 15:33:25 -0800 Subject: [PATCH 03/39] Add .vscode/launch.json to .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 7e434174..918f7fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -243,3 +243,5 @@ package.json tau2-bench *.err eval-protocol + +.vscode/launch.json From acaa901670805064f49a3a23f4f42440e3147561 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 7 Jan 2026 15:33:40 -0800 Subject: [PATCH 04/39] Enhance environment variable loading in auth module - Added functionality to load environment variables from .env.dev or .env as a fallback when the auth module is imported. - Updated the API key verification process to allow explicit base URL handling, defaulting to dev.api.fireworks.ai if not provided. - Removed redundant environment variable loading code from platform_api module. --- eval_protocol/auth.py | 30 +++++++++++++++++++++++++++++- eval_protocol/platform_api.py | 23 ----------------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 68ce134c..7be1aed5 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -3,9 +3,30 @@ from typing import Optional import requests +from dotenv import find_dotenv, load_dotenv logger = logging.getLogger(__name__) +# --- Load .env files --- +# Attempt to load .env.dev first, then .env as a fallback. +# This happens when the module is imported. +# We use override=False (default) so that existing environment variables +# (e.g., set in the shell) are NOT overridden by .env files. +_ENV_DEV_PATH = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) +if _ENV_DEV_PATH: + load_dotenv(dotenv_path=_ENV_DEV_PATH, override=False) + logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_ENV_DEV_PATH}") +else: + _ENV_PATH = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) + if _ENV_PATH: + load_dotenv(dotenv_path=_ENV_PATH, override=False) + logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_ENV_PATH}") + else: + logger.debug( + "eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables." + ) +# --- End .env loading --- + def get_fireworks_api_key() -> Optional[str]: """ @@ -73,6 +94,8 @@ def verify_api_key_and_get_account_id( Args: api_key: Optional explicit API key. When None, resolves via get_fireworks_api_key(). api_base: Optional explicit API base. When None, resolves via get_fireworks_api_base(). + If api_base is api.fireworks.ai, it is used directly. Otherwise, defaults to + dev.api.fireworks.ai for the verification call. Returns: The resolved account id if verification succeeds and the header is present; otherwise None. @@ -81,7 +104,12 @@ def verify_api_key_and_get_account_id( resolved_key = api_key or get_fireworks_api_key() if not resolved_key: return None - resolved_base = api_base or get_fireworks_api_base() + provided_base = api_base or get_fireworks_api_base() + # Use api.fireworks.ai if explicitly provided, otherwise fall back to dev + if "api.fireworks.ai" in provided_base: + resolved_base = provided_base + else: + resolved_base = "https://dev.api.fireworks.ai" from .common_utils import get_user_agent diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index f6dd2d89..8b07f4d7 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -3,8 +3,6 @@ import sys from typing import Optional -from dotenv import find_dotenv, load_dotenv - from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_base, @@ -16,27 +14,6 @@ logger = logging.getLogger(__name__) -# --- Load .env files --- -# Attempt to load .env.dev first, then .env as a fallback. -# This happens when the module is imported. -# We use override=False (default) so that existing environment variables -# (e.g., set in the shell) are NOT overridden by .env files. -ENV_DEV_PATH = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) -if ENV_DEV_PATH: - load_dotenv(dotenv_path=ENV_DEV_PATH, override=False) - logger.info(f"eval_protocol.platform_api: Loaded environment variables from: {ENV_DEV_PATH}") -else: - ENV_PATH = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) - if ENV_PATH: - load_dotenv(dotenv_path=ENV_PATH, override=False) - logger.info(f"eval_protocol.platform_api: Loaded environment variables from: {ENV_PATH}") - else: - logger.info( - "eval_protocol.platform_api: No .env.dev or .env file found. " - "Relying on shell/existing environment variables." - ) -# --- End .env loading --- - class PlatformAPIError(Exception): """Custom exception for platform API errors.""" From 4b71ddb91ca2c0a412686d78938953305f6de9a6 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 7 Jan 2026 15:33:46 -0800 Subject: [PATCH 05/39] Add evaluator version creation in evaluation module - Introduced functionality to create evaluator versions using parameters such as commit hash, entry point, and requirements. - Updated the upload endpoint call to utilize the newly created evaluator version ID instead of a hardcoded test version ID. - Added error handling for missing evaluator version ID in the response to ensure robustness during code uploads. --- eval_protocol/evaluation.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 38b9e011..f43ee20f 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -4,6 +4,7 @@ from typing import List, Optional import fireworks +from fireworks.types import EvaluatorVersionParam import requests from eval_protocol.auth import ( @@ -234,6 +235,25 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) f"Cannot proceed with code upload. Response: {result}" ) + evaluator_version_param: EvaluatorVersionParam = {} + if "commit_hash" in evaluator_params: + evaluator_version_param["commit_hash"] = evaluator_params["commit_hash"] + if "entry_point" in evaluator_params: + evaluator_version_param["entry_point"] = evaluator_params["entry_point"] + if "requirements" in evaluator_params: + evaluator_version_param["requirements"] = evaluator_params["requirements"] + + evaluator_version = client.evaluator_versions.create( + evaluator_id=evaluator_id, + evaluator_version=evaluator_version_param, + ) + evaluator_version_id = evaluator_version.name.split("/")[-1] if evaluator_version.name else None + if not evaluator_version_id: + raise ValueError( + "Create evaluator version response missing 'name' field. " + f"Cannot proceed with code upload. Response: {evaluator_version}" + ) + try: # Create tar.gz of current directory cwd = os.getcwd() @@ -243,12 +263,10 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) tar_size = self._create_tar_gz_with_ignores(tar_path, cwd) - version_id = "test" - # Call GetEvaluatorUploadEndpoint using SDK logger.info(f"Requesting upload endpoint for {tar_filename}") upload_response = client.evaluator_versions.get_upload_endpoint( - version_id=version_id, + version_id=evaluator_version_id, evaluator_id=evaluator_id, filename_to_size={tar_filename: str(tar_size)}, ) @@ -330,7 +348,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) # Step 3: Validate upload using SDK client.evaluator_versions.validate_upload( - version_id=version_id, + version_id=evaluator_version_id, evaluator_id=evaluator_id, ) logger.info("Upload validated successfully") From 3dbcd5980210119b6b3c29ec6173453af624bcd1 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 8 Jan 2026 15:20:04 -0800 Subject: [PATCH 06/39] test --- eval_protocol/cli.py | 6 -- eval_protocol/cli_commands/create_rft.py | 64 ++++++-------- eval_protocol/cli_commands/upload.py | 2 - eval_protocol/evaluation.py | 22 +---- tests/test_cli_create_rft.py | 15 ---- tests/test_ep_upload_e2e.py | 107 ----------------------- 6 files changed, 30 insertions(+), 186 deletions(-) diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 743fe15d..40dce34c 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -80,11 +80,6 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "--env-file", help="Path to .env file containing secrets to upload (default: .env in current directory)", ) - upload_parser.add_argument( - "--force", - action="store_true", - help="Overwrite existing evaluator with the same ID", - ) # Auto-generate flags from SDK Fireworks().evaluators.create() signature create_evaluator_fn = create_fireworks_client().evaluators.create @@ -136,7 +131,6 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") - rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") rft_parser.add_argument( "--ignore-docker", diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 4865a62f..fc4d20b4 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -567,37 +567,35 @@ def _upload_and_ensure_evaluator( evaluator_resource_name: str, api_key: str, api_base: str, - force: bool, ) -> bool: """Ensure the evaluator exists and is ACTIVE, uploading it if needed.""" - # Optional short-circuit: if evaluator already exists and not forcing, skip upload path - if not force: - try: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) - if resp.ok: - state = resp.json().get("state", "STATE_UNSPECIFIED") - print(f"✓ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).") - # Poll for ACTIVE before proceeding - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - if not _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ): - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\n❌ Evaluator is not ready within the timeout period.") - print(f"📊 Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - except requests.exceptions.RequestException: - pass + # Check if evaluator already exists + try: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } + resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) + if resp.ok: + state = resp.json().get("state", "STATE_UNSPECIFIED") + print(f"✓ Evaluator exists (state: {state}). Skipping upload.") + # Poll for ACTIVE before proceeding + print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") + if not _poll_evaluator_status( + evaluator_resource_name=evaluator_resource_name, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ): + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\n❌ Evaluator is not ready within the timeout period.") + print(f"📊 Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") + return False + return True + except requests.exceptions.RequestException: + pass # Ensure evaluator exists by invoking the upload flow programmatically try: @@ -622,14 +620,10 @@ def _upload_and_ensure_evaluator( id=evaluator_id, display_name=None, description=None, - force=force, # Pass through the --force flag yes=True, - env_file=None, # Add the new env_file parameter + env_file=None, ) - if force: - print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'") - rc = upload_command(upload_args) if rc == 0: print(f"✓ Uploaded/ensured evaluator: {evaluator_id}") @@ -738,7 +732,6 @@ def create_rft_command(args) -> int: evaluator_arg: Optional[str] = getattr(args, "evaluator", None) non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) - force: bool = bool(getattr(args, "force", False)) skip_validation: bool = bool(getattr(args, "skip_validation", False)) ignore_docker: bool = bool(getattr(args, "ignore_docker", False)) docker_build_extra: str = getattr(args, "docker_build_extra", "") or "" @@ -816,7 +809,6 @@ def create_rft_command(args) -> int: evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, - force=force, ): return 1 diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index a8a132d6..d61b31ae 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -289,7 +289,6 @@ def upload_command(args: argparse.Namespace) -> int: base_id = getattr(args, "id", None) display_name = getattr(args, "display_name", None) description = getattr(args, "description", None) - force = bool(getattr(args, "force", False)) env_file = getattr(args, "env_file", None) # Load secrets from .env file and ensure they're available on Fireworks @@ -382,7 +381,6 @@ def upload_command(args: argparse.Namespace) -> int: evaluator_id=evaluator_id, display_name=display_name or evaluator_id, description=description or f"Evaluator for {qualname}", - force=force, entry_point=entry_point, ) name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index f43ee20f..0187ee3e 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -155,7 +155,7 @@ def _create_tar_gz_with_ignores(output_path: str, source_dir: str) -> int: logger.info(f"Created {output_path} ({size_bytes:,} bytes)") return size_bytes - def create(self, evaluator_id, display_name=None, description=None, force=False): + def create(self, evaluator_id, display_name=None, description=None): auth_token = self.api_key or get_fireworks_api_key() account_id = self.account_id or get_fireworks_account_id() if not account_id and auth_token: @@ -203,22 +203,6 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) logger.info(f"Creating evaluator '{evaluator_id}' for account '{account_id}'...") try: - if force: - try: - logger.info("Checking if evaluator exists") - existing_evaluator = client.evaluators.get(evaluator_id=evaluator_id) - if existing_evaluator: - logger.info(f"Evaluator '{evaluator_id}' already exists, deleting and recreating...") - try: - client.evaluators.delete(evaluator_id=evaluator_id) - logger.info(f"Successfully deleted evaluator '{evaluator_id}'") - except fireworks.NotFoundError: - logger.info(f"Evaluator '{evaluator_id}' not found, creating...") - except fireworks.APIError as e: - logger.warning(f"Error deleting evaluator: {str(e)}") - except fireworks.NotFoundError: - logger.info(f"Evaluator '{evaluator_id}' does not exist, creating...") - # Create evaluator using SDK result = client.evaluators.create( evaluator_id=evaluator_id, @@ -387,7 +371,6 @@ def create_evaluation( evaluator_id: str, display_name: Optional[str] = None, description: Optional[str] = None, - force: bool = False, account_id: Optional[str] = None, api_key: Optional[str] = None, entry_point: Optional[str] = None, @@ -399,7 +382,6 @@ def create_evaluation( evaluator_id: Unique identifier for the evaluator display_name: Display name for the evaluator description: Description for the evaluator - force: If True, delete and recreate if evaluator exists account_id: Optional Fireworks account ID api_key: Optional Fireworks API key entry_point: Optional entry point (module::function or path::function) @@ -410,4 +392,4 @@ def create_evaluation( entry_point=entry_point, ) - return evaluator.create(evaluator_id, display_name, description, force) + return evaluator.create(evaluator_id, display_name, description) diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 1f1e8395..9536ef30 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -239,7 +239,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator=None, yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -299,7 +298,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator=None, yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -351,7 +349,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator="my-evaluator", yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -401,7 +398,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator="my-evaluator", yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -462,7 +458,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d setattr(args, "evaluator", None) setattr(args, "yes", True) setattr(args, "dry_run", False) - setattr(args, "force", False) setattr(args, "env_file", None) setattr(args, "dataset", None) setattr(args, "dataset_jsonl", str(ds_path)) @@ -530,7 +525,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -600,7 +594,6 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -674,7 +667,6 @@ def raise_for_status(self): evaluator="some-eval", yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -727,7 +719,6 @@ def _raise(*a, **k): evaluator="some-eval", yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(project / "dataset.jsonl"), @@ -789,7 +780,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -850,7 +840,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -912,7 +901,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -1007,7 +995,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=eval_id, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -1175,7 +1162,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(explicit_jsonl), @@ -1266,7 +1252,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, diff --git a/tests/test_ep_upload_e2e.py b/tests/test_ep_upload_e2e.py index 56de5fea..9acecc79 100644 --- a/tests/test_ep_upload_e2e.py +++ b/tests/test_ep_upload_e2e.py @@ -108,18 +108,6 @@ def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): mock_validate_response.valid = True mock_client.evaluators.validate_upload.return_value = mock_validate_response - # Mock evaluators.get (for force flow - raises NotFoundError by default) - import fireworks - - mock_client.evaluators.get.side_effect = fireworks.NotFoundError( - "Evaluator not found", - response=MagicMock(status_code=404), - body={"error": "not found"}, - ) - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - yield mock_client @@ -213,7 +201,6 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: id="test-simple-eval", # Explicit ID display_name="Simple Word Count Eval", description="E2E test evaluator", - force=False, yes=True, # Non-interactive ) @@ -326,7 +313,6 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: id="test-param-eval", display_name="Parametrized Eval", description="Test parametrized evaluator", - force=False, yes=True, ) @@ -505,7 +491,6 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: id=None, # Auto-generate from test name display_name=None, # Auto-generate description=None, # Auto-generate - force=False, yes=True, ) @@ -564,95 +549,3 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: if test_project_dir in sys.path: sys.path.remove(test_project_dir) shutil.rmtree(test_project_dir, ignore_errors=True) - - -def test_ep_upload_force_flag_triggers_delete_flow( - mock_env_variables, - mock_gcs_upload, - mock_platform_api_client, -): - """ - Test that --force flag triggers the check/delete/recreate flow - """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests - - test_content = """ -from eval_protocol.pytest import evaluation_test -from eval_protocol.models import EvaluationRow - -@evaluation_test(input_rows=[[EvaluationRow()]]) -async def test_force_eval(row: EvaluationRow) -> EvaluationRow: - return row -""" - - test_project_dir, test_file_path = create_test_project_with_evaluation_test(test_content, "test_force.py") - - original_cwd = os.getcwd() - - try: - os.chdir(test_project_dir) - - # Mock the Fireworks client with evaluator existing (for force flow) - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client - - # Mock evaluators.get to return an existing evaluator (not raise NotFoundError) - mock_existing_evaluator = MagicMock() - mock_existing_evaluator.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.get.return_value = mock_existing_evaluator - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - - # Mock evaluators.create response - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.create.return_value = mock_create_response - - # Mock get_upload_endpoint - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): - response = MagicMock() - signed_urls = {} - for filename in filename_to_size.keys(): - signed_urls[filename] = f"https://storage.googleapis.com/test-bucket/{filename}?signed=true" - response.filename_to_signed_urls = signed_urls - return response - - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - - # Mock validate_upload - mock_client.evaluators.validate_upload.return_value = MagicMock() - - discovered_tests = _discover_tests(test_project_dir) - - args = argparse.Namespace( - path=test_project_dir, - entry=None, - id="test-force", - display_name=None, - description=None, - force=True, # Force flag enabled - yes=True, - ) - - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: - mock_select.return_value = discovered_tests - exit_code = upload_command(args) - - assert exit_code == 0 - - # Verify check happened (evaluators.get was called) - assert mock_client.evaluators.get.called, "Should check if evaluator exists" - - # Verify delete happened (since evaluator existed) - assert mock_client.evaluators.delete.called, "Should delete existing evaluator" - - # Verify create happened after delete - assert mock_client.evaluators.create.called, "Should create evaluator after delete" - - finally: - os.chdir(original_cwd) - if test_project_dir in sys.path: - sys.path.remove(test_project_dir) - shutil.rmtree(test_project_dir, ignore_errors=True) From 532e071d5204d3c2ec2696e17bad9afdf642bdf0 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 8 Jan 2026 15:22:21 -0800 Subject: [PATCH 07/39] REVERT this later update to latest once SDK is published with changes --- pyproject.toml | 2 +- uv.lock | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e5caa497..80e52b77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "pytest-asyncio>=0.21.0", "peewee>=3.18.2", "backoff>=2.2.0", - "fireworks-ai==1.0.0a20", + "fireworks-ai @ https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl", "questionary>=2.0.0", # Dependencies for vendored tau2 package "toml>=0.10.0", diff --git a/uv.lock b/uv.lock index c175b81f..b8b23a7d 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,7 +1312,7 @@ requires-dist = [ { name = "dspy", marker = "extra == 'dspy'", specifier = ">=3.0.0" }, { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.116.1" }, - { name = "fireworks-ai", specifier = "==1.0.0a20" }, + { name = "fireworks-ai", url = "https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl" }, { name = "google-auth", marker = "extra == 'bigquery'", specifier = ">=2.0.0" }, { name = "google-cloud-bigquery", marker = "extra == 'bigquery'", specifier = ">=3.0.0" }, { name = "gymnasium", marker = "extra == 'dev'", specifier = ">=1.2.0" }, @@ -1583,7 +1583,7 @@ wheels = [ [[package]] name = "fireworks-ai" version = "1.0.0a20" -source = { registry = "https://pypi.org/simple" } +source = { url = "https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl" } dependencies = [ { name = "aiohttp" }, { name = "anyio" }, @@ -1594,9 +1594,20 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1d/c6/cdc6c152876ee1253491e6f72c65c2cdaf7b22b320be0cec7ac5778d3b1c/fireworks_ai-1.0.0a20.tar.gz", hash = "sha256:c84f702445679ea768461dba8fb027175b82255021832a89f9ece65821a2ab25", size = 564097, upload-time = "2025-12-23T19:21:17.891Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/a4/e2bc9c4af291786bc7fe364ae63503ba2c8161c2e71223d570a77f0a1415/fireworks_ai-1.0.0a20-py3-none-any.whl", hash = "sha256:b5e199978f71b564b2e19cf55a71c1ac20906d9a7b4ae75135fdccb245227722", size = 304153, upload-time = "2025-12-23T19:21:15.943Z" }, + { url = "https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl", hash = "sha256:d0fb6d84bc93d161276be6b8f134d77e0cbc7f12f3477482485fa4bfc1491d5a" }, +] + +[package.metadata] +requires-dist = [ + { name = "aiohttp" }, + { name = "anyio", specifier = ">=3.5.0,<5" }, + { name = "distro", specifier = ">=1.7.0,<2" }, + { name = "httpx", specifier = ">=0.23.0,<1" }, + { name = "httpx-aiohttp", specifier = ">=0.1.9" }, + { name = "pydantic", specifier = ">=1.9.0,<3" }, + { name = "sniffio" }, + { name = "typing-extensions", specifier = ">=4.10,<5" }, ] [[package]] From 060d72c2a778994f1ea7ce506f56cbd001b6ff84 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Fri, 9 Jan 2026 10:24:48 -0800 Subject: [PATCH 08/39] fix mock tests --- tests/test_cli_create_rft.py | 9 +- tests/test_ep_upload_e2e.py | 176 ++++++++++------------------------- tests/test_evaluation.py | 24 +++-- 3 files changed, 75 insertions(+), 134 deletions(-) diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 9536ef30..7b989028 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -24,7 +24,7 @@ def _write_json(path: str, data: dict) -> None: def stub_fireworks(monkeypatch) -> dict[str, Any]: """ Stub Fireworks SDK so tests stay offline and so create_rft.py can inspect a stable - create() signature (it uses inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)). + create() signature (it uses inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create)). Returns: A dict containing the last captured create() kwargs under key "kwargs". @@ -72,12 +72,15 @@ def create( return SimpleNamespace(name=f"accounts/{account_id}/reinforcementFineTuningJobs/xyz") class _FakeFW: - def __init__(self, api_key=None, base_url=None): + def __init__(self, api_key=None, base_url=None, account_id=None, default_headers=None): self.api_key = api_key self.base_url = base_url + self.account_id = account_id + self.default_headers = default_headers self.reinforcement_fine_tuning_jobs = _FakeJobs() - monkeypatch.setattr(cr, "Fireworks", _FakeFW) + # Patch create_fireworks_client to return our fake client + monkeypatch.setattr(cr, "create_fireworks_client", lambda **kwargs: _FakeFW(**kwargs)) return captured diff --git a/tests/test_ep_upload_e2e.py b/tests/test_ep_upload_e2e.py index 005dac60..e76ac246 100644 --- a/tests/test_ep_upload_e2e.py +++ b/tests/test_ep_upload_e2e.py @@ -80,8 +80,8 @@ def mock_gcs_upload(): @pytest.fixture def mock_fireworks_client(): - """Mock the Fireworks SDK client used in evaluation.py""" - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: + """Mock the Fireworks SDK client used in fireworks_client.py""" + with patch("eval_protocol.fireworks_client.Fireworks") as mock_fw_class: mock_client = MagicMock() mock_fw_class.return_value = mock_client @@ -92,8 +92,13 @@ def mock_fireworks_client(): mock_create_response.description = "Test description" mock_client.evaluators.create.return_value = mock_create_response - # Mock evaluators.get_upload_endpoint response - will be set dynamically - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): + # Mock evaluator_versions.create response + mock_version_response = MagicMock() + mock_version_response.name = "accounts/test_account/evaluators/test-eval/versions/v1" + mock_client.evaluator_versions.create.return_value = mock_version_response + + # Mock evaluator_versions.get_upload_endpoint response - will be set dynamically + def get_upload_endpoint_side_effect(evaluator_id, version_id, filename_to_size): response = MagicMock() signed_urls = {} for filename in filename_to_size.keys(): @@ -101,23 +106,13 @@ def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): response.filename_to_signed_urls = signed_urls return response - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect + mock_client.evaluator_versions.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - # Mock evaluators.validate_upload response + # Mock evaluator_versions.validate_upload response mock_validate_response = MagicMock() mock_validate_response.success = True mock_validate_response.valid = True - mock_client.evaluators.validate_upload.return_value = mock_validate_response - - yield mock_client - - -@pytest.fixture -def mock_platform_api_client(): - """Mock the Fireworks SDK client used in platform_api.py for secrets""" - with patch("eval_protocol.platform_api.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client + mock_client.evaluator_versions.validate_upload.return_value = mock_validate_response # Mock secrets.get - raise NotFoundError to simulate secret doesn't exist from fireworks import NotFoundError @@ -129,13 +124,23 @@ def mock_platform_api_client(): ) # Mock secrets.create - successful - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/secrets/test-secret" - mock_client.secrets.create.return_value = mock_create_response + mock_secrets_create_response = MagicMock() + mock_secrets_create_response.name = "accounts/test_account/secrets/test-secret" + mock_client.secrets.create.return_value = mock_secrets_create_response yield mock_client +@pytest.fixture +def mock_platform_api_client(mock_fireworks_client): + """ + Mock the Fireworks SDK client for secrets. + This is now just an alias for mock_fireworks_client since both use the same patched location. + The mock_fireworks_client fixture already includes secrets mocking. + """ + yield mock_fireworks_client + + def test_ep_upload_discovers_and_uploads_evaluation_test( mock_env_variables, mock_fireworks_client, mock_platform_api_client, mock_gcs_upload, monkeypatch ): @@ -219,13 +224,18 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: # Step 1: Create evaluator assert mock_fireworks_client.evaluators.create.called, "Should call evaluators.create" - # Step 2: Get upload endpoint - assert mock_fireworks_client.evaluators.get_upload_endpoint.called, ( - "Should call evaluators.get_upload_endpoint" + # Step 1b: Create evaluator version + assert mock_fireworks_client.evaluator_versions.create.called, "Should call evaluator_versions.create" + + # Step 2: Get upload endpoint (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, ( + "Should call evaluator_versions.get_upload_endpoint" ) - # Step 3: Validate upload - assert mock_fireworks_client.evaluators.validate_upload.called, "Should call evaluators.validate_upload" + # Step 3: Validate upload (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.validate_upload.called, ( + "Should call evaluator_versions.validate_upload" + ) # Step 4: GCS upload assert mock_gcs_upload.send.called, "Should upload tar.gz to GCS" @@ -325,8 +335,9 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: # Verify upload flow completed via Fireworks SDK assert mock_fireworks_client.evaluators.create.called - assert mock_fireworks_client.evaluators.get_upload_endpoint.called - assert mock_fireworks_client.evaluators.validate_upload.called + assert mock_fireworks_client.evaluator_versions.create.called + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called + assert mock_fireworks_client.evaluator_versions.validate_upload.called assert mock_gcs_upload.send.called finally: @@ -505,8 +516,13 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: # Step 1: Create evaluator assert mock_fireworks_client.evaluators.create.called, "Missing create call" - # Step 2: Get upload endpoint - assert mock_fireworks_client.evaluators.get_upload_endpoint.called, "Missing getUploadEndpoint call" + # Step 1b: Create evaluator version + assert mock_fireworks_client.evaluator_versions.create.called, "Missing evaluator_versions.create call" + + # Step 2: Get upload endpoint (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, ( + "Missing evaluator_versions.get_upload_endpoint call" + ) # Step 3: Upload to GCS assert mock_gcs_upload.send.called, "Missing GCS upload" @@ -514,8 +530,10 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: assert gcs_request.method == "PUT" assert "storage.googleapis.com" in gcs_request.url - # Step 4: Validate - assert mock_fireworks_client.evaluators.validate_upload.called, "Missing validateUpload call" + # Step 4: Validate (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.validate_upload.called, ( + "Missing evaluator_versions.validate_upload call" + ) # 4. VERIFY PAYLOAD DETAILS create_call = mock_fireworks_client.evaluators.create.call_args @@ -532,8 +550,8 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: assert "test_math_eval.py::test_math_correctness" in entry_point # 5. VERIFY TAR.GZ WAS CREATED AND UPLOADED - # Check getUploadEndpoint call payload - upload_call = mock_fireworks_client.evaluators.get_upload_endpoint.call_args + # Check getUploadEndpoint call payload (via evaluator_versions API) + upload_call = mock_fireworks_client.evaluator_versions.get_upload_endpoint.call_args assert upload_call is not None filename_to_size = upload_call.kwargs.get("filename_to_size", {}) assert filename_to_size, "Should have filename_to_size" @@ -582,95 +600,3 @@ def test_create_tar_includes_dockerignored_files(tmp_path): for expected_path in expected_paths: assert expected_path in names, f"Expected {expected_path} in archive" - - -def test_ep_upload_force_flag_triggers_delete_flow( - mock_env_variables, - mock_gcs_upload, - mock_platform_api_client, -): - """ - Test that --force flag triggers the check/delete/recreate flow - """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests - - test_content = """ -from eval_protocol.pytest import evaluation_test -from eval_protocol.models import EvaluationRow - -@evaluation_test(input_rows=[[EvaluationRow()]]) -async def test_force_eval(row: EvaluationRow) -> EvaluationRow: - return row -""" - - test_project_dir, test_file_path = create_test_project_with_evaluation_test(test_content, "test_force.py") - - original_cwd = os.getcwd() - - try: - os.chdir(test_project_dir) - - # Mock the Fireworks client with evaluator existing (for force flow) - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client - - # Mock evaluators.get to return an existing evaluator (not raise NotFoundError) - mock_existing_evaluator = MagicMock() - mock_existing_evaluator.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.get.return_value = mock_existing_evaluator - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - - # Mock evaluators.create response - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.create.return_value = mock_create_response - - # Mock get_upload_endpoint - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): - response = MagicMock() - signed_urls = {} - for filename in filename_to_size.keys(): - signed_urls[filename] = f"https://storage.googleapis.com/test-bucket/{filename}?signed=true" - response.filename_to_signed_urls = signed_urls - return response - - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - - # Mock validate_upload - mock_client.evaluators.validate_upload.return_value = MagicMock() - - discovered_tests = _discover_tests(test_project_dir) - - args = argparse.Namespace( - path=test_project_dir, - entry=None, - id="test-force", - display_name=None, - description=None, - force=True, # Force flag enabled - yes=True, - ) - - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: - mock_select.return_value = discovered_tests - exit_code = upload_command(args) - - assert exit_code == 0 - - # Verify check happened (evaluators.get was called) - assert mock_client.evaluators.get.called, "Should check if evaluator exists" - - # Verify delete happened (since evaluator existed) - assert mock_client.evaluators.delete.called, "Should delete existing evaluator" - - # Verify create happened after delete - assert mock_client.evaluators.create.called, "Should create evaluator after delete" - - finally: - os.chdir(original_cwd) - if test_project_dir in sys.path: - sys.path.remove(test_project_dir) - shutil.rmtree(test_project_dir, ignore_errors=True) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 942c1962..1dad3b19 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -41,6 +41,7 @@ def test_create_evaluation_helper(monkeypatch): # Track SDK calls create_called = False + version_create_called = False upload_endpoint_called = False validate_called = False @@ -61,7 +62,16 @@ def mock_create(evaluator_id, evaluator): assert evaluator["description"] == "Test description" return mock_evaluator_result - def mock_get_upload_endpoint(evaluator_id, filename_to_size): + # Mock evaluator_versions.create + mock_version_result = MagicMock() + mock_version_result.name = "accounts/test_account/evaluators/test-eval/versions/v1" + + def mock_version_create(evaluator_id, evaluator_version): + nonlocal version_create_called + version_create_called = True + return mock_version_result + + def mock_get_upload_endpoint(evaluator_id, version_id, filename_to_size): nonlocal upload_endpoint_called upload_endpoint_called = True mock_response = MagicMock() @@ -71,7 +81,7 @@ def mock_get_upload_endpoint(evaluator_id, filename_to_size): mock_response.filename_to_signed_urls = signed_urls return mock_response - def mock_validate_upload(evaluator_id, body): + def mock_validate_upload(evaluator_id, version_id): nonlocal validate_called validate_called = True return MagicMock() @@ -83,13 +93,14 @@ def mock_validate_upload(evaluator_id, body): mock_gcs_response.raise_for_status = MagicMock() mock_session.send.return_value = mock_gcs_response - # Patch the Fireworks client - with patch("eval_protocol.evaluation.Fireworks") as mock_fireworks_class: + # Patch the Fireworks client at the location where it's imported + with patch("eval_protocol.fireworks_client.Fireworks") as mock_fireworks_class: mock_client = MagicMock() mock_fireworks_class.return_value = mock_client mock_client.evaluators.create = mock_create - mock_client.evaluators.get_upload_endpoint = mock_get_upload_endpoint - mock_client.evaluators.validate_upload = mock_validate_upload + mock_client.evaluator_versions.create = mock_version_create + mock_client.evaluator_versions.get_upload_endpoint = mock_get_upload_endpoint + mock_client.evaluator_versions.validate_upload = mock_validate_upload # Patch requests.Session for GCS upload monkeypatch.setattr("requests.Session", lambda: mock_session) @@ -109,6 +120,7 @@ def mock_validate_upload(evaluator_id, body): # Verify full upload flow was executed assert create_called, "Create endpoint should be called" + assert version_create_called, "Version create should be called" assert upload_endpoint_called, "GetUploadEndpoint should be called" assert validate_called, "ValidateUpload should be called" assert mock_session.send.called, "GCS upload should happen" From bc31c9f5dc0ea67c292e902f5824dce57cd3f33d Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Fri, 9 Jan 2026 14:32:32 -0800 Subject: [PATCH 09/39] Add error handling for evaluator creation in evaluation module - Implemented a try-except block to handle APIStatusError during evaluator creation. - Added logic to check for existing evaluators and retrieve the existing one if a conflict occurs (status code 409). - Enhanced logging for better traceability of evaluator creation process. --- eval_protocol/evaluation.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 8c047309..ee98fc1e 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -202,12 +202,20 @@ def create(self, evaluator_id, display_name=None, description=None): logger.info(f"Creating evaluator '{evaluator_id}' for account '{account_id}'...") try: - # Create evaluator using SDK - result = client.evaluators.create( - evaluator_id=evaluator_id, - evaluator=evaluator_params, - ) - logger.info(f"Successfully created evaluator '{evaluator_id}'") + # Try to create evaluator using SDK + try: + result = client.evaluators.create( + evaluator_id=evaluator_id, + evaluator=evaluator_params, + ) + logger.info(f"Successfully created evaluator '{evaluator_id}'") + except fireworks.APIStatusError as create_error: + if create_error.status_code == 409: + # Evaluator already exists, get the existing one and proceed to create a new version + logger.info(f"Evaluator '{evaluator_id}' already exists, creating new version...") + result = client.evaluators.get(evaluator_id=evaluator_id) + else: + raise # Upload code as tar.gz to GCS evaluator_name = result.name # e.g., "accounts/pyroworks/evaluators/test-123" From ea08062cd0f4a1af3cb3197e5c20fea9590da17e Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Fri, 9 Jan 2026 15:42:07 -0800 Subject: [PATCH 10/39] Support EP_REMOTE_API_KEY --- eval_protocol/adapters/fireworks_tracing.py | 11 +++++++++-- eval_protocol/pytest/tracing_utils.py | 8 ++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 8e5c7d15..4913e33b 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -253,6 +253,7 @@ def __init__( project_id: Optional[str] = None, base_url: str = "https://tracing.fireworks.ai", timeout: int = 300, + api_key: Optional[str] = None, ): """Initialize the Fireworks Tracing adapter. @@ -260,10 +261,16 @@ def __init__( project_id: Optional project ID. If not provided, uses the default project configured on the server. base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai) timeout: Request timeout in seconds (default: 300) + api_key: Optional API key. If not provided, falls back to FIREWORKS_API_KEY environment variable. """ self.project_id = project_id self.base_url = base_url.rstrip("/") self.timeout = timeout + self._api_key = api_key + + def _get_api_key(self) -> Optional[str]: + """Get the API key, preferring instance-level key over environment variable.""" + return self._api_key or os.environ.get("FIREWORKS_API_KEY") def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -> List[Dict[str, Any]]: """Fetch logs from Fireworks tracing gateway /logs endpoint. @@ -276,7 +283,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) - from ..common_utils import get_user_agent headers = { - "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "Authorization": f"Bearer {self._get_api_key()}", "User-Agent": get_user_agent(), } params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"} @@ -407,7 +414,7 @@ def get_evaluation_rows( from ..common_utils import get_user_agent headers = { - "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "Authorization": f"Bearer {self._get_api_key()}", "User-Agent": get_user_agent(), } diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 158fcbb4..7d6b1714 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -19,7 +19,9 @@ def default_fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDat def fetch_traces() -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) + # Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY + api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY") + adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key) return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -131,7 +133,9 @@ def build_init_request( final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url) # Extract API key from environment or completion_params - api_key = os.environ.get("FIREWORKS_API_KEY") + # EP_REMOTE_API_KEY takes precedence for remote rollout processors, + # falling back to FIREWORKS_API_KEY for backwards compatibility + api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY") return InitRequest( completion_params=completion_params_dict, From 6b53ac1a5f45a0c4691c79aa9d31076f7361ec0e Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 12 Jan 2026 10:28:42 -0800 Subject: [PATCH 11/39] include launch.json.backup --- .vscode/.gitignore | 1 + .vscode/launch.json.backup | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 .vscode/.gitignore create mode 100644 .vscode/launch.json.backup diff --git a/.vscode/.gitignore b/.vscode/.gitignore new file mode 100644 index 00000000..c2dd2a37 --- /dev/null +++ b/.vscode/.gitignore @@ -0,0 +1 @@ +!launch.json.backup diff --git a/.vscode/launch.json.backup b/.vscode/launch.json.backup new file mode 100644 index 00000000..8088d350 --- /dev/null +++ b/.vscode/launch.json.backup @@ -0,0 +1,37 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "EP: Upload", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["upload"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + }, + { + "name": "EP: Local Test", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["local-test"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + } + ] +} From ec0c8ca8a8fb6e61a32b65ae038336945305f05c Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 12 Jan 2026 15:48:56 -0800 Subject: [PATCH 12/39] rename to .example and add docker run extra arg --- .vscode/{launch.json.backup => launch.json.example} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename .vscode/{launch.json.backup => launch.json.example} (94%) diff --git a/.vscode/launch.json.backup b/.vscode/launch.json.example similarity index 94% rename from .vscode/launch.json.backup rename to .vscode/launch.json.example index 8088d350..74014b37 100644 --- a/.vscode/launch.json.backup +++ b/.vscode/launch.json.example @@ -22,7 +22,7 @@ "type": "python", "request": "launch", "module": "eval_protocol.cli", - "args": ["local-test"], + "args": ["local-test", "--docker-run-extra", "--env-file .env"], "console": "integratedTerminal", "justMyCode": false, "cwd": "", From fc036f5ea5cf6eeebe9e170f107ba790196c559c Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 12 Jan 2026 15:59:22 -0800 Subject: [PATCH 13/39] use ignore-docker by default --- .vscode/launch.json.backup | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 .vscode/launch.json.backup diff --git a/.vscode/launch.json.backup b/.vscode/launch.json.backup new file mode 100644 index 00000000..df851aa3 --- /dev/null +++ b/.vscode/launch.json.backup @@ -0,0 +1,37 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "EP: Upload", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["upload"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + }, + { + "name": "EP: Local Test", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["local-test", "--ignore-docker"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + } + ] +} From 45665846e505479404441c77d04b7121249e2060 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 12 Jan 2026 16:04:51 -0800 Subject: [PATCH 14/39] delete backup --- .vscode/launch.json.backup | 37 ------------------------------------- 1 file changed, 37 deletions(-) delete mode 100644 .vscode/launch.json.backup diff --git a/.vscode/launch.json.backup b/.vscode/launch.json.backup deleted file mode 100644 index df851aa3..00000000 --- a/.vscode/launch.json.backup +++ /dev/null @@ -1,37 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "EP: Upload", - "type": "python", - "request": "launch", - "module": "eval_protocol.cli", - "args": ["upload"], - "console": "integratedTerminal", - "justMyCode": false, - "cwd": "", - "env": { - "PYTHONPATH": "${workspaceFolder}", - "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", - "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", - "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" - } - }, - { - "name": "EP: Local Test", - "type": "python", - "request": "launch", - "module": "eval_protocol.cli", - "args": ["local-test", "--ignore-docker"], - "console": "integratedTerminal", - "justMyCode": false, - "cwd": "", - "env": { - "PYTHONPATH": "${workspaceFolder}", - "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", - "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", - "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" - } - } - ] -} From f103b693ab6bd3a1b2339bb5af388729ab9884c2 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 12 Jan 2026 16:05:07 -0800 Subject: [PATCH 15/39] ignore-docker by default in dev --- .vscode/launch.json.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/launch.json.example b/.vscode/launch.json.example index 74014b37..df851aa3 100644 --- a/.vscode/launch.json.example +++ b/.vscode/launch.json.example @@ -22,7 +22,7 @@ "type": "python", "request": "launch", "module": "eval_protocol.cli", - "args": ["local-test", "--docker-run-extra", "--env-file .env"], + "args": ["local-test", "--ignore-docker"], "console": "integratedTerminal", "justMyCode": false, "cwd": "", From 9c3e4175246d7dc0cbeb3040d52520fc22ba5bf7 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 12 Jan 2026 17:14:03 -0800 Subject: [PATCH 16/39] Refactor evaluator function calls to use Fireworks directly for method signature introspection, avoiding unnecessary API requests during help invocations. --- eval_protocol/cli.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 40dce34c..51b5d4dd 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -8,9 +8,10 @@ import sys from pathlib import Path +from fireworks import Fireworks + from .cli_commands.common import setup_logging from .cli_commands.utils import add_args_from_callable_signature -from .fireworks_client import create_fireworks_client logger = logging.getLogger(__name__) @@ -82,7 +83,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse ) # Auto-generate flags from SDK Fireworks().evaluators.create() signature - create_evaluator_fn = create_fireworks_client().evaluators.create + # Note: We use Fireworks() directly here instead of create_fireworks_client() + # because we only need the method signature for introspection, not a fully + # authenticated client. create_fireworks_client() would trigger an HTTP request + # to verify the API key, causing delays even for --help invocations. + create_evaluator_fn = Fireworks().evaluators.create upload_skip_fields = { "__top_level__": { @@ -191,7 +196,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.", } - create_rft_job_fn = create_fireworks_client().reinforcement_fine_tuning_jobs.create + # Note: We use Fireworks() directly here instead of create_fireworks_client() + # because we only need the method signature for introspection, not a fully + # authenticated client. create_fireworks_client() would trigger an HTTP request + # to verify the API key, causing delays even for --help invocations. + create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create add_args_from_callable_signature( rft_parser, From ea673f45e28a6cb248fc427cefbce5bb55dfc577 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 12:03:41 -0800 Subject: [PATCH 17/39] use in-flight SDK version --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 80e52b77..4b2a92b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "pytest-asyncio>=0.21.0", "peewee>=3.18.2", "backoff>=2.2.0", - "fireworks-ai @ https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl", + "fireworks-ai @ https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl", "questionary>=2.0.0", # Dependencies for vendored tau2 package "toml>=0.10.0", diff --git a/uv.lock b/uv.lock index b8b23a7d..c5df7b2f 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,7 +1312,7 @@ requires-dist = [ { name = "dspy", marker = "extra == 'dspy'", specifier = ">=3.0.0" }, { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.116.1" }, - { name = "fireworks-ai", url = "https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl" }, + { name = "fireworks-ai", url = "https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl" }, { name = "google-auth", marker = "extra == 'bigquery'", specifier = ">=2.0.0" }, { name = "google-cloud-bigquery", marker = "extra == 'bigquery'", specifier = ">=3.0.0" }, { name = "gymnasium", marker = "extra == 'dev'", specifier = ">=1.2.0" }, @@ -1582,8 +1582,8 @@ wheels = [ [[package]] name = "fireworks-ai" -version = "1.0.0a20" -source = { url = "https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl" } +version = "1.0.0a21" +source = { url = "https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl" } dependencies = [ { name = "aiohttp" }, { name = "anyio" }, @@ -1595,7 +1595,7 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://pkg.stainless.com/s/fireworks-ai-python/edd49b0b378db786d6e6b043839baa9aeb6cf0c0/fireworks_ai-1.0.0a20-py3-none-any.whl", hash = "sha256:d0fb6d84bc93d161276be6b8f134d77e0cbc7f12f3477482485fa4bfc1491d5a" }, + { url = "https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl", hash = "sha256:882c45957fa4a5be55680b9a8972381ec83875834f097ab743981716819aecb7" }, ] [package.metadata] From 26fbc2de81a971bf328cd29606f625f9ea7a6fb7 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 13:09:53 -0800 Subject: [PATCH 18/39] Enhance evaluator handling by returning version ID on creation and updating polling functions to target specific evaluator versions. Refactor related CLI commands and tests to accommodate these changes, ensuring clearer status messages and improved error handling. --- eval_protocol/cli_commands/create_rft.py | 155 ++++++++++------------- eval_protocol/cli_commands/upload.py | 4 +- eval_protocol/evaluation.py | 8 +- tests/test_cli_create_rft.py | 54 ++------ tests/test_evaluation.py | 5 +- 5 files changed, 90 insertions(+), 136 deletions(-) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index fc4d20b4..6a6123f1 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -7,20 +7,18 @@ import time from typing import Any, Callable, Dict, Optional import inspect -import requests import tempfile from pydantic import ValidationError from ..auth import get_fireworks_api_base, get_fireworks_api_key from ..fireworks_client import create_fireworks_client -from ..common_utils import get_user_agent, load_jsonl +from ..common_utils import load_jsonl from ..fireworks_rft import ( create_dataset_from_jsonl, detect_dataset_builder, materialize_dataset_via_builder, ) from ..models import EvaluationRow -from .upload import upload_command from .utils import ( _build_entry_point, _build_trimmed_dataset_id, @@ -222,64 +220,68 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) return None -def _poll_evaluator_status( - evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10 +def _poll_evaluator_version_status( + evaluator_id: str, + version_id: str, + api_key: str, + api_base: str, + timeout_minutes: int = 10, ) -> bool: """ - Poll evaluator status until it becomes ACTIVE or times out. + Poll a specific evaluator version status until it becomes ACTIVE or times out. + + Uses the Fireworks SDK to get the specified version of the evaluator and checks + its build state. Args: - evaluator_resource_name: Full evaluator resource name (e.g., accounts/xxx/evaluators/yyy) + evaluator_id: The evaluator ID (not full resource name) + version_id: The specific version ID to poll api_key: Fireworks API key api_base: Fireworks API base URL timeout_minutes: Maximum time to wait in minutes Returns: - True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED + True if evaluator version becomes ACTIVE, False if timeout or BUILD_FAILED """ - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - - check_url = f"{api_base}/v1/{evaluator_resource_name}" timeout_seconds = timeout_minutes * 60 poll_interval = 10 # seconds start_time = time.time() - print(f"Polling evaluator status (timeout: {timeout_minutes}m, interval: {poll_interval}s)...") + print( + f"Polling evaluator version '{version_id}' status (timeout: {timeout_minutes}m, interval: {poll_interval}s)..." + ) + + client = create_fireworks_client(api_key=api_key, base_url=api_base) while time.time() - start_time < timeout_seconds: try: - response = requests.get(check_url, headers=headers, timeout=30) - response.raise_for_status() - - evaluator_data = response.json() - state = evaluator_data.get("state", "STATE_UNSPECIFIED") - status = evaluator_data.get("status", "") + version = client.evaluator_versions.get(version_id, evaluator_id=evaluator_id) + state = version.state or "STATE_UNSPECIFIED" + status_msg = "" + if version.status and version.status.message: + status_msg = version.status.message if state == "ACTIVE": - print("✅ Evaluator is ACTIVE and ready!") + print("✅ Evaluator version is ACTIVE and ready!") return True elif state == "BUILD_FAILED": - print(f"❌ Evaluator build failed. Status: {status}") + print(f"❌ Evaluator version build failed. Status: {status_msg}") return False elif state == "BUILDING": elapsed_minutes = (time.time() - start_time) / 60 - print(f"⏳ Evaluator is still building... ({elapsed_minutes:.1f}m elapsed)") + print(f"⏳ Evaluator version is still building... ({elapsed_minutes:.1f}m elapsed)") else: - print(f"⏳ Evaluator state: {state}, status: {status}") + print(f"⏳ Evaluator version state: {state}, status: {status_msg}") - except requests.exceptions.RequestException as e: - print(f"Warning: Failed to check evaluator status: {e}") + except Exception as e: + print(f"Warning: Failed to check evaluator version status: {e}") # Wait before next poll time.sleep(poll_interval) # Timeout reached elapsed_minutes = (time.time() - start_time) / 60 - print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator is not yet ACTIVE") + print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator version is not yet ACTIVE") return False @@ -564,40 +566,16 @@ def _upload_dataset( def _upload_and_ensure_evaluator( project_root: str, evaluator_id: str, - evaluator_resource_name: str, api_key: str, api_base: str, ) -> bool: - """Ensure the evaluator exists and is ACTIVE, uploading it if needed.""" - # Check if evaluator already exists - try: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) - if resp.ok: - state = resp.json().get("state", "STATE_UNSPECIFIED") - print(f"✓ Evaluator exists (state: {state}). Skipping upload.") - # Poll for ACTIVE before proceeding - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - if not _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ): - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\n❌ Evaluator is not ready within the timeout period.") - print(f"📊 Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - except requests.exceptions.RequestException: - pass + """Upload evaluator and ensure its version becomes ACTIVE. + + Creates/updates the evaluator and uploads the code, then polls the specific + version until it becomes ACTIVE. + """ + from eval_protocol.evaluation import create_evaluation - # Ensure evaluator exists by invoking the upload flow programmatically try: tests = _discover_tests(project_root) selected_entry: Optional[str] = None @@ -614,39 +592,37 @@ def _upload_and_ensure_evaluator( ) return False - upload_args = argparse.Namespace( - path=project_root, - entry=selected_entry, - id=evaluator_id, - display_name=None, - description=None, - yes=True, - env_file=None, + print(f"\nUploading evaluator '{evaluator_id}'...") + result, version_id = create_evaluation( + evaluator_id=evaluator_id, + display_name=evaluator_id, + description=f"Evaluator for {evaluator_id}", + entry_point=selected_entry, ) - rc = upload_command(upload_args) - if rc == 0: - print(f"✓ Uploaded/ensured evaluator: {evaluator_id}") - - # Poll for evaluator status - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - is_active = _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ) + if not version_id: + print("Warning: Evaluator created but version upload failed.") + return False - if not is_active: - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\n❌ Evaluator is not ready within the timeout period.") - print(f"📊 Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - else: - print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") + print(f"✓ Uploaded evaluator: {evaluator_id} (version: {version_id})") + + # Poll for the specific evaluator version status + print(f"Waiting for evaluator '{evaluator_id}' version '{version_id}' to become ACTIVE...") + is_active = _poll_evaluator_version_status( + evaluator_id=evaluator_id, + version_id=version_id, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ) + + if not is_active: + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\n❌ Evaluator version is not ready within the timeout period.") + print(f"📊 Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") return False + return True except Exception as e: print(f"Warning: Failed to upload evaluator automatically: {e}") return False @@ -802,11 +778,10 @@ def create_rft_command(args) -> int: if not dataset_id or not dataset_resource: return 1 - # 5) Ensure evaluator exists and is ACTIVE (upload + poll if needed) + # 5) Ensure evaluator exists and its latest version is ACTIVE (upload + poll if needed) if not _upload_and_ensure_evaluator( project_root=project_root, evaluator_id=evaluator_id, - evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, ): diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index d61b31ae..5abe49e8 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -377,7 +377,7 @@ def upload_command(args: argparse.Namespace) -> int: print(f"\nUploading evaluator '{evaluator_id}' for {qualname.split('.')[-1]}...") try: - result = create_evaluation( + result, version_id = create_evaluation( evaluator_id=evaluator_id, display_name=display_name or evaluator_id, description=description or f"Evaluator for {qualname}", @@ -387,6 +387,8 @@ def upload_command(args: argparse.Namespace) -> int: # Print success message with Fireworks dashboard link print(f"\n✅ Successfully uploaded evaluator: {evaluator_id}") + if version_id: + print(f" Version: {version_id}") print("📊 View in Fireworks Dashboard:") dashboard_url = _build_evaluator_dashboard_url(evaluator_id) print(f" {dashboard_url}\n") diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index ee98fc1e..31298992 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -351,8 +351,10 @@ def create(self, evaluator_id, display_name=None, description=None): except Exception as upload_error: logger.warning(f"Code upload failed (evaluator created but code not uploaded): {upload_error}") # Don't fail - evaluator is created, just code upload failed + # Return None for version_id since upload failed + return result, None - return result # Return after attempting upload + return result, evaluator_version_id # Return evaluator result and version ID except fireworks.APIStatusError as e: logger.error(f"Error creating evaluator: {str(e)}") logger.error(f"Status code: {e.status_code}, Response: {e.response.text}") @@ -392,6 +394,10 @@ def create_evaluation( account_id: Optional Fireworks account ID api_key: Optional Fireworks API key entry_point: Optional entry point (module::function or path::function) + + Returns: + A tuple of (evaluator_result, version_id) where version_id is the ID of the + created evaluator version, or None if upload failed. """ evaluator = Evaluator( account_id=account_id, diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 7b989028..9832aec2 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -1,7 +1,6 @@ import json import os import argparse -import requests from types import SimpleNamespace from unittest.mock import patch from typing import Any, cast @@ -106,7 +105,7 @@ def rft_test_harness(tmp_path, monkeypatch, stub_fireworks): monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) return project @@ -446,7 +445,7 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_ monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) captured = {"dataset_id": None} @@ -641,17 +640,8 @@ def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Mock evaluator exists and is ACTIVE - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + # Mock evaluator upload and version polling - evaluator becomes ACTIVE + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) # Provide dataset via --dataset-jsonl so no test discovery needed ds_path = project / "dataset.jsonl" @@ -703,11 +693,8 @@ def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Evaluator does not exist (force path into upload section) - def _raise(*a, **k): - raise requests.exceptions.RequestException("nope") - - monkeypatch.setattr(cr.requests, "get", _raise) + # Mock _upload_and_ensure_evaluator to fail (ambiguous tests) + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: False) # Two discovered tests (ambiguous) f1 = project / "a.py" @@ -948,18 +935,8 @@ def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(r d2 = SimpleNamespace(qualname="beta.test_two", file_path=str(f2)) monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) - # Evaluator exists and is ACTIVE (skip upload) - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # Evaluator upload succeeds and version becomes ACTIVE + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) # We will provide JSONL via input_dataset extractor for matching test (beta.test_two) jsonl_path = project / "data.jsonl" @@ -1040,17 +1017,8 @@ def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatc monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "pyroworks-dev") - # Mock evaluator exists and ACTIVE - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + # Mock evaluator upload succeeds and version becomes ACTIVE + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) captured = stub_fireworks @@ -1133,7 +1101,7 @@ def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_h monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) # Prepare two JSONL paths: one explicit via --dataset-jsonl and one inferable via input_dataset explicit_jsonl = project / "metric" / "explicit.jsonl" diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 1dad3b19..0d4bb13e 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -107,7 +107,7 @@ def mock_validate_upload(evaluator_id, version_id): try: os.chdir(tmp_dir) - api_response = create_evaluation( + api_response, version_id = create_evaluation( evaluator_id="test-eval", display_name="Test Evaluator", description="Test description", @@ -118,6 +118,9 @@ def mock_validate_upload(evaluator_id, version_id): assert api_response.display_name == "Test Evaluator" assert api_response.description == "Test description" + # Verify version ID was returned + assert version_id == "v1", "Version ID should be returned" + # Verify full upload flow was executed assert create_called, "Create endpoint should be called" assert version_create_called, "Version create should be called" From 470230785c8bcb75c06728910dc608124dd2eb23 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 15:52:45 -0800 Subject: [PATCH 19/39] update --- tests/test_upload_entrypoint.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_upload_entrypoint.py b/tests/test_upload_entrypoint.py index 2ae23024..076a6f79 100644 --- a/tests/test_upload_entrypoint.py +++ b/tests/test_upload_entrypoint.py @@ -28,8 +28,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - # Simulate API response - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -40,7 +40,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -72,7 +71,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -83,7 +83,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -119,7 +118,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -130,7 +130,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -163,8 +162,8 @@ def test_llm_judge(row=None): monkeypatch.setenv("FIREWORKS_API_BASE", "https://dev.api.fireworks.ai") def fake_create_evaluation(**kwargs): - # Simulate creation result with evaluator name - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate creation result with evaluator name - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -174,7 +173,6 @@ def fake_create_evaluation(**kwargs): id="quickstart-test-llm-judge", display_name=None, description=None, - force=True, yes=True, ) @@ -204,7 +202,8 @@ def test_llm_judge(row=None): monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") def fake_create_evaluation(**kwargs): - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -214,7 +213,6 @@ def fake_create_evaluation(**kwargs): id="quickstart-test-llm-judge", display_name=None, description=None, - force=False, yes=True, ) From 9d1bc74b09317ac3e519a31969faa743ac0e5e5f Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 15:53:13 -0800 Subject: [PATCH 20/39] use published a22 of fireworks-ai --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4b2a92b1..841f5ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "pytest-asyncio>=0.21.0", "peewee>=3.18.2", "backoff>=2.2.0", - "fireworks-ai @ https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl", + "fireworks-ai==1.0.0a22", "questionary>=2.0.0", # Dependencies for vendored tau2 package "toml>=0.10.0", From 3314becfcdf35f771c41988a24f38dcb91593203 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 15:54:22 -0800 Subject: [PATCH 21/39] uv lock --- uv.lock | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/uv.lock b/uv.lock index c5df7b2f..188413f6 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,7 +1312,7 @@ requires-dist = [ { name = "dspy", marker = "extra == 'dspy'", specifier = ">=3.0.0" }, { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.116.1" }, - { name = "fireworks-ai", url = "https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl" }, + { name = "fireworks-ai", specifier = "==1.0.0a22" }, { name = "google-auth", marker = "extra == 'bigquery'", specifier = ">=2.0.0" }, { name = "google-cloud-bigquery", marker = "extra == 'bigquery'", specifier = ">=3.0.0" }, { name = "gymnasium", marker = "extra == 'dev'", specifier = ">=1.2.0" }, @@ -1582,8 +1582,8 @@ wheels = [ [[package]] name = "fireworks-ai" -version = "1.0.0a21" -source = { url = "https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl" } +version = "1.0.0a22" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, { name = "anyio" }, @@ -1594,20 +1594,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/ef/16/073cf6855d18e43c14972d4b8f8fe59a43e41b581d430a7fad1dae3b8ddf/fireworks_ai-1.0.0a22.tar.gz", hash = "sha256:ab6fc7ad2beb8d69454b8c8c34ccd5d97ffa8cefa308a5cac7e568676e4b1188", size = 572510, upload-time = "2026-01-13T23:52:12.538Z" } wheels = [ - { url = "https://pkg.stainless.com/s/fireworks-ai-python/c35a37f39bced36305e3f2949484c3409a8f892e/fireworks_ai-1.0.0a21-py3-none-any.whl", hash = "sha256:882c45957fa4a5be55680b9a8972381ec83875834f097ab743981716819aecb7" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiohttp" }, - { name = "anyio", specifier = ">=3.5.0,<5" }, - { name = "distro", specifier = ">=1.7.0,<2" }, - { name = "httpx", specifier = ">=0.23.0,<1" }, - { name = "httpx-aiohttp", specifier = ">=0.1.9" }, - { name = "pydantic", specifier = ">=1.9.0,<3" }, - { name = "sniffio" }, - { name = "typing-extensions", specifier = ">=4.10,<5" }, + { url = "https://files.pythonhosted.org/packages/26/ef/a932f1fc357b7847258c212d53c074df3956ffbcbf74b2d5c3fdf14fd805/fireworks_ai-1.0.0a22-py3-none-any.whl", hash = "sha256:4ee18a0cb454585baab4803d82ec647d70fd8078a737a7ca4be7a686bc468ce3", size = 316745, upload-time = "2026-01-13T23:52:11.268Z" }, ] [[package]] From 66f191a09db5364b9cd9bb21230e1f48e50be724 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 16:25:00 -0800 Subject: [PATCH 22/39] Refactor dotenv handling in auth module and integrate environment variable loading into local test command. Introduced functions to find and retrieve values from .env files, enhancing configuration management for Docker tests. --- eval_protocol/auth.py | 70 +++++++++++++++++++----- eval_protocol/cli_commands/local_test.py | 7 +++ 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 7be1aed5..40e3c777 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -1,30 +1,72 @@ import logging import os -from typing import Optional +from typing import Dict, Optional import requests -from dotenv import find_dotenv, load_dotenv +from dotenv import dotenv_values, find_dotenv, load_dotenv logger = logging.getLogger(__name__) + +def find_dotenv_path(search_path: Optional[str] = None) -> Optional[str]: + """ + Find the .env file path, searching .env.dev first, then .env. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Path to the .env file if found, otherwise None. + """ + # If a specific search path is provided, look there first + if search_path: + env_dev_path = os.path.join(search_path, ".env.dev") + if os.path.isfile(env_dev_path): + return env_dev_path + env_path = os.path.join(search_path, ".env") + if os.path.isfile(env_path): + return env_path + return None + + # Otherwise use find_dotenv to search up the directory tree + env_dev_path = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) + if env_dev_path: + return env_dev_path + env_path = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) + if env_path: + return env_path + return None + + +def get_dotenv_values(search_path: Optional[str] = None) -> Dict[str, Optional[str]]: + """ + Get all key-value pairs from the .env file. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Dictionary of environment variable names to values. + """ + dotenv_path = find_dotenv_path(search_path) + if dotenv_path: + return dotenv_values(dotenv_path) + return {} + + # --- Load .env files --- # Attempt to load .env.dev first, then .env as a fallback. # This happens when the module is imported. # We use override=False (default) so that existing environment variables # (e.g., set in the shell) are NOT overridden by .env files. -_ENV_DEV_PATH = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) -if _ENV_DEV_PATH: - load_dotenv(dotenv_path=_ENV_DEV_PATH, override=False) - logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_ENV_DEV_PATH}") +_DOTENV_PATH = find_dotenv_path() +if _DOTENV_PATH: + load_dotenv(dotenv_path=_DOTENV_PATH, override=False) + logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_DOTENV_PATH}") else: - _ENV_PATH = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) - if _ENV_PATH: - load_dotenv(dotenv_path=_ENV_PATH, override=False) - logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_ENV_PATH}") - else: - logger.debug( - "eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables." - ) + logger.debug( + "eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables." + ) # --- End .env loading --- diff --git a/eval_protocol/cli_commands/local_test.py b/eval_protocol/cli_commands/local_test.py index 97e02e9f..43a59a3f 100644 --- a/eval_protocol/cli_commands/local_test.py +++ b/eval_protocol/cli_commands/local_test.py @@ -5,6 +5,7 @@ import sys from typing import List +from ..auth import get_dotenv_values from .utils import _build_entry_point, _discover_and_select_tests @@ -71,6 +72,12 @@ def _run_pytest_in_docker( workdir, ] + # Forward environment variables from .env file to the container + dotenv_vars = get_dotenv_values(project_root) + for key, value in dotenv_vars.items(): + if value is not None: + cmd += ["-e", f"{key}={value}"] + # If EP_SUMMARY_JSON is set on the host, mirror it into the container so that # pytest evaluation tests can write summary artifacts that are visible to the # host. We map paths under the host logs directory (~/.eval_protocol) into the From 165afe1d5f0e256272c2fa09eb7662d59448f3c1 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 16:29:07 -0800 Subject: [PATCH 23/39] add create rft launch configuration --- .vscode/launch.json.example | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/.vscode/launch.json.example b/.vscode/launch.json.example index df851aa3..7b70e735 100644 --- a/.vscode/launch.json.example +++ b/.vscode/launch.json.example @@ -32,6 +32,29 @@ "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" } + }, + { + "name": "EP: Create RFT", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": [ + "create", + "rft", + "--base-model", + "accounts/fireworks/models/qwen3-0p6b", + "--chunk-size", + "10" + ], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } } ] } From 838c7a50ec75b9e64a95b206f541863400d49668 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 13 Jan 2026 16:25:00 -0800 Subject: [PATCH 24/39] Refactor dotenv handling in auth module and integrate environment variable loading into local test command. Introduced functions to find and retrieve values from .env files, enhancing configuration management for Docker tests. --- eval_protocol/auth.py | 65 +++++++++++++++++++++++- eval_protocol/cli_commands/local_test.py | 7 +++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 68ce134c..19b3e76d 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -1,12 +1,75 @@ import logging import os -from typing import Optional +from typing import Dict, Optional import requests +from dotenv import dotenv_values, find_dotenv, load_dotenv logger = logging.getLogger(__name__) +def find_dotenv_path(search_path: Optional[str] = None) -> Optional[str]: + """ + Find the .env file path, searching .env.dev first, then .env. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Path to the .env file if found, otherwise None. + """ + # If a specific search path is provided, look there first + if search_path: + env_dev_path = os.path.join(search_path, ".env.dev") + if os.path.isfile(env_dev_path): + return env_dev_path + env_path = os.path.join(search_path, ".env") + if os.path.isfile(env_path): + return env_path + return None + + # Otherwise use find_dotenv to search up the directory tree + env_dev_path = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) + if env_dev_path: + return env_dev_path + env_path = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) + if env_path: + return env_path + return None + + +def get_dotenv_values(search_path: Optional[str] = None) -> Dict[str, Optional[str]]: + """ + Get all key-value pairs from the .env file. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Dictionary of environment variable names to values. + """ + dotenv_path = find_dotenv_path(search_path) + if dotenv_path: + return dotenv_values(dotenv_path) + return {} + + +# --- Load .env files --- +# Attempt to load .env.dev first, then .env as a fallback. +# This happens when the module is imported. +# We use override=False (default) so that existing environment variables +# (e.g., set in the shell) are NOT overridden by .env files. +_DOTENV_PATH = find_dotenv_path() +if _DOTENV_PATH: + load_dotenv(dotenv_path=_DOTENV_PATH, override=False) + logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_DOTENV_PATH}") +else: + logger.debug( + "eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables." + ) +# --- End .env loading --- + + def get_fireworks_api_key() -> Optional[str]: """ Retrieves the Fireworks API key. diff --git a/eval_protocol/cli_commands/local_test.py b/eval_protocol/cli_commands/local_test.py index 97e02e9f..43a59a3f 100644 --- a/eval_protocol/cli_commands/local_test.py +++ b/eval_protocol/cli_commands/local_test.py @@ -5,6 +5,7 @@ import sys from typing import List +from ..auth import get_dotenv_values from .utils import _build_entry_point, _discover_and_select_tests @@ -71,6 +72,12 @@ def _run_pytest_in_docker( workdir, ] + # Forward environment variables from .env file to the container + dotenv_vars = get_dotenv_values(project_root) + for key, value in dotenv_vars.items(): + if value is not None: + cmd += ["-e", f"{key}={value}"] + # If EP_SUMMARY_JSON is set on the host, mirror it into the container so that # pytest evaluation tests can write summary artifacts that are visible to the # host. We map paths under the host logs directory (~/.eval_protocol) into the From 0144c9fd67690c630927d340ed595772c2fa9ea5 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 10:20:11 -0800 Subject: [PATCH 25/39] actually not necessary for local test since local-test mounts the workspace so it should include the .env file --- eval_protocol/cli_commands/local_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/eval_protocol/cli_commands/local_test.py b/eval_protocol/cli_commands/local_test.py index 43a59a3f..97e02e9f 100644 --- a/eval_protocol/cli_commands/local_test.py +++ b/eval_protocol/cli_commands/local_test.py @@ -5,7 +5,6 @@ import sys from typing import List -from ..auth import get_dotenv_values from .utils import _build_entry_point, _discover_and_select_tests @@ -72,12 +71,6 @@ def _run_pytest_in_docker( workdir, ] - # Forward environment variables from .env file to the container - dotenv_vars = get_dotenv_values(project_root) - for key, value in dotenv_vars.items(): - if value is not None: - cmd += ["-e", f"{key}={value}"] - # If EP_SUMMARY_JSON is set on the host, mirror it into the container so that # pytest evaluation tests can write summary artifacts that are visible to the # host. We map paths under the host logs directory (~/.eval_protocol) into the From c8774a6de60b14df9421798f7e031aae499495f3 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 10:23:12 -0800 Subject: [PATCH 26/39] increase sql retries --- eval_protocol/event_bus/sqlite_event_bus_database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 5086d6e3..59a026ed 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -11,8 +11,8 @@ # Retry configuration for database operations -SQLITE_RETRY_MAX_TRIES = 5 -SQLITE_RETRY_MAX_TIME = 30 # seconds +SQLITE_RETRY_MAX_TRIES = 10 +SQLITE_RETRY_MAX_TIME = 60 # seconds def _is_database_locked_error(e: Exception) -> bool: From 2076f0a5808d6243ae69997d376f307029cc8444 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 12:15:32 -0800 Subject: [PATCH 27/39] Refactor dotenv loading to use explicit paths in CLI and API modules - improving environment variable management and preventing conflicts with other .env files. --- eval_protocol/cli.py | 4 +- eval_protocol/mcp/mcp_multi_client.py | 3 - .../svg_agent/vercel_svg_server/api/init.py | 5 +- tests/test_no_implicit_dotenv.py | 209 ++++++++++++++++++ 4 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 tests/test_no_implicit_dotenv.py diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 51b5d4dd..9b3bb320 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -286,8 +286,10 @@ def main(): from dotenv import load_dotenv # .env.dev for development-specific overrides, .env for general + # Use explicit paths to avoid find_dotenv() searching up the directory tree + # and potentially finding a different .env file (e.g., in some other repo) load_dotenv(dotenv_path=Path(".") / ".env.dev", override=True) - load_dotenv(override=True) + load_dotenv(dotenv_path=Path(".") / ".env", override=True) except ImportError: pass diff --git a/eval_protocol/mcp/mcp_multi_client.py b/eval_protocol/mcp/mcp_multi_client.py index 4c138796..faa774a9 100644 --- a/eval_protocol/mcp/mcp_multi_client.py +++ b/eval_protocol/mcp/mcp_multi_client.py @@ -13,7 +13,6 @@ class FunctionLike(BaseModel): parameters: Any = None -from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client @@ -26,8 +25,6 @@ class FunctionLike(BaseModel): MCPMultiClientConfiguration, ) -load_dotenv() # load environment variables from .env - class MCPMultiClient: """ diff --git a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py index ffd8b9ea..87db9acb 100644 --- a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py +++ b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py @@ -13,11 +13,14 @@ from flask import Flask, request, jsonify from openai import OpenAI import openai +from pathlib import Path + from dotenv import load_dotenv from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter -load_dotenv() +# Use explicit path to avoid find_dotenv() searching up the directory tree +load_dotenv(dotenv_path=Path(".") / ".env") # Configure logging so INFO and below go to stdout, WARNING+ to stderr. # This avoids Vercel marking INFO logs as [error] (stderr). diff --git a/tests/test_no_implicit_dotenv.py b/tests/test_no_implicit_dotenv.py new file mode 100644 index 00000000..04855821 --- /dev/null +++ b/tests/test_no_implicit_dotenv.py @@ -0,0 +1,209 @@ +""" +Test to ensure load_dotenv() is never called without an explicit path. + +When load_dotenv() is called without a dotenv_path argument, it uses find_dotenv() +which searches up the directory tree for a .env file. This can cause unexpected +behavior when running the CLI from a subdirectory, as it may find a .env file +in a parent directory (e.g., the python-sdk repo's .env) instead of the intended +project's .env file. + +This test scans all Python files in the SDK to ensure that every call to +load_dotenv() includes an explicit dotenv_path argument. +""" + +import ast +import os +from pathlib import Path +from typing import List, Set, Tuple + +# Directories to scan for implicit load_dotenv calls +SCAN_DIRECTORIES = [ + "eval_protocol", +] + +# Directories to exclude from scanning (relative to repo root) +EXCLUDE_DIRECTORIES: Set[str] = { + ".venv", + ".git", + "__pycache__", + ".pytest_cache", + ".mypy_cache", + "node_modules", + "build", + "dist", + ".eggs", + "*.egg-info", +} + + +def find_implicit_load_dotenv_calls(file_path: Path) -> List[Tuple[int, str]]: + """ + Parse a Python file and find any load_dotenv() calls without explicit dotenv_path. + + Returns a list of (line_number, code_snippet) tuples for violations. + """ + violations = [] + + try: + with open(file_path, "r", encoding="utf-8") as f: + source = f.read() + except (IOError, UnicodeDecodeError): + return violations + + try: + tree = ast.parse(source, filename=str(file_path)) + except SyntaxError: + return violations + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + # Check if this is a call to load_dotenv + func_name = None + if isinstance(node.func, ast.Name): + func_name = node.func.id + elif isinstance(node.func, ast.Attribute): + func_name = node.func.attr + + if func_name == "load_dotenv": + # Check if dotenv_path is provided as a positional or keyword argument + has_explicit_path = False + + # Check positional arguments (dotenv_path is the first positional arg) + if node.args: + has_explicit_path = True + + # Check keyword arguments + for keyword in node.keywords: + if keyword.arg == "dotenv_path": + has_explicit_path = True + break + + if not has_explicit_path: + # Get the source line for context + try: + lines = source.splitlines() + line = lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + except (IndexError, AttributeError): + line = "" + + violations.append((node.lineno, line)) + + return violations + + +def _should_exclude_dir(dir_name: str) -> bool: + """Check if a directory should be excluded from scanning.""" + return dir_name in EXCLUDE_DIRECTORIES or dir_name.startswith(".") + + +def _scan_directory(directory: Path, repo_root: Path) -> List[Tuple[Path, int, str]]: + """Scan a directory for implicit load_dotenv calls.""" + all_violations: List[Tuple[Path, int, str]] = [] + + for root, dirs, files in os.walk(directory): + # Filter out excluded directories in-place to prevent os.walk from descending into them + dirs[:] = [d for d in dirs if not _should_exclude_dir(d)] + + for filename in files: + if not filename.endswith(".py"): + continue + + file_path = Path(root) / filename + violations = find_implicit_load_dotenv_calls(file_path) + + for line_no, code in violations: + all_violations.append((file_path, line_no, code)) + + return all_violations + + +def test_no_implicit_load_dotenv_calls(): + """ + Ensure no load_dotenv() calls exist without an explicit dotenv_path argument. + + This prevents the CLI from accidentally loading .env files from parent directories + when running from a subdirectory. + """ + repo_root = Path(__file__).parent.parent + + all_violations: List[Tuple[Path, int, str]] = [] + + for scan_dir in SCAN_DIRECTORIES: + directory = repo_root / scan_dir + if directory.exists(): + violations = _scan_directory(directory, repo_root) + all_violations.extend(violations) + + if all_violations: + error_msg = [ + "Found load_dotenv() calls without explicit dotenv_path argument.", + "This can cause the CLI to load .env files from parent directories unexpectedly.", + "", + "Violations:", + ] + for file_path, line_no, code in all_violations: + try: + rel_path = file_path.relative_to(repo_root) + except ValueError: + rel_path = file_path + error_msg.append(f" {rel_path}:{line_no}: {code}") + + error_msg.extend( + [ + "", + "Fix by providing an explicit path:", + " load_dotenv(dotenv_path=Path('.') / '.env', override=True)", + "", + ] + ) + + assert False, "\n".join(error_msg) + + +def test_load_dotenv_ast_detection(): + """Test that our AST detection correctly identifies implicit vs explicit calls.""" + import tempfile + + # Test case: implicit call (should be detected) + implicit_code = """ +from dotenv import load_dotenv +load_dotenv() +load_dotenv(override=True) +load_dotenv(verbose=True, override=True) +""" + + # Test case: explicit call (should NOT be detected) + explicit_code = """ +from dotenv import load_dotenv +load_dotenv(dotenv_path='.env') +load_dotenv('.env') +load_dotenv(Path('.') / '.env') +load_dotenv(dotenv_path=Path('.') / '.env', override=True) +load_dotenv(env_file_path) # positional arg counts as explicit +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(implicit_code) + implicit_file = Path(f.name) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(explicit_code) + explicit_file = Path(f.name) + + try: + implicit_violations = find_implicit_load_dotenv_calls(implicit_file) + explicit_violations = find_implicit_load_dotenv_calls(explicit_file) + + # Should find 3 violations in implicit code + assert len(implicit_violations) == 3, ( + f"Expected 3 implicit violations, got {len(implicit_violations)}: {implicit_violations}" + ) + + # Should find 0 violations in explicit code + assert len(explicit_violations) == 0, ( + f"Expected 0 explicit violations, got {len(explicit_violations)}: {explicit_violations}" + ) + + finally: + implicit_file.unlink() + explicit_file.unlink() From 432a64929bb210329abad9cb6f3ef7106cef1c73 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 12:15:32 -0800 Subject: [PATCH 28/39] Refactor dotenv loading to use explicit paths in CLI and API modules - improving environment variable management and preventing conflicts with other .env files. --- eval_protocol/cli.py | 4 +- eval_protocol/mcp/mcp_multi_client.py | 3 - .../svg_agent/vercel_svg_server/api/init.py | 5 +- tests/test_no_implicit_dotenv.py | 209 ++++++++++++++++++ 4 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 tests/test_no_implicit_dotenv.py diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index ac8a8d9d..4222cab9 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -284,8 +284,10 @@ def main(): from dotenv import load_dotenv # .env.dev for development-specific overrides, .env for general + # Use explicit paths to avoid find_dotenv() searching up the directory tree + # and potentially finding a different .env file (e.g., in some other repo) load_dotenv(dotenv_path=Path(".") / ".env.dev", override=True) - load_dotenv(override=True) + load_dotenv(dotenv_path=Path(".") / ".env", override=True) except ImportError: pass diff --git a/eval_protocol/mcp/mcp_multi_client.py b/eval_protocol/mcp/mcp_multi_client.py index 4c138796..faa774a9 100644 --- a/eval_protocol/mcp/mcp_multi_client.py +++ b/eval_protocol/mcp/mcp_multi_client.py @@ -13,7 +13,6 @@ class FunctionLike(BaseModel): parameters: Any = None -from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client @@ -26,8 +25,6 @@ class FunctionLike(BaseModel): MCPMultiClientConfiguration, ) -load_dotenv() # load environment variables from .env - class MCPMultiClient: """ diff --git a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py index ffd8b9ea..87db9acb 100644 --- a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py +++ b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py @@ -13,11 +13,14 @@ from flask import Flask, request, jsonify from openai import OpenAI import openai +from pathlib import Path + from dotenv import load_dotenv from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter -load_dotenv() +# Use explicit path to avoid find_dotenv() searching up the directory tree +load_dotenv(dotenv_path=Path(".") / ".env") # Configure logging so INFO and below go to stdout, WARNING+ to stderr. # This avoids Vercel marking INFO logs as [error] (stderr). diff --git a/tests/test_no_implicit_dotenv.py b/tests/test_no_implicit_dotenv.py new file mode 100644 index 00000000..04855821 --- /dev/null +++ b/tests/test_no_implicit_dotenv.py @@ -0,0 +1,209 @@ +""" +Test to ensure load_dotenv() is never called without an explicit path. + +When load_dotenv() is called without a dotenv_path argument, it uses find_dotenv() +which searches up the directory tree for a .env file. This can cause unexpected +behavior when running the CLI from a subdirectory, as it may find a .env file +in a parent directory (e.g., the python-sdk repo's .env) instead of the intended +project's .env file. + +This test scans all Python files in the SDK to ensure that every call to +load_dotenv() includes an explicit dotenv_path argument. +""" + +import ast +import os +from pathlib import Path +from typing import List, Set, Tuple + +# Directories to scan for implicit load_dotenv calls +SCAN_DIRECTORIES = [ + "eval_protocol", +] + +# Directories to exclude from scanning (relative to repo root) +EXCLUDE_DIRECTORIES: Set[str] = { + ".venv", + ".git", + "__pycache__", + ".pytest_cache", + ".mypy_cache", + "node_modules", + "build", + "dist", + ".eggs", + "*.egg-info", +} + + +def find_implicit_load_dotenv_calls(file_path: Path) -> List[Tuple[int, str]]: + """ + Parse a Python file and find any load_dotenv() calls without explicit dotenv_path. + + Returns a list of (line_number, code_snippet) tuples for violations. + """ + violations = [] + + try: + with open(file_path, "r", encoding="utf-8") as f: + source = f.read() + except (IOError, UnicodeDecodeError): + return violations + + try: + tree = ast.parse(source, filename=str(file_path)) + except SyntaxError: + return violations + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + # Check if this is a call to load_dotenv + func_name = None + if isinstance(node.func, ast.Name): + func_name = node.func.id + elif isinstance(node.func, ast.Attribute): + func_name = node.func.attr + + if func_name == "load_dotenv": + # Check if dotenv_path is provided as a positional or keyword argument + has_explicit_path = False + + # Check positional arguments (dotenv_path is the first positional arg) + if node.args: + has_explicit_path = True + + # Check keyword arguments + for keyword in node.keywords: + if keyword.arg == "dotenv_path": + has_explicit_path = True + break + + if not has_explicit_path: + # Get the source line for context + try: + lines = source.splitlines() + line = lines[node.lineno - 1].strip() if node.lineno <= len(lines) else "" + except (IndexError, AttributeError): + line = "" + + violations.append((node.lineno, line)) + + return violations + + +def _should_exclude_dir(dir_name: str) -> bool: + """Check if a directory should be excluded from scanning.""" + return dir_name in EXCLUDE_DIRECTORIES or dir_name.startswith(".") + + +def _scan_directory(directory: Path, repo_root: Path) -> List[Tuple[Path, int, str]]: + """Scan a directory for implicit load_dotenv calls.""" + all_violations: List[Tuple[Path, int, str]] = [] + + for root, dirs, files in os.walk(directory): + # Filter out excluded directories in-place to prevent os.walk from descending into them + dirs[:] = [d for d in dirs if not _should_exclude_dir(d)] + + for filename in files: + if not filename.endswith(".py"): + continue + + file_path = Path(root) / filename + violations = find_implicit_load_dotenv_calls(file_path) + + for line_no, code in violations: + all_violations.append((file_path, line_no, code)) + + return all_violations + + +def test_no_implicit_load_dotenv_calls(): + """ + Ensure no load_dotenv() calls exist without an explicit dotenv_path argument. + + This prevents the CLI from accidentally loading .env files from parent directories + when running from a subdirectory. + """ + repo_root = Path(__file__).parent.parent + + all_violations: List[Tuple[Path, int, str]] = [] + + for scan_dir in SCAN_DIRECTORIES: + directory = repo_root / scan_dir + if directory.exists(): + violations = _scan_directory(directory, repo_root) + all_violations.extend(violations) + + if all_violations: + error_msg = [ + "Found load_dotenv() calls without explicit dotenv_path argument.", + "This can cause the CLI to load .env files from parent directories unexpectedly.", + "", + "Violations:", + ] + for file_path, line_no, code in all_violations: + try: + rel_path = file_path.relative_to(repo_root) + except ValueError: + rel_path = file_path + error_msg.append(f" {rel_path}:{line_no}: {code}") + + error_msg.extend( + [ + "", + "Fix by providing an explicit path:", + " load_dotenv(dotenv_path=Path('.') / '.env', override=True)", + "", + ] + ) + + assert False, "\n".join(error_msg) + + +def test_load_dotenv_ast_detection(): + """Test that our AST detection correctly identifies implicit vs explicit calls.""" + import tempfile + + # Test case: implicit call (should be detected) + implicit_code = """ +from dotenv import load_dotenv +load_dotenv() +load_dotenv(override=True) +load_dotenv(verbose=True, override=True) +""" + + # Test case: explicit call (should NOT be detected) + explicit_code = """ +from dotenv import load_dotenv +load_dotenv(dotenv_path='.env') +load_dotenv('.env') +load_dotenv(Path('.') / '.env') +load_dotenv(dotenv_path=Path('.') / '.env', override=True) +load_dotenv(env_file_path) # positional arg counts as explicit +""" + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(implicit_code) + implicit_file = Path(f.name) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(explicit_code) + explicit_file = Path(f.name) + + try: + implicit_violations = find_implicit_load_dotenv_calls(implicit_file) + explicit_violations = find_implicit_load_dotenv_calls(explicit_file) + + # Should find 3 violations in implicit code + assert len(implicit_violations) == 3, ( + f"Expected 3 implicit violations, got {len(implicit_violations)}: {implicit_violations}" + ) + + # Should find 0 violations in explicit code + assert len(explicit_violations) == 0, ( + f"Expected 0 explicit violations, got {len(explicit_violations)}: {explicit_violations}" + ) + + finally: + implicit_file.unlink() + explicit_file.unlink() From 3c2db598adc0df5ef1f34927fa3db35b599384ec Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 13:58:14 -0800 Subject: [PATCH 29/39] "ep create evj" --- eval_protocol/cli.py | 78 +++++++- eval_protocol/cli_commands/create_evj.py | 227 +++++++++++++++++++++++ eval_protocol/cli_commands/create_rft.py | 141 ++++---------- eval_protocol/cli_commands/utils.py | 102 ++++++++++ 4 files changed, 447 insertions(+), 101 deletions(-) create mode 100644 eval_protocol/cli_commands/create_evj.py diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 9b3bb320..ea862d70 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -210,6 +210,78 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse help_overrides=help_overrides, ) + # Create evj (Evaluation Job) subcommand + evj_parser = create_subparsers.add_parser( + "evj", + help="Create an Evaluation Job on Fireworks", + ) + + evj_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") + evj_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") + evj_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") + evj_parser.add_argument( + "--ignore-docker", + action="store_true", + help="Ignore Dockerfile even if present; run pytest on host during evaluator validation", + ) + evj_parser.add_argument( + "--docker-build-extra", + default="", + metavar="", + help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")", + ) + evj_parser.add_argument( + "--docker-run-extra", + default="", + metavar="", + help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")", + ) + evj_parser.add_argument( + "--quiet", + action="store_true", + help="If set, only errors will be printed.", + ) + + # Auto-generate flags from SDK Fireworks().evaluation_jobs.create() signature + create_evj_fn = Fireworks().evaluation_jobs.create + + evj_skip_fields = { + "__top_level__": { + "account_id", # auto-detected + "extra_headers", + "extra_query", + "extra_body", + "timeout", + }, + "evaluation_job": { + "output_stats", # read-only, set by server + }, + } + evj_aliases = { + "evaluation_job_id": ["--job-id"], + "evaluation_job.evaluator": ["--evaluator"], + "evaluation_job.input_dataset": ["--dataset"], # --input-dataset is auto-added + "evaluation_job.display_name": ["--name"], + # output_dataset, evaluator_version get their short forms auto-added + } + evj_help_overrides = { + "evaluation_job_id": "Evaluation Job ID to use", + "evaluation_job.evaluator": "Evaluator resource name (format: accounts/{account_id}/evaluators/{evaluator_id})", + "evaluation_job.input_dataset": "Input dataset resource name (format: accounts/{account_id}/datasets/{dataset_id})", + "evaluation_job.output_dataset": "Output dataset resource name where results will be written", + "evaluation_job.display_name": "Display name for the evaluation job", + "evaluation_job.evaluator_version": "Specific evaluator version to use (defaults to current version)", + "leaderboard_ids": "Optional leaderboard IDs to attach this job to upon creation", + } + + add_args_from_callable_signature( + evj_parser, + create_evj_fn, + skip_fields=evj_skip_fields, + aliases=evj_aliases, + help_overrides=evj_help_overrides, + ) + # Local test command local_test_parser = subparsers.add_parser( "local-test", @@ -351,7 +423,11 @@ def _extract_flag_value(argv_list, flag_name): from .cli_commands.create_rft import create_rft_command return create_rft_command(args) - print("Error: missing subcommand for 'create'. Try: eval-protocol create rft") + elif args.create_command == "evj": + from .cli_commands.create_evj import create_evj_command + + return create_evj_command(args) + print("Error: missing subcommand for 'create'. Try: eval-protocol create rft|evj") return 1 elif args.command == "local-test": from .cli_commands.local_test import local_test_command diff --git a/eval_protocol/cli_commands/create_evj.py b/eval_protocol/cli_commands/create_evj.py new file mode 100644 index 00000000..4b054561 --- /dev/null +++ b/eval_protocol/cli_commands/create_evj.py @@ -0,0 +1,227 @@ +import argparse +from fireworks._client import Fireworks +from fireworks.types.evaluation_job_create_response import EvaluationJobCreateResponse +import json +import os +import inspect +from typing import Any, Dict, Optional + +from ..auth import get_fireworks_api_base, get_fireworks_api_key +from ..fireworks_client import create_fireworks_client +from .utils import ( + _build_trimmed_dataset_id, + _build_evaluator_dashboard_url, + _ensure_account_id, + _extract_terminal_segment, + resolve_evaluator, + validate_evaluator_locally, +) +from .create_rft import ( + resolve_dataset, + upload_dataset, +) + + +def _resolve_output_dataset( + account_id: str, + output_dataset_arg: Optional[str], + evaluator_id: str, +) -> tuple[str, str]: + """Resolve output dataset id and resource name. + + If not provided, auto-generates an output dataset ID based on the evaluator ID. + """ + if output_dataset_arg: + if output_dataset_arg.startswith("accounts/"): + output_dataset_resource = output_dataset_arg + output_dataset_id = _extract_terminal_segment(output_dataset_arg) + else: + output_dataset_id = output_dataset_arg + output_dataset_resource = f"accounts/{account_id}/datasets/{output_dataset_id}" + else: + # Auto-generate output dataset ID + output_dataset_id = _build_trimmed_dataset_id(f"{evaluator_id}-results") + output_dataset_resource = f"accounts/{account_id}/datasets/{output_dataset_id}" + print(f"Auto-generated output dataset ID: {output_dataset_id}") + + return output_dataset_id, output_dataset_resource + + +def _print_evj_links( + evaluator_id: str, input_dataset_id: str, output_dataset_id: str, job_name: Optional[str] +) -> None: + """Print helpful links to the Fireworks dashboard.""" + print("\n📊 Links:") + print(f" Evaluator: {_build_evaluator_dashboard_url(evaluator_id)}") + if job_name: + # Extract job id from resource name if present + job_id = _extract_terminal_segment(job_name) if "/" in job_name else job_name + print(f" Evaluation Job: https://fireworks.ai/dashboard/evaluation-jobs/{job_id}") + print(f" Input Dataset: https://fireworks.ai/dashboard/datasets/{input_dataset_id}") + print(f" Output Dataset: https://fireworks.ai/dashboard/datasets/{output_dataset_id}") + + +def _create_evj_job( + account_id: str, + api_key: str, + api_base: str, + evaluator_id: str, + evaluator_resource_name: str, + input_dataset_id: str, + input_dataset_resource: str, + output_dataset_id: str, + output_dataset_resource: str, + args: argparse.Namespace, + dry_run: bool, +) -> int: + """Build and submit the Evaluation Job request (via Fireworks SDK).""" + + signature = inspect.signature(create_fireworks_client().evaluation_jobs.create) + + # Build top-level SDK kwargs + sdk_kwargs: Dict[str, Any] = {} + + # Build the evaluation_job nested object + evaluation_job: Dict[str, Any] = { + "evaluator": evaluator_resource_name, + "input_dataset": input_dataset_resource, + "output_dataset": output_dataset_resource, + } + + args_dict = vars(args) + + # Handle evaluation_job nested fields + for k, v in args_dict.items(): + if v is None: + continue + if k.startswith("evaluation_job_") and k != "evaluation_job_id": + field_name = k[len("evaluation_job_") :] + # Don't overwrite the normalized resources + if field_name in ("evaluator", "input_dataset", "output_dataset"): + continue + evaluation_job[field_name] = v + + sdk_kwargs["evaluation_job"] = evaluation_job + + # Handle top-level fields + for name in signature.parameters: + if name in ("account_id", "evaluation_job", "extra_headers", "extra_query", "extra_body", "timeout"): + continue + + value = args_dict.get(name) + if value is not None: + sdk_kwargs[name] = value + + print(f"Prepared Evaluation Job for evaluator '{evaluator_id}' using dataset '{input_dataset_id}'") + + if dry_run: + print("--dry-run: would call Fireworks().evaluation_jobs.create with kwargs:") + print(json.dumps(sdk_kwargs, indent=2)) + _print_evj_links(evaluator_id, input_dataset_id, output_dataset_id, None) + return 0 + + try: + fw: Fireworks = create_fireworks_client(api_key=api_key, base_url=api_base) + job: EvaluationJobCreateResponse = fw.evaluation_jobs.create(account_id=account_id, **sdk_kwargs) + job_name = job.name + print(f"\n✅ Created Evaluation Job: {job_name}") + _print_evj_links(evaluator_id, input_dataset_id, output_dataset_id, job_name) + return 0 + except Exception as e: + print(f"Error creating Evaluation Job: {e}") + return 1 + + +def create_evj_command(args) -> int: + """Main entry point for the 'create evj' CLI command.""" + # Pre-flight: resolve auth and environment + api_key = get_fireworks_api_key() + if not api_key: + print("Error: FIREWORKS_API_KEY not set.") + return 1 + + account_id = _ensure_account_id() + if not account_id: + print("Error: Could not resolve Fireworks account id from FIREWORKS_API_KEY.") + return 1 + + api_base = get_fireworks_api_base() + project_root = os.getcwd() + evaluator_arg: Optional[str] = getattr(args, "evaluation_job_evaluator", None) + input_dataset_arg: Optional[str] = getattr(args, "evaluation_job_input_dataset", None) + output_dataset_arg: Optional[str] = getattr(args, "evaluation_job_output_dataset", None) + non_interactive: bool = bool(getattr(args, "yes", False)) + dry_run: bool = bool(getattr(args, "dry_run", False)) + skip_validation: bool = bool(getattr(args, "skip_validation", False)) + ignore_docker: bool = bool(getattr(args, "ignore_docker", False)) + docker_build_extra: str = getattr(args, "docker_build_extra", "") or "" + docker_run_extra: str = getattr(args, "docker_run_extra", "") or "" + + # 1) Resolve evaluator and associated local test + ( + evaluator_id, + evaluator_resource_name, + selected_test_file_path, + selected_test_func_name, + ) = resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id, command_name="create evj") + if not evaluator_id or not evaluator_resource_name: + return 1 + + # 2) Resolve input dataset source (id or JSONL path) + input_dataset_id, input_dataset_resource, dataset_jsonl = resolve_dataset( + project_root=project_root, + account_id=account_id, + dataset_id_arg=input_dataset_arg, + dataset_jsonl_arg=None, # EVJ doesn't support --dataset-jsonl flag yet + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, + ) + # Require either an existing dataset id or a JSONL source to materialize from + if dataset_jsonl is None and not input_dataset_id: + return 1 + + # 3) Resolve output dataset + output_dataset_id, output_dataset_resource = _resolve_output_dataset(account_id, output_dataset_arg, evaluator_id) + + # 4) Optional local validation + if not skip_validation: + if not validate_evaluator_locally( + project_root=project_root, + selected_test_file=selected_test_file_path, + selected_test_func=selected_test_func_name, + ignore_docker=ignore_docker, + docker_build_extra=docker_build_extra, + docker_run_extra=docker_run_extra, + ): + return 1 + + # 5) Upload dataset when using JSONL sources (no-op for existing datasets) + input_dataset_id, input_dataset_resource = upload_dataset( + project_root=project_root, + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + dataset_id=input_dataset_id, + dataset_resource=input_dataset_resource, + dataset_jsonl=dataset_jsonl, + dataset_display_name=None, # EVJ auto-generates display name + dry_run=dry_run, + ) + if not input_dataset_id or not input_dataset_resource: + return 1 + + # 6) Create the Evaluation Job + return _create_evj_job( + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + evaluator_resource_name=evaluator_resource_name, + input_dataset_id=input_dataset_id, + input_dataset_resource=input_dataset_resource, + output_dataset_id=output_dataset_id, + output_dataset_resource=output_dataset_resource, + args=args, + dry_run=dry_run, + ) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 6a6123f1..12ad898c 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -31,6 +31,8 @@ _print_links, _resolve_selected_test, load_module_from_file_path, + resolve_evaluator, + validate_evaluator_locally, ) from .local_test import run_evaluator_test @@ -335,108 +337,31 @@ def _validate_dataset(dataset_jsonl: Optional[str]) -> bool: return _validate_dataset_jsonl(dataset_jsonl) -def _validate_evaluator_locally( - project_root: str, - selected_test_file: Optional[str], - selected_test_func: Optional[str], - ignore_docker: bool, - docker_build_extra: str, - docker_run_extra: str, -) -> bool: - """Run pytest locally for the selected evaluation test to validate the evaluator. - - The pytest helpers always enforce a small success threshold (0.01) for - evaluation_test-based suites so that an evaluation run where all scores are - 0.0 will naturally fail with a non-zero pytest exit code, which we then treat - as a failed validator. - """ - if not selected_test_file or not selected_test_func: - # No local test associated; skip validation but warn the user. - print("Warning: Could not resolve a local evaluation test for this evaluator; skipping local validation.") - return True - - pytest_target = _build_entry_point(project_root, selected_test_file, selected_test_func) - exit_code = run_evaluator_test( - project_root=project_root, - pytest_target=pytest_target, - ignore_docker=ignore_docker, - docker_build_extra=docker_build_extra, - docker_run_extra=docker_run_extra, - ) - return exit_code == 0 - - -def _resolve_evaluator( - project_root: str, - evaluator_arg: Optional[str], - non_interactive: bool, - account_id: str, -) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: - """Resolve evaluator id/resource and associated local test (file + func).""" - evaluator_id = evaluator_arg - selected_test_file_path: Optional[str] = None - selected_test_func_name: Optional[str] = None - - if not evaluator_id: - selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive) - if not selected_tests: - return None, None, None, None - - if len(selected_tests) != 1: - if non_interactive and len(selected_tests) > 1: - print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") - print(" Please pass --evaluator or --entry to disambiguate.") - else: - print("Error: Please select exactly one evaluation test for 'create rft'.") - return None, None, None, None - - chosen = selected_tests[0] - func_name = chosen.qualname.split(".")[-1] - source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] - evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}") - # Resolve selected test once for downstream - selected_test_file_path, selected_test_func_name = _resolve_selected_test( - project_root, evaluator_id, selected_tests=selected_tests - ) - else: - # Caller provided an evaluator id or fully-qualified resource; try to resolve local test - short_id = evaluator_id - if evaluator_id.startswith("accounts/"): - short_id = _extract_terminal_segment(evaluator_id) - st_path, st_func = _resolve_selected_test(project_root, short_id) - if st_path and st_func: - selected_test_file_path = st_path - selected_test_func_name = st_func - evaluator_id = short_id - - if not evaluator_id: - return None, None, None, None - - # Resolve evaluator resource name to fully-qualified format required by API. - if evaluator_arg and evaluator_arg.startswith("accounts/"): - evaluator_resource_name = evaluator_arg - else: - evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" - - return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name - - -def _resolve_dataset( +def resolve_dataset( project_root: str, account_id: str, - args: argparse.Namespace, + dataset_id_arg: Optional[str], + dataset_jsonl_arg: Optional[str], selected_test_file_path: Optional[str], selected_test_func_name: Optional[str], ) -> tuple[Optional[str], Optional[str], Optional[str]]: """Resolve dataset source without performing any uploads. + Args: + project_root: Path to the project root directory. + account_id: The Fireworks account ID. + dataset_id_arg: Dataset ID or fully-qualified resource name (from --dataset). + dataset_jsonl_arg: Path to a local JSONL file (from --dataset-jsonl). + selected_test_file_path: Path to the selected test file (for inference). + selected_test_func_name: Name of the selected test function (for inference). + Returns a tuple of: - dataset_id: existing dataset id when using --dataset or fully-qualified dataset resource - dataset_resource: fully-qualified dataset resource for existing datasets; None for JSONL sources - dataset_jsonl: local JSONL path when using --dataset-jsonl or inferred sources; None for id-only datasets """ - dataset_id = getattr(args, "dataset", None) - dataset_jsonl = getattr(args, "dataset_jsonl", None) + dataset_id = dataset_id_arg + dataset_jsonl = dataset_jsonl_arg dataset_resource_override: Optional[str] = None if dataset_id and dataset_jsonl: @@ -506,7 +431,7 @@ def _resolve_dataset( return dataset_id, dataset_resource, dataset_jsonl -def _upload_dataset( +def upload_dataset( project_root: str, account_id: str, api_key: str, @@ -515,13 +440,28 @@ def _upload_dataset( dataset_id: Optional[str], dataset_resource: Optional[str], dataset_jsonl: Optional[str], - args: argparse.Namespace, + dataset_display_name: Optional[str], dry_run: bool, ) -> tuple[Optional[str], Optional[str]]: """Create/upload the dataset when using a local JSONL source. For existing datasets (--dataset or fully-qualified ids), this is a no-op that simply ensures dataset_id and dataset_resource are populated. + + Args: + project_root: Path to the project root directory. + account_id: The Fireworks account ID. + api_key: Fireworks API key. + api_base: Fireworks API base URL. + evaluator_id: The evaluator ID (used for generating dataset ID if needed). + dataset_id: Existing dataset ID (if known). + dataset_resource: Existing dataset resource name (if known). + dataset_jsonl: Path to local JSONL file to upload (if any). + dataset_display_name: Display name for the dataset (optional). + dry_run: If True, simulate the upload without actually creating. + + Returns: + A tuple of (dataset_id, dataset_resource). """ # Existing dataset case: nothing to upload if not dataset_jsonl: @@ -533,7 +473,7 @@ def _upload_dataset( # JSONL-based dataset: upload or simulate upload inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) - dataset_display_name = getattr(args, "dataset_display_name", None) or inferred_dataset_id + display_name = dataset_display_name or inferred_dataset_id # Resolve dataset_jsonl path relative to CWD if needed jsonl_path_for_upload = ( @@ -552,7 +492,7 @@ def _upload_dataset( api_key=api_key, api_base=api_base, dataset_id=inferred_dataset_id, - display_name=dataset_display_name, + display_name=display_name, jsonl_path=jsonl_path_for_upload, ) print(f"✓ Created and uploaded dataset: {dataset_id}") @@ -719,15 +659,16 @@ def create_rft_command(args) -> int: evaluator_resource_name, selected_test_file_path, selected_test_func_name, - ) = _resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id) + ) = resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id, command_name="create rft") if not evaluator_id or not evaluator_resource_name: return 1 # 2) Resolve dataset source (id or JSONL path) - dataset_id, dataset_resource, dataset_jsonl = _resolve_dataset( + dataset_id, dataset_resource, dataset_jsonl = resolve_dataset( project_root=project_root, account_id=account_id, - args=args, + dataset_id_arg=getattr(args, "dataset", None), + dataset_jsonl_arg=getattr(args, "dataset_jsonl", None), selected_test_file_path=selected_test_file_path, selected_test_func_name=selected_test_func_name, ) @@ -752,7 +693,7 @@ def create_rft_command(args) -> int: return 1 # Evaluator validation (run pytest for the selected test, possibly via Docker) - if not _validate_evaluator_locally( + if not validate_evaluator_locally( project_root=project_root, selected_test_file=selected_test_file_path, selected_test_func=selected_test_func_name, @@ -763,7 +704,7 @@ def create_rft_command(args) -> int: return 1 # 4) Upload dataset when using JSONL sources (no-op for existing datasets) - dataset_id, dataset_resource = _upload_dataset( + dataset_id, dataset_resource = upload_dataset( project_root=project_root, account_id=account_id, api_key=api_key, @@ -772,7 +713,7 @@ def create_rft_command(args) -> int: dataset_id=dataset_id, dataset_resource=dataset_resource, dataset_jsonl=dataset_jsonl, - args=args, + dataset_display_name=getattr(args, "dataset_display_name", None), dry_run=dry_run, ) if not dataset_id or not dataset_resource: diff --git a/eval_protocol/cli_commands/utils.py b/eval_protocol/cli_commands/utils.py index 1338ae31..804a2ae6 100644 --- a/eval_protocol/cli_commands/utils.py +++ b/eval_protocol/cli_commands/utils.py @@ -752,3 +752,105 @@ def add_args_from_callable_signature( help_text = help_overrides.get(name, help.get(name)) _add_flag(parser, flags, hints.get(name), help_text) + + +def validate_evaluator_locally( + project_root: str, + selected_test_file: Optional[str], + selected_test_func: Optional[str], + ignore_docker: bool, + docker_build_extra: str, + docker_run_extra: str, +) -> bool: + """Run pytest locally for the selected evaluation test to validate the evaluator. + + The pytest helpers always enforce a small success threshold (0.01) for + evaluation_test-based suites so that an evaluation run where all scores are + 0.0 will naturally fail with a non-zero pytest exit code, which we then treat + as a failed validator. + """ + # Lazy import to avoid circular dependency (local_test imports from utils) + from .local_test import run_evaluator_test + + if not selected_test_file or not selected_test_func: + # No local test associated; skip validation but warn the user. + print("Warning: Could not resolve a local evaluation test for this evaluator; skipping local validation.") + return True + + pytest_target = _build_entry_point(project_root, selected_test_file, selected_test_func) + exit_code = run_evaluator_test( + project_root=project_root, + pytest_target=pytest_target, + ignore_docker=ignore_docker, + docker_build_extra=docker_build_extra, + docker_run_extra=docker_run_extra, + ) + return exit_code == 0 + + +def resolve_evaluator( + project_root: str, + evaluator_arg: Optional[str], + non_interactive: bool, + account_id: str, + command_name: str = "create", +) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Resolve evaluator id/resource and associated local test (file + func). + + Args: + project_root: Path to the project root directory. + evaluator_arg: The evaluator argument provided by the user (id or fully-qualified resource). + non_interactive: Whether to run in non-interactive mode. + account_id: The Fireworks account ID. + command_name: The CLI command name for error messages (e.g., 'create rft', 'create evj'). + + Returns: + A tuple of (evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name). + Returns (None, None, None, None) if resolution fails. + """ + evaluator_id = evaluator_arg + selected_test_file_path: Optional[str] = None + selected_test_func_name: Optional[str] = None + + if not evaluator_id: + selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive) + if not selected_tests: + return None, None, None, None + + if len(selected_tests) != 1: + if non_interactive and len(selected_tests) > 1: + print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") + print(" Please pass --evaluator or --entry to disambiguate.") + else: + print(f"Error: Please select exactly one evaluation test for '{command_name}'.") + return None, None, None, None + + chosen = selected_tests[0] + func_name = chosen.qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] + evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}") + # Resolve selected test once for downstream + selected_test_file_path, selected_test_func_name = _resolve_selected_test( + project_root, evaluator_id, selected_tests=selected_tests + ) + else: + # Caller provided an evaluator id or fully-qualified resource; try to resolve local test + short_id = evaluator_id + if evaluator_id.startswith("accounts/"): + short_id = _extract_terminal_segment(evaluator_id) + st_path, st_func = _resolve_selected_test(project_root, short_id) + if st_path and st_func: + selected_test_file_path = st_path + selected_test_func_name = st_func + evaluator_id = short_id + + if not evaluator_id: + return None, None, None, None + + # Resolve evaluator resource name to fully-qualified format required by API. + if evaluator_arg and evaluator_arg.startswith("accounts/"): + evaluator_resource_name = evaluator_arg + else: + evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" + + return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name From 17eb18f428e756cde431f0319d4f8373fffc7ab1 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 13:58:25 -0800 Subject: [PATCH 30/39] use SDK for Dataset API calls --- eval_protocol/fireworks_rft.py | 104 +++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 148aefde..600296ba 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -1,5 +1,4 @@ import importlib.util -import io import json import os import sys @@ -9,12 +8,8 @@ import hashlib from pathlib import Path from typing import Any, Callable, Dict, Iterable, Optional, Tuple -from urllib.parse import urlencode -import requests - -from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key -from .common_utils import get_user_agent +from .fireworks_client import create_fireworks_client def _map_api_host_to_app_host(api_base: str) -> str: @@ -142,43 +137,80 @@ def create_dataset_from_jsonl( display_name: Optional[str], jsonl_path: str, ) -> Tuple[str, Dict[str, Any]]: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + """Create a dataset and upload a JSONL file using the Fireworks SDK client. + + This function uses the Fireworks SDK client which properly handles authentication + including extra headers set via FIREWORKS_EXTRA_HEADERS environment variable. + + Args: + account_id: The Fireworks account ID. + api_key: Fireworks API key. + api_base: Fireworks API base URL. + dataset_id: The ID for the new dataset. + display_name: Display name for the dataset (optional). + jsonl_path: Path to the JSONL file to upload. + + Returns: + A tuple of (dataset_id, dataset_response_dict). + + Raises: + RuntimeError: If dataset creation or upload fails. + """ # Count examples quickly example_count = 0 with open(jsonl_path, "r", encoding="utf-8") as f: for _ in f: example_count += 1 - dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets" - payload = { - "dataset": { - "displayName": display_name or dataset_id, - "evalProtocol": {}, - "format": "FORMAT_UNSPECIFIED", - "exampleCount": str(example_count), - }, - "datasetId": dataset_id, - } - resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60) - if resp.status_code not in (200, 201): - raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}") - ds = resp.json() - - upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" - with open(jsonl_path, "rb") as f: - files = {"file": f} - up_headers = { - "Authorization": f"Bearer {api_key}", - "User-Agent": get_user_agent(), + # Create Fireworks client with consistent configuration + client = create_fireworks_client( + api_key=api_key, + account_id=account_id, + base_url=api_base, + ) + + try: + # Create the dataset + dataset = client.datasets.create( + account_id=account_id, + dataset_id=dataset_id, + dataset={ + "display_name": display_name or dataset_id, + "eval_protocol": {}, + "format": "FORMAT_UNSPECIFIED", + "example_count": str(example_count), + }, + timeout=60.0, + ) + except Exception as e: + raise RuntimeError(f"Dataset creation failed: {e}") from e + + try: + # Upload the JSONL file + with open(jsonl_path, "rb") as f: + client.datasets.upload( + dataset_id=dataset_id, + account_id=account_id, + file=f, + timeout=600.0, + ) + except Exception as e: + raise RuntimeError(f"Dataset upload failed: {e}") from e + + # Convert SDK response to dict for backwards compatibility + ds_dict: Dict[str, Any] = {} + if hasattr(dataset, "model_dump"): + ds_dict = dataset.model_dump() + elif hasattr(dataset, "dict"): + ds_dict = dataset.dict() + else: + # Fallback: extract known fields + ds_dict = { + "name": getattr(dataset, "name", None), + "state": getattr(dataset, "state", None), } - up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) - if up_resp.status_code not in (200, 201): - raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") - return dataset_id, ds + + return dataset_id, ds_dict def build_default_dataset_id(evaluator_id: str) -> str: From 1fd66f75409f914ebcbc143e71eb809bfe613294 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 14:27:54 -0800 Subject: [PATCH 31/39] Implement evaluator upload and status polling in create commands - Added `upload_and_ensure_evaluator` function to handle evaluator uploads and ensure the latest version is ACTIVE. - Updated `create_evj_command` and `create_rft_command` to utilize the new upload function. - Removed redundant polling logic from `create_rft.py` and `create_evj.py`, centralizing it in the new utility function. - Adjusted tests to mock the new upload function correctly. --- eval_protocol/cli_commands/create_evj.py | 15 ++- eval_protocol/cli_commands/create_rft.py | 136 +-------------------- eval_protocol/cli_commands/utils.py | 149 +++++++++++++++++++++++ tests/test_cli_create_rft.py | 12 +- 4 files changed, 174 insertions(+), 138 deletions(-) diff --git a/eval_protocol/cli_commands/create_evj.py b/eval_protocol/cli_commands/create_evj.py index 4b054561..13063f77 100644 --- a/eval_protocol/cli_commands/create_evj.py +++ b/eval_protocol/cli_commands/create_evj.py @@ -14,6 +14,7 @@ _ensure_account_id, _extract_terminal_segment, resolve_evaluator, + upload_and_ensure_evaluator, validate_evaluator_locally, ) from .create_rft import ( @@ -211,7 +212,19 @@ def create_evj_command(args) -> int: if not input_dataset_id or not input_dataset_resource: return 1 - # 6) Create the Evaluation Job + # 6) Ensure evaluator exists and its latest version is ACTIVE (upload + poll if needed) + if not dry_run: + if not upload_and_ensure_evaluator( + project_root=project_root, + evaluator_id=evaluator_id, + api_key=api_key, + api_base=api_base, + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, + ): + return 1 + + # 7) Create the Evaluation Job return _create_evj_job( account_id=account_id, api_key=api_key, diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 12ad898c..b1550dbf 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -28,10 +28,12 @@ _ensure_account_id, _extract_terminal_segment, _normalize_evaluator_id, + _poll_evaluator_version_status, _print_links, _resolve_selected_test, load_module_from_file_path, resolve_evaluator, + upload_and_ensure_evaluator, validate_evaluator_locally, ) from .local_test import run_evaluator_test @@ -222,71 +224,6 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) return None -def _poll_evaluator_version_status( - evaluator_id: str, - version_id: str, - api_key: str, - api_base: str, - timeout_minutes: int = 10, -) -> bool: - """ - Poll a specific evaluator version status until it becomes ACTIVE or times out. - - Uses the Fireworks SDK to get the specified version of the evaluator and checks - its build state. - - Args: - evaluator_id: The evaluator ID (not full resource name) - version_id: The specific version ID to poll - api_key: Fireworks API key - api_base: Fireworks API base URL - timeout_minutes: Maximum time to wait in minutes - - Returns: - True if evaluator version becomes ACTIVE, False if timeout or BUILD_FAILED - """ - timeout_seconds = timeout_minutes * 60 - poll_interval = 10 # seconds - start_time = time.time() - - print( - f"Polling evaluator version '{version_id}' status (timeout: {timeout_minutes}m, interval: {poll_interval}s)..." - ) - - client = create_fireworks_client(api_key=api_key, base_url=api_base) - - while time.time() - start_time < timeout_seconds: - try: - version = client.evaluator_versions.get(version_id, evaluator_id=evaluator_id) - state = version.state or "STATE_UNSPECIFIED" - status_msg = "" - if version.status and version.status.message: - status_msg = version.status.message - - if state == "ACTIVE": - print("✅ Evaluator version is ACTIVE and ready!") - return True - elif state == "BUILD_FAILED": - print(f"❌ Evaluator version build failed. Status: {status_msg}") - return False - elif state == "BUILDING": - elapsed_minutes = (time.time() - start_time) / 60 - print(f"⏳ Evaluator version is still building... ({elapsed_minutes:.1f}m elapsed)") - else: - print(f"⏳ Evaluator version state: {state}, status: {status_msg}") - - except Exception as e: - print(f"Warning: Failed to check evaluator version status: {e}") - - # Wait before next poll - time.sleep(poll_interval) - - # Timeout reached - elapsed_minutes = (time.time() - start_time) / 60 - print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator version is not yet ACTIVE") - return False - - def _validate_dataset_jsonl(jsonl_path: str, sample_limit: int = 50) -> bool: """Validate that a JSONL file contains rows compatible with EvaluationRow. @@ -503,71 +440,6 @@ def upload_dataset( return None, None -def _upload_and_ensure_evaluator( - project_root: str, - evaluator_id: str, - api_key: str, - api_base: str, -) -> bool: - """Upload evaluator and ensure its version becomes ACTIVE. - - Creates/updates the evaluator and uploads the code, then polls the specific - version until it becomes ACTIVE. - """ - from eval_protocol.evaluation import create_evaluation - - try: - tests = _discover_tests(project_root) - selected_entry: Optional[str] = None - st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) - if st_path and st_func: - selected_entry = _build_entry_point(project_root, st_path, st_func) - # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators - if selected_entry is None and len(tests) > 1: - print( - f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" - " Please re-run specifying the evaluator.\n" - " Hints:\n" - " - eval-protocol create rft --evaluator \n" - ) - return False - - print(f"\nUploading evaluator '{evaluator_id}'...") - result, version_id = create_evaluation( - evaluator_id=evaluator_id, - display_name=evaluator_id, - description=f"Evaluator for {evaluator_id}", - entry_point=selected_entry, - ) - - if not version_id: - print("Warning: Evaluator created but version upload failed.") - return False - - print(f"✓ Uploaded evaluator: {evaluator_id} (version: {version_id})") - - # Poll for the specific evaluator version status - print(f"Waiting for evaluator '{evaluator_id}' version '{version_id}' to become ACTIVE...") - is_active = _poll_evaluator_version_status( - evaluator_id=evaluator_id, - version_id=version_id, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ) - - if not is_active: - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\n❌ Evaluator version is not ready within the timeout period.") - print(f"📊 Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - except Exception as e: - print(f"Warning: Failed to upload evaluator automatically: {e}") - return False - - def _create_rft_job( account_id: str, api_key: str, @@ -720,11 +592,13 @@ def create_rft_command(args) -> int: return 1 # 5) Ensure evaluator exists and its latest version is ACTIVE (upload + poll if needed) - if not _upload_and_ensure_evaluator( + if not upload_and_ensure_evaluator( project_root=project_root, evaluator_id=evaluator_id, api_key=api_key, api_base=api_base, + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, ): return 1 diff --git a/eval_protocol/cli_commands/utils.py b/eval_protocol/cli_commands/utils.py index 804a2ae6..ab733832 100644 --- a/eval_protocol/cli_commands/utils.py +++ b/eval_protocol/cli_commands/utils.py @@ -23,6 +23,7 @@ get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from ..fireworks_client import create_fireworks_client from ..fireworks_rft import _map_api_host_to_app_host @@ -854,3 +855,151 @@ def resolve_evaluator( evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name + + +def _poll_evaluator_version_status( + evaluator_id: str, + version_id: str, + api_key: str, + api_base: str, + timeout_minutes: int = 10, +) -> bool: + """ + Poll a specific evaluator version status until it becomes ACTIVE or times out. + + Uses the Fireworks SDK to get the specified version of the evaluator and checks + its build state. + + Args: + evaluator_id: The evaluator ID (not full resource name) + version_id: The specific version ID to poll + api_key: Fireworks API key + api_base: Fireworks API base URL + timeout_minutes: Maximum time to wait in minutes + + Returns: + True if evaluator version becomes ACTIVE, False if timeout or BUILD_FAILED + """ + timeout_seconds = timeout_minutes * 60 + poll_interval = 10 # seconds + start_time = time.time() + + print( + f"Polling evaluator version '{version_id}' status (timeout: {timeout_minutes}m, interval: {poll_interval}s)..." + ) + + client = create_fireworks_client(api_key=api_key, base_url=api_base) + + while time.time() - start_time < timeout_seconds: + try: + version = client.evaluator_versions.get(version_id, evaluator_id=evaluator_id) + state = version.state or "STATE_UNSPECIFIED" + status_msg = "" + if version.status and version.status.message: + status_msg = version.status.message + + if state == "ACTIVE": + print("✅ Evaluator version is ACTIVE and ready!") + return True + elif state == "BUILD_FAILED": + print(f"❌ Evaluator version build failed. Status: {status_msg}") + return False + elif state == "BUILDING": + elapsed_minutes = (time.time() - start_time) / 60 + print(f"⏳ Evaluator version is still building... ({elapsed_minutes:.1f}m elapsed)") + else: + print(f"⏳ Evaluator version state: {state}, status: {status_msg}") + + except Exception as e: + print(f"Warning: Failed to check evaluator version status: {e}") + + # Wait before next poll + time.sleep(poll_interval) + + # Timeout reached + elapsed_minutes = (time.time() - start_time) / 60 + print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator version is not yet ACTIVE") + return False + + +def upload_and_ensure_evaluator( + project_root: str, + evaluator_id: str, + api_key: str, + api_base: str, + selected_test_file_path: Optional[str] = None, + selected_test_func_name: Optional[str] = None, +) -> bool: + """Upload evaluator and ensure its version becomes ACTIVE. + + Creates/updates the evaluator and uploads the code, then polls the specific + version until it becomes ACTIVE. This is the shared implementation used by + both 'ep upload', 'ep create rft', and 'ep create evj' commands. + + Args: + project_root: Path to the project root directory. + evaluator_id: The evaluator ID. + api_key: Fireworks API key. + api_base: Fireworks API base URL. + selected_test_file_path: Optional path to the selected test file. + selected_test_func_name: Optional name of the selected test function. + + Returns: + True if evaluator was uploaded and became ACTIVE, False otherwise. + """ + from eval_protocol.evaluation import create_evaluation + + try: + tests = _discover_tests(project_root) + selected_entry: Optional[str] = None + + # Use provided test info if available, otherwise try to resolve + if selected_test_file_path and selected_test_func_name: + selected_entry = _build_entry_point(project_root, selected_test_file_path, selected_test_func_name) + else: + st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) + if st_path and st_func: + selected_entry = _build_entry_point(project_root, st_path, st_func) + + # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators + if selected_entry is None and len(tests) > 1: + print( + f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" + " Please re-run specifying the evaluator.\n" + ) + return False + + print(f"\nUploading evaluator '{evaluator_id}'...") + result, version_id = create_evaluation( + evaluator_id=evaluator_id, + display_name=evaluator_id, + description=f"Evaluator for {evaluator_id}", + entry_point=selected_entry, + ) + + if not version_id: + print("Warning: Evaluator created but version upload failed.") + return False + + print(f"✓ Uploaded evaluator: {evaluator_id} (version: {version_id})") + + # Poll for the specific evaluator version status + print(f"Waiting for evaluator '{evaluator_id}' version '{version_id}' to become ACTIVE...") + is_active = _poll_evaluator_version_status( + evaluator_id=evaluator_id, + version_id=version_id, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ) + + if not is_active: + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\n❌ Evaluator version is not ready within the timeout period.") + print(f"📊 Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run the command again.") + return False + return True + except Exception as e: + print(f"Warning: Failed to upload evaluator automatically: {e}") + return False diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 9832aec2..ef4e7288 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -106,7 +106,7 @@ def rft_test_harness(tmp_path, monkeypatch, stub_fireworks): monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) - monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) return project @@ -641,7 +641,7 @@ def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch, monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") # Mock evaluator upload and version polling - evaluator becomes ACTIVE - monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) # Provide dataset via --dataset-jsonl so no test discovery needed ds_path = project / "dataset.jsonl" @@ -693,8 +693,8 @@ def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Mock _upload_and_ensure_evaluator to fail (ambiguous tests) - monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: False) + # Mock upload_and_ensure_evaluator to fail (ambiguous tests) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: False) # Two discovered tests (ambiguous) f1 = project / "a.py" @@ -936,7 +936,7 @@ def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(r monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) # Evaluator upload succeeds and version becomes ACTIVE - monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) # We will provide JSONL via input_dataset extractor for matching test (beta.test_two) jsonl_path = project / "data.jsonl" @@ -1018,7 +1018,7 @@ def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatc monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "pyroworks-dev") # Mock evaluator upload succeeds and version becomes ACTIVE - monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) captured = stub_fireworks From fc4f91377f55b5ca377a8c6ddfed9854219add0f Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 15:22:38 -0800 Subject: [PATCH 32/39] Add secret management for uploads in CLI - Implemented functions to check for existing secrets and confirm overrides before uploading to Fireworks. - Enhanced user interaction with double confirmation for overriding existing secrets, including fallback for non-interactive environments. - Updated the upload command to handle new and existing secrets separately, ensuring proper management during uploads. --- eval_protocol/cli_commands/upload.py | 160 ++++++++++++++++++++++++--- 1 file changed, 145 insertions(+), 15 deletions(-) diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index 5abe49e8..93d6be63 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -7,7 +7,7 @@ from typing import Any, Dict from eval_protocol.auth import get_fireworks_api_key -from eval_protocol.platform_api import create_or_update_fireworks_secret +from eval_protocol.platform_api import create_or_update_fireworks_secret, get_fireworks_secret from eval_protocol.evaluation import create_evaluation from .utils import ( @@ -270,6 +270,110 @@ def _prompt_select_secrets_fallback( return {} +def _check_existing_secrets( + secrets: Dict[str, str], + account_id: str, +) -> set[str]: + """ + Check which secrets already exist on Fireworks. + Returns a set of secret names that already exist. + """ + existing = set() + for secret_name in secrets.keys(): + secret = get_fireworks_secret(account_id=account_id, key_name=secret_name) + if secret is not None: + existing.add(secret_name) + return existing + + +def _confirm_override_secrets( + secrets_to_override: set[str], + non_interactive: bool, +) -> set[str]: + """ + Prompt user to confirm overriding existing secrets. + Returns the set of secrets confirmed for override. + """ + if not secrets_to_override: + return set() + + if non_interactive: + # In non-interactive mode, skip overriding existing secrets by default + print(f"\n⚠️ Skipping {len(secrets_to_override)} existing secret(s) in non-interactive mode.") + print(" Use interactive mode to confirm overriding existing secrets.") + return set() + + # Check if running in a non-TTY environment + if not sys.stdin.isatty(): + print(f"\n⚠️ Skipping {len(secrets_to_override)} existing secret(s) (non-TTY environment).") + return set() + + print(f"\n⚠️ The following {len(secrets_to_override)} secret(s) already exist on Fireworks:") + for name in sorted(secrets_to_override): + print(f" • {name}") + + try: + import questionary + + custom_style = _get_questionary_style() + + # First confirmation + confirm1 = questionary.confirm( + "Do you want to override these existing secrets?", + default=False, + style=custom_style, + ).ask() + + if confirm1 is None or not confirm1: + print("Override cancelled. Existing secrets will be skipped.") + return set() + + # Second confirmation (double verification) + confirm2 = questionary.confirm( + "⚠️ Are you SURE? This will permanently overwrite the existing secret values.", + default=False, + style=custom_style, + ).ask() + + if confirm2 is None or not confirm2: + print("Override cancelled. Existing secrets will be skipped.") + return set() + + return secrets_to_override + + except ImportError: + # Fallback to simple text-based confirmation + return _confirm_override_secrets_fallback(secrets_to_override) + except KeyboardInterrupt: + print("\n\nOverride cancelled.") + return set() + + +def _confirm_override_secrets_fallback(secrets_to_override: set[str]) -> set[str]: + """Fallback confirmation for when questionary is not available.""" + print("\n" + "=" * 60) + print("WARNING: Confirm override of existing secrets") + print("=" * 60) + + try: + # First confirmation + response1 = input("Do you want to override these existing secrets? (yes/no): ").strip().lower() + if response1 not in ("yes", "y"): + print("Override cancelled. Existing secrets will be skipped.") + return set() + + # Second confirmation (double verification) + response2 = input("⚠️ Are you SURE? Type 'override' to confirm: ").strip().lower() + if response2 != "override": + print("Override cancelled. Existing secrets will be skipped.") + return set() + + return secrets_to_override + except KeyboardInterrupt: + print("\n\nOverride cancelled.") + return set() + + def upload_command(args: argparse.Namespace) -> int: root = os.path.abspath(getattr(args, "path", ".")) entries_arg = getattr(args, "entry", None) @@ -322,21 +426,47 @@ def upload_command(args: argparse.Namespace) -> int: ) if selected_secrets: - print(f"\nUploading {len(selected_secrets)} selected secret(s) to Fireworks...") - for secret_name, secret_value in selected_secrets.items(): - source = ".env" if secret_name in secrets_from_env_file else "environment" - print( - f"Ensuring {secret_name} is registered as a secret on Fireworks for rollout... " - f"({source}: {_mask_secret_value(secret_value)})" + # Check which secrets already exist + print("\nChecking for existing secrets on Fireworks...") + existing_secrets = _check_existing_secrets(selected_secrets, fw_account_id) + + # Separate new secrets from existing ones + new_secrets = {k: v for k, v in selected_secrets.items() if k not in existing_secrets} + secrets_needing_override = {k: v for k, v in selected_secrets.items() if k in existing_secrets} + + # Confirm override for existing secrets (double verification) + confirmed_overrides: set[str] = set() + if secrets_needing_override: + confirmed_overrides = _confirm_override_secrets( + set(secrets_needing_override.keys()), + non_interactive, ) - if create_or_update_fireworks_secret( - account_id=fw_account_id, - key_name=secret_name, - secret_value=secret_value, - ): - print(f"✓ {secret_name} secret created/updated on Fireworks.") - else: - print(f"Warning: Failed to create/update {secret_name} secret on Fireworks.") + + # Build final list of secrets to upload + secrets_to_upload = dict(new_secrets) + for name in confirmed_overrides: + secrets_to_upload[name] = secrets_needing_override[name] + + if not secrets_to_upload: + print("No secrets to upload (existing secrets were not confirmed for override).") + else: + print(f"\nUploading {len(secrets_to_upload)} secret(s) to Fireworks...") + for secret_name, secret_value in secrets_to_upload.items(): + source = ".env" if secret_name in secrets_from_env_file else "environment" + action = "Overriding" if secret_name in confirmed_overrides else "Creating" + print(f"{action} {secret_name} on Fireworks... ({source}: {_mask_secret_value(secret_value)})") + if create_or_update_fireworks_secret( + account_id=fw_account_id, + key_name=secret_name, + secret_value=secret_value, + ): + print( + f"✓ {secret_name} secret {'updated' if secret_name in confirmed_overrides else 'created'} on Fireworks." + ) + else: + print( + f"Warning: Failed to {'update' if secret_name in confirmed_overrides else 'create'} {secret_name} secret on Fireworks." + ) else: print("No secrets selected for upload.") else: From 2f884283835820fe068255500e32936de75ddec8 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 15:25:41 -0800 Subject: [PATCH 33/39] handle existing secrets with caution --- eval_protocol/cli_commands/upload.py | 52 +++++++++++++++++++++------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index 93d6be63..5ecd4594 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -170,21 +170,25 @@ def _mask_secret_value(value: str) -> str: def _prompt_select_secrets( secrets: Dict[str, str], secrets_from_env_file: Dict[str, str], + existing_secrets: set[str], non_interactive: bool, ) -> Dict[str, str]: """ Prompt user to select which environment variables to upload as secrets. + Existing secrets are unchecked by default and marked with [exists]. Returns the selected secrets. """ if not secrets: return {} if non_interactive: - return secrets + # In non-interactive mode, only return new secrets (skip existing ones) + return {k: v for k, v in secrets.items() if k not in existing_secrets} # Check if running in a non-TTY environment (e.g., CI/CD) if not sys.stdin.isatty(): - return secrets + # In non-TTY, only return new secrets (skip existing ones) + return {k: v for k, v in secrets.items() if k not in existing_secrets} try: import questionary @@ -192,17 +196,23 @@ def _prompt_select_secrets( custom_style = _get_questionary_style() # Build choices with source info and masked values + # Existing secrets are unchecked by default choices = [] for key, value in secrets.items(): source = ".env" if key in secrets_from_env_file else "env" masked = _mask_secret_value(value) - label = f"{key} ({source}: {masked})" - choices.append(questionary.Choice(title=label, value=key, checked=True)) + is_existing = key in existing_secrets + exists_marker = " [exists]" if is_existing else "" + label = f"{key}{exists_marker} ({source}: {masked})" + # Uncheck existing secrets by default + choices.append(questionary.Choice(title=label, value=key, checked=not is_existing)) if len(choices) == 0: return {} print("\nFound environment variables to upload as Fireworks secrets:") + if existing_secrets: + print(" (Secrets marked [exists] are unchecked - selecting them will override)") selected_keys = questionary.checkbox( "Select secrets to upload:", choices=choices, @@ -220,7 +230,7 @@ def _prompt_select_secrets( except ImportError: # Fallback to simple text-based selection - return _prompt_select_secrets_fallback(secrets, secrets_from_env_file) + return _prompt_select_secrets_fallback(secrets, secrets_from_env_file, existing_secrets) except KeyboardInterrupt: print("\n\nSecret upload cancelled.") return {} @@ -229,6 +239,7 @@ def _prompt_select_secrets( def _prompt_select_secrets_fallback( secrets: Dict[str, str], secrets_from_env_file: Dict[str, str], + existing_secrets: set[str], ) -> Dict[str, str]: """Fallback prompt selection for when questionary is not available.""" print("\n" + "=" * 60) @@ -237,13 +248,22 @@ def _prompt_select_secrets_fallback( print("\nTip: Install questionary for better UX: pip install questionary\n") secret_list = list(secrets.items()) + new_indices = [] for idx, (key, value) in enumerate(secret_list, 1): source = ".env" if key in secrets_from_env_file else "env" masked = _mask_secret_value(value) - print(f" [{idx}] {key} ({source}: {masked})") + is_existing = key in existing_secrets + exists_marker = " [exists]" if is_existing else "" + print(f" [{idx}] {key}{exists_marker} ({source}: {masked})") + if not is_existing: + new_indices.append(idx) print("\n" + "=" * 60) - print("Enter numbers to select (comma-separated), 'all' for all, or 'none' to skip:") + if existing_secrets: + print("Note: Secrets marked [exists] will be overridden if selected.") + default_selection = ",".join(str(i) for i in new_indices) if new_indices else "none" + print("Enter numbers (comma-separated), 'all' for all, or 'none' to skip.") + print(f"Default (new secrets only): {default_selection}") try: choice = input("Selection: ").strip().lower() @@ -251,7 +271,11 @@ def _prompt_select_secrets_fallback( print("\nSecret upload cancelled.") return {} - if not choice or choice == "none": + if not choice: + # Default: only new secrets + choice = default_selection + + if choice == "none": return {} if choice == "all": @@ -418,19 +442,21 @@ def upload_command(args: argparse.Namespace) -> int: if secrets_from_env_file and os.path.exists(env_file_path): print(f"Loading secrets from: {env_file_path}") + # Check which secrets already exist on Fireworks BEFORE prompting + print("Checking for existing secrets on Fireworks...") + existing_secrets = _check_existing_secrets(secrets_from_file, fw_account_id) + # Prompt user to select which secrets to upload + # Existing secrets are unchecked by default selected_secrets = _prompt_select_secrets( secrets_from_file, secrets_from_env_file, + existing_secrets, non_interactive, ) if selected_secrets: - # Check which secrets already exist - print("\nChecking for existing secrets on Fireworks...") - existing_secrets = _check_existing_secrets(selected_secrets, fw_account_id) - - # Separate new secrets from existing ones + # Separate new secrets from existing ones that user explicitly selected new_secrets = {k: v for k, v in selected_secrets.items() if k not in existing_secrets} secrets_needing_override = {k: v for k, v in selected_secrets.items() if k in existing_secrets} From c6a8c5135a17c24ec98ba0f704b1de8c196ad61d Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 15:32:25 -0800 Subject: [PATCH 34/39] Integrate secrets upload handling in CLI commands - Added `handle_secrets_upload` function to `create_evj.py`, `create_rft.py`, and `upload.py` for managing secrets with double verification for existing entries. - Streamlined the upload process by consolidating secret management logic, enhancing user interaction during uploads. - Removed redundant secret loading functions from `upload.py` to improve code clarity and maintainability. --- eval_protocol/cli_commands/create_evj.py | 9 + eval_protocol/cli_commands/create_rft.py | 9 + eval_protocol/cli_commands/secrets.py | 388 +++++++++++++++++++++++ eval_protocol/cli_commands/upload.py | 370 +-------------------- 4 files changed, 414 insertions(+), 362 deletions(-) create mode 100644 eval_protocol/cli_commands/secrets.py diff --git a/eval_protocol/cli_commands/create_evj.py b/eval_protocol/cli_commands/create_evj.py index 13063f77..99319a54 100644 --- a/eval_protocol/cli_commands/create_evj.py +++ b/eval_protocol/cli_commands/create_evj.py @@ -8,6 +8,7 @@ from ..auth import get_fireworks_api_base, get_fireworks_api_key from ..fireworks_client import create_fireworks_client +from .secrets import handle_secrets_upload from .utils import ( _build_trimmed_dataset_id, _build_evaluator_dashboard_url, @@ -168,6 +169,14 @@ def create_evj_command(args) -> int: if not evaluator_id or not evaluator_resource_name: return 1 + # 1.5) Handle secrets upload (with double verification for existing secrets) + env_file = getattr(args, "env_file", None) + handle_secrets_upload( + project_root=project_root, + env_file=env_file, + non_interactive=non_interactive, + ) + # 2) Resolve input dataset source (id or JSONL path) input_dataset_id, input_dataset_resource, dataset_jsonl = resolve_dataset( project_root=project_root, diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index b1550dbf..dd8d361c 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -19,6 +19,7 @@ materialize_dataset_via_builder, ) from ..models import EvaluationRow +from .secrets import handle_secrets_upload from .utils import ( _build_entry_point, _build_trimmed_dataset_id, @@ -535,6 +536,14 @@ def create_rft_command(args) -> int: if not evaluator_id or not evaluator_resource_name: return 1 + # 1.5) Handle secrets upload (with double verification for existing secrets) + env_file = getattr(args, "env_file", None) + handle_secrets_upload( + project_root=project_root, + env_file=env_file, + non_interactive=non_interactive, + ) + # 2) Resolve dataset source (id or JSONL path) dataset_id, dataset_resource, dataset_jsonl = resolve_dataset( project_root=project_root, diff --git a/eval_protocol/cli_commands/secrets.py b/eval_protocol/cli_commands/secrets.py new file mode 100644 index 00000000..25278294 --- /dev/null +++ b/eval_protocol/cli_commands/secrets.py @@ -0,0 +1,388 @@ +""" +Secret handling module for ep CLI commands. + +This module provides reusable functions for loading, selecting, and uploading +secrets to Fireworks. Used by 'ep upload', 'ep create rft', and 'ep create evj'. +""" + +import os +import sys +from typing import Dict + +from eval_protocol.auth import get_fireworks_api_key +from eval_protocol.platform_api import create_or_update_fireworks_secret, get_fireworks_secret +from .utils import _ensure_account_id, _get_questionary_style + + +def load_secrets_from_env_file(env_file_path: str) -> Dict[str, str]: + """ + Load secrets from a .env file that should be uploaded to Fireworks. + """ + if not os.path.exists(env_file_path): + return {} + + # Load the .env file into a temporary environment + env_vars = {} + with open(env_file_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) + key = key.strip() + value = value.strip().strip('"').strip("'") # Remove quotes + env_vars[key] = value + return env_vars + + +def mask_secret_value(value: str) -> str: + """ + Return a masked representation of a secret showing only a small prefix/suffix. + Example: fw_3Z*******Xgnk + """ + try: + if not isinstance(value, str) or not value: + return "" + prefix_len = 6 + suffix_len = 4 + if len(value) <= prefix_len + suffix_len: + return value[0] + "***" + value[-1] + return f"{value[:prefix_len]}***{value[-suffix_len:]}" + except Exception: + return "" + + +def check_existing_secrets( + secrets: Dict[str, str], + account_id: str, +) -> set[str]: + """ + Check which secrets already exist on Fireworks. + Returns a set of secret names that already exist. + """ + existing = set() + for secret_name in secrets.keys(): + secret = get_fireworks_secret(account_id=account_id, key_name=secret_name) + if secret is not None: + existing.add(secret_name) + return existing + + +def prompt_select_secrets( + secrets: Dict[str, str], + secrets_from_env_file: Dict[str, str], + existing_secrets: set[str], + non_interactive: bool, +) -> Dict[str, str]: + """ + Prompt user to select which environment variables to upload as secrets. + Existing secrets are unchecked by default and marked with [exists]. + Returns the selected secrets. + """ + if not secrets: + return {} + + if non_interactive: + # In non-interactive mode, only return new secrets (skip existing ones) + return {k: v for k, v in secrets.items() if k not in existing_secrets} + + # Check if running in a non-TTY environment (e.g., CI/CD) + if not sys.stdin.isatty(): + # In non-TTY, only return new secrets (skip existing ones) + return {k: v for k, v in secrets.items() if k not in existing_secrets} + + try: + import questionary + + custom_style = _get_questionary_style() + + # Build choices with source info and masked values + # Existing secrets are unchecked by default + choices = [] + for key, value in secrets.items(): + source = ".env" if key in secrets_from_env_file else "env" + masked = mask_secret_value(value) + is_existing = key in existing_secrets + exists_marker = " [exists]" if is_existing else "" + label = f"{key}{exists_marker} ({source}: {masked})" + # Uncheck existing secrets by default + choices.append(questionary.Choice(title=label, value=key, checked=not is_existing)) + + if len(choices) == 0: + return {} + + print("\nFound environment variables to upload as Fireworks secrets:") + if existing_secrets: + print(" (Secrets marked [exists] are unchecked - selecting them will override)") + selected_keys = questionary.checkbox( + "Select secrets to upload:", + choices=choices, + style=custom_style, + pointer=">", + instruction="(↑↓ move, space select, enter confirm)", + ).ask() + + if selected_keys is None: + # User cancelled with Ctrl+C + print("\nSecret upload cancelled.") + return {} + + return {k: v for k, v in secrets.items() if k in selected_keys} + + except ImportError: + # Fallback to simple text-based selection + return _prompt_select_secrets_fallback(secrets, secrets_from_env_file, existing_secrets) + except KeyboardInterrupt: + print("\n\nSecret upload cancelled.") + return {} + + +def _prompt_select_secrets_fallback( + secrets: Dict[str, str], + secrets_from_env_file: Dict[str, str], + existing_secrets: set[str], +) -> Dict[str, str]: + """Fallback prompt selection for when questionary is not available.""" + print("\n" + "=" * 60) + print("Found environment variables to upload as Fireworks secrets:") + print("=" * 60) + print("\nTip: Install questionary for better UX: pip install questionary\n") + + secret_list = list(secrets.items()) + new_indices = [] + for idx, (key, value) in enumerate(secret_list, 1): + source = ".env" if key in secrets_from_env_file else "env" + masked = mask_secret_value(value) + is_existing = key in existing_secrets + exists_marker = " [exists]" if is_existing else "" + print(f" [{idx}] {key}{exists_marker} ({source}: {masked})") + if not is_existing: + new_indices.append(idx) + + print("\n" + "=" * 60) + if existing_secrets: + print("Note: Secrets marked [exists] will be overridden if selected.") + default_selection = ",".join(str(i) for i in new_indices) if new_indices else "none" + print("Enter numbers (comma-separated), 'all' for all, or 'none' to skip.") + print(f"Default (new secrets only): {default_selection}") + + try: + choice = input("Selection: ").strip().lower() + except KeyboardInterrupt: + print("\nSecret upload cancelled.") + return {} + + if not choice: + # Default: only new secrets + choice = default_selection + + if choice == "none": + return {} + + if choice == "all": + return secrets + + try: + indices = [int(x.strip()) for x in choice.split(",")] + selected = {} + for idx in indices: + if 1 <= idx <= len(secret_list): + key, value = secret_list[idx - 1] + selected[key] = value + return selected + except ValueError: + print("Invalid input. Skipping secret upload.") + return {} + + +def confirm_override_secrets( + secrets_to_override: set[str], + non_interactive: bool, +) -> set[str]: + """ + Prompt user to confirm overriding existing secrets (double verification). + Returns the set of secrets confirmed for override. + """ + if not secrets_to_override: + return set() + + if non_interactive: + # In non-interactive mode, skip overriding existing secrets by default + print(f"\n⚠️ Skipping {len(secrets_to_override)} existing secret(s) in non-interactive mode.") + print(" Use interactive mode to confirm overriding existing secrets.") + return set() + + # Check if running in a non-TTY environment + if not sys.stdin.isatty(): + print(f"\n⚠️ Skipping {len(secrets_to_override)} existing secret(s) (non-TTY environment).") + return set() + + print(f"\n⚠️ The following {len(secrets_to_override)} secret(s) already exist on Fireworks:") + for name in sorted(secrets_to_override): + print(f" • {name}") + + try: + import questionary + + custom_style = _get_questionary_style() + + # First confirmation + confirm1 = questionary.confirm( + "Do you want to override these existing secrets?", + default=False, + style=custom_style, + ).ask() + + if confirm1 is None or not confirm1: + print("Override cancelled. Existing secrets will be skipped.") + return set() + + # Second confirmation (double verification) + confirm2 = questionary.confirm( + "⚠️ Are you SURE? This will permanently overwrite the existing secret values.", + default=False, + style=custom_style, + ).ask() + + if confirm2 is None or not confirm2: + print("Override cancelled. Existing secrets will be skipped.") + return set() + + return secrets_to_override + + except ImportError: + # Fallback to simple text-based confirmation + return _confirm_override_secrets_fallback(secrets_to_override) + except KeyboardInterrupt: + print("\n\nOverride cancelled.") + return set() + + +def _confirm_override_secrets_fallback(secrets_to_override: set[str]) -> set[str]: + """Fallback confirmation for when questionary is not available.""" + print("\n" + "=" * 60) + print("WARNING: Confirm override of existing secrets") + print("=" * 60) + + try: + # First confirmation + response1 = input("Do you want to override these existing secrets? (yes/no): ").strip().lower() + if response1 not in ("yes", "y"): + print("Override cancelled. Existing secrets will be skipped.") + return set() + + # Second confirmation (double verification) + response2 = input("⚠️ Are you SURE? Type 'override' to confirm: ").strip().lower() + if response2 != "override": + print("Override cancelled. Existing secrets will be skipped.") + return set() + + return secrets_to_override + except KeyboardInterrupt: + print("\n\nOverride cancelled.") + return set() + + +def handle_secrets_upload( + project_root: str, + env_file: str | None, + non_interactive: bool, +) -> None: + """ + Main entry point for handling secrets upload flow. + + This function: + 1. Loads secrets from .env file and environment + 2. Checks which secrets already exist on Fireworks + 3. Prompts user to select secrets (existing ones unchecked by default) + 4. Requires double confirmation for overriding existing secrets + 5. Uploads the selected/confirmed secrets + + Args: + project_root: Path to the project root directory. + env_file: Optional path to a specific .env file (overrides default). + non_interactive: If True, skip prompts and only upload new secrets. + """ + try: + fw_account_id = _ensure_account_id() + + # Determine .env file path + if env_file: + env_file_path = env_file + else: + env_file_path = os.path.join(project_root, ".env") + + # Load secrets from .env file + secrets_from_file = load_secrets_from_env_file(env_file_path) + secrets_from_env_file = secrets_from_file.copy() # Track what came from .env file + + # Also consider FIREWORKS_API_KEY from environment, but prefer .env value + fw_api_key_value = get_fireworks_api_key() + if fw_api_key_value and "FIREWORKS_API_KEY" not in secrets_from_file: + secrets_from_file["FIREWORKS_API_KEY"] = fw_api_key_value + + if fw_account_id and secrets_from_file: + if secrets_from_env_file and os.path.exists(env_file_path): + print(f"Loading secrets from: {env_file_path}") + + # Check which secrets already exist on Fireworks BEFORE prompting + print("Checking for existing secrets on Fireworks...") + existing_secrets = check_existing_secrets(secrets_from_file, fw_account_id) + + # Prompt user to select which secrets to upload + # Existing secrets are unchecked by default + selected_secrets = prompt_select_secrets( + secrets_from_file, + secrets_from_env_file, + existing_secrets, + non_interactive, + ) + + if selected_secrets: + # Separate new secrets from existing ones that user explicitly selected + new_secrets = {k: v for k, v in selected_secrets.items() if k not in existing_secrets} + secrets_needing_override = {k: v for k, v in selected_secrets.items() if k in existing_secrets} + + # Confirm override for existing secrets (double verification) + confirmed_overrides: set[str] = set() + if secrets_needing_override: + confirmed_overrides = confirm_override_secrets( + set(secrets_needing_override.keys()), + non_interactive, + ) + + # Build final list of secrets to upload + secrets_to_upload = dict(new_secrets) + for name in confirmed_overrides: + secrets_to_upload[name] = secrets_needing_override[name] + + if not secrets_to_upload: + print("No secrets to upload (existing secrets were not confirmed for override).") + else: + print(f"\nUploading {len(secrets_to_upload)} secret(s) to Fireworks...") + for secret_name, secret_value in secrets_to_upload.items(): + source = ".env" if secret_name in secrets_from_env_file else "environment" + action = "Overriding" if secret_name in confirmed_overrides else "Creating" + print(f"{action} {secret_name} on Fireworks... ({source}: {mask_secret_value(secret_value)})") + if create_or_update_fireworks_secret( + account_id=fw_account_id, + key_name=secret_name, + secret_value=secret_value, + ): + print( + f"✓ {secret_name} secret {'updated' if secret_name in confirmed_overrides else 'created'} on Fireworks." + ) + else: + print( + f"Warning: Failed to {'update' if secret_name in confirmed_overrides else 'create'} {secret_name} secret on Fireworks." + ) + else: + print("No secrets selected for upload.") + else: + if not fw_account_id: + print( + "Warning: Could not resolve Fireworks account id from FIREWORKS_API_KEY; cannot register secrets." + ) + if not secrets_from_file: + print("Warning: No API keys found in environment or .env file; no secrets to register.") + except Exception as e: + print(f"Warning: Skipped Fireworks secret registration due to error: {e}") diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index 5ecd4594..87046ab2 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -1,25 +1,18 @@ import argparse -from eval_protocol.cli_commands.utils import DiscoveredTest import os import re import sys from pathlib import Path -from typing import Any, Dict - -from eval_protocol.auth import get_fireworks_api_key -from eval_protocol.platform_api import create_or_update_fireworks_secret, get_fireworks_secret +from eval_protocol.cli_commands.utils import DiscoveredTest from eval_protocol.evaluation import create_evaluation +from .secrets import handle_secrets_upload from .utils import ( _build_entry_point, _build_evaluator_dashboard_url, _discover_and_select_tests, - _discover_tests, - _ensure_account_id, - _get_questionary_style, load_module_from_file_path, _normalize_evaluator_id, - _prompt_select, ) @@ -130,274 +123,6 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]: return qualname, os.path.abspath(source_file_path) if source_file_path else "" -def _load_secrets_from_env_file(env_file_path: str) -> Dict[str, str]: - """ - Load secrets from a .env file that should be uploaded to Fireworks. - """ - if not os.path.exists(env_file_path): - return {} - - # Load the .env file into a temporary environment - env_vars = {} - with open(env_file_path, "r") as f: - for line in f: - line = line.strip() - if line and not line.startswith("#") and "=" in line: - key, value = line.split("=", 1) - key = key.strip() - value = value.strip().strip('"').strip("'") # Remove quotes - env_vars[key] = value - return env_vars - - -def _mask_secret_value(value: str) -> str: - """ - Return a masked representation of a secret showing only a small prefix/suffix. - Example: fw_3Z*******Xgnk - """ - try: - if not isinstance(value, str) or not value: - return "" - prefix_len = 6 - suffix_len = 4 - if len(value) <= prefix_len + suffix_len: - return value[0] + "***" + value[-1] - return f"{value[:prefix_len]}***{value[-suffix_len:]}" - except Exception: - return "" - - -def _prompt_select_secrets( - secrets: Dict[str, str], - secrets_from_env_file: Dict[str, str], - existing_secrets: set[str], - non_interactive: bool, -) -> Dict[str, str]: - """ - Prompt user to select which environment variables to upload as secrets. - Existing secrets are unchecked by default and marked with [exists]. - Returns the selected secrets. - """ - if not secrets: - return {} - - if non_interactive: - # In non-interactive mode, only return new secrets (skip existing ones) - return {k: v for k, v in secrets.items() if k not in existing_secrets} - - # Check if running in a non-TTY environment (e.g., CI/CD) - if not sys.stdin.isatty(): - # In non-TTY, only return new secrets (skip existing ones) - return {k: v for k, v in secrets.items() if k not in existing_secrets} - - try: - import questionary - - custom_style = _get_questionary_style() - - # Build choices with source info and masked values - # Existing secrets are unchecked by default - choices = [] - for key, value in secrets.items(): - source = ".env" if key in secrets_from_env_file else "env" - masked = _mask_secret_value(value) - is_existing = key in existing_secrets - exists_marker = " [exists]" if is_existing else "" - label = f"{key}{exists_marker} ({source}: {masked})" - # Uncheck existing secrets by default - choices.append(questionary.Choice(title=label, value=key, checked=not is_existing)) - - if len(choices) == 0: - return {} - - print("\nFound environment variables to upload as Fireworks secrets:") - if existing_secrets: - print(" (Secrets marked [exists] are unchecked - selecting them will override)") - selected_keys = questionary.checkbox( - "Select secrets to upload:", - choices=choices, - style=custom_style, - pointer=">", - instruction="(↑↓ move, space select, enter confirm)", - ).ask() - - if selected_keys is None: - # User cancelled with Ctrl+C - print("\nSecret upload cancelled.") - return {} - - return {k: v for k, v in secrets.items() if k in selected_keys} - - except ImportError: - # Fallback to simple text-based selection - return _prompt_select_secrets_fallback(secrets, secrets_from_env_file, existing_secrets) - except KeyboardInterrupt: - print("\n\nSecret upload cancelled.") - return {} - - -def _prompt_select_secrets_fallback( - secrets: Dict[str, str], - secrets_from_env_file: Dict[str, str], - existing_secrets: set[str], -) -> Dict[str, str]: - """Fallback prompt selection for when questionary is not available.""" - print("\n" + "=" * 60) - print("Found environment variables to upload as Fireworks secrets:") - print("=" * 60) - print("\nTip: Install questionary for better UX: pip install questionary\n") - - secret_list = list(secrets.items()) - new_indices = [] - for idx, (key, value) in enumerate(secret_list, 1): - source = ".env" if key in secrets_from_env_file else "env" - masked = _mask_secret_value(value) - is_existing = key in existing_secrets - exists_marker = " [exists]" if is_existing else "" - print(f" [{idx}] {key}{exists_marker} ({source}: {masked})") - if not is_existing: - new_indices.append(idx) - - print("\n" + "=" * 60) - if existing_secrets: - print("Note: Secrets marked [exists] will be overridden if selected.") - default_selection = ",".join(str(i) for i in new_indices) if new_indices else "none" - print("Enter numbers (comma-separated), 'all' for all, or 'none' to skip.") - print(f"Default (new secrets only): {default_selection}") - - try: - choice = input("Selection: ").strip().lower() - except KeyboardInterrupt: - print("\nSecret upload cancelled.") - return {} - - if not choice: - # Default: only new secrets - choice = default_selection - - if choice == "none": - return {} - - if choice == "all": - return secrets - - try: - indices = [int(x.strip()) for x in choice.split(",")] - selected = {} - for idx in indices: - if 1 <= idx <= len(secret_list): - key, value = secret_list[idx - 1] - selected[key] = value - return selected - except ValueError: - print("Invalid input. Skipping secret upload.") - return {} - - -def _check_existing_secrets( - secrets: Dict[str, str], - account_id: str, -) -> set[str]: - """ - Check which secrets already exist on Fireworks. - Returns a set of secret names that already exist. - """ - existing = set() - for secret_name in secrets.keys(): - secret = get_fireworks_secret(account_id=account_id, key_name=secret_name) - if secret is not None: - existing.add(secret_name) - return existing - - -def _confirm_override_secrets( - secrets_to_override: set[str], - non_interactive: bool, -) -> set[str]: - """ - Prompt user to confirm overriding existing secrets. - Returns the set of secrets confirmed for override. - """ - if not secrets_to_override: - return set() - - if non_interactive: - # In non-interactive mode, skip overriding existing secrets by default - print(f"\n⚠️ Skipping {len(secrets_to_override)} existing secret(s) in non-interactive mode.") - print(" Use interactive mode to confirm overriding existing secrets.") - return set() - - # Check if running in a non-TTY environment - if not sys.stdin.isatty(): - print(f"\n⚠️ Skipping {len(secrets_to_override)} existing secret(s) (non-TTY environment).") - return set() - - print(f"\n⚠️ The following {len(secrets_to_override)} secret(s) already exist on Fireworks:") - for name in sorted(secrets_to_override): - print(f" • {name}") - - try: - import questionary - - custom_style = _get_questionary_style() - - # First confirmation - confirm1 = questionary.confirm( - "Do you want to override these existing secrets?", - default=False, - style=custom_style, - ).ask() - - if confirm1 is None or not confirm1: - print("Override cancelled. Existing secrets will be skipped.") - return set() - - # Second confirmation (double verification) - confirm2 = questionary.confirm( - "⚠️ Are you SURE? This will permanently overwrite the existing secret values.", - default=False, - style=custom_style, - ).ask() - - if confirm2 is None or not confirm2: - print("Override cancelled. Existing secrets will be skipped.") - return set() - - return secrets_to_override - - except ImportError: - # Fallback to simple text-based confirmation - return _confirm_override_secrets_fallback(secrets_to_override) - except KeyboardInterrupt: - print("\n\nOverride cancelled.") - return set() - - -def _confirm_override_secrets_fallback(secrets_to_override: set[str]) -> set[str]: - """Fallback confirmation for when questionary is not available.""" - print("\n" + "=" * 60) - print("WARNING: Confirm override of existing secrets") - print("=" * 60) - - try: - # First confirmation - response1 = input("Do you want to override these existing secrets? (yes/no): ").strip().lower() - if response1 not in ("yes", "y"): - print("Override cancelled. Existing secrets will be skipped.") - return set() - - # Second confirmation (double verification) - response2 = input("⚠️ Are you SURE? Type 'override' to confirm: ").strip().lower() - if response2 != "override": - print("Override cancelled. Existing secrets will be skipped.") - return set() - - return secrets_to_override - except KeyboardInterrupt: - print("\n\nOverride cancelled.") - return set() - - def upload_command(args: argparse.Namespace) -> int: root = os.path.abspath(getattr(args, "path", ".")) entries_arg = getattr(args, "entry", None) @@ -419,91 +144,12 @@ def upload_command(args: argparse.Namespace) -> int: description = getattr(args, "description", None) env_file = getattr(args, "env_file", None) - # Load secrets from .env file and ensure they're available on Fireworks - try: - fw_account_id = _ensure_account_id() - - # Determine .env file path - if env_file: - env_file_path = env_file - else: - env_file_path = os.path.join(root, ".env") - - # Load secrets from .env file - secrets_from_file = _load_secrets_from_env_file(env_file_path) - secrets_from_env_file = secrets_from_file.copy() # Track what came from .env file - - # Also consider FIREWORKS_API_KEY from environment, but prefer .env value - fw_api_key_value = get_fireworks_api_key() - if fw_api_key_value and "FIREWORKS_API_KEY" not in secrets_from_file: - secrets_from_file["FIREWORKS_API_KEY"] = fw_api_key_value - - if fw_account_id and secrets_from_file: - if secrets_from_env_file and os.path.exists(env_file_path): - print(f"Loading secrets from: {env_file_path}") - - # Check which secrets already exist on Fireworks BEFORE prompting - print("Checking for existing secrets on Fireworks...") - existing_secrets = _check_existing_secrets(secrets_from_file, fw_account_id) - - # Prompt user to select which secrets to upload - # Existing secrets are unchecked by default - selected_secrets = _prompt_select_secrets( - secrets_from_file, - secrets_from_env_file, - existing_secrets, - non_interactive, - ) - - if selected_secrets: - # Separate new secrets from existing ones that user explicitly selected - new_secrets = {k: v for k, v in selected_secrets.items() if k not in existing_secrets} - secrets_needing_override = {k: v for k, v in selected_secrets.items() if k in existing_secrets} - - # Confirm override for existing secrets (double verification) - confirmed_overrides: set[str] = set() - if secrets_needing_override: - confirmed_overrides = _confirm_override_secrets( - set(secrets_needing_override.keys()), - non_interactive, - ) - - # Build final list of secrets to upload - secrets_to_upload = dict(new_secrets) - for name in confirmed_overrides: - secrets_to_upload[name] = secrets_needing_override[name] - - if not secrets_to_upload: - print("No secrets to upload (existing secrets were not confirmed for override).") - else: - print(f"\nUploading {len(secrets_to_upload)} secret(s) to Fireworks...") - for secret_name, secret_value in secrets_to_upload.items(): - source = ".env" if secret_name in secrets_from_env_file else "environment" - action = "Overriding" if secret_name in confirmed_overrides else "Creating" - print(f"{action} {secret_name} on Fireworks... ({source}: {_mask_secret_value(secret_value)})") - if create_or_update_fireworks_secret( - account_id=fw_account_id, - key_name=secret_name, - secret_value=secret_value, - ): - print( - f"✓ {secret_name} secret {'updated' if secret_name in confirmed_overrides else 'created'} on Fireworks." - ) - else: - print( - f"Warning: Failed to {'update' if secret_name in confirmed_overrides else 'create'} {secret_name} secret on Fireworks." - ) - else: - print("No secrets selected for upload.") - else: - if not fw_account_id: - print( - "Warning: Could not resolve Fireworks account id from FIREWORKS_API_KEY; cannot register secrets." - ) - if not secrets_from_file: - print("Warning: No API keys found in environment or .env file; no secrets to register.") - except Exception as e: - print(f"Warning: Skipped Fireworks secret registration due to error: {e}") + # Handle secrets upload (with double verification for existing secrets) + handle_secrets_upload( + project_root=root, + env_file=env_file, + non_interactive=non_interactive, + ) exit_code = 0 for i, (qualname, source_file_path) in enumerate(selected_specs): From a2165fbbc6c117a215bc1f5de362ce7012ef3bb5 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 15:34:04 -0800 Subject: [PATCH 35/39] Remove unused `_to_pyargs_nodeid` function from `upload.py` to enhance code clarity and maintainability. --- eval_protocol/cli_commands/upload.py | 55 ---------------------------- 1 file changed, 55 deletions(-) diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index 87046ab2..8fe8423e 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -16,61 +16,6 @@ ) -def _to_pyargs_nodeid(file_path: str, func_name: str) -> str | None: - """Attempt to build a pytest nodeid suitable for `pytest `. - - Preference order: - 1) Dotted package module path with double-colon: pkg.subpkg.module::func - 2) Filesystem path with double-colon: path/to/module.py::func - - Returns dotted form when package root can be inferred (directory chain with __init__.py - leading up to a directory contained in sys.path). Returns None if no reasonable - nodeid can be created (should be rare). - """ - try: - abs_path = os.path.abspath(file_path) - dir_path, filename = os.path.split(abs_path) - module_base, ext = os.path.splitext(filename) - if ext != ".py": - # Not a python file - return None - - # Walk up while packages have __init__.py - segments: list[str] = [module_base] - current = dir_path - package_root = None - while True: - if os.path.isfile(os.path.join(current, "__init__.py")): - segments.insert(0, os.path.basename(current)) - parent = os.path.dirname(current) - # Stop if parent is not within current sys.path import roots - if parent == current: - break - current = parent - else: - package_root = current - break - - # If we found a package chain, check that the package_root is importable (in sys.path) - if package_root and any( - os.path.abspath(sp).rstrip(os.sep) == os.path.abspath(package_root).rstrip(os.sep) for sp in sys.path - ): - dotted = ".".join(segments) - return f"{dotted}::{func_name}" - - # Do not emit a dotted top-level module for non-packages; prefer path-based nodeid - - # Fallback to relative path (if under cwd) or absolute path - cwd = os.getcwd() - try: - rel = os.path.relpath(abs_path, cwd) - except Exception: - rel = abs_path - return f"{rel}::{func_name}" - except Exception: - return None - - def _parse_entry(entry: str, cwd: str) -> tuple[str, str]: # Accept module::function, path::function, or legacy module:function entry = entry.strip() From 1445d753d949b4aeb1e3f931ba049185ad889bf3 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 14 Jan 2026 10:23:12 -0800 Subject: [PATCH 36/39] increase sql retries --- eval_protocol/event_bus/sqlite_event_bus_database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 5086d6e3..59a026ed 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -11,8 +11,8 @@ # Retry configuration for database operations -SQLITE_RETRY_MAX_TRIES = 5 -SQLITE_RETRY_MAX_TIME = 30 # seconds +SQLITE_RETRY_MAX_TRIES = 10 +SQLITE_RETRY_MAX_TIME = 60 # seconds def _is_database_locked_error(e: Exception) -> bool: From 7969a6ec542b87a8a21d0f39fb8d24cd46fd99f9 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 15:50:35 -0800 Subject: [PATCH 37/39] Refactor secret loading in CLI to use python-dotenv - Replaced manual parsing of .env files with `dotenv_values()` for improved handling of comments, quotes, and multi-line values. - Updated `load_secrets_from_env_file` function to return a filtered dictionary of environment variables, enhancing code clarity and maintainability. --- eval_protocol/cli_commands/secrets.py | 26 +-- tests/cli_commands/test_secrets.py | 217 ++++++++++++++++++++++++++ 2 files changed, 231 insertions(+), 12 deletions(-) create mode 100644 tests/cli_commands/test_secrets.py diff --git a/eval_protocol/cli_commands/secrets.py b/eval_protocol/cli_commands/secrets.py index 25278294..0d1789a3 100644 --- a/eval_protocol/cli_commands/secrets.py +++ b/eval_protocol/cli_commands/secrets.py @@ -9,6 +9,8 @@ import sys from typing import Dict +from dotenv import dotenv_values + from eval_protocol.auth import get_fireworks_api_key from eval_protocol.platform_api import create_or_update_fireworks_secret, get_fireworks_secret from .utils import _ensure_account_id, _get_questionary_style @@ -17,24 +19,24 @@ def load_secrets_from_env_file(env_file_path: str) -> Dict[str, str]: """ Load secrets from a .env file that should be uploaded to Fireworks. + + Uses python-dotenv's dotenv_values() for proper parsing of .env files, + which correctly handles: + - End-of-line comments (e.g., KEY=value # comment) + - Quoted values (single and double quotes) + - Escape sequences + - Multi-line values """ if not os.path.exists(env_file_path): return {} - # Load the .env file into a temporary environment - env_vars = {} - with open(env_file_path, "r") as f: - for line in f: - line = line.strip() - if line and not line.startswith("#") and "=" in line: - key, value = line.split("=", 1) - key = key.strip() - value = value.strip().strip('"').strip("'") # Remove quotes - env_vars[key] = value - return env_vars + # Use dotenv_values for proper .env parsing (handles comments, quotes, etc.) + parsed = dotenv_values(env_file_path) + # Filter out None values and convert to Dict[str, str] + return {k: v for k, v in parsed.items() if v is not None} -def mask_secret_value(value: str) -> str: +def mask_secret_value(value: str | None) -> str: """ Return a masked representation of a secret showing only a small prefix/suffix. Example: fw_3Z*******Xgnk diff --git a/tests/cli_commands/test_secrets.py b/tests/cli_commands/test_secrets.py new file mode 100644 index 00000000..3f4d62ea --- /dev/null +++ b/tests/cli_commands/test_secrets.py @@ -0,0 +1,217 @@ +"""Tests for eval_protocol.cli_commands.secrets module.""" + +import os +from pathlib import Path + +import pytest + +from eval_protocol.cli_commands.secrets import load_secrets_from_env_file, mask_secret_value + + +class TestLoadSecretsFromEnvFile: + """Tests for load_secrets_from_env_file function.""" + + def test_basic_key_value(self, tmp_path: Path): + """Test basic KEY=value parsing.""" + env_file = tmp_path / ".env" + env_file.write_text("MY_KEY=myvalue123\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "myvalue123"} + + def test_end_of_line_comment(self, tmp_path: Path): + """Test that end-of-line comments are properly stripped. + + This was a bug where manual parsing included the comment in the value. + """ + env_file = tmp_path / ".env" + env_file.write_text("MY_API_KEY=test_dummy_value_abc123 # this is a comment\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_API_KEY": "test_dummy_value_abc123"} + # Ensure the comment is NOT included + assert "# this" not in result["MY_API_KEY"] + assert "comment" not in result["MY_API_KEY"] + + def test_end_of_line_comment_no_space(self, tmp_path: Path): + """Test end-of-line comment without space before #.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY=value#this is a comment\n") + + result = load_secrets_from_env_file(str(env_file)) + + # python-dotenv treats # without space as part of value + # unless the value is quoted. This is the expected behavior. + assert result == {"KEY": "value#this is a comment"} + + def test_quoted_value_with_hash(self, tmp_path: Path): + """Test that quoted values preserve # character.""" + env_file = tmp_path / ".env" + env_file.write_text('KEY="value#with#hashes"\n') + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY": "value#with#hashes"} + + def test_double_quoted_values(self, tmp_path: Path): + """Test that double-quoted values are properly unquoted.""" + env_file = tmp_path / ".env" + env_file.write_text('MY_KEY="value with spaces"\n') + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "value with spaces"} + + def test_single_quoted_values(self, tmp_path: Path): + """Test that single-quoted values are properly unquoted.""" + env_file = tmp_path / ".env" + env_file.write_text("MY_KEY='value with spaces'\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "value with spaces"} + + def test_comment_lines_ignored(self, tmp_path: Path): + """Test that full-line comments are ignored.""" + env_file = tmp_path / ".env" + env_file.write_text("# This is a comment\nMY_KEY=myvalue\n# Another comment\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "myvalue"} + + def test_empty_lines_ignored(self, tmp_path: Path): + """Test that empty lines are ignored.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY1=value1\n\n\nKEY2=value2\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY1": "value1", "KEY2": "value2"} + + def test_multiple_keys(self, tmp_path: Path): + """Test parsing multiple key-value pairs.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY1=value1\nKEY2=value2\nKEY3=value3\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"} + + def test_file_not_found(self, tmp_path: Path): + """Test that non-existent file returns empty dict.""" + env_file = tmp_path / "nonexistent.env" + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {} + + def test_empty_file(self, tmp_path: Path): + """Test that empty file returns empty dict.""" + env_file = tmp_path / ".env" + env_file.write_text("") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {} + + def test_value_with_equals_sign(self, tmp_path: Path): + """Test that values containing = are properly parsed.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY=value=with=equals\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY": "value=with=equals"} + + def test_quoted_value_with_comment(self, tmp_path: Path): + """Test quoted value followed by a comment.""" + env_file = tmp_path / ".env" + env_file.write_text('MY_KEY="myvalue123" # this is a comment\n') + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "myvalue123"} + + def test_complex_env_file(self, tmp_path: Path): + """Test a complex .env file with various formats.""" + env_content = """# Configuration file +# Last updated: 2024-01-15 + +# API Keys +SERVICE_A_KEY=dummy_key_aaa # Service A key +SERVICE_B_KEY="dummy_key_bbb" # Service B key + +# Database settings +DB_HOST=localhost +DB_PORT=5432 +DB_NAME='mydb' + +# Feature flags +ENABLE_FEATURE=true # enable new feature + +# Empty values are skipped +""" + env_file = tmp_path / ".env" + env_file.write_text(env_content) + + result = load_secrets_from_env_file(str(env_file)) + + assert result["SERVICE_A_KEY"] == "dummy_key_aaa" + assert result["SERVICE_B_KEY"] == "dummy_key_bbb" + assert result["DB_HOST"] == "localhost" + assert result["DB_PORT"] == "5432" + assert result["DB_NAME"] == "mydb" + assert result["ENABLE_FEATURE"] == "true" + # Ensure no comments leaked into values + assert "# Service" not in result.get("SERVICE_A_KEY", "") + assert "# Service" not in result.get("SERVICE_B_KEY", "") + + def test_export_prefix_handled(self, tmp_path: Path): + """Test that 'export' prefix is handled (if supported by dotenv).""" + env_file = tmp_path / ".env" + env_file.write_text("export MY_KEY=myvalue123\n") + + result = load_secrets_from_env_file(str(env_file)) + + # python-dotenv handles 'export' prefix + assert result == {"MY_KEY": "myvalue123"} + + +class TestMaskSecretValue: + """Tests for mask_secret_value function.""" + + def test_normal_length_value(self): + """Test masking a normal length secret.""" + result = mask_secret_value("abcdefghijklmnopqrstu") + assert result == "abcdef***rstu" + assert len(result) < len("abcdefghijklmnopqrstu") + + def test_short_value(self): + """Test masking a very short secret.""" + result = mask_secret_value("abc") + assert result == "a***c" + + def test_empty_value(self): + """Test masking an empty value.""" + result = mask_secret_value("") + assert result == "" + + def test_none_value(self): + """Test masking None (edge case).""" + result = mask_secret_value(None) + assert result == "" + + def test_exact_boundary_length(self): + """Test masking value at exactly prefix+suffix length (10 chars).""" + # prefix_len=6, suffix_len=4, so <= 10 chars uses short format + result = mask_secret_value("1234567890") + assert result == "1***0" + + def test_just_over_boundary(self): + """Test masking value just over the boundary.""" + # 11 chars should use the full format + result = mask_secret_value("12345678901") + assert result == "123456***8901" From d4a445b0ad425d4ddfb3c7ecde367c66768eb8f7 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 16:02:27 -0800 Subject: [PATCH 38/39] make connection more robust --- eval_protocol/dataset_logger/sqlite_evaluation_row_store.py | 5 +++-- eval_protocol/event_bus/sqlite_event_bus_database.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index f6a81e1e..2b9885c7 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -42,9 +42,10 @@ class EvaluationRow(BaseModel): # type: ignore self._EvaluationRow = EvaluationRow - self._db.connect() + # Wrap connect() in retry logic since setting pragmas can fail with "database is locked" + execute_with_sqlite_retry(lambda: self._db.connect(reuse_if_open=True)) # Use safe=True to avoid errors when tables/indexes already exist - self._db.create_tables([EvaluationRow], safe=True) + execute_with_sqlite_retry(lambda: self._db.create_tables([EvaluationRow], safe=True)) @property def db_path(self) -> str: diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 59a026ed..122fbac9 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -181,9 +181,10 @@ class Event(BaseModel): # type: ignore processed = BooleanField(default=False) # Track if event has been processed self._Event = Event - self._db.connect() + # Wrap connect() in retry logic since setting pragmas can fail with "database is locked" + execute_with_sqlite_retry(lambda: self._db.connect(reuse_if_open=True)) # Use safe=True to avoid errors when tables already exist - self._db.create_tables([Event], safe=True) + execute_with_sqlite_retry(lambda: self._db.create_tables([Event], safe=True)) def publish_event(self, event_type: str, data: Any, process_id: str) -> None: """Publish an event to the database.""" From 37f48567576fd9b4603257ae126ec25ede312ae5 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Thu, 15 Jan 2026 17:05:20 -0800 Subject: [PATCH 39/39] passes --- tests/test_cli_create_rft.py | 37 ++++++++++++++++++------------------ tests/test_ep_upload_e2e.py | 19 ++++++++++-------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index ef4e7288..1c0df26c 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -8,6 +8,7 @@ from eval_protocol.cli_commands import create_rft as cr from eval_protocol.cli_commands import upload as upload_mod +from eval_protocol.cli_commands import local_test as local_test_mod import eval_protocol.fireworks_rft as fr from eval_protocol.cli import parse_args import eval_protocol.cli_commands.utils as cli_utils @@ -103,7 +104,7 @@ def rft_test_harness(tmp_path, monkeypatch, stub_fireworks): # Account id is derived from API key; mock the verify call to keep tests offline. monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(cli_utils, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) @@ -225,7 +226,7 @@ def test_create_rft_evaluator_validation_fails(rft_test_harness, monkeypatch): test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# dummy eval test", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_eval_validation", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Force local evaluator validation to fail calls = {"count": 0, "pytest_target": None} @@ -235,7 +236,7 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ calls["pytest_target"] = pytest_target return 1 # non-zero exit code => validation failure - monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + monkeypatch.setattr(local_test_mod, "run_evaluator_test", _fake_run_evaluator_test) args = argparse.Namespace( evaluator=None, @@ -284,7 +285,7 @@ def test_create_rft_evaluator_validation_passes(rft_test_harness, monkeypatch): test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# dummy ok eval test", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_eval_ok", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Force local evaluator validation to succeed calls = {"count": 0, "pytest_target": None} @@ -294,7 +295,7 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ calls["pytest_target"] = pytest_target return 0 # success - monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + monkeypatch.setattr(local_test_mod, "run_evaluator_test", _fake_run_evaluator_test) args = argparse.Namespace( evaluator=None, @@ -442,8 +443,8 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_ one_file.write_text("# single", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_single", file_path=str(one_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) @@ -505,7 +506,7 @@ def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(r # Fake discovered tests: foo and bar cal_disc = SimpleNamespace(qualname="foo_eval.test_bar_evaluation", file_path=str(cal_file)) svg_disc = SimpleNamespace(qualname="bar_eval.test_baz_evaluation", file_path=str(svg_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [cal_disc, svg_disc]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [cal_disc, svg_disc]) # Capture dataset id used during dataset creation captured = {"dataset_id": None} @@ -572,7 +573,7 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat test_file.write_text("# one", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_one", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Capture dataset id used during dataset creation captured = {"dataset_id": None} @@ -703,7 +704,7 @@ def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, f2.write_text("# b", encoding="utf-8") d1 = SimpleNamespace(qualname="a.test_one", file_path=str(f1)) d2 = SimpleNamespace(qualname="b.test_two", file_path=str(f2)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [d1, d2]) args = argparse.Namespace( evaluator="some-eval", @@ -742,9 +743,9 @@ def test_create_rft_fallback_to_dataset_builder(rft_test_harness, monkeypatch): test_file.write_text("# builder case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_builder", file_path=str(test_file)) # New flow uses _discover_and_select_tests for evaluator resolution; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Also patch _discover_tests for any direct calls during dataset inference. - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [single_disc]) # Dataset builder fallback out_jsonl = project / "metric" / "builder_out.jsonl" @@ -807,7 +808,7 @@ def test_create_rft_rejects_dataloader_jsonl(rft_test_harness, monkeypatch): test_file.write_text("# loader case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_loader", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Provide JSONL via dataloader extractor dl_jsonl = project / "metric" / "loader_out.jsonl" @@ -868,7 +869,7 @@ def test_create_rft_uses_input_dataset_jsonl_when_available(rft_test_harness, mo test_file.write_text("# input_dataset case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_input_ds", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Provide JSONL via input_dataset extractor id_jsonl = project / "metric" / "input_ds_out.jsonl" @@ -933,7 +934,7 @@ def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(r f2.write_text("# beta", encoding="utf-8") d1 = SimpleNamespace(qualname="alpha.test_one", file_path=str(f1)) d2 = SimpleNamespace(qualname="beta.test_two", file_path=str(f2)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [d1, d2]) # Evaluator upload succeeds and version becomes ACTIVE monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) @@ -1097,9 +1098,9 @@ def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_h test_file.write_text("# prefer explicit dataset_jsonl", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_pref", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(cli_utils, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) @@ -1203,7 +1204,7 @@ def test_adapt(row: EvaluationRow) -> EvaluationRow: # Discovery: exactly one test, and resolve_selected_test points to our module/function single_disc = SimpleNamespace(qualname="metric.test_adapt.test_adapt", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) monkeypatch.setattr( cr, "_resolve_selected_test", diff --git a/tests/test_ep_upload_e2e.py b/tests/test_ep_upload_e2e.py index e76ac246..aadfd0b2 100644 --- a/tests/test_ep_upload_e2e.py +++ b/tests/test_ep_upload_e2e.py @@ -151,7 +151,8 @@ def test_ep_upload_discovers_and_uploads_evaluation_test( - Upload via upload_command - Verify all API calls """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests + from eval_protocol.cli_commands.upload import upload_command + from eval_protocol.cli_commands.utils import _discover_tests # 1. CREATE TEST PROJECT STRUCTURE test_content = """ @@ -211,7 +212,7 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: ) # Mock the selection (auto-select the discovered test) - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: + with patch("eval_protocol.cli_commands.utils._prompt_select") as mock_select: mock_select.return_value = discovered_tests # Execute upload command @@ -280,7 +281,8 @@ def test_ep_upload_with_parametrized_test( Test ep upload with a parametrized @evaluation_test Verifies that parametrized tests are discovered and uploaded as single evaluator """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests + from eval_protocol.cli_commands.upload import upload_command + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ import pytest @@ -327,7 +329,7 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: yes=True, ) - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: + with patch("eval_protocol.cli_commands.utils._prompt_select") as mock_select: mock_select.return_value = discovered_tests exit_code = upload_command(args) @@ -352,7 +354,7 @@ def test_ep_upload_discovery_skips_problematic_files(mock_env_variables): Test that discovery properly skips files like setup.py, versioneer.py that would cause issues during pytest collection """ - from eval_protocol.cli_commands.upload import _discover_tests + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ from eval_protocol.pytest import evaluation_test @@ -400,7 +402,7 @@ def test_ep_upload_discovers_non_test_prefixed_files(mock_env_variables): Test that discovery finds @evaluation_test in files like quickstart.py (files that don't start with 'test_') """ - from eval_protocol.cli_commands.upload import _discover_tests + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ from eval_protocol.pytest import evaluation_test @@ -450,7 +452,8 @@ def test_ep_upload_complete_workflow_with_entry_point_validation( - Full 5-step upload flow - Payload structure """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests + from eval_protocol.cli_commands.upload import upload_command + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ from typing import List @@ -506,7 +509,7 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: yes=True, ) - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: + with patch("eval_protocol.cli_commands.utils._prompt_select") as mock_select: mock_select.return_value = discovered_tests exit_code = upload_command(args)