Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions backend/app/rag/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob
import pickle
import logging
import re
from typing import List, Dict, Any, Optional

from app.config import get_settings
Expand All @@ -29,10 +30,11 @@ def get_bm25_path(user_id: str, document_id: str) -> str:
"""Get the file path for a specific document's BM25 index."""
return os.path.join(get_bm25_dir(user_id), f"{document_id}.pkl")

import re

def tokenize(text: str) -> List[str]:
"""Simple tokenization for BM25."""
# Convert to lowercase and split by whitespace
return text.lower().split()
"""Better tokenization for BM25."""
return re.findall(r'\w+', text.lower())

def store_bm25_index(chunks: List[Dict[str, Any]], document_id: str, filename: str, user_id: str):
"""
Expand Down Expand Up @@ -100,6 +102,7 @@ def _query_single_index(path: str, tokenized_query: List[str], top_k: int) -> Li
# BM25 scores are usually > 0, often 1-10.
# We keep the raw score for now, RRF will handle the ranking.
chunk["score"] = float(scores[i])
chunk['id'] = f"bm25_{chunk.get('document_id','unk')}_{chunk.get('page',0)}_{i}"
results.append(chunk)

return results
Expand Down Expand Up @@ -155,3 +158,8 @@ def delete_user_bm25_indexes(user_id: str):
logger.info(f"Deleted BM25 directory for user {user_id}")
except Exception as e:
logger.warning(f"Error deleting BM25 directory for user {user_id}: {e}")

def update_bm25_index(chunks: List[Dict[str, Any]], document_id: str, filename: str, user_id: str):
"""Update existing BM25 index."""
delete_bm25_index(document_id, user_id)
store_bm25_index(chunks, document_id, filename, user_id)
10 changes: 8 additions & 2 deletions backend/app/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ def __init__(self, retrievers, weights=None):

def invoke(self, query):
docs = []
for retriever in self.retrievers:
docs.extend(retriever.invoke(query))
for i, retriever in enumerate(self.retrievers):
retrieved = retriever.invoke(query)
for j, doc in enumerate(retrieved):
doc['id'] = f"{doc.metadata.get('source','unk')}_{doc.metadata.get('page',0)}_{i}_{j}"
docs.append(doc)
return docs
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document as LangchainDocument
Expand All @@ -31,6 +34,7 @@ def invoke(self, query):
from app.rag.embeddings import embed_query
from app.rag.tracing import trace_function
from app.rag.vectorstore import query_chunks
from app.rag.bm25 import store_bm25_index, query_bm25, update_bm25_index

logger = logging.getLogger(__name__)
settings = get_settings()
Expand Down Expand Up @@ -90,6 +94,8 @@ def _get_relevant_documents(
document_id=self.document_id,
top_k=self.top_k,
)
for i, c in enumerate(candidates):
c['id'] = f"bm25_{c.get('document_id','unk')}_{c.get('page',0)}_{i}"
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]


Expand Down
Loading