diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index f95ccdcc..702eb2fe 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -5,15 +5,15 @@ import os import sys import time -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional import inspect import requests +import tempfile from pydantic import ValidationError from ..auth import get_fireworks_api_base, get_fireworks_api_key -from ..common_utils import get_user_agent +from ..common_utils import get_user_agent, load_jsonl from ..fireworks_rft import ( - build_default_output_model, create_dataset_from_jsonl, detect_dataset_builder, materialize_dataset_via_builder, @@ -31,12 +31,88 @@ _normalize_evaluator_id, _print_links, _resolve_selected_test, + load_module_from_file_path, ) from .local_test import run_evaluator_test from fireworks import Fireworks +def _extract_dataset_adapter( + test_file_path: str, test_func_name: str +) -> Optional[Callable[[list[dict[str, Any]]], Any]]: + """Extract dataset_adapter from an @evaluation_test wrapper via __ep_params__.""" + try: + module = load_module_from_file_path(test_file_path) + wrapper = getattr(module, test_func_name, None) + if wrapper is None: + return None + ep_params = getattr(wrapper, "__ep_params__", None) + if ep_params is None: + return None + adapter = getattr(ep_params, "dataset_adapter", None) + if callable(adapter): + return adapter + return None + except Exception: + return None + + +def _maybe_transform_dataset_jsonl_via_adapter( + project_root: str, + dataset_jsonl: str, + test_file_path: Optional[str], + test_func_name: Optional[str], +) -> str: + """Transform dataset_jsonl via the test's dataset_adapter (when available). + + For RFT dataset uploads, we want the uploaded dataset to match what evaluation-time + would run on. If the selected evaluation test provides a dataset_adapter, that + adapter is treated as the source of truth for constructing EvaluationRows. + """ + if not dataset_jsonl: + return dataset_jsonl + + if not test_file_path or not test_func_name: + return dataset_jsonl + + adapter = _extract_dataset_adapter(test_file_path, test_func_name) + if not adapter: + return dataset_jsonl + + raw_rows: list[dict[str, Any]] = load_jsonl(dataset_jsonl) # type: ignore[assignment] + adapted = adapter(raw_rows) + if not isinstance(adapted, list): + raise ValueError("dataset_adapter must return a list of EvaluationRow (or dicts parseable as EvaluationRow).") + + eval_rows: list[EvaluationRow] = [] + for item in adapted: + if isinstance(item, EvaluationRow): + eval_rows.append(item) + else: + eval_rows.append(EvaluationRow.model_validate(item)) + + output_dir = os.path.join(project_root, ".ep_tmp") + os.makedirs(output_dir, exist_ok=True) + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + suffix=".jsonl", + prefix="ep_rft_dataset_", + dir=output_dir, + ) as f: + for row in eval_rows: + f.write(json.dumps(row.model_dump(mode="json", exclude_none=True), ensure_ascii=False) + "\n") + out_path = os.path.abspath(f.name) + try: + rel = os.path.relpath(out_path, project_root) + except Exception: + rel = out_path + print(f"✓ Transformed dataset via dataset_adapter into EvaluationRow JSONL: {rel} ({len(eval_rows)} rows)") + return out_path + + def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]: """Import the test module and extract a JSONL path from data_loaders param if present. @@ -45,18 +121,10 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> relative to the directory of the test file. """ try: - import importlib.util - from pathlib import Path - - spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path) - if not spec or not spec.loader: - return None - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) # type: ignore[attr-defined] - if not hasattr(module, test_func_name): + module = load_module_from_file_path(test_file_path) + wrapper = getattr(module, test_func_name, None) + if wrapper is None: return None - wrapper = getattr(module, test_func_name) marks = getattr(wrapper, "pytestmark", []) for m in marks: if getattr(m, "name", "") == "parametrize": @@ -105,18 +173,10 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) of the test file. """ try: - import importlib.util - from pathlib import Path - - spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path) - if not spec or not spec.loader: - return None - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) # type: ignore[attr-defined] - if not hasattr(module, test_func_name): + module = load_module_from_file_path(test_file_path) + wrapper = getattr(module, test_func_name, None) + if wrapper is None: return None - wrapper = getattr(module, test_func_name) marks = getattr(wrapper, "pytestmark", []) for m in marks: if getattr(m, "name", "") == "parametrize": @@ -320,27 +380,15 @@ def _resolve_evaluator( selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive) if not selected_tests: return None, None, None, None + if len(selected_tests) != 1: if non_interactive and len(selected_tests) > 1: print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") print(" Please pass --evaluator or --entry to disambiguate.") - try: - # Offer candidate evaluator ids for convenience - tests = _discover_tests(project_root) - if tests: - print(" Candidate evaluator ids:") - for t in tests: - func = t.qualname.split(".")[-1] - stem = os.path.splitext(os.path.basename(t.file_path))[0] - cand = _normalize_evaluator_id(f"{stem}-{func}") - print(f" - {cand}") - except Exception: - pass else: print("Error: Please select exactly one evaluation test for 'create rft'.") return None, None, None, None - # Derive evaluator_id from user's single selection chosen = selected_tests[0] func_name = chosen.qualname.split(".")[-1] source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] @@ -719,6 +767,16 @@ def create_rft_command(args) -> int: if dataset_jsonl is None and not dataset_id: return 1 + # 2.5) If the selected evaluation test provides a dataset_adapter, always use it to + # construct the EvaluationRow dataset that we upload for RFT. + if dataset_jsonl is not None: + dataset_jsonl = _maybe_transform_dataset_jsonl_via_adapter( + project_root=project_root, + dataset_jsonl=dataset_jsonl, + test_file_path=selected_test_file_path, + test_func_name=selected_test_func_name, + ) + # 3) Optional local validation if not skip_validation: # Dataset validation (JSONL must be EvaluationRow-compatible when present) diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index 8a25b49b..a8a132d6 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -1,6 +1,5 @@ import argparse from eval_protocol.cli_commands.utils import DiscoveredTest -import importlib.util import os import re import sys @@ -18,6 +17,7 @@ _discover_tests, _ensure_account_id, _get_questionary_style, + load_module_from_file_path, _normalize_evaluator_id, _prompt_select, ) @@ -120,13 +120,8 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]: source_file_path = os.path.join(cwd, dotted_as_path) # Load the module from the file path - spec = importlib.util.spec_from_file_location(Path(source_file_path).stem, source_file_path) - if not spec or not spec.loader: - raise ValueError(f"Unable to load module from path: {source_file_path}") - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) # type: ignore[attr-defined] - module_name = spec.name + module = load_module_from_file_path(source_file_path) + module_name = getattr(module, "__name__", Path(source_file_path).stem) if not hasattr(module, func): raise ValueError(f"Function '{func}' not found in module '{module_name}'") diff --git a/eval_protocol/cli_commands/utils.py b/eval_protocol/cli_commands/utils.py index 3f941d4b..1338ae31 100644 --- a/eval_protocol/cli_commands/utils.py +++ b/eval_protocol/cli_commands/utils.py @@ -1,3 +1,6 @@ +from types import ModuleType + + import os import ast import sys @@ -6,16 +9,16 @@ import argparse import typing import types +import importlib.util from dataclasses import dataclass from pathlib import Path -from typing import Any, List, Optional, is_typeddict +from typing import Any, List, Optional import typing_extensions import inspect from collections.abc import Callable import pytest from ..auth import ( - get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, verify_api_key_and_get_account_id, @@ -23,6 +26,29 @@ from ..fireworks_rft import _map_api_host_to_app_host +def load_module_from_file_path(source_file_path: str) -> ModuleType: + """Load a Python module from an absolute/relative filesystem path. + + This mirrors the CLI behavior used by `upload.py` and `create_rft.py`: + - module name is derived from the file stem (e.g. /a/b/foo.py -> foo) + - the module is inserted into sys.modules under that name before exec + """ + abs_path = os.path.abspath(source_file_path) + if not os.path.isfile(abs_path): + raise ValueError(f"File not found: {abs_path}") + if not abs_path.endswith(".py"): + raise ValueError(f"Expected a .py file path, got: {abs_path}") + + module_name = Path(abs_path).stem + spec = importlib.util.spec_from_file_location(module_name, abs_path) + if not spec or not spec.loader: + raise ValueError(f"Unable to load module from path: {abs_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + return module + + def _get_questionary_style(): """Get the shared questionary style for CLI prompts - minimal and clean.""" try: @@ -252,7 +278,7 @@ def _format_test_choice(test: DiscoveredTest, idx: int) -> str: def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTest]: - """Interactive selection with arrow keys using questionary.""" + """Interactive single selection with arrow keys using questionary (Enter selects highlighted).""" try: import questionary @@ -263,35 +289,32 @@ def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTe print(f"\nFound 1 test: {_format_test_choice(tests[0], 1)}") confirm = questionary.confirm("Select this test?", default=True, style=custom_style).ask() if confirm: - return tests + return [tests[0]] else: return [] - # Build checkbox choices + # Build single-select choices choices = [] for idx, t in enumerate(tests, 1): choice_text = _format_test_choice(t, idx) - choices.append(questionary.Choice(title=choice_text, value=idx - 1, checked=False)) + choices.append(questionary.Choice(title=choice_text, value=idx - 1)) print() - selected_indices = questionary.checkbox( - "Select evaluation tests to upload:", + selected_index = questionary.select( + "Select an evaluation test:", choices=choices, style=custom_style, pointer=">", - instruction="(↑↓ move, space select, enter confirm)", + instruction="(↑↓ move, enter confirm)", ).ask() - if selected_indices is None: # Ctrl+C + if selected_index is None: # Ctrl+C / Esc print("\nUpload cancelled.") return [] - if not selected_indices: - return [] - - selected_tests = [tests[i] for i in selected_indices] - print(f"\n✓ Selected {len(selected_tests)} test(s)") - return selected_tests + chosen = tests[int(selected_index)] + print("\n✓ Selected 1 test") + return [chosen] except ImportError: # Fallback to simpler implementation @@ -346,9 +369,10 @@ def _prompt_select_fallback(tests: list[DiscoveredTest]) -> list[DiscoveredTest] def _prompt_select(tests: list[DiscoveredTest], non_interactive: bool) -> list[DiscoveredTest]: - """Prompt user to select tests to upload.""" + """Prompt user to select exactly one test.""" if non_interactive: - return tests + # In non-interactive mode, only proceed if unambiguous. + return [tests[0]] if len(tests) == 1 else [] return _prompt_select_interactive(tests) @@ -375,7 +399,16 @@ def _discover_and_select_tests(project_root: str, non_interactive: bool) -> Opti return None if not selected_tests: - print("No tests selected.") + if non_interactive and len(tests) > 1: + print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") + print(" Please pass --evaluator or --entry to disambiguate.") + else: + print("No test selected.") + return None + + # Enforce single-select at the helper level. + if len(selected_tests) != 1: + print("Error: Please select exactly one evaluation test.") return None return selected_tests diff --git a/tests/test_cli_create_rft.py b/tests/test_cli_create_rft.py index 3c9b1a78..1f1e8395 100644 --- a/tests/test_cli_create_rft.py +++ b/tests/test_cli_create_rft.py @@ -1206,3 +1206,95 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d assert captured["jsonl_path"] != str(inferred_jsonl) # And because --dataset-jsonl was provided, we should never call the input_dataset extractor assert calls["input_dataset"] == 0 + + +def test_create_rft_transforms_raw_input_dataset_via_dataset_adapter_before_upload(rft_test_harness, monkeypatch): + project = rft_test_harness + + # Create a real @evaluation_test-decorated module so create_rft can extract __ep_params__.dataset_adapter + metric_dir = project / "metric" + metric_dir.mkdir(parents=True, exist_ok=True) + + raw_jsonl = metric_dir / "raw.jsonl" + raw_jsonl.write_text('{"q":"hi","a":"ok"}\n{"q":"yo","a":"ok2"}\n', encoding="utf-8") + + test_file = metric_dir / "test_adapt.py" + test_file.write_text( + """ +from typing import Any +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test + +def my_adapter(rows: list[dict[str, Any]]) -> list[EvaluationRow]: + return [ + EvaluationRow(messages=[Message(role="user", content=r["q"])], ground_truth=r.get("a")) + for r in rows + ] + +@evaluation_test( + input_dataset=["raw.jsonl"], + dataset_adapter=my_adapter, + num_runs=1, + max_dataset_rows=2, + mode="pointwise", +) +def test_adapt(row: EvaluationRow) -> EvaluationRow: + return row +""".lstrip(), + encoding="utf-8", + ) + + # Discovery: exactly one test, and resolve_selected_test points to our module/function + single_disc = SimpleNamespace(qualname="metric.test_adapt.test_adapt", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + monkeypatch.setattr( + cr, + "_resolve_selected_test", + lambda project_root, evaluator_id, selected_tests=None: (str(test_file), "test_adapt"), + ) + + captured = {"jsonl_path": None} + + def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): + captured["jsonl_path"] = jsonl_path + return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"} + + monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) + + # Ensure upload path doesn't touch the network; job creation via stub_fireworks fixture + args = argparse.Namespace( + evaluator=None, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset=None, + dataset_jsonl=None, + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + assert captured["jsonl_path"] is not None + # Raw JSONL should NOT be uploaded; transformed EvaluationRow JSONL should be. + assert os.path.abspath(captured["jsonl_path"]) != os.path.abspath(str(raw_jsonl)) + assert os.path.basename(captured["jsonl_path"]).endswith(".jsonl") + # The transformed file should validate as EvaluationRow JSONL + assert cr._validate_dataset_jsonl(captured["jsonl_path"])