diff --git a/README.md b/README.md
index 1110f7b..f55a086 100644
--- a/README.md
+++ b/README.md
@@ -91,7 +91,31 @@ At the moment, we run them manually every time we ship a PR. They're a handful o
```
python evals.py
```
-from within the `evals/` folder. If you have suggestions we'd love to hear.
+from within the `evals/` folder.
+
+You can select which model(s) and eval(s) to run:
+```bash
+# Run all evals with all models (default)
+python evals.py
+
+# Run all evals with a specific model
+python evals.py -m gemini
+
+# Run specific evals only
+python evals.py -e article question
+
+# Combine model and eval selection
+python evals.py -m claude -e allower orchestrator
+
+# List available models and evals
+python evals.py --list
+```
+
+Available models: `gemini`, `claude`, `claude-aws-bedrock`, `nova-lite-aws-bedrock`
+
+Available evals: `allower`, `orchestrator`, `summary`, `question`, `article`, `general`
+
+If you have suggestions we'd love to hear.
# Acknowledgments
diff --git a/evals/evals.py b/evals/evals.py
index cf1a9b0..02c9da0 100644
--- a/evals/evals.py
+++ b/evals/evals.py
@@ -1,3 +1,4 @@
+import argparse
import asyncio
import os
@@ -14,63 +15,120 @@
console = Console()
-
-async def main():
-
- load_dotenv()
-
- for model_family in [
- "gemini",
- "claude",
- "claude-aws-bedrock",
- "nova-lite-aws-bedrock",
- ]:
-
- if model_family == "gemini" and os.getenv("GOOGLE_API_KEY") is None:
+ALL_MODELS = ["gemini", "claude", "claude-aws-bedrock", "nova-lite-aws-bedrock"]
+ALL_EVALS = ["allower", "orchestrator", "summary", "question", "article", "general"]
+
+EVAL_RUNNERS = {
+ "allower": run_evals_allower,
+ "orchestrator": run_evals_orchestrator,
+ "summary": run_evals_summary,
+ "question": run_evals_question,
+ "article": run_evals_article,
+ "general": run_evals_general,
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Run askademic evaluation suite",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=f"""
+Examples:
+ python evals.py # Run all evals with all models
+ python evals.py -m gemini # Run all evals with gemini only
+ python evals.py -e article question # Run article and question evals
+ python evals.py -m claude -e allower # Run allower eval with claude only
+
+Available models: {', '.join(ALL_MODELS)}
+Available evals: {', '.join(ALL_EVALS)}
+ """,
+ )
+ parser.add_argument(
+ "-m",
+ "--model",
+ nargs="+",
+ choices=ALL_MODELS,
+ default=ALL_MODELS,
+ metavar="MODEL",
+ help=f"Model(s) to run evals with. Choices: {', '.join(ALL_MODELS)}",
+ )
+ parser.add_argument(
+ "-e",
+ "--eval",
+ nargs="+",
+ choices=ALL_EVALS,
+ default=ALL_EVALS,
+ metavar="EVAL",
+ help=f"Eval(s) to run. Choices: {', '.join(ALL_EVALS)}",
+ )
+ parser.add_argument(
+ "--list",
+ action="store_true",
+ help="List available models and evals, then exit",
+ )
+ return parser.parse_args()
+
+
+def check_model_credentials(model_family: str) -> bool:
+ """Check if credentials are available for the given model family."""
+ if model_family == "gemini" and os.getenv("GOOGLE_API_KEY") is None:
+ console.print(
+ "[bold red]GOOGLE_API_KEY environment variable is not set. "
+ "Skipping gemini evals.[/bold red]"
+ )
+ return False
+
+ if model_family == "claude" and os.getenv("ANTHROPIC_API_KEY") is None:
+ console.print(
+ "[bold red]ANTHROPIC_API_KEY environment variable is not set. "
+ "Skipping claude evals.[/bold red]"
+ )
+ return False
+
+ if model_family in ("claude-aws-bedrock", "nova-lite-aws-bedrock"):
+ try:
+ _ = boto3.client("sts").get_caller_identity()
+ except (ClientError, NoCredentialsError):
console.print(
- """
- [bold red]GOOGLE_API_KEY environment variable is not set.
- Skipping evals.[/bold red]
- """
+ f"[bold red]AWS credentials are not set or invalid. "
+ f"Skipping {model_family} evals.[/bold red]"
)
- continue
+ return False
- if model_family == "claude" and os.getenv("ANTHROPIC_API_KEY") is None:
- console.print(
- """
- [bold red]ANTHROPIC_API_KEY environment variable is not set.
- Skipping CLAUDE evals.[/bold red]
- """
- )
- continue
+ return True
+
+
+async def main():
+ args = parse_args()
- if model_family in ("claude-aws-bedrock", "nova-lite-aws-bedrock"):
- try:
- _ = boto3.client("sts").get_caller_identity()
- except (ClientError, NoCredentialsError):
- console.print(
- f"""[bold red]AWS credentials are not set or invalid.
- Skipping {model_family} evals.[/bold red]"""
- )
- continue
+ if args.list:
+ console.print("[bold]Available models:[/bold]")
+ for model in ALL_MODELS:
+ console.print(f" - {model}")
+ console.print("\n[bold]Available evals:[/bold]")
+ for eval_name in ALL_EVALS:
+ console.print(f" - {eval_name}")
+ return
- console.print("\n[bold magenta]Running allower evals...[/bold magenta]")
- await run_evals_allower(model_family)
+ load_dotenv()
- console.print("\n[bold magenta]Running orchestrator evals...[/bold magenta]")
- await run_evals_orchestrator(model_family)
+ models = args.model
+ evals = args.eval
- console.print("\n[bold magenta]Running summary evals...[/bold magenta]")
- await run_evals_summary(model_family)
+ console.print(f"[bold cyan]Models:[/bold cyan] {', '.join(models)}")
+ console.print(f"[bold cyan]Evals:[/bold cyan] {', '.join(evals)}\n")
- console.print("\n[bold magenta]Running question evals...[/bold magenta]")
- await run_evals_question(model_family)
+ for model_family in models:
+ if not check_model_credentials(model_family):
+ continue
- console.print("\n[bold magenta]Running article evals...[/bold magenta]")
- await run_evals_article(model_family)
+ console.print(f"\n[bold blue]===== Model: {model_family} =====[/bold blue]")
- console.print("\n[bold magenta]Running general agent evals...[/bold magenta]")
- await run_evals_general(model_family)
+ for eval_name in evals:
+ console.print(
+ f"\n[bold magenta]Running {eval_name} evals...[/bold magenta]"
+ )
+ await EVAL_RUNNERS[eval_name](model_family)
if __name__ == "__main__":
diff --git a/evals/evals_article.py b/evals/evals_article.py
index b0f2a95..1f0c735 100644
--- a/evals/evals_article.py
+++ b/evals/evals_article.py
@@ -13,11 +13,21 @@
class ArticleResponseTestCase:
- def __init__(self, request: str, article_data: str, title: str, link: str):
+ def __init__(
+ self,
+ request: str,
+ article_data: str,
+ title: str,
+ link: str,
+ fuzzy_keywords: list[str] = None,
+ ):
self.request = request
self.article_data = article_data
self.title = title
self.link = link
+ # For fuzzy matching: if set, just check that response contains
+ # at least one keyword and has a valid arXiv link format
+ self.fuzzy_keywords = fuzzy_keywords
eval_cases = [
@@ -47,12 +57,14 @@ def __init__(self, request: str, article_data: str, title: str, link: str):
"THE DETERMINISTIC KERMACK-MCKENDRICK MODEL BOUNDS THE GENERAL STOCHASTIC EPIDEMIC",
"https://arxiv.org/pdf/1602.01730.pdf",
),
- # not existing paper
+ # Fuzzy match: paper doesn't exist exactly, so we just check that a relevant
+ # paper is returned (contains keywords) with a valid arXiv link
ArticleResponseTestCase(
"Find this paper 'Quark Gluon plasma and AI'",
- "http://arxiv.org/pdf/2412.19393v1",
- "Hydrodynamic Description of the Quark-Gluon Plasma ",
- "http://arxiv.org/pdf/2311.10621v2",
+ "",
+ "",
+ "",
+ fuzzy_keywords=["quark", "gluon", "plasma", "qgp"],
),
]
@@ -65,6 +77,55 @@ def __init__(self, request: str, article_data: str, title: str, link: str):
MAX_ATTEMPTS = 5
+def check_fuzzy_match(case, response) -> tuple[bool, str]:
+ """
+ Check if response passes fuzzy matching criteria.
+ Returns (passed, reason) tuple.
+ """
+ title = response.output.article_title.lower()
+ link = response.output.article_link
+
+ # Check for valid arXiv link format
+ link_match = re.match(LINK_PATTERN, link)
+ if link_match is None:
+ return False, f"Invalid arXiv link format: {link}"
+
+ # Check if at least one keyword is in the title
+ keyword_found = any(kw.lower() in title for kw in case.fuzzy_keywords)
+ if not keyword_found:
+ return False, f"No keywords {case.fuzzy_keywords} found in title: {title}"
+
+ return True, ""
+
+
+def check_exact_match(case, response) -> tuple[bool, str]:
+ """
+ Check if response passes exact matching criteria.
+ Returns (passed, reason) tuple.
+ """
+ match1 = re.match(LINK_PATTERN, case.link)
+ match2 = re.match(LINK_PATTERN, response.output.article_link)
+
+ # Check titles match (case insensitive)
+ title_matches = case.title.lower().replace(
+ "\n", " "
+ ) == response.output.article_title.lower().replace("\n", " ")
+
+ # Check links regex match exists and IDs are the same
+ links_match = (
+ match1 is not None and match2 is not None and match1.group(1) == match2.group(1)
+ )
+
+ if not title_matches or not links_match:
+ reason = (
+ f"Got: {response.output.article_title} and {response.output.article_link}\n"
+ f"Expected: {case.title} and {case.link}"
+ )
+ return False, reason
+
+ return True, ""
+
+
async def run_evals(model_family: str):
model, model_settings = choose_model(model_family)
@@ -79,22 +140,15 @@ async def run_evals(model_family: str):
print(f"Evaluating case: {case.request}")
response = await article_agent.run(request=case.request)
- match1 = re.match(LINK_PATTERN, case.link)
- match2 = re.match(LINK_PATTERN, response.output.article_link)
-
- # check titles match (case insensitive), links regex match exists and are the same
- if (
- case.title.lower().replace("\n", " ")
- != response.output.article_title.lower().replace("\n", " ")
- or match1 is None
- or match2 is None
- or match1.group(1) != match2.group(1)
- ):
+ # Use fuzzy matching if keywords are specified, else exact match
+ if case.fuzzy_keywords:
+ passed, reason = check_fuzzy_match(case, response)
+ else:
+ passed, reason = check_exact_match(case, response)
+
+ if not passed:
print(f"Test failed for question: {case.request}")
- print(
- f"Got: {response.output.article_title} and {response.output.article_link}"
- )
- print(f"Expected: {case.title} and {case.link}")
+ print(reason)
print("\n")
c_failed += 1
else:
diff --git a/src/askademic/article.py b/src/askademic/article.py
index 61582fb..56c34c1 100644
--- a/src/askademic/article.py
+++ b/src/askademic/article.py
@@ -1,18 +1,12 @@
import logging
+import re
from datetime import datetime
from pydantic import BaseModel, Field
-from pydantic_ai import Agent
+from pydantic_ai import Agent, RunContext
from pydantic_ai.settings import ModelSettings
-from askademic.prompts.general import (
- SYSTEM_PROMPT_ARTICLE,
- SYSTEM_PROMPT_ARTICLE_RETRIEVAL,
- SYSTEM_PROMPT_REQUEST_DISCRIMINATOR,
- USER_PROMPT_ARTICLE_RETRIEVAL_TEMPLATE,
- USER_PROMPT_ARTICLE_TEMPLATE,
- USER_PROMPT_REQUEST_DISCRIMINATOR_TEMPLATE,
-)
+from askademic.prompts.general import SYSTEM_PROMPT_ARTICLE_AGENT
from askademic.tools import get_article, search_articles_by_title
today = datetime.now().strftime("%Y-%m-%d")
@@ -20,20 +14,6 @@
logger = logging.getLogger(__name__)
-class ArticleRequestDiscriminatorResponse(BaseModel):
- """
- The response to the article request discriminator agent.
- It contains the type of article request and the value of the article.
- """
-
- article_type: str = Field(
- description="The type of article request: 'title', 'link', or 'error'."
- )
- article_value: str = Field(
- description="The article title, link or an emtpy string if the type is error."
- )
-
-
class ArticleResponse(BaseModel):
"""
The response to the article agent.
@@ -49,124 +29,112 @@ class ArticleResponse(BaseModel):
)
-class ArticleRetrievalResponse(BaseModel):
- """
- The response to the article retrieval agent.
- It contains the article link and title.
- """
+class ArticleAgentDeps(BaseModel):
+ """Dependencies for the article agent."""
- article_link: str = Field(description="The article link you found.")
- article_title: str = Field(description="The title of the article you found.")
+ use_cache: bool = True
class ArticleAgent:
def __init__(
self, model: str, model_settings: ModelSettings = None, use_cache: bool = True
):
-
- self._get_article = get_article
self.use_cache = use_cache
- self._search_articles_by_title = search_articles_by_title
- self._article_request_discriminator_agent = Agent(
+ self._agent = Agent(
model=model,
model_settings=model_settings,
- system_prompt=SYSTEM_PROMPT_REQUEST_DISCRIMINATOR,
- output_type=ArticleRequestDiscriminatorResponse,
- )
-
- self._article_agent = Agent(
- model=model,
- model_settings=model_settings,
- system_prompt=SYSTEM_PROMPT_ARTICLE,
+ system_prompt=SYSTEM_PROMPT_ARTICLE_AGENT,
output_type=ArticleResponse,
+ deps_type=ArticleAgentDeps,
)
- self._article_retrieval_agent = Agent(
- model=model,
- model_settings=model_settings,
- system_prompt=SYSTEM_PROMPT_ARTICLE_RETRIEVAL,
- output_type=ArticleRetrievalResponse,
- )
-
- async def _discriminate_article_request(
- self, request: str
- ) -> ArticleRequestDiscriminatorResponse:
- """
- Discriminate the type of article request.
- Args:
- request: the request to discriminate
- Returns:
- ArticleRequestDiscriminatorResponse: the response with the article type and value
+ @self._agent.tool
+ def search_by_title(ctx: RunContext[ArticleAgentDeps], title: str) -> str:
+ """
+ Search arXiv for articles matching a title.
+ Returns a JSON string with article links and abstracts.
+
+ Args:
+ title: The title or keywords to search for.
+ """
+ logger.info(f"{datetime.now()}: Searching articles by title: {title}")
+ result = search_articles_by_title(title)
+ logger.info(f"{datetime.now()}: Search results: {result[:200]}...")
+ return result
+
+ @self._agent.tool
+ def fetch_article(ctx: RunContext[ArticleAgentDeps], link: str) -> str:
+ """
+ Fetch the full content of an article from arXiv.
+
+ Args:
+ link: The arXiv link or ID (e.g., "https://arxiv.org/abs/1706.03762"
+ or "1706.03762" or "https://arxiv.org/pdf/1706.03762.pdf").
+ """
+ # Normalize the link to PDF format
+ normalized_link = self._normalize_arxiv_link(link)
+ logger.info(f"{datetime.now()}: Fetching article: {normalized_link}")
+ result = get_article(normalized_link, use_cache=ctx.deps.use_cache)
+ logger.info(f"{datetime.now()}: Article fetched, length: {len(result)}")
+ return result
+
+ def _normalize_arxiv_link(self, link: str) -> str:
"""
- return await self._article_request_discriminator_agent.run(
- USER_PROMPT_REQUEST_DISCRIMINATOR_TEMPLATE.format(request=request)
- )
+ Normalize various arXiv link formats to PDF URL.
- async def _retrieve_article(self, article_title: str) -> ArticleRetrievalResponse:
- """
- Retrieve an article link by its title.
- Args:
- article_title: the title of the article to retrieve
- Returns:
- ArticleRetrievalResponse: the response with the article link and title
+ Handles:
+ - Full PDF URL: https://arxiv.org/pdf/1706.03762.pdf
+ - Abstract URL: https://arxiv.org/abs/1706.03762
+ - Just the ID: 1706.03762 or 2401.00001
"""
- articles = self._search_articles_by_title(article_title)
+ # If it's already a PDF link, return as-is
+ if link.endswith(".pdf"):
+ return link
- return await self._article_retrieval_agent.run(
- USER_PROMPT_ARTICLE_RETRIEVAL_TEMPLATE.format(
- article_title=article_title, articles=articles
- )
- )
+ # Extract arxiv ID from various formats
+ arxiv_id = None
- async def _answer_question(
- self, request: str, article_link: str
- ) -> ArticleResponse:
- """
- Retrieve an article by its link and answer a question about it.
- Args:
- request: the question to answer
- article_link: the link to the article
- Returns:
- ArticleResponse: the response with the article content
- """
- article = self._get_article(article_link, use_cache=self.use_cache)
- return await self._article_agent.run(
- USER_PROMPT_ARTICLE_TEMPLATE.format(request=request, article=article)
- )
+ # Pattern for arxiv ID (old format: YYMM.NNNNN or new format: YYMM.NNNNN)
+ id_pattern = r"(\d{4}\.\d{4,5})"
- async def run(self, request: str) -> ArticleResponse:
+ if "arxiv.org" in link:
+ # Extract ID from URL
+ match = re.search(id_pattern, link)
+ if match:
+ arxiv_id = match.group(1)
+ else:
+ # Assume it's just the ID
+ match = re.match(id_pattern, link.strip())
+ if match:
+ arxiv_id = match.group(1)
+
+ if arxiv_id:
+ return f"https://arxiv.org/pdf/{arxiv_id}.pdf"
+
+ # If we couldn't parse it, return as-is and let the downstream handle the error
+ return link
+
+ async def run(self, request: str):
"""
Run the article agent to answer a question about an article.
+
Args:
- request: the question to answer
+ request: The question to answer, which should reference an article
+ by title, link, or arXiv ID.
+
Returns:
- ArticleResponse: the response with the article content
+ The agent result with normalized article_link in PDF format.
"""
- # Discriminate the type of article request
-
- article_request = await self._discriminate_article_request(request)
-
- logger.info(
- f"{datetime.now()}: Discriminated article request: {article_request}"
- )
-
- if article_request.output.article_type == "title":
- # Search for the article by title
- article_title = article_request.output.article_value
- retrieve_article = await self._retrieve_article(article_title)
- article_link = retrieve_article.output.article_link
- else:
- # If the article type is not title, we assume it's a link or an error
- article_link = article_request.output.article_value
+ logger.info(f"{datetime.now()}: ArticleAgent received request: {request}")
- logger.info(f"{datetime.now()}: Article link retrieved: {article_link}")
+ deps = ArticleAgentDeps(use_cache=self.use_cache)
+ result = await self._agent.run(request, deps=deps)
- if "No articles found" == article_link:
- return ArticleResponse(
- response="No articles found, the requested article is probably not in ArXiv.",
- article_title="",
- article_link="",
- )
+ # Normalize the article_link to PDF format in the output
+ if result.output and result.output.article_link:
+ normalized_link = self._normalize_arxiv_link(result.output.article_link)
+ result.output.article_link = normalized_link
- return await self._answer_question(request=request, article_link=article_link)
+ logger.info(f"{datetime.now()}: ArticleAgent completed request")
+ return result
diff --git a/src/askademic/prompts/general.py b/src/askademic/prompts/general.py
index 688de97..0096177 100644
--- a/src/askademic/prompts/general.py
+++ b/src/askademic/prompts/general.py
@@ -315,6 +315,34 @@
"""
)
+SYSTEM_PROMPT_ARTICLE_AGENT = cleandoc(
+ """
+ You are an expert in retrieving and analyzing arXiv articles.
+ You help users find specific papers and answer questions about them.
+
+ You have two tools available:
+ 1. search_by_title: Search arXiv for articles matching a title
+ 2. fetch_article: Fetch the full content of an article given its link or arXiv ID
+
+ When you receive a request:
+
+ - If the user provides an arXiv link (e.g., https://arxiv.org/abs/1706.03762)
+ or an arXiv ID (e.g., 1706.03762), use fetch_article directly.
+ - If the user provides an article title, first use search_by_title to find
+ matching articles, then use fetch_article to retrieve the best match.
+ - After fetching the article, answer the user's question based on its content.
+ - Quote relevant parts of the article in your response.
+ - If no articles are found, inform the user that the article is not available on arXiv.
+
+
+
+ - article_link MUST be in PDF format: https://arxiv.org/pdf/XXXX.XXXXX.pdf
+ - Convert /abs/ URLs to /pdf/ URLs and add .pdf extension
+ - Example: https://arxiv.org/abs/1706.03762 -> https://arxiv.org/pdf/1706.03762.pdf
+
+ """
+)
+
USER_PROMPT_ARTICLE_TEMPLATE = cleandoc(
"""
You will receive an article and a request.
diff --git a/tests/test_article.py b/tests/test_article.py
index 13fd574..08ff676 100644
--- a/tests/test_article.py
+++ b/tests/test_article.py
@@ -1,151 +1,104 @@
-import asyncio
import os
-from unittest.mock import MagicMock
import pytest
-from pydantic_ai.agent import AgentRunResult # noqa: F401
os.environ["GOOGLE_API_KEY"] = "mock"
-from askademic.article import ( # noqa: E402
- USER_PROMPT_ARTICLE_RETRIEVAL_TEMPLATE,
- USER_PROMPT_ARTICLE_TEMPLATE,
- USER_PROMPT_REQUEST_DISCRIMINATOR_TEMPLATE,
- ArticleAgent,
- ArticleRequestDiscriminatorResponse,
- ArticleResponse,
- ArticleRetrievalResponse,
-)
-
-testdata = [
- (
- "what the 'Augmented Synthetic Control' paper about?",
- ArticleRequestDiscriminatorResponse(
- article_type="title", article_value="Augmented Synthetic Control"
- ),
- ArticleRetrievalResponse(
- article_link="https://arxiv.org/abs/2401.00001",
- article_title="Augmented Synthetic Control",
- ),
- ArticleResponse(
- response="The 'Augmented Synthetic Control' paper discusses...",
- article_title="Augmented Synthetic Control",
- article_link="https://arxiv.org/abs/2401.00001",
- ),
- ),
- (
- "what is the paper with id 2401.00001 about?",
- ArticleRequestDiscriminatorResponse(
- article_type="link", article_value="https://arxiv.org/abs/2401.00001"
- ),
- ArticleRetrievalResponse(
- article_link="https://arxiv.org/abs/2401.00001",
- article_title="Augmented Synthetic Control",
- ),
- ArticleResponse(
- response="The paper with id 2401.00001 discusses...",
- article_title="Augmented Synthetic Control",
- article_link="https://arxiv.org/abs/2401.00001",
- ),
- ),
-]
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize(
- "question, request_discriminator_response, retrieval_response, expected_response",
- testdata,
-)
-async def test_article_agent(
- question: str,
- request_discriminator_response: ArticleRequestDiscriminatorResponse,
- retrieval_response: ArticleRetrievalResponse,
- expected_response: ArticleResponse,
-):
-
- # Mock the agents
- article_agent = ArticleAgent(
- model="google-gla:gemini-2.0-flash",
- )
- article_agent._get_article = MagicMock(
- return_value="Mocked article text for testing purposes."
- )
- article_agent._search_articles_by_title = MagicMock(
- return_value="Mocked search articles by title response."
- )
-
- request_discriminator_response_future = asyncio.Future()
- request_discriminator_response_future.set_result(
- AgentRunResult(
- output=request_discriminator_response,
- _output_tool_name=None,
- _state=None,
- _new_message_index=None,
- _traceparent_value=None,
- )
- )
- article_agent._article_request_discriminator_agent = MagicMock(
- run=MagicMock(return_value=request_discriminator_response_future)
- )
-
- retrieval_response_future = asyncio.Future()
- retrieval_response_future.set_result(
- AgentRunResult(
- output=retrieval_response,
- _output_tool_name=None,
- _state=None,
- _new_message_index=None,
- _traceparent_value=None,
- )
- )
- article_agent._article_retrieval_agent = MagicMock(
- run=MagicMock(return_value=retrieval_response_future)
- )
-
- expected_response_future = asyncio.Future()
- expected_response_future.set_result(
- AgentRunResult(
- output=expected_response,
- _output_tool_name=None,
- _state=None,
- _new_message_index=None,
- _traceparent_value=None,
- )
- )
- article_agent._article_agent = MagicMock(
- run=MagicMock(return_value=expected_response_future)
- )
-
- # Run the agent
- response = await article_agent.run(question)
- assert response.output == expected_response
-
- article_agent._article_request_discriminator_agent.run.assert_called_once_with(
- USER_PROMPT_REQUEST_DISCRIMINATOR_TEMPLATE.format(request=question)
- )
-
- # Check if the article retrieval agent was called correctly
- if request_discriminator_response.article_type == "title":
- article_agent._search_articles_by_title.assert_called_once_with(
- request_discriminator_response.article_value
- )
- article_agent._article_retrieval_agent.run.assert_called_once_with(
- USER_PROMPT_ARTICLE_RETRIEVAL_TEMPLATE.format(
- article_title=request_discriminator_response.article_value,
- articles="Mocked search articles by title response.",
- )
- )
- article_agent._get_article.assert_called_once_with(
- retrieval_response.article_link, use_cache=True
- )
+from askademic.article import ArticleAgent, ArticleResponse # noqa: E402
- else:
- article_agent._get_article.assert_called_once_with(
- request_discriminator_response.article_value, use_cache=True
- )
- article_agent._article_agent.run.assert_called_once_with(
- USER_PROMPT_ARTICLE_TEMPLATE.format(
- request=question, article="Mocked article text for testing purposes."
+class TestArticleAgent:
+ """Tests for the refactored ArticleAgent with tools."""
+
+ def test_normalize_arxiv_link_pdf_url(self):
+ """Test that PDF URLs are returned as-is."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+ url = "https://arxiv.org/pdf/1706.03762.pdf"
+ assert agent._normalize_arxiv_link(url) == url
+
+ def test_normalize_arxiv_link_abs_url(self):
+ """Test that abstract URLs are converted to PDF URLs."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+ url = "https://arxiv.org/abs/1706.03762"
+ expected = "https://arxiv.org/pdf/1706.03762.pdf"
+ assert agent._normalize_arxiv_link(url) == expected
+
+ def test_normalize_arxiv_link_id_only(self):
+ """Test that bare arXiv IDs are converted to PDF URLs."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+ arxiv_id = "1706.03762"
+ expected = "https://arxiv.org/pdf/1706.03762.pdf"
+ assert agent._normalize_arxiv_link(arxiv_id) == expected
+
+ def test_normalize_arxiv_link_new_format_id(self):
+ """Test that new format arXiv IDs (5 digits) are handled."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+ arxiv_id = "2401.00001"
+ expected = "https://arxiv.org/pdf/2401.00001.pdf"
+ assert agent._normalize_arxiv_link(arxiv_id) == expected
+
+ def test_normalize_arxiv_link_invalid(self):
+ """Test that invalid links are returned as-is."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+ invalid = "not-a-valid-link"
+ assert agent._normalize_arxiv_link(invalid) == invalid
+
+ def test_agent_has_tools(self):
+ """Test that the agent is configured with the expected tools."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+ tool_names = list(agent._agent._function_toolset.tools.keys())
+ assert "search_by_title" in tool_names
+ assert "fetch_article" in tool_names
+
+ @pytest.mark.asyncio
+ async def test_run_with_mocked_agent(self):
+ """Test the run method with a mocked internal agent."""
+ from unittest.mock import AsyncMock, MagicMock
+
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+
+ mock_response = ArticleResponse(
+ response="This paper discusses attention mechanisms.",
+ article_title="Attention Is All You Need",
+ article_link="https://arxiv.org/pdf/1706.03762.pdf",
)
- )
+
+ mock_result = MagicMock()
+ mock_result.output = mock_response
+
+ agent._agent.run = AsyncMock(return_value=mock_result)
+
+ result = await agent.run("What is the Attention paper about?")
+
+ assert result.output == mock_response
+ agent._agent.run.assert_called_once()
+
+
+class TestArticleAgentIntegration:
+ """Integration-style tests that mock the external tools."""
+
+ @pytest.mark.asyncio
+ async def test_search_by_title_tool_schema(self):
+ """Test that search_by_title tool has the expected schema."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+
+ tools = agent._agent._function_toolset.tools
+ assert "search_by_title" in tools
+
+ # Verify the tool has the expected parameter via its function schema
+ search_tool = tools["search_by_title"]
+ json_schema = search_tool.function_schema.json_schema
+ assert "title" in json_schema.get("properties", {})
+
+ @pytest.mark.asyncio
+ async def test_fetch_article_tool_schema(self):
+ """Test that fetch_article tool has the expected schema."""
+ agent = ArticleAgent(model="google-gla:gemini-2.0-flash")
+
+ tools = agent._agent._function_toolset.tools
+ assert "fetch_article" in tools
+
+ # Verify the tool has the expected parameter via its function schema
+ fetch_tool = tools["fetch_article"]
+ json_schema = fetch_tool.function_schema.json_schema
+ assert "link" in json_schema.get("properties", {})