diff --git a/.gitignore b/.gitignore index 8b8235e6..20db02ca 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ util/__pycache__/ index.html?linkid=2289031 wget-log weights/icon_caption_florence_v2/ -omnitool/gradio/uploads/ \ No newline at end of file +omnitool/gradio/uploads/ +**/.DS_Store \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py index 15664d31..e2f9266d 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -8,7 +8,7 @@ import base64, os -from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img +from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, detect_device import torch from PIL import Image @@ -27,7 +27,7 @@ OmniParser is a screen parsing tool to convert general GUI screen to structured elements. """ -DEVICE = torch.device('cuda') +DEVICE = torch.device(detect_device()) # @spaces.GPU # @torch.inference_mode() diff --git a/omnitool/.DS_Store b/omnitool/.DS_Store new file mode 100644 index 00000000..9ce7173f Binary files /dev/null and b/omnitool/.DS_Store differ diff --git a/omnitool/gradio/.DS_Store b/omnitool/gradio/.DS_Store new file mode 100644 index 00000000..a37f1c2f Binary files /dev/null and b/omnitool/gradio/.DS_Store differ diff --git a/omnitool/gradio/agent/.DS_Store b/omnitool/gradio/agent/.DS_Store new file mode 100644 index 00000000..1ef795a8 Binary files /dev/null and b/omnitool/gradio/agent/.DS_Store differ diff --git a/omnitool/gradio/agent/anthropic_agent.py b/omnitool/gradio/agent/anthropic_agent.py index b1c744e2..55d9b1fa 100644 --- a/omnitool/gradio/agent/anthropic_agent.py +++ b/omnitool/gradio/agent/anthropic_agent.py @@ -39,7 +39,7 @@ class APIProvider(StrEnum): VERTEX = "vertex" SYSTEM_PROMPT = f""" -* You are utilizing a Windows system with internet access. +* You are utilizing a {platform.system()} system with internet access. * The current date is {datetime.today().strftime('%A, %B %d, %Y')}. """ @@ -47,6 +47,7 @@ class APIProvider(StrEnum): class AnthropicActor: def __init__( self, + args, model: str, provider: APIProvider, api_key: str, @@ -62,7 +63,7 @@ def __init__( self.max_tokens = max_tokens self.only_n_most_recent_images = only_n_most_recent_images - self.tool_collection = ToolCollection(ComputerTool()) + self.tool_collection = ToolCollection(ComputerTool(args=args)) self.system = SYSTEM_PROMPT diff --git a/omnitool/gradio/agent/llm_utils/geminiclient.py b/omnitool/gradio/agent/llm_utils/geminiclient.py new file mode 100644 index 00000000..d6bd8b51 --- /dev/null +++ b/omnitool/gradio/agent/llm_utils/geminiclient.py @@ -0,0 +1,72 @@ +import os +from google import genai +from google.genai import types +from pydantic import BaseModel, Field +from typing import Optional +from PIL import Image +from pprint import pprint + +from .utils import is_image_path, encode_image + +class Action(BaseModel): + reasoning: str = Field(..., alias="Reasoning") + next_action: str = Field(..., alias="Next Action") + box_id: str | None = Field(None, alias="Box ID") + value: str | None = None + +def run_gemini_interleaved(messages: list, system: str, model_name: str, api_key: str, max_tokens: int, temperature=0): + """ + Run a chat completion through Google Gemini's API + """ + api_key = api_key or os.environ.get("GEMINI_API_KEY") + if not api_key: + raise ValueError("GEMINI_API_KEY is not set") + + client = genai.Client( + api_key=api_key, + ) + + generate_content_config = types.GenerateContentConfig( + temperature=temperature, + max_output_tokens=max_tokens, + response_mime_type="application/json", + response_schema=Action, + system_instruction=[ + types.Part.from_text(text=system), + ], + ) + + contents = [] + + if type(messages) == list: + for item in messages: + if isinstance(item, dict): + for cnt in item["content"]: + if isinstance(cnt, str): + if is_image_path(cnt): + contents.append(Image.open(cnt)) + else: + contents.append(cnt) + else: + contents.append(str(cnt)) + + else: # str + contents.append(str(cnt)) + + elif isinstance(messages, str): + contents.push(messages) + + try: + response = client.models.generate_content( + model=model_name, + contents=contents, + config=generate_content_config + ) + final_answer = response.text + token_usage = response.usage_metadata.total_token_count + + return final_answer, token_usage + except Exception as e: + print(f"Error in interleaved Gemini: {e}") + + return str(e), 0 diff --git a/omnitool/gradio/agent/llm_utils/oaiclient.py b/omnitool/gradio/agent/llm_utils/oaiclient.py index ad421100..768a86e8 100644 --- a/omnitool/gradio/agent/llm_utils/oaiclient.py +++ b/omnitool/gradio/agent/llm_utils/oaiclient.py @@ -1,6 +1,3 @@ -import os -import logging -import base64 import requests from .utils import is_image_path, encode_image diff --git a/omnitool/gradio/agent/llm_utils/omniparserclient.py b/omnitool/gradio/agent/llm_utils/omniparserclient.py index e90ddef8..fc6921aa 100644 --- a/omnitool/gradio/agent/llm_utils/omniparserclient.py +++ b/omnitool/gradio/agent/llm_utils/omniparserclient.py @@ -8,11 +8,13 @@ class OmniParserClient: def __init__(self, + host_device: str, url: str) -> None: + self.host_device = host_device self.url = url def __call__(self,): - screenshot, screenshot_path = get_screenshot() + screenshot, screenshot_path = get_screenshot(host_device=self.host_device) screenshot_path = str(screenshot_path) image_base64 = encode_image(screenshot_path) response = requests.post(self.url, json={"base64_image": image_base64}) diff --git a/omnitool/gradio/agent/vlm_agent.py b/omnitool/gradio/agent/vlm_agent.py index 9f631a70..ee6f6d0e 100644 --- a/omnitool/gradio/agent/vlm_agent.py +++ b/omnitool/gradio/agent/vlm_agent.py @@ -5,6 +5,7 @@ from PIL import Image, ImageDraw import base64 from io import BytesIO +import platform from anthropic import APIResponse from anthropic.types import ToolResultBlockParam @@ -12,6 +13,7 @@ from agent.llm_utils.oaiclient import run_oai_interleaved from agent.llm_utils.groqclient import run_groq_interleaved +from agent.llm_utils.geminiclient import run_gemini_interleaved from agent.llm_utils.utils import is_image_path import time import re @@ -49,6 +51,10 @@ def __init__( self.model = "o1" elif model == "omniparser + o3-mini": self.model = "o3-mini" + elif model == "omniparser + gemini-2.0-flash": + self.model = "gemini-2.0-flash" + elif model == "omniparser + gemini-2.5-flash-preview-04-17": + self.model = "gemini-2.5-flash-preview-04-17" else: raise ValueError(f"Model {model} not supported") @@ -133,6 +139,18 @@ def __call__(self, messages: list, parsed_screen: list[str, list, dict]): print(f"qwen token usage: {token_usage}") self.total_token_usage += token_usage self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a + elif "gemini" in self.model: + vlm_response, token_usage = run_gemini_interleaved( + messages=planner_messages, + system=system, + model_name=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + temperature=0, + ) + print(f"gemini token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += 0 # assume using free tier else: raise ValueError(f"Model {self.model} not supported") latency_vlm = time.time() - start @@ -209,9 +227,11 @@ def _api_response_callback(self, response: APIResponse): def _get_system_prompt(self, screen_info: str = ""): main_section = f""" -You are using a Windows device. +You are using a {platform.system()} device. You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot. -You can only interact with the desktop GUI (no terminal or application menu access). +You can only interact with the desktop GUI (no terminal or application menu access) + +!!!DO NOT interact with the chatbot webpage interface that opens in 0.0.0.0:7888. You don't need to click the orange send button because the user already clicked it!!! You may be given some history plan and actions, this is the response from the previous loop. You should carefully consider your plan base on the task, screenshot, and history actions. @@ -230,6 +250,15 @@ def _get_system_prompt(self, screen_info: str = ""): Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task. +Use this JSON schema: + +Action = {{ + "Reasoning": str, + "Next Action": str, + "Box ID": str | None, + "value": str | None +}} + Output format: ```json {{ diff --git a/omnitool/gradio/agent/vlm_agent_with_orchestrator.py b/omnitool/gradio/agent/vlm_agent_with_orchestrator.py index 74d554a8..4b5d0275 100644 --- a/omnitool/gradio/agent/vlm_agent_with_orchestrator.py +++ b/omnitool/gradio/agent/vlm_agent_with_orchestrator.py @@ -14,6 +14,7 @@ from agent.llm_utils.oaiclient import run_oai_interleaved from agent.llm_utils.groqclient import run_groq_interleaved +from agent.llm_utils.geminiclient import run_gemini_interleaved from agent.llm_utils.utils import is_image_path import time import re @@ -85,6 +86,10 @@ def __init__( self.model = "o1" elif model == "omniparser + o3-mini" or model == "omniparser + o3-mini-orchestrated": self.model = "o3-mini" + elif model == "omniparser + gemini-2.0-flash" or model == "omniparser + gemini-2.0-flash-orchestrated": + self.model = "gemini-2.0-flash" + elif model == "omniparser + gemini-2.5-flash-preview-04-17" or model == "omniparser + gemini-2.5-flash-preview-04-17-orchestrated": + self.model = "gemini-2.5-flash-preview-04-17" else: raise ValueError(f"Model {model} not supported") @@ -194,6 +199,18 @@ def __call__(self, messages: list, parsed_screen: list[str, list, dict]): print(f"qwen token usage: {token_usage}") self.total_token_usage += token_usage self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a + elif "gemini" in self.model: + vlm_response, token_usage = run_gemini_interleaved( + messages=planner_messages, + system=system, + model_name=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + temperature=0, + ) + print(f"gemini token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += 0 # assume using free tier else: raise ValueError(f"Model {self.model} not supported") latency_vlm = time.time() - start @@ -312,6 +329,15 @@ def _get_system_prompt(self, screen_info: str = ""): Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task. +Use this JSON schema: + +Action = {{ + "Reasoning": str, + "Next Action": str, + "Box ID": str | None, + "value": str | None +}} + Output format: ```json {{ @@ -381,7 +407,9 @@ def _initialize_task(self, messages: list): plan_prompt = self._get_plan_prompt(self._task) input_message = copy.deepcopy(messages) input_message.append({"role": "user", "content": plan_prompt}) - vlm_response, token_usage = run_oai_interleaved( + + if "gpt" in self.model or "o1" in self.model or "o3-mini" in self.model: + vlm_response, token_usage = run_oai_interleaved( messages=input_message, system="", model_name=self.model, @@ -390,6 +418,53 @@ def _initialize_task(self, messages: list): provider_base_url="https://api.openai.com/v1", temperature=0, ) + print(f"oai token usage: {token_usage}") + self.total_token_usage += token_usage + if 'gpt' in self.model: + self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/ + elif 'o1' in self.model: + self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/ + elif 'o3-mini' in self.model: + self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/ + elif "r1" in self.model: + vlm_response, token_usage = run_groq_interleaved( + messages=input_message, + system="", + model_name=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + ) + print(f"groq token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += (token_usage * 0.99 / 1000000) + elif "qwen" in self.model: + vlm_response, token_usage = run_oai_interleaved( + messages=input_message, + system="", + model_name=self.model, + api_key=self.api_key, + max_tokens=min(2048, self.max_tokens), + provider_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + temperature=0, + ) + print(f"qwen token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a + elif "gemini" in self.model: + vlm_response, token_usage = run_gemini_interleaved( + messages=input_message, + system="", + model_name=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + temperature=0, + ) + print(f"gemini token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += 0 # assume using free tier + else: + raise ValueError(f"Model {self.model} not supported") + plan = extract_data(vlm_response, "json") # Create a filename with timestamp @@ -413,7 +488,9 @@ def _update_ledger(self, messages): update_ledger_prompt = ORCHESTRATOR_LEDGER_PROMPT.format(task=self._task) input_message = copy.deepcopy(messages) input_message.append({"role": "user", "content": update_ledger_prompt}) - vlm_response, token_usage = run_oai_interleaved( + + if "gpt" in self.model or "o1" in self.model or "o3-mini" in self.model: + vlm_response, token_usage = run_oai_interleaved( messages=input_message, system="", model_name=self.model, @@ -422,6 +499,53 @@ def _update_ledger(self, messages): provider_base_url="https://api.openai.com/v1", temperature=0, ) + print(f"oai token usage: {token_usage}") + self.total_token_usage += token_usage + if 'gpt' in self.model: + self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/ + elif 'o1' in self.model: + self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/ + elif 'o3-mini' in self.model: + self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/ + elif "r1" in self.model: + vlm_response, token_usage = run_groq_interleaved( + messages=input_message, + system="", + model_name=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + ) + print(f"groq token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += (token_usage * 0.99 / 1000000) + elif "qwen" in self.model: + vlm_response, token_usage = run_oai_interleaved( + messages=input_message, + system="", + model_name=self.model, + api_key=self.api_key, + max_tokens=min(2048, self.max_tokens), + provider_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + temperature=0, + ) + print(f"qwen token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a + elif "gemini" in self.model: + vlm_response, token_usage = run_gemini_interleaved( + messages=input_message, + system="", + model_name=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + temperature=0, + ) + print(f"gemini token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += 0 # assume using free tier + else: + raise ValueError(f"Model {self.model} not supported") + updated_ledger = extract_data(vlm_response, "json") return updated_ledger diff --git a/omnitool/gradio/app.py b/omnitool/gradio/app.py index 54cca8a0..53dc1c0c 100644 --- a/omnitool/gradio/app.py +++ b/omnitool/gradio/app.py @@ -1,5 +1,6 @@ """ -python app.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000 +python app.py --host_device omnibox_windows --windows_host_url localhost:8006 --omniparser_server_url localhost:8000 +python app.py --host_device local --omniparser_server_url localhost:8000 """ import os @@ -27,7 +28,7 @@ API_KEY_FILE = CONFIG_DIR / "api_key" INTRO_TEXT = ''' -OmniParser lets you turn any vision-langauge model into an AI agent. We currently support **OpenAI (4o/o1/o3-mini), DeepSeek (R1), Qwen (2.5VL) or Anthropic Computer Use (Sonnet).** +OmniParser lets you turn any vision-langauge model into an AI agent. We currently support **OpenAI (4o/o1/o3-mini), DeepSeek (R1), Qwen (2.5VL), Gemini (2.0/2.5) or Anthropic Computer Use (Sonnet).** Type a message and press submit to start OmniTool. Press stop to pause, and press the trash icon in the chat to clear the message history. ''' @@ -35,7 +36,8 @@ def parse_arguments(): parser = argparse.ArgumentParser(description="Gradio App") - parser.add_argument("--windows_host_url", type=str, default='localhost:8006') + parser.add_argument("--host_device", type=str, choices=["omnibox_windows", "local"], default="omnibox_windows") + parser.add_argument("--windows_host_url", type=str, default="localhost:8006") parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000") return parser.parse_args() args = parse_arguments() @@ -189,8 +191,13 @@ def _truncate_string(s, max_length=500): def valid_params(user_input, state): """Validate all requirements and return a list of error messages.""" errors = [] + + servers = [('OmniParser Server', args.omniparser_server_url)] + + if args.host_device == "omnibox_windows": + servers.append(("Windows Host", args.windows_host_url)) - for server_name, url in [('Windows Host', 'localhost:5000'), ('OmniParser Server', args.omniparser_server_url)]: + for server_name, url in servers: try: url = f'http://{url}/probe' response = requests.get(url, timeout=3) @@ -233,6 +240,7 @@ def process_input(user_input, state): # Run sampling_loop_sync with the chatbot_output_callback for loop_msg in sampling_loop_sync( + args=args, model=state["model"], provider=state["provider"], messages=state["messages"], @@ -241,8 +249,7 @@ def process_input(user_input, state): api_response_callback=partial(_api_response_callback, response_state=state["responses"]), api_key=state["api_key"], only_n_most_recent_images=state["only_n_most_recent_images"], - max_tokens=16384, - omniparser_url=args.omniparser_server_url + max_tokens=16384 ): if loop_msg is None or state.get("stop"): yield state['chatbot_messages'] @@ -302,7 +309,7 @@ def get_header_image_base64(): with gr.Column(): model = gr.Dropdown( label="Model", - choices=["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "claude-3-5-sonnet-20241022", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated"], + choices=["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "claude-3-5-sonnet-20241022", "omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.5-flash-preview-04-17-orchestrated"], value="omniparser + gpt-4o", interactive=True, ) @@ -343,12 +350,13 @@ def get_header_image_base64(): with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True, height=580) - with gr.Column(scale=3): - iframe = gr.HTML( - f'', - container=False, - elem_classes="no-padding" - ) + if args.host_device == "omnibox_windows": + with gr.Column(scale=3): + iframe = gr.HTML( + f'', + container=False, + elem_classes="no-padding" + ) def update_model(model_selection, state): state["model"] = model_selection @@ -362,6 +370,8 @@ def update_model(model_selection, state): provider_choices = ["groq"] elif model_selection == "omniparser + qwen2.5vl": provider_choices = ["dashscope"] + elif model_selection in set(["omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.5-flash-preview-04-17-orchestrated"]): + provider_choices = ["gemini"] else: provider_choices = [option.value for option in APIProvider] default_provider_value = provider_choices[0] diff --git a/omnitool/gradio/app_new.py b/omnitool/gradio/app_new.py index d67ae185..4bb80b16 100644 --- a/omnitool/gradio/app_new.py +++ b/omnitool/gradio/app_new.py @@ -3,6 +3,7 @@ - a new UI for the OmniParser AI Agent. - python app_new.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000 +python app_new.py --host_device local --omniparser_server_url localhost:8000 """ import os @@ -43,6 +44,7 @@ def parse_arguments(): parser = argparse.ArgumentParser(description="Gradio App") + parser.add_argument("--host_device", type=str, choices=["omnibox_windows", "local"], default="omnibox_windows") parser.add_argument("--windows_host_url", type=str, default='localhost:8006') parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000") parser.add_argument("--run_folder", type=str, default="./tmp/outputs") @@ -222,8 +224,13 @@ def _truncate_string(s, max_length=500): def valid_params(user_input, state): """Validate all requirements and return a list of error messages.""" errors = [] + + servers = [('OmniParser Server', args.omniparser_server_url)] + + if args.host_device == "omnibox_windows": + servers.append(("Windows Host", args.windows_host_url)) - for server_name, url in [('Windows Host', 'localhost:5000'), ('OmniParser Server', args.omniparser_server_url)]: + for server_name, url in servers: try: url = f'http://{url}/probe' response = requests.get(url, timeout=3) @@ -266,6 +273,7 @@ def process_input(user_input, state): # Run sampling_loop_sync with the chatbot_output_callback for loop_msg in sampling_loop_sync( + args=args, model=state["model"], provider=state["provider"], messages=state["messages"], @@ -275,7 +283,6 @@ def process_input(user_input, state): api_key=state["api_key"], only_n_most_recent_images=state["only_n_most_recent_images"], max_tokens=16384, - omniparser_url=args.omniparser_server_url, save_folder=str(RUN_FOLDER) ): if loop_msg is None or state.get("stop"): diff --git a/omnitool/gradio/executor/anthropic_executor.py b/omnitool/gradio/executor/anthropic_executor.py index f5c1a77f..7014b9fb 100644 --- a/omnitool/gradio/executor/anthropic_executor.py +++ b/omnitool/gradio/executor/anthropic_executor.py @@ -18,11 +18,12 @@ class AnthropicExecutor: def __init__( self, + args, output_callback: Callable[[BetaContentBlockParam], None], tool_output_callback: Callable[[Any, str], None], ): self.tool_collection = ToolCollection( - ComputerTool() + ComputerTool(args=args) ) self.output_callback = output_callback self.tool_output_callback = tool_output_callback diff --git a/omnitool/gradio/loop.py b/omnitool/gradio/loop.py index 9ce63169..6985c928 100644 --- a/omnitool/gradio/loop.py +++ b/omnitool/gradio/loop.py @@ -28,6 +28,7 @@ class APIProvider(StrEnum): BEDROCK = "bedrock" VERTEX = "vertex" OPENAI = "openai" + GEMINI = "gemini" PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { @@ -35,10 +36,12 @@ class APIProvider(StrEnum): APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", APIProvider.OPENAI: "gpt-4o", + APIProvider.GEMINI: "gemini-2.0-flash" } def sampling_loop_sync( *, + args, model: str, provider: APIProvider | None, messages: list[BetaMessageParam], @@ -48,17 +51,17 @@ def sampling_loop_sync( api_key: str, only_n_most_recent_images: int | None = 2, max_tokens: int = 4096, - omniparser_url: str, save_folder: str = "./uploads" ): """ Synchronous agentic sampling loop for the assistant/tool interaction of computer use. """ print('in sampling_loop_sync, model:', model) - omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/") + omniparser_client = OmniParserClient(host_device=args.host_device, url=f"http://{args.omniparser_server_url}/parse/") if model == "claude-3-5-sonnet-20241022": # Register Actor and Executor actor = AnthropicActor( + args=args, model=model, provider=provider, api_key=api_key, @@ -66,7 +69,7 @@ def sampling_loop_sync( max_tokens=max_tokens, only_n_most_recent_images=only_n_most_recent_images ) - elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]): + elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17"]): actor = VLMAgent( model=model, provider=provider, @@ -76,7 +79,7 @@ def sampling_loop_sync( max_tokens=max_tokens, only_n_most_recent_images=only_n_most_recent_images ) - elif model in set(["omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated"]): + elif model in set(["omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.5-flash-preview-04-17-orchestrated"]): actor = VLMOrchestratedAgent( model=model, provider=provider, @@ -90,6 +93,7 @@ def sampling_loop_sync( else: raise ValueError(f"Model {model} not supported") executor = AnthropicExecutor( + args=args, output_callback=output_callback, tool_output_callback=tool_output_callback, ) @@ -115,7 +119,7 @@ def sampling_loop_sync( messages.append({"content": tool_result_content, "role": "user"}) - elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated"]): + elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.0-flash-thinking-exp-orchestrated"]): while True: parsed_screen = omniparser_client() tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen) diff --git a/omnitool/gradio/tools/.DS_Store b/omnitool/gradio/tools/.DS_Store new file mode 100644 index 00000000..c47b0595 Binary files /dev/null and b/omnitool/gradio/tools/.DS_Store differ diff --git a/omnitool/gradio/tools/computer.py b/omnitool/gradio/tools/computer.py index 6b91bad2..f54f0b11 100644 --- a/omnitool/gradio/tools/computer.py +++ b/omnitool/gradio/tools/computer.py @@ -2,6 +2,10 @@ import time from enum import StrEnum from typing import Literal, TypedDict +import threading +import shlex +import os +import subprocess from PIL import Image @@ -12,6 +16,8 @@ import requests import re +computer_control_lock = threading.Lock() + OUTPUT_DIR = "./tmp/outputs" TYPING_DELAY_MS = 12 @@ -88,9 +94,11 @@ def options(self) -> ComputerToolOptions: def to_params(self) -> BetaToolComputerUse20241022Param: return {"name": self.name, "type": self.api_type, **self.options} - def __init__(self, is_scaling: bool = False): + def __init__(self, args, is_scaling: bool = False): super().__init__() + self.args = args + # Get screen width and height using Windows command self.display_num = None self.offset_x = 0 @@ -141,11 +149,11 @@ async def __call__( print(f"mouse move to {x}, {y}") if action == "mouse_move": - self.send_to_vm(f"pyautogui.moveTo({x}, {y})") + self.send_to_host_device(f"pyautogui.moveTo({x}, {y})") return ToolResult(output=f"Moved mouse to ({x}, {y})") elif action == "left_click_drag": - current_x, current_y = self.send_to_vm("pyautogui.position()") - self.send_to_vm(f"pyautogui.dragTo({x}, {y}, duration=0.5)") + current_x, current_y = self.send_to_host_device("pyautogui.position()") + self.send_to_host_device(f"pyautogui.dragTo({x}, {y}, duration=0.5)") return ToolResult(output=f"Dragged mouse from ({current_x}, {current_y}) to ({x}, {y})") if action in ("key", "type"): @@ -162,18 +170,18 @@ async def __call__( for key in keys: key = self.key_conversion.get(key.strip(), key.strip()) key = key.lower() - self.send_to_vm(f"pyautogui.keyDown('{key}')") # Press down each key + self.send_to_host_device(f"pyautogui.keyDown('{key}')") # Press down each key for key in reversed(keys): key = self.key_conversion.get(key.strip(), key.strip()) key = key.lower() - self.send_to_vm(f"pyautogui.keyUp('{key}')") # Release each key in reverse order + self.send_to_host_device(f"pyautogui.keyUp('{key}')") # Release each key in reverse order return ToolResult(output=f"Pressed keys: {text}") elif action == "type": # default click before type TODO: check if this is needed - self.send_to_vm("pyautogui.click()") - self.send_to_vm(f"pyautogui.typewrite('{text}', interval={TYPING_DELAY_MS / 1000})") - self.send_to_vm("pyautogui.press('enter')") + self.send_to_host_device("pyautogui.click()") + self.send_to_host_device(f"pyautogui.typewrite('{text}', interval={TYPING_DELAY_MS / 1000})") + self.send_to_host_device("pyautogui.press('enter')") screenshot_base64 = (await self.screenshot()).base64_image return ToolResult(output=text, base64_image=screenshot_base64) @@ -194,28 +202,28 @@ async def __call__( if action == "screenshot": return await self.screenshot() elif action == "cursor_position": - x, y = self.send_to_vm("pyautogui.position()") + x, y = self.send_to_host_device("pyautogui.position()") x, y = self.scale_coordinates(ScalingSource.COMPUTER, x, y) return ToolResult(output=f"X={x},Y={y}") else: if action == "left_click": - self.send_to_vm("pyautogui.click()") + self.send_to_host_device("pyautogui.click()") elif action == "right_click": - self.send_to_vm("pyautogui.rightClick()") + self.send_to_host_device("pyautogui.rightClick()") elif action == "middle_click": - self.send_to_vm("pyautogui.middleClick()") + self.send_to_host_device("pyautogui.middleClick()") elif action == "double_click": - self.send_to_vm("pyautogui.doubleClick()") + self.send_to_host_device("pyautogui.doubleClick()") elif action == "left_press": - self.send_to_vm("pyautogui.mouseDown()") + self.send_to_host_device("pyautogui.mouseDown()") time.sleep(1) - self.send_to_vm("pyautogui.mouseUp()") + self.send_to_host_device("pyautogui.mouseUp()") return ToolResult(output=f"Performed {action}") if action in ("scroll_up", "scroll_down"): if action == "scroll_up": - self.send_to_vm("pyautogui.scroll(100)") + self.send_to_host_device("pyautogui.scroll(100)") elif action == "scroll_down": - self.send_to_vm("pyautogui.scroll(-100)") + self.send_to_host_device("pyautogui.scroll(-100)") return ToolResult(output=f"Performed {action}") if action == "hover": return ToolResult(output=f"Performed {action}") @@ -224,9 +232,9 @@ async def __call__( return ToolResult(output=f"Performed {action}") raise ToolError(f"Invalid action: {action}") - def send_to_vm(self, action: str): + def send_to_host_device(self, action: str): """ - Executes a python command on the server. Only return tuple of x,y when action is "pyautogui.position()" + Executes a python command on the host device. Only return tuple of x,y when action is "pyautogui.position()" """ prefix = "import pyautogui; pyautogui.FAILSAFE = False;" command_list = ["python", "-c", f"{prefix} {action}"] @@ -236,18 +244,25 @@ def send_to_vm(self, action: str): try: print(f"sending to vm: {command_list}") - response = requests.post( - f"http://localhost:5000/execute", - headers={'Content-Type': 'application/json'}, - json={"command": command_list}, - timeout=90 - ) + + if self.args.host_device == "omnibox_windows": + response = requests.post( + f"http://localhost:5000/execute", + headers={'Content-Type': 'application/json'}, + json={"command": command_list}, + timeout=90 + ) + if response.status_code != 200: + raise ToolError(f"Failed to execute command. Status code: {response.status_code}") + output = response.json()['output'].strip() + elif self.args.host_device == "local": + response = self.execute(command_list) + output = response['output'].strip() + time.sleep(0.7) # avoid async error as actions take time to complete print(f"action executed") - if response.status_code != 200: - raise ToolError(f"Failed to execute command. Status code: {response.status_code}") + if parse: - output = response.json()['output'].strip() match = re.search(r'Point\(x=(\d+),\s*y=(\d+)\)', output) if not match: raise ToolError(f"Could not parse coordinates from output: {output}") @@ -255,13 +270,38 @@ def send_to_vm(self, action: str): return x, y except requests.exceptions.RequestException as e: raise ToolError(f"An error occurred while trying to execute the command: {str(e)}") + + def execute(self, command, shell=False): + with computer_control_lock: + if isinstance(command, str) and not shell: + command = shlex.split(command) + + # Expand user directory + for i, arg in enumerate(command): + if arg.startswith("~/"): + command[i] = os.path.expanduser(arg) + + # Execute the command without any safety checks. + try: + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True, timeout=120) + return { + 'status': 'success', + 'output': result.stdout, + 'error': result.stderr, + 'returncode': result.returncode + } + except Exception as e: + return { + 'status': 'error', + 'message': str(e) + } async def screenshot(self): if not hasattr(self, 'target_dimension'): screenshot = self.padding_image(screenshot) self.target_dimension = MAX_SCALING_TARGETS["WXGA"] width, height = self.target_dimension["width"], self.target_dimension["height"] - screenshot, path = get_screenshot(resize=True, target_width=width, target_height=height) + screenshot, path = get_screenshot(host_device=self.args.host_device, resize=True, target_width=width, target_height=height) time.sleep(0.7) # avoid async error as actions take time to complete return ToolResult(base64_image=base64.b64encode(path.read_bytes()).decode()) @@ -310,16 +350,21 @@ def scale_coordinates(self, source: ScalingSource, x: int, y: int): def get_screen_size(self): """Return width and height of the screen""" try: - response = requests.post( - f"http://localhost:5000/execute", - headers={'Content-Type': 'application/json'}, - json={"command": ["python", "-c", "import pyautogui; print(pyautogui.size())"]}, - timeout=90 - ) - if response.status_code != 200: - raise ToolError(f"Failed to get screen size. Status code: {response.status_code}") + if self.args.host_device == "omnibox_windows": + response = requests.post( + f"http://localhost:5000/execute", + headers={'Content-Type': 'application/json'}, + json={"command": ["python", "-c", "import pyautogui; print(pyautogui.size())"]}, + timeout=90 + ) + + if response.status_code != 200: + raise ToolError(f"Failed to get screen size. Status code: {response.status_code}") + output = response.json()['output'].strip() + elif self.args.host_device == "local": + response = self.execute(["python", "-c", "import pyautogui; print(pyautogui.size())"]) + output = response['output'].strip() - output = response.json()['output'].strip() match = re.search(r'Size\(width=(\d+),\s*height=(\d+)\)', output) if not match: raise ToolError(f"Could not parse screen size from output: {output}") diff --git a/omnitool/gradio/tools/cursor.png b/omnitool/gradio/tools/cursor.png new file mode 100644 index 00000000..d3a3c5bb Binary files /dev/null and b/omnitool/gradio/tools/cursor.png differ diff --git a/omnitool/gradio/tools/screen_capture.py b/omnitool/gradio/tools/screen_capture.py index 1c1ad04a..b3b25daa 100644 --- a/omnitool/gradio/tools/screen_capture.py +++ b/omnitool/gradio/tools/screen_capture.py @@ -1,28 +1,42 @@ +import os from pathlib import Path from uuid import uuid4 import requests from PIL import Image from .base import BaseAnthropicTool, ToolError from io import BytesIO +import pyautogui OUTPUT_DIR = "./tmp/outputs" -def get_screenshot(resize: bool = False, target_width: int = 1920, target_height: int = 1080): +def get_screenshot(host_device: str, resize = False, target_width: int = 1920, target_height: int = 1080): """Capture screenshot by requesting from HTTP endpoint - returns native resolution unless resized""" output_dir = Path(OUTPUT_DIR) output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"screenshot_{uuid4().hex}.png" try: - response = requests.get('http://localhost:5000/screenshot') - if response.status_code != 200: - raise ToolError(f"Failed to capture screenshot: HTTP {response.status_code}") - - # (1280, 800) - screenshot = Image.open(BytesIO(response.content)) - - if resize and screenshot.size != (target_width, target_height): - screenshot = screenshot.resize((target_width, target_height)) + if host_device == "omnibox_windows": + response = requests.get('http://localhost:5000/screenshot') + if response.status_code != 200: + raise ToolError(f"Failed to capture screenshot: HTTP {response.status_code}") + # (1280, 800) + screenshot = Image.open(BytesIO(response.content)) + if resize and screenshot.size != (target_width, target_height): + screenshot = screenshot.resize((target_width, target_height)) + elif host_device == "local": + screenshot = pyautogui.screenshot() + size = pyautogui.size() + + screenshot = screenshot.resize((size.width, size.height)) + + cursor_path = os.path.join(os.path.dirname(__file__), "cursor.png") + cursor_x, cursor_y = pyautogui.position() + cursor = Image.open(cursor_path) + # make the cursor smaller + cursor = cursor.resize((int(cursor.width / 1.5), int(cursor.height / 1.5))) + screenshot.paste(cursor, (cursor_x, cursor_y), cursor) + screenshot.save(path) return screenshot, path except Exception as e: diff --git a/omnitool/omniparserserver/omniparserserver.py b/omnitool/omniparserserver/omniparserserver.py index 49fb306f..cfa50c7d 100644 --- a/omnitool/omniparserserver/omniparserserver.py +++ b/omnitool/omniparserserver/omniparserserver.py @@ -12,13 +12,14 @@ root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(root_dir) from util.omniparser import Omniparser +from util.utils import detect_device def parse_arguments(): parser = argparse.ArgumentParser(description='Omniparser API') parser.add_argument('--som_model_path', type=str, default='../../weights/icon_detect/model.pt', help='Path to the som model') parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model') parser.add_argument('--caption_model_path', type=str, default='../../weights/icon_caption_florence', help='Path to the caption model') - parser.add_argument('--device', type=str, default='cpu', help='Device to run the model') + parser.add_argument('--device', type=str, default=detect_device(), help='Device to run the model') parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API') parser.add_argument('--port', type=int, default=8000, help='Port for the API') diff --git a/requirements.txt b/requirements.txt index 901a27fa..b58b5423 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,5 @@ google-auth<3,>=2 screeninfo uiautomation dashscope -groq \ No newline at end of file +groq +google-genai \ No newline at end of file diff --git a/util/omniparser.py b/util/omniparser.py index 536385e6..5610d5d2 100644 --- a/util/omniparser.py +++ b/util/omniparser.py @@ -1,4 +1,4 @@ -from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box +from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box, detect_device import torch from PIL import Image import io @@ -7,7 +7,7 @@ class Omniparser(object): def __init__(self, config: Dict): self.config = config - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = detect_device() self.som_model = get_yolo_model(model_path=config['som_model_path']) self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device) diff --git a/util/utils.py b/util/utils.py index eb7c8b25..7e91258c 100644 --- a/util/utils.py +++ b/util/utils.py @@ -43,10 +43,19 @@ import torchvision.transforms as T from util.box_annotator import BoxAnnotator +def detect_device() -> str: + if torch.cuda.is_available(): + print("[+] Using CUDA") + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + print("[+] Using MPS for Apple Silicon") + return "mps" + else: + return "cpu" def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): if not device: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = detect_device() if model_name == "blip2": from transformers import Blip2Processor, Blip2ForConditionalGeneration processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") @@ -107,7 +116,7 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_ start = time.time() batch = croped_pil_image[i:i+batch_size] t1 = time.time() - if model.device.type == 'cuda': + if model.device.type == 'cuda' or model.device.type == 'mps': inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16) else: inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)