Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions src/eva/orchestrator/validation_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Validation metrics runner for benchmark validation mode."""

import asyncio
import json
from dataclasses import dataclass, field
from pathlib import Path

Expand Down Expand Up @@ -72,28 +73,50 @@ async def run_validation(self) -> dict[str, ValidationResult]:
gate_run = await gate_runner.run(contexts=contexts)

gate_passed, not_finished, agent_timeout_ids = self._partition(check_ids, gate_run.all_metrics)

# Separate time-limit-exceeded records from truly not_finished — they get LLM
# metrics evaluated with gate bypass (same as validate_one(skip_gate=True)).
time_limit_ids = []
truly_not_finished = []
for record_id in not_finished:
result_path = self.run_dir / "records" / record_id / "result.json"
try:
with open(result_path) as f:
result_data = json.load(f)
if result_data.get("conversation_ended_reason") == "time_limit_exceeded":
time_limit_ids.append(record_id)
continue
except (FileNotFoundError, json.JSONDecodeError):
pass
truly_not_finished.append(record_id)

logger.info(
f"Gate: {len(gate_passed)} passed ({len(agent_timeout_ids)} agent_timeout_on_user_turn), "
f"{len(not_finished)} not_finished"
f"{len(truly_not_finished)} not_finished, {len(time_limit_ids)} time_limit_exceeded"
)

for record_id in not_finished:
for record_id in truly_not_finished:
validation_results[record_id] = ValidationResult(passed=False)

if gate_passed:
# Run LLM metrics on gate-passed and time-limit records together.
# gate_passed records get GATE_METRIC=1.0 in their scores; time-limit records do not.
all_llm_ids = gate_passed + time_limit_ids
if all_llm_ids:
metrics_runner = MetricsRunner(
run_dir=self.run_dir,
dataset=self.dataset,
metric_names=LLM_METRICS,
metric_configs=self.metric_configs,
record_ids=gate_passed,
record_ids=all_llm_ids,
)
passed_contexts = {rid: contexts[rid] for rid in gate_passed if rid in contexts}
metrics_run = await metrics_runner.run(contexts=passed_contexts)
all_llm_contexts = {rid: contexts[rid] for rid in all_llm_ids if rid in contexts}
metrics_run = await metrics_runner.run(contexts=all_llm_contexts)

gate_passed_set = set(gate_passed)
for record_id, record_metrics in metrics_run.all_metrics.items():
vr = self._evaluate_record(record_id, record_metrics, LLM_METRICS)
vr.scores[GATE_METRIC] = 1.0
if record_id in gate_passed_set:
vr.scores[GATE_METRIC] = 1.0
validation_results[record_id] = vr

passed_count = sum(1 for vr in validation_results.values() if vr.passed)
Expand Down
8 changes: 4 additions & 4 deletions src/eva/orchestrator/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ async def run(self) -> ConversationResult:
except asyncio.CancelledError:
conversation_ended_reason = "cancelled"
logger.info(f"Conversation {self.record.id} was cancelled")
finally:
# Collect stats regardless of how the conversation ended
if self._assistant_server:
self._conversation_stats = self._assistant_server.get_conversation_stats()

except asyncio.CancelledError:
conversation_ended_reason = "cancelled"
Expand Down Expand Up @@ -374,10 +378,6 @@ async def _run_conversation(self) -> str:

ended_reason = await self._user_simulator.run_conversation()

# Collect stats from assistant
if self._assistant_server:
self._conversation_stats = self._assistant_server.get_conversation_stats()

return ended_reason

def _resolve_framework_logs_path(self) -> str:
Expand Down
82 changes: 74 additions & 8 deletions tests/unit/orchestrator/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand Down Expand Up @@ -219,18 +219,84 @@ async def test_raises_when_simulator_not_initialized(self, tmp_path):
await worker._run_conversation()

@pytest.mark.asyncio
async def test_returns_ended_reason_and_captures_stats(self, tmp_path):
async def test_returns_ended_reason(self, tmp_path):
worker = _make_worker(tmp_path)
mock_sim = MagicMock()
mock_sim.run_conversation = AsyncMock(return_value="goodbye")
worker._user_simulator = mock_sim

stats = {"num_turns": 5, "num_tool_calls": 2, "tools_called": ["get_reservation"]}
mock_server = MagicMock()
mock_server.get_conversation_stats.return_value = stats
worker._assistant_server = mock_server

result = await worker._run_conversation()

assert result == "goodbye"
assert worker._conversation_stats == stats


def _setup_run_mocks(worker: ConversationWorker, stats: dict, run_conversation_side_effect=None):
"""Wire up the mocks needed to run worker.run() in isolation.

Sets worker._assistant_server to a mock that returns *stats* from
get_conversation_stats(), writes the two DB files run() expects on disk,
and returns a context manager that patches the expensive internal methods.
"""
worker.output_dir.mkdir(parents=True, exist_ok=True)
(worker.output_dir / "initial_scenario_db.json").write_text(json.dumps({}))
(worker.output_dir / "final_scenario_db.json").write_text(json.dumps({}))

mock_server = MagicMock()
mock_server.get_conversation_stats.return_value = stats
worker._assistant_server = mock_server

run_conv_mock = AsyncMock(
side_effect=run_conversation_side_effect,
return_value="goodbye" if run_conversation_side_effect is None else None,
)

patches = [
patch.object(worker, "_start_assistant", AsyncMock()),
patch.object(worker, "_start_user_simulator", AsyncMock()),
patch.object(worker, "_cleanup", AsyncMock()),
patch.object(worker, "_run_conversation", run_conv_mock),
patch.object(worker, "_calculate_llm_latency", return_value=None),
patch.object(worker, "_calculate_stt_latency", return_value=None),
patch.object(worker, "_calculate_tts_latency", return_value=None),
patch.object(worker, "_calculate_model_response_latency", return_value=None),
patch("eva.orchestrator.worker.add_record_log_file", return_value=MagicMock()),
]

class _Ctx:
async def __aenter__(self_):
for p in patches:
p.start()
return self_

async def __aexit__(self_, *_args):
for p in reversed(patches):
p.stop()

return _Ctx()


class TestConversationStatsInRun:
@pytest.mark.asyncio
async def test_stats_captured_on_normal_completion(self, tmp_path):
worker = _make_worker(tmp_path)
stats = {"num_turns": 4, "num_tool_calls": 2, "tools_called": ["lookup_user"]}

async with _setup_run_mocks(worker, stats):
result = await worker.run()

assert result.num_turns == 4
assert result.num_tool_calls == 2
assert result.conversation_ended_reason != "error"

@pytest.mark.asyncio
async def test_stats_captured_on_time_limit_exceeded(self, tmp_path):
"""Regression test: num_turns must be non-zero even when the conversation times out."""
worker = _make_worker(tmp_path)
stats = {"num_turns": 3, "num_tool_calls": 1, "tools_called": ["lookup_user"]}

async with _setup_run_mocks(worker, stats, run_conversation_side_effect=TimeoutError()):
result = await worker.run()

assert result.num_turns == 3
assert result.num_tool_calls == 1
assert result.conversation_ended_reason == "time_limit_exceeded"
Loading