diff --git a/eval_protocol/pytest/github_action_rollout_processor.py b/eval_protocol/pytest/github_action_rollout_processor.py index bbdd8b84..3e4f9ec0 100644 --- a/eval_protocol/pytest/github_action_rollout_processor.py +++ b/eval_protocol/pytest/github_action_rollout_processor.py @@ -1,13 +1,12 @@ import asyncio import os import time -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import json import requests from datetime import datetime, timezone, timedelta from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig @@ -21,7 +20,7 @@ class GithubActionRolloutProcessor(RolloutProcessor): Expected GitHub Actions workflow: - Workflow dispatch with inputs: completion_params, metadata (JSON), model_base_url, api_key - Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy) - - Traces are fetched later via output_data_loader using rollout_id tags + - Traces are fetched later via Fireworks tracing proxy using rollout_id tags NOTE: GHA has a rate limit of 5000 requests per hour. """ @@ -38,7 +37,6 @@ def __init__( timeout_seconds: float = 1800.0, max_find_workflow_retries: int = 5, github_token: Optional[str] = None, - output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, ): self.owner = owner self.repo = repo @@ -52,7 +50,6 @@ def __init__( self.timeout_seconds = timeout_seconds self.max_find_workflow_retries = max_find_workflow_retries self.github_token = github_token - self._output_data_loader = output_data_loader or default_fireworks_output_data_loader def _headers(self) -> Dict[str, str]: headers = {"Accept": "application/vnd.github+json"} @@ -200,7 +197,7 @@ def _get_run() -> Dict[str, Any]: row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time def _update_with_trace() -> None: - return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url) + return update_row_with_remote_trace(row, default_fireworks_output_data_loader, self.model_base_url) await asyncio.to_thread(_update_with_trace) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index ab42bdcd..aa1c5d44 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, Dict, List, Optional, Callable +from typing import Any, Dict, List, Optional import requests @@ -26,8 +26,7 @@ class RemoteRolloutProcessor(RolloutProcessor): """ Rollout processor that triggers a remote HTTP server to perform the rollout. - By default, fetches traces from the Fireworks tracing proxy using rollout_id tags. - You can provide a custom output_data_loader for different tracing backends. + Fetches traces from the Fireworks tracing proxy using rollout_id tags. See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation. """ @@ -39,7 +38,6 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, ): # Prefer constructor-provided configuration. These can be overridden via # config.kwargs at call time for backward compatibility. @@ -52,7 +50,6 @@ def __init__( self._model_base_url = _ep_model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds - self._output_data_loader = output_data_loader or default_fireworks_output_data_loader self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url) def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: @@ -188,7 +185,7 @@ def _get_status() -> Dict[str, Any]: row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time def _update_with_trace() -> None: - return update_row_with_remote_trace(row, self._output_data_loader, model_base_url) + return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url) await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place return row diff --git a/tests/github_actions/test_github_actions_rollout.py b/tests/github_actions/test_github_actions_rollout.py index b68236fd..f8f7775a 100644 --- a/tests/github_actions/test_github_actions_rollout.py +++ b/tests/github_actions/test_github_actions_rollout.py @@ -12,36 +12,33 @@ from eval_protocol.models import EvaluationRow, InputMetadata from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor +import eval_protocol.pytest.github_action_rollout_processor as github_action_rollout_processor_module from eval_protocol.types.remote_rollout_processor import DataLoaderConfig -from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation + ROLLOUT_IDS = set() @pytest.fixture(autouse=True) -def check_rollout_coverage(): - """Ensure we processed all expected rollout_ids""" +def check_rollout_coverage(monkeypatch): + """ + Ensure we attempted to fetch remote traces for each rollout. + + This wraps the built-in default_fireworks_output_data_loader (without making it configurable) + and tracks rollout_ids passed through its DataLoaderConfig. + """ global ROLLOUT_IDS ROLLOUT_IDS.clear() - yield - - assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + original_loader = github_action_rollout_processor_module.default_fireworks_output_data_loader -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(config.rollout_id) + def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader: + ROLLOUT_IDS.add(config.rollout_id) + return original_loader(config) - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) + monkeypatch.setattr(github_action_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader) + yield + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" def rows() -> List[EvaluationRow]: @@ -68,14 +65,11 @@ def rows() -> List[EvaluationRow]: ref=os.getenv("GITHUB_REF", "main"), poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval timeout_seconds=300, - output_data_loader=fireworks_output_data_loader, ), ) async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow: """Test GitHub Actions rollout with worker-controlled dataset.""" - # Track rollout IDs for coverage check - global ROLLOUT_IDS - ROLLOUT_IDS.add(row.execution_metadata.rollout_id) + assert row.execution_metadata.rollout_id is not None # This dataset is built into github_actions/rollout_worker.py if row.messages[0].content == "What is the capital of France?": diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index e364b788..4ac4fd6c 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -37,8 +37,23 @@ def _worker(): if not model: raise ValueError("model is required in completion_params") + # Convert Eval Protocol Message objects into OpenAI-compatible dicts, + # excluding any None fields (Fireworks rejects extra keys even when null). + messages_payload = [] + for m in req.messages: + if hasattr(m, "dump_mdoel_for_chat_completion_request"): + md = m.dump_mdoel_for_chat_completion_request() # type: ignore[attr-defined] + elif hasattr(m, "model_dump"): + md = m.model_dump(exclude_none=True) # type: ignore[call-arg] + elif isinstance(m, dict): + md = {k: v for k, v in m.items() if v is not None} + else: + md = {"role": getattr(m, "role", None), "content": getattr(m, "content", None)} + md = {k: v for k, v in md.items() if v is not None} + messages_payload.append(md) + # Spread all completion_params (model, temperature, max_tokens, etc.) - completion_kwargs = {"messages": req.messages, **req.completion_params} + completion_kwargs = {"messages": messages_payload, **req.completion_params} if req.tools: completion_kwargs["tools"] = req.tools diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index db5fdb49..e172e309 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -1,6 +1,5 @@ # AUTO SERVER STARTUP: Server is automatically started and stopped by the test -import os import subprocess import socket import time @@ -13,13 +12,35 @@ from eval_protocol.models import EvaluationRow, Message, EvaluateResult from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation +import eval_protocol.pytest.remote_rollout_processor as remote_rollout_processor_module from eval_protocol.types.remote_rollout_processor import DataLoaderConfig + ROLLOUT_IDS = set() +@pytest.fixture(autouse=True) +def check_rollout_coverage(monkeypatch): + """ + Ensure we attempted to fetch remote traces for each rollout. + + This wraps the built-in default_fireworks_output_data_loader (without making it configurable) + and tracks rollout_ids passed through its DataLoaderConfig. + """ + global ROLLOUT_IDS + ROLLOUT_IDS.clear() + + original_loader = remote_rollout_processor_module.default_fireworks_output_data_loader + + def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader: + ROLLOUT_IDS.add(config.rollout_id) + return original_loader(config) + + monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader) + yield + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + + def find_available_port() -> int: """Find an available port on localhost""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -68,31 +89,6 @@ def setup_remote_server(): process.wait() -@pytest.fixture(autouse=True) -def check_rollout_coverage(): - """Ensure we processed all expected rollout_ids""" - global ROLLOUT_IDS - ROLLOUT_IDS.clear() - yield - - assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" - - -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(config.rollout_id) - - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) - - def rows() -> List[EvaluationRow]: """Generate local rows with rich input_metadata to verify it survives remote traces.""" base_dataset_info = { @@ -118,7 +114,6 @@ def rows() -> List[EvaluationRow]: rollout_processor=RemoteRolloutProcessor( remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", timeout_seconds=180, - output_data_loader=fireworks_output_data_loader, ), ) async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: @@ -133,9 +128,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." - assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( - f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" - ) assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b" assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level" diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 8e2aaaa8..81a3436f 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -12,9 +12,6 @@ from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig def find_available_port() -> int: @@ -67,18 +64,6 @@ def setup_remote_server(): process.wait() -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) - - def rows() -> List[EvaluationRow]: row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) return [row] @@ -92,7 +77,6 @@ def rows() -> List[EvaluationRow]: rollout_processor=RemoteRolloutProcessor( remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", timeout_seconds=120, - output_data_loader=fireworks_output_data_loader, ), ) async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 8c66b136..f53e304f 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -1,86 +1,88 @@ -# MANUAL SERVER STARTUP REQUIRED: -# -# For Python server testing, start: -# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) -# -# For TypeScript server testing, start: -# cd tests/remote_server/typescript-server -# npm install -# npm start -# -# The TypeScript server should be running on http://127.0.0.1:3000 -# You only need to start one of the servers! - -import os -from typing import List - -import pytest - -from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.langfuse import create_langfuse_adapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig - -ROLLOUT_IDS = set() - - -@pytest.fixture(autouse=True) -def check_rollout_coverage(): - """Ensure we processed all expected rollout_ids""" - global ROLLOUT_IDS - ROLLOUT_IDS.clear() - yield - - assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" - - -def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(config.rollout_id) - - adapter = create_langfuse_adapter() - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) - - -def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation - ) - - -def rows() -> List[EvaluationRow]: - row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) - return [row, row, row] - - -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") -@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) -@evaluation_test( - data_loaders=DynamicDataLoader( - generators=[rows], - ), - rollout_processor=RemoteRolloutProcessor( - remote_base_url="http://127.0.0.1:3000", - timeout_seconds=30, - output_data_loader=langfuse_output_data_loader, - model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk", - ), -) -async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: - """ - End-to-end test: - - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server - - trigger remote rollout via RemoteRolloutProcessor (calls init/status) - - fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found - """ - assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" - assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." - - assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( - f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" - ) - - return row +# NOTE: This test is deprecated. We no longer support custom output data loaders, including pulling from Langfuse. We can revisit this in the future. + +# # MANUAL SERVER STARTUP REQUIRED: +# # +# # For Python server testing, start: +# # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# # +# # For TypeScript server testing, start: +# # cd tests/remote_server/typescript-server +# # npm install +# # npm start +# # +# # The TypeScript server should be running on http://127.0.0.1:3000 +# # You only need to start one of the servers! + +# import os +# from typing import List + +# import pytest + +# from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +# from eval_protocol.models import EvaluationRow, Message +# from eval_protocol.pytest import evaluation_test +# from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +# from eval_protocol.adapters.langfuse import create_langfuse_adapter +# from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation +# from eval_protocol.types.remote_rollout_processor import DataLoaderConfig + +# ROLLOUT_IDS = set() + + +# @pytest.fixture(autouse=True) +# def check_rollout_coverage(): +# """Ensure we processed all expected rollout_ids""" +# global ROLLOUT_IDS +# ROLLOUT_IDS.clear() +# yield + +# assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" + + +# def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: +# global ROLLOUT_IDS # Track all rollout_ids we've seen +# ROLLOUT_IDS.add(config.rollout_id) + +# adapter = create_langfuse_adapter() +# return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + + +# def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: +# return DynamicDataLoader( +# generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation +# ) + + +# def rows() -> List[EvaluationRow]: +# row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) +# return [row, row, row] + + +# @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +# @pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +# @evaluation_test( +# data_loaders=DynamicDataLoader( +# generators=[rows], +# ), +# rollout_processor=RemoteRolloutProcessor( +# remote_base_url="http://127.0.0.1:3000", +# timeout_seconds=30, +# output_data_loader=langfuse_output_data_loader, +# model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk", +# ), +# ) +# async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: +# """ +# End-to-end test: +# - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server +# - trigger remote rollout via RemoteRolloutProcessor (calls init/status) +# - fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found +# """ +# assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" +# assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + +# assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( +# f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" +# ) + +# return row diff --git a/tests/remote_server/typescript-server/README.md b/tests/remote_server/typescript-server/README.md index 434b82a8..90a7083e 100644 --- a/tests/remote_server/typescript-server/README.md +++ b/tests/remote_server/typescript-server/README.md @@ -120,7 +120,6 @@ from eval_protocol import ( data_loaders=[InlineDataLoader(messages=[[Message(role="user", content="Hello")]])], rollout_processor=RemoteRolloutProcessor( remote_base_url="http://localhost:3000", - output_data_loader=create_output_data_loader, ) ) def test_remote_http(row: EvaluationRow) -> EvaluationRow: