Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 96 additions & 38 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import os
import sys
import time
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional
import inspect
import requests
import tempfile
from pydantic import ValidationError

from ..auth import get_fireworks_api_base, get_fireworks_api_key
from ..common_utils import get_user_agent
from ..common_utils import get_user_agent, load_jsonl
from ..fireworks_rft import (
build_default_output_model,
create_dataset_from_jsonl,
detect_dataset_builder,
materialize_dataset_via_builder,
Expand All @@ -31,12 +31,88 @@
_normalize_evaluator_id,
_print_links,
_resolve_selected_test,
load_module_from_file_path,
)
from .local_test import run_evaluator_test

from fireworks import Fireworks


def _extract_dataset_adapter(
test_file_path: str, test_func_name: str
) -> Optional[Callable[[list[dict[str, Any]]], Any]]:
"""Extract dataset_adapter from an @evaluation_test wrapper via __ep_params__."""
try:
module = load_module_from_file_path(test_file_path)
wrapper = getattr(module, test_func_name, None)
if wrapper is None:
return None
ep_params = getattr(wrapper, "__ep_params__", None)
if ep_params is None:
return None
adapter = getattr(ep_params, "dataset_adapter", None)
if callable(adapter):
return adapter
return None
except Exception:
return None


def _maybe_transform_dataset_jsonl_via_adapter(
project_root: str,
dataset_jsonl: str,
test_file_path: Optional[str],
test_func_name: Optional[str],
) -> str:
"""Transform dataset_jsonl via the test's dataset_adapter (when available).

For RFT dataset uploads, we want the uploaded dataset to match what evaluation-time
would run on. If the selected evaluation test provides a dataset_adapter, that
adapter is treated as the source of truth for constructing EvaluationRows.
"""
if not dataset_jsonl:
return dataset_jsonl

if not test_file_path or not test_func_name:
return dataset_jsonl

adapter = _extract_dataset_adapter(test_file_path, test_func_name)
if not adapter:
return dataset_jsonl

raw_rows: list[dict[str, Any]] = load_jsonl(dataset_jsonl) # type: ignore[assignment]
adapted = adapter(raw_rows)
if not isinstance(adapted, list):
raise ValueError("dataset_adapter must return a list of EvaluationRow (or dicts parseable as EvaluationRow).")

eval_rows: list[EvaluationRow] = []
for item in adapted:
if isinstance(item, EvaluationRow):
eval_rows.append(item)
else:
eval_rows.append(EvaluationRow.model_validate(item))

output_dir = os.path.join(project_root, ".ep_tmp")
os.makedirs(output_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
delete=False,
suffix=".jsonl",
prefix="ep_rft_dataset_",
dir=output_dir,
) as f:
for row in eval_rows:
f.write(json.dumps(row.model_dump(mode="json", exclude_none=True), ensure_ascii=False) + "\n")
out_path = os.path.abspath(f.name)
try:
rel = os.path.relpath(out_path, project_root)
except Exception:
rel = out_path
print(f"✓ Transformed dataset via dataset_adapter into EvaluationRow JSONL: {rel} ({len(eval_rows)} rows)")
return out_path


def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
"""Import the test module and extract a JSONL path from data_loaders param if present.

Expand All @@ -45,18 +121,10 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) ->
relative to the directory of the test file.
"""
try:
import importlib.util
from pathlib import Path

spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
if not spec or not spec.loader:
return None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
if not hasattr(module, test_func_name):
module = load_module_from_file_path(test_file_path)
wrapper = getattr(module, test_func_name, None)
if wrapper is None:
return None
wrapper = getattr(module, test_func_name)
marks = getattr(wrapper, "pytestmark", [])
for m in marks:
if getattr(m, "name", "") == "parametrize":
Expand Down Expand Up @@ -105,18 +173,10 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str)
of the test file.
"""
try:
import importlib.util
from pathlib import Path

spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
if not spec or not spec.loader:
return None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
if not hasattr(module, test_func_name):
module = load_module_from_file_path(test_file_path)
wrapper = getattr(module, test_func_name, None)
if wrapper is None:
return None
wrapper = getattr(module, test_func_name)
marks = getattr(wrapper, "pytestmark", [])
for m in marks:
if getattr(m, "name", "") == "parametrize":
Expand Down Expand Up @@ -320,27 +380,15 @@ def _resolve_evaluator(
selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive)
if not selected_tests:
return None, None, None, None

if len(selected_tests) != 1:
if non_interactive and len(selected_tests) > 1:
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
print(" Please pass --evaluator or --entry to disambiguate.")
try:
# Offer candidate evaluator ids for convenience
tests = _discover_tests(project_root)
if tests:
print(" Candidate evaluator ids:")
for t in tests:
func = t.qualname.split(".")[-1]
stem = os.path.splitext(os.path.basename(t.file_path))[0]
cand = _normalize_evaluator_id(f"{stem}-{func}")
print(f" - {cand}")
except Exception:
pass
else:
print("Error: Please select exactly one evaluation test for 'create rft'.")
return None, None, None, None

# Derive evaluator_id from user's single selection
chosen = selected_tests[0]
func_name = chosen.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
Expand Down Expand Up @@ -719,6 +767,16 @@ def create_rft_command(args) -> int:
if dataset_jsonl is None and not dataset_id:
return 1

# 2.5) If the selected evaluation test provides a dataset_adapter, always use it to
# construct the EvaluationRow dataset that we upload for RFT.
if dataset_jsonl is not None:
dataset_jsonl = _maybe_transform_dataset_jsonl_via_adapter(
project_root=project_root,
dataset_jsonl=dataset_jsonl,
test_file_path=selected_test_file_path,
test_func_name=selected_test_func_name,
)

# 3) Optional local validation
if not skip_validation:
# Dataset validation (JSONL must be EvaluationRow-compatible when present)
Expand Down
11 changes: 3 additions & 8 deletions eval_protocol/cli_commands/upload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
from eval_protocol.cli_commands.utils import DiscoveredTest
import importlib.util
import os
import re
import sys
Expand All @@ -18,6 +17,7 @@
_discover_tests,
_ensure_account_id,
_get_questionary_style,
load_module_from_file_path,
_normalize_evaluator_id,
_prompt_select,
)
Expand Down Expand Up @@ -120,13 +120,8 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]:
source_file_path = os.path.join(cwd, dotted_as_path)

# Load the module from the file path
spec = importlib.util.spec_from_file_location(Path(source_file_path).stem, source_file_path)
if not spec or not spec.loader:
raise ValueError(f"Unable to load module from path: {source_file_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
module_name = spec.name
module = load_module_from_file_path(source_file_path)
module_name = getattr(module, "__name__", Path(source_file_path).stem)

if not hasattr(module, func):
raise ValueError(f"Function '{func}' not found in module '{module_name}'")
Expand Down
71 changes: 52 additions & 19 deletions eval_protocol/cli_commands/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from types import ModuleType


import os
import ast
import sys
Expand All @@ -6,23 +9,46 @@
import argparse
import typing
import types
import importlib.util
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional, is_typeddict
from typing import Any, List, Optional
import typing_extensions
import inspect
from collections.abc import Callable
import pytest

from ..auth import (
get_fireworks_account_id,
get_fireworks_api_base,
get_fireworks_api_key,
verify_api_key_and_get_account_id,
)
from ..fireworks_rft import _map_api_host_to_app_host


def load_module_from_file_path(source_file_path: str) -> ModuleType:
"""Load a Python module from an absolute/relative filesystem path.

This mirrors the CLI behavior used by `upload.py` and `create_rft.py`:
- module name is derived from the file stem (e.g. /a/b/foo.py -> foo)
- the module is inserted into sys.modules under that name before exec
"""
abs_path = os.path.abspath(source_file_path)
if not os.path.isfile(abs_path):
raise ValueError(f"File not found: {abs_path}")
if not abs_path.endswith(".py"):
raise ValueError(f"Expected a .py file path, got: {abs_path}")

module_name = Path(abs_path).stem
spec = importlib.util.spec_from_file_location(module_name, abs_path)
if not spec or not spec.loader:
raise ValueError(f"Unable to load module from path: {abs_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
return module


def _get_questionary_style():
"""Get the shared questionary style for CLI prompts - minimal and clean."""
try:
Expand Down Expand Up @@ -252,7 +278,7 @@ def _format_test_choice(test: DiscoveredTest, idx: int) -> str:


def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTest]:
"""Interactive selection with arrow keys using questionary."""
"""Interactive single selection with arrow keys using questionary (Enter selects highlighted)."""
try:
import questionary

Expand All @@ -263,35 +289,32 @@ def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTe
print(f"\nFound 1 test: {_format_test_choice(tests[0], 1)}")
confirm = questionary.confirm("Select this test?", default=True, style=custom_style).ask()
if confirm:
return tests
return [tests[0]]
else:
return []

# Build checkbox choices
# Build single-select choices
choices = []
for idx, t in enumerate(tests, 1):
choice_text = _format_test_choice(t, idx)
choices.append(questionary.Choice(title=choice_text, value=idx - 1, checked=False))
choices.append(questionary.Choice(title=choice_text, value=idx - 1))

print()
selected_indices = questionary.checkbox(
"Select evaluation tests to upload:",
selected_index = questionary.select(
"Select an evaluation test:",
choices=choices,
style=custom_style,
pointer=">",
instruction="(↑↓ move, space select, enter confirm)",
instruction="(↑↓ move, enter confirm)",
).ask()

if selected_indices is None: # Ctrl+C
if selected_index is None: # Ctrl+C / Esc
print("\nUpload cancelled.")
return []

if not selected_indices:
return []

selected_tests = [tests[i] for i in selected_indices]
print(f"\n✓ Selected {len(selected_tests)} test(s)")
return selected_tests
chosen = tests[int(selected_index)]
print("\n✓ Selected 1 test")
return [chosen]

except ImportError:
# Fallback to simpler implementation
Expand Down Expand Up @@ -346,9 +369,10 @@ def _prompt_select_fallback(tests: list[DiscoveredTest]) -> list[DiscoveredTest]


def _prompt_select(tests: list[DiscoveredTest], non_interactive: bool) -> list[DiscoveredTest]:
"""Prompt user to select tests to upload."""
"""Prompt user to select exactly one test."""
if non_interactive:
return tests
# In non-interactive mode, only proceed if unambiguous.
return [tests[0]] if len(tests) == 1 else []

return _prompt_select_interactive(tests)

Expand All @@ -375,7 +399,16 @@ def _discover_and_select_tests(project_root: str, non_interactive: bool) -> Opti
return None

if not selected_tests:
print("No tests selected.")
if non_interactive and len(tests) > 1:
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
print(" Please pass --evaluator or --entry to disambiguate.")
else:
print("No test selected.")
return None

# Enforce single-select at the helper level.
if len(selected_tests) != 1:
print("Error: Please select exactly one evaluation test.")
return None

return selected_tests
Expand Down
Loading
Loading