diff --git a/backend/.env.example b/backend/.env.example index af1be8b..902f42a 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -1,32 +1,94 @@ +# ========================================================== +# Unified OpenAI-Compatible Provider Config (Primary) +# Use these 4 keys for OpenAI / Xeon / Ollama-compatible endpoints. +# API_ENDPOINT: +# - OpenAI cloud: leave empty +# - Xeon/Gateway: https://api.example.com (without /v1; auto-appended) +# - Ollama: http://localhost:11434 (without /v1; auto-appended) +# PROVIDER_NAME is optional and only used for logs/UI labeling. +# If omitted, provider name is inferred from API_ENDPOINT. +# ========================================================== +API_ENDPOINT= +API_TOKEN=your-api-token-here +MODEL_NAME=gpt-4o-mini +PROVIDER_NAME= -# OpenAI Configuration -OPENAI_API_KEY=your_openai_api_key_here -# Legacy -OPENAI_MODEL=gpt-4o-mini + +# Simple embedding config (OpenAI-compatible) +# If embeddings use the same endpoint/token as chat, leave EMBEDDING_ENDPOINT +# and EMBEDDING_API empty. +# Example: +# EMBEDDING_ENDPOINT=http://localhost:11434 +# EMBEDDING_API=ollama +# EMBEDDING_MODEL=nomic-embed-text +EMBEDDING_ENDPOINT= +EMBEDDING_API= EMBEDDING_MODEL=text-embedding-3-small +EMBEDDING_PROVIDER_NAME= +# Optional: fail fast for slow local embedding endpoints +EMBEDDING_TIMEOUT=20 +EMBEDDING_MAX_RETRIES=0 + -RAG_CHUNK_SIZE=800 -RAG_CHUNK_OVERLAP=150 -RAG_TOP_K=5 -RAG_MAX_DOCS=25 -RAG_TTL_SECONDS=3600 +# Optional per-model context window overrides +# MODEL_CONTEXT_TOKENS_MAP format: +# {"gpt-4o-mini":128000,"llama3.2:3b":32768} +MODEL_CONTEXT_TOKENS= +MODEL_CONTEXT_TOKENS_MAP= +MIN_MODEL_CONTEXT_TOKENS=4096 +CONTEXT_RETRY_SHRINK_RATIO=0.75 +CONTEXT_RETRY_MARGIN_TOKENS=1200 +# Speed knobs (recommended for local models) +# false = heuristic section chips only (faster, fewer LLM calls) +DYNAMIC_SECTIONS_USE_LLM=false +# false = skip second "retry summary" LLM call +SUMMARY_RETRY_ENABLE=false +# ========================================================== +# LLM Tuning +# ========================================================== LLM_TEMPERATURE=0.2 LLM_MAX_TOKENS=900 -CACHE_MAX_DOCS=25 -CACHE_TTL_SECONDS=3600 - +# ========================================================== +# RAG / Chunking +# ========================================================== +RAG_CHUNK_CHARS=1400 +RAG_CHUNK_OVERLAP_CHARS=220 +# For local embedding models, start with 1-4. +RAG_EMBED_BATCH_SIZE=64 +RAG_SUMMARY_TOP_K=8 +RAG_SECTION_TOP_K=10 +RAG_MIN_SCORE=0.15 +RAG_CONTEXT_MAX_CHARS=18000 +RETRIEVAL_FIRST_ENABLE=true +# If true, request path can synchronously build index when missing (may be slow). +RETRIEVAL_FORCE_INDEX_ON_DEMAND=false +# ========================================================== +# Map-Reduce Summarization for Large Docs +# ========================================================== +MAP_REDUCE_ENABLE=true +MAP_REDUCE_MIN_CHARS=18000 +MAP_REDUCE_MIN_CHUNKS=8 +MAP_REDUCE_MAX_CHUNKS=24 +MAP_REDUCE_CHUNK_CHARS=2800 +MAP_REDUCE_CHUNK_OVERLAP=250 +MAP_REDUCE_BATCH_SIZE=6 +# ========================================================== +# Cache +# ========================================================== +CACHE_MAX_DOCS=25 +CACHE_TTL_SECONDS=3600 -# Service Configuration +# ========================================================== +# Service +# ========================================================== SERVICE_PORT=8000 LOG_LEVEL=INFO - -# CORS Settings -CORS_ORIGINS=* \ No newline at end of file +CORS_ORIGINS=* diff --git a/backend/api/routes.py b/backend/api/routes.py index 69a9e62..b81178b 100644 --- a/backend/api/routes.py +++ b/backend/api/routes.py @@ -4,6 +4,7 @@ """ from fastapi import APIRouter, Form, File, UploadFile, HTTPException, BackgroundTasks +from fastapi.responses import PlainTextResponse from fastapi.responses import StreamingResponse from typing import Optional import os @@ -14,7 +15,8 @@ from models import HealthResponse from services import pdf_service, llm_service -from services.rag_index_service import rag_index_service # <-- ADDED +from services.rag.rag_index_service import rag_index_service +from services.observability_service import observability_service logger = logging.getLogger(__name__) @@ -23,18 +25,26 @@ @router.get("/health", response_model=HealthResponse) async def health_check(): - """Health check endpoint - OpenAI-only""" + """Health check endpoint""" llm_health = llm_service.health_check() response = { "status": "healthy" if llm_health.get("status") == "healthy" else "unhealthy", "service": config.APP_TITLE, "version": config.APP_VERSION, - "llm_provider": "OpenAI", + "llm_provider": llm_health.get("provider", llm_service.get_provider_name()), } return response + +@router.get("/v1/observability", response_class=PlainTextResponse) +async def observability(limit: int = 100): + """ + Plain text table with LLM token observability only. + """ + return observability_service.render_table(limit=limit, llm_only=True) + @router.get("/v1/rag/status") async def rag_status(doc_id: str): """ @@ -46,11 +56,13 @@ async def rag_status(doc_id: str): return {"doc_id": doc_id.strip(), **status} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) + + @router.post("/v1/rag/chat") async def rag_chat( doc_id: str = Form(""), message: str = Form(""), - max_tokens: int = Form(500), + max_tokens: int = Form(220), temperature: float = Form(0.2), ): """ @@ -119,7 +131,7 @@ async def delete_vectors(doc_id: str): if not doc_id_clean: raise HTTPException(status_code=400, detail="doc_id is required") - from services.vector_store import vector_store + from services.rag.vector_store import vector_store vector_store.clear_doc(doc_id_clean) return {"doc_id": doc_id_clean, "status": "deleted", "message": "Vector data cleared"} diff --git a/backend/config.py b/backend/config.py index 72dfb77..78c2703 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,5 +1,5 @@ """ -Configuration settings for Doc-Sum Application +Configuration settings for FinSights Application """ import os @@ -10,8 +10,39 @@ # OpenAI Configuration (optional) OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +# Optional fallback model when MODEL_NAME is not set OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") +# Unified Provider Configuration (primary) +API_ENDPOINT = os.getenv("API_ENDPOINT", "") +API_TOKEN = os.getenv("API_TOKEN", "") +MODEL_NAME = os.getenv("MODEL_NAME", "") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "") +PROVIDER_NAME = os.getenv("PROVIDER_NAME", "") +VERIFY_SSL = os.getenv("VERIFY_SSL", "true") +LOCAL_URL_ENDPOINT = os.getenv("LOCAL_URL_ENDPOINT", "not-needed") +EMBEDDING_PROVIDER = os.getenv("EMBEDDING_PROVIDER", "same") +EMBEDDING_ENDPOINT = os.getenv("EMBEDDING_ENDPOINT", "") +EMBEDDING_API = os.getenv("EMBEDDING_API", "") +EMBEDDING_API_ENDPOINT = os.getenv("EMBEDDING_API_ENDPOINT", "") +EMBEDDING_API_TOKEN = os.getenv("EMBEDDING_API_TOKEN", "") +EMBEDDING_PROVIDER_NAME = os.getenv("EMBEDDING_PROVIDER_NAME", "") + +# Optional embedding overrides (advanced) +OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") +INFERENCE_EMBEDDING_MODEL_NAME = os.getenv("INFERENCE_EMBEDDING_MODEL_NAME", "") +OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL", "") + +# Legacy compatibility (optional) +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "") +INFERENCE_API_ENDPOINT = os.getenv("INFERENCE_API_ENDPOINT", "") +INFERENCE_API_TOKEN = os.getenv("INFERENCE_API_TOKEN", "") +INFERENCE_MODEL_NAME = os.getenv("INFERENCE_MODEL_NAME", "") +OLLAMA_ENDPOINT = os.getenv("OLLAMA_ENDPOINT", "") +OLLAMA_TOKEN = os.getenv("OLLAMA_TOKEN", "") +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "") +EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL_NAME", "") + # LLM Configuration (tuned for section summaries) diff --git a/backend/server.py b/backend/server.py index fa6ca10..ba5b038 100644 --- a/backend/server.py +++ b/backend/server.py @@ -1,18 +1,21 @@ """ -FastAPI server for Doc-Sum Application (OpenAI-only) +FastAPI server for Doc-Sum Application """ import logging +import time from fastapi import FastAPI +from fastapi import Request from fastapi.middleware.cors import CORSMiddleware import uvicorn import config from models import HealthResponse from api.routes import router +from services.observability_service import observability_service -# IMPORTANT: llm_service is inside the "service" package -from services.llm_service import llm_service +# IMPORTANT: import the llm_service singleton object directly +from services.llm.llm_service import llm_service # Configure logging logging.basicConfig( @@ -40,6 +43,22 @@ # Include API routes app.include_router(router) +@app.middleware("http") +async def request_observability_middleware(request: Request, call_next): + started = time.perf_counter() + ctx_tokens = observability_service.set_request_context(request.url.path, request.method) + status_code = 500 + try: + response = await call_next(request) + status_code = response.status_code + return response + finally: + observability_service.record_request( + status_code=status_code, + duration_ms=(time.perf_counter() - started) * 1000.0, + ) + observability_service.reset_request_context(ctx_tokens) + @app.get("/") def root(): @@ -51,9 +70,9 @@ def root(): "docs": "/docs", "health": "/health", "config": { - "llm_provider": "OpenAI", - "llm_model": config.OPENAI_MODEL, - "openai_configured": bool(config.OPENAI_API_KEY), + "llm_provider": llm_service.get_provider_name(), + "llm_model": llm_service.model, + "api_token_configured": bool(config.API_TOKEN or config.OPENAI_API_KEY), }, } return response @@ -61,7 +80,7 @@ def root(): @app.get("/health", response_model=HealthResponse) def health_check(): - """Detailed health check - OpenAI only""" + """Detailed health check""" response_data = { "status": "healthy", "service": config.APP_TITLE, @@ -70,9 +89,7 @@ def health_check(): llm_health = llm_service.health_check() - # Only set fields that likely exist in your HealthResponse model - # Keep the original fields + add llm_provider (as your old code did) - response_data["llm_provider"] = "OpenAI" + response_data["llm_provider"] = llm_health.get("provider", llm_service.get_provider_name()) # If OpenAI isn't configured or health check fails, mark unhealthy if llm_health.get("status") in ("not_configured", "unhealthy"): @@ -87,9 +104,9 @@ async def startup_event(): logger.info("=" * 60) logger.info(f"Starting {config.APP_TITLE} v{config.APP_VERSION}") logger.info("=" * 60) - logger.info("LLM Provider: OpenAI") - logger.info(f"OpenAI Configured: {bool(config.OPENAI_API_KEY)}") - logger.info(f"Model: {config.OPENAI_MODEL}") + logger.info(f"LLM Provider: {llm_service.get_provider_name()}") + logger.info(f"API Token Configured: {bool(config.API_TOKEN or config.OPENAI_API_KEY)}") + logger.info(f"Model: {llm_service.model}") logger.info(f"Port: {config.SERVICE_PORT}") logger.info("=" * 60) diff --git a/backend/services/__init__.py b/backend/services/__init__.py index b790f73..18f6a05 100644 --- a/backend/services/__init__.py +++ b/backend/services/__init__.py @@ -1,6 +1,6 @@ """Services module - Business logic layer""" -from .pdf_service import pdf_service -from .llm_service import llm_service +from .pdf import pdf_service +from .llm.llm_service import llm_service __all__ = ["pdf_service", "llm_service"] diff --git a/backend/services/llm/__init__.py b/backend/services/llm/__init__.py new file mode 100644 index 0000000..4f37b36 --- /dev/null +++ b/backend/services/llm/__init__.py @@ -0,0 +1,3 @@ +from .llm_service import llm_service + +__all__ = ["llm_service"] diff --git a/backend/services/llm/llm_provider.py b/backend/services/llm/llm_provider.py new file mode 100644 index 0000000..b6c3370 --- /dev/null +++ b/backend/services/llm/llm_provider.py @@ -0,0 +1,625 @@ +from __future__ import annotations + +import os +import time +import json +from typing import Optional, List, Dict, Any +from urllib.parse import urlparse, urlunparse + +import httpx +import config +from openai import OpenAI +from services.observability_service import observability_service + + +def _as_bool(value: str, default: bool = True) -> bool: + if value is None: + return default + return str(value).strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _normalize_base_url(endpoint: str) -> str: + ep = (endpoint or "").strip().rstrip("/") + if not ep: + return ep + if ep.endswith("/v1"): + return ep + return f"{ep}/v1" + + +def _rewrite_local_domain(endpoint: str, local_domain: str) -> str: + ep = (endpoint or "").strip() + d = (local_domain or "").strip() + if not ep or not d or d.lower() == "not-needed": + return ep + + parsed = urlparse(ep) + if parsed.hostname and parsed.hostname.lower() == d.lower(): + netloc = parsed.netloc.replace(parsed.hostname, "host.docker.internal") + return urlunparse((parsed.scheme, netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)) + return ep + + +def _infer_provider_name(endpoint: str) -> str: + ep = (endpoint or "").strip() + if not ep: + return "openai" + parsed = urlparse(ep if "://" in ep else f"https://{ep}") + host = (parsed.hostname or "").lower() + port = parsed.port + + if "openai.com" in host: + return "openai" + if "ollama" in host: + return "ollama" + if host in {"localhost", "127.0.0.1", "host.docker.internal"} and (port == 11434 or ":11434" in ep): + return "ollama" + return "inference_api" + + +class LLMProvider: + def __init__(self) -> None: + self.timeout = float(os.getenv("OPENAI_TIMEOUT", "60")) + self.max_retries = int(os.getenv("OPENAI_MAX_RETRIES", "2")) + self.embedding_timeout = float(os.getenv("EMBEDDING_TIMEOUT", str(self.timeout))) + self.embedding_max_retries = int(os.getenv("EMBEDDING_MAX_RETRIES", str(self.max_retries))) + self.verify_ssl = _as_bool(os.getenv("VERIFY_SSL", "true"), default=True) + self.local_url_endpoint = os.getenv("LOCAL_URL_ENDPOINT", "not-needed").strip() + + # Unified config (preferred) + self.api_endpoint = os.getenv("API_ENDPOINT", "").strip() + self.api_token = os.getenv("API_TOKEN", "").strip() + self.model_name = os.getenv("MODEL_NAME", "").strip() + self.provider_name_env = os.getenv("PROVIDER_NAME", "").strip().lower() + self.embedding_model = ( + os.getenv("EMBEDDING_MODEL", "").strip() + or os.getenv("EMBEDDING_MODEL_NAME", "").strip() + or os.getenv("OPENAI_EMBEDDING_MODEL", "").strip() + or "text-embedding-3-small" + ) + + # Legacy config (fallback) + self.legacy_provider = os.getenv("LLM_PROVIDER", "openai").strip().lower() + self.openai_api_key = os.getenv("OPENAI_API_KEY", "") or getattr(config, "OPENAI_API_KEY", "") + self.openai_model = os.getenv("OPENAI_MODEL", None) or getattr(config, "OPENAI_MODEL", None) or "gpt-4o-mini" + + self.inference_endpoint = os.getenv("INFERENCE_API_ENDPOINT", "").strip() + self.inference_token = os.getenv("INFERENCE_API_TOKEN", "").strip() + self.inference_model = os.getenv("INFERENCE_MODEL_NAME", "").strip() + + self.ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434").strip() + self.ollama_token = os.getenv("OLLAMA_TOKEN", "ollama").strip() + self.ollama_model = os.getenv("OLLAMA_MODEL", "llama3.2:3b").strip() + + # Embedding behavior + self.embedding_provider = os.getenv("EMBEDDING_PROVIDER", "same").strip().lower() # same | openai + self.openai_embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "").strip() or self.embedding_model + self.inference_embedding_model = os.getenv("INFERENCE_EMBEDDING_MODEL_NAME", "").strip() or self.embedding_model + self.ollama_embedding_model = os.getenv("OLLAMA_EMBEDDING_MODEL", "").strip() or self.embedding_model + # Simple dedicated embedding config (preferred) + self.embedding_api_endpoint = ( + os.getenv("EMBEDDING_ENDPOINT", "").strip() + or os.getenv("EMBEDDING_API_ENDPOINT", "").strip() # backward compatible + ) + self.embedding_api_token = ( + os.getenv("EMBEDDING_API", "").strip() + or os.getenv("EMBEDDING_API_TOKEN", "").strip() # backward compatible + ) + self.embedding_model_name = ( + os.getenv("EMBEDDING_MODEL", "").strip() + or os.getenv("EMBEDDING_MODEL_NAME", "").strip() # backward compatible + or self.embedding_model + ) + self.embedding_provider_name = os.getenv("EMBEDDING_PROVIDER_NAME", "").strip().lower() + + self.client: Optional[OpenAI] = None + self.embedding_client: Optional[OpenAI] = None + self.provider_embedding_client: Optional[OpenAI] = None + self.api_key: str = "" + self._context_tokens_cache: Optional[int] = None + settings = self._resolve_settings() + self.model = settings.get("model", "") + self.provider_name = settings.get("name", "unknown") + self.base_url = settings.get("base_url", "") + self.initialized = False + + def _using_unified_config(self) -> bool: + return bool(self.api_endpoint or self.api_token or self.model_name or self.provider_name_env) + + def _resolve_settings_legacy(self) -> Dict[str, str]: + p = self.legacy_provider + if p == "openai": + return {"name": "openai", "base_url": "", "api_key": self.openai_api_key, "model": self.openai_model} + if p in {"inference_api", "xeon"}: + ep = _normalize_base_url(_rewrite_local_domain(self.inference_endpoint, self.local_url_endpoint)) + return {"name": "inference_api", "base_url": ep, "api_key": self.inference_token, "model": self.inference_model} + if p == "ollama": + ep = _normalize_base_url(_rewrite_local_domain(self.ollama_endpoint, self.local_url_endpoint)) + return {"name": "ollama", "base_url": ep, "api_key": self.ollama_token or "ollama", "model": self.ollama_model} + raise ValueError("LLM_PROVIDER must be one of: openai, inference_api, xeon, ollama") + + def _resolve_settings(self) -> Dict[str, str]: + if not self._using_unified_config(): + return self._resolve_settings_legacy() + + ep = _rewrite_local_domain(self.api_endpoint, self.local_url_endpoint) + base = _normalize_base_url(ep) if ep else "" + name = self.provider_name_env or _infer_provider_name(ep) + token = self.api_token or self.openai_api_key + model = self.model_name or self.openai_model + + return { + "name": name, + "base_url": base, + "api_key": token, + "model": model, + } + + def ensure_initialized(self) -> None: + if self.initialized: + return + + settings = self._resolve_settings() + api_key = (settings.get("api_key") or "").strip() + model = (settings.get("model") or "").strip() + base_url = (settings.get("base_url") or "").strip() + + if not api_key: + raise ValueError("Missing API token/key. Set API_TOKEN (or OPENAI_API_KEY for OpenAI).") + if not model: + raise ValueError("Missing model name. Set MODEL_NAME (or OPENAI_MODEL).") + + http_client = httpx.Client(verify=self.verify_ssl, timeout=self.timeout) + kwargs: Dict[str, Any] = { + "api_key": api_key, + "timeout": self.timeout, + "max_retries": self.max_retries, + "http_client": http_client, + } + if base_url: + kwargs["base_url"] = base_url + + self.client = OpenAI(**kwargs) + self.api_key = api_key + self.model = model + self.base_url = base_url + self.provider_name = settings.get("name", "unknown") + self.initialized = True + + @staticmethod + def _extract_int_candidate(payload: Any) -> int: + """ + Best-effort extractor for context window fields from model metadata. + """ + keys = { + "context_length", + "max_context_length", + "context_window", + "max_input_tokens", + "max_prompt_tokens", + "input_token_limit", + "num_ctx", + "n_ctx", + } + + def walk(obj: Any) -> int: + if isinstance(obj, dict): + for k, v in obj.items(): + lk = str(k).strip().lower() + if lk in keys: + try: + iv = int(v) + if 512 <= iv <= 10_000_000: + return iv + except Exception: + pass + nested = walk(v) + if nested > 0: + return nested + elif isinstance(obj, list): + for item in obj: + nested = walk(item) + if nested > 0: + return nested + return 0 + + return walk(payload) + + def _context_override_from_env(self) -> int: + direct = (os.getenv("MODEL_CONTEXT_TOKENS", "") or "").strip() + if direct: + try: + v = int(direct) + if v > 0: + return v + except Exception: + pass + + raw_map = (os.getenv("MODEL_CONTEXT_TOKENS_MAP", "") or "").strip() + if not raw_map: + return 0 + try: + data = json.loads(raw_map) + if not isinstance(data, dict): + return 0 + model_key = (self.model or "").strip().lower() + for k, v in data.items(): + if str(k).strip().lower() == model_key: + iv = int(v) + if iv > 0: + return iv + except Exception: + return 0 + return 0 + + def _discover_context_tokens_from_models_endpoint(self) -> int: + self.ensure_initialized() + assert self.client is not None + target_model = (self.model or "").strip().lower() + + try: + listing = self.client.models.list() + items = getattr(listing, "data", []) or [] + fallback_item: Dict[str, Any] = {} + for it in items: + if hasattr(it, "model_dump"): + obj = it.model_dump() + elif isinstance(it, dict): + obj = it + else: + obj = {} + if not isinstance(obj, dict): + continue + item_id = str(obj.get("id", "")).strip().lower() + if item_id == target_model and obj: + found = self._extract_int_candidate(obj) + if found > 0: + return found + if not fallback_item and obj: + fallback_item = obj + + if fallback_item: + found = self._extract_int_candidate(fallback_item) + if found > 0: + return found + except Exception: + pass + return 0 + + def _discover_context_tokens_from_raw_http(self) -> int: + """ + For OpenAI-compatible gateways that expose extra model metadata on /v1/models. + """ + base = (self.base_url or "").strip().rstrip("/") + if not base or not self.api_key: + return 0 + + headers = {"Authorization": f"Bearer {self.api_key}"} + try: + with httpx.Client(verify=self.verify_ssl, timeout=self.timeout) as c: + resp = c.get(f"{base}/models", headers=headers) + resp.raise_for_status() + body = resp.json() + data = body.get("data", []) if isinstance(body, dict) else [] + if not isinstance(data, list): + return 0 + target_model = (self.model or "").strip().lower() + fallback: Dict[str, Any] = {} + for item in data: + if not isinstance(item, dict): + continue + mid = str(item.get("id", "")).strip().lower() + if mid == target_model: + found = self._extract_int_candidate(item) + if found > 0: + return found + if not fallback: + fallback = item + if fallback: + found = self._extract_int_candidate(fallback) + if found > 0: + return found + except Exception: + pass + return 0 + + def resolve_model_context_tokens(self, default_tokens: int) -> int: + """ + Resolve model context window dynamically: + 1) explicit env override + 2) cached discovered value + 3) provider metadata (/models) + 4) fallback default + """ + env_override = self._context_override_from_env() + if env_override > 0: + self._context_tokens_cache = env_override + return env_override + + if self._context_tokens_cache and self._context_tokens_cache > 0: + return int(self._context_tokens_cache) + + discovered = self._discover_context_tokens_from_models_endpoint() + if discovered <= 0: + discovered = self._discover_context_tokens_from_raw_http() + if discovered > 0: + self._context_tokens_cache = int(discovered) + return int(discovered) + return int(default_tokens) + + @staticmethod + def _extract_between(text: str, start_marker: str, end_marker: str = "") -> str: + t = text or "" + start_idx = t.find(start_marker) + if start_idx < 0: + return "" + start_idx += len(start_marker) + if not end_marker: + return t[start_idx:].strip() + end_idx = t.find(end_marker, start_idx) + if end_idx < 0: + return t[start_idx:].strip() + return t[start_idx:end_idx].strip() + + def _split_chat_prompt_parts(self, user_prompt: str) -> Dict[str, str]: + """ + Best-effort parsing of our prompt templates: + - uploaded_document: large document/context block + - user_input: direct user question/input + - user_prompt: instruction wrapper around the above + """ + up = user_prompt or "" + parts = { + "uploaded_document": "", + "user_input": "", + "user_prompt": "", + } + + # RAG chat template + if "CONTEXT:\n" in up and "\n\nQUESTION:\n" in up: + ctx = self._extract_between(up, "CONTEXT:\n", "\n\nQUESTION:\n") + question = self._extract_between(up, "\n\nQUESTION:\n", "\n\nAnswer using only the context.") + parts["uploaded_document"] = ctx + parts["user_input"] = question + wrapper = up.replace(ctx, "").replace(question, "") + parts["user_prompt"] = wrapper.strip() + return parts + + # Summarization templates + if "\n\nDocument:\n" in up: + doc = self._extract_between(up, "\n\nDocument:\n") + parts["uploaded_document"] = doc + wrapper = up.replace(doc, "") + parts["user_prompt"] = wrapper.strip() + return parts + + # Map phase template + if "Chunk text:\n" in up and "\n\nReturn concise bullet points from this chunk only." in up: + chunk = self._extract_between(up, "Chunk text:\n", "\n\nReturn concise bullet points from this chunk only.") + parts["uploaded_document"] = chunk + wrapper = up.replace(chunk, "") + parts["user_prompt"] = wrapper.strip() + return parts + + # Reduce phase template + if "Summaries to merge:\n" in up and "\n\nMerge these into a single de-duplicated brief." in up: + merge_input = self._extract_between(up, "Summaries to merge:\n", "\n\nMerge these into a single de-duplicated brief.") + parts["uploaded_document"] = merge_input + wrapper = up.replace(merge_input, "") + parts["user_prompt"] = wrapper.strip() + return parts + + # Final phase template + if "Consolidated notes:\n" in up and "\n\nWrite the final output now." in up: + notes = self._extract_between(up, "Consolidated notes:\n", "\n\nWrite the final output now.") + parts["uploaded_document"] = notes + wrapper = up.replace(notes, "") + parts["user_prompt"] = wrapper.strip() + return parts + + # Fallback: treat full user message as direct user input + parts["user_input"] = up.strip() + return parts + + def call_chat(self, system_prompt: str, user_prompt: str, max_tokens: int, temperature: float, stream: bool = False): + self.ensure_initialized() + assert self.client is not None + parts = self._split_chat_prompt_parts(user_prompt) + started = time.perf_counter() + try: + resp = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + max_tokens=max_tokens, + temperature=temperature, + stream=stream, + ) + observability_service.record_llm_call( + event="chat", + model=self.model, + provider=self.provider_name, + duration_ms=(time.perf_counter() - started) * 1000.0, + usage=getattr(resp, "usage", None), + system_prompt_chars=len(system_prompt or ""), + user_prompt_chars=len(parts.get("user_prompt", "")), + user_input_chars=len(parts.get("user_input", "")), + uploaded_document_chars=len(parts.get("uploaded_document", "")), + success=True, + ) + return resp + except Exception as e: + observability_service.record_llm_call( + event="chat", + model=self.model, + provider=self.provider_name, + duration_ms=(time.perf_counter() - started) * 1000.0, + system_prompt_chars=len(system_prompt or ""), + user_prompt_chars=len(parts.get("user_prompt", "")), + user_input_chars=len(parts.get("user_input", "")), + uploaded_document_chars=len(parts.get("uploaded_document", "")), + success=False, + error=str(e), + ) + raise + + def _openai_embed_client(self) -> OpenAI: + if not self.openai_api_key: + raise ValueError("OPENAI_API_KEY is required for EMBEDDING_PROVIDER=openai") + return OpenAI( + api_key=self.openai_api_key, + timeout=self.embedding_timeout, + max_retries=self.embedding_max_retries, + http_client=httpx.Client(verify=self.verify_ssl, timeout=self.embedding_timeout), + ) + + def _using_dedicated_embedding_config(self) -> bool: + return bool(self.embedding_model_name and (self.embedding_api_endpoint or self.embedding_api_token)) + + def _dedicated_embedding_client(self) -> OpenAI: + if self.embedding_client is not None: + return self.embedding_client + + endpoint_src = self.embedding_api_endpoint or self.api_endpoint + endpoint = _normalize_base_url(_rewrite_local_domain(endpoint_src, self.local_url_endpoint)) + if not endpoint: + raise ValueError("Embedding endpoint missing. Set EMBEDDING_ENDPOINT or API_ENDPOINT.") + + token = self.embedding_api_token or self.api_token or self.openai_api_key + if not token: + raise ValueError("EMBEDDING_API_TOKEN (or API_TOKEN/OPENAI_API_KEY fallback) is required for dedicated embeddings") + + self.embedding_client = OpenAI( + api_key=token, + base_url=endpoint, + timeout=self.embedding_timeout, + max_retries=self.embedding_max_retries, + http_client=httpx.Client(verify=self.verify_ssl, timeout=self.embedding_timeout), + ) + return self.embedding_client + + def _same_provider_embed_client(self) -> OpenAI: + if self.provider_embedding_client is not None: + return self.provider_embedding_client + + self.ensure_initialized() + kwargs: Dict[str, Any] = { + "api_key": self.api_key, + "timeout": self.embedding_timeout, + "max_retries": self.embedding_max_retries, + "http_client": httpx.Client(verify=self.verify_ssl, timeout=self.embedding_timeout), + } + if self.base_url: + kwargs["base_url"] = self.base_url + self.provider_embedding_client = OpenAI(**kwargs) + return self.provider_embedding_client + + def _embedding_model_for_active_provider(self) -> str: + if self.provider_name in {"inference_api", "xeon"}: + return self.inference_embedding_model + if self.provider_name == "ollama": + return self.ollama_embedding_model + return self.openai_embedding_model + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + self.ensure_initialized() + if not texts: + return [] + started = time.perf_counter() + input_chars = sum(len(t or "") for t in texts) + + if self.embedding_provider == "openai": + emb_client = self._openai_embed_client() + resp = emb_client.embeddings.create(model=self.openai_embedding_model, input=texts) + observability_service.record_llm_call( + event="embedding", + model=self.openai_embedding_model, + provider="openai", + duration_ms=(time.perf_counter() - started) * 1000.0, + usage=getattr(resp, "usage", None), + uploaded_document_chars=input_chars, + success=True, + ) + return [d.embedding for d in resp.data] + + if self._using_dedicated_embedding_config(): + emb_client = self._dedicated_embedding_client() + model_for_custom = self.embedding_model_name + provider_for_custom = self.embedding_provider_name or _infer_provider_name(self.embedding_api_endpoint or self.api_endpoint) + resp = emb_client.embeddings.create(model=model_for_custom, input=texts) + observability_service.record_llm_call( + event="embedding", + model=model_for_custom, + provider=provider_for_custom, + duration_ms=(time.perf_counter() - started) * 1000.0, + usage=getattr(resp, "usage", None), + uploaded_document_chars=input_chars, + success=True, + ) + return [d.embedding for d in resp.data] + + emb_client = self._same_provider_embed_client() + model_for_provider = self._embedding_model_for_active_provider() + try: + resp = emb_client.embeddings.create(model=model_for_provider, input=texts) + observability_service.record_llm_call( + event="embedding", + model=model_for_provider, + provider=self.provider_name, + duration_ms=(time.perf_counter() - started) * 1000.0, + usage=getattr(resp, "usage", None), + uploaded_document_chars=input_chars, + success=True, + ) + return [d.embedding for d in resp.data] + except Exception as e: + observability_service.record_llm_call( + event="embedding", + model=model_for_provider, + provider=self.provider_name, + duration_ms=(time.perf_counter() - started) * 1000.0, + uploaded_document_chars=input_chars, + success=False, + error=str(e), + ) + if self.openai_api_key: + emb_client = self._openai_embed_client() + started_fallback = time.perf_counter() + resp = emb_client.embeddings.create(model=self.openai_embedding_model, input=texts) + observability_service.record_llm_call( + event="embedding_fallback", + model=self.openai_embedding_model, + provider="openai", + duration_ms=(time.perf_counter() - started_fallback) * 1000.0, + usage=getattr(resp, "usage", None), + uploaded_document_chars=input_chars, + success=True, + ) + return [d.embedding for d in resp.data] + raise + + def embed_query(self, query: str) -> List[float]: + q = (query or "").strip() + if not q: + return [] + out = self.embed_texts([q]) + return out[0] if out else [] + + def health_check(self) -> Dict[str, Any]: + try: + self.ensure_initialized() + assert self.client is not None + self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Say OK"}], + max_tokens=10, + temperature=0, + ) + return {"status": "healthy", "provider": self.provider_name, "model": self.model} + except Exception as e: + return {"status": "unhealthy", "provider": self.provider_name or "unknown", "error": str(e)} + + +OpenAIProvider = LLMProvider diff --git a/backend/services/llm/llm_retrieval_service.py b/backend/services/llm/llm_retrieval_service.py new file mode 100644 index 0000000..d025c3a --- /dev/null +++ b/backend/services/llm/llm_retrieval_service.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import time +from typing import Callable, Dict, Any, List + +from services.rag.vector_store import vector_store + + +class LLMRetrievalService: + def __init__( + self, + doc_store: Dict[str, Dict[str, Any]], + embed_query: Callable[[str], List[float]], + rag_min_score: float = 0.15, + rag_context_max_chars: int = 18000, + ) -> None: + self.doc_store = doc_store + self.embed_query = embed_query + self.rag_min_score = float(rag_min_score) + self.rag_context_max_chars = int(rag_context_max_chars) + + def is_index_ready(self, doc_id: str) -> bool: + if not doc_id: + return False + obj = self.doc_store.get(doc_id) + if not obj: + return False + obj["ts"] = time.time() + return str(obj.get("index_status", "")).lower() == "ready" + + def retrieve_context(self, doc_id: str, query: str, top_k: int) -> str: + if not doc_id or not self.is_index_ready(doc_id): + return "" + if vector_store.count(doc_id) <= 0: + return "" + + emb = self.embed_query(query) + if not emb: + return "" + + results = vector_store.query( + doc_id=doc_id, + query_embedding=emb, + top_k=max(1, int(top_k)), + min_score=self.rag_min_score, + ) + if not results: + return "" + + parts: List[str] = [] + seen_ids = set() + total_chars = 0 + for chunk, _score in results: + if chunk.chunk_id in seen_ids: + continue + txt = (chunk.text or "").strip() + if not txt: + continue + if total_chars + len(txt) > self.rag_context_max_chars and parts: + break + parts.append(txt) + seen_ids.add(chunk.chunk_id) + total_chars += len(txt) + + return "\n\n".join(parts).strip() diff --git a/backend/services/llm/llm_section_service.py b/backend/services/llm/llm_section_service.py new file mode 100644 index 0000000..013ceb7 --- /dev/null +++ b/backend/services/llm/llm_section_service.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +import json +import re +from typing import Callable, Dict, Any, List + +from .llm_text_utils import clean_text, dedupe_section_heading, extract_response_text + + +class LLMSectionService: + def __init__( + self, + call_chat: Callable[[str, str, int, float, bool], Any], + base_system_prompt: Callable[[], str], + normalize_section_title: Callable[[str], str], + facts_max_items: int, + anchor_max_items: int, + anchor_max_chars_each: int, + ) -> None: + self.call_chat = call_chat + self.base_system_prompt = base_system_prompt + self.normalize_section_title = normalize_section_title + self.facts_max_items = int(facts_max_items) + self.anchor_max_items = int(anchor_max_items) + self.anchor_max_chars_each = int(anchor_max_chars_each) + + def discover_dynamic_sections( + self, + fitted_text: str, + min_sections: int = 2, + max_sections: int = 5, + temperature: float = 0.2, + ) -> List[Dict[str, Any]]: + min_sections = max(1, min(int(min_sections), 6)) + max_sections = max(min_sections, min(int(max_sections), 8)) + + system_prompt = ( + self.base_system_prompt() + + "\nYou propose section options (chips) for a summarization UI.\n" + "The user can upload ANY document: invoices, payroll, tax returns, audit reports, loan documents, or general articles.\n" + "Propose only sections that are likely supported by the text.\n" + "\nCRITICAL RULES:\n" + f"- Return between {min_sections} and {max_sections} sections.\n" + "- Use short titles (2 to 5 words). No numbering.\n" + "- Provide a short hint (6 to 12 words) describing what to expect.\n" + "- Do not invent specific numbers.\n" + "- Output MUST be strict JSON only.\n" + ) + + user_prompt = f"""Read the document and propose section chips. + +Return JSON: +{{ + "sections": [ + {{ + "title": "Short Title", + "hint": "Short description" + }} + ] +}} + +Document: +{fitted_text} +""" + + items: List[Dict[str, Any]] = [] + for attempt in range(2): + prompt_to_use = user_prompt + temp = max(0.0, min(float(temperature), 0.35)) if attempt == 0 else 0.0 + max_tok = 350 if attempt == 0 else 220 + if attempt == 1: + prompt_to_use = ( + "Return ONLY strict JSON. No prose. No markdown.\n" + "Schema:\n" + '{"sections":[{"title":"Short Title","hint":"Short description"}]}\n\n' + f"Document:\n{fitted_text}" + ) + + resp = self.call_chat( + system_prompt, + prompt_to_use, + max_tokens=max_tok, + temperature=temp, + stream=False, + ) + raw = extract_response_text(resp).strip() + items = self._normalize_sections(self._parse_json_loose(raw).get("sections", [])) + if len(items) >= min_sections: + break + + out: List[Dict[str, Any]] = [] + seen = set() + for it in items: + key = it["title"].lower() + if key in seen: + continue + seen.add(key) + out.append(it) + + if len(out) >= min_sections: + return out[:max_sections] + + heuristic = self._heuristic_sections_from_text(fitted_text, max_sections=max_sections) + if len(heuristic) >= min_sections: + return heuristic[:max_sections] + + fallback = [ + {"title": "General Summary", "hint": "What the document is about"}, + {"title": "Key Extracts", "hint": "Important names, dates, numbers"}, + ] + return fallback[:max_sections] + + def extract_facts_with_anchors(self, section_title: str, section_hint: str, fitted_text: str) -> List[Dict[str, Any]]: + title = self.normalize_section_title(section_title) + hint = (section_hint or "").strip() + if not title: + return [] + + system_prompt = ( + self.base_system_prompt() + + "\nTask: extract readable key facts for a requested section.\n" + "Facts must be supported by the document.\n" + "\nCRITICAL RULES:\n" + "- Return STRICT JSON only.\n" + "- Facts must be written in plain English (not copied verbatim).\n" + "- Each fact MUST include at least one anchor.\n" + "- Anchors must be SHORT strings that appear verbatim in the document.\n" + "- Do NOT return long quotes.\n" + ) + + user_prompt = f"""Requested section: {title} +Section hint (if any): {hint} + +Return JSON: +{{ + "facts": [ + {{ + "point": "Readable summarized point", + "anchors": ["anchor1", "anchor2"] + }} + ] +}} + +Rules: +- Provide up to {self.facts_max_items} facts. +- Each fact must include 1 to {self.anchor_max_items} anchors. +- Anchors must be <= {self.anchor_max_chars_each} characters. +- Prefer anchors that include numbers, dates, totals, ratios, names, account labels, or table row labels. +- Do not invent any numbers. + +Document: +{fitted_text} +""" + facts: List[Dict[str, Any]] = [] + for attempt in range(2): + prompt_to_use = user_prompt + temp = 0.15 if attempt == 0 else 0.0 + max_tok = 650 if attempt == 0 else 420 + if attempt == 1: + prompt_to_use = ( + f"Requested section: {title}\n" + f"Section hint: {hint}\n" + "Return ONLY strict JSON. No prose. No markdown.\n" + 'Schema: {"facts":[{"point":"Readable point","anchors":["anchor1","anchor2"]}]}\n' + f"Constraints: up to {self.facts_max_items} facts, 1 to {self.anchor_max_items} anchors each.\n\n" + f"Document:\n{fitted_text}" + ) + + resp = self.call_chat(system_prompt, prompt_to_use, max_tokens=max_tok, temperature=temp, stream=False) + raw = extract_response_text(resp).strip() + facts = self._normalize_facts(self._parse_json_loose(raw).get("facts", [])) + if facts: + break + + out: List[Dict[str, Any]] = [] + seen = set() + for f in facts: + key = re.sub(r"\s+", " ", f["point"]).strip().lower() + if key in seen: + continue + seen.add(key) + out.append(f) + return out + + def _normalize_sections(self, secs: Any) -> List[Dict[str, str]]: + items: List[Dict[str, str]] = [] + if not isinstance(secs, list): + return items + for s in secs: + if not isinstance(s, dict): + continue + title = self.normalize_section_title(str(s.get("title", "")).strip()) + hint = str(s.get("hint", "") or "").strip() + hint = re.sub(r"\s{2,}", " ", hint) + if not title: + continue + if len(hint) > 90: + hint = hint[:90].rstrip() + items.append({"title": title, "hint": hint}) + return items + + def _normalize_facts(self, items: Any) -> List[Dict[str, Any]]: + facts: List[Dict[str, Any]] = [] + if not isinstance(items, list): + return facts + for it in items: + if not isinstance(it, dict): + continue + point = str(it.get("point", "") or "").strip() + anchors = it.get("anchors", []) + if not point or not isinstance(anchors, list): + continue + norm_anchors: List[str] = [] + for a in anchors[: self.anchor_max_items]: + s = str(a or "").strip() + if not s: + continue + if len(s) > self.anchor_max_chars_each: + s = s[: self.anchor_max_chars_each].rstrip() + norm_anchors.append(s) + if not norm_anchors: + continue + facts.append({"point": point, "anchors": norm_anchors}) + return facts + + @staticmethod + def validate_anchored_facts(facts: List[Dict[str, Any]], fitted_text: str) -> List[Dict[str, Any]]: + if not facts: + return [] + valid: List[Dict[str, Any]] = [] + for f in facts: + anchors = f.get("anchors", []) + if not isinstance(anchors, list): + continue + ok = False + for a in anchors: + if not a: + continue + if str(a) in fitted_text: + ok = True + break + if ok: + valid.append(f) + return valid + + def write_section_from_facts(self, section_title: str, facts: List[Dict[str, Any]], max_tokens: int, temperature: float) -> str: + title = self.normalize_section_title(section_title) + + if not facts: + return f"{title}\n- No supported information found in the text for this section.\n" + + fact_lines: List[str] = [] + for f in facts: + p = str(f.get("point", "") or "").strip() + if p: + fact_lines.append(f"- {p}") + if not fact_lines: + return f"{title}\n- No supported information found in the text for this section.\n" + + system_prompt = self.base_system_prompt() + ( + "\nWrite the requested section using ONLY the provided facts.\n" + "Do not invent.\n" + "\nOUTPUT RULES:\n" + "- Start with the heading exactly as provided.\n" + "- Keep it easy to read.\n" + "- Use bullets (recommended) or short paragraphs.\n" + "- Keep numbers when present in facts.\n" + "- If facts are weak, explicitly say it is limited.\n" + ) + user_prompt = f"""Section heading: {title} + +Facts (use only these): +{chr(10).join(fact_lines)} + +Write the section now. +""" + resp = self.call_chat( + system_prompt, + user_prompt, + max_tokens=max_tokens, + temperature=max(0.0, min(float(temperature), 0.3)), + stream=False, + ) + out = clean_text(extract_response_text(resp)) + return dedupe_section_heading(out, title) + + @staticmethod + def _parse_json_loose(raw: str) -> Dict[str, Any]: + txt = (raw or "").strip() + if not txt: + return {} + try: + return json.loads(txt) + except Exception: + pass + # Remove common markdown fences + txt2 = re.sub(r"^```(?:json)?\s*", "", txt, flags=re.IGNORECASE).strip() + txt2 = re.sub(r"\s*```$", "", txt2).strip() + try: + return json.loads(txt2) + except Exception: + pass + # Fallback: parse first JSON object span + s = txt2.find("{") + e = txt2.rfind("}") + if s >= 0 and e > s: + try: + return json.loads(txt2[s : e + 1]) + except Exception: + return {} + return {} + + def _heuristic_sections_from_text(self, text: str, max_sections: int) -> List[Dict[str, str]]: + t = (text or "").lower() + out: List[Dict[str, str]] = [] + + def add(title: str, hint: str) -> None: + if any(x["title"].lower() == title.lower() for x in out): + return + out.append({"title": title, "hint": hint}) + + # Prefer document-structure driven chips for weak models. + if any(k in t for k in ("balance sheet", "statement of financial position")): + add("Balance Sheet", "Assets, liabilities, and equity snapshot") + if any(k in t for k in ("income statement", "profit and loss", "statement of operations")): + add("Income Statement", "Revenue, expenses, and net income") + if any(k in t for k in ("cash flow", "cash flows")): + add("Cash Flow", "Operating, investing, and financing cash movement") + if any(k in t for k in ("retained earnings", "shareholder", "equity")): + add("Equity & Retained", "Changes in equity and retained earnings") + if any(k in t for k in ("notes", "accounting policy", "policies")): + add("Notes & Policies", "Important notes and accounting assumptions") + if any(k in t for k in ("tax", "gst", "vat", "income tax")): + add("Taxes", "Tax-related details and amounts") + if any(k in t for k in ("debt", "loan", "borrow", "interest")): + add("Debt & Financing", "Borrowings, interest, and funding details") + if any(k in t for k in ("invoice", "payable", "receivable")): + add("Billing & Payments", "Invoices, receivables, and payment status") + + if not out: + out = [ + {"title": "General Summary", "hint": "What the document is about"}, + {"title": "Key Figures", "hint": "Important amounts, dates, and ratios"}, + {"title": "Entities & Dates", "hint": "Main names, periods, and timelines"}, + ] + + return out[: max(1, int(max_sections))] diff --git a/backend/services/llm/llm_service.py b/backend/services/llm/llm_service.py new file mode 100644 index 0000000..b8a930b --- /dev/null +++ b/backend/services/llm/llm_service.py @@ -0,0 +1,1469 @@ +""" +LLM Service for Document Summarization (FinSights) +Uses an OpenAI-compatible provider (OpenAI, inference API/Xeon, or Ollama) + +This version supports: +- Dynamic section chips (2 to 5) generated from the document at initial step. +- Section-wise summaries for ANY selected section title (no pre-defined section list). +- Readable section summaries WITHOUT showing quotes. + Internally, we extract "facts" with short "anchors" that must exist in the document, + validate anchors, then generate the final section from facts only. + +Compatibility: +- initial_summary_first_chunk(doc_id) returns ONLY a summary string (same as before). +- summarize_financial(mode="financial_section") accepts any section title. +- doc_id flow stays fast (no file re-upload needed on chip clicks). + +Anti-hallucination strategy: +- We do NOT require exact quotes in the final output. +- We do require each extracted fact to include at least one anchor that exists in the document text. + Anchors are not shown to the user; they are only used for validation. +""" + +from typing import Iterator, Dict, Any, Optional, Union, List +import logging +import re +import os +import time +import uuid +import hashlib +import math + +import config +from services.rag.vector_store import vector_store +from .llm_text_utils import clean_text, normalize_money, dedupe_section_heading, extract_response_text +from services.rag.summarization_pipeline import SummarizationPipeline, SourceChunk +from .llm_provider import OpenAIProvider +from .llm_retrieval_service import LLMRetrievalService +from .llm_section_service import LLMSectionService +from services.observability_service import observability_service + +logger = logging.getLogger(__name__) + + +def _env_int(name: str, default: int) -> int: + raw = os.getenv(name, "") + if raw is None: + return int(default) + txt = str(raw).strip() + if txt == "": + return int(default) + try: + return int(txt) + except Exception: + return int(default) + + +def _env_float(name: str, default: float) -> float: + raw = os.getenv(name, "") + if raw is None: + return float(default) + txt = str(raw).strip() + if txt == "": + return float(default) + try: + return float(txt) + except Exception: + return float(default) + + +class LLMService: + def __init__(self): + self.provider = OpenAIProvider() + self.model = self.provider.model + self.embedding_model = self.provider.embedding_model + self._initialized = False + + # Large default; override via env if needed. + self.model_context_tokens = _env_int("MODEL_CONTEXT_TOKENS", 128000) + self.min_model_context_tokens = _env_int("MIN_MODEL_CONTEXT_TOKENS", 4096) + self.context_retry_shrink_ratio = _env_float("CONTEXT_RETRY_SHRINK_RATIO", 0.75) + self.context_retry_margin_tokens = _env_int("CONTEXT_RETRY_MARGIN_TOKENS", 1200) + + # In-memory doc store for doc_id flow + self.doc_store: Dict[str, Dict[str, Any]] = {} + self.cache_ttl_seconds = _env_int("CACHE_TTL_SECONDS", 60 * 60) # 1 hour + self.cache_max_docs = _env_int("CACHE_MAX_DOCS", 25) + + # Dynamic section discovery bounds (frontend chips) + self.dynamic_min_sections = _env_int("DYNAMIC_SECTIONS_MIN", 2) + self.dynamic_max_sections = _env_int("DYNAMIC_SECTIONS_MAX", 5) + self.dynamic_sections_use_llm = os.getenv("DYNAMIC_SECTIONS_USE_LLM", "false").strip().lower() == "true" + + # Evidence/facts extraction bounds + self.facts_max_items = _env_int("FACTS_MAX_ITEMS", 10) + self.anchor_max_items = _env_int("ANCHOR_MAX_ITEMS", 3) + self.anchor_max_chars_each = _env_int("ANCHOR_MAX_CHARS_EACH", 60) + + # Validation threshold: how many facts must be anchored to proceed + # Example: 0.6 means at least 60% of extracted facts must have >=1 valid anchor. + self.min_anchored_fact_ratio = _env_float("MIN_ANCHORED_FACT_RATIO", 0.6) + self.summary_retry_enable = os.getenv("SUMMARY_RETRY_ENABLE", "false").strip().lower() == "true" + + # RAG-assisted summarization defaults for large docs + self.rag_summary_top_k = _env_int("RAG_SUMMARY_TOP_K", 8) + self.rag_section_top_k = _env_int("RAG_SECTION_TOP_K", 10) + self.rag_min_score = _env_float("RAG_MIN_SCORE", 0.15) + self.rag_context_max_chars = _env_int("RAG_CONTEXT_MAX_CHARS", 18000) + self.retrieval_first_enable = os.getenv("RETRIEVAL_FIRST_ENABLE", "true").strip().lower() == "true" + self.retrieval_force_index_on_demand = os.getenv("RETRIEVAL_FORCE_INDEX_ON_DEMAND", "false").strip().lower() == "true" + + # Map-reduce summarization for large docs + self.enable_map_reduce = os.getenv("MAP_REDUCE_ENABLE", "true").lower() == "true" + self.map_reduce_min_chars = _env_int("MAP_REDUCE_MIN_CHARS", 18000) + self.map_reduce_min_chunks = _env_int("MAP_REDUCE_MIN_CHUNKS", 8) + self.map_reduce_max_chunks = _env_int("MAP_REDUCE_MAX_CHUNKS", 24) + self.map_reduce_chunk_chars = _env_int("MAP_REDUCE_CHUNK_CHARS", 2800) + self.map_reduce_chunk_overlap = _env_int("MAP_REDUCE_CHUNK_OVERLAP", 250) + self.map_reduce_batch_size = _env_int("MAP_REDUCE_BATCH_SIZE", 6) + self.map_reduce_overhead_tokens = _env_int("MAP_REDUCE_OVERHEAD_TOKENS", 900) + self.map_reduce_target_reduce_groups = _env_int("MAP_REDUCE_TARGET_REDUCE_GROUPS", 6) + self.summary_pipeline = SummarizationPipeline( + max_chunks=self.map_reduce_max_chunks, + chunk_chars=self.map_reduce_chunk_chars, + chunk_overlap_chars=self.map_reduce_chunk_overlap, + reduce_batch_size=self.map_reduce_batch_size, + ) + self.retrieval_service = LLMRetrievalService( + doc_store=self.doc_store, + embed_query=self._embed_query, + rag_min_score=self.rag_min_score, + rag_context_max_chars=self.rag_context_max_chars, + ) + self.section_service = LLMSectionService( + call_chat=self._call_chat, + base_system_prompt=self._base_system_prompt, + normalize_section_title=self._normalize_section_title, + facts_max_items=self.facts_max_items, + anchor_max_items=self.anchor_max_items, + anchor_max_chars_each=self.anchor_max_chars_each, + ) + + def _ensure_initialized(self): + if self._initialized: + return + self.provider.ensure_initialized() + self.model = self.provider.model + self.embedding_model = self.provider.embedding_model + try: + resolved_ctx = int(self.provider.resolve_model_context_tokens(self.model_context_tokens)) + if resolved_ctx > 0: + self.model_context_tokens = resolved_ctx + except Exception: + logger.exception("Failed to auto-resolve model context window; using configured value") + self._initialized = True + + logger.info("LLM provider initialized successfully") + logger.info(f"Provider: {self.provider.provider_name or 'unknown'}") + logger.info(f"Model: {self.model}") + logger.info(f"MODEL_CONTEXT_TOKENS: {self.model_context_tokens}") + logger.info(f"MIN_MODEL_CONTEXT_TOKENS: {self.min_model_context_tokens}") + logger.info(f"CONTEXT_RETRY_SHRINK_RATIO: {self.context_retry_shrink_ratio}") + logger.info(f"CONTEXT_RETRY_MARGIN_TOKENS: {self.context_retry_margin_tokens}") + logger.info(f"CACHE_MAX_DOCS: {self.cache_max_docs}") + logger.info(f"CACHE_TTL_SECONDS: {self.cache_ttl_seconds}") + logger.info(f"DYNAMIC_SECTIONS_MIN: {self.dynamic_min_sections}") + logger.info(f"DYNAMIC_SECTIONS_MAX: {self.dynamic_max_sections}") + logger.info(f"DYNAMIC_SECTIONS_USE_LLM: {self.dynamic_sections_use_llm}") + logger.info(f"FACTS_MAX_ITEMS: {self.facts_max_items}") + logger.info(f"ANCHOR_MAX_ITEMS: {self.anchor_max_items}") + logger.info(f"MIN_ANCHORED_FACT_RATIO: {self.min_anchored_fact_ratio}") + logger.info(f"SUMMARY_RETRY_ENABLE: {self.summary_retry_enable}") + logger.info(f"EMBEDDING_MODEL: {self.embedding_model}") + logger.info(f"RAG_SUMMARY_TOP_K: {self.rag_summary_top_k}") + logger.info(f"RAG_SECTION_TOP_K: {self.rag_section_top_k}") + logger.info(f"RAG_MIN_SCORE: {self.rag_min_score}") + logger.info(f"RAG_CONTEXT_MAX_CHARS: {self.rag_context_max_chars}") + logger.info(f"RETRIEVAL_FIRST_ENABLE: {self.retrieval_first_enable}") + logger.info(f"RETRIEVAL_FORCE_INDEX_ON_DEMAND: {self.retrieval_force_index_on_demand}") + logger.info(f"MAP_REDUCE_ENABLE: {self.enable_map_reduce}") + logger.info(f"MAP_REDUCE_MIN_CHARS: {self.map_reduce_min_chars}") + logger.info(f"MAP_REDUCE_MIN_CHUNKS: {self.map_reduce_min_chunks}") + logger.info(f"MAP_REDUCE_MAX_CHUNKS: {self.map_reduce_max_chunks}") + logger.info(f"MAP_REDUCE_CHUNK_CHARS: {self.map_reduce_chunk_chars}") + logger.info(f"MAP_REDUCE_CHUNK_OVERLAP: {self.map_reduce_chunk_overlap}") + logger.info(f"MAP_REDUCE_BATCH_SIZE: {self.map_reduce_batch_size}") + logger.info(f"MAP_REDUCE_OVERHEAD_TOKENS: {self.map_reduce_overhead_tokens}") + logger.info(f"MAP_REDUCE_TARGET_REDUCE_GROUPS: {self.map_reduce_target_reduce_groups}") + + def get_provider_name(self) -> str: + try: + self._ensure_initialized() + return self.provider.provider_name or "unknown" + except Exception: + return self.provider.provider_name or "unknown" + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + self._ensure_initialized() + return self.provider.embed_texts(texts) + + # ---------------------------- + # Compatibility wrapper + # ---------------------------- + def summarize( + self, + text: str, + max_tokens: int = None, + temperature: float = None, + stream: bool = False, + mode: str = "financial_initial", + section: str = None, + ) -> Union[str, Iterator[str]]: + return self.summarize_financial( + text=text, + mode=mode, + section=section, + max_tokens=max_tokens, + temperature=temperature, + stream=stream, + ) + + # ---------------------------- + # doc_id based flow (compat) + # ---------------------------- + def create_doc(self, text: str) -> str: + self._ensure_initialized() + self._evict_cache_if_needed() + + if not text or not text.strip(): + raise ValueError("Empty text") + + doc_id = str(uuid.uuid4()) + dk = self._doc_key(text) + + self.doc_store[doc_id] = { + "ts": time.time(), + "text": text, + "doc_key": dk, + # list of {"title": str, "hint": str} created during initial summary + "sections": None, + "sections_status": "pending", # pending | ready | failed + + # ---- RAG indexing state (for chat) ---- + "index_status": "pending", # pending | ready | failed + "chunk_count": 0, + "index_error": "", + "index_started_at": None, + "index_finished_at": None, + } + return doc_id + + def get_doc_text(self, doc_id: str) -> str: + obj = self.doc_store.get(doc_id) + if not obj: + raise ValueError("Invalid doc_id") + obj["ts"] = time.time() + return obj.get("text", "") + + def get_doc_key(self, doc_id: str) -> str: + obj = self.doc_store.get(doc_id) + if not obj: + raise ValueError("Invalid doc_id") + obj["ts"] = time.time() + return obj.get("doc_key", "") + + def get_doc_sections(self, doc_id: str) -> List[str]: + """ + Returns discovered dynamic section titles for this doc_id. + If not discovered yet, returns []. + """ + obj = self.doc_store.get(doc_id) + if not obj: + raise ValueError("Invalid doc_id") + obj["ts"] = time.time() + + secs = obj.get("sections") + if not isinstance(secs, list): + return [] + + titles: List[str] = [] + for s in secs: + if isinstance(s, dict) and s.get("title"): + titles.append(str(s["title"]).strip()) + elif isinstance(s, str): + titles.append(str(s).strip()) + return [t for t in titles if t] + + def get_doc_sections_payload(self, doc_id: str) -> Dict[str, Any]: + obj = self.doc_store.get(doc_id) + if not obj: + raise ValueError("Invalid doc_id") + obj["ts"] = time.time() + + status = str(obj.get("sections_status", "pending")).lower() + sections_raw = obj.get("sections") + sections: List[Dict[str, str]] = [] + if isinstance(sections_raw, list): + for s in sections_raw: + if isinstance(s, dict): + title = str(s.get("title", "")).strip() + hint = str(s.get("hint", "")).strip() + if title: + sections.append({"title": title, "hint": hint}) + elif isinstance(s, str): + t = str(s).strip() + if t: + sections.append({"title": t, "hint": ""}) + + return { + "doc_id": doc_id, + "status": status, + "ready": status == "ready", + "count": len(sections), + "sections": sections, + } + + def generate_doc_sections(self, doc_id: str, force: bool = False) -> None: + """ + Background-safe section discovery. + """ + self._ensure_initialized() + obj = self.doc_store.get(doc_id) + if not obj: + return + + obj["ts"] = time.time() + if not force and isinstance(obj.get("sections"), list) and obj.get("sections"): + obj["sections_status"] = "ready" + return + + obj["sections_status"] = "pending" + try: + text = str(obj.get("text", "") or "") + fitted = self._context_for_mode( + text=text, + mode="financial_initial", + max_output_tokens=240, + doc_id=doc_id, + ) + sections = self._discover_dynamic_sections( + fitted_text=fitted, + min_sections=self.dynamic_min_sections, + max_sections=self.dynamic_max_sections, + temperature=0.2, + ) + obj["sections"] = sections or [] + obj["sections_status"] = "ready" + obj["ts"] = time.time() + except Exception: + obj["sections_status"] = "failed" + obj["sections"] = [] + logger.exception("Background section discovery failed doc_id=%s", doc_id) + + def get_doc_section_hint(self, doc_id: str, section_title: str) -> str: + """ + Returns a stored hint for a discovered section (optional), else "". + Hints are short descriptions like "Totals, taxes, payment status". + """ + obj = self.doc_store.get(doc_id) + if not obj: + raise ValueError("Invalid doc_id") + obj["ts"] = time.time() + + secs = obj.get("sections") + if not isinstance(secs, list): + return "" + + target = self._normalize_section_title(section_title).lower() + if not target: + return "" + + for s in secs: + if isinstance(s, dict): + t = self._normalize_section_title(str(s.get("title", ""))).lower() + if t == target: + return str(s.get("hint", "") or "").strip() + return "" + + def prefetch_doc(self, doc_id: str) -> None: + # No-op (kept so existing routes won't break) + obj = self.doc_store.get(doc_id) + if obj: + obj["ts"] = time.time() + return + + def summarize_by_doc_id( + self, + doc_id: str, + mode: str = "financial_initial", + section: str = None, + max_tokens: int = None, + temperature: float = None, + stream: bool = False, + ) -> Union[str, Iterator[str]]: + text = self.get_doc_text(doc_id) + return self.summarize_financial( + text=text, + mode=mode, + section=section, + max_tokens=max_tokens, + temperature=temperature, + stream=stream, + doc_id=doc_id, + ) + + def initial_summary_first_chunk(self, doc_id: str, max_tokens: int = 240, temperature: float = 0.25) -> str: + """ + Compatibility method expected by routes/frontend. + Uses full doc text (fitted to context) to produce 4-5 sentences. + + Updated behavior: + - Also runs dynamic section discovery and stores it in doc_store[doc_id]["sections"]. + - Still returns ONLY the summary string (so existing routes don't break). + """ + self._ensure_initialized() + text = self.get_doc_text(doc_id) + fitted = self._context_for_mode( + text=text, + mode="financial_initial", + max_output_tokens=max_tokens, + doc_id=doc_id, + ) + + # Dynamic sections discovery (best effort; never breaks summary) + try: + sections = self._discover_dynamic_sections( + fitted_text=fitted, + min_sections=self.dynamic_min_sections, + max_sections=self.dynamic_max_sections, + temperature=0.2, + ) + obj = self.doc_store.get(doc_id) + if obj is not None: + obj["sections"] = sections + obj["sections_status"] = "ready" + obj["ts"] = time.time() + except Exception: + logger.exception("Dynamic section discovery failed (continuing with summary only).") + + # Use map-reduce for large docs when enabled. + # Keep section discovery above so frontend chips still work. + if (not self.retrieval_first_enable) and self._should_use_map_reduce(text=text, doc_id=doc_id): + mr = self._map_reduce_summary( + text=text, + doc_id=doc_id, + max_tokens=max_tokens, + temperature=temperature, + style="initial", + ) + if mr: + return normalize_money(clean_text(mr)) + + # Normal initial summary (non map-reduce path) + system_prompt = self._base_system_prompt() + ( + "\nWrite a short generalized summary that tells the user what this document is about.\n" + "Focus on: document type, company/entity, reporting period/date, and purpose ONLY if explicitly stated.\n" + "Rules:\n" + "- Plain text only.\n" + "- 4 to 5 sentences.\n" + "- Prefer exact names, dates, and periods when present.\n" + "- Include key numeric highlights ONLY if explicitly present.\n" + "- Do not invent any details.\n" + ) + + user_prompt = f"""Write the generalized summary for this document. + +Document: +{fitted} +""" + resp = self._call_chat( + system_prompt, + user_prompt, + max_tokens=max_tokens, + temperature=max(0.0, min(float(temperature), 0.35)), + stream=False, + ) + out = self._clean_or_retry_summary( + draft=extract_response_text(resp), + source_text=(fitted or text), + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + temperature=temperature, + ) + return normalize_money(out) + + # ---------------------------- + # Prompt building + # ---------------------------- + @staticmethod + def _base_system_prompt() -> str: + return ( + "You are an analyst AI.\n" + "You must use ONLY information supported by the provided document text.\n" + "Do not invent facts, numbers, dates, names, or events.\n" + "If you are uncertain, state it clearly.\n" + "Output must be plain text only (no markdown emphasis: no **, *, _).\n" + ) + def chat_with_context( + self, + question: str, + context: str, + max_tokens: int = 500, + temperature: float = 0.2, + ) -> str: + """ + Answer a user question using ONLY the retrieved context. + If the answer is not supported by the context, say so. + """ + self._ensure_initialized() + + q = (question or "").strip() + ctx = (context or "").strip() + + if not q: + return "Please enter a question." + + system_prompt = ( + self._base_system_prompt() + + "\nYou are answering questions about an uploaded document.\n" + "Use ONLY the provided CONTEXT.\n" + "If the answer is not in the context, say you cannot find it.\n" + "Keep the response concise and readable.\n" + ) + + user_prompt = f"""CONTEXT: +{ctx} + +QUESTION: +{q} + +Answer using only the context. +""" + + resp = self._call_chat( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + temperature=max(0.0, min(float(temperature), 0.5)), + stream=False, + ) + return normalize_money(clean_text(extract_response_text(resp))) + + + # ---------------------------- + # Core helpers + # ---------------------------- + @staticmethod + def _estimate_tokens(text: str) -> int: + return max(1, int(len(text) / 4)) if text else 0 + + @staticmethod + def _is_context_limit_error(exc: Exception) -> bool: + msg = str(exc or "").lower() + needles = ( + "context length", + "maximum context length", + "context_window_exceeded", + "context_length_exceeded", + "too many tokens", + "prompt is too long", + "token limit exceeded", + "max context", + ) + return any(n in msg for n in needles) + + def _shrink_context_window(self, reason: str) -> int: + current = int(self.model_context_tokens) + shrunk = int(max(self.min_model_context_tokens, current * self.context_retry_shrink_ratio)) + if shrunk >= current: + shrunk = max(self.min_model_context_tokens, current - 2048) + if shrunk < current: + logger.warning( + "Shrinking model context budget (%s): %d -> %d", + reason, + current, + shrunk, + ) + self.model_context_tokens = shrunk + return int(self.model_context_tokens) + + def _fit_prompts_to_context( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int, + margin_tokens: int = 900, + ) -> tuple[str, str, int]: + """ + Ensure chat prompt + completion target fits into current model_context_tokens. + Trims only the user prompt when needed. + """ + sys_txt = system_prompt or "" + usr_txt = user_prompt or "" + out_tokens = max(64, int(max_tokens)) + margin = max(256, int(margin_tokens)) + + available_prompt_tokens = max(256, int(self.model_context_tokens - out_tokens - margin)) + sys_tokens = self._estimate_tokens(sys_txt) + user_tokens = self._estimate_tokens(usr_txt) + total_prompt_tokens = sys_tokens + user_tokens + if total_prompt_tokens <= available_prompt_tokens: + return sys_txt, usr_txt, out_tokens + + available_user_tokens = max(64, available_prompt_tokens - sys_tokens) + if user_tokens <= available_user_tokens: + return sys_txt, usr_txt, out_tokens + + ratio = max(0.05, min(1.0, available_user_tokens / float(max(1, user_tokens)))) + keep_chars = max(600, int(len(usr_txt) * ratio)) + head = usr_txt[: int(keep_chars * 0.75)] + tail = usr_txt[-int(keep_chars * 0.25) :] if keep_chars > 1800 else "" + trimmed = head + if tail and tail not in head: + trimmed = head + "\n\n[...TRUNCATED FOR CONTEXT LIMIT...]\n\n" + tail + return sys_txt, trimmed, out_tokens + + @staticmethod + def _strip_summary_leadin(text: str) -> str: + t = (text or "").strip() + if not t: + return "" + patterns = [ + r"^\s*here is (a|the) (short )?(generalized )?summary[^:\n]*:\s*", + r"^\s*here'?s (a|the) summary[^:\n]*:\s*", + r"^\s*summary\s*:\s*", + ] + for p in patterns: + t = re.sub(p, "", t, flags=re.IGNORECASE) + return t.strip() + + @staticmethod + def _extract_num_tokens(text: str) -> List[str]: + if not text: + return [] + vals = re.findall(r"\b\d[\d,./:-]*\b", text) + out: List[str] = [] + seen = set() + for v in vals: + k = v.strip().lower() + if k and k not in seen: + seen.add(k) + out.append(v.strip()) + if len(out) >= 24: + break + return out + + def _looks_generic_summary(self, summary: str, source_text: str) -> bool: + s = (summary or "").strip() + if not s: + return True + low = s.lower() + generic_phrases = ( + "can be evaluated using", + "provide insights", + "stakeholders can assess", + "ability to generate revenue", + "manage costs", + "invest in growth opportunities", + ) + if any(g in low for g in generic_phrases): + return True + + src_nums = self._extract_num_tokens(source_text) + if src_nums: + # If source has many numbers, summary should preserve at least one. + if not any(n in s for n in src_nums[:12]): + return True + return False + + def _clean_or_retry_summary( + self, + draft: str, + source_text: str, + system_prompt: str, + user_prompt: str, + max_tokens: int, + temperature: float, + ) -> str: + cleaned = clean_text(self._strip_summary_leadin(draft)) + if cleaned and not self._looks_generic_summary(cleaned, source_text): + return cleaned + if not self.summary_retry_enable: + if cleaned: + return cleaned + return self._fallback_extractive_summary(source_text) + + retry_user = ( + user_prompt + + "\n\nIMPORTANT:\n" + "- Do NOT start with phrases like 'Here is a summary'.\n" + "- Use only concrete facts from the document.\n" + "- Include at least one exact date/period or numeric value if present in document.\n" + "- Output plain text only.\n" + ) + retry_resp = self._call_chat( + system_prompt, + retry_user, + max_tokens=max_tokens, + temperature=0.0, + stream=False, + ) + retry_text = clean_text(self._strip_summary_leadin(extract_response_text(retry_resp))) + if retry_text and not self._looks_generic_summary(retry_text, source_text): + return retry_text + if cleaned and not self._looks_generic_summary(cleaned, source_text): + return cleaned + return self._fallback_extractive_summary(source_text) + + @staticmethod + def _fallback_extractive_summary(source_text: str) -> str: + """ + Deterministic fallback for weak models: pull concrete sentences from source. + """ + t = (source_text or "").strip() + if not t: + return "" + + # Normalize whitespace then split into sentence-like units. + norm = re.sub(r"\s+", " ", t) + raw_sentences = re.split(r"(?<=[.!?])\s+|\s*\n+\s*", norm) + sentences = [s.strip() for s in raw_sentences if s and len(s.strip()) >= 35] + if not sentences: + return t[:600].strip() + + def score(s: str) -> int: + low = s.lower() + sc = 0 + if re.search(r"\b\d[\d,./:-]*\b", s): + sc += 3 + for k in ("ended", "year", "period", "balance sheet", "income statement", "cash flow", "revenue", "expense", "net", "total", "assets", "liabilities"): + if k in low: + sc += 1 + if 60 <= len(s) <= 220: + sc += 1 + return sc + + ranked = sorted(sentences, key=score, reverse=True) + picked: List[str] = [] + seen = set() + for s in ranked: + key = s[:120].lower() + if key in seen: + continue + seen.add(key) + picked.append(s) + if len(picked) >= 4: + break + + if not picked: + picked = sentences[:3] + + return " ".join(picked).strip() + + @staticmethod + def _doc_key(text: str) -> str: + return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest() + + @staticmethod + def _normalize_section_title(title: str) -> str: + t = (title or "").strip() + t = re.sub(r"\s{2,}", " ", t) + t = t.strip(" -:\t\r\n") + if len(t) > 56: + t = t[:56].rstrip() + return t + + def _evict_cache_if_needed(self): + now = time.time() + + expired_docs = [k for k, v in self.doc_store.items() if now - v.get("ts", now) > self.cache_ttl_seconds] + for k in expired_docs: + self.doc_store.pop(k, None) + + if len(self.doc_store) <= self.cache_max_docs: + return + + items = sorted(self.doc_store.items(), key=lambda kv: kv[1].get("ts", 0)) + while len(items) > self.cache_max_docs: + k, _ = items.pop(0) + self.doc_store.pop(k, None) + + def _fit_text_to_context(self, text: str, max_output_tokens: int) -> str: + """ + Truncate input so input + output stays within model_context_tokens. + Keeps head + tail when truncating. + """ + if not text: + return "" + + overhead_tokens = 900 + available_input_tokens = max(500, self.model_context_tokens - int(max_output_tokens) - overhead_tokens) + + est = self._estimate_tokens(text) + if est <= available_input_tokens: + return text + + ratio = available_input_tokens / float(est) + keep_chars = max(2500, int(len(text) * ratio)) + + head = text[: int(keep_chars * 0.72)] + tail = text[-int(keep_chars * 0.28) :] if keep_chars > 4500 else "" + + truncated = head + if tail and tail not in head: + truncated = head + "\n\n[...TRUNCATED...]\n\n" + tail + + return truncated + + def _is_index_ready(self, doc_id: Optional[str]) -> bool: + return self.retrieval_service.is_index_ready(doc_id or "") + + def _embed_query(self, query: str) -> List[float]: + self._ensure_initialized() + try: + return self.provider.embed_query(query) + except Exception: + logger.exception("Failed to embed query for retrieval") + return [] + + def _retrieve_context(self, doc_id: str, query: str, top_k: int) -> str: + try: + return self.retrieval_service.retrieve_context(doc_id, query, top_k) + except Exception: + logger.exception("Vector retrieval failed") + return "" + + def _ensure_doc_index_ready(self, doc_id: Optional[str]) -> bool: + if not doc_id: + return False + if self._is_index_ready(doc_id): + return True + if not self.retrieval_force_index_on_demand: + return False + try: + from services.rag.rag_index_service import rag_index_service + rag_index_service.index_doc(doc_id) + except Exception: + logger.exception("Failed to build RAG index on-demand doc_id=%s", doc_id) + return False + return self._is_index_ready(doc_id) + + @staticmethod + def _dedupe_text_blocks(raw: str) -> List[str]: + parts = [p.strip() for p in (raw or "").split("\n\n") if p and p.strip()] + out: List[str] = [] + seen = set() + for p in parts: + key = hashlib.sha1(p.encode("utf-8", errors="ignore")).hexdigest() + if key in seen: + continue + seen.add(key) + out.append(p) + return out + + def _retrieve_multi_query_context(self, doc_id: Optional[str], queries: List[str], top_k: int, max_chars: Optional[int] = None) -> str: + if not doc_id or not queries: + return "" + if not self._is_index_ready(doc_id) and not self._ensure_doc_index_ready(doc_id): + return "" + + limit = int(max_chars or max(4000, self.rag_context_max_chars)) + merged: List[str] = [] + seen = set() + total = 0 + for q in queries: + ctx = self._retrieve_context(doc_id, query=q, top_k=top_k) + for block in self._dedupe_text_blocks(ctx): + h = hashlib.sha1(block.encode("utf-8", errors="ignore")).hexdigest() + if h in seen: + continue + if total + len(block) > limit and merged: + return "\n\n".join(merged).strip() + merged.append(block) + seen.add(h) + total += len(block) + if total >= limit: + return "\n\n".join(merged).strip() + return "\n\n".join(merged).strip() + + def _context_for_mode( + self, + text: str, + mode: str, + max_output_tokens: int, + doc_id: Optional[str] = None, + section: Optional[str] = None, + section_hint: str = "", + ) -> str: + """ + Prefer retrieval context for large docs when index is ready. + Falls back to legacy truncation path when retrieval is unavailable. + """ + fitted = self._fit_text_to_context(text, max_output_tokens=max_output_tokens) + if not doc_id: + return fitted + if mode == "financial_section": + sec = self._normalize_section_title(section or "") + query = f"Find facts for section: {sec}. {section_hint}".strip() + rag_text = self._retrieve_context(doc_id, query=query, top_k=self.rag_section_top_k) + return rag_text or fitted + + if mode == "financial_initial": + queries = [ + "What is this document about? Include type, parties/entities, and purpose.", + "What reporting period/date does this document cover?", + "What are the key numeric amounts, totals, balances, or highlights?", + ] + rag_text = self._retrieve_multi_query_context( + doc_id=doc_id, + queries=queries, + top_k=self.rag_summary_top_k, + max_chars=max(8000, self.rag_context_max_chars), + ) + return rag_text or fitted + + if mode == "financial_overall": + query = "Summarize the document with key claims, names, dates, and numbers." + rag_text = self._retrieve_context(doc_id, query=query, top_k=self.rag_summary_top_k) + return rag_text or fitted + + if mode == "financial_sectionwise": + query = "Identify major themes and section-level highlights for this document." + rag_text = self._retrieve_context(doc_id, query=query, top_k=self.rag_summary_top_k) + return rag_text or fitted + + return fitted + + def _source_chunks_for_map_reduce(self, text: str, doc_id: Optional[str], max_chunks: Optional[int] = None) -> List[SourceChunk]: + if doc_id and self._is_index_ready(doc_id) and vector_store.count(doc_id) > 0: + chunks = vector_store.list_chunks(doc_id) + out: List[SourceChunk] = [] + for c in chunks: + chunk_text = (c.text or "").strip() + if not chunk_text: + continue + idx = 0 + if isinstance(c.meta, dict): + try: + idx = int(c.meta.get("index", 0)) + except Exception: + idx = 0 + out.append(SourceChunk(text=chunk_text, order=idx)) + if out: + ordered = sorted(out, key=lambda x: x.order) + if max_chunks is not None: + return self.summary_pipeline.limit_chunks_with_limit(ordered, max_chunks=max_chunks) + return ordered + + raw = (text or "").strip() + if not raw: + return [] + if max_chunks is None: + return [SourceChunk(text=raw, order=0)] + return self.summary_pipeline.build_chunks_from_text_with_limit(text=raw, max_chunks=max_chunks) + + @staticmethod + def _split_text_even(text: str, target_chunks: int, overlap_chars: int = 0) -> List[SourceChunk]: + t = (text or "").strip() + if not t: + return [] + n = len(t) + k = max(1, int(target_chunks)) + if k <= 1: + return [SourceChunk(text=t, order=0)] + + overlap = max(0, int(overlap_chars)) + out: List[SourceChunk] = [] + for i in range(k): + start = int(i * n / k) + end = int((i + 1) * n / k) + if i > 0: + start = max(0, start - overlap) + if i < k - 1: + end = min(n, end + overlap) + part = t[start:end].strip() + if part: + out.append(SourceChunk(text=part, order=i)) + return out + + @staticmethod + def _merge_source_chunks_to_target(chunks: List[SourceChunk], target_chunks: int) -> List[SourceChunk]: + ordered = sorted(chunks, key=lambda c: c.order) + n = len(ordered) + k = max(1, min(int(target_chunks), n)) + if k >= n: + return ordered + + out: List[SourceChunk] = [] + for i in range(k): + start = int(i * n / k) + end = int((i + 1) * n / k) + group = ordered[start:end] + merged_text = "\n\n".join([g.text for g in group if (g.text or "").strip()]).strip() + if not merged_text: + continue + base_order = group[0].order if group else i + out.append(SourceChunk(text=merged_text, order=base_order)) + return out + + def _adaptive_map_reduce_plan(self, text: str, final_tokens: int) -> Dict[str, int]: + est_doc_tokens = max(1, self._estimate_tokens(text)) + safe_output_tokens = max(128, int(final_tokens)) + overhead_tokens = max(256, int(self.map_reduce_overhead_tokens)) + input_budget_tokens = max(700, int(self.model_context_tokens - safe_output_tokens - overhead_tokens)) + + chunks_needed = max(1, math.ceil(est_doc_tokens / float(input_budget_tokens))) + target_chunks = max(1, min(chunks_needed, int(self.map_reduce_max_chunks))) + + # Prefer fewer reduce calls by choosing a larger merge batch when safe. + target_groups = max(2, min(int(self.map_reduce_target_reduce_groups), target_chunks)) + adaptive_batch = math.ceil(target_chunks / float(target_groups)) if target_chunks > 0 else 2 + reduce_batch_size = max(2, min(target_chunks if target_chunks > 0 else 2, max(int(self.map_reduce_batch_size), adaptive_batch))) + + estimated_reduce_rounds = 0 + current = max(1, target_chunks) + while current > reduce_batch_size and estimated_reduce_rounds < 12: + current = math.ceil(current / float(reduce_batch_size)) + estimated_reduce_rounds += 1 + + return { + "estimated_doc_tokens": est_doc_tokens, + "input_budget_tokens": input_budget_tokens, + "target_chunks": int(target_chunks), + "reduce_batch_size": int(reduce_batch_size), + "estimated_reduce_rounds": int(estimated_reduce_rounds), + } + + def _should_use_map_reduce(self, text: str, doc_id: Optional[str]) -> bool: + if not self.enable_map_reduce: + return False + if not text or not text.strip(): + return False + if len(text) >= self.map_reduce_min_chars: + return True + if doc_id and self._is_index_ready(doc_id) and vector_store.count(doc_id) >= self.map_reduce_min_chunks: + return True + return False + + def _chat_text(self, system_prompt: str, user_prompt: str, max_tokens: int, temperature: float) -> str: + resp = self._call_chat( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + temperature=temperature, + stream=False, + ) + return clean_text(extract_response_text(resp)) + + def _map_reduce_summary(self, text: str, doc_id: Optional[str], max_tokens: int, temperature: float, style: str) -> str: + if style == "initial": + map_instruction = ( + "Extract only factual points from this chunk: document type, parties/entities, reporting periods, dates, " + "and key numeric values if explicitly present." + ) + reduce_instruction = "Merge chunk summaries, remove duplicates, and keep only supported facts." + final_instruction = ( + "Write a short generalized summary (3 to 5 sentences) for a non-expert user. " + "Include exact names/dates only when present. Do not invent." + ) + final_temp = max(0.0, min(float(temperature), 0.25)) + final_tokens = max(160, min(int(max_tokens), 320)) + else: + map_instruction = ( + "Extract the most important factual points from this chunk. " + "Preserve numbers, dates, and entity names when present." + ) + reduce_instruction = "Merge summaries into a concise de-duplicated brief with key facts prioritized." + final_instruction = ( + "Write a concise overall summary as 8 to 14 bullets max. " + "Use only supported information and keep it readable for non-experts." + ) + final_temp = max(0.0, min(float(temperature), 0.35)) + final_tokens = int(max_tokens) + + source_chunks = self._source_chunks_for_map_reduce(text=text, doc_id=doc_id, max_chunks=None) + if not source_chunks: + return "" + last_error: Optional[Exception] = None + for attempt in range(2): + plan = self._adaptive_map_reduce_plan( + text=text, + final_tokens=final_tokens, + ) + target_chunks = max(1, int(plan.get("target_chunks", 1))) + reduce_batch_size = max(2, int(plan.get("reduce_batch_size", self.map_reduce_batch_size))) + + # Build large context-rich map chunks: + # - If we have indexed chunks, merge them into target-sized groups. + # - Otherwise split raw text evenly into target chunks. + if doc_id and self._is_index_ready(doc_id) and source_chunks: + chunks = self._merge_source_chunks_to_target(source_chunks, target_chunks=target_chunks) + else: + chunks = self._split_text_even( + text=text, + target_chunks=target_chunks, + overlap_chars=self.map_reduce_chunk_overlap, + ) + + if not chunks: + return "" + observability_service.record_map_reduce_plan( + style=style, + estimated_doc_tokens=int(plan.get("estimated_doc_tokens", 0)), + input_budget_tokens=int(plan.get("input_budget_tokens", 0)), + planned_chunks=target_chunks, + planned_reduce_batch=reduce_batch_size, + planned_reduce_rounds=int(plan.get("estimated_reduce_rounds", 0)), + ) + + try: + logger.info( + "Map-reduce summary start style=%s chunks=%d est_doc_tokens=%d input_budget_tokens=%d reduce_batch=%d est_reduce_rounds=%d attempt=%d", + style, + len(chunks), + int(plan.get("estimated_doc_tokens", 0)), + int(plan.get("input_budget_tokens", 0)), + reduce_batch_size, + int(plan.get("estimated_reduce_rounds", 0)), + attempt + 1, + ) + out = self.summary_pipeline.summarize( + chunks=chunks, + call_chat=self._chat_text, + base_system_prompt=self._base_system_prompt(), + map_instruction=map_instruction, + reduce_instruction=reduce_instruction, + final_instruction=final_instruction, + max_tokens=final_tokens, + temperature=final_temp, + max_chunks_override=target_chunks, + reduce_batch_size_override=reduce_batch_size, + ) + logger.info("Map-reduce summary done style=%s output_chars=%d", style, len(out or "")) + return out + except Exception as e: + last_error = e + if attempt == 0 and self._is_context_limit_error(e): + self._shrink_context_window(reason="map-reduce-replan") + logger.warning("Retrying map-reduce with tighter context budget") + continue + logger.exception("Map-reduce summary failed style=%s", style) + return "" + if last_error: + logger.exception("Map-reduce summary failed after retries style=%s error=%s", style, str(last_error)) + return "" + + def _call_chat(self, system_prompt: str, user_prompt: str, max_tokens: int, temperature: float, stream: bool = False): + sys_txt, usr_txt, out_tokens = self._fit_prompts_to_context( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + margin_tokens=900, + ) + try: + return self.provider.call_chat(sys_txt, usr_txt, out_tokens, temperature, stream=stream) + except Exception as e: + if self._is_context_limit_error(e): + self._shrink_context_window(reason="context-limit-error") + sys_txt2, usr_txt2, out_tokens2 = self._fit_prompts_to_context( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + margin_tokens=self.context_retry_margin_tokens, + ) + try: + logger.warning("Retrying LLM call after context-limit adjustment") + return self.provider.call_chat(sys_txt2, usr_txt2, out_tokens2, temperature, stream=stream) + except Exception as retry_error: + logger.exception("LLM retry failed after context-limit adjustment") + raise RuntimeError(f"LLM connection/request failed: {str(retry_error)}") + logger.exception("LLM request failed") + raise RuntimeError(f"LLM connection/request failed: {str(e)}") + + # ---------------------------- + # Dynamic section discovery (2 to 5) + # ---------------------------- + def _discover_dynamic_sections( + self, + fitted_text: str, + min_sections: int = 2, + max_sections: int = 5, + temperature: float = 0.2, + ) -> List[Dict[str, Any]]: + self._ensure_initialized() + min_sections = max(1, int(min_sections)) + max_sections = max(min_sections, int(max_sections)) + if not self.dynamic_sections_use_llm: + try: + heuristic = self.section_service._heuristic_sections_from_text( + fitted_text, + max_sections=max_sections, + ) + if len(heuristic) >= min_sections: + return heuristic[:max_sections] + except Exception: + logger.exception("Heuristic section discovery failed; using fallback.") + return [ + {"title": "General Summary", "hint": "What the document is about"}, + {"title": "Key Extracts", "hint": "Important names, dates, numbers"}, + ][:max_sections] + try: + return self.section_service.discover_dynamic_sections( + fitted_text=fitted_text, + min_sections=min_sections, + max_sections=max_sections, + temperature=temperature, + ) + except Exception: + logger.exception("Dynamic section discovery failed; using fallback.") + return [ + {"title": "General Summary", "hint": "What the document is about"}, + {"title": "Key Extracts", "hint": "Important names, dates, numbers"}, + ][: max(1, int(max_sections))] + + # ---------------------------- + # Fact extraction with anchors (internal) + # ---------------------------- + def _extract_facts_with_anchors(self, section_title: str, section_hint: str, fitted_text: str) -> List[Dict[str, Any]]: + self._ensure_initialized() + try: + return self.section_service.extract_facts_with_anchors(section_title, section_hint, fitted_text) + except Exception: + logger.exception("Failed to extract facts with anchors") + return [] + + def _validate_anchored_facts(self, facts: List[Dict[str, Any]], fitted_text: str) -> List[Dict[str, Any]]: + return self.section_service.validate_anchored_facts(facts, fitted_text) + + # ---------------------------- + # Final section writing (user-facing, no anchors shown) + # ---------------------------- + def _write_section_from_facts(self, section_title: str, facts: List[Dict[str, Any]], max_tokens: int, temperature: float) -> str: + return self.section_service.write_section_from_facts(section_title, facts, max_tokens, temperature) + + # ---------------------------- + # Public API + # ---------------------------- + def summarize_financial( + self, + text: str, + mode: str = "financial_initial", + section: str = None, + max_tokens: int = None, + temperature: float = None, + stream: bool = False, + doc_id: str = None, + ) -> Union[str, Iterator[str]]: + self._ensure_initialized() + self._evict_cache_if_needed() + + max_tokens = max_tokens or config.LLM_MAX_TOKENS + temperature = temperature if temperature is not None else config.LLM_TEMPERATURE + + if mode not in ("financial_initial", "financial_section", "financial_overall", "financial_sectionwise"): + raise ValueError( + "mode must be one of: financial_initial, financial_section, financial_overall, financial_sectionwise" + ) + + if stream and mode != "financial_overall": + raise ValueError("stream=True is only supported for mode='financial_overall'") + + if not text or not text.strip(): + return "No text found to summarize." + + if mode == "financial_initial": + if (not (self.retrieval_first_enable and doc_id)) and self._should_use_map_reduce(text=text, doc_id=doc_id): + mr = self._map_reduce_summary( + text=text, + doc_id=doc_id, + max_tokens=max_tokens, + temperature=temperature, + style="initial", + ) + if mr: + return normalize_money(clean_text(mr)) + mode_text = self._context_for_mode( + text=text, + mode=mode, + max_output_tokens=max_tokens, + doc_id=doc_id, + ) + out = self._financial_initial(mode_text, max_tokens=max_tokens, temperature=temperature) + return normalize_money(clean_text(out)) + + if mode == "financial_overall": + if not stream and (not (self.retrieval_first_enable and doc_id)) and self._should_use_map_reduce(text=text, doc_id=doc_id): + mr = self._map_reduce_summary( + text=text, + doc_id=doc_id, + max_tokens=max_tokens, + temperature=temperature, + style="overall", + ) + if mr: + return normalize_money(clean_text(mr)) + mode_text = self._context_for_mode( + text=text, + mode=mode, + max_output_tokens=max_tokens, + doc_id=doc_id, + ) + out = self._financial_overall(mode_text, max_tokens=max_tokens, temperature=temperature, stream=stream) + if isinstance(out, str): + return normalize_money(clean_text(out)) + return out + + if mode == "financial_section": + # Dynamic: accept ANY section title. + sec = self._normalize_section_title(section or "") + if not sec: + raise ValueError("section must be provided for mode='financial_section'") + + # Optional hint from discovery (helps retrieval) + hint = "" + try: + if doc_id: + hint = self.get_doc_section_hint(doc_id, sec) + except Exception: + hint = "" + + mode_text = self._context_for_mode( + text=text, + mode=mode, + max_output_tokens=max_tokens, + doc_id=doc_id, + section=sec, + section_hint=hint, + ) + + # 1) Extract facts + anchors (internal) + facts = self._extract_facts_with_anchors(sec, hint, mode_text) + + # 2) Validate anchors exist in document + valid_facts = self._validate_anchored_facts(facts, mode_text) + + # If too few facts are anchored, treat as unsupported + if facts: + ratio = (len(valid_facts) / float(len(facts))) if len(facts) > 0 else 0.0 + else: + ratio = 0.0 + + if not valid_facts or ratio < self.min_anchored_fact_ratio: + out = f"{sec}\n- No supported information found in the text for this section.\n" + return normalize_money(clean_text(out)) + + # 3) Write final section from validated facts (anchors NOT shown) + out = self._write_section_from_facts(sec, valid_facts, max_tokens=max_tokens, temperature=temperature) + out = normalize_money(clean_text(out)) + out = dedupe_section_heading(out, sec) + return out + + # financial_sectionwise: keep a generic structured brief (still useful for some flows) + mode_text = self._context_for_mode( + text=text, + mode=mode, + max_output_tokens=max_tokens, + doc_id=doc_id, + ) + out = self._financial_sectionwise(mode_text, max_tokens=max_tokens, temperature=temperature) + return normalize_money(clean_text(out)) + + # ---------------------------- + # Implementations + # ---------------------------- + def _financial_initial(self, text: str, max_tokens: int, temperature: float) -> str: + init_max = max(160, min(int(max_tokens), 280)) + temperature_init = max(0.0, min(float(temperature), 0.25)) + + system_prompt = self._base_system_prompt() + ( + "\nWrite a short generalized summary that tells the user what this document is about.\n" + "Focus on: document type, company/entity, reporting period/date, and purpose ONLY if explicitly stated.\n" + "Rules:\n" + "- Plain text only.\n" + "- 3 to 5 sentences.\n" + "- Prefer exact names, dates, and periods when present.\n" + "- Include key numeric highlights ONLY if explicitly present.\n" + ) + + user_prompt = f"""Write the generalized summary for this document. + +Document: +{text} +""" + resp = self._call_chat(system_prompt, user_prompt, max_tokens=init_max, temperature=temperature_init, stream=False) + return self._clean_or_retry_summary( + draft=extract_response_text(resp), + source_text=text, + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=init_max, + temperature=temperature_init, + ) + + def _financial_overall(self, text: str, max_tokens: int, temperature: float, stream: bool) -> Union[str, Iterator[str]]: + temperature_overall = max(0.0, min(float(temperature), 0.35)) + + system_prompt = self._base_system_prompt() + ( + "\nCreate a concise summary for a non-expert user.\n" + "Prefer concrete numbers and dates when present.\n" + "Keep it readable.\n" + "\nRules:\n" + "- Do not invent.\n" + "- If the document is an article or non-financial, summarize its key claims and any numbers.\n" + ) + + user_prompt = f"""Create a neat summary of the document below. + +Document: +{text} + +Write 8 to 14 bullet points max. +""" + resp = self._call_chat(system_prompt, user_prompt, max_tokens=max_tokens, temperature=temperature_overall, stream=stream) + if stream: + return self._stream_response(resp) + return self._clean_or_retry_summary( + draft=extract_response_text(resp), + source_text=text, + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=max_tokens, + temperature=temperature_overall, + ) + + def _financial_sectionwise(self, text: str, max_tokens: int, temperature: float) -> str: + temperature_structured = max(0.0, min(float(temperature), 0.35)) + + system_prompt = self._base_system_prompt() + ( + "\nCreate a short section-wise brief.\n" + "Since the document type may be unknown, choose sensible headings that match the content.\n" + "Do not invent.\n" + "Plain text only.\n" + ) + + user_prompt = f"""Create a short section-wise brief for this document. + +Document: +{text} + +Return the section-wise brief now. +""" + resp = self._call_chat(system_prompt, user_prompt, max_tokens=max_tokens, temperature=temperature_structured, stream=False) + return extract_response_text(resp) + + def _stream_response(self, response) -> Iterator[str]: + accumulated = "" + for chunk in response: + delta = chunk.choices[0].delta + if delta and getattr(delta, "content", None): + accumulated += delta.content + if accumulated.endswith((".", "!", "?", "\n")): + yield normalize_money(clean_text(accumulated)) + accumulated = "" + if accumulated: + yield normalize_money(clean_text(accumulated)) + + def health_check(self) -> Dict[str, Any]: + try: + self._ensure_initialized() + out = self.provider.health_check() + if out.get("status") == "healthy": + out["model_context_tokens"] = self.model_context_tokens + return out + except Exception as e: + logger.error(f"Health check failed: {str(e)}") + return {"status": "unhealthy", "provider": self.get_provider_name(), "error": str(e)} + + +llm_service = LLMService() diff --git a/backend/services/llm/llm_text_utils.py b/backend/services/llm/llm_text_utils.py new file mode 100644 index 0000000..fdb4614 --- /dev/null +++ b/backend/services/llm/llm_text_utils.py @@ -0,0 +1,77 @@ +import re + + +def clean_text(text: str) -> str: + if not text: + return "" + text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) + text = re.sub(r"__(.+?)__", r"\1", text) + text = re.sub(r"\*(.+?)\*", r"\1", text) + text = re.sub(r"_(.+?)_", r"\1", text) + text = re.sub(r"```.*?```", "", text, flags=re.DOTALL) + text = re.sub(r"`(.+?)`", r"\1", text) + text = re.sub(r"^#+\s+", "", text, flags=re.MULTILINE) + text = re.sub(r"\n{3,}", "\n\n", text).strip() + return text + + +def extract_response_text(resp) -> str: + """ + Best-effort extraction for OpenAI-compatible responses where some models + (e.g., reasoning models) may not populate message.content. + """ + try: + choice0 = resp.choices[0] + except Exception: + return "" + + # 1) Standard chat content + msg = getattr(choice0, "message", None) + if msg is not None: + content = getattr(msg, "content", None) + if isinstance(content, str) and content.strip(): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + itype = str(item.get("type", "")).strip().lower() + if "reasoning" in itype or "think" in itype: + continue + txt = item.get("text") + if isinstance(txt, str) and txt.strip(): + parts.append(txt) + else: + itype = str(getattr(item, "type", "")).strip().lower() + if "reasoning" in itype or "think" in itype: + continue + txt = getattr(item, "text", None) + if isinstance(txt, str) and txt.strip(): + parts.append(txt) + joined = "\n".join(parts).strip() + if joined: + return joined + + # 2) Legacy text completion shape + txt = getattr(choice0, "text", None) + if isinstance(txt, str) and txt.strip(): + return txt + return "" + + +def normalize_money(text: str) -> str: + if not text: + return "" + text = re.sub(r"\bRs\.\s*", "₹ ", text, flags=re.IGNORECASE) + text = re.sub(r"\bINR\s*", "₹ ", text, flags=re.IGNORECASE) + return text + + +def dedupe_section_heading(text: str, section: str) -> str: + if not text: + return "" + t = text.replace("\r\n", "\n").strip() + lines = [ln.strip() for ln in t.split("\n") if ln.strip()] + body = [ln for ln in lines if ln.lower() != section.lower()] + out = section + "\n" + "\n".join(body) + return re.sub(r"\n{3,}", "\n\n", out).strip() diff --git a/backend/services/llm_service.py b/backend/services/llm_service.py deleted file mode 100644 index f904a0c..0000000 --- a/backend/services/llm_service.py +++ /dev/null @@ -1,902 +0,0 @@ -""" -LLM Service for Document Summarization (FinSights) -Uses OpenAI Chat Completions API (gpt-4o-mini) - -This version supports: -- Dynamic section chips (2 to 5) generated from the document at initial step. -- Section-wise summaries for ANY selected section title (no pre-defined section list). -- Readable section summaries WITHOUT showing quotes. - Internally, we extract "facts" with short "anchors" that must exist in the document, - validate anchors, then generate the final section from facts only. - -Compatibility: -- initial_summary_first_chunk(doc_id) returns ONLY a summary string (same as before). -- summarize_financial(mode="financial_section") accepts any section title. -- doc_id flow stays fast (no file re-upload needed on chip clicks). - -Anti-hallucination strategy: -- We do NOT require exact quotes in the final output. -- We do require each extracted fact to include at least one anchor that exists in the document text. - Anchors are not shown to the user; they are only used for validation. -""" - -from typing import Iterator, Dict, Any, Optional, Union, List -import logging -import re -import os -import time -import uuid -import hashlib -import json - -import config -from openai import OpenAI - -logger = logging.getLogger(__name__) - - -def clean_text(text: str) -> str: - if not text: - return "" - text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) - text = re.sub(r"__(.+?)__", r"\1", text) - text = re.sub(r"\*(.+?)\*", r"\1", text) - text = re.sub(r"_(.+?)_", r"\1", text) - text = re.sub(r"```.*?```", "", text, flags=re.DOTALL) - text = re.sub(r"`(.+?)`", r"\1", text) - text = re.sub(r"^#+\s+", "", text, flags=re.MULTILINE) - text = re.sub(r"\n{3,}", "\n\n", text).strip() - return text - - -def _normalize_money(text: str) -> str: - if not text: - return "" - text = re.sub(r"\bRs\.\s*", "₹ ", text, flags=re.IGNORECASE) - text = re.sub(r"\bINR\s*", "₹ ", text, flags=re.IGNORECASE) - return text - - -def _dedupe_section_heading(text: str, section: str) -> str: - if not text: - return "" - t = text.replace("\r\n", "\n").strip() - lines = [ln.strip() for ln in t.split("\n") if ln.strip()] - body = [ln for ln in lines if ln.lower() != section.lower()] - out = section + "\n" + "\n".join(body) - return re.sub(r"\n{3,}", "\n\n", out).strip() - - -class LLMService: - def __init__(self): - self.client: Optional[OpenAI] = None - self.model = os.getenv("OPENAI_MODEL", None) or getattr(config, "OPENAI_MODEL", None) or "gpt-4o-mini" - self._initialized = False - - # Large default; override via env if needed. - self.model_context_tokens = int(os.getenv("MODEL_CONTEXT_TOKENS", "128000")) - - # In-memory doc store for doc_id flow - self.doc_store: Dict[str, Dict[str, Any]] = {} - self.cache_ttl_seconds = int(os.getenv("CACHE_TTL_SECONDS", str(60 * 60))) # 1 hour - self.cache_max_docs = int(os.getenv("CACHE_MAX_DOCS", "25")) - - # Dynamic section discovery bounds (frontend chips) - self.dynamic_min_sections = int(os.getenv("DYNAMIC_SECTIONS_MIN", "2")) - self.dynamic_max_sections = int(os.getenv("DYNAMIC_SECTIONS_MAX", "5")) - - # Evidence/facts extraction bounds - self.facts_max_items = int(os.getenv("FACTS_MAX_ITEMS", "10")) - self.anchor_max_items = int(os.getenv("ANCHOR_MAX_ITEMS", "3")) - self.anchor_max_chars_each = int(os.getenv("ANCHOR_MAX_CHARS_EACH", "60")) - - # Validation threshold: how many facts must be anchored to proceed - # Example: 0.6 means at least 60% of extracted facts must have >=1 valid anchor. - self.min_anchored_fact_ratio = float(os.getenv("MIN_ANCHORED_FACT_RATIO", "0.6")) - - def _ensure_initialized(self): - if self._initialized: - return - if not config.OPENAI_API_KEY: - raise ValueError("OPENAI_API_KEY must be set in environment variables") - - self.client = OpenAI( - api_key=config.OPENAI_API_KEY, - timeout=float(os.getenv("OPENAI_TIMEOUT", "60")), - max_retries=int(os.getenv("OPENAI_MAX_RETRIES", "2")), - ) - self._initialized = True - - logger.info("OpenAI client initialized successfully") - logger.info(f"Model: {self.model}") - logger.info(f"MODEL_CONTEXT_TOKENS: {self.model_context_tokens}") - logger.info(f"CACHE_MAX_DOCS: {self.cache_max_docs}") - logger.info(f"CACHE_TTL_SECONDS: {self.cache_ttl_seconds}") - logger.info(f"DYNAMIC_SECTIONS_MIN: {self.dynamic_min_sections}") - logger.info(f"DYNAMIC_SECTIONS_MAX: {self.dynamic_max_sections}") - logger.info(f"FACTS_MAX_ITEMS: {self.facts_max_items}") - logger.info(f"ANCHOR_MAX_ITEMS: {self.anchor_max_items}") - logger.info(f"MIN_ANCHORED_FACT_RATIO: {self.min_anchored_fact_ratio}") - - # ---------------------------- - # Compatibility wrapper - # ---------------------------- - def summarize( - self, - text: str, - max_tokens: int = None, - temperature: float = None, - stream: bool = False, - mode: str = "financial_initial", - section: str = None, - ) -> Union[str, Iterator[str]]: - return self.summarize_financial( - text=text, - mode=mode, - section=section, - max_tokens=max_tokens, - temperature=temperature, - stream=stream, - ) - - # ---------------------------- - # doc_id based flow (compat) - # ---------------------------- - def create_doc(self, text: str) -> str: - self._ensure_initialized() - self._evict_cache_if_needed() - - if not text or not text.strip(): - raise ValueError("Empty text") - - doc_id = str(uuid.uuid4()) - dk = self._doc_key(text) - - self.doc_store[doc_id] = { - "ts": time.time(), - "text": text, - "doc_key": dk, - # list of {"title": str, "hint": str} created during initial summary - "sections": None, - - # ---- RAG indexing state (for chat) ---- - "index_status": "pending", # pending | ready | failed - "chunk_count": 0, - "index_error": "", - "index_started_at": None, - "index_finished_at": None, - } - return doc_id - - def get_doc_text(self, doc_id: str) -> str: - obj = self.doc_store.get(doc_id) - if not obj: - raise ValueError("Invalid doc_id") - obj["ts"] = time.time() - return obj.get("text", "") - - def get_doc_key(self, doc_id: str) -> str: - obj = self.doc_store.get(doc_id) - if not obj: - raise ValueError("Invalid doc_id") - obj["ts"] = time.time() - return obj.get("doc_key", "") - - def get_doc_sections(self, doc_id: str) -> List[str]: - """ - Returns discovered dynamic section titles for this doc_id. - If not discovered yet, returns []. - """ - obj = self.doc_store.get(doc_id) - if not obj: - raise ValueError("Invalid doc_id") - obj["ts"] = time.time() - - secs = obj.get("sections") - if not isinstance(secs, list): - return [] - - titles: List[str] = [] - for s in secs: - if isinstance(s, dict) and s.get("title"): - titles.append(str(s["title"]).strip()) - elif isinstance(s, str): - titles.append(str(s).strip()) - return [t for t in titles if t] - - def get_doc_section_hint(self, doc_id: str, section_title: str) -> str: - """ - Returns a stored hint for a discovered section (optional), else "". - Hints are short descriptions like "Totals, taxes, payment status". - """ - obj = self.doc_store.get(doc_id) - if not obj: - raise ValueError("Invalid doc_id") - obj["ts"] = time.time() - - secs = obj.get("sections") - if not isinstance(secs, list): - return "" - - target = self._normalize_section_title(section_title).lower() - if not target: - return "" - - for s in secs: - if isinstance(s, dict): - t = self._normalize_section_title(str(s.get("title", ""))).lower() - if t == target: - return str(s.get("hint", "") or "").strip() - return "" - - def prefetch_doc(self, doc_id: str) -> None: - # No-op (kept so existing routes won't break) - obj = self.doc_store.get(doc_id) - if obj: - obj["ts"] = time.time() - return - - def summarize_by_doc_id( - self, - doc_id: str, - mode: str = "financial_initial", - section: str = None, - max_tokens: int = None, - temperature: float = None, - stream: bool = False, - ) -> Union[str, Iterator[str]]: - text = self.get_doc_text(doc_id) - return self.summarize_financial( - text=text, - mode=mode, - section=section, - max_tokens=max_tokens, - temperature=temperature, - stream=stream, - doc_id=doc_id, - ) - - def initial_summary_first_chunk(self, doc_id: str, max_tokens: int = 240, temperature: float = 0.25) -> str: - """ - Compatibility method expected by routes/frontend. - Uses full doc text (fitted to context) to produce 4-5 sentences. - - Updated behavior: - - Also runs dynamic section discovery and stores it in doc_store[doc_id]["sections"]. - - Still returns ONLY the summary string (so existing routes don't break). - """ - self._ensure_initialized() - text = self.get_doc_text(doc_id) - fitted = self._fit_text_to_context(text, max_output_tokens=max_tokens) - - # Dynamic sections discovery (best effort; never breaks summary) - try: - sections = self._discover_dynamic_sections( - fitted_text=fitted, - min_sections=self.dynamic_min_sections, - max_sections=self.dynamic_max_sections, - temperature=0.2, - ) - obj = self.doc_store.get(doc_id) - if obj is not None: - obj["sections"] = sections - obj["ts"] = time.time() - except Exception: - logger.exception("Dynamic section discovery failed (continuing with summary only).") - - # Normal initial summary - system_prompt = self._base_system_prompt() + ( - "\nWrite a short generalized summary that tells the user what this document is about.\n" - "Focus on: document type, company/entity, reporting period/date, and purpose ONLY if explicitly stated.\n" - "Rules:\n" - "- Plain text only.\n" - "- 4 to 5 sentences.\n" - "- Prefer exact names, dates, and periods when present.\n" - "- Include key numeric highlights ONLY if explicitly present.\n" - "- Do not invent any details.\n" - ) - - user_prompt = f"""Write the generalized summary for this document. - -Document: -{fitted} -""" - resp = self._call_chat( - system_prompt, - user_prompt, - max_tokens=max_tokens, - temperature=max(0.0, min(float(temperature), 0.35)), - stream=False, - ) - return _normalize_money(clean_text(resp.choices[0].message.content or "")) - - # ---------------------------- - # Prompt building - # ---------------------------- - @staticmethod - def _base_system_prompt() -> str: - return ( - "You are an analyst AI.\n" - "You must use ONLY information supported by the provided document text.\n" - "Do not invent facts, numbers, dates, names, or events.\n" - "If you are uncertain, state it clearly.\n" - "Output must be plain text only (no markdown emphasis: no **, *, _).\n" - ) - def chat_with_context( - self, - question: str, - context: str, - max_tokens: int = 500, - temperature: float = 0.2, - ) -> str: - """ - Answer a user question using ONLY the retrieved context. - If the answer is not supported by the context, say so. - """ - self._ensure_initialized() - - q = (question or "").strip() - ctx = (context or "").strip() - - if not q: - return "Please enter a question." - - system_prompt = ( - self._base_system_prompt() - + "\nYou are answering questions about an uploaded document.\n" - "Use ONLY the provided CONTEXT.\n" - "If the answer is not in the context, say you cannot find it.\n" - "Keep the response concise and readable.\n" - ) - - user_prompt = f"""CONTEXT: -{ctx} - -QUESTION: -{q} - -Answer using only the context. -""" - - resp = self._call_chat( - system_prompt=system_prompt, - user_prompt=user_prompt, - max_tokens=max_tokens, - temperature=max(0.0, min(float(temperature), 0.5)), - stream=False, - ) - return _normalize_money(clean_text(resp.choices[0].message.content or "")) - - - # ---------------------------- - # Core helpers - # ---------------------------- - @staticmethod - def _estimate_tokens(text: str) -> int: - return max(1, int(len(text) / 4)) if text else 0 - - @staticmethod - def _doc_key(text: str) -> str: - return hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest() - - @staticmethod - def _normalize_section_title(title: str) -> str: - t = (title or "").strip() - t = re.sub(r"\s{2,}", " ", t) - t = t.strip(" -:\t\r\n") - if len(t) > 56: - t = t[:56].rstrip() - return t - - def _evict_cache_if_needed(self): - now = time.time() - - expired_docs = [k for k, v in self.doc_store.items() if now - v.get("ts", now) > self.cache_ttl_seconds] - for k in expired_docs: - self.doc_store.pop(k, None) - - if len(self.doc_store) <= self.cache_max_docs: - return - - items = sorted(self.doc_store.items(), key=lambda kv: kv[1].get("ts", 0)) - while len(items) > self.cache_max_docs: - k, _ = items.pop(0) - self.doc_store.pop(k, None) - - def _fit_text_to_context(self, text: str, max_output_tokens: int) -> str: - """ - Truncate input so input + output stays within model_context_tokens. - Keeps head + tail when truncating. - """ - if not text: - return "" - - overhead_tokens = 900 - available_input_tokens = max(500, self.model_context_tokens - int(max_output_tokens) - overhead_tokens) - - est = self._estimate_tokens(text) - if est <= available_input_tokens: - return text - - ratio = available_input_tokens / float(est) - keep_chars = max(2500, int(len(text) * ratio)) - - head = text[: int(keep_chars * 0.72)] - tail = text[-int(keep_chars * 0.28) :] if keep_chars > 4500 else "" - - truncated = head - if tail and tail not in head: - truncated = head + "\n\n[...TRUNCATED...]\n\n" + tail - - return truncated - - def _call_chat(self, system_prompt: str, user_prompt: str, max_tokens: int, temperature: float, stream: bool = False): - try: - return self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - max_tokens=max_tokens, - temperature=temperature, - stream=stream, - ) - except Exception as e: - logger.exception("OpenAI request failed") - raise RuntimeError(f"OpenAI connection/request failed: {str(e)}") - - # ---------------------------- - # Dynamic section discovery (2 to 5) - # ---------------------------- - def _discover_dynamic_sections( - self, - fitted_text: str, - min_sections: int = 2, - max_sections: int = 5, - temperature: float = 0.2, - ) -> List[Dict[str, Any]]: - """ - Returns 2 to 5 section items derived from the document content. - This is for UI chips. We avoid quotes here; we keep short hints only. - - Output item shape: - - {"title": "Invoice Totals", "hint": "Totals, taxes, payment status"} - """ - self._ensure_initialized() - - min_sections = max(1, min(int(min_sections), 6)) - max_sections = max(min_sections, min(int(max_sections), 8)) - - system_prompt = ( - self._base_system_prompt() - + "\nYou propose section options (chips) for a summarization UI.\n" - "The user can upload ANY document: invoices, payroll, tax returns, audit reports, loan documents, or general articles.\n" - "Propose only sections that are likely supported by the text.\n" - "\nCRITICAL RULES:\n" - f"- Return between {min_sections} and {max_sections} sections.\n" - "- Use short titles (2 to 5 words). No numbering.\n" - "- Provide a short hint (6 to 12 words) describing what to expect.\n" - "- Do not invent specific numbers.\n" - "- Output MUST be strict JSON only.\n" - ) - - user_prompt = f"""Read the document and propose section chips. - -Return JSON: -{{ - "sections": [ - {{ - "title": "Short Title", - "hint": "Short description" - }} - ] -}} - -Document: -{fitted_text} -""" - - resp = self._call_chat( - system_prompt, - user_prompt, - max_tokens=350, - temperature=max(0.0, min(float(temperature), 0.35)), - stream=False, - ) - raw = (resp.choices[0].message.content or "").strip() - - items: List[Dict[str, Any]] = [] - try: - data = json.loads(raw) - secs = data.get("sections", []) - if isinstance(secs, list): - for s in secs: - if not isinstance(s, dict): - continue - title = self._normalize_section_title(str(s.get("title", "")).strip()) - hint = str(s.get("hint", "") or "").strip() - hint = re.sub(r"\s{2,}", " ", hint) - if not title: - continue - if len(hint) > 90: - hint = hint[:90].rstrip() - items.append({"title": title, "hint": hint}) - except Exception: - logger.exception("Failed to parse dynamic sections JSON; using fallback.") - - # de-dupe titles - out: List[Dict[str, Any]] = [] - seen = set() - for it in items: - key = it["title"].lower() - if key in seen: - continue - seen.add(key) - out.append(it) - - if len(out) >= min_sections: - return out[:max_sections] - - # Fallback for low-signal docs - fallback = [ - {"title": "General Summary", "hint": "What the document is about"}, - {"title": "Key Extracts", "hint": "Important names, dates, numbers"}, - ] - return fallback[:max_sections] - - # ---------------------------- - # Fact extraction with anchors (internal) - # ---------------------------- - def _extract_facts_with_anchors(self, section_title: str, section_hint: str, fitted_text: str) -> List[Dict[str, Any]]: - """ - Extract readable facts (paraphrased) relevant to a requested section. - Each fact must include at least one short anchor that appears in the document text. - Anchors are used ONLY for validation and are not shown in the final user output. - """ - title = self._normalize_section_title(section_title) - hint = (section_hint or "").strip() - - if not title: - return [] - - system_prompt = ( - self._base_system_prompt() - + "\nTask: extract readable key facts for a requested section.\n" - "Facts must be supported by the document.\n" - "\nCRITICAL RULES:\n" - "- Return STRICT JSON only.\n" - "- Facts must be written in plain English (not copied verbatim).\n" - "- Each fact MUST include at least one anchor.\n" - "- Anchors must be SHORT strings that appear verbatim in the document (examples: a number, a date, an entity name, a short label).\n" - "- Do NOT return long quotes.\n" - ) - - user_prompt = f"""Requested section: {title} -Section hint (if any): {hint} - -Return JSON: -{{ - "facts": [ - {{ - "point": "Readable summarized point", - "anchors": ["anchor1", "anchor2"] - }} - ] -}} - -Rules: -- Provide up to {self.facts_max_items} facts. -- Each fact must include 1 to {self.anchor_max_items} anchors. -- Anchors must be <= {self.anchor_max_chars_each} characters. -- Prefer anchors that include numbers, dates, totals, ratios, names, account labels, or table row labels. -- Do not invent any numbers. - -Document: -{fitted_text} -""" - - resp = self._call_chat(system_prompt, user_prompt, max_tokens=650, temperature=0.15, stream=False) - raw = (resp.choices[0].message.content or "").strip() - - facts: List[Dict[str, Any]] = [] - try: - data = json.loads(raw) - items = data.get("facts", []) - if isinstance(items, list): - for it in items: - if not isinstance(it, dict): - continue - point = str(it.get("point", "") or "").strip() - anchors = it.get("anchors", []) - if not point or not isinstance(anchors, list): - continue - - # normalize anchors - norm_anchors: List[str] = [] - for a in anchors[: self.anchor_max_items]: - s = str(a or "").strip() - if not s: - continue - if len(s) > self.anchor_max_chars_each: - s = s[: self.anchor_max_chars_each].rstrip() - norm_anchors.append(s) - - if not norm_anchors: - continue - - facts.append({"point": point, "anchors": norm_anchors}) - except Exception: - logger.exception("Failed to parse facts JSON from extractor") - return [] - - # de-dupe similar points - out: List[Dict[str, Any]] = [] - seen = set() - for f in facts: - key = re.sub(r"\s+", " ", f["point"]).strip().lower() - if key in seen: - continue - seen.add(key) - out.append(f) - return out - - def _validate_anchored_facts(self, facts: List[Dict[str, Any]], fitted_text: str) -> List[Dict[str, Any]]: - """ - Keep only facts that have at least one anchor substring present in fitted_text. - """ - if not facts: - return [] - - valid: List[Dict[str, Any]] = [] - for f in facts: - anchors = f.get("anchors", []) - if not isinstance(anchors, list): - continue - ok = False - for a in anchors: - if not a: - continue - if str(a) in fitted_text: - ok = True - break - if ok: - valid.append(f) - return valid - - # ---------------------------- - # Final section writing (user-facing, no anchors shown) - # ---------------------------- - def _write_section_from_facts(self, section_title: str, facts: List[Dict[str, Any]], max_tokens: int, temperature: float) -> str: - title = self._normalize_section_title(section_title) - - if not facts: - return f"{title}\n- No supported information found in the text for this section.\n" - - # Build fact lines (anchors not shown) - fact_lines: List[str] = [] - for f in facts: - p = str(f.get("point", "") or "").strip() - if p: - fact_lines.append(f"- {p}") - - if not fact_lines: - return f"{title}\n- No supported information found in the text for this section.\n" - - system_prompt = self._base_system_prompt() + ( - "\nWrite the requested section using ONLY the provided facts.\n" - "Do not invent.\n" - "\nOUTPUT RULES:\n" - "- Start with the heading exactly as provided.\n" - "- Keep it easy to read.\n" - "- Use bullets (recommended) or short paragraphs.\n" - "- Keep numbers when present in facts.\n" - "- If facts are weak, explicitly say it is limited.\n" - ) - - user_prompt = f"""Section heading: {title} - -Facts (use only these): -{chr(10).join(fact_lines)} - -Write the section now. -""" - - resp = self._call_chat( - system_prompt, - user_prompt, - max_tokens=max_tokens, - temperature=max(0.0, min(float(temperature), 0.3)), - stream=False, - ) - out = clean_text(resp.choices[0].message.content or "") - out = _dedupe_section_heading(out, title) - return out - - # ---------------------------- - # Public API - # ---------------------------- - def summarize_financial( - self, - text: str, - mode: str = "financial_initial", - section: str = None, - max_tokens: int = None, - temperature: float = None, - stream: bool = False, - doc_id: str = None, - ) -> Union[str, Iterator[str]]: - self._ensure_initialized() - self._evict_cache_if_needed() - - max_tokens = max_tokens or config.LLM_MAX_TOKENS - temperature = temperature if temperature is not None else config.LLM_TEMPERATURE - - if mode not in ("financial_initial", "financial_section", "financial_overall", "financial_sectionwise"): - raise ValueError( - "mode must be one of: financial_initial, financial_section, financial_overall, financial_sectionwise" - ) - - if stream and mode != "financial_overall": - raise ValueError("stream=True is only supported for mode='financial_overall'") - - if not text or not text.strip(): - return "No text found to summarize." - - fitted_text = self._fit_text_to_context(text, max_output_tokens=max_tokens) - - if mode == "financial_initial": - out = self._financial_initial(fitted_text, max_tokens=max_tokens, temperature=temperature) - return _normalize_money(clean_text(out)) - - if mode == "financial_overall": - out = self._financial_overall(fitted_text, max_tokens=max_tokens, temperature=temperature, stream=stream) - if isinstance(out, str): - return _normalize_money(clean_text(out)) - return out - - if mode == "financial_section": - # Dynamic: accept ANY section title. - sec = self._normalize_section_title(section or "") - if not sec: - raise ValueError("section must be provided for mode='financial_section'") - - # Optional hint from discovery (helps retrieval) - hint = "" - try: - if doc_id: - hint = self.get_doc_section_hint(doc_id, sec) - except Exception: - hint = "" - - # 1) Extract facts + anchors (internal) - facts = self._extract_facts_with_anchors(sec, hint, fitted_text) - - # 2) Validate anchors exist in document - valid_facts = self._validate_anchored_facts(facts, fitted_text) - - # If too few facts are anchored, treat as unsupported - if facts: - ratio = (len(valid_facts) / float(len(facts))) if len(facts) > 0 else 0.0 - else: - ratio = 0.0 - - if not valid_facts or ratio < self.min_anchored_fact_ratio: - out = f"{sec}\n- No supported information found in the text for this section.\n" - return _normalize_money(clean_text(out)) - - # 3) Write final section from validated facts (anchors NOT shown) - out = self._write_section_from_facts(sec, valid_facts, max_tokens=max_tokens, temperature=temperature) - out = _normalize_money(clean_text(out)) - out = _dedupe_section_heading(out, sec) - return out - - # financial_sectionwise: keep a generic structured brief (still useful for some flows) - out = self._financial_sectionwise(fitted_text, max_tokens=max_tokens, temperature=temperature) - return _normalize_money(clean_text(out)) - - # ---------------------------- - # Implementations - # ---------------------------- - def _financial_initial(self, text: str, max_tokens: int, temperature: float) -> str: - init_max = max(160, min(int(max_tokens), 280)) - temperature_init = max(0.0, min(float(temperature), 0.25)) - - system_prompt = self._base_system_prompt() + ( - "\nWrite a short generalized summary that tells the user what this document is about.\n" - "Focus on: document type, company/entity, reporting period/date, and purpose ONLY if explicitly stated.\n" - "Rules:\n" - "- Plain text only.\n" - "- 3 to 5 sentences.\n" - "- Prefer exact names, dates, and periods when present.\n" - "- Include key numeric highlights ONLY if explicitly present.\n" - ) - - user_prompt = f"""Write the generalized summary for this document. - -Document: -{text} -""" - resp = self._call_chat(system_prompt, user_prompt, max_tokens=init_max, temperature=temperature_init, stream=False) - return resp.choices[0].message.content or "" - - def _financial_overall(self, text: str, max_tokens: int, temperature: float, stream: bool) -> Union[str, Iterator[str]]: - temperature_overall = max(0.0, min(float(temperature), 0.35)) - - system_prompt = self._base_system_prompt() + ( - "\nCreate a concise summary for a non-expert user.\n" - "Prefer concrete numbers and dates when present.\n" - "Keep it readable.\n" - "\nRules:\n" - "- Do not invent.\n" - "- If the document is an article or non-financial, summarize its key claims and any numbers.\n" - ) - - user_prompt = f"""Create a neat summary of the document below. - -Document: -{text} - -Write 8 to 14 bullet points max. -""" - resp = self._call_chat(system_prompt, user_prompt, max_tokens=max_tokens, temperature=temperature_overall, stream=stream) - if stream: - return self._stream_response(resp) - return resp.choices[0].message.content or "" - - def _financial_sectionwise(self, text: str, max_tokens: int, temperature: float) -> str: - temperature_structured = max(0.0, min(float(temperature), 0.35)) - - system_prompt = self._base_system_prompt() + ( - "\nCreate a short section-wise brief.\n" - "Since the document type may be unknown, choose sensible headings that match the content.\n" - "Do not invent.\n" - "Plain text only.\n" - ) - - user_prompt = f"""Create a short section-wise brief for this document. - -Document: -{text} - -Return the section-wise brief now. -""" - resp = self._call_chat(system_prompt, user_prompt, max_tokens=max_tokens, temperature=temperature_structured, stream=False) - return resp.choices[0].message.content or "" - - def _stream_response(self, response) -> Iterator[str]: - accumulated = "" - for chunk in response: - delta = chunk.choices[0].delta - if delta and getattr(delta, "content", None): - accumulated += delta.content - if accumulated.endswith((".", "!", "?", "\n")): - yield _normalize_money(clean_text(accumulated)) - accumulated = "" - if accumulated: - yield _normalize_money(clean_text(accumulated)) - - def health_check(self) -> Dict[str, Any]: - try: - if not config.OPENAI_API_KEY: - return {"status": "not_configured", "provider": "OpenAI", "message": "OPENAI_API_KEY not configured"} - - self._ensure_initialized() - self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": "Say OK"}], - max_tokens=10, - temperature=0, - ) - return { - "status": "healthy", - "provider": "OpenAI", - "model": self.model, - "model_context_tokens": self.model_context_tokens, - } - except Exception as e: - logger.error(f"Health check failed: {str(e)}") - return {"status": "unhealthy", "provider": "OpenAI", "error": str(e)} - - -llm_service = LLMService() diff --git a/backend/services/observability_service.py b/backend/services/observability_service.py new file mode 100644 index 0000000..38f00d0 --- /dev/null +++ b/backend/services/observability_service.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import os +import time +import threading +from collections import deque +from contextvars import ContextVar, Token +from datetime import datetime, timezone +from typing import Any, Deque, Dict, List, Optional, Tuple + +_endpoint_ctx: ContextVar[str] = ContextVar("obs_endpoint", default="-") +_method_ctx: ContextVar[str] = ContextVar("obs_method", default="-") + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _to_int(value: Any, default: int = 0) -> int: + try: + if value is None: + return default + return int(value) + except Exception: + return default + + +def _usage_to_dict(usage: Any) -> Dict[str, Any]: + if usage is None: + return {} + if isinstance(usage, dict): + return usage + if hasattr(usage, "model_dump"): + try: + return usage.model_dump() + except Exception: + return {} + out: Dict[str, Any] = {} + for key in ("prompt_tokens", "completion_tokens", "total_tokens"): + if hasattr(usage, key): + out[key] = getattr(usage, key) + return out + + +def _estimate_tokens_from_chars(char_count: int) -> int: + # Practical rough estimate for observability dashboards. + if char_count <= 0: + return 0 + return max(1, char_count // 4) + + +class ObservabilityService: + def __init__(self) -> None: + self.max_rows = max(100, int(os.getenv("OBSERVABILITY_MAX_ROWS", "1000"))) + self.rows: Deque[Dict[str, Any]] = deque(maxlen=self.max_rows) + self.lock = threading.Lock() + + def set_request_context(self, endpoint: str, method: str) -> Tuple[Token, Token]: + ep = (endpoint or "-").strip() or "-" + m = (method or "-").strip().upper() or "-" + return _endpoint_ctx.set(ep), _method_ctx.set(m) + + def reset_request_context(self, tokens: Tuple[Token, Token]) -> None: + ep_token, method_token = tokens + _endpoint_ctx.reset(ep_token) + _method_ctx.reset(method_token) + + def _current_endpoint(self) -> str: + return f"{_method_ctx.get()} {_endpoint_ctx.get()}".strip() + + def _append(self, row: Dict[str, Any]) -> None: + with self.lock: + self.rows.appendleft(row) + + def record_request(self, status_code: int, duration_ms: float) -> None: + self._append( + { + "time": _now_iso(), + "endpoint": self._current_endpoint(), + "event": "request", + "model": "-", + "planned_chunks": 0, + "planned_reduce_batch": 0, + "planned_reduce_rounds": 0, + "input_budget_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "user_prompt_tokens": 0, + "system_prompt_tokens": 0, + "uploaded_document_tokens": 0, + "user_input_tokens": 0, + "duration_ms": round(float(duration_ms), 2), + "status": str(status_code), + } + ) + + def record_llm_call( + self, + *, + event: str, + model: str, + provider: str, + duration_ms: float, + usage: Any = None, + system_prompt_chars: int = 0, + user_prompt_chars: int = 0, + user_input_chars: int = 0, + uploaded_document_chars: int = 0, + success: bool = True, + error: str = "", + ) -> None: + usage_dict = _usage_to_dict(usage) + prompt_tokens = _to_int(usage_dict.get("prompt_tokens"), 0) + completion_tokens = _to_int(usage_dict.get("completion_tokens"), 0) + total_tokens = _to_int(usage_dict.get("total_tokens"), 0) + + est_system_tokens = _estimate_tokens_from_chars(system_prompt_chars) + est_user_prompt_tokens = _estimate_tokens_from_chars(user_prompt_chars) + est_user_input_tokens = _estimate_tokens_from_chars(user_input_chars) + est_uploaded_document_tokens = _estimate_tokens_from_chars(uploaded_document_chars) + est_input_tokens = est_system_tokens + est_user_prompt_tokens + est_user_input_tokens + est_uploaded_document_tokens + + if prompt_tokens <= 0: + prompt_tokens = est_input_tokens + if total_tokens <= 0: + total_tokens = max(prompt_tokens + completion_tokens, 0) + + self._append( + { + "time": _now_iso(), + "endpoint": self._current_endpoint(), + "event": f"{event}:{provider or '-'}", + "model": (model or "-").strip() or "-", + "planned_chunks": 0, + "planned_reduce_batch": 0, + "planned_reduce_rounds": 0, + "input_budget_tokens": 0, + "input_tokens": prompt_tokens, + "output_tokens": completion_tokens, + "total_tokens": total_tokens, + "uploaded_document_tokens": est_uploaded_document_tokens, + "user_prompt_tokens": est_user_prompt_tokens, + "user_input_tokens": est_user_input_tokens, + "system_prompt_tokens": est_system_tokens, + "duration_ms": round(float(duration_ms), 2), + "status": "ok" if success else f"error: {error[:120]}", + } + ) + + def record_map_reduce_plan( + self, + *, + style: str, + estimated_doc_tokens: int, + input_budget_tokens: int, + planned_chunks: int, + planned_reduce_batch: int, + planned_reduce_rounds: int, + ) -> None: + self._append( + { + "time": _now_iso(), + "endpoint": self._current_endpoint(), + "event": "map_reduce_plan", + "model": (style or "-").strip() or "-", + "planned_chunks": int(planned_chunks), + "planned_reduce_batch": int(planned_reduce_batch), + "planned_reduce_rounds": int(planned_reduce_rounds), + "input_budget_tokens": int(input_budget_tokens), + "input_tokens": int(estimated_doc_tokens), + "output_tokens": 0, + "total_tokens": int(estimated_doc_tokens), + "uploaded_document_tokens": int(estimated_doc_tokens), + "user_prompt_tokens": 0, + "user_input_tokens": 0, + "system_prompt_tokens": 0, + "duration_ms": 0.0, + "status": "plan", + } + ) + + def get_rows(self, limit: int = 100, llm_only: bool = False) -> List[Dict[str, Any]]: + lim = max(1, min(int(limit), self.max_rows)) + with self.lock: + rows = list(self.rows) + if llm_only: + rows = [ + r + for r in rows + if str(r.get("event", "")).startswith(("chat:", "embedding:", "embedding_fallback:", "map_reduce_plan")) + ] + return rows[:lim] + + def render_table(self, limit: int = 100, llm_only: bool = False) -> str: + rows = self.get_rows(limit=limit, llm_only=llm_only) + cols = [ + "time", + "endpoint", + "event", + "model", + "plan", + "budget", + "doc_tok", + "u_prompt", + "u_input", + "sys_tok", + "out_tok", + "total_tok", + "ms", + "status", + ] + numeric_cols = { + "budget", + "doc_tok", + "u_prompt", + "u_input", + "sys_tok", + "out_tok", + "total_tok", + "ms", + } + labels = { + "time": "time", + "endpoint": "endpoint", + "event": "event", + "model": "model", + "plan": "plan(c/b/r)", + "budget": "budget", + "doc_tok": "doc_tok", + "u_prompt": "u_prompt", + "u_input": "u_input", + "sys_tok": "sys_tok", + "out_tok": "out_tok", + "total_tok": "total_tok", + "ms": "ms", + "status": "status", + } + + if not rows: + return "No LLM observability data yet." if llm_only else "No observability data yet." + + def _short(s: str, max_len: int) -> str: + t = (s or "").strip() + if len(t) <= max_len: + return t + if max_len <= 3: + return t[:max_len] + return t[: max_len - 3] + "..." + + table_rows: List[Dict[str, Any]] = [] + for r in rows: + planned_chunks = int(r.get("planned_chunks", 0) or 0) + planned_batch = int(r.get("planned_reduce_batch", 0) or 0) + planned_rounds = int(r.get("planned_reduce_rounds", 0) or 0) + plan_val = "-" if (planned_chunks == 0 and planned_batch == 0 and planned_rounds == 0) else f"{planned_chunks}/{planned_batch}/{planned_rounds}" + table_rows.append( + { + "time": _short(str(r.get("time", "")), 26), + "endpoint": _short(str(r.get("endpoint", "")), 18), + "event": _short(str(r.get("event", "")), 18), + "model": _short(str(r.get("model", "")), 22), + "plan": plan_val, + "budget": int(r.get("input_budget_tokens", 0) or 0), + "doc_tok": int(r.get("uploaded_document_tokens", 0) or 0), + "u_prompt": int(r.get("user_prompt_tokens", 0) or 0), + "u_input": int(r.get("user_input_tokens", 0) or 0), + "sys_tok": int(r.get("system_prompt_tokens", 0) or 0), + "out_tok": int(r.get("output_tokens", 0) or 0), + "total_tok": int(r.get("total_tokens", 0) or 0), + "ms": f"{float(r.get('duration_ms', 0.0) or 0.0):.2f}", + "status": _short(str(r.get("status", "")), 12), + } + ) + + widths: Dict[str, int] = {c: len(labels[c]) for c in cols} + for r in table_rows: + for c in cols: + widths[c] = max(widths[c], len(str(r.get(c, "")))) + + def fmt_line(values: Dict[str, Any]) -> str: + parts: List[str] = [] + for c in cols: + v = str(values.get(c, "")) + parts.append(v.rjust(widths[c]) if c in numeric_cols else v.ljust(widths[c])) + return " | ".join(parts) + + header = fmt_line({c: labels[c] for c in cols}) + sep = "-+-".join("-" * widths[c] for c in cols) + lines = [header, sep] + for r in table_rows: + lines.append(fmt_line(r)) + return "\n".join(lines) + + +observability_service = ObservabilityService() diff --git a/backend/services/pdf/__init__.py b/backend/services/pdf/__init__.py new file mode 100644 index 0000000..e4e7015 --- /dev/null +++ b/backend/services/pdf/__init__.py @@ -0,0 +1,3 @@ +from .pdf_service import pdf_service + +__all__ = ["pdf_service"] diff --git a/backend/services/pdf_service.py b/backend/services/pdf/pdf_service.py similarity index 100% rename from backend/services/pdf_service.py rename to backend/services/pdf/pdf_service.py diff --git a/backend/services/rag/__init__.py b/backend/services/rag/__init__.py new file mode 100644 index 0000000..bfb46cd --- /dev/null +++ b/backend/services/rag/__init__.py @@ -0,0 +1 @@ +"""RAG services package.""" diff --git a/backend/services/rag_index_service.py b/backend/services/rag/rag_index_service.py similarity index 87% rename from backend/services/rag_index_service.py rename to backend/services/rag/rag_index_service.py index eeb1740..6f535bc 100644 --- a/backend/services/rag_index_service.py +++ b/backend/services/rag/rag_index_service.py @@ -7,11 +7,8 @@ import logging from typing import List, Dict, Any -import config -from openai import OpenAI - -from services.llm_service import llm_service -from services.vector_store import vector_store, VectorChunk +from services.llm.llm_service import llm_service +from .vector_store import vector_store, VectorChunk logger = logging.getLogger(__name__) @@ -23,9 +20,6 @@ class RAGIndexService: """ def __init__(self) -> None: - self._client: OpenAI | None = None - self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") - # Chunking defaults (tune later) self.chunk_chars = int(os.getenv("RAG_CHUNK_CHARS", "1400")) self.chunk_overlap_chars = int(os.getenv("RAG_CHUNK_OVERLAP_CHARS", "220")) @@ -33,17 +27,6 @@ def __init__(self) -> None: # Embedding batching (avoid huge requests) self.embed_batch_size = int(os.getenv("RAG_EMBED_BATCH_SIZE", "64")) - def _ensure_client(self) -> None: - if self._client is not None: - return - if not config.OPENAI_API_KEY: - raise ValueError("OPENAI_API_KEY must be set for RAG indexing") - self._client = OpenAI( - api_key=config.OPENAI_API_KEY, - timeout=float(os.getenv("OPENAI_TIMEOUT", "60")), - max_retries=int(os.getenv("OPENAI_MAX_RETRIES", "2")), - ) - def _set_doc_index_state( self, doc_id: str, @@ -110,17 +93,7 @@ def _chunk_text(self, text: str) -> List[Dict[str, Any]]: return out def _embed_texts(self, texts: List[str]) -> List[List[float]]: - self._ensure_client() - assert self._client is not None - - # OpenAI embeddings API accepts a list input. - resp = self._client.embeddings.create( - model=self.embedding_model, - input=texts, - ) - # Preserve order - embeddings: List[List[float]] = [d.embedding for d in resp.data] - return embeddings + return llm_service.embed_texts(texts) def index_doc(self, doc_id: str) -> None: """ diff --git a/backend/services/rag/summarization_pipeline.py b/backend/services/rag/summarization_pipeline.py new file mode 100644 index 0000000..90b2f47 --- /dev/null +++ b/backend/services/rag/summarization_pipeline.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, List, Sequence + + +ChatCall = Callable[[str, str, int, float], str] + + +@dataclass +class SourceChunk: + text: str + order: int + + +class SummarizationPipeline: + """ + Generic map-reduce summarization for large documents. + """ + + def __init__( + self, + max_chunks: int = 24, + chunk_chars: int = 2800, + chunk_overlap_chars: int = 250, + reduce_batch_size: int = 6, + ) -> None: + self.max_chunks = max(4, int(max_chunks)) + self.chunk_chars = max(800, int(chunk_chars)) + self.chunk_overlap_chars = max(0, min(int(chunk_overlap_chars), self.chunk_chars - 50)) + self.reduce_batch_size = max(2, int(reduce_batch_size)) + self.max_reduce_rounds = 8 + self.max_final_notes_chars = 24000 + + def build_chunks_from_text(self, text: str) -> List[SourceChunk]: + return self.build_chunks_from_text_with_limit(text=text, max_chunks=self.max_chunks) + + def build_chunks_from_text_with_limit(self, text: str, max_chunks: int) -> List[SourceChunk]: + t = (text or "").strip() + if not t: + return [] + + out: List[SourceChunk] = [] + n = len(t) + start = 0 + idx = 0 + limit = max(1, int(max_chunks)) + while start < n and len(out) < limit: + end = min(n, start + self.chunk_chars) + chunk = t[start:end].strip() + if chunk: + out.append(SourceChunk(text=chunk, order=idx)) + idx += 1 + if end >= n: + break + start = max(0, end - self.chunk_overlap_chars) + return out + + def limit_chunks(self, chunks: Sequence[SourceChunk]) -> List[SourceChunk]: + return self.limit_chunks_with_limit(chunks=chunks, max_chunks=self.max_chunks) + + def limit_chunks_with_limit(self, chunks: Sequence[SourceChunk], max_chunks: int) -> List[SourceChunk]: + if not chunks: + return [] + ordered = sorted(list(chunks), key=lambda c: c.order) + return ordered[: max(1, int(max_chunks))] + + def summarize( + self, + chunks: Sequence[SourceChunk], + call_chat: ChatCall, + base_system_prompt: str, + map_instruction: str, + reduce_instruction: str, + final_instruction: str, + max_tokens: int, + temperature: float, + max_chunks_override: int | None = None, + reduce_batch_size_override: int | None = None, + ) -> str: + max_chunks = int(max_chunks_override) if max_chunks_override is not None else self.max_chunks + reduce_batch_size = int(reduce_batch_size_override) if reduce_batch_size_override is not None else self.reduce_batch_size + reduce_batch_size = max(2, reduce_batch_size) + + src = self.limit_chunks_with_limit(chunks, max_chunks=max_chunks) + if not src: + return "" + + map_points: List[str] = [] + map_max_tokens = max(160, min(420, int(max_tokens // 3) if max_tokens else 260)) + map_temp = max(0.0, min(float(temperature), 0.25)) + + for ch in src: + system_prompt = base_system_prompt + "\n" + map_instruction + user_prompt = ( + "Chunk text:\n" + f"{ch.text}\n\n" + "Return concise bullet points from this chunk only." + ) + mapped = call_chat(system_prompt, user_prompt, map_max_tokens, map_temp).strip() + if mapped: + map_points.append(mapped) + + if not map_points: + return "" + + current = map_points + reduce_max_tokens = max(220, min(700, int(max_tokens // 2) if max_tokens else 420)) + rounds = 0 + while len(current) > reduce_batch_size and rounds < self.max_reduce_rounds: + rounds += 1 + nxt: List[str] = [] + for i in range(0, len(current), reduce_batch_size): + batch = current[i : i + reduce_batch_size] + system_prompt = base_system_prompt + "\n" + reduce_instruction + user_prompt = ( + "Summaries to merge:\n" + f"{chr(10).join(batch)}\n\n" + "Merge these into a single de-duplicated brief." + ) + merged = call_chat(system_prompt, user_prompt, reduce_max_tokens, map_temp).strip() + if merged: + nxt.append(merged) + # Prevent infinite loops when model returns empty reduce outputs. + if not nxt: + break + current = nxt + if len(current) <= 1: + break + + final_system = base_system_prompt + "\n" + final_instruction + final_notes = "\n".join(current).strip() + if len(final_notes) > self.max_final_notes_chars: + final_notes = final_notes[: self.max_final_notes_chars] + final_user = ( + "Consolidated notes:\n" + f"{final_notes}\n\n" + "Write the final output now." + ) + return call_chat(final_system, final_user, max_tokens, temperature).strip() diff --git a/backend/services/vector_store.py b/backend/services/rag/vector_store.py similarity index 88% rename from backend/services/vector_store.py rename to backend/services/rag/vector_store.py index f31fbfc..1d51e67 100644 --- a/backend/services/vector_store.py +++ b/backend/services/rag/vector_store.py @@ -54,6 +54,12 @@ def count(self, doc_id: str) -> int: with self._lock: return len(self._store.get(doc_id, [])) + def list_chunks(self, doc_id: str) -> List[VectorChunk]: + with self._lock: + chunks = list(self._store.get(doc_id, [])) + chunks.sort(key=lambda c: int(c.meta.get("index", 0)) if isinstance(c.meta, dict) else 0) + return chunks + def query( self, doc_id: str, diff --git a/frontend/src/pages/Generate.jsx b/frontend/src/pages/Generate.jsx index c914a21..a08d2e2 100644 --- a/frontend/src/pages/Generate.jsx +++ b/frontend/src/pages/Generate.jsx @@ -8,20 +8,30 @@ import FileUpload from '../components/FileUpload'; import { generateSummaryJson, cloneFormData, getRagStatus, ragChat, deleteVectors } from '../services/api'; // Helper function for streaming text effect -const streamText = (text, onUpdate, speedMs = 20) => { +const streamText = (text, onUpdate, targetMs = 900, tickMs = 20) => { + const full = String(text || ''); + if (!full) { + onUpdate(''); + return () => {}; + } + // Keep UI responsive: cap typing animation duration. + if (full.length <= 180) { + onUpdate(full); + return () => {}; + } + let index = 0; - let currentText = ''; - + const steps = Math.max(1, Math.floor(targetMs / tickMs)); + const chunkSize = Math.max(16, Math.ceil(full.length / steps)); const interval = setInterval(() => { - if (index < text.length) { - currentText += text[index]; - onUpdate(currentText); - index++; - } else { - clearInterval(interval); + if (index < full.length) { + index = Math.min(full.length, index + chunkSize); + onUpdate(full.slice(0, index)); + return; } - }, speedMs); - + clearInterval(interval); + }, tickMs); + return () => clearInterval(interval); }; @@ -344,7 +354,6 @@ export const Generate = () => { ]); // Stream the text - const lastIndex = setHistory.length - 1; streamText(text, (streamedText) => { setHistory((prev) => { const updated = [...prev]; diff --git a/frontend/src/services/api.js b/frontend/src/services/api.js index 2658d15..b17741b 100644 --- a/frontend/src/services/api.js +++ b/frontend/src/services/api.js @@ -218,7 +218,7 @@ export const getRagStatus = async (docId) => { /** * RAG chat: POST /v1/rag/chat (doc_id + message) */ -export const ragChat = async ({ docId, message, maxTokens = 500, temperature = 0.2 }) => { +export const ragChat = async ({ docId, message, maxTokens = 220, temperature = 0.2 }) => { const fd = new FormData(); fd.set('doc_id', (docId || '').trim()); fd.set('message', (message || '').trim()); @@ -260,4 +260,4 @@ export const deleteVectors = async (docId) => { console.warn(`Vector cleanup error: ${e.message}`); return { status: 'error', error: e.message }; } -}; \ No newline at end of file +};