Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
95 changes: 95 additions & 0 deletions evaluators/bertscore/bertscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""BERTScore semantic similarity evaluator.

Computes BERTScore (precision, recall, F1) between each invocation's response
and a reference text using contextual embeddings and cosine similarity.

Config:
expected (str): Required. Reference text to compare against. If omitted, returns NOT_EVALUATED.
model_name (str, default "distilbert-base-uncased"): HuggingFace model for embeddings.
metric (str, default "f1"): Primary score component: "precision", "recall", or "f1".

Usage:
config:
expected: "reference answer"
metric: "f1"
"""

from __future__ import annotations

import sys

import torch
from transformers import AutoModel, AutoTokenizer

from agentevals_evaluator_sdk import EvalInput, EvalResult, EvalStatus, evaluator

_VALID_METRICS = ("precision", "recall", "f1")


def _compute_bertscore(
candidate: str,
reference: str,
tokenizer: AutoTokenizer,
model: AutoModel,
device: torch.device,
) -> dict[str, float]:
"""Compute BERTScore precision, recall, and F1 for a candidate-reference pair."""
cand = tokenizer(candidate, return_tensors="pt", padding=True, truncation=True).to(device)
ref = tokenizer(reference, return_tensors="pt", padding=True, truncation=True).to(device)

with torch.no_grad():
cand_emb = torch.nn.functional.normalize(model(**cand).last_hidden_state, dim=-1)
ref_emb = torch.nn.functional.normalize(model(**ref).last_hidden_state, dim=-1)

sim = torch.bmm(cand_emb, ref_emb.transpose(1, 2))

precision = sim.max(dim=2)[0].mean().item()
recall = sim.max(dim=1)[0].mean().item()
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

return {"precision": precision, "recall": recall, "f1": f1}


@evaluator
def bertscore(input: EvalInput) -> EvalResult:
expected = input.config.get("expected")
if expected is None:
n = len(input.invocations)
return EvalResult(
score=0.0,
status=EvalStatus.NOT_EVALUATED,
per_invocation_scores=[None] * n,
details={"reason": "missing config: expected"},
)

metric = input.config.get("metric", "f1")
if metric not in _VALID_METRICS:
print(f"WARNING: invalid metric '{metric}', using 'f1'", file=sys.stderr)
metric = "f1"

model_name = input.config.get("model_name", "distilbert-base-uncased")
reference = str(expected)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

scores: list[float] = []
details_rows: list[dict] = []

for inv in input.invocations:
result = _compute_bertscore(inv.final_response or "", reference, tokenizer, model, device)
scores.append(result[metric])
details_rows.append({"invocation_id": inv.invocation_id, **result})

overall = sum(scores) / len(scores) if scores else 0.0
return EvalResult(
score=overall,
per_invocation_scores=scores,
details={"metric": metric, "model_name": model_name, "per_invocation": details_rows},
)


if __name__ == "__main__":
bertscore.run()
6 changes: 6 additions & 0 deletions evaluators/bertscore/evaluator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: bertscore
description: Semantic similarity scoring using BERTScore (precision, recall, F1) between response and reference text
language: python
entrypoint: bertscore.py
tags: [semantic, similarity, bert, nlp]
author: agentevals-dev
3 changes: 3 additions & 0 deletions evaluators/bertscore/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch
transformers
14 changes: 14 additions & 0 deletions scripts/validate_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def validate_syntax(evaluator_dir: Path, manifest: dict) -> bool:
return False
_ok("Imports, decorator, and explicit run() present")

req_file = evaluator_dir / "requirements.txt"
if req_file.exists():
lines = req_file.read_text().strip().splitlines()
valid_lines = [l.strip() for l in lines if l.strip() and not l.strip().startswith("#")]
if not valid_lines:
_fail("requirements.txt exists but contains no dependencies")
return False
_ok(f"requirements.txt found ({len(valid_lines)} dependencies)")

elif language in ("javascript", "typescript"):
ext = Path(entrypoint).suffix
expected = LANGUAGE_EXTENSIONS.get(language, set())
Expand All @@ -124,6 +133,11 @@ def validate_syntax(evaluator_dir: Path, manifest: dict) -> bool:

def validate_smoke_run(evaluator_dir: Path, manifest: dict) -> bool:
"""Run the evaluator with synthetic input and validate the output."""
req_file = evaluator_dir / "requirements.txt"
if req_file.exists() and req_file.read_text().strip():
_ok("Skipping smoke run: evaluator has requirements.txt (deps may not be installed)")
return True

language = manifest.get("language", "python")
entrypoint = manifest["entrypoint"]
entry_path = evaluator_dir / entrypoint
Expand Down
Loading