Skip to content
Open

Issue #330

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ util/__pycache__/
index.html?linkid=2289031
wget-log
weights/icon_caption_florence_v2/
omnitool/gradio/uploads/
omnitool/gradio/uploads/
.env
53 changes: 53 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from typing import Optional
import torch
from PIL import Image
import io
import base64
import os
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import logging

# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# Initialize models
logger.info("Initializing models...")
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Models initialized. Using device: {DEVICE}")

app = FastAPI(title="OmniParser API")

def process_image(image: Image.Image, box_threshold: float, iou_threshold: float,
use_paddleocr: bool, imgsz: int) -> str:
logger.info(f"Processing image with params: box_threshold={box_threshold}, iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}")
ocr_bbox_rslt, _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=use_paddleocr)
text, ocr_bbox = ocr_bbox_rslt
_, _, parsed_content_list = get_som_labeled_img(
image, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True,
ocr_bbox=ocr_bbox, draw_bbox_config={}, caption_model_processor=caption_model_processor,
ocr_text=text, iou_threshold=iou_threshold, imgsz=imgsz
)
return '\n'.join([f'icon {i}: ' + str(v) for i, v in enumerate(parsed_content_list)])

@app.post("/process")
async def process_endpoint(
file: UploadFile = File(...),
box_threshold: float = Form(...),
iou_threshold: float = Form(...),
use_paddleocr: str = Form(...),
imgsz: int = Form(...)
):
use_paddleocr_bool = use_paddleocr.lower() in ('true', '1')
parsed_content = process_image(Image.open(file.file).convert("RGB"), box_threshold, iou_threshold, use_paddleocr_bool, imgsz)
return { "parsed_elements": parsed_content }

if __name__ == "__main__":
import uvicorn
logger.info("Starting OmniParser API server for local testing...")
uvicorn.run(app, host="0.0.0.0", port=8000)
2 changes: 1 addition & 1 deletion omnitool/omniparserserver/omniparserserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def parse_arguments():
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
parser.add_argument('--caption_model_path', type=str, default='../../weights/icon_caption_florence', help='Path to the caption model')
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.01, help='Threshold for box detection')
parser.add_argument('--host', type=str, default='127.0.0.1', help='Host for the API')
parser.add_argument('--port', type=int, default=8000, help='Port for the API')
args = parser.parse_args()
Expand Down
30 changes: 15 additions & 15 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
torch
easyocr
torchvision
supervision==0.18.0
openai==1.3.5
supervision
openai
transformers
ultralytics==8.3.70
ultralytics
azure-identity
numpy==1.26.4
numpy
opencv-python
opencv-python-headless
gradio
dill
accelerate
timm
einops==0.8.0
einops
paddlepaddle
paddleocr
ruff==0.6.7
pre-commit==3.8.0
pytest==8.3.3
pytest-asyncio==0.23.6
pyautogui==0.9.54
streamlit>=1.38.0
anthropic[bedrock,vertex]>=0.37.1
jsonschema==4.22.0
boto3>=1.28.57
google-auth<3,>=2
ruff
pre-commit
pytest
pytest-asyncio
pyautogui
streamlit
anthropic[bedrock,vertex]
jsonschema
boto3
google-auth
screeninfo
uiautomation
dashscope
Expand Down
17 changes: 9 additions & 8 deletions util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from paddleocr import PaddleOCR
reader = easyocr.Reader(['en'])
paddle_ocr = PaddleOCR(
lang='en', # other lang also available
# lang='en', # other lang also available
use_angle_cls=False,
use_gpu=False, # using cuda will conflict with pytorch in the same process
show_log=False,
max_batch_size=1024,
use_dilation=True, # improves accuracy
det_db_score_mode='slow', # improves accuracy
rec_batch_num=1024)
# use_gpu=False, # using cuda will conflict with pytorch in the same process
# show_log=False,
# max_batch_size=1024,
# use_dilation=True, # improves accuracy
# det_db_score_mode='slow', # improves accuracy
# rec_batch_num=1024)
)
import time
import base64

Expand Down Expand Up @@ -514,7 +515,7 @@ def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, out
text_threshold = 0.5
else:
text_threshold = easyocr_args['text_threshold']
result = paddle_ocr.ocr(image_np, cls=False)[0]
result = paddle_ocr.ocr(image_np)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
else: # EasyOCR
Expand Down