diff --git a/.gitignore b/.gitignore index 8b8235e6..a0c7ff8c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ util/__pycache__/ index.html?linkid=2289031 wget-log weights/icon_caption_florence_v2/ -omnitool/gradio/uploads/ \ No newline at end of file +omnitool/gradio/uploads/ +.env \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 00000000..bdae8bcd --- /dev/null +++ b/main.py @@ -0,0 +1,53 @@ +from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request +from typing import Optional +import torch +from PIL import Image +import io +import base64 +import os +from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img +import logging + +# --- Basic Logging Setup --- +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +# Initialize models +logger.info("Initializing models...") +yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') +caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence") +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +logger.info(f"Models initialized. Using device: {DEVICE}") + +app = FastAPI(title="OmniParser API") + +def process_image(image: Image.Image, box_threshold: float, iou_threshold: float, + use_paddleocr: bool, imgsz: int) -> str: + logger.info(f"Processing image with params: box_threshold={box_threshold}, iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}") + ocr_bbox_rslt, _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=use_paddleocr) + text, ocr_bbox = ocr_bbox_rslt + _, _, parsed_content_list = get_som_labeled_img( + image, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, + ocr_bbox=ocr_bbox, draw_bbox_config={}, caption_model_processor=caption_model_processor, + ocr_text=text, iou_threshold=iou_threshold, imgsz=imgsz + ) + return '\n'.join([f'icon {i}: ' + str(v) for i, v in enumerate(parsed_content_list)]) + +@app.post("/process") +async def process_endpoint( + file: UploadFile = File(...), + box_threshold: float = Form(...), + iou_threshold: float = Form(...), + use_paddleocr: str = Form(...), + imgsz: int = Form(...) +): + use_paddleocr_bool = use_paddleocr.lower() in ('true', '1') + parsed_content = process_image(Image.open(file.file).convert("RGB"), box_threshold, iou_threshold, use_paddleocr_bool, imgsz) + return { "parsed_elements": parsed_content } + +if __name__ == "__main__": + import uvicorn + logger.info("Starting OmniParser API server for local testing...") + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/omnitool/omniparserserver/omniparserserver.py b/omnitool/omniparserserver/omniparserserver.py index 045fbace..781e4b67 100644 --- a/omnitool/omniparserserver/omniparserserver.py +++ b/omnitool/omniparserserver/omniparserserver.py @@ -19,7 +19,7 @@ def parse_arguments(): parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model') parser.add_argument('--caption_model_path', type=str, default='../../weights/icon_caption_florence', help='Path to the caption model') parser.add_argument('--device', type=str, default='cpu', help='Device to run the model') - parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection') + parser.add_argument('--BOX_TRESHOLD', type=float, default=0.01, help='Threshold for box detection') parser.add_argument('--host', type=str, default='127.0.0.1', help='Host for the API') parser.add_argument('--port', type=int, default=8000, help='Port for the API') args = parser.parse_args() diff --git a/requirements.txt b/requirements.txt index 901a27fa..b8b09d69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,31 @@ torch easyocr torchvision -supervision==0.18.0 -openai==1.3.5 +supervision +openai transformers -ultralytics==8.3.70 +ultralytics azure-identity -numpy==1.26.4 +numpy opencv-python opencv-python-headless gradio dill accelerate timm -einops==0.8.0 +einops paddlepaddle paddleocr -ruff==0.6.7 -pre-commit==3.8.0 -pytest==8.3.3 -pytest-asyncio==0.23.6 -pyautogui==0.9.54 -streamlit>=1.38.0 -anthropic[bedrock,vertex]>=0.37.1 -jsonschema==4.22.0 -boto3>=1.28.57 -google-auth<3,>=2 +ruff +pre-commit +pytest +pytest-asyncio +pyautogui +streamlit +anthropic[bedrock,vertex] +jsonschema +boto3 +google-auth screeninfo uiautomation dashscope diff --git a/util/utils.py b/util/utils.py index eb7c8b25..b291c4bc 100644 --- a/util/utils.py +++ b/util/utils.py @@ -21,14 +21,15 @@ from paddleocr import PaddleOCR reader = easyocr.Reader(['en']) paddle_ocr = PaddleOCR( - lang='en', # other lang also available + # lang='en', # other lang also available use_angle_cls=False, - use_gpu=False, # using cuda will conflict with pytorch in the same process - show_log=False, - max_batch_size=1024, - use_dilation=True, # improves accuracy - det_db_score_mode='slow', # improves accuracy - rec_batch_num=1024) + # use_gpu=False, # using cuda will conflict with pytorch in the same process + # show_log=False, + # max_batch_size=1024, + # use_dilation=True, # improves accuracy + # det_db_score_mode='slow', # improves accuracy + # rec_batch_num=1024) +) import time import base64 @@ -514,7 +515,7 @@ def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, out text_threshold = 0.5 else: text_threshold = easyocr_args['text_threshold'] - result = paddle_ocr.ocr(image_np, cls=False)[0] + result = paddle_ocr.ocr(image_np)[0] coord = [item[0] for item in result if item[1][1] > text_threshold] text = [item[1][0] for item in result if item[1][1] > text_threshold] else: # EasyOCR