Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ util/__pycache__/
index.html?linkid=2289031
wget-log
weights/icon_caption_florence_v2/
omnitool/gradio/uploads/
omnitool/gradio/uploads/
**/.DS_Store
4 changes: 2 additions & 2 deletions gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


import base64, os
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, detect_device
import torch
from PIL import Image

Expand All @@ -27,7 +27,7 @@
OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
"""

DEVICE = torch.device('cuda')
DEVICE = torch.device(detect_device())

# @spaces.GPU
# @torch.inference_mode()
Expand Down
Binary file added omnitool/.DS_Store
Binary file not shown.
Binary file added omnitool/gradio/.DS_Store
Binary file not shown.
Binary file added omnitool/gradio/agent/.DS_Store
Binary file not shown.
5 changes: 3 additions & 2 deletions omnitool/gradio/agent/anthropic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ class APIProvider(StrEnum):
VERTEX = "vertex"

SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* You are utilizing a Windows system with internet access.
* You are utilizing a {platform.system()} system with internet access.
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
</SYSTEM_CAPABILITY>
"""

class AnthropicActor:
def __init__(
self,
args,
model: str,
provider: APIProvider,
api_key: str,
Expand All @@ -62,7 +63,7 @@ def __init__(
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images

self.tool_collection = ToolCollection(ComputerTool())
self.tool_collection = ToolCollection(ComputerTool(args=args))

self.system = SYSTEM_PROMPT

Expand Down
72 changes: 72 additions & 0 deletions omnitool/gradio/agent/llm_utils/geminiclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
from google import genai
from google.genai import types
from pydantic import BaseModel, Field
from typing import Optional
from PIL import Image
from pprint import pprint

from .utils import is_image_path, encode_image

class Action(BaseModel):
reasoning: str = Field(..., alias="Reasoning")
next_action: str = Field(..., alias="Next Action")
box_id: str | None = Field(None, alias="Box ID")
value: str | None = None

def run_gemini_interleaved(messages: list, system: str, model_name: str, api_key: str, max_tokens: int, temperature=0):
"""
Run a chat completion through Google Gemini's API
"""
api_key = api_key or os.environ.get("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY is not set")

client = genai.Client(
api_key=api_key,
)

generate_content_config = types.GenerateContentConfig(
temperature=temperature,
max_output_tokens=max_tokens,
response_mime_type="application/json",
response_schema=Action,
system_instruction=[
types.Part.from_text(text=system),
],
)

contents = []

if type(messages) == list:
for item in messages:
if isinstance(item, dict):
for cnt in item["content"]:
if isinstance(cnt, str):
if is_image_path(cnt):
contents.append(Image.open(cnt))
else:
contents.append(cnt)
else:
contents.append(str(cnt))

else: # str
contents.append(str(cnt))

elif isinstance(messages, str):
contents.push(messages)

try:
response = client.models.generate_content(
model=model_name,
contents=contents,
config=generate_content_config
)
final_answer = response.text
token_usage = response.usage_metadata.total_token_count

return final_answer, token_usage
except Exception as e:
print(f"Error in interleaved Gemini: {e}")

return str(e), 0
3 changes: 0 additions & 3 deletions omnitool/gradio/agent/llm_utils/oaiclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import logging
import base64
import requests
from .utils import is_image_path, encode_image

Expand Down
4 changes: 3 additions & 1 deletion omnitool/gradio/agent/llm_utils/omniparserclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

class OmniParserClient:
def __init__(self,
host_device: str,
url: str) -> None:
self.host_device = host_device
self.url = url

def __call__(self,):
screenshot, screenshot_path = get_screenshot()
screenshot, screenshot_path = get_screenshot(host_device=self.host_device)
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
response = requests.post(self.url, json={"base64_image": image_base64})
Expand Down
33 changes: 31 additions & 2 deletions omnitool/gradio/agent/vlm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from PIL import Image, ImageDraw
import base64
from io import BytesIO
import platform

from anthropic import APIResponse
from anthropic.types import ToolResultBlockParam
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage

from agent.llm_utils.oaiclient import run_oai_interleaved
from agent.llm_utils.groqclient import run_groq_interleaved
from agent.llm_utils.geminiclient import run_gemini_interleaved
from agent.llm_utils.utils import is_image_path
import time
import re
Expand Down Expand Up @@ -49,6 +51,10 @@ def __init__(
self.model = "o1"
elif model == "omniparser + o3-mini":
self.model = "o3-mini"
elif model == "omniparser + gemini-2.0-flash":
self.model = "gemini-2.0-flash"
elif model == "omniparser + gemini-2.5-flash-preview-04-17":
self.model = "gemini-2.5-flash-preview-04-17"
else:
raise ValueError(f"Model {model} not supported")

Expand Down Expand Up @@ -133,6 +139,18 @@ def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
print(f"qwen token usage: {token_usage}")
self.total_token_usage += token_usage
self.total_cost += (token_usage * 2.2 / 1000000) # https://help.aliyun.com/zh/model-studio/getting-started/models?spm=a2c4g.11186623.0.0.74b04823CGnPv7#fe96cfb1a422a
elif "gemini" in self.model:
vlm_response, token_usage = run_gemini_interleaved(
messages=planner_messages,
system=system,
model_name=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
temperature=0,
)
print(f"gemini token usage: {token_usage}")
self.total_token_usage += token_usage
self.total_cost += 0 # assume using free tier
else:
raise ValueError(f"Model {self.model} not supported")
latency_vlm = time.time() - start
Expand Down Expand Up @@ -209,9 +227,11 @@ def _api_response_callback(self, response: APIResponse):

def _get_system_prompt(self, screen_info: str = ""):
main_section = f"""
You are using a Windows device.
You are using a {platform.system()} device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
You can only interact with the desktop GUI (no terminal or application menu access).
You can only interact with the desktop GUI (no terminal or application menu access)

!!!DO NOT interact with the chatbot webpage interface that opens in 0.0.0.0:7888. You don't need to click the orange send button because the user already clicked it!!!

You may be given some history plan and actions, this is the response from the previous loop.
You should carefully consider your plan base on the task, screenshot, and history actions.
Expand All @@ -230,6 +250,15 @@ def _get_system_prompt(self, screen_info: str = ""):

Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.

Use this JSON schema:

Action = {{
"Reasoning": str,
"Next Action": str,
"Box ID": str | None,
"value": str | None
}}

Output format:
```json
{{
Expand Down
Loading