Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 63 additions & 121 deletions backend/app/rag/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,84 +18,65 @@

from app.config import get_settings

logger = logging.getLogger(__name__)
settings = get_settings()

# ── Optional OCR backend (PIL + pytesseract) ─────────────────────────────────
# Imported once at module load instead of inline on every _ocr_caption() call.
# ``HAS_OCR`` records availability so the hot path (large batch caption loops in
# generate_captions_for_chunks) can short-circuit with a cheap boolean check
# rather than re-running an import + try/except on each image.
try:
from PIL import Image
import pytesseract

HAS_OCR = True
except ImportError:
Image = None # type: ignore[assignment]
pytesseract = None # type: ignore[assignment]
HAS_OCR = False
logger.info(
"OCR backend unavailable (PIL/pytesseract not installed); "
"image captioning will fall back to placeholders."
)

# Minimum image area (pxΒ²) β€” smaller images are decorative and skipped.
_MIN_IMAGE_AREA = 1_000


# ── 1. Proximity-based caption extraction ────────────────────────────────────

def _find_caption_near_image(
page: fitz.Page,
img_bbox: fitz.Rect,
search_margin: float = 60.0,
) -> str:
"""Return the closest text block directly below (or above) an image rect."""
page_dict = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)
blocks = page_dict.get("blocks", [])

def _closest(region: fitz.Rect) -> str:
candidates = []
for block in blocks:
if block.get("type") != 0: # 0 == text block
continue
bx0, by0, bx1, by1 = block["bbox"]
if fitz.Rect(bx0, by0, bx1, by1).intersects(region):
text = " ".join(
span["text"]
for line in block.get("lines", [])
for span in line.get("spans", [])
).strip()
if text:
candidates.append((abs(by0 - img_bbox.y1), text))
if candidates:
return min(candidates, key=lambda t: t[0])[1]
return ""
from abc import ABC, abstractmethod

# Search below first, fall back to above
below = fitz.Rect(img_bbox.x0, img_bbox.y1, img_bbox.x1, img_bbox.y1 + search_margin)
caption = _closest(below)
if caption:
return caption
# --- VLM Strategy Pattern Core ---

above = fitz.Rect(img_bbox.x0, img_bbox.y0 - search_margin, img_bbox.x1, img_bbox.y0)
return _closest(above)
class BaseVisionProvider(ABC):
"""Abstract interface for all Vision-Language Model providers."""
@abstractmethod
def caption(self, image_bytes: bytes) -> str | None:
"""Takes image bytes and returns a descriptive caption string or None if it fails."""
pass


def extract_captions_from_pdf(filepath: str) -> List[Dict[str, Any]]:
"""Extract proximity-based image captions from a PDF.
class OpenAIVisionProvider(BaseVisionProvider):
"""Concrete Strategy implementing OpenAI's multimodal vision capabilities."""
def __init__(self, settings):
self.settings = settings

Returns a list of dicts ordered by (page, figure_index):
{
"page": int, # 1-based
"figure_index": int, # 0-based within the page
"caption": str, # may be empty string
"bbox": list[float], # [x0, y0, x1, y1] normalised to [0, 1]
}
"""
results: List[Dict[str, Any]] = []
doc = fitz.open(filepath)
def caption(self, image_bytes: bytes) -> str | None:
try:
import openai
import base64

api_key = getattr(self.settings, "OPENAI_API_KEY", None)
if not api_key:
return None

# Use modern client initialization or configure global API key based on project convention
openai.api_key = api_key

# Production-ready execution utilizing OpenAI's chat completions API with vision capability
base64_image = base64.b64encode(image_bytes).decode("utf-8")
model = getattr(self.settings, "LLM_MODEL", "gpt-4o")

response = openai.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image in one concise sentence."},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
],
}
],
max_tokens=100,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.debug(f"OpenAIVisionProvider execution failed: {e}")
return None


# Simply extend this dictionary registry to add future VLM engines (e.g., Gemini, Claude)
VISION_PROVIDER_REGISTRY = {
"openai": OpenAIVisionProvider,
}

try:
for page_num, page in enumerate(doc):
Expand Down Expand Up @@ -221,57 +202,18 @@ def caption_image(image_bytes: bytes | List[bytes], page: int | List[int] | None
- Fall back to local OCR (pytesseract) if available.
- Otherwise return a simple placeholder caption including the page number.
"""
if isinstance(image_bytes, list):
pages = page if isinstance(page, list) else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes))
return [caption_image(img, pg) for img, pg in zip(image_bytes, pages)]

# Placeholder for provider-based captioning (e.g., OpenAI / LLaVA hooks)
provider = getattr(settings, "VISION_PROVIDER", None)

if provider == "openai":
# Dynamically resolve and execute configured strategy from registry
provider_name = getattr(settings, "VISION_PROVIDER", None)
if provider_name and provider_name.lower() in VISION_PROVIDER_REGISTRY:
try:
import base64
from openai import OpenAI
provider_class = VISION_PROVIDER_REGISTRY[provider_name.lower()]
provider_instance = provider_class(settings)

api_key = getattr(settings, "OPENAI_API_KEY", None)
if api_key:
# Initialize modern client
client = OpenAI(api_key=api_key)

# Base64 encode the incoming image bytes
base64_image = base64.b64encode(image_bytes).decode('utf-8')

# Request a visual caption using Chat Completions payload structure
resp = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe this image in one concise sentence."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=150
)

# Extract and return the caption immediately if successful
caption_text = resp.choices[0].message.content
if caption_text:
return caption_text.strip()

vlm_caption = provider_instance.caption(image_bytes)
if vlm_caption:
return vlm_caption
except Exception as e:
# Enhanced error logging to make debugging transparent
logger.warning(f"OpenAI vision provider failed: {e}, falling back to OCR")
logger.debug(f"Configured vision provider '{provider_name}' failed: {e}. Falling back to OCR.")

# Try OCR caption
ocr = _ocr_caption(image_bytes)
Expand Down
Loading