Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env-template
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
LLM_FAMILY="gemini" or "claude" or "claude-aws-bedrock"
LLM_FAMILY="gemini" or "claude" or "claude-aws-bedrock" or "nova-pro-aws-bedrock"
GEMINI_API_KEY="your key if using gemini"
ANTHROPIC_API_KEY="your key if using Claude"
AWS_ACCESS_KEY_ID="your aws access key id if using Claude via aws bedrock"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ As its underlying LLM, you can choose to run it either with:
* Gemini (it will use 2.0 Flash) [preferred and default option]
* Claude (it will use Haiku 3.5) [experimental]
* Claude via AWS Bedrock (it will use Haiku 3.5) [experimental]
* Nova Pro via AWS Bedrock [experimental]

Gemini is preferred because:
* it has a free tier - we privilege cost-effectiveness over speed, which means for short conversations you should be within the quotas of the free tier
Expand Down
7 changes: 6 additions & 1 deletion evals/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ async def main():

load_dotenv()

for model_family in ["gemini", "claude", "claude-aws-bedrock"]:
for model_family in [
"gemini",
"claude",
"claude-aws-bedrock",
"nova-pro-aws-bedrock",
]:

if model_family == "gemini" and os.getenv("GEMINI_API_KEY") is None:
console.print(
Expand Down
6 changes: 3 additions & 3 deletions evals/evals_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, request: str, answer: list[str]):
QuestionAnswerTestCaseRange(
"What percentage of DNA has been found to be shared between \
Sapiens and Neandertals?",
["4%", "5%"],
["4%", "5%", ""],
),
]

Expand All @@ -55,9 +55,9 @@ async def run_evals(model_family: str):
question_agent = QuestionAgent(
model=model,
model_settings=model_settings,
query_list_limit=5,
query_list_limit=2,
relevance_score_threshold=0.8,
article_list_limit=3,
article_list_limit=2,
)

# single-answer ones
Expand Down
16 changes: 9 additions & 7 deletions evals/evals_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import time
from typing import List

from rich.console import Console

Expand All @@ -11,16 +12,17 @@


class SummaryTestCase:
def __init__(self, request: str, category: str):
def __init__(self, request: str, category_list: List[str]):
self.request = request
self.category = category
self.category_list = category_list


eval_cases = [
SummaryTestCase("What is the latest research on quantum field theory?", "hep-th"),
SummaryTestCase("Can you summarize the latest papers on AI?", "cs.AI"),
SummaryTestCase("What is the latest research on quantum field theory?", ["hep-th"]),
SummaryTestCase("Can you summarize the latest papers on AI?", ["cs.AI"]),
SummaryTestCase(
"Tell me all about the recent work in Bayesian statistics?", "stat.TH"
"Tell me all about the recent work in Bayesian statistics?",
["stat.TH", "stat.ME"],
),
]

Expand All @@ -43,10 +45,10 @@ async def run_evals(model_family: str):
print(f"Evaluating case: {case.request}")

response = await summary_agent(case.request)
if response.category.category_id != case.category:
if response.category.category_id not in case.category_list:
print(f"Test failed for question: {case.request}")
print(f"Got: {response.category.category_id}")
print(f"Expected: {case.category}")
print(f"Expected: {case.category_list}")
c_failed += 1
else:
c_passed += 1
Expand Down
2 changes: 1 addition & 1 deletion src/askademic/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# User instructions
INSTRUCTIONS = """
Instructions:
- Type "llm" to choose the LLM family (default is 'gemini')
- Type "reset" to reset the memory
- Type "history" to see the memory history
- Type "exit" or CTRL+D to quit
Expand All @@ -12,6 +11,7 @@
GEMINI_2_FLASH_MODEL_ID = "gemini-2.0-flash"
CLAUDE_HAIKU_3_5_MODEL_ID = "anthropic:claude-3-5-haiku-latest"
CLAUDE_HAIKU_3_5_BEDROCK_MODEL_ID = "{region}.anthropic.claude-3-5-haiku-20241022-v1:0"
NOVA_PRO_BEDROCK_MODEL_ID = "{region}.amazon.nova-pro-v1:0"
MISTRAL_LARGE_MODEL_ID = "mistral:mistral-large-latest"

# ARXIV URLS
Expand Down
16 changes: 11 additions & 5 deletions src/askademic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def check_environment_variables(user_model: str):
"[bold red]The ANTHROPIC_API_KEY environment variable is not set.[/bold red]"
)
sys.exit()
elif user_model == "claude-aws-bedrock":
elif user_model in ("claude-aws-bedrock", "nova-pro-aws-bedrock"):
try:
_ = boto3.client("sts").get_caller_identity()
except boto3.exceptions.ClientError:
Expand All @@ -74,7 +74,8 @@ async def check_environment_variables(user_model: str):
else:
console.print(
"[bold red]Invalid model family selected. "
+ "Please choose 'gemini', 'claude', or 'claude-aws-bedrock'.[/bold red]"
+ "Please choose 'gemini', 'claude', 'claude-aws-bedrock'"
+ " or 'nova-pro-aws-bedrock'.[/bold red]"
)
sys.exit()

Expand All @@ -91,7 +92,6 @@ async def ask_me():
sys.exit()

load_dotenv()

logfire_token = os.getenv("LOGFIRE_TOKEN", None)
user_model = os.getenv("LLM_FAMILY", "gemini")

Expand Down Expand Up @@ -120,10 +120,16 @@ async def ask_me():
memory = Memory(max_request_tokens=1e5)

# ask user to choose the model family (gemini by default)
while user_model not in ("gemini", "claude", "claude-aws-bedrock"):
while user_model not in (
"gemini",
"claude",
"claude-aws-bedrock",
"nova-pro-aws-bedrock",
):
console.print(
"""[bold red]Please configure the LLM family
to be either "gemini" or "claude", or "claude-aws-bedrock"):[/bold red]"""
to be either "gemini" or "claude", "claude-aws-bedrock"
or "nova-pro-aws-bedrock"):[/bold red]"""
)
return

Expand Down
4 changes: 2 additions & 2 deletions src/askademic/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OrchestratorResponse(BaseModel):
orchestrator_agent_base = Agent(
system_prompt=SYSTEM_PROMPT_ORCHESTRATOR,
output_type=OrchestratorResponse,
retries=10,
retries=20,
end_strategy="early",
)

Expand Down Expand Up @@ -70,7 +70,7 @@ async def answer_question(ctx: RunContext[Context], question: str) -> list[str]:
question_agent = QuestionAgent(
orchestrator_agent_base.model,
orchestrator_agent_base.model_settings,
query_list_limit=3,
query_list_limit=2,
relevance_score_threshold=0.8,
article_list_limit=2,
)
Expand Down
9 changes: 8 additions & 1 deletion src/askademic/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def __init__(self, model: Model, model_settings: ModelSettings = None):
output_type=Summary,
)

if "nova" in model.model_name:
self._max_results = 100
else:
self._max_results = 300

self._identify_latest_day = identify_latest_day
self._retrieve_recent_articles = retrieve_recent_articles

Expand All @@ -96,7 +101,9 @@ async def __call__(self, request: str) -> SummaryResponse:

# Get the articles
articles = self._retrieve_recent_articles(
category=category.output.category_id, latest_day=latest_day
category=category.output.category_id,
latest_day=latest_day,
max_results=self._max_results,
)

logger.info(f"Latest published day: {latest_day} - Articles #: {len(articles)}")
Expand Down
8 changes: 6 additions & 2 deletions src/askademic/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def search_articles_by_title(
def retrieve_recent_articles(
category: str = "cs.AI",
latest_day: str = "2022-01-01",
max_results: int = 300,
):
"""
Search articles on arXiv by category, filtering to the ones publishhed
Expand All @@ -198,13 +199,16 @@ def retrieve_recent_articles(
Args:
category: the category ID used for the search
latest_day: the day of publications to filter articles by
max_results: the total number of articles to retrieve. Do not change this parameter.
"""

search_query = f"cat:{category}"
# 300 is empirical: there should never be more articles in a day for a category
url = f"{ARXIV_BASE_URL}search_query={search_query}&start=0&max_results=300"
url = (
f"{ARXIV_BASE_URL}search_query={search_query}&start=0&max_results={max_results}"
)
url += "&sortBy=submittedDate&sortOrder=descending"
logger.info(f"{datetime.now()}: API URL to retrieve recent articles: {url}")
logger.info(f"{datetime.now()}: API URL to retrieve recent articles: {max_results}")

response = requests.get(url, timeout=360)
df_articles = organise_api_response_as_dataframe(response)
Expand Down
23 changes: 18 additions & 5 deletions src/askademic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CLAUDE_HAIKU_3_5_BEDROCK_MODEL_ID,
CLAUDE_HAIKU_3_5_MODEL_ID,
GEMINI_2_FLASH_MODEL_ID,
NOVA_PRO_BEDROCK_MODEL_ID,
)

today = datetime.now().strftime("%Y-%m-%d")
Expand All @@ -26,7 +27,12 @@ def choose_model(model_family: str = "gemini") -> Tuple[Model, ModelSettings]:
"""
Choose the model ID based on the given model family.
"""
if model_family not in ["gemini", "claude", "claude-aws-bedrock"]:
if model_family not in [
"gemini",
"claude",
"claude-aws-bedrock",
"nova-pro-aws-bedrock",
]:
raise ValueError(f"Invalid model family '{model_family}'.")

if model_family == "gemini":
Expand All @@ -39,17 +45,24 @@ def choose_model(model_family: str = "gemini") -> Tuple[Model, ModelSettings]:
model = AnthropicModel(model_name=model_name)
model_settings = ModelSettings(max_tokens=1000, temperature=0)
return model, model_settings
elif model_family == "claude-aws-bedrock":
elif model_family in ("claude-aws-bedrock", "nova-pro-aws-bedrock"):

model_id = (
CLAUDE_HAIKU_3_5_BEDROCK_MODEL_ID
if model_family == "claude-aws-bedrock"
else NOVA_PRO_BEDROCK_MODEL_ID
)

region = boto3.session.Session().region_name
if not region:
model_name = CLAUDE_HAIKU_3_5_BEDROCK_MODEL_ID.format(region="us")
model_name = model_id.format(region="us")
else:
region = region.split("-")[0]
model_name = CLAUDE_HAIKU_3_5_BEDROCK_MODEL_ID.format(region=region)
model_name = model_id.format(region=region)

model_settings = BedrockModelSettings(
temperature=0,
max_tokens=1000,
max_tokens=4000,
top_k=1,
bedrock_performance_configuration={"latency": "optimized"},
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from pydantic_ai.agent import AgentRunResult
from pydantic_ai.models.gemini import GeminiModel

os.environ["GEMINI_API_KEY"] = "mock"

Expand Down Expand Up @@ -56,7 +57,8 @@ async def test_summary_agent(
summary_response,
):
"""Test the SummaryAgent class."""
summary_agent = SummaryAgent("gemini-2.0-flash")
model = GeminiModel(model_name="gemini-2.0-flash")
summary_agent = SummaryAgent(model)
summary_agent._category_agent = MagicMock()
summary_agent._summary_agent = MagicMock()

Expand Down