diff --git a/.gitignore b/.gitignore index 7e434174..918f7fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -243,3 +243,5 @@ package.json tau2-bench *.err eval-protocol + +.vscode/launch.json diff --git a/.vscode/.gitignore b/.vscode/.gitignore new file mode 100644 index 00000000..c2dd2a37 --- /dev/null +++ b/.vscode/.gitignore @@ -0,0 +1 @@ +!launch.json.backup diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 38fff2f8..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Debug Tests", - "type": "python", - "request": "launch", - "module": "pytest", - "args": ["-s", "--tb=short", "${file}"], - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - }, - { - "name": "Python: Debug Current File", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - }, - { - "name": "Python: Debug Logs Server", - "type": "python", - "request": "launch", - "module": "eval_protocol.utils.logs_server", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - } - } - ] -} diff --git a/.vscode/launch.json.example b/.vscode/launch.json.example new file mode 100644 index 00000000..7b70e735 --- /dev/null +++ b/.vscode/launch.json.example @@ -0,0 +1,60 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "EP: Upload", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["upload"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + }, + { + "name": "EP: Local Test", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": ["local-test", "--ignore-docker"], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + }, + { + "name": "EP: Create RFT", + "type": "python", + "request": "launch", + "module": "eval_protocol.cli", + "args": [ + "create", + "rft", + "--base-model", + "accounts/fireworks/models/qwen3-0p6b", + "--chunk-size", + "10" + ], + "console": "integratedTerminal", + "justMyCode": false, + "cwd": "", + "env": { + "PYTHONPATH": "${workspaceFolder}", + "FIREWORKS_API_KEY": "${env:FIREWORKS_API_KEY}", + "FIREWORKS_BASE_URL": "${env:FIREWORKS_BASE_URL}", + "FIREWORKS_EXTRA_HEADERS": "{\"x-api-key\": \"${env:FIREWORKS_API_KEY}\", \"X-Fireworks-Gateway-Secret\": \"${env:FIREWORKS_GATEWAY_SECRET}\"}" + } + } + ] +} diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 68ce134c..40e3c777 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -1,12 +1,75 @@ import logging import os -from typing import Optional +from typing import Dict, Optional import requests +from dotenv import dotenv_values, find_dotenv, load_dotenv logger = logging.getLogger(__name__) +def find_dotenv_path(search_path: Optional[str] = None) -> Optional[str]: + """ + Find the .env file path, searching .env.dev first, then .env. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Path to the .env file if found, otherwise None. + """ + # If a specific search path is provided, look there first + if search_path: + env_dev_path = os.path.join(search_path, ".env.dev") + if os.path.isfile(env_dev_path): + return env_dev_path + env_path = os.path.join(search_path, ".env") + if os.path.isfile(env_path): + return env_path + return None + + # Otherwise use find_dotenv to search up the directory tree + env_dev_path = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) + if env_dev_path: + return env_dev_path + env_path = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) + if env_path: + return env_path + return None + + +def get_dotenv_values(search_path: Optional[str] = None) -> Dict[str, Optional[str]]: + """ + Get all key-value pairs from the .env file. + + Args: + search_path: Directory to search from. If None, uses current working directory. + + Returns: + Dictionary of environment variable names to values. + """ + dotenv_path = find_dotenv_path(search_path) + if dotenv_path: + return dotenv_values(dotenv_path) + return {} + + +# --- Load .env files --- +# Attempt to load .env.dev first, then .env as a fallback. +# This happens when the module is imported. +# We use override=False (default) so that existing environment variables +# (e.g., set in the shell) are NOT overridden by .env files. +_DOTENV_PATH = find_dotenv_path() +if _DOTENV_PATH: + load_dotenv(dotenv_path=_DOTENV_PATH, override=False) + logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_DOTENV_PATH}") +else: + logger.debug( + "eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables." + ) +# --- End .env loading --- + + def get_fireworks_api_key() -> Optional[str]: """ Retrieves the Fireworks API key. @@ -73,6 +136,8 @@ def verify_api_key_and_get_account_id( Args: api_key: Optional explicit API key. When None, resolves via get_fireworks_api_key(). api_base: Optional explicit API base. When None, resolves via get_fireworks_api_base(). + If api_base is api.fireworks.ai, it is used directly. Otherwise, defaults to + dev.api.fireworks.ai for the verification call. Returns: The resolved account id if verification succeeds and the header is present; otherwise None. @@ -81,7 +146,12 @@ def verify_api_key_and_get_account_id( resolved_key = api_key or get_fireworks_api_key() if not resolved_key: return None - resolved_base = api_base or get_fireworks_api_base() + provided_base = api_base or get_fireworks_api_base() + # Use api.fireworks.ai if explicitly provided, otherwise fall back to dev + if "api.fireworks.ai" in provided_base: + resolved_base = provided_base + else: + resolved_base = "https://dev.api.fireworks.ai" from .common_utils import get_user_agent diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 4222cab9..ea862d70 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -81,13 +81,12 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "--env-file", help="Path to .env file containing secrets to upload (default: .env in current directory)", ) - upload_parser.add_argument( - "--force", - action="store_true", - help="Overwrite existing evaluator with the same ID", - ) # Auto-generate flags from SDK Fireworks().evaluators.create() signature + # Note: We use Fireworks() directly here instead of create_fireworks_client() + # because we only need the method signature for introspection, not a fully + # authenticated client. create_fireworks_client() would trigger an HTTP request + # to verify the API key, causing delays even for --help invocations. create_evaluator_fn = Fireworks().evaluators.create upload_skip_fields = { @@ -137,7 +136,6 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") - rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") rft_parser.add_argument( "--ignore-docker", @@ -198,6 +196,10 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse "loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.", } + # Note: We use Fireworks() directly here instead of create_fireworks_client() + # because we only need the method signature for introspection, not a fully + # authenticated client. create_fireworks_client() would trigger an HTTP request + # to verify the API key, causing delays even for --help invocations. create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create add_args_from_callable_signature( @@ -208,6 +210,78 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse help_overrides=help_overrides, ) + # Create evj (Evaluation Job) subcommand + evj_parser = create_subparsers.add_parser( + "evj", + help="Create an Evaluation Job on Fireworks", + ) + + evj_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") + evj_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") + evj_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") + evj_parser.add_argument( + "--ignore-docker", + action="store_true", + help="Ignore Dockerfile even if present; run pytest on host during evaluator validation", + ) + evj_parser.add_argument( + "--docker-build-extra", + default="", + metavar="", + help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")", + ) + evj_parser.add_argument( + "--docker-run-extra", + default="", + metavar="", + help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")", + ) + evj_parser.add_argument( + "--quiet", + action="store_true", + help="If set, only errors will be printed.", + ) + + # Auto-generate flags from SDK Fireworks().evaluation_jobs.create() signature + create_evj_fn = Fireworks().evaluation_jobs.create + + evj_skip_fields = { + "__top_level__": { + "account_id", # auto-detected + "extra_headers", + "extra_query", + "extra_body", + "timeout", + }, + "evaluation_job": { + "output_stats", # read-only, set by server + }, + } + evj_aliases = { + "evaluation_job_id": ["--job-id"], + "evaluation_job.evaluator": ["--evaluator"], + "evaluation_job.input_dataset": ["--dataset"], # --input-dataset is auto-added + "evaluation_job.display_name": ["--name"], + # output_dataset, evaluator_version get their short forms auto-added + } + evj_help_overrides = { + "evaluation_job_id": "Evaluation Job ID to use", + "evaluation_job.evaluator": "Evaluator resource name (format: accounts/{account_id}/evaluators/{evaluator_id})", + "evaluation_job.input_dataset": "Input dataset resource name (format: accounts/{account_id}/datasets/{dataset_id})", + "evaluation_job.output_dataset": "Output dataset resource name where results will be written", + "evaluation_job.display_name": "Display name for the evaluation job", + "evaluation_job.evaluator_version": "Specific evaluator version to use (defaults to current version)", + "leaderboard_ids": "Optional leaderboard IDs to attach this job to upon creation", + } + + add_args_from_callable_signature( + evj_parser, + create_evj_fn, + skip_fields=evj_skip_fields, + aliases=evj_aliases, + help_overrides=evj_help_overrides, + ) + # Local test command local_test_parser = subparsers.add_parser( "local-test", @@ -349,7 +423,11 @@ def _extract_flag_value(argv_list, flag_name): from .cli_commands.create_rft import create_rft_command return create_rft_command(args) - print("Error: missing subcommand for 'create'. Try: eval-protocol create rft") + elif args.create_command == "evj": + from .cli_commands.create_evj import create_evj_command + + return create_evj_command(args) + print("Error: missing subcommand for 'create'. Try: eval-protocol create rft|evj") return 1 elif args.command == "local-test": from .cli_commands.local_test import local_test_command diff --git a/eval_protocol/cli_commands/create_evj.py b/eval_protocol/cli_commands/create_evj.py new file mode 100644 index 00000000..99319a54 --- /dev/null +++ b/eval_protocol/cli_commands/create_evj.py @@ -0,0 +1,249 @@ +import argparse +from fireworks._client import Fireworks +from fireworks.types.evaluation_job_create_response import EvaluationJobCreateResponse +import json +import os +import inspect +from typing import Any, Dict, Optional + +from ..auth import get_fireworks_api_base, get_fireworks_api_key +from ..fireworks_client import create_fireworks_client +from .secrets import handle_secrets_upload +from .utils import ( + _build_trimmed_dataset_id, + _build_evaluator_dashboard_url, + _ensure_account_id, + _extract_terminal_segment, + resolve_evaluator, + upload_and_ensure_evaluator, + validate_evaluator_locally, +) +from .create_rft import ( + resolve_dataset, + upload_dataset, +) + + +def _resolve_output_dataset( + account_id: str, + output_dataset_arg: Optional[str], + evaluator_id: str, +) -> tuple[str, str]: + """Resolve output dataset id and resource name. + + If not provided, auto-generates an output dataset ID based on the evaluator ID. + """ + if output_dataset_arg: + if output_dataset_arg.startswith("accounts/"): + output_dataset_resource = output_dataset_arg + output_dataset_id = _extract_terminal_segment(output_dataset_arg) + else: + output_dataset_id = output_dataset_arg + output_dataset_resource = f"accounts/{account_id}/datasets/{output_dataset_id}" + else: + # Auto-generate output dataset ID + output_dataset_id = _build_trimmed_dataset_id(f"{evaluator_id}-results") + output_dataset_resource = f"accounts/{account_id}/datasets/{output_dataset_id}" + print(f"Auto-generated output dataset ID: {output_dataset_id}") + + return output_dataset_id, output_dataset_resource + + +def _print_evj_links( + evaluator_id: str, input_dataset_id: str, output_dataset_id: str, job_name: Optional[str] +) -> None: + """Print helpful links to the Fireworks dashboard.""" + print("\nšŸ“Š Links:") + print(f" Evaluator: {_build_evaluator_dashboard_url(evaluator_id)}") + if job_name: + # Extract job id from resource name if present + job_id = _extract_terminal_segment(job_name) if "/" in job_name else job_name + print(f" Evaluation Job: https://fireworks.ai/dashboard/evaluation-jobs/{job_id}") + print(f" Input Dataset: https://fireworks.ai/dashboard/datasets/{input_dataset_id}") + print(f" Output Dataset: https://fireworks.ai/dashboard/datasets/{output_dataset_id}") + + +def _create_evj_job( + account_id: str, + api_key: str, + api_base: str, + evaluator_id: str, + evaluator_resource_name: str, + input_dataset_id: str, + input_dataset_resource: str, + output_dataset_id: str, + output_dataset_resource: str, + args: argparse.Namespace, + dry_run: bool, +) -> int: + """Build and submit the Evaluation Job request (via Fireworks SDK).""" + + signature = inspect.signature(create_fireworks_client().evaluation_jobs.create) + + # Build top-level SDK kwargs + sdk_kwargs: Dict[str, Any] = {} + + # Build the evaluation_job nested object + evaluation_job: Dict[str, Any] = { + "evaluator": evaluator_resource_name, + "input_dataset": input_dataset_resource, + "output_dataset": output_dataset_resource, + } + + args_dict = vars(args) + + # Handle evaluation_job nested fields + for k, v in args_dict.items(): + if v is None: + continue + if k.startswith("evaluation_job_") and k != "evaluation_job_id": + field_name = k[len("evaluation_job_") :] + # Don't overwrite the normalized resources + if field_name in ("evaluator", "input_dataset", "output_dataset"): + continue + evaluation_job[field_name] = v + + sdk_kwargs["evaluation_job"] = evaluation_job + + # Handle top-level fields + for name in signature.parameters: + if name in ("account_id", "evaluation_job", "extra_headers", "extra_query", "extra_body", "timeout"): + continue + + value = args_dict.get(name) + if value is not None: + sdk_kwargs[name] = value + + print(f"Prepared Evaluation Job for evaluator '{evaluator_id}' using dataset '{input_dataset_id}'") + + if dry_run: + print("--dry-run: would call Fireworks().evaluation_jobs.create with kwargs:") + print(json.dumps(sdk_kwargs, indent=2)) + _print_evj_links(evaluator_id, input_dataset_id, output_dataset_id, None) + return 0 + + try: + fw: Fireworks = create_fireworks_client(api_key=api_key, base_url=api_base) + job: EvaluationJobCreateResponse = fw.evaluation_jobs.create(account_id=account_id, **sdk_kwargs) + job_name = job.name + print(f"\nāœ… Created Evaluation Job: {job_name}") + _print_evj_links(evaluator_id, input_dataset_id, output_dataset_id, job_name) + return 0 + except Exception as e: + print(f"Error creating Evaluation Job: {e}") + return 1 + + +def create_evj_command(args) -> int: + """Main entry point for the 'create evj' CLI command.""" + # Pre-flight: resolve auth and environment + api_key = get_fireworks_api_key() + if not api_key: + print("Error: FIREWORKS_API_KEY not set.") + return 1 + + account_id = _ensure_account_id() + if not account_id: + print("Error: Could not resolve Fireworks account id from FIREWORKS_API_KEY.") + return 1 + + api_base = get_fireworks_api_base() + project_root = os.getcwd() + evaluator_arg: Optional[str] = getattr(args, "evaluation_job_evaluator", None) + input_dataset_arg: Optional[str] = getattr(args, "evaluation_job_input_dataset", None) + output_dataset_arg: Optional[str] = getattr(args, "evaluation_job_output_dataset", None) + non_interactive: bool = bool(getattr(args, "yes", False)) + dry_run: bool = bool(getattr(args, "dry_run", False)) + skip_validation: bool = bool(getattr(args, "skip_validation", False)) + ignore_docker: bool = bool(getattr(args, "ignore_docker", False)) + docker_build_extra: str = getattr(args, "docker_build_extra", "") or "" + docker_run_extra: str = getattr(args, "docker_run_extra", "") or "" + + # 1) Resolve evaluator and associated local test + ( + evaluator_id, + evaluator_resource_name, + selected_test_file_path, + selected_test_func_name, + ) = resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id, command_name="create evj") + if not evaluator_id or not evaluator_resource_name: + return 1 + + # 1.5) Handle secrets upload (with double verification for existing secrets) + env_file = getattr(args, "env_file", None) + handle_secrets_upload( + project_root=project_root, + env_file=env_file, + non_interactive=non_interactive, + ) + + # 2) Resolve input dataset source (id or JSONL path) + input_dataset_id, input_dataset_resource, dataset_jsonl = resolve_dataset( + project_root=project_root, + account_id=account_id, + dataset_id_arg=input_dataset_arg, + dataset_jsonl_arg=None, # EVJ doesn't support --dataset-jsonl flag yet + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, + ) + # Require either an existing dataset id or a JSONL source to materialize from + if dataset_jsonl is None and not input_dataset_id: + return 1 + + # 3) Resolve output dataset + output_dataset_id, output_dataset_resource = _resolve_output_dataset(account_id, output_dataset_arg, evaluator_id) + + # 4) Optional local validation + if not skip_validation: + if not validate_evaluator_locally( + project_root=project_root, + selected_test_file=selected_test_file_path, + selected_test_func=selected_test_func_name, + ignore_docker=ignore_docker, + docker_build_extra=docker_build_extra, + docker_run_extra=docker_run_extra, + ): + return 1 + + # 5) Upload dataset when using JSONL sources (no-op for existing datasets) + input_dataset_id, input_dataset_resource = upload_dataset( + project_root=project_root, + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + dataset_id=input_dataset_id, + dataset_resource=input_dataset_resource, + dataset_jsonl=dataset_jsonl, + dataset_display_name=None, # EVJ auto-generates display name + dry_run=dry_run, + ) + if not input_dataset_id or not input_dataset_resource: + return 1 + + # 6) Ensure evaluator exists and its latest version is ACTIVE (upload + poll if needed) + if not dry_run: + if not upload_and_ensure_evaluator( + project_root=project_root, + evaluator_id=evaluator_id, + api_key=api_key, + api_base=api_base, + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, + ): + return 1 + + # 7) Create the Evaluation Job + return _create_evj_job( + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + evaluator_resource_name=evaluator_resource_name, + input_dataset_id=input_dataset_id, + input_dataset_resource=input_dataset_resource, + output_dataset_id=output_dataset_id, + output_dataset_resource=output_dataset_resource, + args=args, + dry_run=dry_run, + ) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 702eb2fe..dd8d361c 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -7,19 +7,19 @@ import time from typing import Any, Callable, Dict, Optional import inspect -import requests import tempfile from pydantic import ValidationError from ..auth import get_fireworks_api_base, get_fireworks_api_key -from ..common_utils import get_user_agent, load_jsonl +from ..fireworks_client import create_fireworks_client +from ..common_utils import load_jsonl from ..fireworks_rft import ( create_dataset_from_jsonl, detect_dataset_builder, materialize_dataset_via_builder, ) from ..models import EvaluationRow -from .upload import upload_command +from .secrets import handle_secrets_upload from .utils import ( _build_entry_point, _build_trimmed_dataset_id, @@ -29,14 +29,16 @@ _ensure_account_id, _extract_terminal_segment, _normalize_evaluator_id, + _poll_evaluator_version_status, _print_links, _resolve_selected_test, load_module_from_file_path, + resolve_evaluator, + upload_and_ensure_evaluator, + validate_evaluator_locally, ) from .local_test import run_evaluator_test -from fireworks import Fireworks - def _extract_dataset_adapter( test_file_path: str, test_func_name: str @@ -223,67 +225,6 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) return None -def _poll_evaluator_status( - evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10 -) -> bool: - """ - Poll evaluator status until it becomes ACTIVE or times out. - - Args: - evaluator_resource_name: Full evaluator resource name (e.g., accounts/xxx/evaluators/yyy) - api_key: Fireworks API key - api_base: Fireworks API base URL - timeout_minutes: Maximum time to wait in minutes - - Returns: - True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED - """ - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - - check_url = f"{api_base}/v1/{evaluator_resource_name}" - timeout_seconds = timeout_minutes * 60 - poll_interval = 10 # seconds - start_time = time.time() - - print(f"Polling evaluator status (timeout: {timeout_minutes}m, interval: {poll_interval}s)...") - - while time.time() - start_time < timeout_seconds: - try: - response = requests.get(check_url, headers=headers, timeout=30) - response.raise_for_status() - - evaluator_data = response.json() - state = evaluator_data.get("state", "STATE_UNSPECIFIED") - status = evaluator_data.get("status", "") - - if state == "ACTIVE": - print("āœ… Evaluator is ACTIVE and ready!") - return True - elif state == "BUILD_FAILED": - print(f"āŒ Evaluator build failed. Status: {status}") - return False - elif state == "BUILDING": - elapsed_minutes = (time.time() - start_time) / 60 - print(f"ā³ Evaluator is still building... ({elapsed_minutes:.1f}m elapsed)") - else: - print(f"ā³ Evaluator state: {state}, status: {status}") - - except requests.exceptions.RequestException as e: - print(f"Warning: Failed to check evaluator status: {e}") - - # Wait before next poll - time.sleep(poll_interval) - - # Timeout reached - elapsed_minutes = (time.time() - start_time) / 60 - print(f"ā° Timeout after {elapsed_minutes:.1f}m - evaluator is not yet ACTIVE") - return False - - def _validate_dataset_jsonl(jsonl_path: str, sample_limit: int = 50) -> bool: """Validate that a JSONL file contains rows compatible with EvaluationRow. @@ -334,108 +275,31 @@ def _validate_dataset(dataset_jsonl: Optional[str]) -> bool: return _validate_dataset_jsonl(dataset_jsonl) -def _validate_evaluator_locally( - project_root: str, - selected_test_file: Optional[str], - selected_test_func: Optional[str], - ignore_docker: bool, - docker_build_extra: str, - docker_run_extra: str, -) -> bool: - """Run pytest locally for the selected evaluation test to validate the evaluator. - - The pytest helpers always enforce a small success threshold (0.01) for - evaluation_test-based suites so that an evaluation run where all scores are - 0.0 will naturally fail with a non-zero pytest exit code, which we then treat - as a failed validator. - """ - if not selected_test_file or not selected_test_func: - # No local test associated; skip validation but warn the user. - print("Warning: Could not resolve a local evaluation test for this evaluator; skipping local validation.") - return True - - pytest_target = _build_entry_point(project_root, selected_test_file, selected_test_func) - exit_code = run_evaluator_test( - project_root=project_root, - pytest_target=pytest_target, - ignore_docker=ignore_docker, - docker_build_extra=docker_build_extra, - docker_run_extra=docker_run_extra, - ) - return exit_code == 0 - - -def _resolve_evaluator( +def resolve_dataset( project_root: str, - evaluator_arg: Optional[str], - non_interactive: bool, account_id: str, -) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: - """Resolve evaluator id/resource and associated local test (file + func).""" - evaluator_id = evaluator_arg - selected_test_file_path: Optional[str] = None - selected_test_func_name: Optional[str] = None - - if not evaluator_id: - selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive) - if not selected_tests: - return None, None, None, None - - if len(selected_tests) != 1: - if non_interactive and len(selected_tests) > 1: - print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") - print(" Please pass --evaluator or --entry to disambiguate.") - else: - print("Error: Please select exactly one evaluation test for 'create rft'.") - return None, None, None, None - - chosen = selected_tests[0] - func_name = chosen.qualname.split(".")[-1] - source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] - evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}") - # Resolve selected test once for downstream - selected_test_file_path, selected_test_func_name = _resolve_selected_test( - project_root, evaluator_id, selected_tests=selected_tests - ) - else: - # Caller provided an evaluator id or fully-qualified resource; try to resolve local test - short_id = evaluator_id - if evaluator_id.startswith("accounts/"): - short_id = _extract_terminal_segment(evaluator_id) - st_path, st_func = _resolve_selected_test(project_root, short_id) - if st_path and st_func: - selected_test_file_path = st_path - selected_test_func_name = st_func - evaluator_id = short_id - - if not evaluator_id: - return None, None, None, None - - # Resolve evaluator resource name to fully-qualified format required by API. - if evaluator_arg and evaluator_arg.startswith("accounts/"): - evaluator_resource_name = evaluator_arg - else: - evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" - - return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name - - -def _resolve_dataset( - project_root: str, - account_id: str, - args: argparse.Namespace, + dataset_id_arg: Optional[str], + dataset_jsonl_arg: Optional[str], selected_test_file_path: Optional[str], selected_test_func_name: Optional[str], ) -> tuple[Optional[str], Optional[str], Optional[str]]: """Resolve dataset source without performing any uploads. + Args: + project_root: Path to the project root directory. + account_id: The Fireworks account ID. + dataset_id_arg: Dataset ID or fully-qualified resource name (from --dataset). + dataset_jsonl_arg: Path to a local JSONL file (from --dataset-jsonl). + selected_test_file_path: Path to the selected test file (for inference). + selected_test_func_name: Name of the selected test function (for inference). + Returns a tuple of: - dataset_id: existing dataset id when using --dataset or fully-qualified dataset resource - dataset_resource: fully-qualified dataset resource for existing datasets; None for JSONL sources - dataset_jsonl: local JSONL path when using --dataset-jsonl or inferred sources; None for id-only datasets """ - dataset_id = getattr(args, "dataset", None) - dataset_jsonl = getattr(args, "dataset_jsonl", None) + dataset_id = dataset_id_arg + dataset_jsonl = dataset_jsonl_arg dataset_resource_override: Optional[str] = None if dataset_id and dataset_jsonl: @@ -505,7 +369,7 @@ def _resolve_dataset( return dataset_id, dataset_resource, dataset_jsonl -def _upload_dataset( +def upload_dataset( project_root: str, account_id: str, api_key: str, @@ -514,13 +378,28 @@ def _upload_dataset( dataset_id: Optional[str], dataset_resource: Optional[str], dataset_jsonl: Optional[str], - args: argparse.Namespace, + dataset_display_name: Optional[str], dry_run: bool, ) -> tuple[Optional[str], Optional[str]]: """Create/upload the dataset when using a local JSONL source. For existing datasets (--dataset or fully-qualified ids), this is a no-op that simply ensures dataset_id and dataset_resource are populated. + + Args: + project_root: Path to the project root directory. + account_id: The Fireworks account ID. + api_key: Fireworks API key. + api_base: Fireworks API base URL. + evaluator_id: The evaluator ID (used for generating dataset ID if needed). + dataset_id: Existing dataset ID (if known). + dataset_resource: Existing dataset resource name (if known). + dataset_jsonl: Path to local JSONL file to upload (if any). + dataset_display_name: Display name for the dataset (optional). + dry_run: If True, simulate the upload without actually creating. + + Returns: + A tuple of (dataset_id, dataset_resource). """ # Existing dataset case: nothing to upload if not dataset_jsonl: @@ -532,7 +411,7 @@ def _upload_dataset( # JSONL-based dataset: upload or simulate upload inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) - dataset_display_name = getattr(args, "dataset_display_name", None) or inferred_dataset_id + display_name = dataset_display_name or inferred_dataset_id # Resolve dataset_jsonl path relative to CWD if needed jsonl_path_for_upload = ( @@ -551,7 +430,7 @@ def _upload_dataset( api_key=api_key, api_base=api_base, dataset_id=inferred_dataset_id, - display_name=dataset_display_name, + display_name=display_name, jsonl_path=jsonl_path_for_upload, ) print(f"āœ“ Created and uploaded dataset: {dataset_id}") @@ -562,103 +441,6 @@ def _upload_dataset( return None, None -def _upload_and_ensure_evaluator( - project_root: str, - evaluator_id: str, - evaluator_resource_name: str, - api_key: str, - api_base: str, - force: bool, -) -> bool: - """Ensure the evaluator exists and is ACTIVE, uploading it if needed.""" - # Optional short-circuit: if evaluator already exists and not forcing, skip upload path - if not force: - try: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) - if resp.ok: - state = resp.json().get("state", "STATE_UNSPECIFIED") - print(f"āœ“ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).") - # Poll for ACTIVE before proceeding - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - if not _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ): - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\nāŒ Evaluator is not ready within the timeout period.") - print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - except requests.exceptions.RequestException: - pass - - # Ensure evaluator exists by invoking the upload flow programmatically - try: - tests = _discover_tests(project_root) - selected_entry: Optional[str] = None - st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) - if st_path and st_func: - selected_entry = _build_entry_point(project_root, st_path, st_func) - # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators - if selected_entry is None and len(tests) > 1: - print( - f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" - " Please re-run specifying the evaluator.\n" - " Hints:\n" - " - eval-protocol create rft --evaluator \n" - ) - return False - - upload_args = argparse.Namespace( - path=project_root, - entry=selected_entry, - id=evaluator_id, - display_name=None, - description=None, - force=force, # Pass through the --force flag - yes=True, - env_file=None, # Add the new env_file parameter - ) - - if force: - print(f"šŸ”„ Force flag enabled - will overwrite existing evaluator '{evaluator_id}'") - - rc = upload_command(upload_args) - if rc == 0: - print(f"āœ“ Uploaded/ensured evaluator: {evaluator_id}") - - # Poll for evaluator status - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - is_active = _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ) - - if not is_active: - dashboard_url = _build_evaluator_dashboard_url(evaluator_id) - print("\nāŒ Evaluator is not ready within the timeout period.") - print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return False - return True - else: - print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") - return False - except Exception as e: - print(f"Warning: Failed to upload evaluator automatically: {e}") - return False - - def _create_rft_job( account_id: str, api_key: str, @@ -672,7 +454,7 @@ def _create_rft_job( ) -> int: """Build and submit the RFT job request (via Fireworks SDK).""" - signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create) + signature = inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create) # Build top-level SDK kwargs sdk_kwargs: Dict[str, Any] = { @@ -711,7 +493,7 @@ def _create_rft_job( return 0 try: - fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base) + fw: Fireworks = create_fireworks_client(api_key=api_key, base_url=api_base) job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs) job_name = job.name print(f"\nāœ… Created Reinforcement Fine-tuning Job: {job_name}") @@ -739,7 +521,6 @@ def create_rft_command(args) -> int: evaluator_arg: Optional[str] = getattr(args, "evaluator", None) non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) - force: bool = bool(getattr(args, "force", False)) skip_validation: bool = bool(getattr(args, "skip_validation", False)) ignore_docker: bool = bool(getattr(args, "ignore_docker", False)) docker_build_extra: str = getattr(args, "docker_build_extra", "") or "" @@ -751,15 +532,24 @@ def create_rft_command(args) -> int: evaluator_resource_name, selected_test_file_path, selected_test_func_name, - ) = _resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id) + ) = resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id, command_name="create rft") if not evaluator_id or not evaluator_resource_name: return 1 + # 1.5) Handle secrets upload (with double verification for existing secrets) + env_file = getattr(args, "env_file", None) + handle_secrets_upload( + project_root=project_root, + env_file=env_file, + non_interactive=non_interactive, + ) + # 2) Resolve dataset source (id or JSONL path) - dataset_id, dataset_resource, dataset_jsonl = _resolve_dataset( + dataset_id, dataset_resource, dataset_jsonl = resolve_dataset( project_root=project_root, account_id=account_id, - args=args, + dataset_id_arg=getattr(args, "dataset", None), + dataset_jsonl_arg=getattr(args, "dataset_jsonl", None), selected_test_file_path=selected_test_file_path, selected_test_func_name=selected_test_func_name, ) @@ -784,7 +574,7 @@ def create_rft_command(args) -> int: return 1 # Evaluator validation (run pytest for the selected test, possibly via Docker) - if not _validate_evaluator_locally( + if not validate_evaluator_locally( project_root=project_root, selected_test_file=selected_test_file_path, selected_test_func=selected_test_func_name, @@ -795,7 +585,7 @@ def create_rft_command(args) -> int: return 1 # 4) Upload dataset when using JSONL sources (no-op for existing datasets) - dataset_id, dataset_resource = _upload_dataset( + dataset_id, dataset_resource = upload_dataset( project_root=project_root, account_id=account_id, api_key=api_key, @@ -804,20 +594,20 @@ def create_rft_command(args) -> int: dataset_id=dataset_id, dataset_resource=dataset_resource, dataset_jsonl=dataset_jsonl, - args=args, + dataset_display_name=getattr(args, "dataset_display_name", None), dry_run=dry_run, ) if not dataset_id or not dataset_resource: return 1 - # 5) Ensure evaluator exists and is ACTIVE (upload + poll if needed) - if not _upload_and_ensure_evaluator( + # 5) Ensure evaluator exists and its latest version is ACTIVE (upload + poll if needed) + if not upload_and_ensure_evaluator( project_root=project_root, evaluator_id=evaluator_id, - evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, - force=force, + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, ): return 1 diff --git a/eval_protocol/cli_commands/secrets.py b/eval_protocol/cli_commands/secrets.py new file mode 100644 index 00000000..0d1789a3 --- /dev/null +++ b/eval_protocol/cli_commands/secrets.py @@ -0,0 +1,390 @@ +""" +Secret handling module for ep CLI commands. + +This module provides reusable functions for loading, selecting, and uploading +secrets to Fireworks. Used by 'ep upload', 'ep create rft', and 'ep create evj'. +""" + +import os +import sys +from typing import Dict + +from dotenv import dotenv_values + +from eval_protocol.auth import get_fireworks_api_key +from eval_protocol.platform_api import create_or_update_fireworks_secret, get_fireworks_secret +from .utils import _ensure_account_id, _get_questionary_style + + +def load_secrets_from_env_file(env_file_path: str) -> Dict[str, str]: + """ + Load secrets from a .env file that should be uploaded to Fireworks. + + Uses python-dotenv's dotenv_values() for proper parsing of .env files, + which correctly handles: + - End-of-line comments (e.g., KEY=value # comment) + - Quoted values (single and double quotes) + - Escape sequences + - Multi-line values + """ + if not os.path.exists(env_file_path): + return {} + + # Use dotenv_values for proper .env parsing (handles comments, quotes, etc.) + parsed = dotenv_values(env_file_path) + # Filter out None values and convert to Dict[str, str] + return {k: v for k, v in parsed.items() if v is not None} + + +def mask_secret_value(value: str | None) -> str: + """ + Return a masked representation of a secret showing only a small prefix/suffix. + Example: fw_3Z*******Xgnk + """ + try: + if not isinstance(value, str) or not value: + return "" + prefix_len = 6 + suffix_len = 4 + if len(value) <= prefix_len + suffix_len: + return value[0] + "***" + value[-1] + return f"{value[:prefix_len]}***{value[-suffix_len:]}" + except Exception: + return "" + + +def check_existing_secrets( + secrets: Dict[str, str], + account_id: str, +) -> set[str]: + """ + Check which secrets already exist on Fireworks. + Returns a set of secret names that already exist. + """ + existing = set() + for secret_name in secrets.keys(): + secret = get_fireworks_secret(account_id=account_id, key_name=secret_name) + if secret is not None: + existing.add(secret_name) + return existing + + +def prompt_select_secrets( + secrets: Dict[str, str], + secrets_from_env_file: Dict[str, str], + existing_secrets: set[str], + non_interactive: bool, +) -> Dict[str, str]: + """ + Prompt user to select which environment variables to upload as secrets. + Existing secrets are unchecked by default and marked with [exists]. + Returns the selected secrets. + """ + if not secrets: + return {} + + if non_interactive: + # In non-interactive mode, only return new secrets (skip existing ones) + return {k: v for k, v in secrets.items() if k not in existing_secrets} + + # Check if running in a non-TTY environment (e.g., CI/CD) + if not sys.stdin.isatty(): + # In non-TTY, only return new secrets (skip existing ones) + return {k: v for k, v in secrets.items() if k not in existing_secrets} + + try: + import questionary + + custom_style = _get_questionary_style() + + # Build choices with source info and masked values + # Existing secrets are unchecked by default + choices = [] + for key, value in secrets.items(): + source = ".env" if key in secrets_from_env_file else "env" + masked = mask_secret_value(value) + is_existing = key in existing_secrets + exists_marker = " [exists]" if is_existing else "" + label = f"{key}{exists_marker} ({source}: {masked})" + # Uncheck existing secrets by default + choices.append(questionary.Choice(title=label, value=key, checked=not is_existing)) + + if len(choices) == 0: + return {} + + print("\nFound environment variables to upload as Fireworks secrets:") + if existing_secrets: + print(" (Secrets marked [exists] are unchecked - selecting them will override)") + selected_keys = questionary.checkbox( + "Select secrets to upload:", + choices=choices, + style=custom_style, + pointer=">", + instruction="(↑↓ move, space select, enter confirm)", + ).ask() + + if selected_keys is None: + # User cancelled with Ctrl+C + print("\nSecret upload cancelled.") + return {} + + return {k: v for k, v in secrets.items() if k in selected_keys} + + except ImportError: + # Fallback to simple text-based selection + return _prompt_select_secrets_fallback(secrets, secrets_from_env_file, existing_secrets) + except KeyboardInterrupt: + print("\n\nSecret upload cancelled.") + return {} + + +def _prompt_select_secrets_fallback( + secrets: Dict[str, str], + secrets_from_env_file: Dict[str, str], + existing_secrets: set[str], +) -> Dict[str, str]: + """Fallback prompt selection for when questionary is not available.""" + print("\n" + "=" * 60) + print("Found environment variables to upload as Fireworks secrets:") + print("=" * 60) + print("\nTip: Install questionary for better UX: pip install questionary\n") + + secret_list = list(secrets.items()) + new_indices = [] + for idx, (key, value) in enumerate(secret_list, 1): + source = ".env" if key in secrets_from_env_file else "env" + masked = mask_secret_value(value) + is_existing = key in existing_secrets + exists_marker = " [exists]" if is_existing else "" + print(f" [{idx}] {key}{exists_marker} ({source}: {masked})") + if not is_existing: + new_indices.append(idx) + + print("\n" + "=" * 60) + if existing_secrets: + print("Note: Secrets marked [exists] will be overridden if selected.") + default_selection = ",".join(str(i) for i in new_indices) if new_indices else "none" + print("Enter numbers (comma-separated), 'all' for all, or 'none' to skip.") + print(f"Default (new secrets only): {default_selection}") + + try: + choice = input("Selection: ").strip().lower() + except KeyboardInterrupt: + print("\nSecret upload cancelled.") + return {} + + if not choice: + # Default: only new secrets + choice = default_selection + + if choice == "none": + return {} + + if choice == "all": + return secrets + + try: + indices = [int(x.strip()) for x in choice.split(",")] + selected = {} + for idx in indices: + if 1 <= idx <= len(secret_list): + key, value = secret_list[idx - 1] + selected[key] = value + return selected + except ValueError: + print("Invalid input. Skipping secret upload.") + return {} + + +def confirm_override_secrets( + secrets_to_override: set[str], + non_interactive: bool, +) -> set[str]: + """ + Prompt user to confirm overriding existing secrets (double verification). + Returns the set of secrets confirmed for override. + """ + if not secrets_to_override: + return set() + + if non_interactive: + # In non-interactive mode, skip overriding existing secrets by default + print(f"\nāš ļø Skipping {len(secrets_to_override)} existing secret(s) in non-interactive mode.") + print(" Use interactive mode to confirm overriding existing secrets.") + return set() + + # Check if running in a non-TTY environment + if not sys.stdin.isatty(): + print(f"\nāš ļø Skipping {len(secrets_to_override)} existing secret(s) (non-TTY environment).") + return set() + + print(f"\nāš ļø The following {len(secrets_to_override)} secret(s) already exist on Fireworks:") + for name in sorted(secrets_to_override): + print(f" • {name}") + + try: + import questionary + + custom_style = _get_questionary_style() + + # First confirmation + confirm1 = questionary.confirm( + "Do you want to override these existing secrets?", + default=False, + style=custom_style, + ).ask() + + if confirm1 is None or not confirm1: + print("Override cancelled. Existing secrets will be skipped.") + return set() + + # Second confirmation (double verification) + confirm2 = questionary.confirm( + "āš ļø Are you SURE? This will permanently overwrite the existing secret values.", + default=False, + style=custom_style, + ).ask() + + if confirm2 is None or not confirm2: + print("Override cancelled. Existing secrets will be skipped.") + return set() + + return secrets_to_override + + except ImportError: + # Fallback to simple text-based confirmation + return _confirm_override_secrets_fallback(secrets_to_override) + except KeyboardInterrupt: + print("\n\nOverride cancelled.") + return set() + + +def _confirm_override_secrets_fallback(secrets_to_override: set[str]) -> set[str]: + """Fallback confirmation for when questionary is not available.""" + print("\n" + "=" * 60) + print("WARNING: Confirm override of existing secrets") + print("=" * 60) + + try: + # First confirmation + response1 = input("Do you want to override these existing secrets? (yes/no): ").strip().lower() + if response1 not in ("yes", "y"): + print("Override cancelled. Existing secrets will be skipped.") + return set() + + # Second confirmation (double verification) + response2 = input("āš ļø Are you SURE? Type 'override' to confirm: ").strip().lower() + if response2 != "override": + print("Override cancelled. Existing secrets will be skipped.") + return set() + + return secrets_to_override + except KeyboardInterrupt: + print("\n\nOverride cancelled.") + return set() + + +def handle_secrets_upload( + project_root: str, + env_file: str | None, + non_interactive: bool, +) -> None: + """ + Main entry point for handling secrets upload flow. + + This function: + 1. Loads secrets from .env file and environment + 2. Checks which secrets already exist on Fireworks + 3. Prompts user to select secrets (existing ones unchecked by default) + 4. Requires double confirmation for overriding existing secrets + 5. Uploads the selected/confirmed secrets + + Args: + project_root: Path to the project root directory. + env_file: Optional path to a specific .env file (overrides default). + non_interactive: If True, skip prompts and only upload new secrets. + """ + try: + fw_account_id = _ensure_account_id() + + # Determine .env file path + if env_file: + env_file_path = env_file + else: + env_file_path = os.path.join(project_root, ".env") + + # Load secrets from .env file + secrets_from_file = load_secrets_from_env_file(env_file_path) + secrets_from_env_file = secrets_from_file.copy() # Track what came from .env file + + # Also consider FIREWORKS_API_KEY from environment, but prefer .env value + fw_api_key_value = get_fireworks_api_key() + if fw_api_key_value and "FIREWORKS_API_KEY" not in secrets_from_file: + secrets_from_file["FIREWORKS_API_KEY"] = fw_api_key_value + + if fw_account_id and secrets_from_file: + if secrets_from_env_file and os.path.exists(env_file_path): + print(f"Loading secrets from: {env_file_path}") + + # Check which secrets already exist on Fireworks BEFORE prompting + print("Checking for existing secrets on Fireworks...") + existing_secrets = check_existing_secrets(secrets_from_file, fw_account_id) + + # Prompt user to select which secrets to upload + # Existing secrets are unchecked by default + selected_secrets = prompt_select_secrets( + secrets_from_file, + secrets_from_env_file, + existing_secrets, + non_interactive, + ) + + if selected_secrets: + # Separate new secrets from existing ones that user explicitly selected + new_secrets = {k: v for k, v in selected_secrets.items() if k not in existing_secrets} + secrets_needing_override = {k: v for k, v in selected_secrets.items() if k in existing_secrets} + + # Confirm override for existing secrets (double verification) + confirmed_overrides: set[str] = set() + if secrets_needing_override: + confirmed_overrides = confirm_override_secrets( + set(secrets_needing_override.keys()), + non_interactive, + ) + + # Build final list of secrets to upload + secrets_to_upload = dict(new_secrets) + for name in confirmed_overrides: + secrets_to_upload[name] = secrets_needing_override[name] + + if not secrets_to_upload: + print("No secrets to upload (existing secrets were not confirmed for override).") + else: + print(f"\nUploading {len(secrets_to_upload)} secret(s) to Fireworks...") + for secret_name, secret_value in secrets_to_upload.items(): + source = ".env" if secret_name in secrets_from_env_file else "environment" + action = "Overriding" if secret_name in confirmed_overrides else "Creating" + print(f"{action} {secret_name} on Fireworks... ({source}: {mask_secret_value(secret_value)})") + if create_or_update_fireworks_secret( + account_id=fw_account_id, + key_name=secret_name, + secret_value=secret_value, + ): + print( + f"āœ“ {secret_name} secret {'updated' if secret_name in confirmed_overrides else 'created'} on Fireworks." + ) + else: + print( + f"Warning: Failed to {'update' if secret_name in confirmed_overrides else 'create'} {secret_name} secret on Fireworks." + ) + else: + print("No secrets selected for upload.") + else: + if not fw_account_id: + print( + "Warning: Could not resolve Fireworks account id from FIREWORKS_API_KEY; cannot register secrets." + ) + if not secrets_from_file: + print("Warning: No API keys found in environment or .env file; no secrets to register.") + except Exception as e: + print(f"Warning: Skipped Fireworks secret registration due to error: {e}") diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index a8a132d6..8fe8423e 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -1,83 +1,21 @@ import argparse -from eval_protocol.cli_commands.utils import DiscoveredTest import os import re import sys from pathlib import Path -from typing import Any, Dict - -from eval_protocol.auth import get_fireworks_api_key -from eval_protocol.platform_api import create_or_update_fireworks_secret +from eval_protocol.cli_commands.utils import DiscoveredTest from eval_protocol.evaluation import create_evaluation +from .secrets import handle_secrets_upload from .utils import ( _build_entry_point, _build_evaluator_dashboard_url, _discover_and_select_tests, - _discover_tests, - _ensure_account_id, - _get_questionary_style, load_module_from_file_path, _normalize_evaluator_id, - _prompt_select, ) -def _to_pyargs_nodeid(file_path: str, func_name: str) -> str | None: - """Attempt to build a pytest nodeid suitable for `pytest `. - - Preference order: - 1) Dotted package module path with double-colon: pkg.subpkg.module::func - 2) Filesystem path with double-colon: path/to/module.py::func - - Returns dotted form when package root can be inferred (directory chain with __init__.py - leading up to a directory contained in sys.path). Returns None if no reasonable - nodeid can be created (should be rare). - """ - try: - abs_path = os.path.abspath(file_path) - dir_path, filename = os.path.split(abs_path) - module_base, ext = os.path.splitext(filename) - if ext != ".py": - # Not a python file - return None - - # Walk up while packages have __init__.py - segments: list[str] = [module_base] - current = dir_path - package_root = None - while True: - if os.path.isfile(os.path.join(current, "__init__.py")): - segments.insert(0, os.path.basename(current)) - parent = os.path.dirname(current) - # Stop if parent is not within current sys.path import roots - if parent == current: - break - current = parent - else: - package_root = current - break - - # If we found a package chain, check that the package_root is importable (in sys.path) - if package_root and any( - os.path.abspath(sp).rstrip(os.sep) == os.path.abspath(package_root).rstrip(os.sep) for sp in sys.path - ): - dotted = ".".join(segments) - return f"{dotted}::{func_name}" - - # Do not emit a dotted top-level module for non-packages; prefer path-based nodeid - - # Fallback to relative path (if under cwd) or absolute path - cwd = os.getcwd() - try: - rel = os.path.relpath(abs_path, cwd) - except Exception: - rel = abs_path - return f"{rel}::{func_name}" - except Exception: - return None - - def _parse_entry(entry: str, cwd: str) -> tuple[str, str]: # Accept module::function, path::function, or legacy module:function entry = entry.strip() @@ -130,146 +68,6 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]: return qualname, os.path.abspath(source_file_path) if source_file_path else "" -def _load_secrets_from_env_file(env_file_path: str) -> Dict[str, str]: - """ - Load secrets from a .env file that should be uploaded to Fireworks. - """ - if not os.path.exists(env_file_path): - return {} - - # Load the .env file into a temporary environment - env_vars = {} - with open(env_file_path, "r") as f: - for line in f: - line = line.strip() - if line and not line.startswith("#") and "=" in line: - key, value = line.split("=", 1) - key = key.strip() - value = value.strip().strip('"').strip("'") # Remove quotes - env_vars[key] = value - return env_vars - - -def _mask_secret_value(value: str) -> str: - """ - Return a masked representation of a secret showing only a small prefix/suffix. - Example: fw_3Z*******Xgnk - """ - try: - if not isinstance(value, str) or not value: - return "" - prefix_len = 6 - suffix_len = 4 - if len(value) <= prefix_len + suffix_len: - return value[0] + "***" + value[-1] - return f"{value[:prefix_len]}***{value[-suffix_len:]}" - except Exception: - return "" - - -def _prompt_select_secrets( - secrets: Dict[str, str], - secrets_from_env_file: Dict[str, str], - non_interactive: bool, -) -> Dict[str, str]: - """ - Prompt user to select which environment variables to upload as secrets. - Returns the selected secrets. - """ - if not secrets: - return {} - - if non_interactive: - return secrets - - # Check if running in a non-TTY environment (e.g., CI/CD) - if not sys.stdin.isatty(): - return secrets - - try: - import questionary - - custom_style = _get_questionary_style() - - # Build choices with source info and masked values - choices = [] - for key, value in secrets.items(): - source = ".env" if key in secrets_from_env_file else "env" - masked = _mask_secret_value(value) - label = f"{key} ({source}: {masked})" - choices.append(questionary.Choice(title=label, value=key, checked=True)) - - if len(choices) == 0: - return {} - - print("\nFound environment variables to upload as Fireworks secrets:") - selected_keys = questionary.checkbox( - "Select secrets to upload:", - choices=choices, - style=custom_style, - pointer=">", - instruction="(↑↓ move, space select, enter confirm)", - ).ask() - - if selected_keys is None: - # User cancelled with Ctrl+C - print("\nSecret upload cancelled.") - return {} - - return {k: v for k, v in secrets.items() if k in selected_keys} - - except ImportError: - # Fallback to simple text-based selection - return _prompt_select_secrets_fallback(secrets, secrets_from_env_file) - except KeyboardInterrupt: - print("\n\nSecret upload cancelled.") - return {} - - -def _prompt_select_secrets_fallback( - secrets: Dict[str, str], - secrets_from_env_file: Dict[str, str], -) -> Dict[str, str]: - """Fallback prompt selection for when questionary is not available.""" - print("\n" + "=" * 60) - print("Found environment variables to upload as Fireworks secrets:") - print("=" * 60) - print("\nTip: Install questionary for better UX: pip install questionary\n") - - secret_list = list(secrets.items()) - for idx, (key, value) in enumerate(secret_list, 1): - source = ".env" if key in secrets_from_env_file else "env" - masked = _mask_secret_value(value) - print(f" [{idx}] {key} ({source}: {masked})") - - print("\n" + "=" * 60) - print("Enter numbers to select (comma-separated), 'all' for all, or 'none' to skip:") - - try: - choice = input("Selection: ").strip().lower() - except KeyboardInterrupt: - print("\nSecret upload cancelled.") - return {} - - if not choice or choice == "none": - return {} - - if choice == "all": - return secrets - - try: - indices = [int(x.strip()) for x in choice.split(",")] - selected = {} - for idx in indices: - if 1 <= idx <= len(secret_list): - key, value = secret_list[idx - 1] - selected[key] = value - return selected - except ValueError: - print("Invalid input. Skipping secret upload.") - return {} - - def upload_command(args: argparse.Namespace) -> int: root = os.path.abspath(getattr(args, "path", ".")) entries_arg = getattr(args, "entry", None) @@ -289,66 +87,14 @@ def upload_command(args: argparse.Namespace) -> int: base_id = getattr(args, "id", None) display_name = getattr(args, "display_name", None) description = getattr(args, "description", None) - force = bool(getattr(args, "force", False)) env_file = getattr(args, "env_file", None) - # Load secrets from .env file and ensure they're available on Fireworks - try: - fw_account_id = _ensure_account_id() - - # Determine .env file path - if env_file: - env_file_path = env_file - else: - env_file_path = os.path.join(root, ".env") - - # Load secrets from .env file - secrets_from_file = _load_secrets_from_env_file(env_file_path) - secrets_from_env_file = secrets_from_file.copy() # Track what came from .env file - - # Also consider FIREWORKS_API_KEY from environment, but prefer .env value - fw_api_key_value = get_fireworks_api_key() - if fw_api_key_value and "FIREWORKS_API_KEY" not in secrets_from_file: - secrets_from_file["FIREWORKS_API_KEY"] = fw_api_key_value - - if fw_account_id and secrets_from_file: - if secrets_from_env_file and os.path.exists(env_file_path): - print(f"Loading secrets from: {env_file_path}") - - # Prompt user to select which secrets to upload - selected_secrets = _prompt_select_secrets( - secrets_from_file, - secrets_from_env_file, - non_interactive, - ) - - if selected_secrets: - print(f"\nUploading {len(selected_secrets)} selected secret(s) to Fireworks...") - for secret_name, secret_value in selected_secrets.items(): - source = ".env" if secret_name in secrets_from_env_file else "environment" - print( - f"Ensuring {secret_name} is registered as a secret on Fireworks for rollout... " - f"({source}: {_mask_secret_value(secret_value)})" - ) - if create_or_update_fireworks_secret( - account_id=fw_account_id, - key_name=secret_name, - secret_value=secret_value, - ): - print(f"āœ“ {secret_name} secret created/updated on Fireworks.") - else: - print(f"Warning: Failed to create/update {secret_name} secret on Fireworks.") - else: - print("No secrets selected for upload.") - else: - if not fw_account_id: - print( - "Warning: Could not resolve Fireworks account id from FIREWORKS_API_KEY; cannot register secrets." - ) - if not secrets_from_file: - print("Warning: No API keys found in environment or .env file; no secrets to register.") - except Exception as e: - print(f"Warning: Skipped Fireworks secret registration due to error: {e}") + # Handle secrets upload (with double verification for existing secrets) + handle_secrets_upload( + project_root=root, + env_file=env_file, + non_interactive=non_interactive, + ) exit_code = 0 for i, (qualname, source_file_path) in enumerate(selected_specs): @@ -378,17 +124,18 @@ def upload_command(args: argparse.Namespace) -> int: print(f"\nUploading evaluator '{evaluator_id}' for {qualname.split('.')[-1]}...") try: - result = create_evaluation( + result, version_id = create_evaluation( evaluator_id=evaluator_id, display_name=display_name or evaluator_id, description=description or f"Evaluator for {qualname}", - force=force, entry_point=entry_point, ) name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id # Print success message with Fireworks dashboard link print(f"\nāœ… Successfully uploaded evaluator: {evaluator_id}") + if version_id: + print(f" Version: {version_id}") print("šŸ“Š View in Fireworks Dashboard:") dashboard_url = _build_evaluator_dashboard_url(evaluator_id) print(f" {dashboard_url}\n") diff --git a/eval_protocol/cli_commands/utils.py b/eval_protocol/cli_commands/utils.py index 1338ae31..ab733832 100644 --- a/eval_protocol/cli_commands/utils.py +++ b/eval_protocol/cli_commands/utils.py @@ -23,6 +23,7 @@ get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from ..fireworks_client import create_fireworks_client from ..fireworks_rft import _map_api_host_to_app_host @@ -752,3 +753,253 @@ def add_args_from_callable_signature( help_text = help_overrides.get(name, help.get(name)) _add_flag(parser, flags, hints.get(name), help_text) + + +def validate_evaluator_locally( + project_root: str, + selected_test_file: Optional[str], + selected_test_func: Optional[str], + ignore_docker: bool, + docker_build_extra: str, + docker_run_extra: str, +) -> bool: + """Run pytest locally for the selected evaluation test to validate the evaluator. + + The pytest helpers always enforce a small success threshold (0.01) for + evaluation_test-based suites so that an evaluation run where all scores are + 0.0 will naturally fail with a non-zero pytest exit code, which we then treat + as a failed validator. + """ + # Lazy import to avoid circular dependency (local_test imports from utils) + from .local_test import run_evaluator_test + + if not selected_test_file or not selected_test_func: + # No local test associated; skip validation but warn the user. + print("Warning: Could not resolve a local evaluation test for this evaluator; skipping local validation.") + return True + + pytest_target = _build_entry_point(project_root, selected_test_file, selected_test_func) + exit_code = run_evaluator_test( + project_root=project_root, + pytest_target=pytest_target, + ignore_docker=ignore_docker, + docker_build_extra=docker_build_extra, + docker_run_extra=docker_run_extra, + ) + return exit_code == 0 + + +def resolve_evaluator( + project_root: str, + evaluator_arg: Optional[str], + non_interactive: bool, + account_id: str, + command_name: str = "create", +) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Resolve evaluator id/resource and associated local test (file + func). + + Args: + project_root: Path to the project root directory. + evaluator_arg: The evaluator argument provided by the user (id or fully-qualified resource). + non_interactive: Whether to run in non-interactive mode. + account_id: The Fireworks account ID. + command_name: The CLI command name for error messages (e.g., 'create rft', 'create evj'). + + Returns: + A tuple of (evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name). + Returns (None, None, None, None) if resolution fails. + """ + evaluator_id = evaluator_arg + selected_test_file_path: Optional[str] = None + selected_test_func_name: Optional[str] = None + + if not evaluator_id: + selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive) + if not selected_tests: + return None, None, None, None + + if len(selected_tests) != 1: + if non_interactive and len(selected_tests) > 1: + print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") + print(" Please pass --evaluator or --entry to disambiguate.") + else: + print(f"Error: Please select exactly one evaluation test for '{command_name}'.") + return None, None, None, None + + chosen = selected_tests[0] + func_name = chosen.qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] + evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}") + # Resolve selected test once for downstream + selected_test_file_path, selected_test_func_name = _resolve_selected_test( + project_root, evaluator_id, selected_tests=selected_tests + ) + else: + # Caller provided an evaluator id or fully-qualified resource; try to resolve local test + short_id = evaluator_id + if evaluator_id.startswith("accounts/"): + short_id = _extract_terminal_segment(evaluator_id) + st_path, st_func = _resolve_selected_test(project_root, short_id) + if st_path and st_func: + selected_test_file_path = st_path + selected_test_func_name = st_func + evaluator_id = short_id + + if not evaluator_id: + return None, None, None, None + + # Resolve evaluator resource name to fully-qualified format required by API. + if evaluator_arg and evaluator_arg.startswith("accounts/"): + evaluator_resource_name = evaluator_arg + else: + evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" + + return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name + + +def _poll_evaluator_version_status( + evaluator_id: str, + version_id: str, + api_key: str, + api_base: str, + timeout_minutes: int = 10, +) -> bool: + """ + Poll a specific evaluator version status until it becomes ACTIVE or times out. + + Uses the Fireworks SDK to get the specified version of the evaluator and checks + its build state. + + Args: + evaluator_id: The evaluator ID (not full resource name) + version_id: The specific version ID to poll + api_key: Fireworks API key + api_base: Fireworks API base URL + timeout_minutes: Maximum time to wait in minutes + + Returns: + True if evaluator version becomes ACTIVE, False if timeout or BUILD_FAILED + """ + timeout_seconds = timeout_minutes * 60 + poll_interval = 10 # seconds + start_time = time.time() + + print( + f"Polling evaluator version '{version_id}' status (timeout: {timeout_minutes}m, interval: {poll_interval}s)..." + ) + + client = create_fireworks_client(api_key=api_key, base_url=api_base) + + while time.time() - start_time < timeout_seconds: + try: + version = client.evaluator_versions.get(version_id, evaluator_id=evaluator_id) + state = version.state or "STATE_UNSPECIFIED" + status_msg = "" + if version.status and version.status.message: + status_msg = version.status.message + + if state == "ACTIVE": + print("āœ… Evaluator version is ACTIVE and ready!") + return True + elif state == "BUILD_FAILED": + print(f"āŒ Evaluator version build failed. Status: {status_msg}") + return False + elif state == "BUILDING": + elapsed_minutes = (time.time() - start_time) / 60 + print(f"ā³ Evaluator version is still building... ({elapsed_minutes:.1f}m elapsed)") + else: + print(f"ā³ Evaluator version state: {state}, status: {status_msg}") + + except Exception as e: + print(f"Warning: Failed to check evaluator version status: {e}") + + # Wait before next poll + time.sleep(poll_interval) + + # Timeout reached + elapsed_minutes = (time.time() - start_time) / 60 + print(f"ā° Timeout after {elapsed_minutes:.1f}m - evaluator version is not yet ACTIVE") + return False + + +def upload_and_ensure_evaluator( + project_root: str, + evaluator_id: str, + api_key: str, + api_base: str, + selected_test_file_path: Optional[str] = None, + selected_test_func_name: Optional[str] = None, +) -> bool: + """Upload evaluator and ensure its version becomes ACTIVE. + + Creates/updates the evaluator and uploads the code, then polls the specific + version until it becomes ACTIVE. This is the shared implementation used by + both 'ep upload', 'ep create rft', and 'ep create evj' commands. + + Args: + project_root: Path to the project root directory. + evaluator_id: The evaluator ID. + api_key: Fireworks API key. + api_base: Fireworks API base URL. + selected_test_file_path: Optional path to the selected test file. + selected_test_func_name: Optional name of the selected test function. + + Returns: + True if evaluator was uploaded and became ACTIVE, False otherwise. + """ + from eval_protocol.evaluation import create_evaluation + + try: + tests = _discover_tests(project_root) + selected_entry: Optional[str] = None + + # Use provided test info if available, otherwise try to resolve + if selected_test_file_path and selected_test_func_name: + selected_entry = _build_entry_point(project_root, selected_test_file_path, selected_test_func_name) + else: + st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) + if st_path and st_func: + selected_entry = _build_entry_point(project_root, st_path, st_func) + + # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators + if selected_entry is None and len(tests) > 1: + print( + f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" + " Please re-run specifying the evaluator.\n" + ) + return False + + print(f"\nUploading evaluator '{evaluator_id}'...") + result, version_id = create_evaluation( + evaluator_id=evaluator_id, + display_name=evaluator_id, + description=f"Evaluator for {evaluator_id}", + entry_point=selected_entry, + ) + + if not version_id: + print("Warning: Evaluator created but version upload failed.") + return False + + print(f"āœ“ Uploaded evaluator: {evaluator_id} (version: {version_id})") + + # Poll for the specific evaluator version status + print(f"Waiting for evaluator '{evaluator_id}' version '{version_id}' to become ACTIVE...") + is_active = _poll_evaluator_version_status( + evaluator_id=evaluator_id, + version_id=version_id, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ) + + if not is_active: + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\nāŒ Evaluator version is not ready within the timeout period.") + print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run the command again.") + return False + return True + except Exception as e: + print(f"Warning: Failed to upload evaluator automatically: {e}") + return False diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index f6a81e1e..2b9885c7 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -42,9 +42,10 @@ class EvaluationRow(BaseModel): # type: ignore self._EvaluationRow = EvaluationRow - self._db.connect() + # Wrap connect() in retry logic since setting pragmas can fail with "database is locked" + execute_with_sqlite_retry(lambda: self._db.connect(reuse_if_open=True)) # Use safe=True to avoid errors when tables/indexes already exist - self._db.create_tables([EvaluationRow], safe=True) + execute_with_sqlite_retry(lambda: self._db.create_tables([EvaluationRow], safe=True)) @property def db_path(self) -> str: diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 128038bf..31298992 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -4,14 +4,15 @@ from typing import List, Optional import fireworks +from fireworks.types import EvaluatorVersionParam import requests -from fireworks import Fireworks from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from eval_protocol.fireworks_client import create_fireworks_client from eval_protocol.get_pep440_version import get_pep440_version logger = logging.getLogger(__name__) @@ -153,7 +154,7 @@ def _create_tar_gz_with_ignores(output_path: str, source_dir: str) -> int: logger.info(f"Created {output_path} ({size_bytes:,} bytes)") return size_bytes - def create(self, evaluator_id, display_name=None, description=None, force=False): + def create(self, evaluator_id, display_name=None, description=None): auth_token = self.api_key or get_fireworks_api_key() account_id = self.account_id or get_fireworks_account_id() if not account_id and auth_token: @@ -163,7 +164,11 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) logger.error("Authentication error: API credentials appear to be invalid or incomplete.") raise ValueError("Invalid or missing API credentials.") - client = Fireworks(api_key=auth_token, base_url=self.api_base, account_id=account_id) + client = create_fireworks_client( + api_key=auth_token, + base_url=self.api_base, + account_id=account_id, + ) self.display_name = display_name or evaluator_id self.description = description or f"Evaluator created from {evaluator_id}" @@ -197,28 +202,20 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) logger.info(f"Creating evaluator '{evaluator_id}' for account '{account_id}'...") try: - if force: - try: - logger.info("Checking if evaluator exists") - existing_evaluator = client.evaluators.get(evaluator_id=evaluator_id) - if existing_evaluator: - logger.info(f"Evaluator '{evaluator_id}' already exists, deleting and recreating...") - try: - client.evaluators.delete(evaluator_id=evaluator_id) - logger.info(f"Successfully deleted evaluator '{evaluator_id}'") - except fireworks.NotFoundError: - logger.info(f"Evaluator '{evaluator_id}' not found, creating...") - except fireworks.APIError as e: - logger.warning(f"Error deleting evaluator: {str(e)}") - except fireworks.NotFoundError: - logger.info(f"Evaluator '{evaluator_id}' does not exist, creating...") - - # Create evaluator using SDK - result = client.evaluators.create( - evaluator_id=evaluator_id, - evaluator=evaluator_params, - ) - logger.info(f"Successfully created evaluator '{evaluator_id}'") + # Try to create evaluator using SDK + try: + result = client.evaluators.create( + evaluator_id=evaluator_id, + evaluator=evaluator_params, + ) + logger.info(f"Successfully created evaluator '{evaluator_id}'") + except fireworks.APIStatusError as create_error: + if create_error.status_code == 409: + # Evaluator already exists, get the existing one and proceed to create a new version + logger.info(f"Evaluator '{evaluator_id}' already exists, creating new version...") + result = client.evaluators.get(evaluator_id=evaluator_id) + else: + raise # Upload code as tar.gz to GCS evaluator_name = result.name # e.g., "accounts/pyroworks/evaluators/test-123" @@ -229,6 +226,25 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) f"Cannot proceed with code upload. Response: {result}" ) + evaluator_version_param: EvaluatorVersionParam = {} + if "commit_hash" in evaluator_params: + evaluator_version_param["commit_hash"] = evaluator_params["commit_hash"] + if "entry_point" in evaluator_params: + evaluator_version_param["entry_point"] = evaluator_params["entry_point"] + if "requirements" in evaluator_params: + evaluator_version_param["requirements"] = evaluator_params["requirements"] + + evaluator_version = client.evaluator_versions.create( + evaluator_id=evaluator_id, + evaluator_version=evaluator_version_param, + ) + evaluator_version_id = evaluator_version.name.split("/")[-1] if evaluator_version.name else None + if not evaluator_version_id: + raise ValueError( + "Create evaluator version response missing 'name' field. " + f"Cannot proceed with code upload. Response: {evaluator_version}" + ) + try: # Create tar.gz of current directory cwd = os.getcwd() @@ -240,7 +256,8 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) # Call GetEvaluatorUploadEndpoint using SDK logger.info(f"Requesting upload endpoint for {tar_filename}") - upload_response = client.evaluators.get_upload_endpoint( + upload_response = client.evaluator_versions.get_upload_endpoint( + version_id=evaluator_version_id, evaluator_id=evaluator_id, filename_to_size={tar_filename: str(tar_size)}, ) @@ -321,9 +338,9 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) raise # Step 3: Validate upload using SDK - client.evaluators.validate_upload( + client.evaluator_versions.validate_upload( + version_id=evaluator_version_id, evaluator_id=evaluator_id, - body={}, ) logger.info("Upload validated successfully") @@ -334,8 +351,10 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) except Exception as upload_error: logger.warning(f"Code upload failed (evaluator created but code not uploaded): {upload_error}") # Don't fail - evaluator is created, just code upload failed + # Return None for version_id since upload failed + return result, None - return result # Return after attempting upload + return result, evaluator_version_id # Return evaluator result and version ID except fireworks.APIStatusError as e: logger.error(f"Error creating evaluator: {str(e)}") logger.error(f"Status code: {e.status_code}, Response: {e.response.text}") @@ -361,7 +380,6 @@ def create_evaluation( evaluator_id: str, display_name: Optional[str] = None, description: Optional[str] = None, - force: bool = False, account_id: Optional[str] = None, api_key: Optional[str] = None, entry_point: Optional[str] = None, @@ -373,10 +391,13 @@ def create_evaluation( evaluator_id: Unique identifier for the evaluator display_name: Display name for the evaluator description: Description for the evaluator - force: If True, delete and recreate if evaluator exists account_id: Optional Fireworks account ID api_key: Optional Fireworks API key entry_point: Optional entry point (module::function or path::function) + + Returns: + A tuple of (evaluator_result, version_id) where version_id is the ID of the + created evaluator version, or None if upload failed. """ evaluator = Evaluator( account_id=account_id, @@ -384,4 +405,4 @@ def create_evaluation( entry_point=entry_point, ) - return evaluator.create(evaluator_id, display_name, description, force) + return evaluator.create(evaluator_id, display_name, description) diff --git a/eval_protocol/event_bus/sqlite_event_bus_database.py b/eval_protocol/event_bus/sqlite_event_bus_database.py index 5086d6e3..122fbac9 100644 --- a/eval_protocol/event_bus/sqlite_event_bus_database.py +++ b/eval_protocol/event_bus/sqlite_event_bus_database.py @@ -11,8 +11,8 @@ # Retry configuration for database operations -SQLITE_RETRY_MAX_TRIES = 5 -SQLITE_RETRY_MAX_TIME = 30 # seconds +SQLITE_RETRY_MAX_TRIES = 10 +SQLITE_RETRY_MAX_TIME = 60 # seconds def _is_database_locked_error(e: Exception) -> bool: @@ -181,9 +181,10 @@ class Event(BaseModel): # type: ignore processed = BooleanField(default=False) # Track if event has been processed self._Event = Event - self._db.connect() + # Wrap connect() in retry logic since setting pragmas can fail with "database is locked" + execute_with_sqlite_retry(lambda: self._db.connect(reuse_if_open=True)) # Use safe=True to avoid errors when tables already exist - self._db.create_tables([Event], safe=True) + execute_with_sqlite_retry(lambda: self._db.create_tables([Event], safe=True)) def publish_event(self, event_type: str, data: Any, process_id: str) -> None: """Publish an event to the database.""" diff --git a/eval_protocol/fireworks_client.py b/eval_protocol/fireworks_client.py new file mode 100644 index 00000000..d92d8bec --- /dev/null +++ b/eval_protocol/fireworks_client.py @@ -0,0 +1,132 @@ +""" +Consolidated Fireworks client factory. + +This module provides a single point of instantiation for the Fireworks SDK client, +ensuring consistent handling of environment variables and configuration across the +eval_protocol codebase. + +Environment variables: + FIREWORKS_API_KEY: API key for authentication (required) + FIREWORKS_ACCOUNT_ID: Account ID (optional, can be derived from API key) + FIREWORKS_API_BASE: Base URL for the API (default: https://api.fireworks.ai) + FIREWORKS_EXTRA_HEADERS: JSON-encoded extra headers to include in requests + Example: '{"X-Custom-Header": "value", "X-Another": "another-value"}' +""" + +import json +import logging +import os +from typing import Mapping, Optional + +from fireworks import Fireworks + +from eval_protocol.auth import ( + get_fireworks_account_id, + get_fireworks_api_base, + get_fireworks_api_key, +) + +logger = logging.getLogger(__name__) + + +def get_fireworks_extra_headers() -> Optional[Mapping[str, str]]: + """ + Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable. + + The value should be a JSON-encoded object mapping header names to values. + Example: '{"X-Custom-Header": "value"}' + + Returns: + A mapping of header names to values if set and valid, otherwise None. + """ + extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS") + if not extra_headers_str or not extra_headers_str.strip(): + return None + + try: + headers = json.loads(extra_headers_str) + if not isinstance(headers, dict): + logger.warning( + "FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s. Ignoring.", + type(headers).__name__, + ) + return None + # Validate all keys and values are strings + for k, v in headers.items(): + if not isinstance(k, str) or not isinstance(v, str): + logger.warning( + "FIREWORKS_EXTRA_HEADERS contains non-string key or value: %s=%s. Ignoring all extra headers.", + k, + v, + ) + return None + logger.debug("Using FIREWORKS_EXTRA_HEADERS: %s", list(headers.keys())) + return headers + except json.JSONDecodeError as e: + logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s. Ignoring.", e) + return None + + +def create_fireworks_client( + *, + api_key: Optional[str] = None, + account_id: Optional[str] = None, + base_url: Optional[str] = None, + extra_headers: Optional[Mapping[str, str]] = None, +) -> Fireworks: + """ + Create a Fireworks client with consistent configuration. + + This factory function centralizes the logic for creating Fireworks clients, + ensuring that environment variables are handled consistently across the codebase. + + Resolution order for each parameter: + 1. Explicit argument passed to this function + 2. Environment variable (via auth module helpers) + 3. SDK defaults (for base_url only) + + Args: + api_key: Fireworks API key. If not provided, resolves from FIREWORKS_API_KEY. + account_id: Fireworks account ID. If not provided, resolves from FIREWORKS_ACCOUNT_ID + or derives from the API key via the verifyApiKey endpoint. + base_url: Base URL for the Fireworks API. If not provided, resolves from + FIREWORKS_API_BASE or defaults to https://api.fireworks.ai. + extra_headers: Additional headers to include in all requests. If not provided, + resolves from FIREWORKS_EXTRA_HEADERS environment variable (JSON-encoded). + + Returns: + A configured Fireworks client instance. + + Raises: + fireworks.FireworksError: If api_key is not provided and FIREWORKS_API_KEY + environment variable is not set. + """ + # Resolve parameters from environment if not explicitly provided + resolved_api_key = api_key or get_fireworks_api_key() + resolved_account_id = account_id or get_fireworks_account_id() + resolved_base_url = base_url or get_fireworks_api_base() + + # Merge extra headers: env var headers first, then explicit headers override + env_extra_headers = get_fireworks_extra_headers() + merged_headers: Optional[Mapping[str, str]] = None + if env_extra_headers or extra_headers: + merged = {} + if env_extra_headers: + merged.update(env_extra_headers) + if extra_headers: + merged.update(extra_headers) + merged_headers = merged if merged else None + + logger.debug( + "Creating Fireworks client: base_url=%s, account_id=%s, extra_headers=%s", + resolved_base_url, + resolved_account_id, + list(merged_headers.keys()) if merged_headers else None, + ) + + return Fireworks( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_base_url, + default_headers=merged_headers, + ) diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 148aefde..600296ba 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -1,5 +1,4 @@ import importlib.util -import io import json import os import sys @@ -9,12 +8,8 @@ import hashlib from pathlib import Path from typing import Any, Callable, Dict, Iterable, Optional, Tuple -from urllib.parse import urlencode -import requests - -from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key -from .common_utils import get_user_agent +from .fireworks_client import create_fireworks_client def _map_api_host_to_app_host(api_base: str) -> str: @@ -142,43 +137,80 @@ def create_dataset_from_jsonl( display_name: Optional[str], jsonl_path: str, ) -> Tuple[str, Dict[str, Any]]: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + """Create a dataset and upload a JSONL file using the Fireworks SDK client. + + This function uses the Fireworks SDK client which properly handles authentication + including extra headers set via FIREWORKS_EXTRA_HEADERS environment variable. + + Args: + account_id: The Fireworks account ID. + api_key: Fireworks API key. + api_base: Fireworks API base URL. + dataset_id: The ID for the new dataset. + display_name: Display name for the dataset (optional). + jsonl_path: Path to the JSONL file to upload. + + Returns: + A tuple of (dataset_id, dataset_response_dict). + + Raises: + RuntimeError: If dataset creation or upload fails. + """ # Count examples quickly example_count = 0 with open(jsonl_path, "r", encoding="utf-8") as f: for _ in f: example_count += 1 - dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets" - payload = { - "dataset": { - "displayName": display_name or dataset_id, - "evalProtocol": {}, - "format": "FORMAT_UNSPECIFIED", - "exampleCount": str(example_count), - }, - "datasetId": dataset_id, - } - resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60) - if resp.status_code not in (200, 201): - raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}") - ds = resp.json() - - upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" - with open(jsonl_path, "rb") as f: - files = {"file": f} - up_headers = { - "Authorization": f"Bearer {api_key}", - "User-Agent": get_user_agent(), + # Create Fireworks client with consistent configuration + client = create_fireworks_client( + api_key=api_key, + account_id=account_id, + base_url=api_base, + ) + + try: + # Create the dataset + dataset = client.datasets.create( + account_id=account_id, + dataset_id=dataset_id, + dataset={ + "display_name": display_name or dataset_id, + "eval_protocol": {}, + "format": "FORMAT_UNSPECIFIED", + "example_count": str(example_count), + }, + timeout=60.0, + ) + except Exception as e: + raise RuntimeError(f"Dataset creation failed: {e}") from e + + try: + # Upload the JSONL file + with open(jsonl_path, "rb") as f: + client.datasets.upload( + dataset_id=dataset_id, + account_id=account_id, + file=f, + timeout=600.0, + ) + except Exception as e: + raise RuntimeError(f"Dataset upload failed: {e}") from e + + # Convert SDK response to dict for backwards compatibility + ds_dict: Dict[str, Any] = {} + if hasattr(dataset, "model_dump"): + ds_dict = dataset.model_dump() + elif hasattr(dataset, "dict"): + ds_dict = dataset.dict() + else: + # Fallback: extract known fields + ds_dict = { + "name": getattr(dataset, "name", None), + "state": getattr(dataset, "state", None), } - up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) - if up_resp.status_code not in (200, 201): - raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") - return dataset_id, ds + + return dataset_id, ds_dict def build_default_dataset_id(evaluator_id: str) -> str: diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 60743ccb..8b07f4d7 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -3,39 +3,17 @@ import sys from typing import Optional -from dotenv import find_dotenv, load_dotenv - from eval_protocol.auth import ( get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, ) +from eval_protocol.fireworks_client import create_fireworks_client from fireworks.types import Secret -from fireworks import Fireworks, FireworksError, NotFoundError, InternalServerError +from fireworks import FireworksError, NotFoundError, InternalServerError logger = logging.getLogger(__name__) -# --- Load .env files --- -# Attempt to load .env.dev first, then .env as a fallback. -# This happens when the module is imported. -# We use override=False (default) so that existing environment variables -# (e.g., set in the shell) are NOT overridden by .env files. -ENV_DEV_PATH = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True) -if ENV_DEV_PATH: - load_dotenv(dotenv_path=ENV_DEV_PATH, override=False) - logger.info(f"eval_protocol.platform_api: Loaded environment variables from: {ENV_DEV_PATH}") -else: - ENV_PATH = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True) - if ENV_PATH: - load_dotenv(dotenv_path=ENV_PATH, override=False) - logger.info(f"eval_protocol.platform_api: Loaded environment variables from: {ENV_PATH}") - else: - logger.info( - "eval_protocol.platform_api: No .env.dev or .env file found. " - "Relying on shell/existing environment variables." - ) -# --- End .env loading --- - class PlatformAPIError(Exception): """Custom exception for platform API errors.""" @@ -88,7 +66,11 @@ def create_or_update_fireworks_secret( resolved_api_key = api_key or get_fireworks_api_key() resolved_api_base = api_base or get_fireworks_api_base() resolved_account_id = account_id # Must be provided - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) if not all([resolved_api_key, resolved_api_base, resolved_account_id]): logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") @@ -173,7 +155,11 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) resource_id = _normalize_secret_resource_id(key_name) try: @@ -215,7 +201,11 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - client = Fireworks(api_key=resolved_api_key, account_id=resolved_account_id, base_url=resolved_api_base) + client = create_fireworks_client( + api_key=resolved_api_key, + account_id=resolved_account_id, + base_url=resolved_api_base, + ) resource_id = _normalize_secret_resource_id(key_name) try: diff --git a/pyproject.toml b/pyproject.toml index e5caa497..841f5ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "pytest-asyncio>=0.21.0", "peewee>=3.18.2", "backoff>=2.2.0", - "fireworks-ai==1.0.0a20", + "fireworks-ai==1.0.0a22", "questionary>=2.0.0", # Dependencies for vendored tau2 package "toml>=0.10.0", diff --git a/tests/cli_commands/test_secrets.py b/tests/cli_commands/test_secrets.py new file mode 100644 index 00000000..3f4d62ea --- /dev/null +++ b/tests/cli_commands/test_secrets.py @@ -0,0 +1,217 @@ +"""Tests for eval_protocol.cli_commands.secrets module.""" + +import os +from pathlib import Path + +import pytest + +from eval_protocol.cli_commands.secrets import load_secrets_from_env_file, mask_secret_value + + +class TestLoadSecretsFromEnvFile: + """Tests for load_secrets_from_env_file function.""" + + def test_basic_key_value(self, tmp_path: Path): + """Test basic KEY=value parsing.""" + env_file = tmp_path / ".env" + env_file.write_text("MY_KEY=myvalue123\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "myvalue123"} + + def test_end_of_line_comment(self, tmp_path: Path): + """Test that end-of-line comments are properly stripped. + + This was a bug where manual parsing included the comment in the value. + """ + env_file = tmp_path / ".env" + env_file.write_text("MY_API_KEY=test_dummy_value_abc123 # this is a comment\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_API_KEY": "test_dummy_value_abc123"} + # Ensure the comment is NOT included + assert "# this" not in result["MY_API_KEY"] + assert "comment" not in result["MY_API_KEY"] + + def test_end_of_line_comment_no_space(self, tmp_path: Path): + """Test end-of-line comment without space before #.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY=value#this is a comment\n") + + result = load_secrets_from_env_file(str(env_file)) + + # python-dotenv treats # without space as part of value + # unless the value is quoted. This is the expected behavior. + assert result == {"KEY": "value#this is a comment"} + + def test_quoted_value_with_hash(self, tmp_path: Path): + """Test that quoted values preserve # character.""" + env_file = tmp_path / ".env" + env_file.write_text('KEY="value#with#hashes"\n') + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY": "value#with#hashes"} + + def test_double_quoted_values(self, tmp_path: Path): + """Test that double-quoted values are properly unquoted.""" + env_file = tmp_path / ".env" + env_file.write_text('MY_KEY="value with spaces"\n') + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "value with spaces"} + + def test_single_quoted_values(self, tmp_path: Path): + """Test that single-quoted values are properly unquoted.""" + env_file = tmp_path / ".env" + env_file.write_text("MY_KEY='value with spaces'\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "value with spaces"} + + def test_comment_lines_ignored(self, tmp_path: Path): + """Test that full-line comments are ignored.""" + env_file = tmp_path / ".env" + env_file.write_text("# This is a comment\nMY_KEY=myvalue\n# Another comment\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "myvalue"} + + def test_empty_lines_ignored(self, tmp_path: Path): + """Test that empty lines are ignored.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY1=value1\n\n\nKEY2=value2\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY1": "value1", "KEY2": "value2"} + + def test_multiple_keys(self, tmp_path: Path): + """Test parsing multiple key-value pairs.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY1=value1\nKEY2=value2\nKEY3=value3\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"} + + def test_file_not_found(self, tmp_path: Path): + """Test that non-existent file returns empty dict.""" + env_file = tmp_path / "nonexistent.env" + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {} + + def test_empty_file(self, tmp_path: Path): + """Test that empty file returns empty dict.""" + env_file = tmp_path / ".env" + env_file.write_text("") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {} + + def test_value_with_equals_sign(self, tmp_path: Path): + """Test that values containing = are properly parsed.""" + env_file = tmp_path / ".env" + env_file.write_text("KEY=value=with=equals\n") + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"KEY": "value=with=equals"} + + def test_quoted_value_with_comment(self, tmp_path: Path): + """Test quoted value followed by a comment.""" + env_file = tmp_path / ".env" + env_file.write_text('MY_KEY="myvalue123" # this is a comment\n') + + result = load_secrets_from_env_file(str(env_file)) + + assert result == {"MY_KEY": "myvalue123"} + + def test_complex_env_file(self, tmp_path: Path): + """Test a complex .env file with various formats.""" + env_content = """# Configuration file +# Last updated: 2024-01-15 + +# API Keys +SERVICE_A_KEY=dummy_key_aaa # Service A key +SERVICE_B_KEY="dummy_key_bbb" # Service B key + +# Database settings +DB_HOST=localhost +DB_PORT=5432 +DB_NAME='mydb' + +# Feature flags +ENABLE_FEATURE=true # enable new feature + +# Empty values are skipped +""" + env_file = tmp_path / ".env" + env_file.write_text(env_content) + + result = load_secrets_from_env_file(str(env_file)) + + assert result["SERVICE_A_KEY"] == "dummy_key_aaa" + assert result["SERVICE_B_KEY"] == "dummy_key_bbb" + assert result["DB_HOST"] == "localhost" + assert result["DB_PORT"] == "5432" + assert result["DB_NAME"] == "mydb" + assert result["ENABLE_FEATURE"] == "true" + # Ensure no comments leaked into values + assert "# Service" not in result.get("SERVICE_A_KEY", "") + assert "# Service" not in result.get("SERVICE_B_KEY", "") + + def test_export_prefix_handled(self, tmp_path: Path): + """Test that 'export' prefix is handled (if supported by dotenv).""" + env_file = tmp_path / ".env" + env_file.write_text("export MY_KEY=myvalue123\n") + + result = load_secrets_from_env_file(str(env_file)) + + # python-dotenv handles 'export' prefix + assert result == {"MY_KEY": "myvalue123"} + + +class TestMaskSecretValue: + """Tests for mask_secret_value function.""" + + def test_normal_length_value(self): + """Test masking a normal length secret.""" + result = mask_secret_value("abcdefghijklmnopqrstu") + assert result == "abcdef***rstu" + assert len(result) < len("abcdefghijklmnopqrstu") + + def test_short_value(self): + """Test masking a very short secret.""" + result = mask_secret_value("abc") + assert result == "a***c" + + def test_empty_value(self): + """Test masking an empty value.""" + result = mask_secret_value("") + assert result == "" + + def test_none_value(self): + """Test masking None (edge case).""" + result = mask_secret_value(None) + assert result == "" + + def test_exact_boundary_length(self): + """Test masking value at exactly prefix+suffix length (10 chars).""" + # prefix_len=6, suffix_len=4, so <= 10 chars uses short format + result = mask_secret_value("1234567890") + assert result == "1***0" + + def test_just_over_boundary(self): + """Test masking value just over the boundary.""" + # 11 chars should use the full format + result = mask_secret_value("12345678901") + assert result == "123456***8901" diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 1f1e8395..1c0df26c 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -1,7 +1,6 @@ import json import os import argparse -import requests from types import SimpleNamespace from unittest.mock import patch from typing import Any, cast @@ -9,6 +8,7 @@ from eval_protocol.cli_commands import create_rft as cr from eval_protocol.cli_commands import upload as upload_mod +from eval_protocol.cli_commands import local_test as local_test_mod import eval_protocol.fireworks_rft as fr from eval_protocol.cli import parse_args import eval_protocol.cli_commands.utils as cli_utils @@ -24,7 +24,7 @@ def _write_json(path: str, data: dict) -> None: def stub_fireworks(monkeypatch) -> dict[str, Any]: """ Stub Fireworks SDK so tests stay offline and so create_rft.py can inspect a stable - create() signature (it uses inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)). + create() signature (it uses inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create)). Returns: A dict containing the last captured create() kwargs under key "kwargs". @@ -72,12 +72,15 @@ def create( return SimpleNamespace(name=f"accounts/{account_id}/reinforcementFineTuningJobs/xyz") class _FakeFW: - def __init__(self, api_key=None, base_url=None): + def __init__(self, api_key=None, base_url=None, account_id=None, default_headers=None): self.api_key = api_key self.base_url = base_url + self.account_id = account_id + self.default_headers = default_headers self.reinforcement_fine_tuning_jobs = _FakeJobs() - monkeypatch.setattr(cr, "Fireworks", _FakeFW) + # Patch create_fireworks_client to return our fake client + monkeypatch.setattr(cr, "create_fireworks_client", lambda **kwargs: _FakeFW(**kwargs)) return captured @@ -101,10 +104,10 @@ def rft_test_harness(tmp_path, monkeypatch, stub_fireworks): # Account id is derived from API key; mock the verify call to keep tests offline. monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(cli_utils, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) - monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) return project @@ -223,7 +226,7 @@ def test_create_rft_evaluator_validation_fails(rft_test_harness, monkeypatch): test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# dummy eval test", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_eval_validation", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Force local evaluator validation to fail calls = {"count": 0, "pytest_target": None} @@ -233,13 +236,12 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ calls["pytest_target"] = pytest_target return 1 # non-zero exit code => validation failure - monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + monkeypatch.setattr(local_test_mod, "run_evaluator_test", _fake_run_evaluator_test) args = argparse.Namespace( evaluator=None, yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -283,7 +285,7 @@ def test_create_rft_evaluator_validation_passes(rft_test_harness, monkeypatch): test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# dummy ok eval test", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_eval_ok", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Force local evaluator validation to succeed calls = {"count": 0, "pytest_target": None} @@ -293,13 +295,12 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ calls["pytest_target"] = pytest_target return 0 # success - monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + monkeypatch.setattr(local_test_mod, "run_evaluator_test", _fake_run_evaluator_test) args = argparse.Namespace( evaluator=None, yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -351,7 +352,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator="my-evaluator", yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -401,7 +401,6 @@ def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_ evaluator="my-evaluator", yes=True, dry_run=True, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -444,10 +443,10 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_ one_file.write_text("# single", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_single", file_path=str(one_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) captured = {"dataset_id": None} @@ -462,7 +461,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d setattr(args, "evaluator", None) setattr(args, "yes", True) setattr(args, "dry_run", False) - setattr(args, "force", False) setattr(args, "env_file", None) setattr(args, "dataset", None) setattr(args, "dataset_jsonl", str(ds_path)) @@ -508,7 +506,7 @@ def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(r # Fake discovered tests: foo and bar cal_disc = SimpleNamespace(qualname="foo_eval.test_bar_evaluation", file_path=str(cal_file)) svg_disc = SimpleNamespace(qualname="bar_eval.test_baz_evaluation", file_path=str(svg_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [cal_disc, svg_disc]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [cal_disc, svg_disc]) # Capture dataset id used during dataset creation captured = {"dataset_id": None} @@ -530,7 +528,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -576,7 +573,7 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat test_file.write_text("# one", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_one", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Capture dataset id used during dataset creation captured = {"dataset_id": None} @@ -600,7 +597,6 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -645,17 +641,8 @@ def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Mock evaluator exists and is ACTIVE - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + # Mock evaluator upload and version polling - evaluator becomes ACTIVE + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) # Provide dataset via --dataset-jsonl so no test discovery needed ds_path = project / "dataset.jsonl" @@ -674,7 +661,6 @@ def raise_for_status(self): evaluator="some-eval", yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(ds_path), @@ -708,11 +694,8 @@ def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "acct123") - # Evaluator does not exist (force path into upload section) - def _raise(*a, **k): - raise requests.exceptions.RequestException("nope") - - monkeypatch.setattr(cr.requests, "get", _raise) + # Mock upload_and_ensure_evaluator to fail (ambiguous tests) + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: False) # Two discovered tests (ambiguous) f1 = project / "a.py" @@ -721,13 +704,12 @@ def _raise(*a, **k): f2.write_text("# b", encoding="utf-8") d1 = SimpleNamespace(qualname="a.test_one", file_path=str(f1)) d2 = SimpleNamespace(qualname="b.test_two", file_path=str(f2)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [d1, d2]) args = argparse.Namespace( evaluator="some-eval", yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(project / "dataset.jsonl"), @@ -761,9 +743,9 @@ def test_create_rft_fallback_to_dataset_builder(rft_test_harness, monkeypatch): test_file.write_text("# builder case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_builder", file_path=str(test_file)) # New flow uses _discover_and_select_tests for evaluator resolution; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Also patch _discover_tests for any direct calls during dataset inference. - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [single_disc]) # Dataset builder fallback out_jsonl = project / "metric" / "builder_out.jsonl" @@ -789,7 +771,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -827,7 +808,7 @@ def test_create_rft_rejects_dataloader_jsonl(rft_test_harness, monkeypatch): test_file.write_text("# loader case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_loader", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Provide JSONL via dataloader extractor dl_jsonl = project / "metric" / "loader_out.jsonl" @@ -850,7 +831,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -889,7 +869,7 @@ def test_create_rft_uses_input_dataset_jsonl_when_available(rft_test_harness, mo test_file.write_text("# input_dataset case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_input_ds", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Provide JSONL via input_dataset extractor id_jsonl = project / "metric" / "input_ds_out.jsonl" @@ -912,7 +892,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -955,20 +934,10 @@ def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(r f2.write_text("# beta", encoding="utf-8") d1 = SimpleNamespace(qualname="alpha.test_one", file_path=str(f1)) d2 = SimpleNamespace(qualname="beta.test_two", file_path=str(f2)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) + monkeypatch.setattr(cli_utils, "_discover_tests", lambda cwd: [d1, d2]) - # Evaluator exists and is ACTIVE (skip upload) - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # Evaluator upload succeeds and version becomes ACTIVE + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) # We will provide JSONL via input_dataset extractor for matching test (beta.test_two) jsonl_path = project / "data.jsonl" @@ -1007,7 +976,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=eval_id, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, @@ -1050,17 +1018,8 @@ def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatc monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") monkeypatch.setattr(cli_utils, "verify_api_key_and_get_account_id", lambda *a, **k: "pyroworks-dev") - # Mock evaluator exists and ACTIVE - class _Resp: - ok = True - - def json(self): - return {"state": "ACTIVE"} - - def raise_for_status(self): - return None - - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + # Mock evaluator upload succeeds and version becomes ACTIVE + monkeypatch.setattr(cr, "upload_and_ensure_evaluator", lambda *a, **k: True) captured = stub_fireworks @@ -1139,11 +1098,11 @@ def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_h test_file.write_text("# prefer explicit dataset_jsonl", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_pref", file_path=str(test_file)) # New flow uses _discover_and_select_tests; patch it to return our single test. - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(cli_utils, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_poll_evaluator_version_status", lambda **kwargs: True) # Prepare two JSONL paths: one explicit via --dataset-jsonl and one inferable via input_dataset explicit_jsonl = project / "metric" / "explicit.jsonl" @@ -1175,7 +1134,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=str(explicit_jsonl), @@ -1246,7 +1204,7 @@ def test_adapt(row: EvaluationRow) -> EvaluationRow: # Discovery: exactly one test, and resolve_selected_test points to our module/function single_disc = SimpleNamespace(qualname="metric.test_adapt.test_adapt", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr(cli_utils, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) monkeypatch.setattr( cr, "_resolve_selected_test", @@ -1266,7 +1224,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d evaluator=None, yes=True, dry_run=False, - force=False, env_file=None, dataset=None, dataset_jsonl=None, diff --git a/tests/test_ep_upload_e2e.py b/tests/test_ep_upload_e2e.py index 8a67fd33..aadfd0b2 100644 --- a/tests/test_ep_upload_e2e.py +++ b/tests/test_ep_upload_e2e.py @@ -80,8 +80,8 @@ def mock_gcs_upload(): @pytest.fixture def mock_fireworks_client(): - """Mock the Fireworks SDK client used in evaluation.py""" - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: + """Mock the Fireworks SDK client used in fireworks_client.py""" + with patch("eval_protocol.fireworks_client.Fireworks") as mock_fw_class: mock_client = MagicMock() mock_fw_class.return_value = mock_client @@ -92,8 +92,13 @@ def mock_fireworks_client(): mock_create_response.description = "Test description" mock_client.evaluators.create.return_value = mock_create_response - # Mock evaluators.get_upload_endpoint response - will be set dynamically - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): + # Mock evaluator_versions.create response + mock_version_response = MagicMock() + mock_version_response.name = "accounts/test_account/evaluators/test-eval/versions/v1" + mock_client.evaluator_versions.create.return_value = mock_version_response + + # Mock evaluator_versions.get_upload_endpoint response - will be set dynamically + def get_upload_endpoint_side_effect(evaluator_id, version_id, filename_to_size): response = MagicMock() signed_urls = {} for filename in filename_to_size.keys(): @@ -101,35 +106,13 @@ def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): response.filename_to_signed_urls = signed_urls return response - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect + mock_client.evaluator_versions.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - # Mock evaluators.validate_upload response + # Mock evaluator_versions.validate_upload response mock_validate_response = MagicMock() mock_validate_response.success = True mock_validate_response.valid = True - mock_client.evaluators.validate_upload.return_value = mock_validate_response - - # Mock evaluators.get (for force flow - raises NotFoundError by default) - import fireworks - - mock_client.evaluators.get.side_effect = fireworks.NotFoundError( - "Evaluator not found", - response=MagicMock(status_code=404), - body={"error": "not found"}, - ) - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - - yield mock_client - - -@pytest.fixture -def mock_platform_api_client(): - """Mock the Fireworks SDK client used in platform_api.py for secrets""" - with patch("eval_protocol.platform_api.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client + mock_client.evaluator_versions.validate_upload.return_value = mock_validate_response # Mock secrets.get - raise NotFoundError to simulate secret doesn't exist from fireworks import NotFoundError @@ -141,13 +124,23 @@ def mock_platform_api_client(): ) # Mock secrets.create - successful - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/secrets/test-secret" - mock_client.secrets.create.return_value = mock_create_response + mock_secrets_create_response = MagicMock() + mock_secrets_create_response.name = "accounts/test_account/secrets/test-secret" + mock_client.secrets.create.return_value = mock_secrets_create_response yield mock_client +@pytest.fixture +def mock_platform_api_client(mock_fireworks_client): + """ + Mock the Fireworks SDK client for secrets. + This is now just an alias for mock_fireworks_client since both use the same patched location. + The mock_fireworks_client fixture already includes secrets mocking. + """ + yield mock_fireworks_client + + def test_ep_upload_discovers_and_uploads_evaluation_test( mock_env_variables, mock_fireworks_client, mock_platform_api_client, mock_gcs_upload, monkeypatch ): @@ -158,7 +151,8 @@ def test_ep_upload_discovers_and_uploads_evaluation_test( - Upload via upload_command - Verify all API calls """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests + from eval_protocol.cli_commands.upload import upload_command + from eval_protocol.cli_commands.utils import _discover_tests # 1. CREATE TEST PROJECT STRUCTURE test_content = """ @@ -214,12 +208,11 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: id="test-simple-eval", # Explicit ID display_name="Simple Word Count Eval", description="E2E test evaluator", - force=False, yes=True, # Non-interactive ) # Mock the selection (auto-select the discovered test) - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: + with patch("eval_protocol.cli_commands.utils._prompt_select") as mock_select: mock_select.return_value = discovered_tests # Execute upload command @@ -232,13 +225,18 @@ async def test_simple_evaluation(row: EvaluationRow) -> EvaluationRow: # Step 1: Create evaluator assert mock_fireworks_client.evaluators.create.called, "Should call evaluators.create" - # Step 2: Get upload endpoint - assert mock_fireworks_client.evaluators.get_upload_endpoint.called, ( - "Should call evaluators.get_upload_endpoint" + # Step 1b: Create evaluator version + assert mock_fireworks_client.evaluator_versions.create.called, "Should call evaluator_versions.create" + + # Step 2: Get upload endpoint (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, ( + "Should call evaluator_versions.get_upload_endpoint" ) - # Step 3: Validate upload - assert mock_fireworks_client.evaluators.validate_upload.called, "Should call evaluators.validate_upload" + # Step 3: Validate upload (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.validate_upload.called, ( + "Should call evaluator_versions.validate_upload" + ) # Step 4: GCS upload assert mock_gcs_upload.send.called, "Should upload tar.gz to GCS" @@ -283,7 +281,8 @@ def test_ep_upload_with_parametrized_test( Test ep upload with a parametrized @evaluation_test Verifies that parametrized tests are discovered and uploaded as single evaluator """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests + from eval_protocol.cli_commands.upload import upload_command + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ import pytest @@ -327,11 +326,10 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: id="test-param-eval", display_name="Parametrized Eval", description="Test parametrized evaluator", - force=False, yes=True, ) - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: + with patch("eval_protocol.cli_commands.utils._prompt_select") as mock_select: mock_select.return_value = discovered_tests exit_code = upload_command(args) @@ -339,8 +337,9 @@ async def test_multi_model_eval(row: EvaluationRow) -> EvaluationRow: # Verify upload flow completed via Fireworks SDK assert mock_fireworks_client.evaluators.create.called - assert mock_fireworks_client.evaluators.get_upload_endpoint.called - assert mock_fireworks_client.evaluators.validate_upload.called + assert mock_fireworks_client.evaluator_versions.create.called + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called + assert mock_fireworks_client.evaluator_versions.validate_upload.called assert mock_gcs_upload.send.called finally: @@ -355,7 +354,7 @@ def test_ep_upload_discovery_skips_problematic_files(mock_env_variables): Test that discovery properly skips files like setup.py, versioneer.py that would cause issues during pytest collection """ - from eval_protocol.cli_commands.upload import _discover_tests + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ from eval_protocol.pytest import evaluation_test @@ -403,7 +402,7 @@ def test_ep_upload_discovers_non_test_prefixed_files(mock_env_variables): Test that discovery finds @evaluation_test in files like quickstart.py (files that don't start with 'test_') """ - from eval_protocol.cli_commands.upload import _discover_tests + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ from eval_protocol.pytest import evaluation_test @@ -453,7 +452,8 @@ def test_ep_upload_complete_workflow_with_entry_point_validation( - Full 5-step upload flow - Payload structure """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests + from eval_protocol.cli_commands.upload import upload_command + from eval_protocol.cli_commands.utils import _discover_tests test_content = """ from typing import List @@ -506,11 +506,10 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: id=None, # Auto-generate from test name display_name=None, # Auto-generate description=None, # Auto-generate - force=False, yes=True, ) - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: + with patch("eval_protocol.cli_commands.utils._prompt_select") as mock_select: mock_select.return_value = discovered_tests exit_code = upload_command(args) @@ -520,8 +519,13 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: # Step 1: Create evaluator assert mock_fireworks_client.evaluators.create.called, "Missing create call" - # Step 2: Get upload endpoint - assert mock_fireworks_client.evaluators.get_upload_endpoint.called, "Missing getUploadEndpoint call" + # Step 1b: Create evaluator version + assert mock_fireworks_client.evaluator_versions.create.called, "Missing evaluator_versions.create call" + + # Step 2: Get upload endpoint (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.get_upload_endpoint.called, ( + "Missing evaluator_versions.get_upload_endpoint call" + ) # Step 3: Upload to GCS assert mock_gcs_upload.send.called, "Missing GCS upload" @@ -529,8 +533,10 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: assert gcs_request.method == "PUT" assert "storage.googleapis.com" in gcs_request.url - # Step 4: Validate - assert mock_fireworks_client.evaluators.validate_upload.called, "Missing validateUpload call" + # Step 4: Validate (via evaluator_versions API) + assert mock_fireworks_client.evaluator_versions.validate_upload.called, ( + "Missing evaluator_versions.validate_upload call" + ) # 4. VERIFY PAYLOAD DETAILS create_call = mock_fireworks_client.evaluators.create.call_args @@ -547,8 +553,8 @@ async def test_math_correctness(row: EvaluationRow) -> EvaluationRow: assert "test_math_eval.py::test_math_correctness" in entry_point # 5. VERIFY TAR.GZ WAS CREATED AND UPLOADED - # Check getUploadEndpoint call payload - upload_call = mock_fireworks_client.evaluators.get_upload_endpoint.call_args + # Check getUploadEndpoint call payload (via evaluator_versions API) + upload_call = mock_fireworks_client.evaluator_versions.get_upload_endpoint.call_args assert upload_call is not None filename_to_size = upload_call.kwargs.get("filename_to_size", {}) assert filename_to_size, "Should have filename_to_size" @@ -597,95 +603,3 @@ def test_create_tar_includes_dockerignored_files(tmp_path): for expected_path in expected_paths: assert expected_path in names, f"Expected {expected_path} in archive" - - -def test_ep_upload_force_flag_triggers_delete_flow( - mock_env_variables, - mock_gcs_upload, - mock_platform_api_client, -): - """ - Test that --force flag triggers the check/delete/recreate flow - """ - from eval_protocol.cli_commands.upload import upload_command, _discover_tests - - test_content = """ -from eval_protocol.pytest import evaluation_test -from eval_protocol.models import EvaluationRow - -@evaluation_test(input_rows=[[EvaluationRow()]]) -async def test_force_eval(row: EvaluationRow) -> EvaluationRow: - return row -""" - - test_project_dir, test_file_path = create_test_project_with_evaluation_test(test_content, "test_force.py") - - original_cwd = os.getcwd() - - try: - os.chdir(test_project_dir) - - # Mock the Fireworks client with evaluator existing (for force flow) - with patch("eval_protocol.evaluation.Fireworks") as mock_fw_class: - mock_client = MagicMock() - mock_fw_class.return_value = mock_client - - # Mock evaluators.get to return an existing evaluator (not raise NotFoundError) - mock_existing_evaluator = MagicMock() - mock_existing_evaluator.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.get.return_value = mock_existing_evaluator - - # Mock evaluators.delete - mock_client.evaluators.delete.return_value = None - - # Mock evaluators.create response - mock_create_response = MagicMock() - mock_create_response.name = "accounts/test_account/evaluators/test-force" - mock_client.evaluators.create.return_value = mock_create_response - - # Mock get_upload_endpoint - def get_upload_endpoint_side_effect(evaluator_id, filename_to_size): - response = MagicMock() - signed_urls = {} - for filename in filename_to_size.keys(): - signed_urls[filename] = f"https://storage.googleapis.com/test-bucket/{filename}?signed=true" - response.filename_to_signed_urls = signed_urls - return response - - mock_client.evaluators.get_upload_endpoint.side_effect = get_upload_endpoint_side_effect - - # Mock validate_upload - mock_client.evaluators.validate_upload.return_value = MagicMock() - - discovered_tests = _discover_tests(test_project_dir) - - args = argparse.Namespace( - path=test_project_dir, - entry=None, - id="test-force", - display_name=None, - description=None, - force=True, # Force flag enabled - yes=True, - ) - - with patch("eval_protocol.cli_commands.upload._prompt_select") as mock_select: - mock_select.return_value = discovered_tests - exit_code = upload_command(args) - - assert exit_code == 0 - - # Verify check happened (evaluators.get was called) - assert mock_client.evaluators.get.called, "Should check if evaluator exists" - - # Verify delete happened (since evaluator existed) - assert mock_client.evaluators.delete.called, "Should delete existing evaluator" - - # Verify create happened after delete - assert mock_client.evaluators.create.called, "Should create evaluator after delete" - - finally: - os.chdir(original_cwd) - if test_project_dir in sys.path: - sys.path.remove(test_project_dir) - shutil.rmtree(test_project_dir, ignore_errors=True) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 942c1962..0d4bb13e 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -41,6 +41,7 @@ def test_create_evaluation_helper(monkeypatch): # Track SDK calls create_called = False + version_create_called = False upload_endpoint_called = False validate_called = False @@ -61,7 +62,16 @@ def mock_create(evaluator_id, evaluator): assert evaluator["description"] == "Test description" return mock_evaluator_result - def mock_get_upload_endpoint(evaluator_id, filename_to_size): + # Mock evaluator_versions.create + mock_version_result = MagicMock() + mock_version_result.name = "accounts/test_account/evaluators/test-eval/versions/v1" + + def mock_version_create(evaluator_id, evaluator_version): + nonlocal version_create_called + version_create_called = True + return mock_version_result + + def mock_get_upload_endpoint(evaluator_id, version_id, filename_to_size): nonlocal upload_endpoint_called upload_endpoint_called = True mock_response = MagicMock() @@ -71,7 +81,7 @@ def mock_get_upload_endpoint(evaluator_id, filename_to_size): mock_response.filename_to_signed_urls = signed_urls return mock_response - def mock_validate_upload(evaluator_id, body): + def mock_validate_upload(evaluator_id, version_id): nonlocal validate_called validate_called = True return MagicMock() @@ -83,20 +93,21 @@ def mock_validate_upload(evaluator_id, body): mock_gcs_response.raise_for_status = MagicMock() mock_session.send.return_value = mock_gcs_response - # Patch the Fireworks client - with patch("eval_protocol.evaluation.Fireworks") as mock_fireworks_class: + # Patch the Fireworks client at the location where it's imported + with patch("eval_protocol.fireworks_client.Fireworks") as mock_fireworks_class: mock_client = MagicMock() mock_fireworks_class.return_value = mock_client mock_client.evaluators.create = mock_create - mock_client.evaluators.get_upload_endpoint = mock_get_upload_endpoint - mock_client.evaluators.validate_upload = mock_validate_upload + mock_client.evaluator_versions.create = mock_version_create + mock_client.evaluator_versions.get_upload_endpoint = mock_get_upload_endpoint + mock_client.evaluator_versions.validate_upload = mock_validate_upload # Patch requests.Session for GCS upload monkeypatch.setattr("requests.Session", lambda: mock_session) try: os.chdir(tmp_dir) - api_response = create_evaluation( + api_response, version_id = create_evaluation( evaluator_id="test-eval", display_name="Test Evaluator", description="Test description", @@ -107,8 +118,12 @@ def mock_validate_upload(evaluator_id, body): assert api_response.display_name == "Test Evaluator" assert api_response.description == "Test description" + # Verify version ID was returned + assert version_id == "v1", "Version ID should be returned" + # Verify full upload flow was executed assert create_called, "Create endpoint should be called" + assert version_create_called, "Version create should be called" assert upload_endpoint_called, "GetUploadEndpoint should be called" assert validate_called, "ValidateUpload should be called" assert mock_session.send.called, "GCS upload should happen" diff --git a/tests/test_fireworks_client.py b/tests/test_fireworks_client.py new file mode 100644 index 00000000..db0b08c6 --- /dev/null +++ b/tests/test_fireworks_client.py @@ -0,0 +1,143 @@ +"""Tests for the consolidated Fireworks client factory.""" + +import os +from unittest.mock import patch + +import pytest + +from eval_protocol.fireworks_client import ( + create_fireworks_client, + get_fireworks_extra_headers, +) + + +class TestGetFireworksExtraHeaders: + """Tests for get_fireworks_extra_headers function.""" + + def test_returns_none_when_env_var_not_set(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is not set.""" + with patch.dict(os.environ, {}, clear=True): + # Remove the env var if it exists + os.environ.pop("FIREWORKS_EXTRA_HEADERS", None) + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_empty_string(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is empty.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": ""}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_whitespace_only(self): + """Should return None when FIREWORKS_EXTRA_HEADERS is whitespace only.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": " "}): + result = get_fireworks_extra_headers() + assert result is None + + def test_parses_valid_json_object(self): + """Should parse valid JSON object into dict.""" + headers = '{"X-Custom": "value", "X-Another": "test"}' + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": headers}): + result = get_fireworks_extra_headers() + assert result == {"X-Custom": "value", "X-Another": "test"} + + def test_returns_none_for_invalid_json(self): + """Should return None and log warning for invalid JSON.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": "not json"}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_json_array(self): + """Should return None when JSON is an array instead of object.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '["item1", "item2"]'}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_json_string(self): + """Should return None when JSON is a string instead of object.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '"just a string"'}): + result = get_fireworks_extra_headers() + assert result is None + + def test_returns_none_for_non_string_values(self): + """Should return None when JSON object has non-string values.""" + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": '{"key": 123}'}): + result = get_fireworks_extra_headers() + assert result is None + + +class TestCreateFireworksClient: + """Tests for create_fireworks_client function.""" + + def test_creates_client_with_explicit_api_key(self): + """Should create client with explicitly provided API key.""" + client = create_fireworks_client(api_key="test-api-key") + assert client.api_key == "test-api-key" + + def test_creates_client_with_explicit_base_url(self): + """Should create client with explicitly provided base URL.""" + client = create_fireworks_client( + api_key="test-api-key", + base_url="https://custom.api.example.com", + ) + assert str(client.base_url).rstrip("/") == "https://custom.api.example.com" + + def test_creates_client_with_explicit_account_id(self): + """Should create client with explicitly provided account ID.""" + client = create_fireworks_client( + api_key="test-api-key", + account_id="test-account-123", + ) + assert client.account_id == "test-account-123" + + def test_creates_client_with_explicit_extra_headers(self): + """Should create client with explicitly provided extra headers.""" + extra_headers = {"X-Custom-Header": "test-value"} + client = create_fireworks_client( + api_key="test-api-key", + extra_headers=extra_headers, + ) + assert "X-Custom-Header" in client._custom_headers + assert client._custom_headers["X-Custom-Header"] == "test-value" + + def test_merges_env_and_explicit_extra_headers(self): + """Should merge env var headers with explicit headers, explicit taking precedence.""" + env_headers = '{"X-Env-Header": "env-value", "X-Override": "env"}' + explicit_headers = {"X-Explicit-Header": "explicit-value", "X-Override": "explicit"} + + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": env_headers}): + client = create_fireworks_client( + api_key="test-api-key", + extra_headers=explicit_headers, + ) + # Both headers should be present + assert client._custom_headers["X-Env-Header"] == "env-value" + assert client._custom_headers["X-Explicit-Header"] == "explicit-value" + # Explicit should override env + assert client._custom_headers["X-Override"] == "explicit" + + def test_uses_env_extra_headers_when_no_explicit(self): + """Should use env var extra headers when no explicit headers provided.""" + env_headers = '{"X-Env-Header": "env-value"}' + + with patch.dict(os.environ, {"FIREWORKS_EXTRA_HEADERS": env_headers}): + client = create_fireworks_client(api_key="test-api-key") + assert client._custom_headers["X-Env-Header"] == "env-value" + + def test_resolves_api_key_from_env(self): + """Should resolve API key from environment when not explicitly provided.""" + with patch.dict(os.environ, {"FIREWORKS_API_KEY": "env-api-key"}): + client = create_fireworks_client() + assert client.api_key == "env-api-key" + + def test_resolves_base_url_from_env(self): + """Should resolve base URL from environment when not explicitly provided.""" + with patch.dict( + os.environ, + { + "FIREWORKS_API_KEY": "test-key", + "FIREWORKS_API_BASE": "https://env.api.example.com", + }, + ): + client = create_fireworks_client() + assert str(client.base_url).rstrip("/") == "https://env.api.example.com" diff --git a/tests/test_upload_entrypoint.py b/tests/test_upload_entrypoint.py index 2ae23024..076a6f79 100644 --- a/tests/test_upload_entrypoint.py +++ b/tests/test_upload_entrypoint.py @@ -28,8 +28,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - # Simulate API response - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -40,7 +40,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -72,7 +71,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -83,7 +83,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -119,7 +118,8 @@ def test_llm_judge(row=None): def fake_create_evaluation(**kwargs): captured.update(kwargs) - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -130,7 +130,6 @@ def fake_create_evaluation(**kwargs): id=None, display_name=None, description=None, - force=False, yes=True, ) @@ -163,8 +162,8 @@ def test_llm_judge(row=None): monkeypatch.setenv("FIREWORKS_API_BASE", "https://dev.api.fireworks.ai") def fake_create_evaluation(**kwargs): - # Simulate creation result with evaluator name - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate creation result with evaluator name - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -174,7 +173,6 @@ def fake_create_evaluation(**kwargs): id="quickstart-test-llm-judge", display_name=None, description=None, - force=True, yes=True, ) @@ -204,7 +202,8 @@ def test_llm_judge(row=None): monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") def fake_create_evaluation(**kwargs): - return {"name": kwargs.get("evaluator_id", "eval")} + # Simulate API response - returns (result, version_id) tuple + return {"name": kwargs.get("evaluator_id", "eval")}, "v1" monkeypatch.setattr(upload_mod, "create_evaluation", fake_create_evaluation) @@ -214,7 +213,6 @@ def fake_create_evaluation(**kwargs): id="quickstart-test-llm-judge", display_name=None, description=None, - force=False, yes=True, ) diff --git a/uv.lock b/uv.lock index c175b81f..188413f6 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,7 +1312,7 @@ requires-dist = [ { name = "dspy", marker = "extra == 'dspy'", specifier = ">=3.0.0" }, { name = "e2b", marker = "extra == 'dev'" }, { name = "fastapi", specifier = ">=0.116.1" }, - { name = "fireworks-ai", specifier = "==1.0.0a20" }, + { name = "fireworks-ai", specifier = "==1.0.0a22" }, { name = "google-auth", marker = "extra == 'bigquery'", specifier = ">=2.0.0" }, { name = "google-cloud-bigquery", marker = "extra == 'bigquery'", specifier = ">=3.0.0" }, { name = "gymnasium", marker = "extra == 'dev'", specifier = ">=1.2.0" }, @@ -1582,7 +1582,7 @@ wheels = [ [[package]] name = "fireworks-ai" -version = "1.0.0a20" +version = "1.0.0a22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1594,9 +1594,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1d/c6/cdc6c152876ee1253491e6f72c65c2cdaf7b22b320be0cec7ac5778d3b1c/fireworks_ai-1.0.0a20.tar.gz", hash = "sha256:c84f702445679ea768461dba8fb027175b82255021832a89f9ece65821a2ab25", size = 564097, upload-time = "2025-12-23T19:21:17.891Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/16/073cf6855d18e43c14972d4b8f8fe59a43e41b581d430a7fad1dae3b8ddf/fireworks_ai-1.0.0a22.tar.gz", hash = "sha256:ab6fc7ad2beb8d69454b8c8c34ccd5d97ffa8cefa308a5cac7e568676e4b1188", size = 572510, upload-time = "2026-01-13T23:52:12.538Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/a4/e2bc9c4af291786bc7fe364ae63503ba2c8161c2e71223d570a77f0a1415/fireworks_ai-1.0.0a20-py3-none-any.whl", hash = "sha256:b5e199978f71b564b2e19cf55a71c1ac20906d9a7b4ae75135fdccb245227722", size = 304153, upload-time = "2025-12-23T19:21:15.943Z" }, + { url = "https://files.pythonhosted.org/packages/26/ef/a932f1fc357b7847258c212d53c074df3956ffbcbf74b2d5c3fdf14fd805/fireworks_ai-1.0.0a22-py3-none-any.whl", hash = "sha256:4ee18a0cb454585baab4803d82ec647d70fd8078a737a7ca4be7a686bc468ce3", size = 316745, upload-time = "2026-01-13T23:52:11.268Z" }, ] [[package]]