diff --git a/.gitignore b/.gitignore
index 8b8235e6..20db02ca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,4 +10,5 @@ util/__pycache__/
index.html?linkid=2289031
wget-log
weights/icon_caption_florence_v2/
-omnitool/gradio/uploads/
\ No newline at end of file
+omnitool/gradio/uploads/
+**/.DS_Store
\ No newline at end of file
diff --git a/gradio_demo.py b/gradio_demo.py
index 15664d31..e2f9266d 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -8,7 +8,7 @@
import base64, os
-from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
+from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, detect_device
import torch
from PIL import Image
@@ -27,7 +27,7 @@
OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
"""
-DEVICE = torch.device('cuda')
+DEVICE = torch.device(detect_device())
# @spaces.GPU
# @torch.inference_mode()
diff --git a/omnitool/.DS_Store b/omnitool/.DS_Store
new file mode 100644
index 00000000..9ce7173f
Binary files /dev/null and b/omnitool/.DS_Store differ
diff --git a/omnitool/gradio/.DS_Store b/omnitool/gradio/.DS_Store
new file mode 100644
index 00000000..a37f1c2f
Binary files /dev/null and b/omnitool/gradio/.DS_Store differ
diff --git a/omnitool/gradio/agent/.DS_Store b/omnitool/gradio/agent/.DS_Store
new file mode 100644
index 00000000..1ef795a8
Binary files /dev/null and b/omnitool/gradio/agent/.DS_Store differ
diff --git a/omnitool/gradio/agent/anthropic_agent.py b/omnitool/gradio/agent/anthropic_agent.py
index b1c744e2..55d9b1fa 100644
--- a/omnitool/gradio/agent/anthropic_agent.py
+++ b/omnitool/gradio/agent/anthropic_agent.py
@@ -39,7 +39,7 @@ class APIProvider(StrEnum):
VERTEX = "vertex"
SYSTEM_PROMPT = f"""
-* You are utilizing a Windows system with internet access.
+* You are utilizing a {platform.system()} system with internet access.
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
"""
@@ -47,6 +47,7 @@ class APIProvider(StrEnum):
class AnthropicActor:
def __init__(
self,
+ args,
model: str,
provider: APIProvider,
api_key: str,
@@ -62,7 +63,7 @@ def __init__(
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images
- self.tool_collection = ToolCollection(ComputerTool())
+ self.tool_collection = ToolCollection(ComputerTool(args=args))
self.system = SYSTEM_PROMPT
diff --git a/omnitool/gradio/agent/llm_utils/geminiclient.py b/omnitool/gradio/agent/llm_utils/geminiclient.py
new file mode 100644
index 00000000..d6bd8b51
--- /dev/null
+++ b/omnitool/gradio/agent/llm_utils/geminiclient.py
@@ -0,0 +1,72 @@
+import os
+from google import genai
+from google.genai import types
+from pydantic import BaseModel, Field
+from typing import Optional
+from PIL import Image
+from pprint import pprint
+
+from .utils import is_image_path, encode_image
+
+class Action(BaseModel):
+ reasoning: str = Field(..., alias="Reasoning")
+ next_action: str = Field(..., alias="Next Action")
+ box_id: str | None = Field(None, alias="Box ID")
+ value: str | None = None
+
+def run_gemini_interleaved(messages: list, system: str, model_name: str, api_key: str, max_tokens: int, temperature=0):
+ """
+ Run a chat completion through Google Gemini's API
+ """
+ api_key = api_key or os.environ.get("GEMINI_API_KEY")
+ if not api_key:
+ raise ValueError("GEMINI_API_KEY is not set")
+
+ client = genai.Client(
+ api_key=api_key,
+ )
+
+ generate_content_config = types.GenerateContentConfig(
+ temperature=temperature,
+ max_output_tokens=max_tokens,
+ response_mime_type="application/json",
+ response_schema=Action,
+ system_instruction=[
+ types.Part.from_text(text=system),
+ ],
+ )
+
+ contents = []
+
+ if type(messages) == list:
+ for item in messages:
+ if isinstance(item, dict):
+ for cnt in item["content"]:
+ if isinstance(cnt, str):
+ if is_image_path(cnt):
+ contents.append(Image.open(cnt))
+ else:
+ contents.append(cnt)
+ else:
+ contents.append(str(cnt))
+
+ else: # str
+ contents.append(str(cnt))
+
+ elif isinstance(messages, str):
+ contents.push(messages)
+
+ try:
+ response = client.models.generate_content(
+ model=model_name,
+ contents=contents,
+ config=generate_content_config
+ )
+ final_answer = response.text
+ token_usage = response.usage_metadata.total_token_count
+
+ return final_answer, token_usage
+ except Exception as e:
+ print(f"Error in interleaved Gemini: {e}")
+
+ return str(e), 0
diff --git a/omnitool/gradio/agent/llm_utils/oaiclient.py b/omnitool/gradio/agent/llm_utils/oaiclient.py
index ad421100..768a86e8 100644
--- a/omnitool/gradio/agent/llm_utils/oaiclient.py
+++ b/omnitool/gradio/agent/llm_utils/oaiclient.py
@@ -1,6 +1,3 @@
-import os
-import logging
-import base64
import requests
from .utils import is_image_path, encode_image
diff --git a/omnitool/gradio/agent/llm_utils/omniparserclient.py b/omnitool/gradio/agent/llm_utils/omniparserclient.py
index e90ddef8..fc6921aa 100644
--- a/omnitool/gradio/agent/llm_utils/omniparserclient.py
+++ b/omnitool/gradio/agent/llm_utils/omniparserclient.py
@@ -8,11 +8,13 @@
class OmniParserClient:
def __init__(self,
+ host_device: str,
url: str) -> None:
+ self.host_device = host_device
self.url = url
def __call__(self,):
- screenshot, screenshot_path = get_screenshot()
+ screenshot, screenshot_path = get_screenshot(host_device=self.host_device)
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
response = requests.post(self.url, json={"base64_image": image_base64})
diff --git a/omnitool/gradio/agent/vlm_agent.py b/omnitool/gradio/agent/vlm_agent.py
index 9f631a70..ee6f6d0e 100644
--- a/omnitool/gradio/agent/vlm_agent.py
+++ b/omnitool/gradio/agent/vlm_agent.py
@@ -5,6 +5,7 @@
from PIL import Image, ImageDraw
import base64
from io import BytesIO
+import platform
from anthropic import APIResponse
from anthropic.types import ToolResultBlockParam
@@ -12,6 +13,7 @@
from agent.llm_utils.oaiclient import run_oai_interleaved
from agent.llm_utils.groqclient import run_groq_interleaved
+from agent.llm_utils.geminiclient import run_gemini_interleaved
from agent.llm_utils.utils import is_image_path
import time
import re
@@ -49,6 +51,10 @@ def __init__(
self.model = "o1"
elif model == "omniparser + o3-mini":
self.model = "o3-mini"
+ elif model == "omniparser + gemini-2.0-flash":
+ self.model = "gemini-2.0-flash"
+ elif model == "omniparser + gemini-2.5-flash-preview-04-17":
+ self.model = "gemini-2.5-flash-preview-04-17"
else:
raise ValueError(f"Model {model} not supported")
@@ -133,6 +139,18 @@ def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
print(f"qwen token usage: {token_usage}")
self.total_token_usage += token_usage
self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a
+ elif "gemini" in self.model:
+ vlm_response, token_usage = run_gemini_interleaved(
+ messages=planner_messages,
+ system=system,
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ temperature=0,
+ )
+ print(f"gemini token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += 0 # assume using free tier
else:
raise ValueError(f"Model {self.model} not supported")
latency_vlm = time.time() - start
@@ -209,9 +227,11 @@ def _api_response_callback(self, response: APIResponse):
def _get_system_prompt(self, screen_info: str = ""):
main_section = f"""
-You are using a Windows device.
+You are using a {platform.system()} device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
-You can only interact with the desktop GUI (no terminal or application menu access).
+You can only interact with the desktop GUI (no terminal or application menu access)
+
+!!!DO NOT interact with the chatbot webpage interface that opens in 0.0.0.0:7888. You don't need to click the orange send button because the user already clicked it!!!
You may be given some history plan and actions, this is the response from the previous loop.
You should carefully consider your plan base on the task, screenshot, and history actions.
@@ -230,6 +250,15 @@ def _get_system_prompt(self, screen_info: str = ""):
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
+Use this JSON schema:
+
+Action = {{
+ "Reasoning": str,
+ "Next Action": str,
+ "Box ID": str | None,
+ "value": str | None
+}}
+
Output format:
```json
{{
diff --git a/omnitool/gradio/agent/vlm_agent_with_orchestrator.py b/omnitool/gradio/agent/vlm_agent_with_orchestrator.py
index 74d554a8..4b5d0275 100644
--- a/omnitool/gradio/agent/vlm_agent_with_orchestrator.py
+++ b/omnitool/gradio/agent/vlm_agent_with_orchestrator.py
@@ -14,6 +14,7 @@
from agent.llm_utils.oaiclient import run_oai_interleaved
from agent.llm_utils.groqclient import run_groq_interleaved
+from agent.llm_utils.geminiclient import run_gemini_interleaved
from agent.llm_utils.utils import is_image_path
import time
import re
@@ -85,6 +86,10 @@ def __init__(
self.model = "o1"
elif model == "omniparser + o3-mini" or model == "omniparser + o3-mini-orchestrated":
self.model = "o3-mini"
+ elif model == "omniparser + gemini-2.0-flash" or model == "omniparser + gemini-2.0-flash-orchestrated":
+ self.model = "gemini-2.0-flash"
+ elif model == "omniparser + gemini-2.5-flash-preview-04-17" or model == "omniparser + gemini-2.5-flash-preview-04-17-orchestrated":
+ self.model = "gemini-2.5-flash-preview-04-17"
else:
raise ValueError(f"Model {model} not supported")
@@ -194,6 +199,18 @@ def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
print(f"qwen token usage: {token_usage}")
self.total_token_usage += token_usage
self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a
+ elif "gemini" in self.model:
+ vlm_response, token_usage = run_gemini_interleaved(
+ messages=planner_messages,
+ system=system,
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ temperature=0,
+ )
+ print(f"gemini token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += 0 # assume using free tier
else:
raise ValueError(f"Model {self.model} not supported")
latency_vlm = time.time() - start
@@ -312,6 +329,15 @@ def _get_system_prompt(self, screen_info: str = ""):
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
+Use this JSON schema:
+
+Action = {{
+ "Reasoning": str,
+ "Next Action": str,
+ "Box ID": str | None,
+ "value": str | None
+}}
+
Output format:
```json
{{
@@ -381,7 +407,9 @@ def _initialize_task(self, messages: list):
plan_prompt = self._get_plan_prompt(self._task)
input_message = copy.deepcopy(messages)
input_message.append({"role": "user", "content": plan_prompt})
- vlm_response, token_usage = run_oai_interleaved(
+
+ if "gpt" in self.model or "o1" in self.model or "o3-mini" in self.model:
+ vlm_response, token_usage = run_oai_interleaved(
messages=input_message,
system="",
model_name=self.model,
@@ -390,6 +418,53 @@ def _initialize_task(self, messages: list):
provider_base_url="https://api.openai.com/v1",
temperature=0,
)
+ print(f"oai token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ if 'gpt' in self.model:
+ self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/
+ elif 'o1' in self.model:
+ self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/
+ elif 'o3-mini' in self.model:
+ self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/
+ elif "r1" in self.model:
+ vlm_response, token_usage = run_groq_interleaved(
+ messages=input_message,
+ system="",
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ )
+ print(f"groq token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += (token_usage * 0.99 / 1000000)
+ elif "qwen" in self.model:
+ vlm_response, token_usage = run_oai_interleaved(
+ messages=input_message,
+ system="",
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=min(2048, self.max_tokens),
+ provider_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
+ temperature=0,
+ )
+ print(f"qwen token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a
+ elif "gemini" in self.model:
+ vlm_response, token_usage = run_gemini_interleaved(
+ messages=input_message,
+ system="",
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ temperature=0,
+ )
+ print(f"gemini token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += 0 # assume using free tier
+ else:
+ raise ValueError(f"Model {self.model} not supported")
+
plan = extract_data(vlm_response, "json")
# Create a filename with timestamp
@@ -413,7 +488,9 @@ def _update_ledger(self, messages):
update_ledger_prompt = ORCHESTRATOR_LEDGER_PROMPT.format(task=self._task)
input_message = copy.deepcopy(messages)
input_message.append({"role": "user", "content": update_ledger_prompt})
- vlm_response, token_usage = run_oai_interleaved(
+
+ if "gpt" in self.model or "o1" in self.model or "o3-mini" in self.model:
+ vlm_response, token_usage = run_oai_interleaved(
messages=input_message,
system="",
model_name=self.model,
@@ -422,6 +499,53 @@ def _update_ledger(self, messages):
provider_base_url="https://api.openai.com/v1",
temperature=0,
)
+ print(f"oai token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ if 'gpt' in self.model:
+ self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/
+ elif 'o1' in self.model:
+ self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/
+ elif 'o3-mini' in self.model:
+ self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/
+ elif "r1" in self.model:
+ vlm_response, token_usage = run_groq_interleaved(
+ messages=input_message,
+ system="",
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ )
+ print(f"groq token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += (token_usage * 0.99 / 1000000)
+ elif "qwen" in self.model:
+ vlm_response, token_usage = run_oai_interleaved(
+ messages=input_message,
+ system="",
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=min(2048, self.max_tokens),
+ provider_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
+ temperature=0,
+ )
+ print(f"qwen token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a
+ elif "gemini" in self.model:
+ vlm_response, token_usage = run_gemini_interleaved(
+ messages=input_message,
+ system="",
+ model_name=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ temperature=0,
+ )
+ print(f"gemini token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += 0 # assume using free tier
+ else:
+ raise ValueError(f"Model {self.model} not supported")
+
updated_ledger = extract_data(vlm_response, "json")
return updated_ledger
diff --git a/omnitool/gradio/app.py b/omnitool/gradio/app.py
index 54cca8a0..53dc1c0c 100644
--- a/omnitool/gradio/app.py
+++ b/omnitool/gradio/app.py
@@ -1,5 +1,6 @@
"""
-python app.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000
+python app.py --host_device omnibox_windows --windows_host_url localhost:8006 --omniparser_server_url localhost:8000
+python app.py --host_device local --omniparser_server_url localhost:8000
"""
import os
@@ -27,7 +28,7 @@
API_KEY_FILE = CONFIG_DIR / "api_key"
INTRO_TEXT = '''
-OmniParser lets you turn any vision-langauge model into an AI agent. We currently support **OpenAI (4o/o1/o3-mini), DeepSeek (R1), Qwen (2.5VL) or Anthropic Computer Use (Sonnet).**
+OmniParser lets you turn any vision-langauge model into an AI agent. We currently support **OpenAI (4o/o1/o3-mini), DeepSeek (R1), Qwen (2.5VL), Gemini (2.0/2.5) or Anthropic Computer Use (Sonnet).**
Type a message and press submit to start OmniTool. Press stop to pause, and press the trash icon in the chat to clear the message history.
'''
@@ -35,7 +36,8 @@
def parse_arguments():
parser = argparse.ArgumentParser(description="Gradio App")
- parser.add_argument("--windows_host_url", type=str, default='localhost:8006')
+ parser.add_argument("--host_device", type=str, choices=["omnibox_windows", "local"], default="omnibox_windows")
+ parser.add_argument("--windows_host_url", type=str, default="localhost:8006")
parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000")
return parser.parse_args()
args = parse_arguments()
@@ -189,8 +191,13 @@ def _truncate_string(s, max_length=500):
def valid_params(user_input, state):
"""Validate all requirements and return a list of error messages."""
errors = []
+
+ servers = [('OmniParser Server', args.omniparser_server_url)]
+
+ if args.host_device == "omnibox_windows":
+ servers.append(("Windows Host", args.windows_host_url))
- for server_name, url in [('Windows Host', 'localhost:5000'), ('OmniParser Server', args.omniparser_server_url)]:
+ for server_name, url in servers:
try:
url = f'http://{url}/probe'
response = requests.get(url, timeout=3)
@@ -233,6 +240,7 @@ def process_input(user_input, state):
# Run sampling_loop_sync with the chatbot_output_callback
for loop_msg in sampling_loop_sync(
+ args=args,
model=state["model"],
provider=state["provider"],
messages=state["messages"],
@@ -241,8 +249,7 @@ def process_input(user_input, state):
api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"],
- max_tokens=16384,
- omniparser_url=args.omniparser_server_url
+ max_tokens=16384
):
if loop_msg is None or state.get("stop"):
yield state['chatbot_messages']
@@ -302,7 +309,7 @@ def get_header_image_base64():
with gr.Column():
model = gr.Dropdown(
label="Model",
- choices=["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "claude-3-5-sonnet-20241022", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated"],
+ choices=["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "claude-3-5-sonnet-20241022", "omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.5-flash-preview-04-17-orchestrated"],
value="omniparser + gpt-4o",
interactive=True,
)
@@ -343,12 +350,13 @@ def get_header_image_base64():
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True, height=580)
- with gr.Column(scale=3):
- iframe = gr.HTML(
- f'',
- container=False,
- elem_classes="no-padding"
- )
+ if args.host_device == "omnibox_windows":
+ with gr.Column(scale=3):
+ iframe = gr.HTML(
+ f'',
+ container=False,
+ elem_classes="no-padding"
+ )
def update_model(model_selection, state):
state["model"] = model_selection
@@ -362,6 +370,8 @@ def update_model(model_selection, state):
provider_choices = ["groq"]
elif model_selection == "omniparser + qwen2.5vl":
provider_choices = ["dashscope"]
+ elif model_selection in set(["omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.5-flash-preview-04-17-orchestrated"]):
+ provider_choices = ["gemini"]
else:
provider_choices = [option.value for option in APIProvider]
default_provider_value = provider_choices[0]
diff --git a/omnitool/gradio/app_new.py b/omnitool/gradio/app_new.py
index d67ae185..4bb80b16 100644
--- a/omnitool/gradio/app_new.py
+++ b/omnitool/gradio/app_new.py
@@ -3,6 +3,7 @@
- a new UI for the OmniParser AI Agent.
-
python app_new.py --windows_host_url localhost:8006 --omniparser_server_url localhost:8000
+python app_new.py --host_device local --omniparser_server_url localhost:8000
"""
import os
@@ -43,6 +44,7 @@
def parse_arguments():
parser = argparse.ArgumentParser(description="Gradio App")
+ parser.add_argument("--host_device", type=str, choices=["omnibox_windows", "local"], default="omnibox_windows")
parser.add_argument("--windows_host_url", type=str, default='localhost:8006')
parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000")
parser.add_argument("--run_folder", type=str, default="./tmp/outputs")
@@ -222,8 +224,13 @@ def _truncate_string(s, max_length=500):
def valid_params(user_input, state):
"""Validate all requirements and return a list of error messages."""
errors = []
+
+ servers = [('OmniParser Server', args.omniparser_server_url)]
+
+ if args.host_device == "omnibox_windows":
+ servers.append(("Windows Host", args.windows_host_url))
- for server_name, url in [('Windows Host', 'localhost:5000'), ('OmniParser Server', args.omniparser_server_url)]:
+ for server_name, url in servers:
try:
url = f'http://{url}/probe'
response = requests.get(url, timeout=3)
@@ -266,6 +273,7 @@ def process_input(user_input, state):
# Run sampling_loop_sync with the chatbot_output_callback
for loop_msg in sampling_loop_sync(
+ args=args,
model=state["model"],
provider=state["provider"],
messages=state["messages"],
@@ -275,7 +283,6 @@ def process_input(user_input, state):
api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"],
max_tokens=16384,
- omniparser_url=args.omniparser_server_url,
save_folder=str(RUN_FOLDER)
):
if loop_msg is None or state.get("stop"):
diff --git a/omnitool/gradio/executor/anthropic_executor.py b/omnitool/gradio/executor/anthropic_executor.py
index f5c1a77f..7014b9fb 100644
--- a/omnitool/gradio/executor/anthropic_executor.py
+++ b/omnitool/gradio/executor/anthropic_executor.py
@@ -18,11 +18,12 @@
class AnthropicExecutor:
def __init__(
self,
+ args,
output_callback: Callable[[BetaContentBlockParam], None],
tool_output_callback: Callable[[Any, str], None],
):
self.tool_collection = ToolCollection(
- ComputerTool()
+ ComputerTool(args=args)
)
self.output_callback = output_callback
self.tool_output_callback = tool_output_callback
diff --git a/omnitool/gradio/loop.py b/omnitool/gradio/loop.py
index 9ce63169..6985c928 100644
--- a/omnitool/gradio/loop.py
+++ b/omnitool/gradio/loop.py
@@ -28,6 +28,7 @@ class APIProvider(StrEnum):
BEDROCK = "bedrock"
VERTEX = "vertex"
OPENAI = "openai"
+ GEMINI = "gemini"
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
@@ -35,10 +36,12 @@ class APIProvider(StrEnum):
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
APIProvider.OPENAI: "gpt-4o",
+ APIProvider.GEMINI: "gemini-2.0-flash"
}
def sampling_loop_sync(
*,
+ args,
model: str,
provider: APIProvider | None,
messages: list[BetaMessageParam],
@@ -48,17 +51,17 @@ def sampling_loop_sync(
api_key: str,
only_n_most_recent_images: int | None = 2,
max_tokens: int = 4096,
- omniparser_url: str,
save_folder: str = "./uploads"
):
"""
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
"""
print('in sampling_loop_sync, model:', model)
- omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
+ omniparser_client = OmniParserClient(host_device=args.host_device, url=f"http://{args.omniparser_server_url}/parse/")
if model == "claude-3-5-sonnet-20241022":
# Register Actor and Executor
actor = AnthropicActor(
+ args=args,
model=model,
provider=provider,
api_key=api_key,
@@ -66,7 +69,7 @@ def sampling_loop_sync(
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
- elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]):
+ elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17"]):
actor = VLMAgent(
model=model,
provider=provider,
@@ -76,7 +79,7 @@ def sampling_loop_sync(
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
)
- elif model in set(["omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated"]):
+ elif model in set(["omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.5-flash-preview-04-17-orchestrated"]):
actor = VLMOrchestratedAgent(
model=model,
provider=provider,
@@ -90,6 +93,7 @@ def sampling_loop_sync(
else:
raise ValueError(f"Model {model} not supported")
executor = AnthropicExecutor(
+ args=args,
output_callback=output_callback,
tool_output_callback=tool_output_callback,
)
@@ -115,7 +119,7 @@ def sampling_loop_sync(
messages.append({"content": tool_result_content, "role": "user"})
- elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated"]):
+ elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "omniparser + gemini-2.0-flash", "omniparser + gemini-2.5-flash-preview-04-17", "omniparser + gpt-4o-orchestrated", "omniparser + o1-orchestrated", "omniparser + o3-mini-orchestrated", "omniparser + R1-orchestrated", "omniparser + qwen2.5vl-orchestrated", "omniparser + gemini-2.0-flash-orchestrated", "omniparser + gemini-2.0-flash-thinking-exp-orchestrated"]):
while True:
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
diff --git a/omnitool/gradio/tools/.DS_Store b/omnitool/gradio/tools/.DS_Store
new file mode 100644
index 00000000..c47b0595
Binary files /dev/null and b/omnitool/gradio/tools/.DS_Store differ
diff --git a/omnitool/gradio/tools/computer.py b/omnitool/gradio/tools/computer.py
index 6b91bad2..f54f0b11 100644
--- a/omnitool/gradio/tools/computer.py
+++ b/omnitool/gradio/tools/computer.py
@@ -2,6 +2,10 @@
import time
from enum import StrEnum
from typing import Literal, TypedDict
+import threading
+import shlex
+import os
+import subprocess
from PIL import Image
@@ -12,6 +16,8 @@
import requests
import re
+computer_control_lock = threading.Lock()
+
OUTPUT_DIR = "./tmp/outputs"
TYPING_DELAY_MS = 12
@@ -88,9 +94,11 @@ def options(self) -> ComputerToolOptions:
def to_params(self) -> BetaToolComputerUse20241022Param:
return {"name": self.name, "type": self.api_type, **self.options}
- def __init__(self, is_scaling: bool = False):
+ def __init__(self, args, is_scaling: bool = False):
super().__init__()
+ self.args = args
+
# Get screen width and height using Windows command
self.display_num = None
self.offset_x = 0
@@ -141,11 +149,11 @@ async def __call__(
print(f"mouse move to {x}, {y}")
if action == "mouse_move":
- self.send_to_vm(f"pyautogui.moveTo({x}, {y})")
+ self.send_to_host_device(f"pyautogui.moveTo({x}, {y})")
return ToolResult(output=f"Moved mouse to ({x}, {y})")
elif action == "left_click_drag":
- current_x, current_y = self.send_to_vm("pyautogui.position()")
- self.send_to_vm(f"pyautogui.dragTo({x}, {y}, duration=0.5)")
+ current_x, current_y = self.send_to_host_device("pyautogui.position()")
+ self.send_to_host_device(f"pyautogui.dragTo({x}, {y}, duration=0.5)")
return ToolResult(output=f"Dragged mouse from ({current_x}, {current_y}) to ({x}, {y})")
if action in ("key", "type"):
@@ -162,18 +170,18 @@ async def __call__(
for key in keys:
key = self.key_conversion.get(key.strip(), key.strip())
key = key.lower()
- self.send_to_vm(f"pyautogui.keyDown('{key}')") # Press down each key
+ self.send_to_host_device(f"pyautogui.keyDown('{key}')") # Press down each key
for key in reversed(keys):
key = self.key_conversion.get(key.strip(), key.strip())
key = key.lower()
- self.send_to_vm(f"pyautogui.keyUp('{key}')") # Release each key in reverse order
+ self.send_to_host_device(f"pyautogui.keyUp('{key}')") # Release each key in reverse order
return ToolResult(output=f"Pressed keys: {text}")
elif action == "type":
# default click before type TODO: check if this is needed
- self.send_to_vm("pyautogui.click()")
- self.send_to_vm(f"pyautogui.typewrite('{text}', interval={TYPING_DELAY_MS / 1000})")
- self.send_to_vm("pyautogui.press('enter')")
+ self.send_to_host_device("pyautogui.click()")
+ self.send_to_host_device(f"pyautogui.typewrite('{text}', interval={TYPING_DELAY_MS / 1000})")
+ self.send_to_host_device("pyautogui.press('enter')")
screenshot_base64 = (await self.screenshot()).base64_image
return ToolResult(output=text, base64_image=screenshot_base64)
@@ -194,28 +202,28 @@ async def __call__(
if action == "screenshot":
return await self.screenshot()
elif action == "cursor_position":
- x, y = self.send_to_vm("pyautogui.position()")
+ x, y = self.send_to_host_device("pyautogui.position()")
x, y = self.scale_coordinates(ScalingSource.COMPUTER, x, y)
return ToolResult(output=f"X={x},Y={y}")
else:
if action == "left_click":
- self.send_to_vm("pyautogui.click()")
+ self.send_to_host_device("pyautogui.click()")
elif action == "right_click":
- self.send_to_vm("pyautogui.rightClick()")
+ self.send_to_host_device("pyautogui.rightClick()")
elif action == "middle_click":
- self.send_to_vm("pyautogui.middleClick()")
+ self.send_to_host_device("pyautogui.middleClick()")
elif action == "double_click":
- self.send_to_vm("pyautogui.doubleClick()")
+ self.send_to_host_device("pyautogui.doubleClick()")
elif action == "left_press":
- self.send_to_vm("pyautogui.mouseDown()")
+ self.send_to_host_device("pyautogui.mouseDown()")
time.sleep(1)
- self.send_to_vm("pyautogui.mouseUp()")
+ self.send_to_host_device("pyautogui.mouseUp()")
return ToolResult(output=f"Performed {action}")
if action in ("scroll_up", "scroll_down"):
if action == "scroll_up":
- self.send_to_vm("pyautogui.scroll(100)")
+ self.send_to_host_device("pyautogui.scroll(100)")
elif action == "scroll_down":
- self.send_to_vm("pyautogui.scroll(-100)")
+ self.send_to_host_device("pyautogui.scroll(-100)")
return ToolResult(output=f"Performed {action}")
if action == "hover":
return ToolResult(output=f"Performed {action}")
@@ -224,9 +232,9 @@ async def __call__(
return ToolResult(output=f"Performed {action}")
raise ToolError(f"Invalid action: {action}")
- def send_to_vm(self, action: str):
+ def send_to_host_device(self, action: str):
"""
- Executes a python command on the server. Only return tuple of x,y when action is "pyautogui.position()"
+ Executes a python command on the host device. Only return tuple of x,y when action is "pyautogui.position()"
"""
prefix = "import pyautogui; pyautogui.FAILSAFE = False;"
command_list = ["python", "-c", f"{prefix} {action}"]
@@ -236,18 +244,25 @@ def send_to_vm(self, action: str):
try:
print(f"sending to vm: {command_list}")
- response = requests.post(
- f"http://localhost:5000/execute",
- headers={'Content-Type': 'application/json'},
- json={"command": command_list},
- timeout=90
- )
+
+ if self.args.host_device == "omnibox_windows":
+ response = requests.post(
+ f"http://localhost:5000/execute",
+ headers={'Content-Type': 'application/json'},
+ json={"command": command_list},
+ timeout=90
+ )
+ if response.status_code != 200:
+ raise ToolError(f"Failed to execute command. Status code: {response.status_code}")
+ output = response.json()['output'].strip()
+ elif self.args.host_device == "local":
+ response = self.execute(command_list)
+ output = response['output'].strip()
+
time.sleep(0.7) # avoid async error as actions take time to complete
print(f"action executed")
- if response.status_code != 200:
- raise ToolError(f"Failed to execute command. Status code: {response.status_code}")
+
if parse:
- output = response.json()['output'].strip()
match = re.search(r'Point\(x=(\d+),\s*y=(\d+)\)', output)
if not match:
raise ToolError(f"Could not parse coordinates from output: {output}")
@@ -255,13 +270,38 @@ def send_to_vm(self, action: str):
return x, y
except requests.exceptions.RequestException as e:
raise ToolError(f"An error occurred while trying to execute the command: {str(e)}")
+
+ def execute(self, command, shell=False):
+ with computer_control_lock:
+ if isinstance(command, str) and not shell:
+ command = shlex.split(command)
+
+ # Expand user directory
+ for i, arg in enumerate(command):
+ if arg.startswith("~/"):
+ command[i] = os.path.expanduser(arg)
+
+ # Execute the command without any safety checks.
+ try:
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell, text=True, timeout=120)
+ return {
+ 'status': 'success',
+ 'output': result.stdout,
+ 'error': result.stderr,
+ 'returncode': result.returncode
+ }
+ except Exception as e:
+ return {
+ 'status': 'error',
+ 'message': str(e)
+ }
async def screenshot(self):
if not hasattr(self, 'target_dimension'):
screenshot = self.padding_image(screenshot)
self.target_dimension = MAX_SCALING_TARGETS["WXGA"]
width, height = self.target_dimension["width"], self.target_dimension["height"]
- screenshot, path = get_screenshot(resize=True, target_width=width, target_height=height)
+ screenshot, path = get_screenshot(host_device=self.args.host_device, resize=True, target_width=width, target_height=height)
time.sleep(0.7) # avoid async error as actions take time to complete
return ToolResult(base64_image=base64.b64encode(path.read_bytes()).decode())
@@ -310,16 +350,21 @@ def scale_coordinates(self, source: ScalingSource, x: int, y: int):
def get_screen_size(self):
"""Return width and height of the screen"""
try:
- response = requests.post(
- f"http://localhost:5000/execute",
- headers={'Content-Type': 'application/json'},
- json={"command": ["python", "-c", "import pyautogui; print(pyautogui.size())"]},
- timeout=90
- )
- if response.status_code != 200:
- raise ToolError(f"Failed to get screen size. Status code: {response.status_code}")
+ if self.args.host_device == "omnibox_windows":
+ response = requests.post(
+ f"http://localhost:5000/execute",
+ headers={'Content-Type': 'application/json'},
+ json={"command": ["python", "-c", "import pyautogui; print(pyautogui.size())"]},
+ timeout=90
+ )
+
+ if response.status_code != 200:
+ raise ToolError(f"Failed to get screen size. Status code: {response.status_code}")
+ output = response.json()['output'].strip()
+ elif self.args.host_device == "local":
+ response = self.execute(["python", "-c", "import pyautogui; print(pyautogui.size())"])
+ output = response['output'].strip()
- output = response.json()['output'].strip()
match = re.search(r'Size\(width=(\d+),\s*height=(\d+)\)', output)
if not match:
raise ToolError(f"Could not parse screen size from output: {output}")
diff --git a/omnitool/gradio/tools/cursor.png b/omnitool/gradio/tools/cursor.png
new file mode 100644
index 00000000..d3a3c5bb
Binary files /dev/null and b/omnitool/gradio/tools/cursor.png differ
diff --git a/omnitool/gradio/tools/screen_capture.py b/omnitool/gradio/tools/screen_capture.py
index 1c1ad04a..b3b25daa 100644
--- a/omnitool/gradio/tools/screen_capture.py
+++ b/omnitool/gradio/tools/screen_capture.py
@@ -1,28 +1,42 @@
+import os
from pathlib import Path
from uuid import uuid4
import requests
from PIL import Image
from .base import BaseAnthropicTool, ToolError
from io import BytesIO
+import pyautogui
OUTPUT_DIR = "./tmp/outputs"
-def get_screenshot(resize: bool = False, target_width: int = 1920, target_height: int = 1080):
+def get_screenshot(host_device: str, resize = False, target_width: int = 1920, target_height: int = 1080):
"""Capture screenshot by requesting from HTTP endpoint - returns native resolution unless resized"""
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"screenshot_{uuid4().hex}.png"
try:
- response = requests.get('http://localhost:5000/screenshot')
- if response.status_code != 200:
- raise ToolError(f"Failed to capture screenshot: HTTP {response.status_code}")
-
- # (1280, 800)
- screenshot = Image.open(BytesIO(response.content))
-
- if resize and screenshot.size != (target_width, target_height):
- screenshot = screenshot.resize((target_width, target_height))
+ if host_device == "omnibox_windows":
+ response = requests.get('http://localhost:5000/screenshot')
+ if response.status_code != 200:
+ raise ToolError(f"Failed to capture screenshot: HTTP {response.status_code}")
+ # (1280, 800)
+ screenshot = Image.open(BytesIO(response.content))
+ if resize and screenshot.size != (target_width, target_height):
+ screenshot = screenshot.resize((target_width, target_height))
+ elif host_device == "local":
+ screenshot = pyautogui.screenshot()
+ size = pyautogui.size()
+
+ screenshot = screenshot.resize((size.width, size.height))
+
+ cursor_path = os.path.join(os.path.dirname(__file__), "cursor.png")
+ cursor_x, cursor_y = pyautogui.position()
+ cursor = Image.open(cursor_path)
+ # make the cursor smaller
+ cursor = cursor.resize((int(cursor.width / 1.5), int(cursor.height / 1.5)))
+ screenshot.paste(cursor, (cursor_x, cursor_y), cursor)
+
screenshot.save(path)
return screenshot, path
except Exception as e:
diff --git a/omnitool/omniparserserver/omniparserserver.py b/omnitool/omniparserserver/omniparserserver.py
index 49fb306f..cfa50c7d 100644
--- a/omnitool/omniparserserver/omniparserserver.py
+++ b/omnitool/omniparserserver/omniparserserver.py
@@ -12,13 +12,14 @@
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root_dir)
from util.omniparser import Omniparser
+from util.utils import detect_device
def parse_arguments():
parser = argparse.ArgumentParser(description='Omniparser API')
parser.add_argument('--som_model_path', type=str, default='../../weights/icon_detect/model.pt', help='Path to the som model')
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
parser.add_argument('--caption_model_path', type=str, default='../../weights/icon_caption_florence', help='Path to the caption model')
- parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
+ parser.add_argument('--device', type=str, default=detect_device(), help='Device to run the model')
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
parser.add_argument('--port', type=int, default=8000, help='Port for the API')
diff --git a/requirements.txt b/requirements.txt
index 901a27fa..b58b5423 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -29,4 +29,5 @@ google-auth<3,>=2
screeninfo
uiautomation
dashscope
-groq
\ No newline at end of file
+groq
+google-genai
\ No newline at end of file
diff --git a/util/omniparser.py b/util/omniparser.py
index 536385e6..5610d5d2 100644
--- a/util/omniparser.py
+++ b/util/omniparser.py
@@ -1,4 +1,4 @@
-from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
+from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box, detect_device
import torch
from PIL import Image
import io
@@ -7,7 +7,7 @@
class Omniparser(object):
def __init__(self, config: Dict):
self.config = config
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ device = detect_device()
self.som_model = get_yolo_model(model_path=config['som_model_path'])
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
diff --git a/util/utils.py b/util/utils.py
index eb7c8b25..7e91258c 100644
--- a/util/utils.py
+++ b/util/utils.py
@@ -43,10 +43,19 @@
import torchvision.transforms as T
from util.box_annotator import BoxAnnotator
+def detect_device() -> str:
+ if torch.cuda.is_available():
+ print("[+] Using CUDA")
+ return "cuda"
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ print("[+] Using MPS for Apple Silicon")
+ return "mps"
+ else:
+ return "cpu"
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
- device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = detect_device()
if model_name == "blip2":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
@@ -107,7 +116,7 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
start = time.time()
batch = croped_pil_image[i:i+batch_size]
t1 = time.time()
- if model.device.type == 'cuda':
+ if model.device.type == 'cuda' or model.device.type == 'mps':
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)