diff --git a/backend/app/rag/agent.py b/backend/app/rag/agent.py index b7e91d5..b1f0f04 100644 --- a/backend/app/rag/agent.py +++ b/backend/app/rag/agent.py @@ -15,7 +15,7 @@ from app.config import get_settings from app.rag.retriever import retrieve from app.rag.graph_retriever import get_entity_context -from app.rag.prompts import AGENT_SYSTEM_PROMPT +from app.rag.prompts import AGENT_SYSTEM_PROMPT, MULTI_DOC_COMPARISON_GUIDANCE from app.exceptions import ExternalServiceException from app.rag.security import MALFORMED_OUTPUT_MESSAGE, OutputParserError, parse_agent_output from app.rag.tools import PDFSearchTool, MathTool, WebSearchTool @@ -61,6 +61,7 @@ def _format_chat_history(messages: List[Dict[str, str]]) -> str: def get_agent_executor( user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, @@ -68,7 +69,7 @@ def get_agent_executor( """Initialize the LangChain ReAct agent executor.""" # Initialize tools - pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k) + pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, document_ids=document_ids, top_k=top_k) tools = [pdf_tool, MathTool(), WebSearchTool()] # Initialize LLM @@ -90,7 +91,14 @@ def get_agent_executor( chat_llm = ChatHuggingFace(llm=llm) # Setup Agent - prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT) + agent_prompt_text = AGENT_SYSTEM_PROMPT + if document_ids and len(document_ids) > 1: + agent_prompt_text = agent_prompt_text.replace( + "Begin!", + MULTI_DOC_COMPARISON_GUIDANCE + "\nBegin!", + 1, + ) + prompt = PromptTemplate.from_template(agent_prompt_text) agent = create_react_agent(chat_llm, tools, prompt) executor = AgentExecutor( @@ -127,6 +135,7 @@ def generate_answer( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, @@ -154,7 +163,7 @@ def generate_answer( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, document_ids, hf_token, top_k, chat_history) result = executor.invoke({"input": question, "chat_history": formatted_history}) raw_answer = result.get("output", "") @@ -199,6 +208,7 @@ def generate_answer_stream( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, @@ -227,7 +237,7 @@ def generate_answer_stream( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, document_ids, hf_token, top_k, chat_history) sources_sent = False diff --git a/backend/app/rag/prompts.py b/backend/app/rag/prompts.py index 42bcc31..d5250c9 100644 --- a/backend/app/rag/prompts.py +++ b/backend/app/rag/prompts.py @@ -87,3 +87,11 @@ {chat_history} Question: {input} Thought: {agent_scratchpad}""" + +MULTI_DOC_COMPARISON_GUIDANCE = """ +MULTI-DOCUMENT MODE: +You are answering across multiple documents at once. When findings differ or overlap between documents: +- Attribute each finding to its specific source document using [Source: filename, Page X]. +- Explicitly note where documents agree and where they disagree or report different figures. +- Do not blend numbers or claims from different documents without making the source of each clear. +""" \ No newline at end of file diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 09610c6..cf16c32 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -211,7 +211,7 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @trace_function( "retrieve", - metadata_factory=lambda query, user_id, document_id=None, top_k=None: { + metadata_factory=lambda query, user_id, document_id=None, document_ids=None, top_k=None: { "user_id": user_id, "document_id": document_id, "embedding_model": settings.EMBEDDING_MODEL, @@ -226,6 +226,7 @@ def retrieve( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, top_k: Optional[int] = None, ) -> List[Dict[str, Any]]: """ @@ -250,6 +251,7 @@ def retrieve( query_embedding=query_vector, user_id=user_id, document_id=document_id, + document_ids=document_ids, top_k=effective_top_k, ) @@ -260,6 +262,7 @@ def retrieve( query=search_query, user_id=user_id, document_id=document_id, + document_ids=document_ids, top_k=effective_top_k, ) except Exception as exc: diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 0381375..01542a9 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -156,6 +156,7 @@ class PDFSearchTool(BaseTool): user_id: str document_id: Optional[str] = None + document_ids: Optional[List[str]] = None top_k: Optional[int] = None # We'll store sources here to retrieve them after agent execution last_sources: List[Dict[str, Any]] = [] @@ -167,6 +168,7 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) diff --git a/backend/app/routes/chat.py b/backend/app/routes/chat.py index 8957b45..83ae991 100644 --- a/backend/app/routes/chat.py +++ b/backend/app/routes/chat.py @@ -472,6 +472,7 @@ def generate_answer( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, @@ -482,6 +483,7 @@ def generate_answer( question=question, user_id=user_id, document_id=document_id, + document_ids=document_ids, hf_token=hf_token, top_k=top_k, chat_history=chat_history, @@ -492,6 +494,7 @@ def generate_answer_stream( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, @@ -502,6 +505,7 @@ def generate_answer_stream( question=question, user_id=user_id, document_id=document_id, + document_ids=document_ids, hf_token=hf_token, top_k=top_k, chat_history=chat_history, @@ -569,6 +573,33 @@ def ask_question( doc.last_accessed_at = datetime.now(timezone.utc) db.commit() + # Validate documents if multiple specified + elif payload.document_ids: + docs = ( + db.query(Document) + .filter( + Document.id.in_(payload.document_ids), + Document.user_id == user.id, + Document.is_deleted.is_(False), + ) + .all() + ) + + found_ids = {doc.id for doc in docs} + missing = [doc_id for doc_id in payload.document_ids if doc_id not in found_ids] + if missing: + raise NotFoundException("Document") + + not_ready = [doc.original_name for doc in docs if doc.status != "ready"] + if not_ready: + raise ValidationException( + f"Some documents are still processing: {', '.join(not_ready)}. Please wait." + ) + + for doc in docs: + doc.last_accessed_at = datetime.now(timezone.utc) + db.commit() + # Resolve or create session session_id = payload.session_id if not session_id: @@ -594,10 +625,13 @@ def ask_question( recent_messages.reverse() chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] + cache_doc_key = str(payload.document_id or "") + if payload.document_ids: + cache_doc_key = "multi:" + ",".join(sorted(payload.document_ids)) # Cache check — return instantly if this (user, document, question) was answered before cached_answer = get_cached_response( user_id=user.id, - document_id=str(payload.document_id or ""), + document_id=cache_doc_key, question=payload.question, ) if cached_answer is not None: @@ -612,6 +646,7 @@ def ask_question( question=payload.question, user_id=user.id, document_id=payload.document_id, + document_ids=payload.document_ids, hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, @@ -628,15 +663,15 @@ def ask_question( # Store result in cache for future identical questions set_cached_response( user_id=user.id, - document_id=str(payload.document_id or ""), + document_id=cache_doc_key, question=payload.question, answer=result["answer"], ) # Save to chat history - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) + _save_message(db, user.id, cache_doc_key, "user", payload.question, session_id=session_id) _save_message( - db, user.id, payload.document_id, "assistant", result["answer"], result["sources"], session_id=session_id + db, user.id, cache_doc_key, "assistant", result["answer"], result["sources"], session_id=session_id ) return ChatResponse( @@ -705,6 +740,33 @@ def ask_question_stream( doc.last_accessed_at = datetime.now(timezone.utc) db.commit() + # Validate documents if multiple specified + elif payload.document_ids: + docs = ( + db.query(Document) + .filter( + Document.id.in_(payload.document_ids), + Document.user_id == user.id, + Document.is_deleted.is_(False), + ) + .all() + ) + + found_ids = {doc.id for doc in docs} + missing = [doc_id for doc_id in payload.document_ids if doc_id not in found_ids] + if missing: + raise NotFoundException("Document") + + not_ready = [doc.original_name for doc in docs if doc.status != "ready"] + if not_ready: + raise ValidationException( + f"Some documents are still processing: {', '.join(not_ready)}. Please wait." + ) + + for doc in docs: + doc.last_accessed_at = datetime.now(timezone.utc) + db.commit() + started_at = time.perf_counter() # Resolve or create session @@ -732,13 +794,17 @@ def ask_question_stream( recent_messages.reverse() chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] + cache_doc_key = str(payload.document_id or "") + if payload.document_ids: + cache_doc_key = "multi:" + ",".join(sorted(payload.document_ids)) + # Save user message immediately - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) + _save_message(db, user.id, cache_doc_key, "user", payload.question, session_id=session_id) # Cache check before starting the stream cached_answer = get_cached_response( - user_id=user.d, - document_id=str(payload.document_id or ""), + user_id=user.id, + document_id=cache_doc_key, question=payload.question, ) if cached_answer is not None: @@ -769,7 +835,8 @@ def event_stream(): for chunk in generate_answer_stream( question=payload.question, user_id=user.id, - document_id=payload.document_id, + document_id=cache_doc_key, + document_ids=payload.document_ids, hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, @@ -791,7 +858,7 @@ def event_stream(): if full_answer: set_cached_response( user_id=user.id, - document_id=str(payload.document_id or ""), + document_id=cache_doc_key, question=payload.question, answer=full_answer, ) @@ -801,7 +868,7 @@ def event_stream(): with get_db_session() as save_db: _save_message( - save_db, user.id, payload.document_id, "assistant", full_answer, sources, session_id=session_id + save_db, user.id, cache_doc_key, "assistant", full_answer, sources, session_id=session_id ) # Log streaming response RAG completion diff --git a/backend/tests/test_multi_document_chat.py b/backend/tests/test_multi_document_chat.py new file mode 100644 index 0000000..cce84d4 --- /dev/null +++ b/backend/tests/test_multi_document_chat.py @@ -0,0 +1,164 @@ +from unittest.mock import MagicMock + +from app.rag import retriever +from app.models import Document + + +# ── Retrieval: document_ids reaches both vector and BM25 (dev's direct-call retrieve) ── + +def _mock_db(monkeypatch, doc_rows): + mock_db = MagicMock() + mock_db.__enter__.return_value = mock_db + mock_query = MagicMock() + mock_db.query.return_value = mock_query + mock_query.filter.return_value.all.return_value = doc_rows + monkeypatch.setattr("app.database.SessionLocal", lambda: mock_db) + + +def test_retrieve_forwards_document_ids_to_vector_and_bm25(monkeypatch): + _mock_db(monkeypatch, [("doc-a",), ("doc-b",)]) + + seen = {"vector": "unset", "bm25": "unset"} + + monkeypatch.setattr(retriever, "transform_query", lambda _q: ["q"]) + monkeypatch.setattr(retriever, "embed_query", lambda q: f"embedding:{q}") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + seen["vector"] = document_ids + return [{"id": "v1", "text": "vec", "filename": "a.pdf", "page": 1, "score": 0.5}] + + def fake_query_bm25(query, user_id, document_id=None, document_ids=None, top_k=10): + seen["bm25"] = document_ids + return [{"id": "b1", "text": "bm", "filename": "b.pdf", "page": 1, "score": 0.5}] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + monkeypatch.setattr("app.rag.bm25.query_bm25", fake_query_bm25) + + retriever.retrieve("question", user_id="user-1", document_ids=["doc-a", "doc-b"]) + + assert seen["vector"] == ["doc-a", "doc-b"] + # bm25 only runs when hybrid search is enabled; if it ran, it must have received the ids + if seen["bm25"] != "unset": + assert seen["bm25"] == ["doc-a", "doc-b"] + + +def test_retrieve_single_document_leaves_document_ids_none(monkeypatch): + _mock_db(monkeypatch, [("doc-a",)]) + + seen = {"vector_id": "unset", "vector_ids": "unset"} + + monkeypatch.setattr(retriever, "transform_query", lambda _q: ["q"]) + monkeypatch.setattr(retriever, "embed_query", lambda q: f"embedding:{q}") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + seen["vector_id"] = document_id + seen["vector_ids"] = document_ids + return [{"id": "v1", "text": "vec", "filename": "a.pdf", "page": 1, "score": 0.5}] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + + retriever.retrieve("question", user_id="user-1", document_id="doc-a") + + assert seen["vector_id"] == "doc-a" + assert seen["vector_ids"] is None + + +# ── Prompt: comparison guidance only when more than one document ── + +def test_comparison_guidance_present_only_for_multiple_documents(monkeypatch): + from app.rag import agent + from app.rag.prompts import MULTI_DOC_COMPARISON_GUIDANCE + + captured = {} + + class FakeLLM: + def __init__(self, *a, **k): + pass + + def capture_prompt(llm, tools, prompt): + captured["template"] = prompt.template + return "agent" + + monkeypatch.setattr(agent, "get_llm_client", lambda hf_token=None: FakeLLM()) + monkeypatch.setattr(agent, "create_react_agent", capture_prompt) + monkeypatch.setattr(agent, "AgentExecutor", lambda **kwargs: kwargs) + + agent.get_agent_executor(user_id="user-1", document_ids=["doc-a", "doc-b"]) + assert MULTI_DOC_COMPARISON_GUIDANCE.strip() in captured["template"] + + captured.clear() + agent.get_agent_executor(user_id="user-1", document_id="doc-a") + assert MULTI_DOC_COMPARISON_GUIDANCE.strip() not in captured["template"] + + +# ── Route guard: ownership + readiness for document_ids ── + +def test_chat_ask_multi_doc_success(client, auth_headers, ready_document, db_session, user, monkeypatch): + second = Document( + user_id=user.id, + filename="second.txt", + original_name="second.txt", + file_size=128, + status="ready", + ) + db_session.add(second) + db_session.commit() + db_session.refresh(second) + + monkeypatch.setattr( + "app.routes.chat.generate_answer", + lambda question, user_id, document_id=None, document_ids=None, **kwargs: { + "answer": "Across both docs", + "sources": [], + }, + ) + + response = client.post( + "/api/v1/chat/ask", + headers=auth_headers, + json={"question": "Compare them", "document_ids": [ready_document.id, second.id]}, + ) + + assert response.status_code == 200 + assert response.json()["answer"] == "Across both docs" + + +def test_chat_ask_multi_doc_rejects_missing_document(client, auth_headers, ready_document): + response = client.post( + "/api/v1/chat/ask", + headers=auth_headers, + json={"question": "Compare", "document_ids": [ready_document.id, "missing-doc-id"]}, + ) + assert response.status_code == 404 + + +def test_chat_ask_multi_doc_rejects_not_ready_document(client, auth_headers, ready_document, pending_document): + response = client.post( + "/api/v1/chat/ask", + headers=auth_headers, + json={"question": "Compare", "document_ids": [ready_document.id, pending_document.id]}, + ) + assert response.status_code == 400 + + +def test_chat_ask_multi_doc_rejects_other_users_document(client, auth_headers, ready_document, db_session, other_user): + other_doc = Document( + user_id=other_user.id, + filename="other.txt", + original_name="other.txt", + file_size=64, + status="ready", + ) + db_session.add(other_doc) + db_session.commit() + db_session.refresh(other_doc) + + response = client.post( + "/api/v1/chat/ask", + headers=auth_headers, + json={"question": "Compare", "document_ids": [ready_document.id, other_doc.id]}, + ) + # not owned -> treated as missing + assert response.status_code == 404 \ No newline at end of file diff --git a/backend/tests/test_rag_tools.py b/backend/tests/test_rag_tools.py index 30bbc9f..9726783 100644 --- a/backend/tests/test_rag_tools.py +++ b/backend/tests/test_rag_tools.py @@ -154,7 +154,7 @@ def test_pdf_search_tool_formats_chunks_and_graph_context(monkeypatch): retrieve_calls = [] graph_calls = [] - def fake_retrieve(query, user_id, document_id=None, top_k=None): + def fake_retrieve(query, user_id, document_id=None, document_ids=None, top_k=None): retrieve_calls.append((query, user_id, document_id)) return chunks