Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import os
import time
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
import json
import requests
from datetime import datetime, timezone, timedelta
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig

from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
Expand All @@ -21,7 +20,7 @@ class GithubActionRolloutProcessor(RolloutProcessor):
Expected GitHub Actions workflow:
- Workflow dispatch with inputs: completion_params, metadata (JSON), model_base_url, api_key
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
- Traces are fetched later via output_data_loader using rollout_id tags
- Traces are fetched later via Fireworks tracing proxy using rollout_id tags

NOTE: GHA has a rate limit of 5000 requests per hour.
"""
Expand All @@ -38,7 +37,6 @@ def __init__(
timeout_seconds: float = 1800.0,
max_find_workflow_retries: int = 5,
github_token: Optional[str] = None,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
):
self.owner = owner
self.repo = repo
Expand All @@ -52,7 +50,6 @@ def __init__(
self.timeout_seconds = timeout_seconds
self.max_find_workflow_retries = max_find_workflow_retries
self.github_token = github_token
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader

def _headers(self) -> Dict[str, str]:
headers = {"Accept": "application/vnd.github+json"}
Expand Down Expand Up @@ -200,7 +197,7 @@ def _get_run() -> Dict[str, Any]:
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time

def _update_with_trace() -> None:
return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url)
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, self.model_base_url)

await asyncio.to_thread(_update_with_trace)

Expand Down
9 changes: 3 additions & 6 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from typing import Any, Dict, List, Optional, Callable
from typing import Any, Dict, List, Optional

import requests

Expand All @@ -26,8 +26,7 @@ class RemoteRolloutProcessor(RolloutProcessor):
"""
Rollout processor that triggers a remote HTTP server to perform the rollout.

By default, fetches traces from the Fireworks tracing proxy using rollout_id tags.
You can provide a custom output_data_loader for different tracing backends.
Fetches traces from the Fireworks tracing proxy using rollout_id tags.

See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
"""
Expand All @@ -39,7 +38,6 @@ def __init__(
model_base_url: str = "https://tracing.fireworks.ai",
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
Expand All @@ -52,7 +50,6 @@ def __init__(
self._model_base_url = _ep_model_base_url
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
Expand Down Expand Up @@ -188,7 +185,7 @@ def _get_status() -> Dict[str, Any]:
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time

def _update_with_trace() -> None:
return update_row_with_remote_trace(row, self._output_data_loader, model_base_url)
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url)

await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place
return row
Expand Down
40 changes: 17 additions & 23 deletions tests/github_actions/test_github_actions_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,33 @@
from eval_protocol.models import EvaluationRow, InputMetadata
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor
import eval_protocol.pytest.github_action_rollout_processor as github_action_rollout_processor_module
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation


ROLLOUT_IDS = set()


@pytest.fixture(autouse=True)
def check_rollout_coverage():
"""Ensure we processed all expected rollout_ids"""
def check_rollout_coverage(monkeypatch):
"""
Ensure we attempted to fetch remote traces for each rollout.

This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
and tracks rollout_ids passed through its DataLoaderConfig.
"""
global ROLLOUT_IDS
ROLLOUT_IDS.clear()
yield

assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"

original_loader = github_action_rollout_processor_module.default_fireworks_output_data_loader

def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
global ROLLOUT_IDS # Track all rollout_ids we've seen
ROLLOUT_IDS.add(config.rollout_id)
def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
ROLLOUT_IDS.add(config.rollout_id)
return original_loader(config)

base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)
monkeypatch.setattr(github_action_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
yield
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"


def rows() -> List[EvaluationRow]:
Expand All @@ -68,14 +65,11 @@ def rows() -> List[EvaluationRow]:
ref=os.getenv("GITHUB_REF", "main"),
poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval
timeout_seconds=300,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow:
"""Test GitHub Actions rollout with worker-controlled dataset."""
# Track rollout IDs for coverage check
global ROLLOUT_IDS
ROLLOUT_IDS.add(row.execution_metadata.rollout_id)
assert row.execution_metadata.rollout_id is not None

# This dataset is built into github_actions/rollout_worker.py
if row.messages[0].content == "What is the capital of France?":
Expand Down
17 changes: 16 additions & 1 deletion tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,23 @@ def _worker():
if not model:
raise ValueError("model is required in completion_params")

# Convert Eval Protocol Message objects into OpenAI-compatible dicts,
# excluding any None fields (Fireworks rejects extra keys even when null).
messages_payload = []
for m in req.messages:
if hasattr(m, "dump_mdoel_for_chat_completion_request"):
md = m.dump_mdoel_for_chat_completion_request() # type: ignore[attr-defined]
elif hasattr(m, "model_dump"):
md = m.model_dump(exclude_none=True) # type: ignore[call-arg]
elif isinstance(m, dict):
md = {k: v for k, v in m.items() if v is not None}
else:
md = {"role": getattr(m, "role", None), "content": getattr(m, "content", None)}
md = {k: v for k, v in md.items() if v is not None}
messages_payload.append(md)

# Spread all completion_params (model, temperature, max_tokens, etc.)
completion_kwargs = {"messages": req.messages, **req.completion_params}
completion_kwargs = {"messages": messages_payload, **req.completion_params}

if req.tools:
completion_kwargs["tools"] = req.tools
Expand Down
56 changes: 24 additions & 32 deletions tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test

import os
import subprocess
import socket
import time
Expand All @@ -13,13 +12,35 @@
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
import eval_protocol.pytest.remote_rollout_processor as remote_rollout_processor_module
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig


ROLLOUT_IDS = set()


@pytest.fixture(autouse=True)
def check_rollout_coverage(monkeypatch):
"""
Ensure we attempted to fetch remote traces for each rollout.

This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
and tracks rollout_ids passed through its DataLoaderConfig.
"""
global ROLLOUT_IDS
ROLLOUT_IDS.clear()

original_loader = remote_rollout_processor_module.default_fireworks_output_data_loader

def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
ROLLOUT_IDS.add(config.rollout_id)
return original_loader(config)

monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
yield
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"


def find_available_port() -> int:
"""Find an available port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand Down Expand Up @@ -68,31 +89,6 @@ def setup_remote_server():
process.wait()


@pytest.fixture(autouse=True)
def check_rollout_coverage():
"""Ensure we processed all expected rollout_ids"""
global ROLLOUT_IDS
ROLLOUT_IDS.clear()
yield

assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
global ROLLOUT_IDS # Track all rollout_ids we've seen
ROLLOUT_IDS.add(config.rollout_id)

base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
"""Generate local rows with rich input_metadata to verify it survives remote traces."""
base_dataset_info = {
Expand All @@ -118,7 +114,6 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=180,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow:
Expand All @@ -133,9 +128,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."

assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
)
assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"

Expand Down
16 changes: 0 additions & 16 deletions tests/remote_server/test_remote_fireworks_propagate_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig


def find_available_port() -> int:
Expand Down Expand Up @@ -67,18 +64,6 @@ def setup_remote_server():
process.wait()


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
return [row]
Expand All @@ -92,7 +77,6 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=120,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
Expand Down
Loading
Loading