Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,31 @@ At the moment, we run them manually every time we ship a PR. They're a handful o
```
python evals.py
```
from within the `evals/` folder. If you have suggestions we'd love to hear.
from within the `evals/` folder.

You can select which model(s) and eval(s) to run:
```bash
# Run all evals with all models (default)
python evals.py

# Run all evals with a specific model
python evals.py -m gemini

# Run specific evals only
python evals.py -e article question

# Combine model and eval selection
python evals.py -m claude -e allower orchestrator

# List available models and evals
python evals.py --list
```

Available models: `gemini`, `claude`, `claude-aws-bedrock`, `nova-lite-aws-bedrock`

Available evals: `allower`, `orchestrator`, `summary`, `question`, `article`, `general`

If you have suggestions we'd love to hear.

# Acknowledgments

Expand Down
152 changes: 105 additions & 47 deletions evals/evals.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import asyncio
import os

Expand All @@ -14,63 +15,120 @@

console = Console()


async def main():

load_dotenv()

for model_family in [
"gemini",
"claude",
"claude-aws-bedrock",
"nova-lite-aws-bedrock",
]:

if model_family == "gemini" and os.getenv("GOOGLE_API_KEY") is None:
ALL_MODELS = ["gemini", "claude", "claude-aws-bedrock", "nova-lite-aws-bedrock"]
ALL_EVALS = ["allower", "orchestrator", "summary", "question", "article", "general"]

EVAL_RUNNERS = {
"allower": run_evals_allower,
"orchestrator": run_evals_orchestrator,
"summary": run_evals_summary,
"question": run_evals_question,
"article": run_evals_article,
"general": run_evals_general,
}


def parse_args():
parser = argparse.ArgumentParser(
description="Run askademic evaluation suite",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=f"""
Examples:
python evals.py # Run all evals with all models
python evals.py -m gemini # Run all evals with gemini only
python evals.py -e article question # Run article and question evals
python evals.py -m claude -e allower # Run allower eval with claude only

Available models: {', '.join(ALL_MODELS)}
Available evals: {', '.join(ALL_EVALS)}
""",
)
parser.add_argument(
"-m",
"--model",
nargs="+",
choices=ALL_MODELS,
default=ALL_MODELS,
metavar="MODEL",
help=f"Model(s) to run evals with. Choices: {', '.join(ALL_MODELS)}",
)
parser.add_argument(
"-e",
"--eval",
nargs="+",
choices=ALL_EVALS,
default=ALL_EVALS,
metavar="EVAL",
help=f"Eval(s) to run. Choices: {', '.join(ALL_EVALS)}",
)
parser.add_argument(
"--list",
action="store_true",
help="List available models and evals, then exit",
)
return parser.parse_args()


def check_model_credentials(model_family: str) -> bool:
"""Check if credentials are available for the given model family."""
if model_family == "gemini" and os.getenv("GOOGLE_API_KEY") is None:
console.print(
"[bold red]GOOGLE_API_KEY environment variable is not set. "
"Skipping gemini evals.[/bold red]"
)
return False

if model_family == "claude" and os.getenv("ANTHROPIC_API_KEY") is None:
console.print(
"[bold red]ANTHROPIC_API_KEY environment variable is not set. "
"Skipping claude evals.[/bold red]"
)
return False

if model_family in ("claude-aws-bedrock", "nova-lite-aws-bedrock"):
try:
_ = boto3.client("sts").get_caller_identity()
except (ClientError, NoCredentialsError):
console.print(
"""
[bold red]GOOGLE_API_KEY environment variable is not set.
Skipping evals.[/bold red]
"""
f"[bold red]AWS credentials are not set or invalid. "
f"Skipping {model_family} evals.[/bold red]"
)
continue
return False

if model_family == "claude" and os.getenv("ANTHROPIC_API_KEY") is None:
console.print(
"""
[bold red]ANTHROPIC_API_KEY environment variable is not set.
Skipping CLAUDE evals.[/bold red]
"""
)
continue
return True


async def main():
args = parse_args()

if model_family in ("claude-aws-bedrock", "nova-lite-aws-bedrock"):
try:
_ = boto3.client("sts").get_caller_identity()
except (ClientError, NoCredentialsError):
console.print(
f"""[bold red]AWS credentials are not set or invalid.
Skipping {model_family} evals.[/bold red]"""
)
continue
if args.list:
console.print("[bold]Available models:[/bold]")
for model in ALL_MODELS:
console.print(f" - {model}")
console.print("\n[bold]Available evals:[/bold]")
for eval_name in ALL_EVALS:
console.print(f" - {eval_name}")
return

console.print("\n[bold magenta]Running allower evals...[/bold magenta]")
await run_evals_allower(model_family)
load_dotenv()

console.print("\n[bold magenta]Running orchestrator evals...[/bold magenta]")
await run_evals_orchestrator(model_family)
models = args.model
evals = args.eval

console.print("\n[bold magenta]Running summary evals...[/bold magenta]")
await run_evals_summary(model_family)
console.print(f"[bold cyan]Models:[/bold cyan] {', '.join(models)}")
console.print(f"[bold cyan]Evals:[/bold cyan] {', '.join(evals)}\n")

console.print("\n[bold magenta]Running question evals...[/bold magenta]")
await run_evals_question(model_family)
for model_family in models:
if not check_model_credentials(model_family):
continue

console.print("\n[bold magenta]Running article evals...[/bold magenta]")
await run_evals_article(model_family)
console.print(f"\n[bold blue]===== Model: {model_family} =====[/bold blue]")

console.print("\n[bold magenta]Running general agent evals...[/bold magenta]")
await run_evals_general(model_family)
for eval_name in evals:
console.print(
f"\n[bold magenta]Running {eval_name} evals...[/bold magenta]"
)
await EVAL_RUNNERS[eval_name](model_family)


if __name__ == "__main__":
Expand Down
94 changes: 74 additions & 20 deletions evals/evals_article.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@


class ArticleResponseTestCase:
def __init__(self, request: str, article_data: str, title: str, link: str):
def __init__(
self,
request: str,
article_data: str,
title: str,
link: str,
fuzzy_keywords: list[str] = None,
):
self.request = request
self.article_data = article_data
self.title = title
self.link = link
# For fuzzy matching: if set, just check that response contains
# at least one keyword and has a valid arXiv link format
self.fuzzy_keywords = fuzzy_keywords


eval_cases = [
Expand Down Expand Up @@ -47,12 +57,14 @@ def __init__(self, request: str, article_data: str, title: str, link: str):
"THE DETERMINISTIC KERMACK-MCKENDRICK MODEL BOUNDS THE GENERAL STOCHASTIC EPIDEMIC",
"https://arxiv.org/pdf/1602.01730.pdf",
),
# not existing paper
# Fuzzy match: paper doesn't exist exactly, so we just check that a relevant
# paper is returned (contains keywords) with a valid arXiv link
ArticleResponseTestCase(
"Find this paper 'Quark Gluon plasma and AI'",
"http://arxiv.org/pdf/2412.19393v1",
"Hydrodynamic Description of the Quark-Gluon Plasma ",
"http://arxiv.org/pdf/2311.10621v2",
"",
"",
"",
fuzzy_keywords=["quark", "gluon", "plasma", "qgp"],
),
]

Expand All @@ -65,6 +77,55 @@ def __init__(self, request: str, article_data: str, title: str, link: str):
MAX_ATTEMPTS = 5


def check_fuzzy_match(case, response) -> tuple[bool, str]:
"""
Check if response passes fuzzy matching criteria.
Returns (passed, reason) tuple.
"""
title = response.output.article_title.lower()
link = response.output.article_link

# Check for valid arXiv link format
link_match = re.match(LINK_PATTERN, link)
if link_match is None:
return False, f"Invalid arXiv link format: {link}"

# Check if at least one keyword is in the title
keyword_found = any(kw.lower() in title for kw in case.fuzzy_keywords)
if not keyword_found:
return False, f"No keywords {case.fuzzy_keywords} found in title: {title}"

return True, ""


def check_exact_match(case, response) -> tuple[bool, str]:
"""
Check if response passes exact matching criteria.
Returns (passed, reason) tuple.
"""
match1 = re.match(LINK_PATTERN, case.link)
match2 = re.match(LINK_PATTERN, response.output.article_link)

# Check titles match (case insensitive)
title_matches = case.title.lower().replace(
"\n", " "
) == response.output.article_title.lower().replace("\n", " ")

# Check links regex match exists and IDs are the same
links_match = (
match1 is not None and match2 is not None and match1.group(1) == match2.group(1)
)

if not title_matches or not links_match:
reason = (
f"Got: {response.output.article_title} and {response.output.article_link}\n"
f"Expected: {case.title} and {case.link}"
)
return False, reason

return True, ""


async def run_evals(model_family: str):

model, model_settings = choose_model(model_family)
Expand All @@ -79,22 +140,15 @@ async def run_evals(model_family: str):
print(f"Evaluating case: {case.request}")
response = await article_agent.run(request=case.request)

match1 = re.match(LINK_PATTERN, case.link)
match2 = re.match(LINK_PATTERN, response.output.article_link)

# check titles match (case insensitive), links regex match exists and are the same
if (
case.title.lower().replace("\n", " ")
!= response.output.article_title.lower().replace("\n", " ")
or match1 is None
or match2 is None
or match1.group(1) != match2.group(1)
):
# Use fuzzy matching if keywords are specified, else exact match
if case.fuzzy_keywords:
passed, reason = check_fuzzy_match(case, response)
else:
passed, reason = check_exact_match(case, response)

if not passed:
print(f"Test failed for question: {case.request}")
print(
f"Got: {response.output.article_title} and {response.output.article_link}"
)
print(f"Expected: {case.title} and {case.link}")
print(reason)
print("\n")
c_failed += 1
else:
Expand Down
Loading