diff --git a/.github/workflows/build_python.yml b/.github/workflows/build_python.yml index 07e8651..1a1e296 100644 --- a/.github/workflows/build_python.yml +++ b/.github/workflows/build_python.yml @@ -55,9 +55,9 @@ jobs: working-directory: ${{ inputs.directory }} run: ruff check . - # - name: Run tests - # working-directory: ${{ inputs.directory }} - # run: pytest --maxfail=3 --disable-warnings --tb=short + - name: Run tests + working-directory: ${{ inputs.directory }} + run: pytest tests --maxfail=3 --disable-warnings --tb=short - name: Save pip cache if: ${{ !cancelled() }} diff --git a/ocr/.gitignore b/ocr/.gitignore index 431009f..2f1f04a 100644 --- a/ocr/.gitignore +++ b/ocr/.gitignore @@ -26,4 +26,10 @@ model_training/data_backup model_training/input_backup -model_training/pipeline.sh \ No newline at end of file +model_training/pipeline.sh + +model_training/test_results + +*.json + +models/postprocessing_models \ No newline at end of file diff --git a/ocr/app/backend_client.py b/ocr/app/backend_client.py index 55eceb7..c93f57a 100644 --- a/ocr/app/backend_client.py +++ b/ocr/app/backend_client.py @@ -1,11 +1,18 @@ import base64 +from typing import Any import requests API_BASE = "/backend/api/v1" -def get_format(backend_url, auth_token, format_name=None, format_id=None, timeout=10): +def get_format( + backend_url: str, + auth_token: str | None, + format_name: str | None = None, + format_id: int | None = None, + timeout: int = 10, +) -> dict[str, Any] | None: url = f"{backend_url.rstrip('/')}{API_BASE}/formats" headers = {} if auth_token: @@ -29,15 +36,15 @@ def get_format(backend_url, auth_token, format_name=None, format_id=None, timeou def send_file( - backend_url, - auth_token, - owner_id, - format_id, - generation, - content_bytes, - primary_file_id=None, - timeout=15, -): + backend_url: str, + auth_token: str, + owner_id: int, + format_id: int, + generation: int, + content_bytes: bytes, + primary_file_id: int | None = None, + timeout: int = 15, +) -> dict[str, Any]: if not backend_url: raise ValueError("backend_url is required") diff --git a/ocr/app/file_converter.py b/ocr/app/file_converter.py index 70051f3..cafe550 100644 --- a/ocr/app/file_converter.py +++ b/ocr/app/file_converter.py @@ -1,6 +1,8 @@ import io import logging +from collections.abc import Iterable from pathlib import Path +from typing import Any import fitz from docx import Document @@ -20,7 +22,7 @@ OUT_DIR = SCRIPT_DIR / ".." / "temp" / "pdf_pages" -def pil_to_pixmap(image): +def pil_to_pixmap(image: Image.Image) -> fitz.Pixmap: if image.mode == "1": image = image.convert("L") @@ -42,7 +44,7 @@ def pil_to_pixmap(image): return fitz.Pixmap(colorspace, width, height, samples, alpha) -def initialize_pdf_with_image(image, visible_image=True): +def initialize_pdf_with_image(image: Image.Image, visible_image: bool = True) -> fitz.Document: pdf_doc = fitz.open() rect = fitz.Rect(0, 0, image.width, image.height) page = pdf_doc.new_page(width=image.width, height=image.height) @@ -52,14 +54,14 @@ def initialize_pdf_with_image(image, visible_image=True): return pdf_doc -def measure_text_single_line(text, fontsize=11, fontname="helv"): +def measure_text_single_line(text: str, fontsize: int = 11, fontname: str = "helv") -> tuple[float, int]: font = fitz.Font(fontname) width = font.text_length(text, fontsize=fontsize) height = fontsize return width, height -def find_fontsize(line_height, line_width, text, fontname="helv"): +def find_fontsize(line_height: int, line_width: int, text: str, fontname: str = "helv") -> int: min_fontsize = 1 max_fontsize = line_height @@ -79,7 +81,13 @@ def find_fontsize(line_height, line_width, text, fontname="helv"): return int(fontsize) - 1 -def insert_text_at_bbox(pdf_doc, text, bbox, visible_image=True, draw_rect=False): +def insert_text_at_bbox( + pdf_doc: fitz.Document, + text: str, + bbox: Iterable[int], + visible_image: bool = True, + draw_rect: bool = False, +) -> None: page = pdf_doc[0] x0, y0, x1, y1 = bbox rect = fitz.Rect(x0, y0, x1, y1) @@ -98,22 +106,24 @@ def insert_text_at_bbox(pdf_doc, text, bbox, visible_image=True, draw_rect=False page.insert_text(point, text, fontsize=fs, fontname="helv", color=0, fill_opacity=1, overlay=True) -def pdf_to_bytes(pdf_doc): +def pdf_to_bytes(pdf_doc: fitz.Document) -> bytes: return pdf_doc.write() -def save_docx_to_path(docx_bytes, output_path) : - with open(output_path, "wb") as f : + +def save_docx_to_path(docx_bytes: bytes, output_path: Path) -> None: + with open(output_path, "wb") as f: f.write(docx_bytes) -def pdf_to_docx_bytes(pdf_doc) : + +def pdf_to_docx_bytes(pdf_doc: fitz.Document) -> bytes: pdf_bytes = pdf_to_bytes(pdf_doc) laparams = LAParams() text = extract_text(io.BytesIO(pdf_bytes), laparams=laparams) doc = Document() - for line in text.splitlines() : - if line.strip() == "" : + for line in text.splitlines(): + if line.strip() == "": continue doc.add_paragraph(line) @@ -121,7 +131,13 @@ def pdf_to_docx_bytes(pdf_doc) : doc.save(buf) return buf.getvalue() -def convert_to_png_bytes(input_bytes, input_format, debug=False, debug_indent=0): + +def convert_to_png_bytes( + input_bytes: bytes, + input_format: dict[str, Any], + debug: bool = False, + debug_indent: int = 0, +) -> bytes: if debug: logging.debug(get_frontline(debug_indent) + f"Converting input format '{input_format}' to PNG bytes") @@ -143,8 +159,9 @@ def convert_to_png_bytes(input_bytes, input_format, debug=False, debug_indent=0) elif input_format["format"] in ["jpeg", "jpg", "tiff", "bmp", "gif"]: if debug: - logging.debug(get_frontline(debug_indent) + - f"Converting image format '{input_format['format']}' to PNG using PIL") + logging.debug( + get_frontline(debug_indent) + f"Converting image format '{input_format['format']}' to PNG using PIL" + ) im = Image.open(io.BytesIO(input_bytes)) with io.BytesIO() as output: im.save(output, format="PNG") diff --git a/ocr/app/main.py b/ocr/app/main.py index 5967c5a..15767a9 100644 --- a/ocr/app/main.py +++ b/ocr/app/main.py @@ -1,7 +1,7 @@ import base64 -import json import logging import os +from typing import Any from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request @@ -38,7 +38,7 @@ class IncomingFile(BaseModel): @field_validator("content") @classmethod - def validate_base64(cls, v): + def validate_base64(cls, v: str) -> str: try: base64.b64decode(v, validate=True) except Exception as e: @@ -46,7 +46,8 @@ def validate_base64(cls, v): raise ValueError(f"Invalid base64 content: {e}") from e return v -def strip_content(data): + +def strip_content(data: Any) -> Any: try: if "content" in data: data["content"] = "[SKIPPED]" @@ -55,9 +56,10 @@ def strip_content(data): data = "[NON-JSON BODY]" return data -def find_correct_backend_url(auth_header, format_id): + +def find_correct_backend_url(auth_header: str | None, format_id: int) -> str | None: global BACKEND_URL - if BACKEND_URL is None : + if BACKEND_URL is None: try: backend_base_url = os.getenv("BACKEND_BASE_URL_DOCKER") get_format(backend_base_url, auth_header, format_id=format_id) @@ -71,31 +73,32 @@ def find_correct_backend_url(auth_header, format_id): logging.critical("Failed to find backend URL: %s", e) return BACKEND_URL + @app.get("/health") -def health(): +def health() -> dict[str, str]: try: _ = get_model_list() return {"status": "ok"} except Exception as e: return {"status": "error", "detail": str(e)} + @app.post("/ocr/process") -async def handle_file(payload: IncomingFile, request: Request): - - raw = await request.body() - - logging.debug("Full request (content skipped): %s\n", strip_content(json.loads(raw))) - - logging.debug( - "Received file: id=%s, ownerId=%s formatId=%s generation=%s primaryFileId=%s model_id=%s size_b64=%d", - payload.id, - payload.ownerId, - payload.formatId, - payload.generation, - payload.primaryFileId, - payload.processingModelId, - len(payload.content), - ) +async def handle_file(payload: IncomingFile, request: Request) -> dict[str, Any]: + await request.body() + + # logging.info("Full request (content skipped): %s\n", strip_content(json.loads(raw))) + + # logging.info( + # "Received file: id=%s, ownerId=%s formatId=%s generation=%s primaryFileId=%s model_id=%s size_b64=%d", + # payload.id, + # payload.ownerId, + # payload.formatId, + # payload.generation, + # payload.primaryFileId, + # payload.processingModelId, + # len(payload.content), + # ) auth_header = request.headers.get("authorization") @@ -134,7 +137,6 @@ async def handle_file(payload: IncomingFile, request: Request): logging.info("Sent OCR result back to backend, got response: %s", strip_content(result)) - out_docx_format = get_format(backend_base_url, auth_header, format_name="docx") if not out_docx_format: logging.critical("DOCX format not found in backend formats") @@ -156,7 +158,7 @@ async def handle_file(payload: IncomingFile, request: Request): @app.get("/ocr/available_models") -def available_models(): +def available_models() -> list[dict[str, Any]]: try: return [{k: v for k, v in model.items() if k != "handle"} for model in get_model_list()] except Exception as e: diff --git a/ocr/app/module_loading.py b/ocr/app/module_loading.py index d31ce36..84f0248 100644 --- a/ocr/app/module_loading.py +++ b/ocr/app/module_loading.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path +from types import ModuleType @dataclass(frozen=True) @@ -14,7 +15,7 @@ class ModelSpec: handle: Callable -def load_module_from_path(path): +def load_module_from_path(path: Path) -> ModuleType: unique_name = f"_ocr_model_{path.stem}_{abs(hash(path.as_posix()))}" spec = spec_from_file_location(unique_name, path) diff --git a/ocr/app/ocr.py b/ocr/app/ocr.py index faabf9f..8d90771 100644 --- a/ocr/app/ocr.py +++ b/ocr/app/ocr.py @@ -2,20 +2,24 @@ import io import logging +from collections.abc import Callable from pathlib import Path +from typing import Any from PIL import Image try: from app.file_converter import initialize_pdf_with_image, insert_text_at_bbox, pdf_to_bytes, pdf_to_docx_bytes from app.module_loading import load_module_from_path - from app.segmentator import debug_save, segment + from app.postprocessor import postprocess + from app.segmentator import segment from app.utils import get_frontline except Exception: try: from file_converter import initialize_pdf_with_image, insert_text_at_bbox, pdf_to_bytes, pdf_to_docx_bytes from module_loading import load_module_from_path - from segmentator import debug_save, segment + from postprocessor import postprocess + from segmentator import segment from utils import get_frontline except Exception as e: raise ImportError("Failed to import necessary modules. Ensure the package structure is correct.") from e @@ -31,7 +35,7 @@ INVERT_THRESHOLD = 0.3 -def get_model_list(): +def get_model_list() -> list[dict[str, Any]]: global MODEL_LIST if MODEL_LIST is not None: return MODEL_LIST @@ -62,7 +66,7 @@ def get_model_list(): return models -def get_model_handler(id, debug=False, debug_indent=0): +def get_model_handler(id: int, debug: bool = False, debug_indent: int = 0) -> Callable[..., str]: global MODEL_LIST if debug: logging.debug(get_frontline(debug_indent) + f"Retrieving handler for model ID: {id}") @@ -85,7 +89,15 @@ def get_model_handler(id, debug=False, debug_indent=0): return default_handler -def run_ocr(png_bytes, model_id, image_visibility=False, one_liner=False, debug=False, debug_indent=0): +def run_ocr( + png_bytes: bytes, + model_id: int, + image_visibility: bool = False, + one_liner: bool = False, + debug: bool = False, + debug_indent: int = 0, + use_postprocessing: bool = True, +) -> tuple[bytes, bytes]: if debug: logging.debug(get_frontline(debug_indent) + f"Starting OCR with model ID: {model_id}") im = Image.open(io.BytesIO(png_bytes)) @@ -98,22 +110,22 @@ def run_ocr(png_bytes, model_id, image_visibility=False, one_liner=False, debug= logging.debug(get_frontline(debug_indent) + f"Inverting image (dark ratio: {dark_ratio:.2f})") im = Image.eval(im, lambda x: 255 - x) - if debug: - OUT_DIR.mkdir(parents=True, exist_ok=True) - (OUT_DIR / "debug_input.png").write_bytes(png_bytes) + # if debug: + # OUT_DIR.mkdir(parents=True, exist_ok=True) + # (OUT_DIR / "debug_input.png").write_bytes(png_bytes) if not one_liner: lines = segment(im, debug=debug, frontline=get_frontline(debug_indent + 1)) - if debug: - debug_save(im, lines, save_dir=OUT_DIR, frontline=get_frontline(debug_indent + 2)) - logging.debug(get_frontline(debug_indent + 1) + "Segmentation finished") + # if debug: + # debug_save(im, lines, save_dir=OUT_DIR, frontline=get_frontline(debug_indent + 2)) + # logging.debug(get_frontline(debug_indent + 1) + "Segmentation finished") else: lines = [{"bbox": (0, 0, im.width, im.height)}] ocr_handler = get_model_handler(model_id, debug=debug, debug_indent=debug_indent + 1) lines_data: list[str] = [] - pdf_path = OUT_DIR / "ocr_overlay.pdf" + # pdf_path = OUT_DIR / "ocr_overlay.pdf" if debug: logging.debug( get_frontline(debug_indent) + f"Initializing PDF document with image visibility set to: {image_visibility}" @@ -134,25 +146,36 @@ def run_ocr(png_bytes, model_id, image_visibility=False, one_liner=False, debug= lines_data.append({"text": line_txt, "bbox": item["bbox"]}) - # optional sorting for better context for postprocessing - TODO, maybe - # lines_data.sort(key=lambda x: (x["bbox"][1], x["bbox"][0])) - # implement postprocessing later - TODO - # lines_data = postprocess(lines_data) + if use_postprocessing: + lines_txt = [item["text"] for item in lines_data] + lines_txt = postprocess(lines_txt) + for i, item in enumerate(lines_data): + try: + item["text"] = lines_txt["lines"][i] + except Exception: + logging.info(f"Line {i} wasn't postprocessed") for item in lines_data: insert_text_at_bbox(pdf_doc, item["text"], item["bbox"], visible_image=image_visibility) - if debug: - pdf_doc.save(pdf_path) - logging.debug(get_frontline(debug_indent) + f"Saved OCR overlay PDF to: {pdf_path}") + # if debug: + # pdf_doc.save(pdf_path) + # logging.debug(get_frontline(debug_indent) + f"Saved OCR overlay PDF to: {pdf_path}") pdf_bytes = pdf_to_bytes(pdf_doc) docx_bytes = pdf_to_docx_bytes(pdf_doc) return pdf_bytes, docx_bytes -def test_ocr(test_image_path, model_id=1, one_liner=False, debug=True, debug_indent=0): +def test_ocr( + test_image_path: Path, + model_id: int = 1, + one_liner: bool = False, + debug: bool = True, + debug_indent: int = 0, + use_postprocessing: bool = True, +) -> None: if debug: logging.debug(get_frontline(debug_indent) + f"Testing OCR on image: {test_image_path}") png_bytes = test_image_path.read_bytes() @@ -161,10 +184,14 @@ def test_ocr(test_image_path, model_id=1, one_liner=False, debug=True, debug_ind if __name__ == "__main__": # test_ocr(Path(__file__).resolve().parent / "../model_training/data/0000.png", 1, True) + from run import setup_logging + + setup_logging() test_ocr( Path(__file__).resolve().parent / "../model_training/data/0000.png", model_id=1, # one_liner=True, debug=True, + use_postprocessing=False, ) # get_model_list() diff --git a/ocr/app/postprocessing_models_wrappers/gemma3.py b/ocr/app/postprocessing_models_wrappers/gemma3.py index a8eb0ca..bb9023e 100644 --- a/ocr/app/postprocessing_models_wrappers/gemma3.py +++ b/ocr/app/postprocessing_models_wrappers/gemma3.py @@ -1,12 +1,27 @@ +from typing import Any + from llama_cpp.llama_grammar import LlamaGrammar, json_schema_to_gbnf -from ocr.postprocessing.model.model import Model + +try: + from app.postprocessing_models_wrappers.model import Model +except Exception: + try: + from postprocessing_models_wrappers.model import Model + except Exception as e: + raise ImportError("Failed to import Model base class. Ensure the package structure is correct.") from e class Gemma3(Model): - def __init__(self, n_gpu_layers=31): + def __init__(self, n_gpu_layers: int = 10) -> None: super().__init__(filename="google_gemma-3-12b-it-qat-Q4_0.gguf", n_gpu_layers=n_gpu_layers) - def __call__(self, request, max_tokens=200, temperature=0.7, **kwargs): + def __call__( + self, + request: str, + max_tokens: int = 200, + temperature: float = 0.7, + **kwargs: Any, + ) -> Any: return self.llm.create_completion( prompt=request, max_tokens=max_tokens, diff --git a/ocr/app/postprocessing_models_wrappers/model.py b/ocr/app/postprocessing_models_wrappers/model.py index a436908..d717060 100644 --- a/ocr/app/postprocessing_models_wrappers/model.py +++ b/ocr/app/postprocessing_models_wrappers/model.py @@ -1,26 +1,30 @@ from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any from llama_cpp import Llama +PROJECT_DIR = Path(__file__).resolve().parent.parent.parent + class Model(ABC): - def __init__(self, filename, n_gpu_layers): + def __init__(self, filename: str, n_gpu_layers: int) -> None: self.llm = Llama( - model_path=str("models" / "postprocessing_models" / filename), - n_ctx=4096 * 2, + model_path=str(PROJECT_DIR / "models" / "postprocessing_models" / filename), + n_ctx=2048, n_threads=32, n_gpu_layers=n_gpu_layers, - n_batch=512, + n_batch=64, use_mmap=True, use_mlock=False, chat_format="gemma", - verbose=False, + verbose=True, stream=False, ) @abstractmethod - def __call__(self, request): + def __call__(self, request: str) -> Any: pass - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ diff --git a/ocr/app/postprocessor.py b/ocr/app/postprocessor.py index b8c9e93..b30bd77 100644 --- a/ocr/app/postprocessor.py +++ b/ocr/app/postprocessor.py @@ -1,16 +1,26 @@ import json +from typing import Any + +try: + from app.postprocessing_models_wrappers.gemma3 import Gemma3 +except Exception: + try: + from postprocessing_models_wrappers.gemma3 import Gemma3 + except Exception as e: + raise ImportError("Failed to import Gemma3 model wrapper. Ensure the package structure is correct.") from e -from postprocessing_models_wrappers.gemma3 import Gemma3 INITIAL_PROMPT = "Jesteś pomocnym asystentem, który poprawia tekst wyekstrahowany z obrazu." -def create_query_with_context(user_input: str, context: str): + +def create_query_with_context(user_input: str, context: str) -> str: query = f"user\n{INITIAL_PROMPT}\n\n \ model\n\nuser\nThis is context: {context}\n\n \ model\n\nuser\n{user_input}\n\nmodel" return query -def postprocess(lines: list[str]): + +def postprocess(lines: list[str]) -> dict[str, Any]: model = Gemma3() query = create_query_with_context( @@ -31,19 +41,12 @@ def postprocess(lines: list[str]): schema = { "type": "object", - "properties": { - "lines": { - "type": "array", - "items": {"type": "string"} - } - }, - "required": [ - "lines" - ], - "additionalProperties": False + "properties": {"lines": {"type": "array", "items": {"type": "string"}}}, + "required": ["lines"], + "additionalProperties": False, } order = ["lines"] response = model(query, schema=json.dumps(schema), order=order)["choices"][0]["text"] - return json.loads(response) \ No newline at end of file + return json.loads(response) diff --git a/ocr/app/run.py b/ocr/app/run.py index 09c07b2..f86bcb7 100644 --- a/ocr/app/run.py +++ b/ocr/app/run.py @@ -1,61 +1,67 @@ +from __future__ import annotations + import logging import uvicorn +from fastapi import FastAPI from main import app -def setup_logging() : - RESET = "\033[0m" - COLORS = { - "DEBUG": "\033[36m", # cyan - "INFO": "\033[32m", # green - "WARNING": "\033[33m", # yellow - "ERROR": "\033[31m", # red +def setup_logging() -> None: + RESET: str = "\033[0m" + COLORS: dict[str, str] = { + "DEBUG": "\033[36m", # cyan + "INFO": "\033[32m", # green + "WARNING": "\033[33m", # yellow + "ERROR": "\033[31m", # red "CRITICAL": "\033[41m\033[97m", # white on red background } - class ColorFormatter(logging.Formatter) : - def format(self, record) : - original_levelname = record.levelname + class ColorFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + original_levelname: str = record.levelname - color = COLORS.get(original_levelname, "") - padded = f"{original_levelname:<8}" + color: str = COLORS.get(original_levelname, "") + padded: str = f"{original_levelname:<8}" record.levelname = f"{color}{padded}{RESET}" - message = super().format(record) + message: str = super().format(record) record.levelname = original_levelname return message - fmt = "%(levelname)s %(asctime)s.%(msecs)03d %(message)s" - datefmt = "%Y-%m-%d %H:%M:%S" + fmt: str = "%(levelname)s %(asctime)s.%(msecs)03d %(message)s" + datefmt: str = "%Y-%m-%d %H:%M:%S" logging.basicConfig( level=logging.DEBUG, format=fmt, - datefmt=datefmt + datefmt=datefmt, ) root_logger = logging.getLogger() - for handler in root_logger.handlers : + for handler in root_logger.handlers: handler.setFormatter(ColorFormatter(fmt, datefmt)) -def main() : + +def main() -> None: setup_logging() - for name in ["uvicorn", "uvicorn.error", "uvicorn.access"] : + for name in ["uvicorn", "uvicorn.error", "uvicorn.access"]: logger = logging.getLogger(name) logger.handlers.clear() logger.propagate = True + fastapi_app: FastAPI = app + uvicorn.run( - app, + fastapi_app, host="0.0.0.0", port=8000, log_config=None, ) -if __name__ == "__main__" : +if __name__ == "__main__": main() diff --git a/ocr/app/segmentator.py b/ocr/app/segmentator.py index bbefbba..3433143 100644 --- a/ocr/app/segmentator.py +++ b/ocr/app/segmentator.py @@ -3,6 +3,7 @@ import logging from io import BytesIO from pathlib import Path +from typing import Any import numpy as np import torch @@ -11,8 +12,8 @@ from PIL import Image SCRIPT_DIR = Path(__file__).resolve().parent -MODEL_PATH = SCRIPT_DIR / ".." / "models" / "seg_best_submitted.mlmodel" -# MODEL_PATH = SCRIPT_DIR / ".." / "models" / "seg_best.mlmodel" +# MODEL_PATH = SCRIPT_DIR / ".." / "models" / "seg_best_trained.mlmodel" +MODEL_PATH = SCRIPT_DIR / ".." / "models" / "blla_submitted.mlmodel" _SEG_MODEL = None TEXT_DIRECTION = "horizontal-lr" @@ -21,7 +22,7 @@ BBOX_LINE_WIDTH = 5 -def silence_segmentation_logs(): +def silence_segmentation_logs() -> None: logging.getLogger("kraken.blla").setLevel(logging.ERROR) logging.getLogger("kraken").setLevel(logging.ERROR) @@ -31,7 +32,7 @@ def silence_segmentation_logs(): logging.getLogger("geos").setLevel(logging.ERROR) -def _ensure_pil_image(img): +def _ensure_pil_image(img: Image.Image | bytes | bytearray | memoryview) -> Image.Image: if isinstance(img, Image.Image): return img if isinstance(img, (bytes, bytearray, memoryview)): @@ -39,7 +40,7 @@ def _ensure_pil_image(img): raise TypeError("img must be a PIL.Image.Image or bytes") -def _load_seg_model(device, seg_model_path=MODEL_PATH): +def _load_seg_model(device: str | None, seg_model_path: Path = MODEL_PATH) -> Any: global _SEG_MODEL if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -54,7 +55,7 @@ def _load_seg_model(device, seg_model_path=MODEL_PATH): return _SEG_MODEL -def _bbox_from_line(line, im_w, im_h): +def _bbox_from_line(line: Any, im_w: int, im_h: int) -> tuple[int, int, int, int]: if hasattr(line, "bbox") and line.bbox is not None: x0, y0, x1, y1 = map(int, line.bbox) elif hasattr(line, "boundary") and line.boundary: @@ -83,14 +84,14 @@ def _bbox_from_line(line, im_w, im_h): def segment_lines_from_image( - img, + img: Image.Image | bytes | bytearray | memoryview, *, - device=None, - text_direction="horizontal-lr", - pad=0, - return_mode="pil", - seg_model_path=MODEL_PATH, -): + device: str | None = None, + text_direction: str = "horizontal-lr", + pad: int = 0, + return_mode: str = "pil", + seg_model_path: Path = MODEL_PATH, +) -> list[dict[str, Any]]: im = _ensure_pil_image(img).convert("RGB") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -147,7 +148,13 @@ def segment_lines_from_image( return results -def segment(im, seg_model_path=MODEL_PATH, filter_warnings=False, debug=False, frontline=""): +def segment( + im: Image.Image | bytes | bytearray | memoryview, + seg_model_path: Path = MODEL_PATH, + filter_warnings: bool = False, + debug: bool = False, + frontline: str = "", +) -> list[dict[str, Any]]: if filter_warnings: silence_segmentation_logs() @@ -164,7 +171,12 @@ def segment(im, seg_model_path=MODEL_PATH, filter_warnings=False, debug=False, f return lines -def debug_save(im, lines, save_dir=SAVE_DIR, frontline=""): +def debug_save( + im: Image.Image, + lines: list[dict[str, Any]], + save_dir: Path = SAVE_DIR, + frontline: str = "", +) -> None: img_arr = np.array(im.convert("RGB")) logging.info(frontline + f"Found {len(lines)} lines:") @@ -234,8 +246,12 @@ def debug_save(im, lines, save_dir=SAVE_DIR, frontline=""): if __name__ == "__main__": + from run import setup_logging + + setup_logging() + print(getattr(vgsl.TorchVGSLModel.load_model(str(MODEL_PATH)).nn, "model_type", None)) - IMAGE_PATH = SCRIPT_DIR / ".." / "model_training" / "data" / "0001.png" + IMAGE_PATH = SCRIPT_DIR / ".." / "model_training" / "data" / "0000.png" if SAVE_DIR.exists(): for f in SAVE_DIR.iterdir(): @@ -249,5 +265,21 @@ def debug_save(im, lines, save_dir=SAVE_DIR, frontline=""): im = Image.open(IMAGE_PATH) + histogram = im.convert("L").histogram() + dark_ratio = sum(histogram[:128]) / sum(histogram) + + if dark_ratio > 0.3: + im = Image.eval(im, lambda x: 255 - x) + + print(im.size) + # # scale down image if too large + + # scale_factor = 50 / max(im.width, im.height) + # new_width = int(im.width * scale_factor) + # new_height = int(im.height * scale_factor) + # im = im.resize((new_width, new_height), Image.LANCZOS) + # print(f"Scaled down image to: {im.size}") + # im = im.convert("L") + lines = segment(im) debug_save(im, lines) diff --git a/ocr/app/utils.py b/ocr/app/utils.py index 8d7741e..690628a 100644 --- a/ocr/app/utils.py +++ b/ocr/app/utils.py @@ -1,5 +1,5 @@ INDENT_SIZE = 2 -def get_frontline(debug_indent): +def get_frontline(debug_indent: int) -> str: return "-" * (debug_indent * INDENT_SIZE - 1) + "> " if debug_indent > 0 else "" diff --git a/ocr/model_training/augment_data.py b/ocr/model_training/augment_data.py index e5e17a7..bb9d6be 100644 --- a/ocr/model_training/augment_data.py +++ b/ocr/model_training/augment_data.py @@ -1,6 +1,8 @@ import json import os +from collections.abc import Iterator from pathlib import Path +from typing import Any import numpy as np from PIL import Image @@ -12,7 +14,7 @@ DATA_DIR = os.path.join(SCRIPT_DIR, "data") -def get_next(mask): +def get_next(mask: list[int]) -> None: prev = -1 for i in range(len(mask) - 1, -2, -1): @@ -22,7 +24,7 @@ def get_next(mask): prev = mask.pop() -def iterate_through_masks(n): +def iterate_through_masks(n: int) -> Iterator[list[int]]: mask = [] while len(mask) <= n: yield mask.copy() @@ -30,14 +32,15 @@ def iterate_through_masks(n): break get_next(mask) -def generate_maskset(n, num_of_masks=MASKS_PER_IMAGE): + +def generate_maskset(n: int, num_of_masks: int = MASKS_PER_IMAGE) -> list[list[int]]: masks = [(mask, i) for i, mask in enumerate(iterate_through_masks(n)) if 0 < len(mask) < n] maskset = [] - for mask in masks : + for mask in masks: while len(maskset) <= (len(mask[0]) - 1): maskset.append([]) maskset[len(mask[0]) - 1].append(mask) - + out_maskset = [] masks_per_length = max(num_of_masks // (n - 1), 1) for lenset in maskset: @@ -46,20 +49,21 @@ def generate_maskset(n, num_of_masks=MASKS_PER_IMAGE): out_maskset.append(mask) if len(out_maskset) >= num_of_masks: break - + out_maskset.sort(key=lambda x: x[1]) - + out_maskset = [mask[0] for mask in out_maskset] - + return out_maskset -def handle_record(record): + +def handle_record(record: dict[str, Any]) -> int: print(f"Augmenting record: {record['name']}") output = 1 n = len(record["lines"]) print(f" - {n} lines found.") output += 1 - + image_file = os.path.join(SCRIPT_DIR, record["filepath"]) with Image.open(image_file) as im: im_arr = np.array(im) @@ -67,7 +71,6 @@ def handle_record(record): mask_number = 0 for mask in generate_maskset(n): - print(f" - Creating augmented image with lines: {mask}") output += 1 mask_number += 1 @@ -88,7 +91,6 @@ def handle_record(record): out_path = Path(DATA_DIR) / name out_im.save(out_path) - # open json file and append new entry there json_path = Path(JSON_DIR) / "dataset.json" with json_path.open("r", encoding="utf-8") as f: @@ -97,9 +99,10 @@ def handle_record(record): with json_path.open("w", encoding="utf-8") as f: json.dump(json_data, f, ensure_ascii=False, indent=4) - #return number of printed lines + # return number of printed lines return output + if __name__ == "__main__": - #test + # test print(generate_maskset(17)) diff --git a/ocr/model_training/gen_data.py b/ocr/model_training/gen_data.py index 7138f3d..acebcd7 100644 --- a/ocr/model_training/gen_data.py +++ b/ocr/model_training/gen_data.py @@ -2,6 +2,7 @@ import os import tkinter as tk from tkinter import messagebox, ttk +from typing import Any from PIL import Image, ImageTk @@ -15,7 +16,7 @@ IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tif", ".tiff"} -def collect_image_files(root_dir): +def collect_image_files(root_dir: str) -> list[str]: paths = [] for root, _, files in os.walk(root_dir): for fname in files: @@ -29,7 +30,13 @@ class ImageAnnotator: HANDLE_SIZE = 6 HANDLE_HIT = 8 - def __init__(self, master, image_paths, start_index, output_file_path=OUTPUT_FILE_PATH): + def __init__( + self, + master: tk.Tk, + image_paths: list[str], + start_index: int, + output_file_path: str = OUTPUT_FILE_PATH, + ) -> None: self.master = master self.master.title("Image annotator — Shift+Enter submit, Enter add line, drag to edit") self.image_paths = image_paths @@ -99,7 +106,7 @@ def __init__(self, master, image_paths, start_index, output_file_path=OUTPUT_FIL self._load_current() - def _fit_image(self, ow, oh, avail_w, avail_h): + def _fit_image(self, ow: int, oh: int, avail_w: int, avail_h: int) -> tuple[int, int]: if avail_w <= 0 or avail_h <= 0: return 0, 0 img_ratio = ow / max(1, oh) @@ -112,7 +119,7 @@ def _fit_image(self, ow, oh, avail_w, avail_h): new_w = int(new_h * img_ratio) return max(new_w, 1), max(new_h, 1) - def _display_image(self): + def _display_image(self) -> None: if self._orig_img is None: return cw = max(self.canvas.winfo_width(), 1) @@ -140,10 +147,10 @@ def _display_image(self): self._rect_id = None self._clicks = [] - def _on_canvas_resize(self, _): + def _on_canvas_resize(self, _: Any) -> None: self._display_image() - def _orig_to_disp(self, x, y): + def _orig_to_disp(self, x: int, y: int) -> tuple[int, int]: ow, oh = self._orig_img.size rw, rh = self._render_size off_x, off_y = self._render_offset @@ -151,7 +158,7 @@ def _orig_to_disp(self, x, y): sy = rh / oh return int(off_x + x * sx), int(off_y + y * sy) - def _disp_to_orig(self, x, y): + def _disp_to_orig(self, x: int, y: int) -> tuple[int, int]: ow, oh = self._orig_img.size rw, rh = self._render_size off_x, off_y = self._render_offset @@ -161,14 +168,14 @@ def _disp_to_orig(self, x, y): sy = oh / max(1, rh) return int(rx * sx), int(ry * sy) - def _constrain_to_image(self, x, y): + def _constrain_to_image(self, x: int, y: int) -> tuple[int, int]: off_x, off_y = self._render_offset rw, rh = self._render_size x = max(off_x, min(off_x + rw, x)) y = max(off_y, min(off_y + rh, y)) return x, y - def _draw_selection(self, x1, y1, x2, y2): + def _draw_selection(self, x1: int, y1: int, x2: int, y2: int) -> None: x1, x2 = sorted((x1, x2)) y1, y2 = sorted((y1, y2)) x1, y1 = self._constrain_to_image(x1, y1) @@ -194,7 +201,7 @@ def _draw_selection(self, x1, y1, x2, y2): self._handle_ids.append(hid) self._sel_rect = [x1, y1, x2, y2] - def _start_rubber_band(self, x, y): + def _start_rubber_band(self, x: int, y: int) -> int: self._clicks = [(x, y)] self.canvas.delete("temp_rect") @@ -203,7 +210,7 @@ def _start_rubber_band(self, x, y): ) return rect_id - def _on_canvas_motion(self, event): + def _on_canvas_motion(self, event: Any) -> None: if len(self._clicks) == 1 and self._orig_img is not None: x1, y1 = self._clicks[0] x2, y2 = self._constrain_to_image(event.x, event.y) @@ -221,7 +228,7 @@ def _on_canvas_motion(self, event): }.get(mode, "") self.canvas.configure(cursor=cursor) - def _on_canvas_click(self, event): + def _on_canvas_click(self, event: Any) -> None: if self._orig_img is None: return off_x, off_y = self._render_offset @@ -249,7 +256,7 @@ def _on_canvas_click(self, event): self.text.focus_set() self._clicks = [] - def _on_canvas_drag(self, event): + def _on_canvas_drag(self, event: Any) -> None: if self._sel_rect is None or self._drag_mode is None: return x1, y1, x2, y2 = self._sel_rect @@ -270,11 +277,11 @@ def _on_canvas_drag(self, event): x2 = ex self._draw_selection(x1, y1, x2, y2) - def _on_canvas_release(self, _event): + def _on_canvas_release(self, _event: Any) -> None: self._drag_mode = None self._drag_start = None - def _hit_test(self, x, y, x1, y1, x2, y2): + def _hit_test(self, x: int, y: int, x1: int, y1: int, x2: int, y2: int) -> str | None: x1, x2 = sorted((x1, x2)) y1, y2 = sorted((y1, y2)) h = self.HANDLE_HIT @@ -286,7 +293,7 @@ def _hit_test(self, x, y, x1, y1, x2, y2): return "move" return None - def _on_escape(self, _): + def _on_escape(self, _: Any) -> None: if self._rect_id is not None: self.canvas.delete(self._rect_id) self._rect_id = None @@ -299,7 +306,7 @@ def _on_escape(self, _): self.text.delete("1.0", "end") self.text.config(state="disabled") - def _on_text_enter(self, event): + def _on_text_enter(self, event: Any) -> str: if self._sel_rect is None: return "break" text = self.text.get("1.0", "end").strip() @@ -323,7 +330,7 @@ def _on_text_enter(self, event): self._display_image() return "break" - def _save_image_record(self): + def _save_image_record(self) -> bool: name = f"{self.current_pos:04d}" rel_filepath = f"data/{name}.png" abs_filepath = os.path.join(SCRIPT_DIR, rel_filepath) @@ -355,7 +362,7 @@ def _save_image_record(self): print(f"Saved record {name} with {len(self._lines)} line(s)") return True - def _on_submit(self, _event): + def _on_submit(self, _event: Any) -> str: try: self._on_text_enter(None) except Exception as e: @@ -367,7 +374,7 @@ def _on_submit(self, _event): self._goto_next_image() return "break" - def _goto_next_image(self): + def _goto_next_image(self) -> None: self.current_pos += 2 if self.current_pos >= len(self.image_paths): messagebox.showinfo("Done", "Reached the end of the dataset.") @@ -375,7 +382,7 @@ def _goto_next_image(self): else: self._load_current() - def _load_current(self): + def _load_current(self) -> None: self.current_image_path = self.image_paths[self.current_pos] self.path_var.set(self.current_image_path) try: @@ -395,7 +402,7 @@ def _load_current(self): self.master.after(10, self._display_image) -def main(): +def main() -> None: all_files = collect_image_files(DATA_DIR) print(f"Collected {len(all_files)} image files from dataset.") diff --git a/ocr/model_training/image_cropper.py b/ocr/model_training/image_cropper.py index 5fce2ba..dca9c7c 100644 --- a/ocr/model_training/image_cropper.py +++ b/ocr/model_training/image_cropper.py @@ -3,6 +3,7 @@ import json import os from pathlib import Path +from typing import Any import cv2 import numpy as np @@ -15,7 +16,7 @@ DOWNSCALE_FACTOR = 1 -def handle_record(item): +def handle_record(item: dict[str, Any]) -> int: filepath = os.path.join(SCRIPT_DIR, item["filepath"]) img_arr = np.array(Image.open(filepath)) @@ -27,6 +28,9 @@ def handle_record(item): max_y1 = min(max(ln["bbox"][3] for ln in lines), img_arr.shape[0] - 1) + 1 new_image = (np.random.power(50, size=img_arr.shape[:2]) * 255).astype(np.uint8) + # new_image = np.zeros_like(img_arr) + # create white new image + # new_image = np.ones_like(img_arr) * 255 new_image[min_y0:max_y1, min_x0:max_x1] = img_arr[min_y0:max_y1, min_x0:max_x1] img_crop = new_image @@ -45,7 +49,7 @@ def handle_record(item): return 0 -def main(): +def main() -> None: os.makedirs(DATA_DIR, exist_ok=True) ds = json.loads(Path(JSON_PATH).read_text(encoding="utf-8")) for item in ds: diff --git a/ocr/model_training/json_to_page.py b/ocr/model_training/json_to_page.py index e8ed41e..555a3a7 100644 --- a/ocr/model_training/json_to_page.py +++ b/ocr/model_training/json_to_page.py @@ -3,6 +3,7 @@ import json import os from pathlib import Path +from typing import Any from lxml import etree from PIL import Image @@ -15,7 +16,7 @@ NSMAP = {None: PAGE_NS} -def clamp_bbox(b, w, h): +def clamp_bbox(b: list[int], w: int, h: int) -> tuple[int, int, int, int]: x0, y0, x1, y1 = b x0 = max(0, min(x0, w)) x1 = max(0, min(x1, w)) @@ -28,7 +29,7 @@ def clamp_bbox(b, w, h): return x0, y0, x1, y1 -def make_page_xml(image_path, img_w, img_h, lines): +def make_page_xml(image_path: str, img_w: int, img_h: int, lines: list[dict[str, Any]]) -> bytes: root = etree.Element("PcGts", nsmap=NSMAP) page = etree.SubElement(root, "Page", imageFilename=image_path, imageWidth=str(img_w), imageHeight=str(img_h)) region = etree.SubElement(page, "TextRegion", id="r1") @@ -46,7 +47,7 @@ def make_page_xml(image_path, img_w, img_h, lines): x0, y0, x1, y1 = ln["bbox"] tl = etree.SubElement(region, "TextLine", id=f"l{i:04d}") etree.SubElement(tl, "Coords", points=f"{x0},{y0} {x1},{y0} {x1},{y1} {x0},{y1}") - etree.SubElement(tl, "Baseline", points=f"{x0},{y1} {x1},{y1}") + etree.SubElement(tl, "Baseline", points=f"{x0},{(y0 + y1) / 2} {x1},{(y0 + y1) / 2}") if "text" in ln and ln["text"] is not None: te = etree.SubElement(tl, "TextEquiv") uni = etree.SubElement(te, "Unicode") @@ -54,7 +55,7 @@ def make_page_xml(image_path, img_w, img_h, lines): return etree.tostring(root, pretty_print=True, xml_declaration=True, encoding="UTF-8") -def save_line_crops(base_out, name, img, lines): +def save_line_crops(base_out: Path, name: str, img: Image.Image, lines: list[dict[str, Any]]) -> None: out_dir = base_out / name out_dir.mkdir(parents=True, exist_ok=True) for i, ln in enumerate(lines): @@ -66,7 +67,7 @@ def save_line_crops(base_out, name, img, lines): txt_path.write_text(ln.get("text", ""), encoding="utf-8") -def handle_record(item): +def handle_record(item: dict[str, Any]) -> int: rel_path = item["filepath"] img_path = (Path(DATA_DIR) / Path(rel_path).name).resolve() name = item.get("name") or Path(rel_path).stem @@ -84,7 +85,7 @@ def handle_record(item): return 0 -def main(): +def main() -> None: os.makedirs(DATA_DIR, exist_ok=True) ds = json.loads(Path(JSON_PATH).read_text(encoding="utf-8")) for item in ds: diff --git a/ocr/model_training/preprocess_data.py b/ocr/model_training/preprocess_data.py index 9e84ec7..b205e29 100644 --- a/ocr/model_training/preprocess_data.py +++ b/ocr/model_training/preprocess_data.py @@ -1,4 +1,5 @@ import json +import subprocess import time from pathlib import Path @@ -11,16 +12,17 @@ JSON_PATH = (SCRIPT_DIR / "input" / "dataset.json").resolve() - -def clear_n_lines(number_of_lines): +def clear_n_lines(number_of_lines: int) -> None: for i in range(number_of_lines): print("\033[F\033[K", end="") - time.sleep(((2*i)/(number_of_lines*(number_of_lines+1))) * 0.2) + time.sleep(((2 * i) / (number_of_lines * (number_of_lines + 1))) * 0.2) -def main(): +def main() -> None: start_time = time.time() - + + subprocess.run(["./load_backups.sh"], cwd=str(SCRIPT_DIR), check=True) + with open(JSON_PATH, encoding="utf-8") as f: original_data = json.load(f) original_data.sort(key=lambda x: int(x["name"])) @@ -41,8 +43,8 @@ def main(): lines_printed += 1 with open(JSON_PATH, encoding="utf-8") as f: json_data = json.load(f) - - try : + + try: record = json_data[record_index] except IndexError: break diff --git a/ocr/model_training/seg_train_pipeline.py b/ocr/model_training/seg_train_pipeline.py index 1d82eb8..cfacff5 100644 --- a/ocr/model_training/seg_train_pipeline.py +++ b/ocr/model_training/seg_train_pipeline.py @@ -7,7 +7,7 @@ SCRIPT_DIR = Path(__file__).resolve().parent -def main(): +def main() -> None: print("Loading backups...") subprocess.run(["./load_backups.sh"], cwd=SCRIPT_DIR) print("Preprocessing data...") diff --git a/ocr/model_training/test_ocr_model.py b/ocr/model_training/test_ocr_model.py index 7895e33..cc815f6 100644 --- a/ocr/model_training/test_ocr_model.py +++ b/ocr/model_training/test_ocr_model.py @@ -2,11 +2,12 @@ import json import multiprocessing as mp -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import dataclass from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType +from typing import Any from PIL import Image from rapidfuzz.distance import Levenshtein @@ -33,32 +34,42 @@ def load_module_from_path(path: Path) -> ModuleType: return module -MODEL_HANDLER_PATH = Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "handlers" / "kraken.py" -MODEL_PATH = Path(__file__).resolve().parent / ".." / "models" / "ocr_best_ketos.mlmodel" +MODEL_HANDLER_PATH = Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "handlers" / "0kraken.py" +MODEL_PATH = Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "ocr_best_submitted.mlmodel" JSON_PATH = Path(__file__).resolve().parent / "input" / "dataset.json" +SCRIPT_DIR = Path(__file__).resolve().parent -def normalize_text(text): +OUT_FILE = SCRIPT_DIR / "test_results" / "ocr_test_results.json" +SAVED_DATA = {} + + +def normalize_text(text: str) -> str: return " ".join((text or "").split()).strip() -def load_dataset(json_path: Path = JSON_PATH): +def load_dataset(json_path: Path = JSON_PATH) -> list[dict[str, Any]]: with JSON_PATH.open("r", encoding="utf-8") as f: dataset_info = json.load(f) return dataset_info def test_model( - model_path=MODEL_PATH, model_handler=MODEL_HANDLER_PATH, tests=100, concurrently=False, dataset_info=None -): + model_path: Path = MODEL_PATH, + model_handler: Path = MODEL_HANDLER_PATH, + tests: int = 100, + concurrently: bool = False, + dataset_info: list[dict[str, Any]] | None = None, +) -> Iterator[tuple[float, float] | None]: module = load_module_from_path(model_handler) module.load(model_path) handle_func = module.handle + if dataset_info is None: dataset_info = load_dataset(JSON_PATH) counter = 0 - score = 0 + score = 0.0 log_length = 0 if not concurrently: @@ -67,6 +78,10 @@ def test_model( for item in dataset_info: img = Image.open(Path(__file__).resolve().parent / item["filepath"]) + + if "aug" in item["name"]: + continue + for line in item["lines"]: line_img = img.crop((line["bbox"][0], line["bbox"][1], line["bbox"][2], line["bbox"][3])) @@ -76,14 +91,18 @@ def test_model( got = normalize_text(result) lev_distance = float(Levenshtein.distance(expected, got)) - score += 1 - (lev_distance / max(1, len(expected))) + model_numerical_result = 1 - (lev_distance / max(1, len(expected))) + + score += model_numerical_result + running_avg = score / (counter + 1) - log = f" Test {counter + 1}, current model score: {score / (counter + 1):.4f}" + log = f" Test {counter + 1}, current model score: {running_avg:.4f}" log_length = max(log_length, len(log)) + if not concurrently: print(log + " " * max((log_length - len(log), 0)), end="\r") else: - yield score / (counter + 1) + yield (model_numerical_result, running_avg) counter += 1 @@ -96,27 +115,44 @@ def test_model( if not concurrently: print(" " * log_length, end="\r") print(f"After {counter} tests") - print(log) else: yield None -def _run_model_in_process(model_path, model_handler, tests, queue): - for result in test_model(model_path, model_handler, tests, concurrently=True): - queue.put(result) - queue.put(None) - +def _run_model_in_process( + model_path: Path, + model_handler: Path, + tests: int, + queue: mp.Queue, + model_name: str, +) -> None: + for payload in test_model( + model_path, + model_handler, + tests, + concurrently=True, + ): + queue.put((model_name, payload)) + queue.put((model_name, None)) + + +def test_models_concurrently( + model_paths: list[tuple[Path, Path]], tests_per_model: int | None = 2, one_line: bool = True +) -> None: + global SAVED_DATA + SAVED_DATA = {mp[0].name: [] for mp in model_paths} -def test_models_concurrently(model_paths, tests_per_model=2, one_line=True): model_paths.sort(key=lambda p: len(p[0].name), reverse=True) column_length = len(model_paths[0][0].name) if model_paths else 10 dataset_info = load_dataset(JSON_PATH) - lines_num = sum(len(record["lines"]) for record in dataset_info) + lines_num = sum(len(record["lines"]) for record in dataset_info if "aug" not in record["name"]) if tests_per_model is None or tests_per_model > lines_num: tests_per_model = lines_num + print(f"Each model will be tested on {tests_per_model} lines.") + print("Starting concurrent model tests...") frontline = " " * 6 header = frontline @@ -134,7 +170,10 @@ def test_models_concurrently(model_paths, tests_per_model=2, one_line=True): processes = [] for (model_path, handler), queue in zip(model_paths, queues, strict=False): - p = mp.Process(target=_run_model_in_process, args=(model_path, handler, tests_per_model, queue)) + p = mp.Process( + target=_run_model_in_process, + args=(model_path, handler, tests_per_model, queue, model_path.name), + ) p.start() processes.append(p) @@ -149,13 +188,17 @@ def test_models_concurrently(model_paths, tests_per_model=2, one_line=True): row.append(prev_results[i]) continue - r = queue.get() - if r is None: + msg_model, payload = queue.get() + + if payload is None: finished[i] = True row.append(prev_results[i]) else: - prev_results[i] = r - row.append(r) + model_numerical_result, running_avg = payload + SAVED_DATA[msg_model].append(model_numerical_result) + + prev_results[i] = running_avg + row.append(running_avg) counter += 1 line = frontline @@ -172,16 +215,25 @@ def test_models_concurrently(model_paths, tests_per_model=2, one_line=True): for p in processes: p.join() - print(line + f"| {counter} / {tests_per_model + 1}") + if one_line: + print(line + f"| {counter} / {tests_per_model + 1}") + OUT_FILE.parent.mkdir(parents=True, exist_ok=True) + with OUT_FILE.open("w", encoding="utf-8") as f: + json.dump(SAVED_DATA, f, indent=2, ensure_ascii=False) if __name__ == "__main__": test_models_concurrently( [ (MODEL_PATH, MODEL_HANDLER_PATH), - (Path(__file__).resolve().parent / ".." / "models" / "en_best.mlmodel", MODEL_HANDLER_PATH), + (Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "en_best.mlmodel", MODEL_HANDLER_PATH), (Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "kraken.mlmodel", MODEL_HANDLER_PATH), + (Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "ocr_best.mlmodel", MODEL_HANDLER_PATH), + ( + Path(__file__).resolve().parent / ".." / "models" / "ocr_models" / "ocr_best_ketos.mlmodel", + MODEL_HANDLER_PATH, + ), ], - tests_per_model=1000, + tests_per_model=float("inf"), one_line=True, ) diff --git a/ocr/model_training/test_seg_model.py b/ocr/model_training/test_seg_model.py index 363d170..3279f5d 100644 --- a/ocr/model_training/test_seg_model.py +++ b/ocr/model_training/test_seg_model.py @@ -2,11 +2,12 @@ import json import multiprocessing as mp -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import dataclass from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType +from typing import Any import numpy as np from PIL import Image @@ -23,6 +24,9 @@ class SegModelSpec: SCRIPT_DIR = Path(__file__).resolve().parent JSON_PATH = SCRIPT_DIR / "input" / "dataset.json" +OUT_FILE = SCRIPT_DIR / "test_results" / "segmentation_test_results_new.json" +SAVED_DATA = {} + def load_module_from_path(path: Path) -> ModuleType: unique_name = f"_seg_model_{path.stem}_{abs(hash(path.as_posix()))}" @@ -37,13 +41,17 @@ def load_module_from_path(path: Path) -> ModuleType: return module -def load_dataset(json_path: Path = JSON_PATH): +def load_dataset(json_path: Path = JSON_PATH) -> list[dict[str, Any]]: with json_path.open("r", encoding="utf-8") as f: data = json.load(f) return data -def segmentation_metric(gt_lines, pred_lines, image): +def segmentation_metric( + gt_lines: list[dict[str, Any]], + pred_lines: list[dict[str, Any]], + image: Image.Image, +) -> float: h, w = image.size[1], image.size[0] gt_mask = np.zeros((h, w), dtype=np.uint8) pred_mask = np.zeros((h, w), dtype=np.uint8) @@ -102,8 +110,12 @@ def _load_seg_model_in_handler(handler_module: ModuleType, model_path: Path) -> def test_seg_model( - model_path: Path, handler_path: Path, max_pages: int | None = None, concurrently: bool = False, dataset=None -): + model_path: Path, + handler_path: Path, + max_pages: int | None = None, + concurrently: bool = False, + dataset: list[dict[str, Any]] | None = None, +) -> Iterator[tuple[float, float] | None]: handler_module = load_module_from_path(handler_path) _load_seg_model_in_handler(handler_module, model_path) segment_func = _get_segment_function(handler_module) @@ -124,7 +136,6 @@ def test_seg_model( image = Image.open(img_path).convert("RGB") pred_lines = segment_func(image, seg_model_path=model_path, filter_warnings=True) - gt_lines = item["lines"] page_score = segmentation_metric(gt_lines, pred_lines, image) @@ -138,7 +149,7 @@ def test_seg_model( if not concurrently: print(log + " " * max((log_length - len(log), 0)), end="\r") else: - yield avg_score + yield (page_score, avg_score) if max_pages is not None and page_count >= max_pages: break @@ -151,19 +162,29 @@ def test_seg_model( yield None -def _run_seg_model_in_process(model_path, handler_path, max_pages, queue): - for result in test_seg_model(model_path, handler_path, max_pages, concurrently=True): - queue.put(result) - queue.put(None) +def _run_seg_model_in_process( + model_path: Path, + handler_path: Path, + max_pages: int | None, + queue: mp.Queue, + model_name: str, +) -> None: + for payload in test_seg_model(model_path, handler_path, max_pages, concurrently=True): + queue.put((model_name, payload)) + queue.put((model_name, None)) def test_seg_models_concurrently( - model_specs, + model_specs: list[tuple[Path, Path]], pages_per_model: int | None = None, one_line: bool = True, -): +) -> None: + global SAVED_DATA normalized = [(Path(m), Path(h)) for (m, h) in model_specs] normalized.sort(key=lambda p: len(p[0].name), reverse=True) + + SAVED_DATA = {model_path.name: [] for (model_path, _) in normalized} + column_length = len(normalized[0][0].name) if normalized else 10 dataset = load_dataset(JSON_PATH) @@ -190,7 +211,7 @@ def test_seg_models_concurrently( for (model_path, handler_path), queue in zip(normalized, queues, strict=False): p = mp.Process( target=_run_seg_model_in_process, - args=(model_path, handler_path, pages_per_model, queue), + args=(model_path, handler_path, pages_per_model, queue, model_path.name), ) p.start() processes.append(p) @@ -206,13 +227,17 @@ def test_seg_models_concurrently( row.append(prev_results[i]) continue - r = queue.get() - if r is None: + msg_model, payload = queue.get() + + if payload is None: finished[i] = True row.append(prev_results[i]) else: - prev_results[i] = r - row.append(r) + page_score, avg_score = payload + SAVED_DATA[msg_model].append(page_score) + + prev_results[i] = avg_score + row.append(avg_score) iteration += 1 line = frontline @@ -232,6 +257,10 @@ def test_seg_models_concurrently( if one_line: print(line + f"| {iteration} / {pages_per_model + 1}") + OUT_FILE.parent.mkdir(parents=True, exist_ok=True) + with OUT_FILE.open("w", encoding="utf-8") as f: + json.dump(SAVED_DATA, f, indent=2, ensure_ascii=False) + if __name__ == "__main__": SEG_HANDLER_PATH = SCRIPT_DIR / ".." / "app" / "segmentator.py" @@ -240,7 +269,9 @@ def test_seg_models_concurrently( [ (SCRIPT_DIR / ".." / "models" / "seg_best.mlmodel", SEG_HANDLER_PATH), (SCRIPT_DIR / ".." / "models" / "seg_best_submitted.mlmodel", SEG_HANDLER_PATH), + (SCRIPT_DIR / ".." / "models" / "seg_best_old.mlmodel", SEG_HANDLER_PATH), + (SCRIPT_DIR / ".." / "models" / "blla.mlmodel", SEG_HANDLER_PATH), ], - pages_per_model=1000, + pages_per_model=float("inf"), one_line=True, ) diff --git a/ocr/model_training/train_ocr.py b/ocr/model_training/train_ocr.py index 9c009b4..23192f0 100644 --- a/ocr/model_training/train_ocr.py +++ b/ocr/model_training/train_ocr.py @@ -6,6 +6,7 @@ import re import shutil import subprocess +from collections.abc import Iterable from pathlib import Path from utils.memory_info import main as print_memory_info @@ -35,12 +36,12 @@ MIN_EPOCHS = max(MIN_EPOCHS, EARLY_STOPPING) -def natural_epoch_sort_key(p): +def natural_epoch_sort_key(p: Path) -> int: m = re.search(r"model_(\d+)\.mlmodel$", p.name) return int(m.group(1)) if m else -1 -def promote(run_dir, dest_path): +def promote(run_dir: Path | str, dest_path: Path | str) -> Path: run_dir = Path(run_dir).resolve() dest_path = Path(dest_path).resolve() dest_path.parent.mkdir(parents=True, exist_ok=True) @@ -58,7 +59,7 @@ def promote(run_dir, dest_path): return best_path -def collect_xml_files(data_dir): +def collect_xml_files(data_dir: Path) -> list[Path]: xml_paths = sorted(glob.glob(str(data_dir / "*.xml"))) xml_paths = [p for p in xml_paths if "aug" not in Path(p).name] if not xml_paths: @@ -66,7 +67,9 @@ def collect_xml_files(data_dir): return [Path(p) for p in xml_paths] -def train_val_split(xml_files, seed=SEED, val_ratio=VAL_RATIO): +def train_val_split( + xml_files: Iterable[Path], seed: int = SEED, val_ratio: float = VAL_RATIO +) -> tuple[list[Path], list[Path]]: xml_files = list(xml_files) random.seed(seed) random.shuffle(xml_files) @@ -81,7 +84,7 @@ def train_val_split(xml_files, seed=SEED, val_ratio=VAL_RATIO): return train_xml, val_xml -def write_manifest(paths, dest): +def write_manifest(paths: Iterable[Path], dest: Path) -> Path: dest = dest.resolve() dest.parent.mkdir(parents=True, exist_ok=True) with dest.open("w", encoding="utf-8") as f: @@ -91,17 +94,17 @@ def write_manifest(paths, dest): def run_ketos_train( - data_dir=DATA_DIR, - out_root=OUT_DIR, - format_type=FORMAT_TYPE, - base_model=BASE_MODEL, - min_epochs=MIN_EPOCHS, - max_epochs=MAX_EPOCHS, - batch_size=BATCH_SIZE, - early_stopping=EARLY_STOPPING, - device=DEVICE, - val_ratio=VAL_RATIO, -): + data_dir: Path = DATA_DIR, + out_root: Path = OUT_DIR, + format_type: str = FORMAT_TYPE, + base_model: Path | None = BASE_MODEL, + min_epochs: int = MIN_EPOCHS, + max_epochs: int = MAX_EPOCHS, + batch_size: int = BATCH_SIZE, + early_stopping: int = EARLY_STOPPING, + device: str = DEVICE, + val_ratio: float = VAL_RATIO, +) -> Path: data_dir = Path(data_dir).resolve() out_root = Path(out_root).resolve() @@ -165,7 +168,7 @@ def run_ketos_train( return run_dir -def main(): +def main() -> None: print_memory_info() run_dir = run_ketos_train( diff --git a/ocr/model_training/train_segmentator.py b/ocr/model_training/train_segmentator.py index 4f64b46..dd5fa80 100644 --- a/ocr/model_training/train_segmentator.py +++ b/ocr/model_training/train_segmentator.py @@ -6,6 +6,7 @@ import re import shutil import subprocess +from collections.abc import Iterable from pathlib import Path from utils.memory_info import main as print_memory_info @@ -39,12 +40,12 @@ USE_AUGMENT = True -def natural_epoch_sort_key(p): +def natural_epoch_sort_key(p: Path) -> int: m = re.search(r"model_(\d+)\.mlmodel$", p.name) return int(m.group(1)) if m else -1 -def promote(run_dir, dest_path): +def promote(run_dir: Path | str, dest_path: Path | str) -> Path: run_dir = Path(run_dir).resolve() dest_path = Path(dest_path).resolve() dest_path.parent.mkdir(parents=True, exist_ok=True) @@ -66,14 +67,16 @@ def promote(run_dir, dest_path): return best_path -def collect_xml_files(data_dir): +def collect_xml_files(data_dir: Path) -> list[Path]: xml_paths = sorted(glob.glob(str(data_dir / "*.xml"))) if not xml_paths: raise FileNotFoundError(f"No XML files found in {data_dir}") return [Path(p) for p in xml_paths] -def train_val_split(xml_files, seed=SEED, val_ratio=VAL_RATIO): +def train_val_split( + xml_files: Iterable[Path], seed: int = SEED, val_ratio: float = VAL_RATIO +) -> tuple[list[Path], list[Path]]: xml_files = list(xml_files) random.seed(seed) random.shuffle(xml_files) @@ -88,7 +91,7 @@ def train_val_split(xml_files, seed=SEED, val_ratio=VAL_RATIO): return train_xml, val_xml -def write_manifest(paths, dest): +def write_manifest(paths: Iterable[Path], dest: Path) -> Path: dest = dest.resolve() dest.parent.mkdir(parents=True, exist_ok=True) with dest.open("w", encoding="utf-8") as f: @@ -98,19 +101,19 @@ def write_manifest(paths, dest): def run_ketos_segtrain( - data_dir=DATA_DIR, - out_root=OUT_DIR, - format_type=FORMAT_TYPE, - base_model=BASE_MODEL, - min_epochs=MIN_EPOCHS, - max_epochs=MAX_EPOCHS, - early_stopping=EARLY_STOPPING, - device=DEVICE, - val_ratio=VAL_RATIO, - schedule=SCHEDULE, - workers=WORKERS, - use_augment=USE_AUGMENT, -): + data_dir: Path = DATA_DIR, + out_root: Path = OUT_DIR, + format_type: str = FORMAT_TYPE, + base_model: Path | None = BASE_MODEL, + min_epochs: int = MIN_EPOCHS, + max_epochs: int = MAX_EPOCHS, + early_stopping: int = EARLY_STOPPING, + device: str = DEVICE, + val_ratio: float = VAL_RATIO, + schedule: str = SCHEDULE, + workers: int = WORKERS, + use_augment: bool = USE_AUGMENT, +) -> Path: data_dir = Path(data_dir).resolve() out_root = Path(out_root).resolve() @@ -175,7 +178,7 @@ def run_ketos_segtrain( return run_dir -def main(): +def main() -> None: print_memory_info() run_dir = run_ketos_segtrain( diff --git a/ocr/model_training/train_segmentator_old.py b/ocr/model_training/train_segmentator_old.py index 27ec7bb..6d662a6 100644 --- a/ocr/model_training/train_segmentator_old.py +++ b/ocr/model_training/train_segmentator_old.py @@ -2,159 +2,201 @@ import datetime import glob -import json import random import re import shutil +import subprocess +from collections.abc import Iterable from pathlib import Path -import torch -from kraken.lib.train import KrakenTrainer, SegmentationModel -from PIL import Image from utils.memory_info import main as print_memory_info -print_memory_info() - SCRIPT_DIR = Path(__file__).resolve().parent + DATA_DIR = (SCRIPT_DIR / "data").resolve() + OUT_DIR = (SCRIPT_DIR / "train_runs").resolve() OUT_DIR.mkdir(parents=True, exist_ok=True) -MIN_EPOCHS = 5 -MAX_EPOCHS = 55 +BEST_MODEL_PATH = (SCRIPT_DIR / ".." / "models" / "seg_best.mlmodel").resolve() + +FORMAT_TYPE = "page" + +BASE_MODEL = (SCRIPT_DIR / ".." / "models" / "blla.mlmodel").resolve() +# BASE_MODEL = None + +MIN_EPOCHS = 10 +MAX_EPOCHS = 100 +EARLY_STOPPING = 10 SEED = 42 +VAL_RATIO = 0.2 +DEVICE = "cuda:0" +MIN_EPOCHS = max(MIN_EPOCHS, EARLY_STOPPING) -def natural_epoch_sort_key(p): - m = re.search(r"model_(\d+)\.mlmodel$", p.name) - return int(m.group(1)) if m else -1 +SCHEDULE = "cosine" +WORKERS = 8 +USE_AUGMENT = True -def load_metric_sibling(model_path): - for suf in (".json", ".metrics.json", ".report.json"): - p = model_path.with_suffix(suf) - if p.exists(): - try: - with p.open("r", encoding="utf-8") as f: - data = json.load(f) - for key in ("val_freq_iu", "freq_iu", "val_mean_iu", "mean_iu"): - if key in data and isinstance(data[key], (int, float)): - return float(data[key]) - except Exception: - pass - return None +def natural_epoch_sort_key(p: Path) -> int: + m = re.search(r"model_(\d+)\.mlmodel$", p.name) + return int(m.group(1)) if m else -1 -def debug_dump_gt(model, run_dir): - try: - dl = model._make_dataloader(split="train", shuffle=False) - xb, yb = next(iter(dl)) - gt = yb.argmax(1)[0].cpu().numpy() - Image.fromarray((gt > 0).astype("uint8") * 255).save(run_dir / "_debug_gt.png") - print(f"[debug] GT sum (first sample): {gt.sum()}") - except Exception as e: - print(f"[debug] Failed to dump GT: {e}") - - -def train( - data_dir=DATA_DIR, - out_root=OUT_DIR, - min_epochs=MIN_EPOCHS, - max_epochs=MAX_EPOCHS, - seed=SEED, - overfit_one=False, -): - all_xml = sorted(glob.glob(str((data_dir).resolve() / "*.xml"))) - assert all_xml, f"No XML files found in {data_dir}" - random.seed(seed) - random.shuffle(all_xml) - split = max(1, int(0.1 * len(all_xml))) - val_xml = all_xml[:split] - train_xml = all_xml[split:] - if overfit_one: - val_xml = train_xml[:] +def promote(run_dir: Path | str, dest_path: Path | str) -> Path: + run_dir = Path(run_dir).resolve() + dest_path = Path(dest_path).resolve() + dest_path.parent.mkdir(parents=True, exist_ok=True) + best_path = run_dir / "model_best.mlmodel" + if not best_path.exists(): + print(f"{best_path} not found - taking latest model instead") + model_paths = sorted( + run_dir.glob("model_*.mlmodel"), + key=natural_epoch_sort_key, + reverse=True, + ) + if not model_paths: + raise FileNotFoundError(f"No model_*.mlmodel files found in {run_dir}") + best_path = model_paths[0] + + shutil.copy2(best_path, dest_path) + print(f"Promoted segmenter: {best_path.name} → {dest_path}") + return best_path + + +def collect_xml_files(data_dir: Path) -> list[Path]: + xml_paths = sorted(glob.glob(str(data_dir / "*.xml"))) + if not xml_paths: + raise FileNotFoundError(f"No XML files found in {data_dir}") + return [Path(p) for p in xml_paths] + + +def train_val_split( + xml_files: Iterable[Path], seed: int = SEED, val_ratio: float = VAL_RATIO +) -> tuple[list[Path], list[Path]]: + xml_files = list(xml_files) + random.seed(seed) + random.shuffle(xml_files) + + split_idx = max(1, int(len(xml_files) * val_ratio)) + val_xml = xml_files[:split_idx] + train_xml = xml_files[split_idx:] + + if not train_xml: + raise RuntimeError("Train split is empty – too high VAL_RATIO for this dataset") + + return train_xml, val_xml + + +def write_manifest(paths: Iterable[Path], dest: Path) -> Path: + dest = dest.resolve() + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("w", encoding="utf-8") as f: + for p in paths: + f.write(str(p) + "\n") + return dest + + +def run_ketos_segtrain( + data_dir: Path = DATA_DIR, + out_root: Path = OUT_DIR, + format_type: str = FORMAT_TYPE, + base_model: Path | None = BASE_MODEL, + min_epochs: int = MIN_EPOCHS, + max_epochs: int = MAX_EPOCHS, + early_stopping: int = EARLY_STOPPING, + device: str = DEVICE, + val_ratio: float = VAL_RATIO, + schedule: str = SCHEDULE, + workers: int = WORKERS, + use_augment: bool = USE_AUGMENT, +) -> Path: + data_dir = Path(data_dir).resolve() + out_root = Path(out_root).resolve() + + xml_files = collect_xml_files(data_dir) + print(f"Found {len(xml_files)} XML files in {data_dir}") + train_xml, val_xml = train_val_split(xml_files, val_ratio=val_ratio) print(f"train: {len(train_xml)} | val: {len(val_xml)}") timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") run_dir = (out_root / f"seg_{timestamp}").resolve() run_dir.mkdir(parents=True, exist_ok=True) - SEG_SPEC = "[1,900,0,3 Cr7,7,32,2,2 Gn16 Cr3,3,64,2,2 Gn16 Cr3,3,64 Gn16 Cr3,3,128 Gn16]" - - model = SegmentationModel( - training_data=train_xml, - evaluation_data=val_xml, - format_type="xml", - output=str(run_dir / "model"), - spec=SEG_SPEC, - num_workers=4, - ) - - trainer = KrakenTrainer( - default_root_dir=str(run_dir), - min_epochs=min_epochs, - max_epochs=max_epochs, - enable_progress_bar=True, - precision="bf16-mixed" if torch.cuda.is_available() else 32, - limit_val_batches=1, - log_every_n_steps=1, - ) - - if torch.cuda.is_available(): - torch.set_float32_matmul_precision("high") - - debug_dump_gt(model, run_dir) + train_manifest = write_manifest(train_xml, run_dir / "train.txt") + val_manifest = write_manifest(val_xml, run_dir / "val.txt") + + cmd = [ + "ketos", + "segtrain", + "-f", + format_type, + "--device", + device, + "--workers", + str(workers), + "--min-epochs", + str(min_epochs), + "-N", + str(max_epochs), + "--lag", + str(early_stopping), + "-q", + "early", + "--schedule", + schedule, + "--output", + str(run_dir / "model"), + "-t", + str(train_manifest), + "-e", + str(val_manifest), + ] + + if use_augment: + cmd.append("--augment") + + if base_model is not None: + base_model_path = Path(base_model).resolve() + if not base_model_path.exists(): + raise FileNotFoundError(f"Base segmentation model not found: {base_model_path}") + cmd += ["-i", str(base_model_path), "--resize", "both"] + + print("Running ketos segtrain:") + print(" " + " ".join(cmd)) - trainer.fit(model) + try: + subprocess.run(cmd, check=True) + except KeyboardInterrupt: + print("Segmentation training interrupted by user") + return run_dir - print("Finished training, check:", run_dir) + print(f"Finished ketos segtrain, run dir: {run_dir}") return run_dir -def promote(run_dir, dest_path, prefer_metric=True): - run_dir = Path(run_dir).resolve() - dest_path = Path(dest_path).resolve() - dest_path.parent.mkdir(parents=True, exist_ok=True) - - candidates = sorted(run_dir.glob("model_*.mlmodel"), key=natural_epoch_sort_key) - if not candidates: - raise FileNotFoundError(f"No model_*.mlmodel files found in {run_dir}") - - selected = None - best_metric = None - - if prefer_metric: - scored = [] - for m in candidates: - sc = load_metric_sibling(m) - if sc is not None: - scored.append((m, sc)) - if scored: - scored.sort(key=lambda x: x[1], reverse=True) - selected, best_metric = scored[0] - else: - selected = max(candidates, key=natural_epoch_sort_key) - else: - selected = max(candidates, key=natural_epoch_sort_key) - - shutil.copy2(selected, dest_path) - print(f"Promoted: {selected.name} → {dest_path} (metric={best_metric})") - return selected, best_metric - - -def main(): +def main() -> None: print_memory_info() - run_dir = train( + + run_dir = run_ketos_segtrain( data_dir=DATA_DIR, out_root=OUT_DIR, + format_type=FORMAT_TYPE, + base_model=BASE_MODEL, min_epochs=MIN_EPOCHS, max_epochs=MAX_EPOCHS, - overfit_one=False, + early_stopping=EARLY_STOPPING, + device=DEVICE, + val_ratio=VAL_RATIO, + schedule=SCHEDULE, + workers=WORKERS, + use_augment=USE_AUGMENT, ) - BEST_PATH = (SCRIPT_DIR / ".." / "models" / "seg_best.mlmodel").resolve() - promote(run_dir, BEST_PATH, prefer_metric=True) + + promote(run_dir, BEST_MODEL_PATH) if __name__ == "__main__": diff --git a/ocr/model_training/utils/memory_info.py b/ocr/model_training/utils/memory_info.py index 07cae05..313c323 100644 --- a/ocr/model_training/utils/memory_info.py +++ b/ocr/model_training/utils/memory_info.py @@ -6,7 +6,7 @@ raise SystemExit("Missing module pynvml. Install it with: pip install nvidia-ml-py3") from exc -def human_readable_bytes(num_bytes): +def human_readable_bytes(num_bytes: float) -> str: for unit in ("B", "KB", "MB", "GB", "TB"): if num_bytes < 1024: return f"{num_bytes:.2f} {unit}" @@ -14,7 +14,7 @@ def human_readable_bytes(num_bytes): return f"{num_bytes:.2f} PB" -def main(): +def main() -> None: ram = psutil.virtual_memory() print(f"Free RAM: {human_readable_bytes(ram.available)} out of {human_readable_bytes(ram.total)}") diff --git a/ocr/models/blla_submitted.mlmodel b/ocr/models/blla_submitted.mlmodel new file mode 100644 index 0000000..f64b630 Binary files /dev/null and b/ocr/models/blla_submitted.mlmodel differ diff --git a/ocr/models/ocr_models/en_best.mlmodel b/ocr/models/ocr_models/en_best.mlmodel new file mode 100644 index 0000000..7ee7c52 Binary files /dev/null and b/ocr/models/ocr_models/en_best.mlmodel differ diff --git a/ocr/models/ocr_models/handlers/0kraken.py b/ocr/models/ocr_models/handlers/0_kraken.py similarity index 90% rename from ocr/models/ocr_models/handlers/0kraken.py rename to ocr/models/ocr_models/handlers/0_kraken.py index e52c0ae..d728ded 100644 --- a/ocr/models/ocr_models/handlers/0kraken.py +++ b/ocr/models/ocr_models/handlers/0_kraken.py @@ -1,12 +1,13 @@ import logging from pathlib import Path +from typing import Any from kraken import binarization, containers, pageseg, rpred from kraken.lib.models import load_any from PIL import Image NAME = "Kraken OCR Model" -DESCRIPTION = """Kraken model""" +DESCRIPTION = """Kraken model - Accuracy ~ 85%""" MODEL = None @@ -14,13 +15,19 @@ TEXT_DIRECTION = "horizontal-lr" -def load(model_path=MODEL_PATH): +def load(model_path: Path = MODEL_PATH) -> Any: global MODEL MODEL = load_any(str(model_path), device="cpu") return MODEL -def handle(image, seg_info=None, debug=False, frontline="", filter_warnings=False): +def handle( + image: Image.Image, + seg_info: dict[str, Any] | None = None, + debug: bool = False, + frontline: str = "", + filter_warnings: bool = False, +) -> str: global MODEL if MODEL is None: load() diff --git a/ocr/models/ocr_models/handlers/1_kraken.py b/ocr/models/ocr_models/handlers/1_kraken.py new file mode 100644 index 0000000..187ded0 --- /dev/null +++ b/ocr/models/ocr_models/handlers/1_kraken.py @@ -0,0 +1,97 @@ +import logging +from pathlib import Path +from typing import Any + +from kraken import binarization, containers, pageseg, rpred +from kraken.lib.models import load_any +from PIL import Image + +NAME = "Kraken OCR Model" +DESCRIPTION = """Kraken model - Accuracy ~ 83%""" + +MODEL = None + +MODEL_PATH = Path(__file__).resolve().parent / ".." / "ocr_best.mlmodel" +TEXT_DIRECTION = "horizontal-lr" + + +def load(model_path: Path = MODEL_PATH) -> Any: + global MODEL + MODEL = load_any(str(model_path), device="cpu") + return MODEL + + +def handle( + image: Image.Image, + seg_info: dict[str, Any] | None = None, + debug: bool = False, + frontline: str = "", + filter_warnings: bool = False, +) -> str: + global MODEL + if MODEL is None: + load() + + if debug: + logging.debug(frontline + f"Using model: {MODEL_PATH.name}") + if filter_warnings: + import warnings + + warnings.filterwarnings( + "ignore", message="Using legacy polygon extractor, as the model was not trained with the new method." + ) + logging.getLogger("kraken").setLevel(logging.ERROR) + + margin_percentage = 0 + margin = min(image.width, image.height) * margin_percentage + new_width = int(image.width + (2 * margin)) + new_height = int(image.height + (2 * margin)) + new_image = image.resize((new_width, new_height), Image.LANCZOS) + bin_im = binarization.nlbin(new_image) + output = "" + try: + output = "".join(rec.prediction for rec in list(rpred.rpred(MODEL, bin_im, pageseg.segment(bin_im)))) + if len(output) == 0: + raise Exception("Empty output from automatic segmentation") + else: + if debug: + logging.debug(frontline + "Used automatic segmentation") + except Exception: + if seg_info is None: + seg_info = {} + output = "".join( + rec.prediction + for rec in list( + rpred.rpred( + MODEL, + bin_im, + containers.Segmentation( + type="baselines", + imagename="", + text_direction=TEXT_DIRECTION, + script_detection=False, + lines=[ + containers.BaselineLine( + id=str(seg_info.get("index", 0)), + baseline=[ + (margin, new_image.height - margin - 1), + (new_image.width - margin - 1, new_image.height - margin - 1), + ], + boundary=[ + (0, 0), + (new_image.width - 1, 0), + (new_image.width - 1, new_image.height - 1), + (0, new_image.height - 1), + ], + ) + ], + regions={}, + line_orders=[[0]], + ), + ) + ) + ) + if debug: + logging.debug(frontline + "Used provided segmentation") + + return output diff --git a/ocr/models/ocr_models/handlers/2_kraken.py b/ocr/models/ocr_models/handlers/2_kraken.py new file mode 100644 index 0000000..b17bd0f --- /dev/null +++ b/ocr/models/ocr_models/handlers/2_kraken.py @@ -0,0 +1,97 @@ +import logging +from pathlib import Path +from typing import Any + +from kraken import binarization, containers, pageseg, rpred +from kraken.lib.models import load_any +from PIL import Image + +NAME = "Kraken OCR Model" +DESCRIPTION = """Kraken model - Accuracy ~ 60%""" + +MODEL = None + +MODEL_PATH = Path(__file__).resolve().parent / ".." / "ocr_first.mlmodel" +TEXT_DIRECTION = "horizontal-lr" + + +def load(model_path: Path = MODEL_PATH) -> Any: + global MODEL + MODEL = load_any(str(model_path), device="cpu") + return MODEL + + +def handle( + image: Image.Image, + seg_info: dict[str, Any] | None = None, + debug: bool = False, + frontline: str = "", + filter_warnings: bool = False, +) -> str: + global MODEL + if MODEL is None: + load() + + if debug: + logging.debug(frontline + f"Using model: {MODEL_PATH.name}") + if filter_warnings: + import warnings + + warnings.filterwarnings( + "ignore", message="Using legacy polygon extractor, as the model was not trained with the new method." + ) + logging.getLogger("kraken").setLevel(logging.ERROR) + + margin_percentage = 0 + margin = min(image.width, image.height) * margin_percentage + new_width = int(image.width + (2 * margin)) + new_height = int(image.height + (2 * margin)) + new_image = image.resize((new_width, new_height), Image.LANCZOS) + bin_im = binarization.nlbin(new_image) + output = "" + try: + output = "".join(rec.prediction for rec in list(rpred.rpred(MODEL, bin_im, pageseg.segment(bin_im)))) + if len(output) == 0: + raise Exception("Empty output from automatic segmentation") + else: + if debug: + logging.debug(frontline + "Used automatic segmentation") + except Exception: + if seg_info is None: + seg_info = {} + output = "".join( + rec.prediction + for rec in list( + rpred.rpred( + MODEL, + bin_im, + containers.Segmentation( + type="baselines", + imagename="", + text_direction=TEXT_DIRECTION, + script_detection=False, + lines=[ + containers.BaselineLine( + id=str(seg_info.get("index", 0)), + baseline=[ + (margin, new_image.height - margin - 1), + (new_image.width - margin - 1, new_image.height - margin - 1), + ], + boundary=[ + (0, 0), + (new_image.width - 1, 0), + (new_image.width - 1, new_image.height - 1), + (0, new_image.height - 1), + ], + ) + ], + regions={}, + line_orders=[[0]], + ), + ) + ) + ) + if debug: + logging.debug(frontline + "Used provided segmentation") + + return output diff --git a/ocr/models/ocr_models/handlers/1kraken.py b/ocr/models/ocr_models/handlers/3_kraken.py similarity index 90% rename from ocr/models/ocr_models/handlers/1kraken.py rename to ocr/models/ocr_models/handlers/3_kraken.py index 20aaa73..1886f90 100644 --- a/ocr/models/ocr_models/handlers/1kraken.py +++ b/ocr/models/ocr_models/handlers/3_kraken.py @@ -1,26 +1,33 @@ import logging from pathlib import Path +from typing import Any from kraken import binarization, containers, pageseg, rpred from kraken.lib.models import load_any from PIL import Image NAME = "Kraken OCR Model" -DESCRIPTION = """Totally different Kraken model""" +DESCRIPTION = """Kraken model - Accuracy ~ 35%""" MODEL = None -MODEL_PATH = Path(__file__).resolve().parent / ".." / "ocr_best_submitted.mlmodel" +MODEL_PATH = Path(__file__).resolve().parent / ".." / "ocr_best_ketos.mlmodel" TEXT_DIRECTION = "horizontal-lr" -def load(model_path=MODEL_PATH): +def load(model_path: Path = MODEL_PATH) -> Any: global MODEL MODEL = load_any(str(model_path), device="cpu") return MODEL -def handle(image, seg_info=None, debug=False, frontline="", filter_warnings=False): +def handle( + image: Image.Image, + seg_info: dict[str, Any] | None = None, + debug: bool = False, + frontline: str = "", + filter_warnings: bool = False, +) -> str: global MODEL if MODEL is None: load() diff --git a/ocr/models/ocr_models/handlers/4_kraken.py b/ocr/models/ocr_models/handlers/4_kraken.py new file mode 100644 index 0000000..f90778f --- /dev/null +++ b/ocr/models/ocr_models/handlers/4_kraken.py @@ -0,0 +1,97 @@ +import logging +from pathlib import Path +from typing import Any + +from kraken import binarization, containers, pageseg, rpred +from kraken.lib.models import load_any +from PIL import Image + +NAME = "Kraken OCR Model" +DESCRIPTION = """Kraken model - Accuracy ~ 27%""" + +MODEL = None + +MODEL_PATH = Path(__file__).resolve().parent / ".." / "kraken_basic.mlmodel" +TEXT_DIRECTION = "horizontal-lr" + + +def load(model_path: Path = MODEL_PATH) -> Any: + global MODEL + MODEL = load_any(str(model_path), device="cpu") + return MODEL + + +def handle( + image: Image.Image, + seg_info: dict[str, Any] | None = None, + debug: bool = False, + frontline: str = "", + filter_warnings: bool = False, +) -> str: + global MODEL + if MODEL is None: + load() + + if debug: + logging.debug(frontline + f"Using model: {MODEL_PATH.name}") + if filter_warnings: + import warnings + + warnings.filterwarnings( + "ignore", message="Using legacy polygon extractor, as the model was not trained with the new method." + ) + logging.getLogger("kraken").setLevel(logging.ERROR) + + margin_percentage = 0 + margin = min(image.width, image.height) * margin_percentage + new_width = int(image.width + (2 * margin)) + new_height = int(image.height + (2 * margin)) + new_image = image.resize((new_width, new_height), Image.LANCZOS) + bin_im = binarization.nlbin(new_image) + output = "" + try: + output = "".join(rec.prediction for rec in list(rpred.rpred(MODEL, bin_im, pageseg.segment(bin_im)))) + if len(output) == 0: + raise Exception("Empty output from automatic segmentation") + else: + if debug: + logging.debug(frontline + "Used automatic segmentation") + except Exception: + if seg_info is None: + seg_info = {} + output = "".join( + rec.prediction + for rec in list( + rpred.rpred( + MODEL, + bin_im, + containers.Segmentation( + type="baselines", + imagename="", + text_direction=TEXT_DIRECTION, + script_detection=False, + lines=[ + containers.BaselineLine( + id=str(seg_info.get("index", 0)), + baseline=[ + (margin, new_image.height - margin - 1), + (new_image.width - margin - 1, new_image.height - margin - 1), + ], + boundary=[ + (0, 0), + (new_image.width - 1, 0), + (new_image.width - 1, new_image.height - 1), + (0, new_image.height - 1), + ], + ) + ], + regions={}, + line_orders=[[0]], + ), + ) + ) + ) + if debug: + logging.debug(frontline + "Used provided segmentation") + + return output diff --git a/ocr/models/ocr_models/ocr_best.mlmodel b/ocr/models/ocr_models/ocr_best.mlmodel new file mode 100644 index 0000000..7a609a2 Binary files /dev/null and b/ocr/models/ocr_models/ocr_best.mlmodel differ diff --git a/ocr/models/ocr_models/ocr_best_ketos.mlmodel b/ocr/models/ocr_models/ocr_best_ketos.mlmodel new file mode 100644 index 0000000..f94665e Binary files /dev/null and b/ocr/models/ocr_models/ocr_best_ketos.mlmodel differ diff --git a/ocr/models/ocr_models/ocr_best_submitted.mlmodel b/ocr/models/ocr_models/ocr_best_submitted.mlmodel index 7a609a2..5e72a1f 100644 Binary files a/ocr/models/ocr_models/ocr_best_submitted.mlmodel and b/ocr/models/ocr_models/ocr_best_submitted.mlmodel differ diff --git a/ocr/models/seg_best.mlmodel b/ocr/models/seg_best.mlmodel deleted file mode 100644 index 65e9af9..0000000 Binary files a/ocr/models/seg_best.mlmodel and /dev/null differ diff --git a/ocr/models/seg_best_old.mlmodel (2) b/ocr/models/seg_best_old.mlmodel (2) new file mode 100644 index 0000000..1a47b20 Binary files /dev/null and b/ocr/models/seg_best_old.mlmodel (2) differ diff --git a/ocr/models/seg_best_old.mlmodel (3) b/ocr/models/seg_best_old.mlmodel (3) new file mode 100644 index 0000000..456ce33 Binary files /dev/null and b/ocr/models/seg_best_old.mlmodel (3) differ diff --git a/ocr/models/seg_best_submitted.mlmodel b/ocr/models/seg_best_submitted.mlmodel deleted file mode 100644 index a1bb2f4..0000000 Binary files a/ocr/models/seg_best_submitted.mlmodel and /dev/null differ diff --git a/ocr/tests/files/0.png b/ocr/tests/files/0.png deleted file mode 100644 index b53d27d..0000000 Binary files a/ocr/tests/files/0.png and /dev/null differ diff --git a/ocr/tests/files/0.txt b/ocr/tests/files/0.txt deleted file mode 100644 index 1e9f139..0000000 --- a/ocr/tests/files/0.txt +++ /dev/null @@ -1 +0,0 @@ -military aides to come up with the answer \ No newline at end of file diff --git a/ocr/tests/files/1.png b/ocr/tests/files/1.png deleted file mode 100644 index bb53cdb..0000000 Binary files a/ocr/tests/files/1.png and /dev/null differ diff --git a/ocr/tests/files/1.txt b/ocr/tests/files/1.txt deleted file mode 100644 index abf8f5b..0000000 --- a/ocr/tests/files/1.txt +++ /dev/null @@ -1 +0,0 @@ -on February 20. \ No newline at end of file diff --git a/ocr/tests/test_backend_client.py b/ocr/tests/test_backend_client.py new file mode 100644 index 0000000..f58f516 --- /dev/null +++ b/ocr/tests/test_backend_client.py @@ -0,0 +1,106 @@ +import base64 +from typing import Any + +import pytest + +from app import backend_client + + +class _DummyResponse: + def __init__(self, payload: Any, status_code: int = 200) -> None: + self._payload = payload + self.status_code = status_code + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> Any: + return self._payload + + +def test_get_format_by_name(monkeypatch: pytest.MonkeyPatch) -> None: + expected = [ + {"format": "PDF", "id": 2}, + {"format": "png", "id": 1}, + ] + + def fake_get(url: str, headers: dict[str, str], timeout: int) -> _DummyResponse: + assert url.endswith("/backend/api/v1/formats") + assert headers["Authorization"] == "token" + assert timeout == 10 + return _DummyResponse(expected) + + monkeypatch.setattr(backend_client.requests, "get", fake_get) + + result = backend_client.get_format( + backend_url="http://example.com", + auth_token="token", + format_name="pdf", + ) + assert result == expected[0] + + +def test_get_format_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + expected = [ + {"format": "PDF", "id": 2}, + {"format": "png", "id": 1}, + ] + + def fake_get(url: str, headers: dict[str, str], timeout: int) -> _DummyResponse: + return _DummyResponse(expected) + + monkeypatch.setattr(backend_client.requests, "get", fake_get) + + result = backend_client.get_format( + backend_url="http://example.com", + auth_token=None, + format_id=1, + ) + assert result == expected[1] + + +def test_get_format_invalid_payload(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_get(url: str, headers: dict[str, str], timeout: int) -> _DummyResponse: + return _DummyResponse({"unexpected": "payload"}) + + monkeypatch.setattr(backend_client.requests, "get", fake_get) + + with pytest.raises(ValueError, match="Unexpected formats response payload"): + backend_client.get_format("http://example.com", None) + + +def test_send_file_posts_expected_payload(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + def fake_post(url: str, headers: dict[str, str], json: dict[str, Any], timeout: int) -> _DummyResponse: + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + captured["timeout"] = timeout + return _DummyResponse({"ok": True}) + + monkeypatch.setattr(backend_client.requests, "post", fake_post) + + content = b"hello world" + result = backend_client.send_file( + backend_url="http://example.com", + auth_token="token", + owner_id=7, + format_id=9, + generation=3, + content_bytes=content, + primary_file_id=5, + timeout=20, + ) + + assert result == {"ok": True} + assert captured["url"].endswith("/backend/api/v1/stored_files") + assert captured["headers"]["Authorization"] == "token" + assert captured["headers"]["Content-Type"] == "application/json" + assert captured["timeout"] == 20 + assert captured["json"]["ownerId"] == 7 + assert captured["json"]["formatId"] == 9 + assert captured["json"]["generation"] == 4 + assert captured["json"]["primaryFileId"] == 5 + assert base64.b64decode(captured["json"]["content"]) == content diff --git a/ocr/tests/test_file_converter.py b/ocr/tests/test_file_converter.py new file mode 100644 index 0000000..b601bf8 --- /dev/null +++ b/ocr/tests/test_file_converter.py @@ -0,0 +1,75 @@ +import io +from pathlib import Path + +import pytest +from PIL import Image + +from app import file_converter + + +def test_convert_to_png_bytes_passthrough() -> None: + payload = b"raw_png_bytes" + result = file_converter.convert_to_png_bytes(payload, {"format": "png"}) + assert result == payload + + +def test_convert_to_png_bytes_image() -> None: + image = Image.new("RGB", (4, 4), (120, 5, 200)) + buf = io.BytesIO() + image.save(buf, format="JPEG") + jpeg_bytes = buf.getvalue() + + result = file_converter.convert_to_png_bytes(jpeg_bytes, {"format": "jpeg"}) + assert result.startswith(b"\x89PNG\r\n\x1a\n") + + +def test_convert_to_png_bytes_pdf(monkeypatch: pytest.MonkeyPatch) -> None: + class _DummyPixmap: + def tobytes(self, fmt: str) -> bytes: + assert fmt == "png" + return b"png-bytes" + + class _DummyPage: + def get_pixmap(self) -> _DummyPixmap: + return _DummyPixmap() + + class _DummyDoc: + def load_page(self, index: int) -> _DummyPage: + assert index == 0 + return _DummyPage() + + def fake_open(stream: bytes, filetype: str) -> _DummyDoc: + assert stream == b"%PDF" + assert filetype == "pdf" + return _DummyDoc() + + monkeypatch.setattr(file_converter.fitz, "open", fake_open) + + result = file_converter.convert_to_png_bytes(b"%PDF", {"format": "pdf"}) + assert result == b"png-bytes" + + +def test_convert_to_png_bytes_unsupported() -> None: + with pytest.raises(ValueError, match="Unsupported input format"): + file_converter.convert_to_png_bytes(b"data", {"format": "zip"}) + + +def test_find_fontsize_uses_line_bounds(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[str, float, str]] = [] + + def fake_measure(text: str, fontsize: float, fontname: str = "helv") -> tuple[float, int]: + calls.append((text, fontsize, fontname)) + return fontsize * 2, int(fontsize) + + monkeypatch.setattr(file_converter, "measure_text_single_line", fake_measure) + + result = file_converter.find_fontsize(line_height=12, line_width=10, text="abc") + assert 0 <= result < 12 + assert calls + + +def test_save_docx_to_path(tmp_path: Path) -> None: + out_path = tmp_path / "out.docx" + payload = b"docx-data" + file_converter.save_docx_to_path(payload, out_path) + assert out_path.read_bytes() == payload diff --git a/ocr/tests/test_main.py b/ocr/tests/test_main.py new file mode 100644 index 0000000..d4c1e35 --- /dev/null +++ b/ocr/tests/test_main.py @@ -0,0 +1,37 @@ +from typing import Any + +import pytest + +from app import main as main_module + + +def test_strip_content_replaces_payload() -> None: + data = {"content": "secret", "other": 123} + result = main_module.strip_content(data) + assert result["content"] == "[SKIPPED]" + assert result["other"] == 123 + + +def test_strip_content_handles_non_dict() -> None: + result = main_module.strip_content("raw-body") + assert result == "raw-body" + + +def test_find_correct_backend_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[str | None, str | None, int]] = [] + + def fake_get_format(url: str | None, auth: str | None, format_id: int) -> dict[str, Any]: + calls.append((url, auth, format_id)) + if url == "http://docker": + raise RuntimeError("not reachable") + return {"id": format_id} + + monkeypatch.setattr(main_module, "get_format", fake_get_format) + monkeypatch.setenv("BACKEND_BASE_URL_DOCKER", "http://docker") + monkeypatch.setenv("BACKEND_BASE_URL", "http://backend") + monkeypatch.setattr(main_module, "BACKEND_URL", None) + + result = main_module.find_correct_backend_url(auth_header="token", format_id=5) + + assert result == "http://backend" + assert calls == [("http://docker", "token", 5), ("http://backend", "token", 5)] diff --git a/ocr/tests/test_module_loading.py b/ocr/tests/test_module_loading.py new file mode 100644 index 0000000..2d4b7e7 --- /dev/null +++ b/ocr/tests/test_module_loading.py @@ -0,0 +1,22 @@ +from pathlib import Path + +import pytest + +from app.module_loading import load_module_from_path + + +def test_load_module_from_path(tmp_path: Path) -> None: + module_path = tmp_path / "temp_module.py" + module_path.write_text("VALUE = 123\n", encoding="utf-8") + + module = load_module_from_path(module_path) + + assert module.VALUE == 123 + + +def test_load_module_from_path_invalid_module(tmp_path: Path) -> None: + module_path = tmp_path / "bad_module.py" + module_path.write_text("def broken(:\n", encoding="utf-8") + + with pytest.raises(ImportError): + load_module_from_path(module_path) diff --git a/ocr/tests/test_ocr.py b/ocr/tests/test_ocr.py deleted file mode 100644 index 52d63f7..0000000 --- a/ocr/tests/test_ocr.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -from pathlib import Path - -import pytest - -import app.ocr as ocr_module - -DIR_PATH = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "files"))) -LEV_PERCENTAGE = 0.35 - - -LEV_AVAILABLE = True -LEV_SOURCE = None -try: - from rapidfuzz.distance import Levenshtein as _Lev - - def lev_distance(a: str, b: str) -> int: - return _Lev.distance(a, b) - - LEV_SOURCE = "rapidfuzz.distance.Levenshtein" - pass -except Exception: - try: - from Levenshtein import distance as _lev_distance # type: ignore - - def lev_distance(a: str, b: str) -> int: - return _lev_distance(a, b) - - LEV_SOURCE = "python-Levenshtein" - except Exception: - LEV_AVAILABLE = False - LEV_SOURCE = "not available" - - -def _normalize_text(s: str) -> str: - return " ".join((s or "").split()).strip().lower() - - -def _collect_cases(data_dir: Path): - if not data_dir.exists(): - return [pytest.param(None, None, marks=pytest.mark.skip(reason="DIR_PATH does not exist."))] - - pngs = sorted(data_dir.glob("*.png")) - if not pngs: - return [pytest.param(None, None, marks=pytest.mark.skip(reason="No .png files found in DIR_PATH."))] - - params = [] - for png in pngs: - txt = png.with_suffix(".txt") - if not txt.exists(): - params.append( - pytest.param( - png, - None, - marks=pytest.mark.skip(reason=f"Missing ground-truth TXT for {png.name}. Expecting {txt.name}."), - ) - ) - else: - params.append(pytest.param(png, txt, id=png.name)) - return params - - -PARAMS = _collect_cases(DIR_PATH) - - -@pytest.fixture(autouse=True) -def reset_model_cache(): - if hasattr(ocr_module, "_MODEL"): - ocr_module._MODEL = None - yield - if hasattr(ocr_module, "_MODEL"): - ocr_module._MODEL = None - - -@pytest.mark.parametrize("png_path, gt_path", PARAMS) -def test_ocr_quality_dynamic_threshold(png_path: Path, gt_path: Path, capsys): - if png_path is None or gt_path is None: - pytest.skip("Invalid parameters for this test case.") - - if not LEV_AVAILABLE: - pytest.skip("Levenshtein implementation is not available (install rapidfuzz or python-Levenshtein).") - - img_bytes = png_path.read_bytes() - expected_text = gt_path.read_text(encoding="utf-8", errors="ignore") - - got_text = ocr_module.ocr_png_bytes(img_bytes) - - got_norm = _normalize_text(got_text) - exp_norm = _normalize_text(expected_text) - - dist = lev_distance(got_norm, exp_norm) - allowed = max(1, int(len(exp_norm) * LEV_PERCENTAGE)) - - capsys.readouterr() - - assert dist <= allowed, ( - f"Levenshtein distance ({dist}) exceeds allowed threshold ({allowed}) for file {png_path.name}" - ) diff --git a/ocr/tests/test_ocr_module.py b/ocr/tests/test_ocr_module.py new file mode 100644 index 0000000..565ddce --- /dev/null +++ b/ocr/tests/test_ocr_module.py @@ -0,0 +1,48 @@ +import io + +import pytest +from PIL import Image + +from app import ocr as ocr_module + + +def _png_bytes() -> bytes: + image = Image.new("RGB", (4, 4), (0, 0, 0)) + buf = io.BytesIO() + image.save(buf, format="PNG") + return buf.getvalue() + + +def test_get_model_handler_returns_specific(monkeypatch: pytest.MonkeyPatch) -> None: + def handler_1(*args, **kwargs) -> str: + return "h1" + def handler_2(*args, **kwargs) -> str: + return "h2" + monkeypatch.setattr( + ocr_module, + "MODEL_LIST", + [ + {"id": 1, "handle": handler_1}, + {"id": 2, "handle": handler_2}, + ], + ) + + assert ocr_module.get_model_handler(2) is handler_2 + + +def test_get_model_handler_fallback_to_id_1(monkeypatch: pytest.MonkeyPatch) -> None: + def handler_1(*args, **kwargs) -> str: + return "h1" + def handler_2(*args, **kwargs) -> str: + return "h2" + monkeypatch.setattr( + ocr_module, + "MODEL_LIST", + [ + {"id": 1, "handle": handler_1}, + {"id": 2, "handle": handler_2}, + ], + ) + + assert ocr_module.get_model_handler(999) is handler_1 + diff --git a/ocr/tests/test_segmentator.py b/ocr/tests/test_segmentator.py new file mode 100644 index 0000000..34f2a4e --- /dev/null +++ b/ocr/tests/test_segmentator.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import contextlib +from typing import Any + +import numpy as np +import pytest +from PIL import Image + +from app import segmentator as segmentator_module + + +def test_bbox_from_line_bbox() -> None: + class Line: + bbox = (1, 2, 3, 4) + + assert segmentator_module._bbox_from_line(Line(), 10, 10) == (1, 2, 3, 4) + + +def test_bbox_from_line_boundary() -> None: + class Line: + boundary = [(2, 3), (4, 6), (1, 5)] + + assert segmentator_module._bbox_from_line(Line(), 10, 10) == (1, 3, 4, 6) + + +def test_bbox_from_line_baseline() -> None: + class Line: + baseline = [(1, 2), (3, 4)] + + assert segmentator_module._bbox_from_line(Line(), 10, 100) == (1, 0, 3, 9) + + +def test_segment_lines_from_image_return_modes(monkeypatch: pytest.MonkeyPatch) -> None: + class Line: + bbox = (1, 1, 3, 3) + baseline = [(1, 2), (2, 2)] + boundary = [(1, 1), (3, 1), (3, 3), (1, 3)] + tags = ["tag"] + regions = ["region"] + type = "line" + + class DummyBounds: + lines = [Line()] + + def fake_load_seg_model(device: str | None, seg_model_path: Any) -> object: + return object() + + def fake_segment(im: Image.Image, model: object, device: str, text_direction: str) -> DummyBounds: + return DummyBounds() + + @contextlib.contextmanager + def fake_inference_mode(): + yield + + monkeypatch.setattr(segmentator_module, "_load_seg_model", fake_load_seg_model) + monkeypatch.setattr(segmentator_module.blla, "segment", fake_segment) + monkeypatch.setattr(segmentator_module.torch, "inference_mode", fake_inference_mode) + + im = Image.new("RGB", (10, 10), (0, 0, 0)) + + array_results = segmentator_module.segment_lines_from_image(im, return_mode="array") + assert array_results[0]["bbox"] == (1, 1, 3, 3) + assert isinstance(array_results[0]["array"], np.ndarray) + + pil_results = segmentator_module.segment_lines_from_image(im, return_mode="pil") + assert isinstance(pil_results[0]["pil_image"], Image.Image)