Skip to content
Merged
1 change: 0 additions & 1 deletion mle/agents/advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(self, model, console=None, mode='normal'):
"""

self.functions = [
schema_web_search,
Comment thread
syangx38 marked this conversation as resolved.
schema_search_arxiv,
schema_search_papers_with_code,
schema_preview_csv_data,
Expand Down
2 changes: 1 addition & 1 deletion mle/agents/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DebugAgent:

def __init__(self, model, console=None, analyze_only=False):
"""
DebugAgent: the agent to run the generated the code and then debug it. The return of the
DebugAgent: the agent to run the generated code and then debug it. The return of the
agent is an instruction to the user to modify the code based on the logs and web search.

Args:
Expand Down
264 changes: 138 additions & 126 deletions mle/model/gemini.py
Original file line number Diff line number Diff line change
@@ -1,175 +1,187 @@
import os
import importlib.util
import json

from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
from mle.model.common import Model

try:
from google.genai import client
from google.genai import types
except ImportError:
raise ImportError(
"It seems you didn't install `google-genai` SDK. "
"In order to enable the Gemini client related features, "
"please make sure gemini Python package has been installed. "
"More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python"
)

class GeminiModel(Model):

def __init__(self, api_key, model, temperature=0.7):
"""
Initialize the Gemini model.
Initialize the Gemini model using the `google-genai` SDK.

Args:
api_key (str): The Gemini API key.
model (str): The model with version.
temperature (float): The temperature value.
"""
super().__init__()

dependency = "google.generativeai"
spec = importlib.util.find_spec(dependency)
if spec is not None:
self.gemini = importlib.import_module(dependency)
self.gemini.configure(api_key=api_key)
else:
raise ImportError(
"It seems you didn't install `google-generativeai`. "
"In order to enable the Gemini client related features, "
"please make sure gemini Python package has been installed. "
"More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python"
)

self.model = model if model else 'gemini-1.5-flash'
self.model = model if model else 'gemini-2.5-flash'
self.model_type = 'Gemini'
self.temperature = temperature
self.client = client.Client(api_key=api_key)
self.func_call_history = []

def _map_chat_history_from_openai(self, chat_history):
_key_map_dict = {
"role": "role",
"content": "parts",
}
_value_map_dict = {
"system": "model",
"user": "user",
"assistant": "model",
"content": "parts",
}
return [
{
_key_map_dict.get(k, k): _value_map_dict.get(v, v)
for k, v in dict(chat).items()
} for chat in chat_history
]

def _map_functions_from_openai(self, functions):
def _mapping_type(_type: str):
if _type == "string":
return self.gemini.protos.Type.STRING
if _type == "object":
return self.gemini.protos.Type.OBJECT
if _type == "integer":
return self.gemini.protos.Type.NUMBER
if _type == "boolean":
return self.gemini.protos.Type.BOOLEAN
if _type == "array":
return self.gemini.protos.Type.ARRAY
return self.gemini.protos.Type.TYPE_UNSPECIFIED

return self.gemini.protos.Tool(function_declarations=[
self.gemini.protos.FunctionDeclaration(
name=func.get("name"),
description=func.get("description"),
parameters=self.gemini.protos.Schema(
type=_mapping_type(func.get("parameters", {}).get("type")),
def _create_gemini_tools(self, functions):
"""Converts a list of function dictionaries into a Gemini Tool object."""
if not functions:
return None

function = []
for func_dict in functions:
params = func_dict.get('parameters', {})

declaration = types.FunctionDeclaration(
name=func_dict['name'],
description=func_dict['description'],
parameters=types.Schema(
type='OBJECT',
properties={
param_name: self.gemini.protos.Schema(
type=_mapping_type(properties.get("type")),
description=properties.get("description")
)
for param_name, properties in \
func.get("parameters",{}).get("properties", {}).items()
key: types.Schema(
type=prop.get('type', 'STRING'),
description=prop.get('description')
) for key, prop in params.get('properties', {}).items()
},
required=[key for key in func.get("parameters",{}).get("properties", {}).keys()],
required=list(params.get('properties', {}).keys())
)
)
for func in functions
])
function.append(declaration)
if not function:
return None
return [types.Tool(function_declarations=function)]

def _mapping_response_format_from_openai(self, response_format):
if response_format.get("type") == "json_object":
return "application/json"
return None

def query(self, chat_history, **kwargs):
def _adapt_history_for_gemini(self, chat_history):
"""
Query the LLM model.

Adapts mle-agent's chat history format to the one required by the Gemini API,
separating the system instruction.

Args:
chat_history: The context (chat history).
chat_history (list): The conversation history in the agent's internal format.

Returns:
tuple[str, list]: A tuple containing the system instruction and conversation history.
"""
parameters = kwargs
chat_history = self._map_chat_history_from_openai(chat_history)
system_instruction = ""
prompt = []

for message in chat_history:
role = message.get("role")
content = message.get("content")

if role == "system":
system_instruction += content + "\n\n"
elif role == "user" and content:
prompt.append({'role': 'user', 'parts': [{'text': content}]})
elif role == "assistant" and content:
prompt.append({'role': 'model', 'parts': [{'text': content}]})

tools = None
if parameters.get("functions") is not None:
tools = self._map_functions_from_openai(parameters["functions"])
return system_instruction.strip(), prompt

def query(self, chat_history, **kwargs):
"""
Query the LLM model with robust tool-calling and JSON-forcing logic.
"""
MAX_TOOL_TURNS = 10
SEARCH_ATTEMPT_LIMIT = 3
self.func_call_history = []

client = self.gemini.GenerativeModel(self.model)
chat_handler = client.start_chat(history=chat_history[:-1])
system_instruction, prompt = self._adapt_history_for_gemini(chat_history)
tools = self._create_gemini_tools(kwargs.get("functions", []))

completion = chat_handler.send_message(
chat_history[-1]["parts"],
base_config = types.GenerateContentConfig(
tools=tools,
generation_config=self.gemini.types.GenerationConfig(
max_output_tokens=4096,
temperature=self.temperature,
temperature=self.temperature,
system_instruction=system_instruction
)
json_only_config = types.GenerateContentConfig(
temperature=self.temperature,
response_mime_type="application/json",
tool_config=types.ToolConfig(
# Explicitly forbid the model from calling any more tools.
function_calling_config=types.FunctionCallingConfig(mode='NONE')
),
system_instruction=system_instruction
)

function_outputs = {}
for part in completion.parts:
fn = part.function_call
if fn:
print("[MLE FUNC CALL]: ", fn.name)
# avoid the multiple search function calls
search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
if len(search_attempts) > 3:
parameters['functions'] = None
result = get_function(fn.name)(**dict(fn.args))
function_outputs[fn.name] = result

if len(function_outputs):
response_parts = [
self.gemini.protos.Part(
function_response=self.gemini.protos.FunctionResponse(
name=fn, response={"result": val}
)
)
for fn, val in function_outputs.items()
]

completion = chat_handler.send_message(
self.gemini.protos.Content(parts=response_parts),
generation_config=self.gemini.types.GenerationConfig(
max_output_tokens=4096,
temperature=self.temperature,
response_mime_type=self._mapping_response_format_from_openai(
parameters.get("response_format", {})),
),
final_response_content = None
json_output_required = False

for _ in range(MAX_TOOL_TURNS):
config = json_only_config if json_output_required else base_config
json_output_required = False

response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=config,
)

return completion.text
# The model can return multiple function calls. For now, we only process the first one,
# as agents' current logic is sequential.
# This is a potential future improvement to handle more complex tasks.
function_call = None
if response.candidates and response.candidates[0].content.parts:
for part in response.candidates[0].content.parts:
if part.function_call:
function_call = part.function_call
break
if function_call:
function_name = process_function_name(function_call.name)
args = dict(function_call.args)

self.func_call_history.append(function_name)
# Prevent infinite search loops.
search_attempts = [f for f in self.func_call_history if f in SEARCH_FUNCTIONS]
if len(search_attempts) > SEARCH_ATTEMPT_LIMIT:
final_response_content = f"[GEMINI WARNING]: Search function limit of {SEARCH_ATTEMPT_LIMIT} reached."
print(final_response_content)
break

print(f"[GEMINI FUNC CALL]: Calling {function_name} with arguments: {args}")
function_result = get_function(function_name)(**args)
function_response_part = types.Part.from_function_response(
name=function_name,
response={"result": str(function_result)}
)
prompt.append(response.candidates[0].content)
prompt.append(types.Content(role='tool', parts=[function_response_part]))
json_output_required = True
else:
final_response_content = response.text
break

if final_response_content is None:
final_response_content = f"[GEMINI WARNING]: Max tool turns of {MAX_TOOL_TURNS} reached."
print(final_response_content)

return final_response_content

def stream(self, chat_history, **kwargs):
"""
Stream the output from the LLM model.
Args:
chat_history: The context (chat history).
"""
client = self.gemini.GenerativeModel(self.model)
chat_handler = client.start_chat(history=chat_history[:-1])
prompt = self._adapt_history_for_gemini(chat_history)

completions = chat_handler.send_message(
chat_history[-1]["parts"],
stream=True,
generation_config=self.gemini.types.GenerationConfig(
response_stream = self.client.models.generate_content_stream(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
max_output_tokens=4096,
temperature=self.temperature,
),
)
)

for chunk in completions:
for chunk in response_stream:
yield chunk.text