Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions llm_inference/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self):
self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
self.perplexity_api_key = os.getenv("PERPLEXITY_API_KEY")
self.replicate_api_key = os.getenv("REPLICATE_API_KEY")
self.groq_api_key = os.getenv("GROQ_API_KEY")
self.nvidia_api_key = os.getenv("NVIDIA_API_KEY")

# AWS credentials
self.aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
Expand Down Expand Up @@ -86,6 +88,10 @@ def infer(
return self._call_xai(model_name, prompt)
elif provider == "zhipu":
return self._call_zhipu(model_name, prompt)
elif provider == "groq":
return self._call_groq(model_name, prompt)
elif provider == "nvidia":
return self._call_nvidia(model_name, prompt)
else:
# Default to Together API for most open-source models
return self._call_together(model_name, prompt)
Expand Down Expand Up @@ -132,7 +138,8 @@ def _get_provider(self, model_name: str) -> str:
"gpt-4.1-mini": "openai",
"gpt-4.1-nano": "openai",
"gpt-4o": "openai",
"gpt-4o-mini": "openai",
"openai/gpt-4o-mini": "openrouter",
"gpt-4o-mini": "openrouter",
"gpt-4-1106-preview": "openai",
"o4-mini": "openai",
"gpt-5-chat-latest": "openai",
Expand All @@ -147,6 +154,9 @@ def _get_provider(self, model_name: str) -> str:
"gemini-2.5-pro": "google",
# Mistral models
"mistral-medium": "mistral",
"mistralai/ministral-3-14b-2512": "mistral",
"mistralai/ministral-3-8b-2512": "mistral",
"mistralai/ministral-3-3b-2512": "mistral",
"codestral-latest": "mistral",
"open-mixtral-8x7b": "mistral",
"mistral-large-latest": "mistral",
Expand All @@ -155,6 +165,9 @@ def _get_provider(self, model_name: str) -> str:
"open-mistral-7b": "mistral",
"open-mistral-nemo": "mistral",
# DeepSeek models
"deepseek/deepseek-v4-flash": "openrouter",
"deepseek-chat": "deepseek",
"deepseek-v3.1": "deepseek",
"deepseek-coder": "deepseek",
# Together AI models
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": "together",
Expand Down Expand Up @@ -189,7 +202,18 @@ def _get_provider(self, model_name: str) -> str:
"llama-3-3-70b-instruct": "aws",
"llama-3-1-405b-instruct": "aws",
# Zhipu
# Groq models (free, fast, OpenAI-compatible)
"meta-llama_llama-3.3-70b-instruct": "groq",
"meta-llama_llama-3.1-405b-instruct": "groq",
"llama-3.3-70b-versatile": "groq",
# NVIDIA NIM models (free)
"meta/llama-3.3-70b-instruct": "nvidia",
"meta/llama-3.1-8b-instruct": "nvidia",
# Zhipu / GLM
"glm-4-air": "zhipu",
"glm-4-air-250414": "zhipu",
"glm-4.5-air": "zhipu",
"glm-4.6": "zhipu",
"glm-4-flash": "zhipu",
"glm-4-plus": "zhipu",
}
Expand Down Expand Up @@ -297,7 +321,9 @@ def _call_openrouter(self, model_name: str, prompt: str) -> Dict[str, Any]:
)

response = client.chat.completions.create(
model=model_name, messages=[{"role": "user", "content": prompt}]
model=model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=2048,
)

usage = getattr(response, "usage", None)
Expand Down Expand Up @@ -462,7 +488,15 @@ def _call_mistral(self, model_name: str, prompt: str) -> Dict[str, Any]:

client = Mistral(api_key=self.mistral_api_key)

clean_model_name = model_name.replace("mistral/", "")
clean_model_name = model_name.replace("mistral/", "").replace("mistralai/", "")

# RouterArena name → Mistral API name mapping
MISTRAL_NAME_MAP = {
"ministral-3-14b-2512": "ministral-14b-2512",
"ministral-3-3b-2512": "ministral-3b-2512",
"ministral-3-8b-2512": "ministral-8b-2512",
}
clean_model_name = MISTRAL_NAME_MAP.get(clean_model_name, clean_model_name)

from typing import Any, cast

Expand Down Expand Up @@ -716,3 +750,54 @@ def _call_aws(self, model_name: str, prompt: str) -> Dict[str, Any]:
"model_used": model_name,
"provider": "aws",
}
def _call_groq(self, model_name: str, prompt: str) -> Dict[str, Any]:
"""Call Groq API (OpenAI-compatible, free tier)."""
import openai
client = openai.OpenAI(
api_key=self.groq_api_key, base_url="https://api.groq.com/openai/v1"
)
# Map RouterArena names to Groq names
GROQ_MODEL_MAP = {
"meta-llama_llama-3.3-70b-instruct": "llama-3.3-70b-versatile",
"meta-llama_llama-3.1-405b-instruct": "llama-3.1-8b-instant",
}
groq_model = GROQ_MODEL_MAP.get(model_name, model_name)
response = client.chat.completions.create(
model=groq_model, messages=[{"role": "user", "content": prompt}],
max_tokens=2048, temperature=0.7,
)
usage = getattr(response, "usage", None)
return {
"response": response.choices[0].message.content,
"success": True,
"token_usage": {
"input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0,
"output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0,
"total_tokens": getattr(usage, "total_tokens", 0) if usage else 0,
},
"model_used": groq_model,
"provider": "groq",
}

def _call_nvidia(self, model_name: str, prompt: str) -> Dict[str, Any]:
"""Call NVIDIA NIM API (OpenAI-compatible, free)."""
import openai
client = openai.OpenAI(
api_key=self.nvidia_api_key, base_url="https://integrate.api.nvidia.com/v1"
)
response = client.chat.completions.create(
model=model_name, messages=[{"role": "user", "content": prompt}],
max_tokens=2048, temperature=0.7,
)
usage = getattr(response, "usage", None)
return {
"response": response.choices[0].message.content,
"success": True,
"token_usage": {
"input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0,
"output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0,
"total_tokens": getattr(usage, "total_tokens", 0) if usage else 0,
},
"model_used": model_name,
"provider": "nvidia",
}
1 change: 1 addition & 0 deletions router_inference/config/a3m-router.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"pipeline_params":{"router_name":"a3m-router","router_cls_name":"A3MRouter","models":["deepseek-chat","mistralai/ministral-3-14b-2512","gemini-2.0-flash-001"],"description":"A3M 3-Model Query-Type Router (DeepSeek + Mistral 14B + Gemini Flash)"}}
11 changes: 10 additions & 1 deletion router_inference/generate_prediction_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,16 @@ def generate_predictions(
continue

# Use the router to get prediction (validation is handled by BaseRouter)
selected_model = router.get_prediction(prompt)
# A3M: query-type routing with global_index, fallback to prompt-only
try:
# Use getattr to avoid MyPy error on BaseRouter signature
_get_pred = getattr(router, '_get_prediction', None)
if _get_pred:
selected_model = _get_pred(prompt, global_index=global_index) # type: ignore[call-arg]
else:
selected_model = router.get_prediction(prompt)
except TypeError:
selected_model = router.get_prediction(prompt)

# Track selected model for sub_10 entries (for optimality generation)
if global_index in sub10_indices:
Expand Down
Loading
Loading