diff --git a/.env.example b/.env.example index 047ff4b1..4cdbb08a 100644 --- a/.env.example +++ b/.env.example @@ -208,6 +208,20 @@ HF_TOKEN=your_huggingface_token_here # Optional — defaults to 12 # GRAPH_MAX_RELATIONSHIPS=12 +# ── Vision / Image Captioning (VLM Providers) ────────────── +# Set VISION_PROVIDER to one of: openai | anthropic | gemini | ollama +# Leave unset to use OCR / placeholder only. +# VISION_PROVIDER=openai + +# VISION_MODEL=gpt-4o-mini # openai default +# VISION_MODEL=claude-3-haiku-20240307 # anthropic default +# VISION_MODEL=gemini-1.5-flash # gemini default +# VISION_MODEL=llava # ollama default + +# OPENAI_API_KEY=sk-... +# ANTHROPIC_API_KEY=sk-ant-... +# GOOGLE_API_KEY=AIza... +# OLLAMA_BASE_URL=http://localhost:11434 # ── ChromaDB (Vector Store) ───────────────────────────────── diff --git a/backend/app/config.py b/backend/app/config.py index c8763b39..b8da2e88 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -139,9 +139,15 @@ class Settings(BaseSettings): # ── Reranker ───────────────────────────────────────── RERANKER_MODEL: str = "BAAI/bge-reranker-v2-m3" # Lightweight 384-dim model fine-tuned for relevance ranking # ── Vision / Image captioning ───────────────────── - VISION_PROVIDER: str | None = None # e.g. 'openai' - VISION_MODEL: str | None = None + # Set to: openai | anthropic | gemini | ollama (or leave None) + VISION_PROVIDER: str | None = None + VISION_MODEL: str | None = None # overrides provider default model + + # Provider API keys — only the active provider's key is required OPENAI_API_KEY: str = "" + ANTHROPIC_API_KEY: str = "" + GOOGLE_API_KEY: str = "" + OLLAMA_BASE_URL: str = "http://localhost:11434" # ── Workspace Invitation ───────────────────────── APP_URL: str = "http://localhost:3000" diff --git a/backend/app/rag/vision.py b/backend/app/rag/vision.py index 8699e0eb..c207b6ee 100644 --- a/backend/app/rag/vision.py +++ b/backend/app/rag/vision.py @@ -11,6 +11,8 @@ """ import base64 import logging +import app.vision.providers # noqa: F401 — triggers self-registration +from app.vision.registry import get_vision_provider from io import BytesIO from typing import Any, Dict, List, Optional @@ -187,24 +189,45 @@ def _openai_caption(image_bytes: bytes) -> str: # ── Public API ─────────────────────────────────────────────────────────────── -def caption_image(image_bytes: bytes, page: Optional[int] = None) -> str: - """Generate a caption for a single image (bytes). - - Resolution order: OpenAI (if configured) → OCR → placeholder. - """ -def caption_image(image_bytes: bytes | List[bytes], page: int | List[int] | None = None) -> str | List[str]: +def caption_image( + image_bytes: "bytes | List[bytes]", + page: "int | List[int] | None" = None, +) -> "str | List[str]": """Generate a caption for a single image or a batch of images. - Order of operations: - - If a list of image bytes is passed, returns a list of captions. - - If an external VLM provider is configured, attempt to call it. - - Fall back to local OCR (pytesseract) if available. - - Otherwise return a simple placeholder caption including the page number. + Resolution order: + 1. Configured VLM provider (set VISION_PROVIDER in .env) + 2. Local OCR via pytesseract + 3. Placeholder string with page number and dimensions """ if isinstance(image_bytes, list): - pages = page if isinstance(page, list) else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes)) + pages = ( + page if isinstance(page, list) + else ([page] * len(image_bytes) if page is not None else [None] * len(image_bytes)) + ) return [caption_image(img, pg) for img, pg in zip(image_bytes, pages)] + # Strategy: try the configured VLM provider + provider = get_vision_provider(getattr(settings, "VISION_PROVIDER", None)) + if provider is not None: + result = provider.caption(image_bytes) + if result: + return result + + # Fallback 1: local OCR + ocr = _ocr_caption(image_bytes) + if ocr: + return ocr + + # Fallback 2: placeholder + try: + pix = fitz.Pixmap(image_bytes) + dims = f"{pix.width}x{pix.height} px" + except Exception: + dims = "unknown size" + + return f"Figure on page {page} ({dims})." if page else f"Figure ({dims})." + # Placeholder for provider-based captioning (e.g., OpenAI / LLaVA hooks) provider = getattr(settings, "VISION_PROVIDER", None) diff --git a/backend/tests/test_vision_providers.py b/backend/tests/test_vision_providers.py new file mode 100644 index 00000000..9ca33c6f --- /dev/null +++ b/backend/tests/test_vision_providers.py @@ -0,0 +1,93 @@ +"""Tests for the VLM provider Strategy Pattern (issue #592).""" +from unittest.mock import MagicMock, patch +import pytest + +from app.vision.base import BaseVisionProvider +from app.vision.registry import _REGISTRY, get_vision_provider, register_provider + + +class TestBaseVisionProvider: + def test_cannot_instantiate_abstract_class(self): + with pytest.raises(TypeError): + BaseVisionProvider() + + def test_concrete_subclass_works(self): + class Dummy(BaseVisionProvider): + def caption(self, image_bytes: bytes) -> str: + return "dummy" + assert Dummy().caption(b"x") == "dummy" + + +class TestRegistry: + def setup_method(self): + self._original = dict(_REGISTRY) + + def teardown_method(self): + _REGISTRY.clear() + _REGISTRY.update(self._original) + + def test_register_and_retrieve(self): + class FakeProvider(BaseVisionProvider): + def caption(self, image_bytes: bytes) -> str: + return "fake" + register_provider("fake", FakeProvider) + assert get_vision_provider("fake") is not None + + def test_case_insensitive(self): + class P(BaseVisionProvider): + def caption(self, image_bytes: bytes) -> str: + return "" + register_provider("UPPER", P) + assert get_vision_provider("upper") is not None + + def test_unknown_returns_none(self): + assert get_vision_provider("doesnotexist") is None + + def test_none_returns_none(self): + assert get_vision_provider(None) is None + + def test_broken_init_returns_none(self): + class Broken(BaseVisionProvider): + def __init__(self): raise RuntimeError("fail") + def caption(self, image_bytes: bytes) -> str: return "" + register_provider("broken", Broken) + assert get_vision_provider("broken") is None + + +class TestCaptionImage: + def test_uses_provider_when_configured(self): + from app.rag.vision import caption_image + + class StubProvider(BaseVisionProvider): + def caption(self, image_bytes: bytes) -> str: + return "stub caption" + + with patch("app.rag.vision.get_vision_provider", return_value=StubProvider()): + assert caption_image(b"img", page=1) == "stub caption" + + def test_falls_back_to_ocr(self): + from app.rag.vision import caption_image + + class EmptyProvider(BaseVisionProvider): + def caption(self, image_bytes: bytes) -> str: + return "" + + with patch("app.rag.vision.get_vision_provider", return_value=EmptyProvider()): + with patch("app.rag.vision._ocr_caption", return_value="ocr text"): + assert caption_image(b"img", page=1) == "ocr text" + + def test_falls_back_to_placeholder(self): + from app.rag.vision import caption_image + + with patch("app.rag.vision.get_vision_provider", return_value=None): + with patch("app.rag.vision._ocr_caption", return_value=""): + result = caption_image(b"img", page=3) + assert "page 3" in result + + def test_batch_mode(self): + from app.rag.vision import caption_image + + with patch("app.rag.vision.get_vision_provider", return_value=None): + with patch("app.rag.vision._ocr_caption", return_value=""): + results = caption_image([b"img1", b"img2"], page=[1, 2]) + assert isinstance(results, list) and len(results) == 2 \ No newline at end of file diff --git a/backend/vision/__init__.py b/backend/vision/__init__.py new file mode 100644 index 00000000..7a1eac67 --- /dev/null +++ b/backend/vision/__init__.py @@ -0,0 +1,5 @@ +"""Vision package: pluggable VLM provider strategy for image captioning.""" +from app.vision.registry import get_vision_provider, register_provider +from app.vision.base import BaseVisionProvider + +__all__ = ["BaseVisionProvider", "get_vision_provider", "register_provider"] \ No newline at end of file diff --git a/backend/vision/base.py b/backend/vision/base.py new file mode 100644 index 00000000..6e0462cf --- /dev/null +++ b/backend/vision/base.py @@ -0,0 +1,13 @@ +"""Abstract base class that every VLM provider must implement.""" +from abc import ABC, abstractmethod + + +class BaseVisionProvider(ABC): + """Strategy interface for Vision-Language Model providers.""" + + @abstractmethod + def caption(self, image_bytes: bytes) -> str: + """Generate a one-sentence caption for the given image. + + Returns a non-empty string, or empty string on failure (so caller can fall back). + """ \ No newline at end of file diff --git a/backend/vision/providers/__init__.py b/backend/vision/providers/__init__.py new file mode 100644 index 00000000..4cbbdabf --- /dev/null +++ b/backend/vision/providers/__init__.py @@ -0,0 +1,7 @@ +"""Auto-registers all built-in providers on import.""" +from app.vision.providers import ( # noqa: F401 + openai_provider, + anthropic_provider, + gemini_provider, + ollama_provider, +) \ No newline at end of file diff --git a/backend/vision/providers/anthropic_provider.py b/backend/vision/providers/anthropic_provider.py new file mode 100644 index 00000000..07e49dcb --- /dev/null +++ b/backend/vision/providers/anthropic_provider.py @@ -0,0 +1,61 @@ +"""Anthropic Claude Vision provider. +Activated when VISION_PROVIDER=anthropic and ANTHROPIC_API_KEY is set. +""" +import base64 +import logging + +from app.config import get_settings +from app.vision.base import BaseVisionProvider +from app.vision.registry import register_provider + +logger = logging.getLogger(__name__) +settings = get_settings() + +_CAPTION_PROMPT = ( + "Describe this figure or diagram in one concise sentence " + "suitable for use as a search index caption." +) + + +class AnthropicVisionProvider(BaseVisionProvider): + + def __init__(self) -> None: + self._api_key: str = getattr(settings, "ANTHROPIC_API_KEY", "") + self._model: str = getattr(settings, "VISION_MODEL", None) or "claude-3-haiku-20240307" + if not self._api_key: + raise ValueError( + "ANTHROPIC_API_KEY must be set when VISION_PROVIDER=anthropic." + ) + + def caption(self, image_bytes: bytes) -> str: + try: + import anthropic + except ImportError: + logger.error("Run: pip install anthropic") + return "" + + try: + client = anthropic.Anthropic(api_key=self._api_key) + b64 = base64.b64encode(image_bytes).decode("utf-8") + message = client.messages.create( + model=self._model, + max_tokens=120, + messages=[{ + "role": "user", + "content": [ + {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": b64}}, + {"type": "text", "text": _CAPTION_PROMPT}, + ], + }], + ) + content = message.content + if not content: + return "" + text_block = next((b for b in content if getattr(b, "type", None) == "text"), None) + return text_block.text.strip() if text_block else "" + except Exception as exc: + logger.debug("Anthropic vision caption failed: %s", exc) + return "" + + +register_provider("anthropic", AnthropicVisionProvider) \ No newline at end of file diff --git a/backend/vision/providers/gemini_provider.py b/backend/vision/providers/gemini_provider.py new file mode 100644 index 00000000..7d7a0805 --- /dev/null +++ b/backend/vision/providers/gemini_provider.py @@ -0,0 +1,50 @@ +"""Google Gemini Vision provider. +Activated when VISION_PROVIDER=gemini and GOOGLE_API_KEY is set. +""" +import logging + +from app.config import get_settings +from app.vision.base import BaseVisionProvider +from app.vision.registry import register_provider + +logger = logging.getLogger(__name__) +settings = get_settings() + +_CAPTION_PROMPT = ( + "Describe this figure or diagram in one concise sentence " + "suitable for use as a search index caption." +) + + +class GeminiVisionProvider(BaseVisionProvider): + + def __init__(self) -> None: + self._api_key: str = getattr(settings, "GOOGLE_API_KEY", "") + self._model: str = getattr(settings, "VISION_MODEL", None) or "gemini-1.5-flash" + if not self._api_key: + raise ValueError( + "GOOGLE_API_KEY must be set when VISION_PROVIDER=gemini." + ) + + def caption(self, image_bytes: bytes) -> str: + try: + import google.generativeai as genai + except ImportError: + logger.error("Run: pip install google-generativeai") + return "" + + try: + from io import BytesIO + import PIL.Image + genai.configure(api_key=self._api_key) + model = genai.GenerativeModel(self._model) + image = PIL.Image.open(BytesIO(image_bytes)) + response = model.generate_content([_CAPTION_PROMPT, image]) + text = getattr(response, "text", None) + return text.strip() if text else "" + except Exception as exc: + logger.debug("Gemini vision caption failed: %s", exc) + return "" + + +register_provider("gemini", GeminiVisionProvider) \ No newline at end of file diff --git a/backend/vision/providers/ollama_provider.py b/backend/vision/providers/ollama_provider.py new file mode 100644 index 00000000..e95a1f1d --- /dev/null +++ b/backend/vision/providers/ollama_provider.py @@ -0,0 +1,51 @@ +"""Ollama / LLaVA local Vision provider. +Activated when VISION_PROVIDER=ollama. No API key needed. +Make sure the model is pulled first: ollama pull llava +""" +import base64 +import logging + +from app.config import get_settings +from app.vision.base import BaseVisionProvider +from app.vision.registry import register_provider + +logger = logging.getLogger(__name__) +settings = get_settings() + +_CAPTION_PROMPT = ( + "Describe this figure or diagram in one concise sentence " + "suitable for use as a search index caption." +) + + +class OllamaVisionProvider(BaseVisionProvider): + + def __init__(self) -> None: + self._base_url: str = ( + getattr(settings, "OLLAMA_BASE_URL", None) or "http://localhost:11434" + ).rstrip("/") + self._model: str = getattr(settings, "VISION_MODEL", None) or "llava" + + def caption(self, image_bytes: bytes) -> str: + try: + import httpx + except ImportError: + logger.error("Run: pip install httpx") + return "" + + try: + b64 = base64.b64encode(image_bytes).decode("utf-8") + response = httpx.post( + f"{self._base_url}/api/generate", + json={"model": self._model, "prompt": _CAPTION_PROMPT, "images": [b64], "stream": False}, + timeout=60.0, + ) + response.raise_for_status() + text = response.json().get("response", "") + return text.strip() if text else "" + except Exception as exc: + logger.debug("Ollama vision caption failed: %s", exc) + return "" + + +register_provider("ollama", OllamaVisionProvider) \ No newline at end of file diff --git a/backend/vision/providers/openai_provider.py b/backend/vision/providers/openai_provider.py new file mode 100644 index 00000000..70a9dade --- /dev/null +++ b/backend/vision/providers/openai_provider.py @@ -0,0 +1,61 @@ +"""OpenAI GPT-4o-mini Vision provider. +Activated when VISION_PROVIDER=openai and OPENAI_API_KEY is set. +""" +import base64 +import logging + +from app.config import get_settings +from app.vision.base import BaseVisionProvider +from app.vision.registry import register_provider + +logger = logging.getLogger(__name__) +settings = get_settings() + +_CAPTION_PROMPT = ( + "Describe this figure or diagram in one concise sentence " + "suitable for use as a search index caption." +) + + +class OpenAIVisionProvider(BaseVisionProvider): + + def __init__(self) -> None: + self._api_key: str = getattr(settings, "OPENAI_API_KEY", "") + self._model: str = getattr(settings, "VISION_MODEL", None) or "gpt-4o-mini" + if not self._api_key: + raise ValueError( + "OPENAI_API_KEY must be set when VISION_PROVIDER=openai." + ) + + def caption(self, image_bytes: bytes) -> str: + try: + from openai import OpenAI + except ImportError: + logger.error("Run: pip install openai") + return "" + + try: + client = OpenAI(api_key=self._api_key) + b64 = base64.b64encode(image_bytes).decode("utf-8") + response = client.chat.completions.create( + model=self._model, + max_tokens=120, + messages=[{ + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}", "detail": "low"}}, + {"type": "text", "text": _CAPTION_PROMPT}, + ], + }], + ) + choices = response.choices + if not choices: + return "" + content = choices[0].message.content + return content.strip() if content else "" + except Exception as exc: + logger.debug("OpenAI vision caption failed: %s", exc) + return "" + + +register_provider("openai", OpenAIVisionProvider) \ No newline at end of file diff --git a/backend/vision/registry.py b/backend/vision/registry.py new file mode 100644 index 00000000..6be5a783 --- /dev/null +++ b/backend/vision/registry.py @@ -0,0 +1,34 @@ +"""Provider registry for VLM strategy lookup.""" +import logging +from typing import Dict, Optional, Type + +from app.vision.base import BaseVisionProvider + +logger = logging.getLogger(__name__) + +_REGISTRY: Dict[str, Type[BaseVisionProvider]] = {} + + +def register_provider(name: str, cls: Type[BaseVisionProvider]) -> None: + _REGISTRY[name.lower()] = cls + logger.debug("Registered VLM provider: %s → %s", name, cls.__name__) + + +def get_vision_provider(name: Optional[str]) -> Optional[BaseVisionProvider]: + if not name: + return None + + cls = _REGISTRY.get(name.lower()) + if cls is None: + logger.warning( + "Unknown VISION_PROVIDER=%r. Available: %s", + name, + list(_REGISTRY.keys()) or ["(none)"], + ) + return None + + try: + return cls() + except Exception as exc: + logger.error("Failed to instantiate VLM provider %r: %s", name, exc) + return None \ No newline at end of file