-
Notifications
You must be signed in to change notification settings - Fork 105
[MRG] feat(model): Update gemini.py to use new Gemini API #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+139
−128
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
4eaa041
[07/12] Update gemini.py to use new Gemini API
syangx39 9d7f80b
[07/12] Update gemini.py to use new Gemini API
syangx38 5e95af6
[07/12] Update gemini.py to use new Gemini API
syangx38 707a52d
[07/12] Update gemini.py to use new Gemini API
syangx38 0f0798d
[07/12] Update gemini.py to use new Gemini API
syangx38 c143ed9
[07/12] Update gemini.py to use new Gemini API
syangx38 e3bc0f9
[07/12] Update gemini.py to use new Gemini API
syangx38 2597cbe
[07/12] Update gemini.py to use new Gemini API
syangx38 4283bb1
[07/12] Update gemini.py to use new Gemini API
syangx38 448214a
[07/13] Fix trailing whitespace
syangx38 31ac6b3
[07/13] Fix typo
syangx38 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,175 +1,187 @@ | ||
| import os | ||
| import importlib.util | ||
| import json | ||
|
|
||
| from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name | ||
| from mle.model.common import Model | ||
|
|
||
| try: | ||
| from google.genai import client | ||
| from google.genai import types | ||
| except ImportError: | ||
| raise ImportError( | ||
| "It seems you didn't install `google-genai` SDK. " | ||
| "In order to enable the Gemini client related features, " | ||
| "please make sure gemini Python package has been installed. " | ||
| "More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python" | ||
| ) | ||
|
|
||
| class GeminiModel(Model): | ||
|
|
||
| def __init__(self, api_key, model, temperature=0.7): | ||
| """ | ||
| Initialize the Gemini model. | ||
| Initialize the Gemini model using the `google-genai` SDK. | ||
|
|
||
| Args: | ||
| api_key (str): The Gemini API key. | ||
| model (str): The model with version. | ||
| temperature (float): The temperature value. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| dependency = "google.generativeai" | ||
| spec = importlib.util.find_spec(dependency) | ||
| if spec is not None: | ||
| self.gemini = importlib.import_module(dependency) | ||
| self.gemini.configure(api_key=api_key) | ||
| else: | ||
| raise ImportError( | ||
| "It seems you didn't install `google-generativeai`. " | ||
| "In order to enable the Gemini client related features, " | ||
| "please make sure gemini Python package has been installed. " | ||
| "More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python" | ||
| ) | ||
|
|
||
| self.model = model if model else 'gemini-1.5-flash' | ||
| self.model = model if model else 'gemini-2.5-flash' | ||
| self.model_type = 'Gemini' | ||
| self.temperature = temperature | ||
| self.client = client.Client(api_key=api_key) | ||
| self.func_call_history = [] | ||
|
|
||
| def _map_chat_history_from_openai(self, chat_history): | ||
| _key_map_dict = { | ||
| "role": "role", | ||
| "content": "parts", | ||
| } | ||
| _value_map_dict = { | ||
| "system": "model", | ||
| "user": "user", | ||
| "assistant": "model", | ||
| "content": "parts", | ||
| } | ||
| return [ | ||
| { | ||
| _key_map_dict.get(k, k): _value_map_dict.get(v, v) | ||
| for k, v in dict(chat).items() | ||
| } for chat in chat_history | ||
| ] | ||
|
|
||
| def _map_functions_from_openai(self, functions): | ||
| def _mapping_type(_type: str): | ||
| if _type == "string": | ||
| return self.gemini.protos.Type.STRING | ||
| if _type == "object": | ||
| return self.gemini.protos.Type.OBJECT | ||
| if _type == "integer": | ||
| return self.gemini.protos.Type.NUMBER | ||
| if _type == "boolean": | ||
| return self.gemini.protos.Type.BOOLEAN | ||
| if _type == "array": | ||
| return self.gemini.protos.Type.ARRAY | ||
| return self.gemini.protos.Type.TYPE_UNSPECIFIED | ||
|
|
||
| return self.gemini.protos.Tool(function_declarations=[ | ||
| self.gemini.protos.FunctionDeclaration( | ||
| name=func.get("name"), | ||
| description=func.get("description"), | ||
| parameters=self.gemini.protos.Schema( | ||
| type=_mapping_type(func.get("parameters", {}).get("type")), | ||
| def _create_gemini_tools(self, functions): | ||
| """Converts a list of function dictionaries into a Gemini Tool object.""" | ||
| if not functions: | ||
| return None | ||
|
|
||
| function = [] | ||
| for func_dict in functions: | ||
| params = func_dict.get('parameters', {}) | ||
|
|
||
| declaration = types.FunctionDeclaration( | ||
| name=func_dict['name'], | ||
| description=func_dict['description'], | ||
| parameters=types.Schema( | ||
| type='OBJECT', | ||
| properties={ | ||
| param_name: self.gemini.protos.Schema( | ||
| type=_mapping_type(properties.get("type")), | ||
| description=properties.get("description") | ||
| ) | ||
| for param_name, properties in \ | ||
| func.get("parameters",{}).get("properties", {}).items() | ||
| key: types.Schema( | ||
| type=prop.get('type', 'STRING'), | ||
| description=prop.get('description') | ||
| ) for key, prop in params.get('properties', {}).items() | ||
| }, | ||
| required=[key for key in func.get("parameters",{}).get("properties", {}).keys()], | ||
| required=list(params.get('properties', {}).keys()) | ||
| ) | ||
| ) | ||
| for func in functions | ||
| ]) | ||
| function.append(declaration) | ||
| if not function: | ||
| return None | ||
| return [types.Tool(function_declarations=function)] | ||
|
|
||
| def _mapping_response_format_from_openai(self, response_format): | ||
| if response_format.get("type") == "json_object": | ||
| return "application/json" | ||
| return None | ||
|
|
||
| def query(self, chat_history, **kwargs): | ||
| def _adapt_history_for_gemini(self, chat_history): | ||
| """ | ||
| Query the LLM model. | ||
|
|
||
| Adapts mle-agent's chat history format to the one required by the Gemini API, | ||
| separating the system instruction. | ||
|
|
||
| Args: | ||
| chat_history: The context (chat history). | ||
| chat_history (list): The conversation history in the agent's internal format. | ||
|
|
||
| Returns: | ||
| tuple[str, list]: A tuple containing the system instruction and conversation history. | ||
| """ | ||
| parameters = kwargs | ||
| chat_history = self._map_chat_history_from_openai(chat_history) | ||
| system_instruction = "" | ||
| prompt = [] | ||
|
|
||
| for message in chat_history: | ||
| role = message.get("role") | ||
| content = message.get("content") | ||
|
|
||
| if role == "system": | ||
| system_instruction += content + "\n\n" | ||
| elif role == "user" and content: | ||
| prompt.append({'role': 'user', 'parts': [{'text': content}]}) | ||
| elif role == "assistant" and content: | ||
| prompt.append({'role': 'model', 'parts': [{'text': content}]}) | ||
|
|
||
| tools = None | ||
| if parameters.get("functions") is not None: | ||
| tools = self._map_functions_from_openai(parameters["functions"]) | ||
| return system_instruction.strip(), prompt | ||
|
|
||
| def query(self, chat_history, **kwargs): | ||
| """ | ||
| Query the LLM model with robust tool-calling and JSON-forcing logic. | ||
| """ | ||
| MAX_TOOL_TURNS = 10 | ||
| SEARCH_ATTEMPT_LIMIT = 3 | ||
| self.func_call_history = [] | ||
|
|
||
| client = self.gemini.GenerativeModel(self.model) | ||
| chat_handler = client.start_chat(history=chat_history[:-1]) | ||
| system_instruction, prompt = self._adapt_history_for_gemini(chat_history) | ||
| tools = self._create_gemini_tools(kwargs.get("functions", [])) | ||
|
|
||
| completion = chat_handler.send_message( | ||
| chat_history[-1]["parts"], | ||
| base_config = types.GenerateContentConfig( | ||
| tools=tools, | ||
| generation_config=self.gemini.types.GenerationConfig( | ||
| max_output_tokens=4096, | ||
| temperature=self.temperature, | ||
| temperature=self.temperature, | ||
| system_instruction=system_instruction | ||
| ) | ||
| json_only_config = types.GenerateContentConfig( | ||
| temperature=self.temperature, | ||
| response_mime_type="application/json", | ||
| tool_config=types.ToolConfig( | ||
| # Explicitly forbid the model from calling any more tools. | ||
| function_calling_config=types.FunctionCallingConfig(mode='NONE') | ||
| ), | ||
| system_instruction=system_instruction | ||
| ) | ||
|
|
||
| function_outputs = {} | ||
| for part in completion.parts: | ||
| fn = part.function_call | ||
| if fn: | ||
| print("[MLE FUNC CALL]: ", fn.name) | ||
| # avoid the multiple search function calls | ||
| search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] | ||
| if len(search_attempts) > 3: | ||
| parameters['functions'] = None | ||
| result = get_function(fn.name)(**dict(fn.args)) | ||
| function_outputs[fn.name] = result | ||
|
|
||
| if len(function_outputs): | ||
| response_parts = [ | ||
| self.gemini.protos.Part( | ||
| function_response=self.gemini.protos.FunctionResponse( | ||
| name=fn, response={"result": val} | ||
| ) | ||
| ) | ||
| for fn, val in function_outputs.items() | ||
| ] | ||
|
|
||
| completion = chat_handler.send_message( | ||
| self.gemini.protos.Content(parts=response_parts), | ||
| generation_config=self.gemini.types.GenerationConfig( | ||
| max_output_tokens=4096, | ||
| temperature=self.temperature, | ||
| response_mime_type=self._mapping_response_format_from_openai( | ||
| parameters.get("response_format", {})), | ||
| ), | ||
| final_response_content = None | ||
| json_output_required = False | ||
|
|
||
| for _ in range(MAX_TOOL_TURNS): | ||
| config = json_only_config if json_output_required else base_config | ||
| json_output_required = False | ||
|
|
||
| response = self.client.models.generate_content( | ||
| model=self.model, | ||
| contents=prompt, | ||
| config=config, | ||
| ) | ||
|
|
||
| return completion.text | ||
| # The model can return multiple function calls. For now, we only process the first one, | ||
| # as agents' current logic is sequential. | ||
| # This is a potential future improvement to handle more complex tasks. | ||
| function_call = None | ||
| if response.candidates and response.candidates[0].content.parts: | ||
| for part in response.candidates[0].content.parts: | ||
| if part.function_call: | ||
| function_call = part.function_call | ||
| break | ||
| if function_call: | ||
| function_name = process_function_name(function_call.name) | ||
| args = dict(function_call.args) | ||
|
|
||
| self.func_call_history.append(function_name) | ||
| # Prevent infinite search loops. | ||
| search_attempts = [f for f in self.func_call_history if f in SEARCH_FUNCTIONS] | ||
| if len(search_attempts) > SEARCH_ATTEMPT_LIMIT: | ||
| final_response_content = f"[GEMINI WARNING]: Search function limit of {SEARCH_ATTEMPT_LIMIT} reached." | ||
| print(final_response_content) | ||
| break | ||
|
|
||
| print(f"[GEMINI FUNC CALL]: Calling {function_name} with arguments: {args}") | ||
| function_result = get_function(function_name)(**args) | ||
| function_response_part = types.Part.from_function_response( | ||
| name=function_name, | ||
| response={"result": str(function_result)} | ||
| ) | ||
| prompt.append(response.candidates[0].content) | ||
| prompt.append(types.Content(role='tool', parts=[function_response_part])) | ||
| json_output_required = True | ||
| else: | ||
| final_response_content = response.text | ||
| break | ||
|
|
||
| if final_response_content is None: | ||
| final_response_content = f"[GEMINI WARNING]: Max tool turns of {MAX_TOOL_TURNS} reached." | ||
| print(final_response_content) | ||
|
|
||
| return final_response_content | ||
|
|
||
| def stream(self, chat_history, **kwargs): | ||
| """ | ||
| Stream the output from the LLM model. | ||
| Args: | ||
| chat_history: The context (chat history). | ||
| """ | ||
| client = self.gemini.GenerativeModel(self.model) | ||
| chat_handler = client.start_chat(history=chat_history[:-1]) | ||
| prompt = self._adapt_history_for_gemini(chat_history) | ||
|
|
||
| completions = chat_handler.send_message( | ||
| chat_history[-1]["parts"], | ||
| stream=True, | ||
| generation_config=self.gemini.types.GenerationConfig( | ||
| response_stream = self.client.models.generate_content_stream( | ||
| model=self.model, | ||
| contents=prompt, | ||
| config=types.GenerateContentConfig( | ||
| response_mime_type="application/json", | ||
| max_output_tokens=4096, | ||
| temperature=self.temperature, | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| for chunk in completions: | ||
| for chunk in response_stream: | ||
| yield chunk.text |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.