Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/build_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ jobs:
working-directory: ${{ inputs.directory }}
run: ruff check .

# - name: Run tests
# working-directory: ${{ inputs.directory }}
# run: pytest --maxfail=3 --disable-warnings --tb=short
- name: Run tests
working-directory: ${{ inputs.directory }}
run: pytest tests --maxfail=3 --disable-warnings --tb=short

- name: Save pip cache
if: ${{ !cancelled() }}
Expand Down
8 changes: 7 additions & 1 deletion ocr/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,10 @@ model_training/data_backup

model_training/input_backup

model_training/pipeline.sh
model_training/pipeline.sh

model_training/test_results

*.json

models/postprocessing_models
27 changes: 17 additions & 10 deletions ocr/app/backend_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import base64
from typing import Any

import requests

API_BASE = "/backend/api/v1"


def get_format(backend_url, auth_token, format_name=None, format_id=None, timeout=10):
def get_format(
backend_url: str,
auth_token: str | None,
format_name: str | None = None,
format_id: int | None = None,
timeout: int = 10,
) -> dict[str, Any] | None:
url = f"{backend_url.rstrip('/')}{API_BASE}/formats"
headers = {}
if auth_token:
Expand All @@ -29,15 +36,15 @@ def get_format(backend_url, auth_token, format_name=None, format_id=None, timeou


def send_file(
backend_url,
auth_token,
owner_id,
format_id,
generation,
content_bytes,
primary_file_id=None,
timeout=15,
):
backend_url: str,
auth_token: str,
owner_id: int,
format_id: int,
generation: int,
content_bytes: bytes,
primary_file_id: int | None = None,
timeout: int = 15,
) -> dict[str, Any]:
if not backend_url:
raise ValueError("backend_url is required")

Expand Down
45 changes: 31 additions & 14 deletions ocr/app/file_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Any

import fitz
from docx import Document
Expand All @@ -20,7 +22,7 @@
OUT_DIR = SCRIPT_DIR / ".." / "temp" / "pdf_pages"


def pil_to_pixmap(image):
def pil_to_pixmap(image: Image.Image) -> fitz.Pixmap:
if image.mode == "1":
image = image.convert("L")

Expand All @@ -42,7 +44,7 @@ def pil_to_pixmap(image):
return fitz.Pixmap(colorspace, width, height, samples, alpha)


def initialize_pdf_with_image(image, visible_image=True):
def initialize_pdf_with_image(image: Image.Image, visible_image: bool = True) -> fitz.Document:
pdf_doc = fitz.open()
rect = fitz.Rect(0, 0, image.width, image.height)
page = pdf_doc.new_page(width=image.width, height=image.height)
Expand All @@ -52,14 +54,14 @@ def initialize_pdf_with_image(image, visible_image=True):
return pdf_doc


def measure_text_single_line(text, fontsize=11, fontname="helv"):
def measure_text_single_line(text: str, fontsize: int = 11, fontname: str = "helv") -> tuple[float, int]:
font = fitz.Font(fontname)
width = font.text_length(text, fontsize=fontsize)
height = fontsize
return width, height


def find_fontsize(line_height, line_width, text, fontname="helv"):
def find_fontsize(line_height: int, line_width: int, text: str, fontname: str = "helv") -> int:
min_fontsize = 1
max_fontsize = line_height

Expand All @@ -79,7 +81,13 @@ def find_fontsize(line_height, line_width, text, fontname="helv"):
return int(fontsize) - 1


def insert_text_at_bbox(pdf_doc, text, bbox, visible_image=True, draw_rect=False):
def insert_text_at_bbox(
pdf_doc: fitz.Document,
text: str,
bbox: Iterable[int],
visible_image: bool = True,
draw_rect: bool = False,
) -> None:
page = pdf_doc[0]
x0, y0, x1, y1 = bbox
rect = fitz.Rect(x0, y0, x1, y1)
Expand All @@ -98,30 +106,38 @@ def insert_text_at_bbox(pdf_doc, text, bbox, visible_image=True, draw_rect=False
page.insert_text(point, text, fontsize=fs, fontname="helv", color=0, fill_opacity=1, overlay=True)


def pdf_to_bytes(pdf_doc):
def pdf_to_bytes(pdf_doc: fitz.Document) -> bytes:
return pdf_doc.write()

def save_docx_to_path(docx_bytes, output_path) :
with open(output_path, "wb") as f :

def save_docx_to_path(docx_bytes: bytes, output_path: Path) -> None:
with open(output_path, "wb") as f:
f.write(docx_bytes)

def pdf_to_docx_bytes(pdf_doc) :

def pdf_to_docx_bytes(pdf_doc: fitz.Document) -> bytes:
pdf_bytes = pdf_to_bytes(pdf_doc)

laparams = LAParams()
text = extract_text(io.BytesIO(pdf_bytes), laparams=laparams)

doc = Document()
for line in text.splitlines() :
if line.strip() == "" :
for line in text.splitlines():
if line.strip() == "":
continue
doc.add_paragraph(line)

buf = io.BytesIO()
doc.save(buf)
return buf.getvalue()

def convert_to_png_bytes(input_bytes, input_format, debug=False, debug_indent=0):

def convert_to_png_bytes(
input_bytes: bytes,
input_format: dict[str, Any],
debug: bool = False,
debug_indent: int = 0,
) -> bytes:
if debug:
logging.debug(get_frontline(debug_indent) + f"Converting input format '{input_format}' to PNG bytes")

Expand All @@ -143,8 +159,9 @@ def convert_to_png_bytes(input_bytes, input_format, debug=False, debug_indent=0)

elif input_format["format"] in ["jpeg", "jpg", "tiff", "bmp", "gif"]:
if debug:
logging.debug(get_frontline(debug_indent) +
f"Converting image format '{input_format['format']}' to PNG using PIL")
logging.debug(
get_frontline(debug_indent) + f"Converting image format '{input_format['format']}' to PNG using PIL"
)
im = Image.open(io.BytesIO(input_bytes))
with io.BytesIO() as output:
im.save(output, format="PNG")
Expand Down
50 changes: 26 additions & 24 deletions ocr/app/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
import logging
import os
from typing import Any

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
Expand Down Expand Up @@ -38,15 +38,16 @@ class IncomingFile(BaseModel):

@field_validator("content")
@classmethod
def validate_base64(cls, v):
def validate_base64(cls, v: str) -> str:
try:
base64.b64decode(v, validate=True)
except Exception as e:
logging.error("Invalid base64 content: %s", e)
raise ValueError(f"Invalid base64 content: {e}") from e
return v

def strip_content(data):

def strip_content(data: Any) -> Any:
try:
if "content" in data:
data["content"] = "[SKIPPED]"
Expand All @@ -55,9 +56,10 @@ def strip_content(data):
data = "[NON-JSON BODY]"
return data

def find_correct_backend_url(auth_header, format_id):

def find_correct_backend_url(auth_header: str | None, format_id: int) -> str | None:
global BACKEND_URL
if BACKEND_URL is None :
if BACKEND_URL is None:
try:
backend_base_url = os.getenv("BACKEND_BASE_URL_DOCKER")
get_format(backend_base_url, auth_header, format_id=format_id)
Expand All @@ -71,31 +73,32 @@ def find_correct_backend_url(auth_header, format_id):
logging.critical("Failed to find backend URL: %s", e)
return BACKEND_URL


@app.get("/health")
def health():
def health() -> dict[str, str]:
try:
_ = get_model_list()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "detail": str(e)}


@app.post("/ocr/process")
async def handle_file(payload: IncomingFile, request: Request):

raw = await request.body()

logging.debug("Full request (content skipped): %s\n", strip_content(json.loads(raw)))

logging.debug(
"Received file: id=%s, ownerId=%s formatId=%s generation=%s primaryFileId=%s model_id=%s size_b64=%d",
payload.id,
payload.ownerId,
payload.formatId,
payload.generation,
payload.primaryFileId,
payload.processingModelId,
len(payload.content),
)
async def handle_file(payload: IncomingFile, request: Request) -> dict[str, Any]:
await request.body()

# logging.info("Full request (content skipped): %s\n", strip_content(json.loads(raw)))

# logging.info(
# "Received file: id=%s, ownerId=%s formatId=%s generation=%s primaryFileId=%s model_id=%s size_b64=%d",
# payload.id,
# payload.ownerId,
# payload.formatId,
# payload.generation,
# payload.primaryFileId,
# payload.processingModelId,
# len(payload.content),
# )

auth_header = request.headers.get("authorization")

Expand Down Expand Up @@ -134,7 +137,6 @@ async def handle_file(payload: IncomingFile, request: Request):

logging.info("Sent OCR result back to backend, got response: %s", strip_content(result))


out_docx_format = get_format(backend_base_url, auth_header, format_name="docx")
if not out_docx_format:
logging.critical("DOCX format not found in backend formats")
Expand All @@ -156,7 +158,7 @@ async def handle_file(payload: IncomingFile, request: Request):


@app.get("/ocr/available_models")
def available_models():
def available_models() -> list[dict[str, Any]]:
try:
return [{k: v for k, v in model.items() if k != "handle"} for model in get_model_list()]
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion ocr/app/module_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType


@dataclass(frozen=True)
Expand All @@ -14,7 +15,7 @@ class ModelSpec:
handle: Callable


def load_module_from_path(path):
def load_module_from_path(path: Path) -> ModuleType:
unique_name = f"_ocr_model_{path.stem}_{abs(hash(path.as_posix()))}"
spec = spec_from_file_location(unique_name, path)

Expand Down
Loading
Loading