From 994242958e8e41647242776ae46a8578ceb7fa3a Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 6 Jan 2026 15:31:59 -0800 Subject: [PATCH 1/5] take out output dataloader --- .../pytest/github_action_rollout_processor.py | 8 +-- .../pytest/remote_rollout_processor.py | 8 +-- .../test_github_actions_rollout.py | 19 ------- tests/remote_server/test_remote_fireworks.py | 24 +------- .../test_remote_fireworks_propagate_status.py | 16 ------ tests/remote_server/test_remote_langfuse.py | 55 +------------------ .../remote_server/typescript-server/README.md | 1 - 7 files changed, 8 insertions(+), 123 deletions(-) diff --git a/eval_protocol/pytest/github_action_rollout_processor.py b/eval_protocol/pytest/github_action_rollout_processor.py index bbdd8b84..621fa830 100644 --- a/eval_protocol/pytest/github_action_rollout_processor.py +++ b/eval_protocol/pytest/github_action_rollout_processor.py @@ -1,13 +1,12 @@ import asyncio import os import time -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import json import requests from datetime import datetime, timezone, timedelta from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig @@ -21,7 +20,7 @@ class GithubActionRolloutProcessor(RolloutProcessor): Expected GitHub Actions workflow: - Workflow dispatch with inputs: completion_params, metadata (JSON), model_base_url, api_key - Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy) - - Traces are fetched later via output_data_loader using rollout_id tags + - Traces are fetched later via Fireworks tracing proxy using rollout_id tags NOTE: GHA has a rate limit of 5000 requests per hour. """ @@ -38,7 +37,6 @@ def __init__( timeout_seconds: float = 1800.0, max_find_workflow_retries: int = 5, github_token: Optional[str] = None, - output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, ): self.owner = owner self.repo = repo @@ -52,7 +50,7 @@ def __init__( self.timeout_seconds = timeout_seconds self.max_find_workflow_retries = max_find_workflow_retries self.github_token = github_token - self._output_data_loader = output_data_loader or default_fireworks_output_data_loader + self._output_data_loader = default_fireworks_output_data_loader def _headers(self) -> Dict[str, str]: headers = {"Accept": "application/vnd.github+json"} diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index ab42bdcd..5c884874 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, Dict, List, Optional, Callable +from typing import Any, Dict, List, Optional import requests @@ -26,8 +26,7 @@ class RemoteRolloutProcessor(RolloutProcessor): """ Rollout processor that triggers a remote HTTP server to perform the rollout. - By default, fetches traces from the Fireworks tracing proxy using rollout_id tags. - You can provide a custom output_data_loader for different tracing backends. + Fetches traces from the Fireworks tracing proxy using rollout_id tags. See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation. """ @@ -39,7 +38,6 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, ): # Prefer constructor-provided configuration. These can be overridden via # config.kwargs at call time for backward compatibility. @@ -52,7 +50,7 @@ def __init__( self._model_base_url = _ep_model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds - self._output_data_loader = output_data_loader or default_fireworks_output_data_loader + self._output_data_loader = default_fireworks_output_data_loader self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url) def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: diff --git a/tests/github_actions/test_github_actions_rollout.py b/tests/github_actions/test_github_actions_rollout.py index b68236fd..76b643f2 100644 --- a/tests/github_actions/test_github_actions_rollout.py +++ b/tests/github_actions/test_github_actions_rollout.py @@ -12,9 +12,6 @@ from eval_protocol.models import EvaluationRow, InputMetadata from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig -from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation ROLLOUT_IDS = set() @@ -29,21 +26,6 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(config.rollout_id) - - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) - - def rows() -> List[EvaluationRow]: return [ EvaluationRow(input_metadata=InputMetadata(row_id=str(i))) @@ -68,7 +50,6 @@ def rows() -> List[EvaluationRow]: ref=os.getenv("GITHUB_REF", "main"), poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval timeout_seconds=300, - output_data_loader=fireworks_output_data_loader, ), ) async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index db5fdb49..23eb9e9c 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -1,6 +1,5 @@ # AUTO SERVER STARTUP: Server is automatically started and stopped by the test -import os import subprocess import socket import time @@ -13,9 +12,6 @@ from eval_protocol.models import EvaluationRow, Message, EvaluateResult from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig ROLLOUT_IDS = set() @@ -78,21 +74,6 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(config.rollout_id) - - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) - - def rows() -> List[EvaluationRow]: """Generate local rows with rich input_metadata to verify it survives remote traces.""" base_dataset_info = { @@ -118,7 +99,6 @@ def rows() -> List[EvaluationRow]: rollout_processor=RemoteRolloutProcessor( remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", timeout_seconds=180, - output_data_loader=fireworks_output_data_loader, ), ) async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: @@ -129,13 +109,11 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat - fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found """ row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result") + ROLLOUT_IDS.add(row.execution_metadata.rollout_id) assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." - assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( - f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" - ) assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b" assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level" diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 8e2aaaa8..81a3436f 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -12,9 +12,6 @@ from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig def find_available_port() -> int: @@ -67,18 +64,6 @@ def setup_remote_server(): process.wait() -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = FireworksTracingAdapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) - - def rows() -> List[EvaluationRow]: row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) return [row] @@ -92,7 +77,6 @@ def rows() -> List[EvaluationRow]: rollout_processor=RemoteRolloutProcessor( remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", timeout_seconds=120, - output_data_loader=fireworks_output_data_loader, ), ) async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 8c66b136..c2680311 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -1,16 +1,3 @@ -# MANUAL SERVER STARTUP REQUIRED: -# -# For Python server testing, start: -# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) -# -# For TypeScript server testing, start: -# cd tests/remote_server/typescript-server -# npm install -# npm start -# -# The TypeScript server should be running on http://127.0.0.1:3000 -# You only need to start one of the servers! - import os from typing import List @@ -20,35 +7,6 @@ from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.langfuse import create_langfuse_adapter -from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig - -ROLLOUT_IDS = set() - - -@pytest.fixture(autouse=True) -def check_rollout_coverage(): - """Ensure we processed all expected rollout_ids""" - global ROLLOUT_IDS - ROLLOUT_IDS.clear() - yield - - assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" - - -def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(config.rollout_id) - - adapter = create_langfuse_adapter() - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) - - -def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation - ) def rows() -> List[EvaluationRow]: @@ -62,25 +20,14 @@ def rows() -> List[EvaluationRow]: data_loaders=DynamicDataLoader( generators=[rows], ), - rollout_processor=RemoteRolloutProcessor( - remote_base_url="http://127.0.0.1:3000", - timeout_seconds=30, - output_data_loader=langfuse_output_data_loader, - model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk", - ), + rollout_processor=RemoteRolloutProcessor(remote_base_url="http://127.0.0.1:3000", timeout_seconds=30), ) async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: """ End-to-end test: - - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server - trigger remote rollout via RemoteRolloutProcessor (calls init/status) - - fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found """ assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." - assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( - f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" - ) - return row diff --git a/tests/remote_server/typescript-server/README.md b/tests/remote_server/typescript-server/README.md index 434b82a8..90a7083e 100644 --- a/tests/remote_server/typescript-server/README.md +++ b/tests/remote_server/typescript-server/README.md @@ -120,7 +120,6 @@ from eval_protocol import ( data_loaders=[InlineDataLoader(messages=[[Message(role="user", content="Hello")]])], rollout_processor=RemoteRolloutProcessor( remote_base_url="http://localhost:3000", - output_data_loader=create_output_data_loader, ) ) def test_remote_http(row: EvaluationRow) -> EvaluationRow: From 7f8ab42b8e7ccd2596dc861bbb55ec3ee213fa58 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 6 Jan 2026 15:41:05 -0800 Subject: [PATCH 2/5] update --- eval_protocol/pytest/github_action_rollout_processor.py | 3 +-- eval_protocol/pytest/remote_rollout_processor.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/eval_protocol/pytest/github_action_rollout_processor.py b/eval_protocol/pytest/github_action_rollout_processor.py index 621fa830..3e4f9ec0 100644 --- a/eval_protocol/pytest/github_action_rollout_processor.py +++ b/eval_protocol/pytest/github_action_rollout_processor.py @@ -50,7 +50,6 @@ def __init__( self.timeout_seconds = timeout_seconds self.max_find_workflow_retries = max_find_workflow_retries self.github_token = github_token - self._output_data_loader = default_fireworks_output_data_loader def _headers(self) -> Dict[str, str]: headers = {"Accept": "application/vnd.github+json"} @@ -198,7 +197,7 @@ def _get_run() -> Dict[str, Any]: row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time def _update_with_trace() -> None: - return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url) + return update_row_with_remote_trace(row, default_fireworks_output_data_loader, self.model_base_url) await asyncio.to_thread(_update_with_trace) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 5c884874..aa1c5d44 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -50,7 +50,6 @@ def __init__( self._model_base_url = _ep_model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds - self._output_data_loader = default_fireworks_output_data_loader self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url) def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: @@ -186,7 +185,7 @@ def _get_status() -> Dict[str, Any]: row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time def _update_with_trace() -> None: - return update_row_with_remote_trace(row, self._output_data_loader, model_base_url) + return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url) await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place return row From 292c5ef3bb68ced16359792ecb9c31f9db732ca5 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 6 Jan 2026 15:49:37 -0800 Subject: [PATCH 3/5] test --- .../test_github_actions_rollout.py | 16 +--------------- tests/remote_server/test_remote_fireworks.py | 13 ------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/tests/github_actions/test_github_actions_rollout.py b/tests/github_actions/test_github_actions_rollout.py index 76b643f2..8fa385d2 100644 --- a/tests/github_actions/test_github_actions_rollout.py +++ b/tests/github_actions/test_github_actions_rollout.py @@ -13,18 +13,6 @@ from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor -ROLLOUT_IDS = set() - - -@pytest.fixture(autouse=True) -def check_rollout_coverage(): - """Ensure we processed all expected rollout_ids""" - global ROLLOUT_IDS - ROLLOUT_IDS.clear() - yield - - assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" - def rows() -> List[EvaluationRow]: return [ @@ -54,9 +42,7 @@ def rows() -> List[EvaluationRow]: ) async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow: """Test GitHub Actions rollout with worker-controlled dataset.""" - # Track rollout IDs for coverage check - global ROLLOUT_IDS - ROLLOUT_IDS.add(row.execution_metadata.rollout_id) + assert row.execution_metadata.rollout_id is not None # This dataset is built into github_actions/rollout_worker.py if row.messages[0].content == "What is the capital of France?": diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 23eb9e9c..36627c10 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -13,8 +13,6 @@ from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -ROLLOUT_IDS = set() - def find_available_port() -> int: """Find an available port on localhost""" @@ -64,16 +62,6 @@ def setup_remote_server(): process.wait() -@pytest.fixture(autouse=True) -def check_rollout_coverage(): - """Ensure we processed all expected rollout_ids""" - global ROLLOUT_IDS - ROLLOUT_IDS.clear() - yield - - assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" - - def rows() -> List[EvaluationRow]: """Generate local rows with rich input_metadata to verify it survives remote traces.""" base_dataset_info = { @@ -109,7 +97,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat - fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found """ row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result") - ROLLOUT_IDS.add(row.execution_metadata.rollout_id) assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." From c9462e9ff83ed8fc70cdba6978ed6fbde481292d Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 6 Jan 2026 16:01:20 -0800 Subject: [PATCH 4/5] fix --- tests/remote_server/remote_server.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index e364b788..4ac4fd6c 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -37,8 +37,23 @@ def _worker(): if not model: raise ValueError("model is required in completion_params") + # Convert Eval Protocol Message objects into OpenAI-compatible dicts, + # excluding any None fields (Fireworks rejects extra keys even when null). + messages_payload = [] + for m in req.messages: + if hasattr(m, "dump_mdoel_for_chat_completion_request"): + md = m.dump_mdoel_for_chat_completion_request() # type: ignore[attr-defined] + elif hasattr(m, "model_dump"): + md = m.model_dump(exclude_none=True) # type: ignore[call-arg] + elif isinstance(m, dict): + md = {k: v for k, v in m.items() if v is not None} + else: + md = {"role": getattr(m, "role", None), "content": getattr(m, "content", None)} + md = {k: v for k, v in md.items() if v is not None} + messages_payload.append(md) + # Spread all completion_params (model, temperature, max_tokens, etc.) - completion_kwargs = {"messages": req.messages, **req.completion_params} + completion_kwargs = {"messages": messages_payload, **req.completion_params} if req.tools: completion_kwargs["tools"] = req.tools From e1d751238dc0fa2debe5dd71c4986e748aa09b21 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 6 Jan 2026 16:06:53 -0800 Subject: [PATCH 5/5] fix test --- .../test_github_actions_rollout.py | 27 ++++ tests/remote_server/test_remote_fireworks.py | 27 ++++ tests/remote_server/test_remote_langfuse.py | 121 +++++++++++++----- 3 files changed, 142 insertions(+), 33 deletions(-) diff --git a/tests/github_actions/test_github_actions_rollout.py b/tests/github_actions/test_github_actions_rollout.py index 8fa385d2..f8f7775a 100644 --- a/tests/github_actions/test_github_actions_rollout.py +++ b/tests/github_actions/test_github_actions_rollout.py @@ -12,6 +12,33 @@ from eval_protocol.models import EvaluationRow, InputMetadata from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor +import eval_protocol.pytest.github_action_rollout_processor as github_action_rollout_processor_module +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig + + +ROLLOUT_IDS = set() + + +@pytest.fixture(autouse=True) +def check_rollout_coverage(monkeypatch): + """ + Ensure we attempted to fetch remote traces for each rollout. + + This wraps the built-in default_fireworks_output_data_loader (without making it configurable) + and tracks rollout_ids passed through its DataLoaderConfig. + """ + global ROLLOUT_IDS + ROLLOUT_IDS.clear() + + original_loader = github_action_rollout_processor_module.default_fireworks_output_data_loader + + def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader: + ROLLOUT_IDS.add(config.rollout_id) + return original_loader(config) + + monkeypatch.setattr(github_action_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader) + yield + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" def rows() -> List[EvaluationRow]: diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 36627c10..e172e309 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -12,6 +12,33 @@ from eval_protocol.models import EvaluationRow, Message, EvaluateResult from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +import eval_protocol.pytest.remote_rollout_processor as remote_rollout_processor_module +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig + + +ROLLOUT_IDS = set() + + +@pytest.fixture(autouse=True) +def check_rollout_coverage(monkeypatch): + """ + Ensure we attempted to fetch remote traces for each rollout. + + This wraps the built-in default_fireworks_output_data_loader (without making it configurable) + and tracks rollout_ids passed through its DataLoaderConfig. + """ + global ROLLOUT_IDS + ROLLOUT_IDS.clear() + + original_loader = remote_rollout_processor_module.default_fireworks_output_data_loader + + def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader: + ROLLOUT_IDS.add(config.rollout_id) + return original_loader(config) + + monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader) + yield + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" def find_available_port() -> int: diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index c2680311..f53e304f 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -1,33 +1,88 @@ -import os -from typing import List - -import pytest - -from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest import evaluation_test -from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor - - -def rows() -> List[EvaluationRow]: - row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) - return [row, row, row] - - -@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") -@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) -@evaluation_test( - data_loaders=DynamicDataLoader( - generators=[rows], - ), - rollout_processor=RemoteRolloutProcessor(remote_base_url="http://127.0.0.1:3000", timeout_seconds=30), -) -async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: - """ - End-to-end test: - - trigger remote rollout via RemoteRolloutProcessor (calls init/status) - """ - assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" - assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." - - return row +# NOTE: This test is deprecated. We no longer support custom output data loaders, including pulling from Langfuse. We can revisit this in the future. + +# # MANUAL SERVER STARTUP REQUIRED: +# # +# # For Python server testing, start: +# # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# # +# # For TypeScript server testing, start: +# # cd tests/remote_server/typescript-server +# # npm install +# # npm start +# # +# # The TypeScript server should be running on http://127.0.0.1:3000 +# # You only need to start one of the servers! + +# import os +# from typing import List + +# import pytest + +# from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +# from eval_protocol.models import EvaluationRow, Message +# from eval_protocol.pytest import evaluation_test +# from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +# from eval_protocol.adapters.langfuse import create_langfuse_adapter +# from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation +# from eval_protocol.types.remote_rollout_processor import DataLoaderConfig + +# ROLLOUT_IDS = set() + + +# @pytest.fixture(autouse=True) +# def check_rollout_coverage(): +# """Ensure we processed all expected rollout_ids""" +# global ROLLOUT_IDS +# ROLLOUT_IDS.clear() +# yield + +# assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" + + +# def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: +# global ROLLOUT_IDS # Track all rollout_ids we've seen +# ROLLOUT_IDS.add(config.rollout_id) + +# adapter = create_langfuse_adapter() +# return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + + +# def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: +# return DynamicDataLoader( +# generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation +# ) + + +# def rows() -> List[EvaluationRow]: +# row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) +# return [row, row, row] + + +# @pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +# @pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +# @evaluation_test( +# data_loaders=DynamicDataLoader( +# generators=[rows], +# ), +# rollout_processor=RemoteRolloutProcessor( +# remote_base_url="http://127.0.0.1:3000", +# timeout_seconds=30, +# output_data_loader=langfuse_output_data_loader, +# model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk", +# ), +# ) +# async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow: +# """ +# End-to-end test: +# - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server +# - trigger remote rollout via RemoteRolloutProcessor (calls init/status) +# - fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found +# """ +# assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" +# assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + +# assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( +# f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" +# ) + +# return row