Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions backend/app/rag/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from app.config import get_settings
from app.rag.retriever import retrieve
from app.rag.graph_retriever import get_entity_context
from app.rag.prompts import AGENT_SYSTEM_PROMPT
from app.rag.prompts import AGENT_SYSTEM_PROMPT, MULTI_DOC_COMPARISON_GUIDANCE
from app.exceptions import ExternalServiceException
from app.rag.security import MALFORMED_OUTPUT_MESSAGE, OutputParserError, parse_agent_output
from app.rag.tools import PDFSearchTool, MathTool, WebSearchTool
Expand Down Expand Up @@ -61,14 +61,15 @@ def _format_chat_history(messages: List[Dict[str, str]]) -> str:
def get_agent_executor(
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
):
"""Initialize the LangChain ReAct agent executor."""

# Initialize tools
pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k)
pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, document_ids=document_ids, top_k=top_k)
tools = [pdf_tool, MathTool(), WebSearchTool()]

# Initialize LLM
Expand All @@ -90,7 +91,14 @@ def get_agent_executor(
chat_llm = ChatHuggingFace(llm=llm)

# Setup Agent
prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT)
agent_prompt_text = AGENT_SYSTEM_PROMPT
if document_ids and len(document_ids) > 1:
agent_prompt_text = agent_prompt_text.replace(
"Begin!",
MULTI_DOC_COMPARISON_GUIDANCE + "\nBegin!",
1,
)
prompt = PromptTemplate.from_template(agent_prompt_text)
agent = create_react_agent(chat_llm, tools, prompt)

executor = AgentExecutor(
Expand Down Expand Up @@ -127,6 +135,7 @@ def generate_answer(
question: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
Expand Down Expand Up @@ -154,7 +163,7 @@ def generate_answer(

# ── Run Agent ────────────────────────────────────
try:
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history)
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, document_ids, hf_token, top_k, chat_history)
result = executor.invoke({"input": question, "chat_history": formatted_history})

raw_answer = result.get("output", "")
Expand Down Expand Up @@ -199,6 +208,7 @@ def generate_answer_stream(
question: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
Expand Down Expand Up @@ -227,7 +237,7 @@ def generate_answer_stream(

# ── Run Agent ────────────────────────────────────
try:
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history)
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, document_ids, hf_token, top_k, chat_history)

sources_sent = False

Expand Down
8 changes: 8 additions & 0 deletions backend/app/rag/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,11 @@
{chat_history}
Question: {input}
Thought: {agent_scratchpad}"""

MULTI_DOC_COMPARISON_GUIDANCE = """
MULTI-DOCUMENT MODE:
You are answering across multiple documents at once. When findings differ or overlap between documents:
- Attribute each finding to its specific source document using [Source: filename, Page X].
- Explicitly note where documents agree and where they disagree or report different figures.
- Do not blend numbers or claims from different documents without making the source of each clear.
"""
5 changes: 4 additions & 1 deletion backend/app/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

@trace_function(
"retrieve",
metadata_factory=lambda query, user_id, document_id=None, top_k=None: {
metadata_factory=lambda query, user_id, document_id=None, document_ids=None, top_k=None: {
"user_id": user_id,
"document_id": document_id,
"embedding_model": settings.EMBEDDING_MODEL,
Expand All @@ -226,6 +226,7 @@ def retrieve(
query: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
top_k: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Expand All @@ -250,6 +251,7 @@ def retrieve(
query_embedding=query_vector,
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
top_k=effective_top_k,
)

Expand All @@ -260,6 +262,7 @@ def retrieve(
query=search_query,
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
top_k=effective_top_k,
)
except Exception as exc:
Expand Down
2 changes: 2 additions & 0 deletions backend/app/rag/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class PDFSearchTool(BaseTool):

user_id: str
document_id: Optional[str] = None
document_ids: Optional[List[str]] = None
top_k: Optional[int] = None
# We'll store sources here to retrieve them after agent execution
last_sources: List[Dict[str, Any]] = []
Expand All @@ -167,6 +168,7 @@ def _run(self, query: str) -> str:
query=query,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)

Expand Down
87 changes: 77 additions & 10 deletions backend/app/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def generate_answer(
question: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[list] = None,
Expand All @@ -482,6 +483,7 @@ def generate_answer(
question=question,
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
hf_token=hf_token,
top_k=top_k,
chat_history=chat_history,
Expand All @@ -492,6 +494,7 @@ def generate_answer_stream(
question: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[list] = None,
Expand All @@ -502,6 +505,7 @@ def generate_answer_stream(
question=question,
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
hf_token=hf_token,
top_k=top_k,
chat_history=chat_history,
Expand Down Expand Up @@ -569,6 +573,33 @@ def ask_question(
doc.last_accessed_at = datetime.now(timezone.utc)
db.commit()

# Validate documents if multiple specified
elif payload.document_ids:
docs = (
db.query(Document)
.filter(
Document.id.in_(payload.document_ids),
Document.user_id == user.id,
Document.is_deleted.is_(False),
)
.all()
)

found_ids = {doc.id for doc in docs}
missing = [doc_id for doc_id in payload.document_ids if doc_id not in found_ids]
if missing:
raise NotFoundException("Document")

not_ready = [doc.original_name for doc in docs if doc.status != "ready"]
if not_ready:
raise ValidationException(
f"Some documents are still processing: {', '.join(not_ready)}. Please wait."
)

for doc in docs:
doc.last_accessed_at = datetime.now(timezone.utc)
db.commit()

# Resolve or create session
session_id = payload.session_id
if not session_id:
Expand All @@ -594,10 +625,13 @@ def ask_question(
recent_messages.reverse()
chat_history = [{"role": m.role, "content": m.content} for m in recent_messages]

cache_doc_key = str(payload.document_id or "")
if payload.document_ids:
cache_doc_key = "multi:" + ",".join(sorted(payload.document_ids))
# Cache check β€” return instantly if this (user, document, question) was answered before
cached_answer = get_cached_response(
user_id=user.id,
document_id=str(payload.document_id or ""),
document_id=cache_doc_key,
question=payload.question,
)
if cached_answer is not None:
Expand All @@ -612,6 +646,7 @@ def ask_question(
question=payload.question,
user_id=user.id,
document_id=payload.document_id,
document_ids=payload.document_ids,
hf_token=user.hf_token,
top_k=payload.top_k,
chat_history=chat_history,
Expand All @@ -628,15 +663,15 @@ def ask_question(
# Store result in cache for future identical questions
set_cached_response(
user_id=user.id,
document_id=str(payload.document_id or ""),
document_id=cache_doc_key,
question=payload.question,
answer=result["answer"],
)

# Save to chat history
_save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id)
_save_message(db, user.id, cache_doc_key, "user", payload.question, session_id=session_id)
_save_message(
db, user.id, payload.document_id, "assistant", result["answer"], result["sources"], session_id=session_id
db, user.id, cache_doc_key, "assistant", result["answer"], result["sources"], session_id=session_id
)

return ChatResponse(
Expand Down Expand Up @@ -705,6 +740,33 @@ def ask_question_stream(
doc.last_accessed_at = datetime.now(timezone.utc)
db.commit()

# Validate documents if multiple specified
elif payload.document_ids:
docs = (
db.query(Document)
.filter(
Document.id.in_(payload.document_ids),
Document.user_id == user.id,
Document.is_deleted.is_(False),
)
.all()
)

found_ids = {doc.id for doc in docs}
missing = [doc_id for doc_id in payload.document_ids if doc_id not in found_ids]
if missing:
raise NotFoundException("Document")

not_ready = [doc.original_name for doc in docs if doc.status != "ready"]
if not_ready:
raise ValidationException(
f"Some documents are still processing: {', '.join(not_ready)}. Please wait."
)

for doc in docs:
doc.last_accessed_at = datetime.now(timezone.utc)
db.commit()

started_at = time.perf_counter()

# Resolve or create session
Expand Down Expand Up @@ -732,13 +794,17 @@ def ask_question_stream(
recent_messages.reverse()
chat_history = [{"role": m.role, "content": m.content} for m in recent_messages]

cache_doc_key = str(payload.document_id or "")
if payload.document_ids:
cache_doc_key = "multi:" + ",".join(sorted(payload.document_ids))

# Save user message immediately
_save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id)
_save_message(db, user.id, cache_doc_key, "user", payload.question, session_id=session_id)

# Cache check before starting the stream
cached_answer = get_cached_response(
user_id=user.d,
document_id=str(payload.document_id or ""),
user_id=user.id,
document_id=cache_doc_key,
question=payload.question,
)
if cached_answer is not None:
Expand Down Expand Up @@ -769,7 +835,8 @@ def event_stream():
for chunk in generate_answer_stream(
question=payload.question,
user_id=user.id,
document_id=payload.document_id,
document_id=cache_doc_key,
document_ids=payload.document_ids,
hf_token=user.hf_token,
top_k=payload.top_k,
chat_history=chat_history,
Expand All @@ -791,7 +858,7 @@ def event_stream():
if full_answer:
set_cached_response(
user_id=user.id,
document_id=str(payload.document_id or ""),
document_id=cache_doc_key,
question=payload.question,
answer=full_answer,
)
Expand All @@ -801,7 +868,7 @@ def event_stream():

with get_db_session() as save_db:
_save_message(
save_db, user.id, payload.document_id, "assistant", full_answer, sources, session_id=session_id
save_db, user.id, cache_doc_key, "assistant", full_answer, sources, session_id=session_id
)

# Log streaming response RAG completion
Expand Down
Loading
Loading