diff --git a/docstra/core/__init__.py b/docstra/core/__init__.py index db9ac88..d4a8026 100644 --- a/docstra/core/__init__.py +++ b/docstra/core/__init__.py @@ -39,8 +39,10 @@ from docstra.core.llm.local import LocalModelClient from docstra.core.llm.ollama import OllamaClient from docstra.core.llm.openai import OpenAIClient +from docstra.core.ingestion.fts_storage import FtsStorage from docstra.core.retrieval.chroma import ChromaRetriever -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever class docstraant: @@ -86,11 +88,16 @@ def setup_components(self): ] ) + # FTS storage (shared by indexer and retriever) + self.fts_storage = FtsStorage(f"{storage_dir}/index.db") + self.fts_retriever = FtsRetriever(self.fts_storage) + # Document indexer self.document_indexer = DocumentIndexer( self.storage, self.embedding_generator, codebase_root=str(Path.cwd()), + fts_storage=self.fts_storage, ) # Code indexer @@ -110,9 +117,14 @@ def setup_components(self): codebase_root=str(Path.cwd()), ) - # Hybrid retriever - self.hybrid_retriever = HybridRetriever( - self.retriever, self.code_indexer.get_index() + # Fusion retriever + self.fusion_retriever = FusionRetriever( + dense=self.retriever, + fts=self.fts_retriever, + code_index=self.code_indexer.get_index(), + rrf_k=self.config.retrieval.rrf_k, + fts_chunks_top_k=self.config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=self.config.retrieval.fts_symbols_top_k, ) # LLM client @@ -198,6 +210,12 @@ def index_file(self, filepath: str) -> str: doc_id = self.document_indexer.index_document(document) self.code_indexer.index_document(document) + # Write symbols to FTS for this file + manifest = self.code_indexer.get_manifest() + file_symbols = [s for s in manifest.symbols if s.file_id == doc_id] + if file_symbols: + self.fts_storage.add_symbols(file_symbols) + return doc_id def document_code( @@ -287,9 +305,7 @@ def answer_question(self, question: str, n_results: int = 5) -> str: Generated answer """ # Retrieve relevant chunks - results = self.hybrid_retriever.retrieve( - query=question, n_results=n_results, use_code_context=True - ) + results = self.fusion_retriever.retrieve(query=question, n_results=n_results) # Generate answer return self._require_text_response( diff --git a/docstra/core/cli.py b/docstra/core/cli.py index 7ab8cbf..c22a02a 100644 --- a/docstra/core/cli.py +++ b/docstra/core/cli.py @@ -50,7 +50,9 @@ RetrievalEvalSummary, evaluate_retrieval_cases, ) -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.services.initialization_service import InitializationService from docstra.core.services.ingestion_service import IngestionService from docstra.core.services.query_service import QueryService @@ -1689,7 +1691,9 @@ def _get_persist_paths( def _create_retrieval_eval_runner( user_config: UserConfig, abs_codebase_path: Path ) -> Callable[[str, int], List[Dict[str, Any]]]: - _, chroma_path, index_path = _get_persist_paths(user_config, abs_codebase_path) + effective_persist_dir, chroma_path, index_path = _get_persist_paths( + user_config, abs_codebase_path + ) core_index_path = index_path / CORE_INDEX_FILENAME chroma_check_file = chroma_path / "chroma.sqlite3" legacy_index_artifacts = CodebaseIndex.legacy_artifacts_in(index_path) @@ -1730,10 +1734,19 @@ def _create_retrieval_eval_runner( code_index = code_indexer.get_index() if code_index: - hybrid_retriever = HybridRetriever(base_retriever, code_index) + fts_storage = FtsStorage(str(effective_persist_dir / "index.db")) + fts_retriever = FtsRetriever(fts_storage) + fusion_retriever = FusionRetriever( + dense=base_retriever, + fts=fts_retriever, + code_index=code_index, + rrf_k=user_config.retrieval.rrf_k, + fts_chunks_top_k=user_config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=user_config.retrieval.fts_symbols_top_k, + ) def retrieve(question: str, top_k: int) -> List[Dict[str, Any]]: - return hybrid_retriever.retrieve(question, n_results=top_k) + return fusion_retriever.retrieve(question, n_results=top_k) return retrieve diff --git a/docstra/core/config/settings.py b/docstra/core/config/settings.py index 43123c8..b991a3b 100644 --- a/docstra/core/config/settings.py +++ b/docstra/core/config/settings.py @@ -120,6 +120,18 @@ def __init__( self.exclude_patterns = exclude_patterns or [] +class RetrievalConfig: + def __init__( + self, + rrf_k: int = 60, + fts_chunks_top_k: int = 50, + fts_symbols_top_k: int = 25, + ) -> None: + self.rrf_k = rrf_k + self.fts_chunks_top_k = fts_chunks_top_k + self.fts_symbols_top_k = fts_symbols_top_k + + class ConfigManager: def __init__(self, config_path: Optional[str] = None) -> None: self.config_path = config_path or "./.docstra/config.yaml" @@ -180,6 +192,7 @@ def __init__(self) -> None: self.processing = ProcessingConfig() self.ingestion = IngestionConfig() self.documentation = DocumentationConfig() + self.retrieval = RetrievalConfig() def save_to_file(self, path: str) -> None: """Save configuration to YAML file.""" @@ -219,6 +232,11 @@ def save_to_file(self, path: str) -> None: "exclude_patterns": self.ingestion.exclude_patterns, }, "documentation": self.documentation.model_dump(), + "retrieval": { + "rrf_k": self.retrieval.rrf_k, + "fts_chunks_top_k": self.retrieval.fts_chunks_top_k, + "fts_symbols_top_k": self.retrieval.fts_symbols_top_k, + }, } # Write to YAML file @@ -283,3 +301,12 @@ def load_from_file(self, path: str) -> None: self.processing.chunk_overlap = processing_data["chunk_overlap"] if "exclude_patterns" in processing_data: self.processing.exclude_patterns = processing_data["exclude_patterns"] + + if "retrieval" in config_dict: + retrieval_data = config_dict["retrieval"] + if "rrf_k" in retrieval_data: + self.retrieval.rrf_k = retrieval_data["rrf_k"] + if "fts_chunks_top_k" in retrieval_data: + self.retrieval.fts_chunks_top_k = retrieval_data["fts_chunks_top_k"] + if "fts_symbols_top_k" in retrieval_data: + self.retrieval.fts_symbols_top_k = retrieval_data["fts_symbols_top_k"] diff --git a/docstra/core/documentation/generator.py b/docstra/core/documentation/generator.py index 7723a1b..13ab5c5 100644 --- a/docstra/core/documentation/generator.py +++ b/docstra/core/documentation/generator.py @@ -28,7 +28,9 @@ from docstra.core.document_processing.document import Document from docstra.core.indexing.repo_map import RepositoryMap from docstra.core.retrieval.chroma import ChromaRetriever -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.indexing.code_index import CodebaseIndex from docstra.core.documentation.prompts import ( EnhancedDocumentationPrompts, @@ -114,6 +116,8 @@ def __init__( max_workers: Optional[int] = None, documentation_depth: str = "comprehensive", # "overview", "standard", "comprehensive" style_guide: Optional[str] = None, + persist_directory: Optional[Union[str, Path]] = None, + user_config: Optional[Any] = None, ): """Initialize the enhanced documentation generator. @@ -129,6 +133,8 @@ def __init__( max_workers: Maximum number of worker threads documentation_depth: Level of documentation detail to generate style_guide: Custom style guide for documentation + persist_directory: Persist directory root (needed to locate index.db for FTS) + user_config: UserConfig instance for retrieval settings """ self.llm_client = llm_client self.output_dir = Path(output_dir) @@ -145,12 +151,26 @@ def __init__( # Enhanced progress reporting self.progress_reporter = DocumentationProgressReporter(self.console) - # Set up hybrid retriever if available - self.hybrid_retriever = None - if self.chroma_retriever and self.code_index: - self.hybrid_retriever = HybridRetriever( - self.chroma_retriever, self.code_index - ) + # Set up fusion retriever if chroma retriever, code index, and persist dir are available + self.fusion_retriever = None + if self.chroma_retriever and self.code_index and persist_directory: + fts_storage = FtsStorage(str(Path(persist_directory) / "index.db")) + fts_retriever = FtsRetriever(fts_storage) + if user_config and hasattr(user_config, "retrieval"): + self.fusion_retriever = FusionRetriever( + dense=self.chroma_retriever, + fts=fts_retriever, + code_index=self.code_index, + rrf_k=user_config.retrieval.rrf_k, + fts_chunks_top_k=user_config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=user_config.retrieval.fts_symbols_top_k, + ) + else: + self.fusion_retriever = FusionRetriever( + dense=self.chroma_retriever, + fts=fts_retriever, + code_index=self.code_index, + ) # Documentation state self.processed_documents: Dict[str, Document] = {} @@ -695,7 +715,7 @@ def _build_file_context(self, document: Document) -> str: ) # Add cross-references - if self.hybrid_retriever: + if self.fusion_retriever: cross_refs = self._get_file_cross_references(document) if cross_refs: context_parts.append( @@ -754,7 +774,7 @@ def _get_similar_code_examples(self, document: Document) -> List[Dict[str, Any]] def _get_file_cross_references(self, document: Document) -> List[str]: """Get cross-references for a file.""" - if not self.hybrid_retriever or not self.chroma_retriever: + if not self.chroma_retriever: return [] try: diff --git a/docstra/core/indexing/code_index.py b/docstra/core/indexing/code_index.py index 0eaf768..722ae7a 100644 --- a/docstra/core/indexing/code_index.py +++ b/docstra/core/indexing/code_index.py @@ -7,7 +7,7 @@ from collections import defaultdict import os from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union from docstra.core.document_processing.document import Document, DocumentType from docstra.core.indexing.model import ( @@ -55,6 +55,7 @@ def __init__( self._functions_by_name: Dict[str, List[IndexedSymbol]] = defaultdict(list) self._classes_by_name: Dict[str, List[IndexedSymbol]] = defaultdict(list) self._symbols_by_file: Dict[str, List[IndexedSymbol]] = defaultdict(list) + self._chunks_by_file: Dict[str, List[Tuple[str, int, int]]] = defaultdict(list) self._imports_by_source: Dict[str, List[ImportRecord]] = defaultdict(list) self._imports_by_text: Dict[str, List[str]] = defaultdict(list) self._dependencies_by_source: Dict[str, List[str]] = defaultdict(list) @@ -110,6 +111,7 @@ def _rebuild_lookups(self) -> None: self._functions_by_name = defaultdict(list) self._classes_by_name = defaultdict(list) self._symbols_by_file = defaultdict(list) + self._chunks_by_file = defaultdict(list) self._imports_by_source = defaultdict(list) self._imports_by_text = defaultdict(list) self._dependencies_by_source = defaultdict(list) @@ -123,6 +125,13 @@ def _rebuild_lookups(self) -> None: elif symbol.kind == "class": self._classes_by_name[symbol.name].append(symbol) + for chunk in self._manifest.chunks: + self._chunks_by_file[chunk.file_id].append( + (chunk.id, chunk.start_line, chunk.end_line) + ) + for chunks in self._chunks_by_file.values(): + chunks.sort(key=lambda item: item[1]) + for import_record in self._manifest.imports: self._imports_by_source[import_record.source_file_id].append(import_record) self._imports_by_text[import_record.raw_text].append( @@ -455,6 +464,15 @@ def get_related_files(self, filepath: str) -> List[str]: related_files.discard(file_id) return sorted(related_files) + def chunks_for_file(self, file_id: str) -> List[Tuple[str, int, int]]: + """Return (chunk_id, start_line, end_line) tuples for a file in line order.""" + return list(self._chunks_by_file.get(file_id, [])) + + def file_language(self, file_id: str) -> Optional[str]: + """Return the language recorded in the manifest for a file id, if any.""" + entry = self._files_by_id.get(file_id) + return entry.language if entry else None + def clear(self) -> None: """Clear the persisted manifest and in-memory lookups.""" self._manifest = CoreIndexManifest.empty( @@ -552,3 +570,7 @@ def index_documents(self, documents: List[Document]) -> None: def get_index(self) -> CodebaseIndex: """Get the underlying codebase index.""" return self.index + + def get_manifest(self) -> CoreIndexManifest: + """Return the in-memory manifest built during indexing.""" + return self.index.manifest diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py new file mode 100644 index 0000000..28e78ab --- /dev/null +++ b/docstra/core/ingestion/fts_storage.py @@ -0,0 +1,257 @@ +"""SQLite + FTS5 store for lexical retrieval over chunks and symbols.""" + +from __future__ import annotations + +import os +import re +import sqlite3 +import threading +from typing import Any, Dict, List, Optional, Sequence + +from docstra.core.indexing.model import IndexedSymbol + +_FTS_TOKEN_RE = re.compile(r"\w+", re.UNICODE) + + +def _sanitize_fts_query(query: str) -> str: + """Strip FTS5 syntax characters and lowercase the result. + + Tokenizes on word boundaries, joins with spaces, lowercases. The + lowercasing matters: FTS5 boolean operators (NOT/AND/OR) are recognized + only in uppercase, so lowercasing user queries prevents accidental + boolean parsing while leaving the unicode61 tokenizer's case-insensitive + matching unchanged. + """ + tokens = _FTS_TOKEN_RE.findall(query) + return " ".join(tokens).lower() + + +SCHEMA_VERSION = 1 + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS chunks ( + chunk_id TEXT PRIMARY KEY, + file_id TEXT NOT NULL, + language TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_chunks_file ON chunks(file_id); + +CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + content, + content='chunks', + content_rowid='rowid', + tokenize='unicode61 remove_diacritics 2' +); + +CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN + INSERT INTO chunks_fts(rowid, content) VALUES (new.rowid, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN + INSERT INTO chunks_fts(chunks_fts, rowid, content) VALUES('delete', old.rowid, old.content); +END; + +CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN + INSERT INTO chunks_fts(chunks_fts, rowid, content) VALUES('delete', old.rowid, old.content); + INSERT INTO chunks_fts(rowid, content) VALUES (new.rowid, new.content); +END; + +CREATE VIRTUAL TABLE IF NOT EXISTS symbols_fts USING fts5( + symbol_id UNINDEXED, + file_id UNINDEXED, + kind UNINDEXED, + name, + tokenize='unicode61 remove_diacritics 2' +); +""" + + +class FtsStorage: + """SQLite store with FTS5 indexes for chunks and symbols. + + The connection is shared across threads (check_same_thread=False) and all + statements run under an RLock; Python's sqlite3 module does not serialize + concurrent execute() calls on a shared connection on its own. + """ + + def __init__(self, db_path: str) -> None: + self.db_path = db_path + self._lock = threading.RLock() + os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True) + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.row_factory = sqlite3.Row + self._migrate() + + def _migrate(self) -> None: + with self._conn: + self._conn.executescript(_SCHEMA) + current = self._conn.execute( + "SELECT version FROM schema_version" + ).fetchone() + if current is None: + self._conn.execute( + "INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,) + ) + + def close(self) -> None: + self._conn.close() + + # --- chunks --- + + def add_chunks( + self, + *, + chunk_ids: Sequence[str], + file_ids: Sequence[str], + languages: Sequence[str], + start_lines: Sequence[int], + end_lines: Sequence[int], + contents: Sequence[str], + ) -> None: + rows = list( + zip(chunk_ids, file_ids, languages, start_lines, end_lines, contents) + ) + with self._lock, self._conn: + self._conn.executemany( + """ + INSERT INTO chunks (chunk_id, file_id, language, start_line, end_line, content) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(chunk_id) DO UPDATE SET + file_id=excluded.file_id, + language=excluded.language, + start_line=excluded.start_line, + end_line=excluded.end_line, + content=excluded.content + """, + rows, + ) + + def delete_by_file(self, file_id: str) -> None: + with self._lock, self._conn: + self._conn.execute("DELETE FROM chunks WHERE file_id = ?", (file_id,)) + self._conn.execute("DELETE FROM symbols_fts WHERE file_id = ?", (file_id,)) + + def search_chunks( + self, + query: str, + n_results: int = 50, + *, + language: Optional[str] = None, + file_id: Optional[str] = None, + ) -> List[Dict[str, Any]]: + match_query = _sanitize_fts_query(query) + if not match_query: + return [] + clauses = ["chunks_fts MATCH ?"] + params: List[Any] = [match_query] + if language is not None: + clauses.append("chunks.language = ?") + params.append(language) + if file_id is not None: + clauses.append("chunks.file_id = ?") + params.append(file_id) + sql = f""" + SELECT chunks.chunk_id, chunks.file_id, chunks.language, + chunks.start_line, chunks.end_line, chunks.content, + -bm25(chunks_fts) AS score + FROM chunks_fts + JOIN chunks ON chunks.rowid = chunks_fts.rowid + WHERE {" AND ".join(clauses)} + ORDER BY score DESC + LIMIT ? + """ + params.append(n_results) + results = [] + with self._lock: + rows = self._conn.execute(sql, params).fetchall() + for row in rows: + results.append( + { + "id": row["chunk_id"], + "chunk_id": row["chunk_id"], + "file_id": row["file_id"], + "language": row["language"], + "start_line": row["start_line"], + "end_line": row["end_line"], + "content": row["content"], + "score": row["score"], + "metadata": { + "document_id": row["file_id"], + "filepath": row["file_id"], + "start_line": row["start_line"], + "end_line": row["end_line"], + "language": row["language"], + "chunk_type": "code", + }, + } + ) + return results + + def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: + """Return the full row for a single chunk id, or None if absent.""" + sql = ( + "SELECT chunk_id, file_id, language, start_line, end_line, content " + "FROM chunks WHERE chunk_id = ? LIMIT 1" + ) + with self._lock: + row = self._conn.execute(sql, (chunk_id,)).fetchone() + return dict(row) if row else None + + # --- symbols --- + + def add_symbols(self, symbols: List[IndexedSymbol]) -> None: + if not symbols: + return + file_ids = sorted({symbol.file_id for symbol in symbols}) + rows = [(s.id, s.file_id, s.kind, s.name) for s in symbols] + with self._lock, self._conn: + placeholders = ",".join("?" * len(file_ids)) + self._conn.execute( + f"DELETE FROM symbols_fts WHERE file_id IN ({placeholders})", + file_ids, + ) + self._conn.executemany( + "INSERT INTO symbols_fts (symbol_id, file_id, kind, name) VALUES (?, ?, ?, ?)", + rows, + ) + + def search_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: + match_query = _sanitize_fts_query(query) + if not match_query: + return [] + sql = """ + SELECT symbol_id, file_id, kind, name, -bm25(symbols_fts) AS score + FROM symbols_fts + WHERE symbols_fts MATCH ? + ORDER BY score DESC + LIMIT ? + """ + results = [] + with self._lock: + rows = self._conn.execute(sql, (match_query, n_results)).fetchall() + for row in rows: + results.append( + { + "id": row["symbol_id"], + "symbol_id": row["symbol_id"], + "file_id": row["file_id"], + "kind": row["kind"], + "name": row["name"], + "score": row["score"], + "metadata": { + "document_id": row["file_id"], + "filepath": row["file_id"], + "name": row["name"], + "kind": row["kind"], + }, + } + ) + return results diff --git a/docstra/core/ingestion/storage.py b/docstra/core/ingestion/storage.py index ad8ec4e..18ef79f 100644 --- a/docstra/core/ingestion/storage.py +++ b/docstra/core/ingestion/storage.py @@ -15,6 +15,7 @@ from docstra.core.document_processing.document import Document from docstra.core.indexing.model import make_chunk_id, normalize_file_id +from docstra.core.ingestion.fts_storage import FtsStorage ChromaScalar = str | int | float | bool ChromaMetadata = Metadata @@ -399,16 +400,20 @@ def __init__( storage: ChromaDBStorage, embedding_generator: Any, codebase_root: Optional[str] = None, + fts_storage: Optional[FtsStorage] = None, ): """Initialize the document indexer. Args: storage: ChromaDB storage embedding_generator: Generator for creating embeddings + codebase_root: Root directory of the codebase + fts_storage: Optional FTS store for lexical retrieval """ self.storage = storage self.embedding_generator = embedding_generator self.codebase_root = codebase_root + self.fts_storage = fts_storage def _prepare_metadata_for_chroma(self, metadata) -> dict: """Convert document metadata to ChromaDB-compatible format. @@ -537,6 +542,17 @@ def index_document(self, document: Document) -> str: embeddings=chunk_embeddings, ) + if self.fts_storage is not None: + self.fts_storage.delete_by_file(doc_id) + self.fts_storage.add_chunks( + chunk_ids=chunk_ids, + file_ids=[doc_id] * len(chunk_ids), + languages=[str(document.metadata.language)] * len(chunk_ids), + start_lines=[chunk.start_line for chunk in document.chunks], + end_lines=[chunk.end_line for chunk in document.chunks], + contents=chunk_contents, + ) + return persisted_doc_id def index_documents(self, documents: List[Document]) -> List[str]: diff --git a/docstra/core/retrieval/context_aware.py b/docstra/core/retrieval/context_aware.py index 1b59817..4fadf2a 100644 --- a/docstra/core/retrieval/context_aware.py +++ b/docstra/core/retrieval/context_aware.py @@ -10,8 +10,7 @@ from docstra.core.indexing.code_index import CodebaseIndex from docstra.core.indexing.repo_map import RepositoryMap -from docstra.core.retrieval.chroma import ChromaRetriever -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.utils.token_counter import ContextBudgetManager @@ -63,7 +62,7 @@ class ContextAwareRetriever: def __init__( self, - base_retriever: ChromaRetriever, + base_retriever: FusionRetriever, budget_manager: ContextBudgetManager, code_index: Optional[CodebaseIndex] = None, repo_map: Optional[RepositoryMap] = None, @@ -72,12 +71,7 @@ def __init__( self.budget_manager = budget_manager self.code_index = code_index self.repo_map = repo_map - - # Create hybrid retriever if code index available - if code_index: - self.hybrid_retriever = HybridRetriever(base_retriever, code_index) - else: - self.hybrid_retriever = None + self.fusion_retriever = base_retriever def retrieve_with_budget( self, query: str, context_type: str = "query", **kwargs: Any @@ -383,15 +377,10 @@ def _get_general_context( ) -> Dict[str, Any]: """Get balanced general context for queries without clear intent.""" - # Use hybrid retrieval if available, otherwise fall back to base retriever - if self.hybrid_retriever: - results = self.hybrid_retriever.retrieve( - query=query, - n_results=10, # Get more initially for filtering - use_code_context=True, - ) - else: - results = self.base_retriever.retrieve_chunks(query=query, n_results=8) + results = self.fusion_retriever.retrieve( + query=query, + n_results=10, # Get more initially for filtering + ) # Budget-aware context assembly context_parts = {} @@ -571,12 +560,9 @@ def _get_targeted_code_samples( ) -> Optional[str]: """Get targeted code samples based on query analysis.""" - if not self.hybrid_retriever: - return None - # Get code examples using hybrid retrieval try: - examples = self.hybrid_retriever.retrieve_code_examples( + examples = self.fusion_retriever.retrieve_code_examples( query=query, n_results=3 ) diff --git a/docstra/core/retrieval/fts.py b/docstra/core/retrieval/fts.py new file mode 100644 index 0000000..527e320 --- /dev/null +++ b/docstra/core/retrieval/fts.py @@ -0,0 +1,33 @@ +"""Retriever that delegates lexical search to FtsStorage.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from docstra.core.ingestion.fts_storage import FtsStorage + + +class FtsRetriever: + """Thin wrapper exposing FtsStorage searches to higher-level retrievers.""" + + def __init__(self, storage: FtsStorage) -> None: + self.storage = storage + + def retrieve_chunks( + self, + query: str, + n_results: int = 50, + *, + language: Optional[str] = None, + file_id: Optional[str] = None, + **_unused_filters: Any, + ) -> List[Dict[str, Any]]: + return self.storage.search_chunks( + query, n_results=n_results, language=language, file_id=file_id + ) + + def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: + return self.storage.search_symbols(query, n_results=n_results) + + def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: + return self.storage.get_chunk(chunk_id) diff --git a/docstra/core/retrieval/fusion.py b/docstra/core/retrieval/fusion.py new file mode 100644 index 0000000..f898a8a --- /dev/null +++ b/docstra/core/retrieval/fusion.py @@ -0,0 +1,249 @@ +"""Reciprocal Rank Fusion over dense + lexical retrieval sources.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Protocol + +from docstra.core.indexing.code_index import CodebaseIndex + + +def rrf_score(rank: int, k: int) -> float: + """Standard RRF contribution for a single source at 1-based rank.""" + return 1.0 / (k + rank) + + +class _DenseLike(Protocol): + def retrieve_chunks( + self, query: str, n_results: int = 20, **filters + ) -> List[Dict[str, Any]]: ... + + +class _FtsLike(Protocol): + def retrieve_chunks( + self, query: str, n_results: int = 50, **filters + ) -> List[Dict[str, Any]]: ... + def retrieve_symbols( + self, query: str, n_results: int = 25 + ) -> List[Dict[str, Any]]: ... + def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: ... + + +class FusionRetriever: + """Runs dense + lexical (chunks, symbols) retrieval and fuses with RRF.""" + + def __init__( + self, + dense: _DenseLike, + fts: _FtsLike, + code_index: CodebaseIndex, + *, + rrf_k: int = 60, + fts_chunks_top_k: int = 50, + fts_symbols_top_k: int = 25, + ) -> None: + self.dense = dense + self.fts = fts + self.code_index = code_index + self.rrf_k = rrf_k + self.fts_chunks_top_k = fts_chunks_top_k + self.fts_symbols_top_k = fts_symbols_top_k + + def retrieve( + self, query: str, n_results: int = 20, **filters + ) -> List[Dict[str, Any]]: + return self.retrieve_chunks(query, n_results=n_results, **filters) + + def retrieve_chunks( + self, query: str, n_results: int = 20, **filters + ) -> List[Dict[str, Any]]: + dense_hits = self.dense.retrieve_chunks( + query, n_results=n_results * 2, **filters + ) + lex_chunk_hits = self.fts.retrieve_chunks( + query, n_results=self.fts_chunks_top_k, **filters + ) + symbol_hits = self.fts.retrieve_symbols(query, n_results=self.fts_symbols_top_k) + symbol_chunk_hits = self._symbols_to_chunks(symbol_hits, filters) + + scored: Dict[str, float] = defaultdict(float) + record: Dict[str, Dict[str, Any]] = {} + for source in (dense_hits, lex_chunk_hits, symbol_chunk_hits): + for rank, hit in enumerate(source, start=1): + chunk_id = self._chunk_id(hit) + if chunk_id is None: + continue + scored[chunk_id] += rrf_score(rank, self.rrf_k) + record.setdefault(chunk_id, hit) + + ordered = sorted(record.items(), key=lambda kv: (-scored[kv[0]], kv[0])) + return [hit for _, hit in ordered[:n_results]] + + def retrieve_by_language( + self, query: str, language: str, n_results: int = 20 + ) -> List[Dict[str, Any]]: + return self.retrieve_chunks(query, n_results=n_results, language=language) + + def retrieve_code_examples( + self, query: str, n_results: int = 10, languages: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: + """Retrieve chunks that are good examples of the queried concept. + + Args: + query: Query string + n_results: Number of results to return + languages: Optional list of languages to filter by + + Returns: + List of example chunks + """ + # Start with basic vector search + filters = {} + if languages: + # We'll retrieve for each language separately + all_results = [] + for language in languages: + results = self.retrieve_by_language( + query=query, + language=language, + n_results=max(n_results // len(languages), 1), + ) + all_results.extend(results) + + vector_results = all_results + else: + vector_results = self.retrieve_chunks( + query=query, + n_results=n_results * 2, # Get more for filtering + **filters, + ) + + # Filter for chunks that are likely to be good examples + # - Prefer complete functions/methods + # - Prefer moderately sized chunks (not too short, not too long) + # - Prefer chunks with meaningful names + good_examples = [] + + for chunk in vector_results: + chunk_type = chunk["metadata"].get("chunk_type", "") + content = chunk["content"] + + # Score the chunk as an example + example_score = 0.0 + + # Prefer functions/methods + if chunk_type in ["function", "method"]: + example_score += 1.0 + + # Check content length (not too short, not too long) + lines = content.count("\n") + 1 + if 5 <= lines <= 50: + example_score += 0.5 + + # Look for meaningful names (more than 3 characters, not generic) + symbols = chunk["metadata"].get("symbols", []) + generic_symbols = [ + "main", + "init", + "test", + "get", + "set", + "run", + "func", + "foo", + "bar", + ] + + for symbol in symbols: + if len(symbol) > 3 and symbol.lower() not in generic_symbols: + example_score += 0.3 + break + + chunk_id = self._chunk_id(chunk) + if example_score > 0: + good_examples.append( + { + "chunk_id": chunk_id, + "id": chunk_id, + "content": content, + "metadata": chunk["metadata"], + "score": example_score, + } + ) + + # Sort by combined score and return top results + sorted_examples = sorted( + good_examples, + key=lambda x: x.get("score", 0), + reverse=True, # Higher score is better + ) + + return sorted_examples[:n_results] + + def _chunk_id(self, hit: Dict[str, Any]) -> Optional[str]: + return hit.get("chunk_id") or hit.get("id") + + def _symbols_to_chunks( + self, symbol_hits: Iterable[Dict[str, Any]], filters: Dict[str, Any] + ) -> List[Dict[str, Any]]: + language = filters.get("language") + file_id_filter = filters.get("file_id") + results: List[Dict[str, Any]] = [] + seen: set[str] = set() + for symbol_hit in symbol_hits: + file_id = symbol_hit.get("file_id") + if file_id is None: + continue + if file_id_filter is not None and file_id != file_id_filter: + continue + if ( + language is not None + and self.code_index.file_language(file_id) != language + ): + continue + symbol_id = symbol_hit.get("symbol_id", "") + line = _extract_line_from_symbol_id(symbol_id) + if line is None: + continue + for chunk_id, start_line, end_line in self.code_index.chunks_for_file( + file_id + ): + if start_line <= line <= end_line and chunk_id not in seen: + seen.add(chunk_id) + chunk_row = self.fts.get_chunk(chunk_id) + content = chunk_row["content"] if chunk_row else "" + lang = chunk_row["language"] if chunk_row else "" + results.append( + { + "chunk_id": chunk_id, + "id": chunk_id, + "file_id": file_id, + "language": lang, + "start_line": start_line, + "end_line": end_line, + "content": content, + "metadata": { + "document_id": file_id, + "filepath": file_id, + "start_line": start_line, + "end_line": end_line, + "language": lang, + "chunk_type": "code", + "via_symbol": symbol_hit.get("name"), + }, + } + ) + break + return results + + +def _extract_line_from_symbol_id(symbol_id: str) -> Optional[int]: + if not symbol_id: + return None + tail = symbol_id.rsplit("::", 1)[-1] + if not tail.startswith("L"): + return None + try: + return int(tail[1:]) + except ValueError: + return None diff --git a/docstra/core/retrieval/hybrid.py b/docstra/core/retrieval/hybrid.py deleted file mode 100644 index 35d80b4..0000000 --- a/docstra/core/retrieval/hybrid.py +++ /dev/null @@ -1,531 +0,0 @@ -# File: ./docstra/core/retrieval/hybrid.py - -""" -Hybrid retrieval strategies for enhanced document search. -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from docstra.core.indexing.code_index import CodebaseIndex -from docstra.core.retrieval.chroma import ChromaRetriever - - -class HybridRetriever: - """Hybrid retriever combining vector search with structural code information.""" - - def __init__( - self, retriever: ChromaRetriever, code_index: Optional[CodebaseIndex] = None - ): - """Initialize the hybrid retriever. - - Args: - retriever: Base retriever for vector search - code_index: Code index for structural information - """ - self.retriever = retriever - self.code_index = code_index - - def retrieve( - self, query: str, n_results: int = 20, use_code_context: bool = True, **filters - ) -> List[Dict[str, Any]]: - """Perform hybrid retrieval using both vector search and code structure. - - Args: - query: Query string - n_results: Number of results to return - use_code_context: Whether to use code context for re-ranking - **filters: Additional filters to apply - - Returns: - List of matching chunks - """ - # Start with vector search - vector_results = self.retriever.retrieve_chunks( - query=query, - n_results=n_results * 2, # Get more results for reranking - **filters, - ) - - if not use_code_context or not self.code_index: - # If code context not requested or not available, return vector results - return vector_results[:n_results] - - # Extract potential code symbols from query - potential_symbols = self._extract_potential_symbols(query) - - # Find symbols in the code index - symbol_matches = [] - for symbol in potential_symbols: - symbol_locs = self.code_index.search_symbol(symbol) - if symbol_locs: - symbol_matches.extend(symbol_locs) - - # Re-rank results based on symbol matches - reranked_results = self._rerank_with_symbol_matches( - vector_results, symbol_matches, n_results - ) - - return reranked_results - - def _extract_potential_symbols(self, query: str) -> List[str]: - """Extract potential code symbols from a query. - - Args: - query: Query string - - Returns: - List of potential symbols - """ - # This is a simplified approach. A more sophisticated approach would use - # NLP techniques to identify potential code symbols. - - # Split by whitespace and punctuation - words = query.split() - - # Filter out common words and keep only likely symbols - stop_words = { - "the", - "a", - "an", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "as", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "do", - "does", - "did", - "can", - "could", - "will", - "would", - "shall", - "should", - "may", - "might", - "must", - "how", - "when", - "where", - "why", - "what", - "who", - "whom", - "which", - "if", - "then", - "else", - "so", - "such", - "and", - "or", - "not", - "no", - "yes", - "this", - "that", - "these", - "those", - "code", - "function", - "method", - "class", - "variable", - "import", - "implement", - "define", - "declaration", - } - - symbols = [ - word.strip(",.()[]{}:;\"'") - for word in words - if word.strip(",.()[]{}:;\"'").lower() not in stop_words - and len(word.strip(",.()[]{}:;\"'")) >= 2 # Minimum length - ] - - return symbols - - def _rerank_with_symbol_matches( - self, - vector_results: List[Dict[str, Any]], - symbol_matches: List[Dict[str, Any]], - n_results: int, - ) -> List[Dict[str, Any]]: - """Re-rank results based on symbol matches. - - Args: - vector_results: Results from vector search - symbol_matches: Matches from symbol search - n_results: Number of results to return - - Returns: - Re-ranked results - """ - # Create a set of files with symbol matches - symbol_files = {match["filepath"] for match in symbol_matches} - - # Create a dictionary to track scores - result_scores: Dict[str, float] = {} - - # Score based on vector search position - for i, result in enumerate(vector_results): - chunk_id = result["id"] - # Base score from vector search (higher for early results) - score = 1.0 - (i / len(vector_results)) - - # Get document ID from chunk metadata - doc_id = result["metadata"].get("document_id", "") - - # Boost score if document contains symbol matches - if doc_id in symbol_files: - score += 0.5 - - result_scores[chunk_id] = score - - # Sort results by score - reranked_ids = sorted( - result_scores.keys(), key=lambda x: result_scores[x], reverse=True - ) - - # Create final results list - reranked_results = [] - id_to_result = {result["id"]: result for result in vector_results} - - for chunk_id in reranked_ids[:n_results]: - if chunk_id in id_to_result: - reranked_results.append(id_to_result[chunk_id]) - - return reranked_results - - def retrieve_for_function( - self, query: str, function_name: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve chunks relevant to a specific function. - - Args: - query: Query string - function_name: Name of the function - n_results: Number of results to return - - Returns: - List of matching chunks - """ - if not self.code_index: - return self.retriever.retrieve_chunks(query, n_results) - - # Find files containing the function - function_locs = self.code_index.search_function(function_name) - - if not function_locs: - return self.retriever.retrieve_chunks(query, n_results) - - # Get relevant file paths - file_paths = [loc["filepath"] for loc in function_locs] - - # Combine results from each file - all_results = [] - for filepath in file_paths: - results = self.retriever.retrieve_by_filepath( - query=query, filepath=filepath, n_results=n_results - ) - all_results.extend(results) - - # Sort by relevance and return top results - sorted_results = sorted( - all_results, - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] - - def retrieve_for_class( - self, query: str, class_name: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve chunks relevant to a specific class. - - Args: - query: Query string - class_name: Name of the class - n_results: Number of results to return - - Returns: - List of matching chunks - """ - if not self.code_index: - return self.retriever.retrieve_chunks(query, n_results) - - # Find files containing the class - class_locs = self.code_index.search_class(class_name) - - if not class_locs: - return self.retriever.retrieve_chunks(query, n_results) - - # Get relevant file paths - file_paths = [loc["filepath"] for loc in class_locs] - - # Combine results from each file - all_results = [] - for filepath in file_paths: - results = self.retriever.retrieve_by_filepath( - query=query, filepath=filepath, n_results=n_results - ) - all_results.extend(results) - - # Sort by relevance and return top results - sorted_results = sorted( - all_results, - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] - - def retrieve_related_code( - self, query: str, chunk_id: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve code chunks related to a specific chunk. - - Args: - query: Query string - chunk_id: ID of the chunk to find related code for - n_results: Number of results to return - - Returns: - List of related chunks - """ - # Extract document ID from chunk ID - if "#" in chunk_id: - document_id = chunk_id.split("#")[0] - else: - document_id = chunk_id - - # Get all chunks for the document - document_chunks = self.retriever.get_chunks_for_document(document_id) - - if not document_chunks: - return self.retriever.retrieve_chunks(query, n_results) - - # If we have a code index, also get related files - related_files = [] - if self.code_index: - related_files = self.code_index.get_related_files(document_id) - - # Combine results from document chunks and related files - all_results = [] - - # Add document chunks (with high priority) - for chunk in document_chunks: - all_results.append( - { - "id": chunk["id"], - "content": chunk["content"], - "metadata": chunk["metadata"], - "score": 0.0, # High priority - } - ) - - # Get chunks from related files - if related_files: - for filepath in related_files: - results = self.retriever.retrieve_by_filepath( - query=query, filepath=filepath, n_results=n_results - ) - all_results.extend(results) - - # Remove duplicates (keeping highest score) - unique_results = {} - for result in all_results: - result_id = result["id"] - if result_id not in unique_results or ( - result.get("score", float("inf")) - < unique_results[result_id].get("score", float("inf")) - ): - unique_results[result_id] = result - - # Sort by score and return top results - sorted_results = sorted( - unique_results.values(), - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] - - def retrieve_code_examples( - self, query: str, n_results: int = 10, languages: Optional[List[str]] = None - ) -> List[Dict[str, Any]]: - """Retrieve chunks that are good examples of the queried concept. - - Args: - query: Query string - n_results: Number of results to return - languages: Optional list of languages to filter by - - Returns: - List of example chunks - """ - # Start with basic vector search - filters = {} - if languages: - # We'll retrieve for each language separately - all_results = [] - for language in languages: - results = self.retriever.retrieve_by_language( - query=query, - language=language, - n_results=max(n_results // len(languages), 1), - ) - all_results.extend(results) - - vector_results = all_results - else: - vector_results = self.retriever.retrieve_chunks( - query=query, - n_results=n_results * 2, # Get more for filtering - **filters, - ) - - # Filter for chunks that are likely to be good examples - # - Prefer complete functions/methods - # - Prefer moderately sized chunks (not too short, not too long) - # - Prefer chunks with meaningful names - good_examples = [] - - for chunk in vector_results: - chunk_type = chunk["metadata"].get("chunk_type", "") - content = chunk["content"] - - # Score the chunk as an example - example_score = 0.0 - - # Prefer functions/methods - if chunk_type in ["function", "method"]: - example_score += 1.0 - - # Check content length (not too short, not too long) - lines = content.count("\n") + 1 - if 5 <= lines <= 50: - example_score += 0.5 - - # Look for meaningful names (more than 3 characters, not generic) - symbols = chunk["metadata"].get("symbols", []) - generic_symbols = [ - "main", - "init", - "test", - "get", - "set", - "run", - "func", - "foo", - "bar", - ] - - for symbol in symbols: - if len(symbol) > 3 and symbol.lower() not in generic_symbols: - example_score += 0.3 - break - - # Use original vector score - vector_score = chunk.get("score", 0) - if vector_score is not None: - # Combine scores (vector score is typically a distance, so lower is better) - combined_score = example_score - vector_score - else: - combined_score = example_score - - good_examples.append( - { - "id": chunk["id"], - "content": content, - "metadata": chunk["metadata"], - "score": combined_score, - "original_score": vector_score, - } - ) - - # Sort by combined score and return top results - sorted_examples = sorted( - good_examples, - key=lambda x: x.get("score", 0), - reverse=True, # Higher score is better - ) - - return sorted_examples[:n_results] - - def retrieve_implementation_details( - self, query: str, symbol: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve implementation details for a specific symbol. - - Args: - query: Query string - symbol: Symbol to find implementation details for - n_results: Number of results to return - - Returns: - List of chunks with implementation details - """ - if not self.code_index: - # Without code index, fall back to basic search - return self.retriever.retrieve_chunks(query, n_results) - - # Find symbol locations - symbol_locs = self.code_index.search_symbol(symbol) - - if not symbol_locs: - # Try function and class indexes if symbol not found - function_locs = self.code_index.search_function(symbol) - class_locs = self.code_index.search_class(symbol) - - symbol_locs = function_locs + class_locs - - if not symbol_locs: - # If still not found, fall back to basic search - return self.retriever.retrieve_chunks(query, n_results) - - # Get relevant file paths - file_paths = [loc["filepath"] for loc in symbol_locs] - - # Combine results from each file - all_results = [] - for filepath in file_paths: - results = self.retriever.retrieve_by_filepath( - query=( - query if query else symbol - ), # Use symbol as query if no query provided - filepath=filepath, - n_results=n_results, - ) - all_results.extend(results) - - # Sort by relevance and return top results - sorted_results = sorted( - all_results, - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] diff --git a/docstra/core/services/documentation_service.py b/docstra/core/services/documentation_service.py index a94cbd9..9fb8d33 100644 --- a/docstra/core/services/documentation_service.py +++ b/docstra/core/services/documentation_service.py @@ -314,6 +314,8 @@ def generate_documentation( max_workers=effective_max_workers, documentation_depth="comprehensive", style_guide=effective_llm_style_prompt, + persist_directory=abs_persist_directory, + user_config=self.user_config, ) self.console.print( diff --git a/docstra/core/services/ingestion_service.py b/docstra/core/services/ingestion_service.py index 313cfc3..b7fd934 100644 --- a/docstra/core/services/ingestion_service.py +++ b/docstra/core/services/ingestion_service.py @@ -29,6 +29,7 @@ SyntaxAwareChunking, ) from docstra.core.ingestion.embeddings import EmbeddingFactory +from docstra.core.ingestion.fts_storage import FtsStorage from docstra.core.ingestion.storage import ChromaDBStorage, DocumentIndexer from docstra.core.indexing.code_index import CodebaseIndex, CodebaseIndexer from docstra.core.indexing.model import CORE_INDEX_FILENAME @@ -90,6 +91,9 @@ def ingest_codebase( legacy_repo_map = persist_directory / "repo_map.json" if legacy_repo_map.exists(): legacy_repo_map.unlink() + index_db_path = persist_directory / "index.db" + if index_db_path.exists(): + index_db_path.unlink() index_path = persist_directory / "index" core_index_path = index_path / CORE_INDEX_FILENAME @@ -108,6 +112,9 @@ def ingest_codebase( shutil.rmtree(index_path) if legacy_repo_map.exists(): legacy_repo_map.unlink() + index_db_path = persist_directory / "index.db" + if index_db_path.exists(): + index_db_path.unlink() # Check if already indexed and not forcing if core_index_path.exists() and not force: @@ -147,11 +154,13 @@ def ingest_codebase( ) storage = ChromaDBStorage(persist_directory=str(persist_directory / "chroma")) + fts_storage = FtsStorage(str(persist_directory / "index.db")) doc_indexer = DocumentIndexer( storage, embedding_generator, codebase_root=str(codebase_path_abs), + fts_storage=fts_storage, ) code_indexer = CodebaseIndexer( @@ -272,6 +281,9 @@ def ingest_codebase( doc_indexer.index_documents(documents) code_indexer.index_documents(documents) + manifest = code_indexer.get_manifest() + fts_storage.add_symbols(list(manifest.symbols)) + progress.update( task_index, completed=True, description="[green]Indexed all documents" ) diff --git a/docstra/core/services/query_service.py b/docstra/core/services/query_service.py index 7b72551..47fc945 100644 --- a/docstra/core/services/query_service.py +++ b/docstra/core/services/query_service.py @@ -20,8 +20,10 @@ from docstra.core.retrieval.chroma import ChromaRetriever from docstra.core.indexing.code_index import CodebaseIndex, CodebaseIndexer from docstra.core.indexing.model import CORE_INDEX_FILENAME -from docstra.core.retrieval.hybrid import HybridRetriever from docstra.core.retrieval.context_aware import ContextAwareRetriever +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.utils.token_counter import get_token_counter, ContextBudgetManager @@ -94,7 +96,9 @@ def __init__( self.storage: Optional[ChromaDBStorage] = None self.retriever: Optional[ChromaRetriever] = None self.code_indexer: Optional[CodebaseIndexer] = None - self.hybrid_retriever: Optional[HybridRetriever] = None + self.fts_storage: Optional[FtsStorage] = None + self.fts_retriever: Optional[FtsRetriever] = None + self.fusion_retriever: Optional[FusionRetriever] = None self.context_aware_retriever: Optional[ContextAwareRetriever] = None self.token_counter = get_token_counter( self.user_config.model.model_name, self.user_config.model.provider @@ -135,17 +139,32 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): legacy_index_artifacts = CodebaseIndex.legacy_artifacts_in(index_path) legacy_repo_map = effective_persist_dir / "repo_map.json" - if not core_index_path.exists() or not chroma_check_file.exists(): + index_db_path = effective_persist_dir / "index.db" + if ( + not core_index_path.exists() + or not chroma_check_file.exists() + or not index_db_path.exists() + ): migration_hint = "" if legacy_index_artifacts or legacy_repo_map.exists(): migration_hint = ( " Legacy index artifacts were found. Rerun 'docstra ingest' " "to rebuild the index in the new format." ) + if ( + not index_db_path.exists() + and core_index_path.exists() + and chroma_check_file.exists() + ): + migration_hint += ( + " The lexical index (.docstra/index.db) is missing — likely an older " + "ingest. Rerun 'docstra ingest' to rebuild it." + ) error_msg = ( f"Codebase at {abs_codebase_path} not fully initialized for querying. " f"ChromaDB path: {chroma_path} (check file: {chroma_check_file}, exists: {chroma_check_file.exists()}), " - f"Core index path: {core_index_path} (exists: {core_index_path.exists()}). " + f"Core index path: {core_index_path} (exists: {core_index_path.exists()}), " + f"Lexical index: {index_db_path} (exists: {index_db_path.exists()}). " "Run 'docstra init' and 'docstra ingest' first." f"{migration_hint}" ) @@ -166,7 +185,17 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): code_index_instance = self.code_indexer.get_index() if code_index_instance is None: raise ValueError(f"Failed to load code index from {index_path}") - self.hybrid_retriever = HybridRetriever(self.retriever, code_index_instance) + + self.fts_storage = FtsStorage(str(effective_persist_dir / "index.db")) + self.fts_retriever = FtsRetriever(self.fts_storage) + self.fusion_retriever = FusionRetriever( + dense=self.retriever, + fts=self.fts_retriever, + code_index=code_index_instance, + rrf_k=self.user_config.retrieval.rrf_k, + fts_chunks_top_k=self.user_config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=self.user_config.retrieval.fts_symbols_top_k, + ) # Initialize context-aware retriever repo_map = None @@ -181,7 +210,7 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): ) self.context_aware_retriever = ContextAwareRetriever( - base_retriever=self.retriever, + base_retriever=self.fusion_retriever, budget_manager=self.budget_manager, code_index=code_index_instance, repo_map=repo_map, diff --git a/tests/test_core_index.py b/tests/test_core_index.py index 5c90a20..8c3e728 100644 --- a/tests/test_core_index.py +++ b/tests/test_core_index.py @@ -479,3 +479,74 @@ def test_codebase_index_rejects_legacy_sidecars_without_core_manifest( with pytest.raises(FileNotFoundError, match="Rerun 'docstra ingest'"): CodebaseIndex(index_directory=str(index_dir), codebase_root=str(tmp_path)) + + +def test_chunks_for_file_returns_chunks_in_line_order(tmp_path): + from docstra.core.indexing.code_index import CodebaseIndex + from docstra.core.indexing.model import CoreIndexManifest, IndexedChunk + + index = CodebaseIndex(index_directory=str(tmp_path / "index")) + index._manifest = CoreIndexManifest.empty() + index._manifest.chunks.extend( + [ + IndexedChunk( + id="a.py#L1-L10", + file_id="a.py", + language="python", + start_line=1, + end_line=10, + chunk_type="code", + ), + IndexedChunk( + id="a.py#L11-L20", + file_id="a.py", + language="python", + start_line=11, + end_line=20, + chunk_type="code", + ), + IndexedChunk( + id="b.py#L1-L5", + file_id="b.py", + language="python", + start_line=1, + end_line=5, + chunk_type="code", + ), + ] + ) + index._rebuild_lookups() + + assert index.chunks_for_file("a.py") == [ + ("a.py#L1-L10", 1, 10), + ("a.py#L11-L20", 11, 20), + ] + assert index.chunks_for_file("missing.py") == [] + + # Verify sorting is enforced even when chunks are inserted in reverse line order. + index2 = CodebaseIndex(index_directory=str(tmp_path / "index2")) + index2._manifest = CoreIndexManifest.empty() + index2._manifest.chunks.extend( + [ + IndexedChunk( + id="c.py#L100-L110", + file_id="c.py", + language="python", + start_line=100, + end_line=110, + chunk_type="code", + ), + IndexedChunk( + id="c.py#L1-L10", + file_id="c.py", + language="python", + start_line=1, + end_line=10, + chunk_type="code", + ), + ] + ) + index2._rebuild_lookups() + + result = index2.chunks_for_file("c.py") + assert result == [("c.py#L1-L10", 1, 10), ("c.py#L100-L110", 100, 110)] diff --git a/tests/test_fts_retriever.py b/tests/test_fts_retriever.py new file mode 100644 index 0000000..e08ebc3 --- /dev/null +++ b/tests/test_fts_retriever.py @@ -0,0 +1,79 @@ +"""Coverage for the FTS-backed retriever.""" + +from pathlib import Path + +from docstra.core.indexing.model import IndexedSymbol +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever + + +def test_retrieve_chunks_delegates_to_storage(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + store.add_chunks( + chunk_ids=["repo/file.py#L1-L10"], + file_ids=["repo/file.py"], + languages=["python"], + start_lines=[1], + end_lines=[10], + contents=["def make_chunk_id(): pass"], + ) + retriever = FtsRetriever(store) + hits = retriever.retrieve_chunks("make_chunk_id", n_results=5) + assert len(hits) == 1 + assert hits[0]["chunk_id"] == "repo/file.py#L1-L10" + assert hits[0]["metadata"]["document_id"] == hits[0]["file_id"] + + +def test_retrieve_symbols_delegates_to_storage(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols( + [ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ) + ] + ) + retriever = FtsRetriever(store) + hits = retriever.retrieve_symbols("foo", n_results=5) + assert len(hits) == 1 + assert hits[0]["name"] == "foo" + assert hits[0]["metadata"]["document_id"] == "x.py" + + +def test_get_chunk_returns_row_and_none(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + store.add_chunks( + chunk_ids=["repo/a.py#L1-L5"], + file_ids=["repo/a.py"], + languages=["python"], + start_lines=[1], + end_lines=[5], + contents=["def bar(): pass"], + ) + retriever = FtsRetriever(store) + row = retriever.get_chunk("repo/a.py#L1-L5") + assert row is not None + assert row["chunk_id"] == "repo/a.py#L1-L5" + assert retriever.get_chunk("missing#L1-L1") is None + + +def test_retrieve_chunks_ignores_unknown_filter_keys(tmp_path: Path): + """Unknown filter keys must not crash; FtsRetriever silently ignores them like Chroma.""" + store = FtsStorage(str(tmp_path / "index.db")) + store.add_chunks( + chunk_ids=["a.py#L1-L1"], + file_ids=["a.py"], + languages=["python"], + start_lines=[1], + end_lines=[1], + contents=["def foo(): pass"], + ) + retriever = FtsRetriever(store) + # document_id is a Chroma-style key not supported by FTS — must not raise. + hits = retriever.retrieve_chunks("foo", n_results=5, document_id="ignored") + assert len(hits) == 1 diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py new file mode 100644 index 0000000..8f25a52 --- /dev/null +++ b/tests/test_fts_storage.py @@ -0,0 +1,229 @@ +"""Coverage for the FTS5-backed lexical store.""" + +from pathlib import Path + +from docstra.core.ingestion.fts_storage import FtsStorage, _sanitize_fts_query + + +def _add_chunk(store: FtsStorage, **overrides): + payload = dict( + chunk_id="repo/file.py#L1-L10", + file_id="repo/file.py", + language="python", + start_line=1, + end_line=10, + content="def make_chunk_id(file_id, start_line, end_line):\n return f'{file_id}#L{start_line}-L{end_line}'", + ) + payload.update(overrides) + store.add_chunks( + chunk_ids=[payload["chunk_id"]], + file_ids=[payload["file_id"]], + languages=[payload["language"]], + start_lines=[payload["start_line"]], + end_lines=[payload["end_line"]], + contents=[payload["content"]], + ) + return payload + + +def test_schema_creates_on_first_open(tmp_path: Path): + db_path = tmp_path / "index.db" + FtsStorage(str(db_path)) + assert db_path.exists() + # Idempotent: opening again does not raise. + FtsStorage(str(db_path)) + + +def test_add_and_search_chunks(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + hits = store.search_chunks("make_chunk_id", n_results=5) + assert len(hits) == 1 + assert hits[0]["chunk_id"] == "repo/file.py#L1-L10" + assert hits[0]["file_id"] == "repo/file.py" + assert hits[0]["start_line"] == 1 + assert hits[0]["end_line"] == 10 + assert "make_chunk_id" in hits[0]["content"] + # Shape contract: metadata sub-dict must be present and document_id must match file_id. + assert hits[0]["metadata"]["document_id"] == hits[0]["file_id"] + + +def test_delete_by_file_removes_chunks_and_fts(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + store.delete_by_file("repo/file.py") + assert store.search_chunks("make_chunk_id", n_results=5) == [] + + +def test_search_supports_language_filter(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk( + store, + chunk_id="a.py#L1-L1", + file_id="a.py", + language="python", + content="foo bar", + ) + _add_chunk( + store, + chunk_id="b.ts#L1-L1", + file_id="b.ts", + language="typescript", + content="foo bar", + ) + hits = store.search_chunks("foo", n_results=5, language="python") + assert {h["file_id"] for h in hits} == {"a.py"} + + +def test_add_and_search_symbols(tmp_path: Path): + from docstra.core.indexing.model import IndexedSymbol + + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols( + [ + IndexedSymbol( + id="repo/file.py::function::make_chunk_id::L67", + file_id="repo/file.py", + name="make_chunk_id", + kind="function", + language="python", + line=67, + ), + IndexedSymbol( + id="repo/other.py::class::CoreIndexBuilder::L194", + file_id="repo/other.py", + name="CoreIndexBuilder", + kind="class", + language="python", + line=194, + ), + ] + ) + hits = store.search_symbols("CoreIndexBuilder", n_results=5) + assert len(hits) == 1 + assert hits[0]["name"] == "CoreIndexBuilder" + assert hits[0]["file_id"] == "repo/other.py" + assert hits[0]["id"] == hits[0]["symbol_id"] + assert hits[0]["metadata"]["document_id"] == hits[0]["file_id"] + assert hits[0]["metadata"]["name"] == "CoreIndexBuilder" + + +def test_delete_by_file_removes_symbols(tmp_path: Path): + from docstra.core.indexing.model import IndexedSymbol + + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols( + [ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ) + ] + ) + store.delete_by_file("x.py") + assert store.search_symbols("foo", n_results=5) == [] + + +def test_sanitize_fts_query_strips_special_chars(): + assert _sanitize_fts_query("What is the config?") == "what is the config" + assert _sanitize_fts_query("(foo OR bar)") == "foo or bar" + assert _sanitize_fts_query('search "exact phrase"') == "search exact phrase" + assert _sanitize_fts_query("a*b") == "a b" + assert _sanitize_fts_query("") == "" + assert _sanitize_fts_query("???") == "" + + +def test_punctuation_query_does_not_raise(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + # FTS5 would raise sqlite3.OperationalError without sanitization. + hits = store.search_chunks("What is the config?", n_results=5) + assert isinstance(hits, list) + + +def test_empty_query_returns_empty_list(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + assert store.search_chunks("???", n_results=5) == [] + assert store.search_symbols("???", n_results=5) == [] + + +def test_get_chunk_returns_row(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + payload = _add_chunk(store) + row = store.get_chunk(payload["chunk_id"]) + assert row is not None + assert row["chunk_id"] == payload["chunk_id"] + assert row["content"] == payload["content"] + + +def test_get_chunk_returns_none_for_missing(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + assert store.get_chunk("nonexistent#L1-L1") is None + + +def test_query_with_reserved_words_does_not_raise(tmp_path: Path): + """A natural-language query starting with NOT/AND/OR must not crash FTS5.""" + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store, content="def handle_not_null(value): pass") + # Each of these would otherwise be parsed as a FTS5 boolean expression. + for query in ("NOT null handling", "OR operator", "AND something", "foo AND"): + # Must not raise; result can be empty or have hits — we only care about no crash. + store.search_chunks(query, n_results=5) + + +def test_search_chunks_safe_across_threads(tmp_path: Path) -> None: + """The FtsStorage connection must be usable from worker threads.""" + from concurrent.futures import ThreadPoolExecutor + + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + + with ThreadPoolExecutor(max_workers=4) as pool: + results = list( + pool.map(lambda _: store.search_chunks("make_chunk_id", 5), range(8)) + ) + + assert all(len(hits) == 1 for hits in results) + + +def test_add_symbols_is_idempotent(tmp_path: Path): + """Calling add_symbols twice with the same input must not duplicate rows.""" + from docstra.core.indexing.model import IndexedSymbol + + store = FtsStorage(str(tmp_path / "index.db")) + symbols = [ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ), + IndexedSymbol( + id="x.py::function::bar::L10", + file_id="x.py", + name="bar", + kind="function", + language="python", + line=10, + ), + ] + store.add_symbols(symbols) + store.add_symbols(symbols) # second call must not duplicate + + foo_hits = store.search_symbols("foo", n_results=10) + bar_hits = store.search_symbols("bar", n_results=10) + assert len(foo_hits) == 1 + assert len(bar_hits) == 1 + + +def test_add_symbols_empty_list_is_no_op(tmp_path: Path): + """An empty symbol list should be a clean no-op (no IN () syntax error).""" + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols([]) # must not raise diff --git a/tests/test_fusion_retriever.py b/tests/test_fusion_retriever.py new file mode 100644 index 0000000..cc015ef --- /dev/null +++ b/tests/test_fusion_retriever.py @@ -0,0 +1,266 @@ +"""Coverage for RRF fusion of dense + lexical retrieval.""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from docstra.core.retrieval.fusion import FusionRetriever, rrf_score + + +class _FakeDense: + def __init__(self, results: List[Dict[str, Any]]): + self._results = results + + def retrieve_chunks(self, query: str, n_results: int = 20, **filters): + return self._results[:n_results] + + +class _FakeFts: + def __init__(self, chunks: List[Dict[str, Any]], symbols: List[Dict[str, Any]]): + self._chunks = chunks + self._symbols = symbols + + def retrieve_chunks(self, query: str, n_results: int = 50, **filters): + return self._chunks[:n_results] + + def retrieve_symbols(self, query: str, n_results: int = 25): + return self._symbols[:n_results] + + def get_chunk(self, chunk_id: str): + return None + + +def _chunk(chunk_id: str, file_id: str, start_line: int = 1, end_line: int = 10): + return { + "chunk_id": chunk_id, + "id": chunk_id, + "file_id": file_id, + "language": "python", + "start_line": start_line, + "end_line": end_line, + "content": f"# chunk {chunk_id}", + "metadata": { + "document_id": file_id, + "start_line": start_line, + "end_line": end_line, + }, + } + + +def _fake_code_index(chunks_by_file): + def chunks_for_file(file_id): + return chunks_by_file.get(file_id, []) + + return SimpleNamespace( + chunks_for_file=chunks_for_file, file_language=lambda fid: None + ) + + +def test_rrf_score_known_inputs(): + assert rrf_score(rank=1, k=60) == pytest.approx(1 / 61) + assert rrf_score(rank=5, k=60) == pytest.approx(1 / 65) + + +def test_fusion_orders_by_combined_rank(): + a = _chunk("repo/a.py#L1-L10", "repo/a.py") + b = _chunk("repo/b.py#L1-L10", "repo/b.py") + c = _chunk("repo/c.py#L1-L10", "repo/c.py") + + # dense: [a(rank1), b(rank2), c(rank3)] + # fts chunks: [b(rank1), c(rank2), a(rank3)] + # a: 1/61 + 1/63, b: 1/61 + 1/62, c: 1/63 + 1/62 + # b has the highest combined score, a and c tie below it + dense = _FakeDense([a, b, c]) + fts = _FakeFts(chunks=[b, c, a], symbols=[]) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=_fake_code_index({}), + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("anything", n_results=3) + ids = [h["chunk_id"] for h in hits] + assert ids[0] == "repo/b.py#L1-L10" + assert set(ids[1:]) == {"repo/a.py#L1-L10", "repo/c.py#L1-L10"} + + +def test_symbol_hit_promotes_containing_chunk(): + a = _chunk("repo/a.py#L1-L10", "repo/a.py", start_line=1, end_line=10) + b = _chunk("repo/b.py#L1-L10", "repo/b.py") + + dense = _FakeDense([b, a]) + fts = _FakeFts( + chunks=[], + symbols=[ + { + "symbol_id": "repo/a.py::function::foo::L5", + "file_id": "repo/a.py", + "name": "foo", + "kind": "function", + } + ], + ) + + code_index = _fake_code_index({"repo/a.py": [("repo/a.py#L1-L10", 1, 10)]}) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=code_index, + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("foo", n_results=2) + assert hits[0]["chunk_id"] == "repo/a.py#L1-L10" + + +def test_symbol_path_respects_language_filter(): + """Symbol-derived chunks must obey the language filter just like dense/lex chunks.""" + py_chunk = _chunk("repo/a.py#L1-L10", "repo/a.py", start_line=1, end_line=10) + + dense = _FakeDense([py_chunk]) + fts = _FakeFts( + chunks=[], + symbols=[ + { + "symbol_id": "repo/a.py::function::foo::L5", + "file_id": "repo/a.py", + "name": "foo", + "kind": "function", + }, + { + "symbol_id": "repo/b.ts::function::foo::L3", + "file_id": "repo/b.ts", + "name": "foo", + "kind": "function", + }, + ], + ) + + code_index = SimpleNamespace( + chunks_for_file=lambda fid: { + "repo/a.py": [("repo/a.py#L1-L10", 1, 10)], + "repo/b.ts": [("repo/b.ts#L1-L10", 1, 10)], + }.get(fid, []), + file_language=lambda fid: { + "repo/a.py": "python", + "repo/b.ts": "typescript", + }.get(fid), + ) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=code_index, + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + + hits = fusion.retrieve_chunks("foo", n_results=5, language="python") + ids = [h["chunk_id"] for h in hits] + assert "repo/b.ts#L1-L10" not in ids + assert "repo/a.py#L1-L10" in ids + + +def test_empty_lexical_source_does_not_break_fusion(): + a = _chunk("repo/a.py#L1-L10", "repo/a.py") + dense = _FakeDense([a]) + fts = _FakeFts(chunks=[], symbols=[]) + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=_fake_code_index({}), + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("anything", n_results=5) + assert [h["chunk_id"] for h in hits] == ["repo/a.py#L1-L10"] + + +def test_symbol_derived_chunk_carries_content(tmp_path): + """A symbol hit that promotes a chunk must surface the chunk's real content.""" + from docstra.core.ingestion.fts_storage import FtsStorage + from docstra.core.retrieval.fts import FtsRetriever + + fts_storage = FtsStorage(str(tmp_path / "index.db")) + fts_storage.add_chunks( + chunk_ids=["repo/a.py#L1-L10"], + file_ids=["repo/a.py"], + languages=["python"], + start_lines=[1], + end_lines=[10], + contents=["def foo():\n return 42\n"], + ) + fts = FtsRetriever(fts_storage) + + dense = _FakeDense([]) + + class _SymbolOnlyFts: + def __init__(self, real): + self._real = real + + def retrieve_chunks(self, *a, **kw): + return [] + + def retrieve_symbols(self, *a, **kw): + return [ + { + "symbol_id": "repo/a.py::function::foo::L2", + "file_id": "repo/a.py", + "name": "foo", + "kind": "function", + } + ] + + def get_chunk(self, chunk_id): + return self._real.get_chunk(chunk_id) + + code_index = SimpleNamespace( + chunks_for_file=lambda fid: ( + [("repo/a.py#L1-L10", 1, 10)] if fid == "repo/a.py" else [] + ), + file_language=lambda fid: "python" if fid == "repo/a.py" else None, + ) + + fusion = FusionRetriever( + dense=dense, + fts=_SymbolOnlyFts(fts), + code_index=code_index, + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("foo", n_results=5) + assert len(hits) == 1 + assert "def foo()" in hits[0]["content"] + assert hits[0]["metadata"]["via_symbol"] == "foo" + + +def test_retrieve_code_examples_prefers_function_chunks(): + func_chunk = _chunk("repo/a.py#L1-L20", "repo/a.py") + func_chunk["metadata"]["chunk_type"] = "function" + func_chunk["content"] = "def good():\n" + " pass\n" * 10 # 11 lines + tiny_chunk = _chunk("repo/b.py#L1-L2", "repo/b.py") + tiny_chunk["metadata"]["chunk_type"] = "other" + tiny_chunk["content"] = "x = 1\n" + + dense = _FakeDense([tiny_chunk, func_chunk]) + fts = _FakeFts(chunks=[], symbols=[]) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=_fake_code_index({}), + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + examples = fusion.retrieve_code_examples("anything", n_results=2) + assert examples[0]["chunk_id"] == "repo/a.py#L1-L20" diff --git a/tests/test_index_loading.py b/tests/test_index_loading.py index 082ae71..9f6c333 100644 --- a/tests/test_index_loading.py +++ b/tests/test_index_loading.py @@ -10,6 +10,7 @@ DocumentType, ) from docstra.core.indexing.model import CORE_INDEX_FILENAME, CoreIndexBuilder +from docstra.core.ingestion.fts_storage import FtsStorage from docstra.core.ingestion.storage import ChromaDBStorage from docstra.core.services.query_service import QueryService from docstra.core.services.repository_explorer_service import RepositoryExplorerService @@ -62,6 +63,7 @@ def _write_core_index(codebase_root: Path) -> None: manifest.model_dump_json(indent=2), encoding="utf-8" ) ChromaDBStorage(persist_directory=str(persist_dir / "chroma")) + FtsStorage(str(persist_dir / "index.db")) def test_query_service_initializes_from_core_index_without_repo_map( diff --git a/tests/test_ingestion_force.py b/tests/test_ingestion_force.py new file mode 100644 index 0000000..66d795f --- /dev/null +++ b/tests/test_ingestion_force.py @@ -0,0 +1,29 @@ +"""Verify that a forced reindex clears the legacy FTS database file.""" + +from pathlib import Path + +from docstra.core.config.settings import UserConfig +from docstra.core.services.ingestion_service import IngestionService + + +def test_force_reindex_removes_index_db(tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / "file.py").write_text("def hello(): return 1\n") + + persist_dir = repo / ".docstra" + persist_dir.mkdir() + stale_db = persist_dir / "index.db" + stale_db.write_bytes(b"stale-bytes-not-a-real-sqlite-file") + assert stale_db.exists() + + config = UserConfig() + config.storage.persist_directory = str(persist_dir) + service = IngestionService() + service.ingest_codebase(str(repo), config, force=True) + + # The forced reindex should have removed the stale FTS DB before rebuilding; + # the new one created by ingestion will be a valid SQLite file. + assert stale_db.exists(), "expected ingestion to recreate the FTS DB" + # And it should not be the stale bytes we put down. + assert stale_db.read_bytes()[:16] != b"stale-bytes-not-" diff --git a/tests/test_ingestion_fts.py b/tests/test_ingestion_fts.py new file mode 100644 index 0000000..207ba37 --- /dev/null +++ b/tests/test_ingestion_fts.py @@ -0,0 +1,79 @@ +"""Verify ingestion writes chunks and symbols into the FTS store alongside Chroma.""" + +from pathlib import Path +from typing import List + +from docstra.core.document_processing.document import ( + CodeChunk, + Document, + DocumentMetadata, + DocumentType, +) +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.ingestion.storage import ChromaDBStorage, DocumentIndexer + + +class _FakeEmbedder: + """Returns a deterministic small vector so Chroma is happy without a real model.""" + + def generate_embedding(self, _text: str) -> List[float]: + return [0.0, 1.0, 0.0] + + +def _make_document(filepath: str, chunk_content: str) -> Document: + metadata = DocumentMetadata( + filepath=filepath, + language=DocumentType.PYTHON, + size_bytes=len(chunk_content.encode("utf-8")), + last_modified=0.0, + line_count=chunk_content.count("\n") + 1, + ) + return Document( + content=chunk_content, + metadata=metadata, + chunks=[ + CodeChunk( + content=chunk_content, + start_line=1, + end_line=metadata.line_count, + chunk_type="function", + ) + ], + ) + + +def test_document_indexer_writes_chunks_to_fts(tmp_path: Path) -> None: + chroma = ChromaDBStorage(persist_directory=str(tmp_path / "chroma")) + fts = FtsStorage(str(tmp_path / "index.db")) + + document = _make_document("repo/foo.py", "def find_me(): pass") + + indexer = DocumentIndexer( + chroma, + embedding_generator=_FakeEmbedder(), + codebase_root=str(tmp_path), + fts_storage=fts, + ) + indexer.index_document(document) + + hits = fts.search_chunks("find_me", n_results=5) + assert len(hits) == 1 + assert hits[0]["file_id"] == "repo/foo.py" + + +def test_document_indexer_reindex_replaces_chunks(tmp_path: Path) -> None: + """Re-indexing the same file should not duplicate chunks in the FTS store.""" + chroma = ChromaDBStorage(persist_directory=str(tmp_path / "chroma")) + fts = FtsStorage(str(tmp_path / "index.db")) + indexer = DocumentIndexer( + chroma, + embedding_generator=_FakeEmbedder(), + codebase_root=str(tmp_path), + fts_storage=fts, + ) + + indexer.index_document(_make_document("repo/foo.py", "def find_me(): pass")) + indexer.index_document(_make_document("repo/foo.py", "def find_me(): pass")) + + hits = fts.search_chunks("find_me", n_results=5) + assert len(hits) == 1 diff --git a/tests/test_retrieval_config.py b/tests/test_retrieval_config.py new file mode 100644 index 0000000..2c50424 --- /dev/null +++ b/tests/test_retrieval_config.py @@ -0,0 +1,16 @@ +"""Settings coverage for the retrieval config block.""" + +from docstra.core.config.settings import RetrievalConfig, UserConfig + + +def test_retrieval_config_defaults(): + config = RetrievalConfig() + assert config.rrf_k == 60 + assert config.fts_chunks_top_k == 50 + assert config.fts_symbols_top_k == 25 + + +def test_user_config_exposes_retrieval(): + user_config = UserConfig() + assert isinstance(user_config.retrieval, RetrievalConfig) + assert user_config.retrieval.rrf_k == 60 diff --git a/tests/test_retrieval_evaluation.py b/tests/test_retrieval_evaluation.py index 83adfc3..aad3fb6 100644 --- a/tests/test_retrieval_evaluation.py +++ b/tests/test_retrieval_evaluation.py @@ -1,9 +1,13 @@ from __future__ import annotations +import hashlib import json +import struct import sys from importlib import util from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List MODULE_PATH = ( Path(__file__).resolve().parents[1] @@ -154,3 +158,249 @@ def retrieve(question: str, candidate_k: int): assert summary.results[0].passed is True assert summary.results[0].rank == 2 assert summary.results[0].retrieved_files == ["a.py", "b.py"] + + +# --------------------------------------------------------------------------- +# Integration: Chroma-only vs. FusionRetriever side-by-side eval +# --------------------------------------------------------------------------- + +_REPO_ROOT = Path(__file__).resolve().parents[1] + +# Subset of real source files that map to the eval queries in evaluation.py. +# Keep the list small so the test stays fast (no network, local embeddings only). +_CORPUS_FILES: List[Dict[str, Any]] = [ + { + "path": "docstra/core/config/settings.py", + "abs": _REPO_ROOT / "docstra/core/config/settings.py", + }, + { + "path": "docstra/core/retrieval/chroma.py", + "abs": _REPO_ROOT / "docstra/core/retrieval/chroma.py", + }, + { + "path": "docstra/core/retrieval/context_aware.py", + "abs": _REPO_ROOT / "docstra/core/retrieval/context_aware.py", + }, + { + "path": "docstra/core/ingestion/storage.py", + "abs": _REPO_ROOT / "docstra/core/ingestion/storage.py", + }, + { + "path": "docstra/core/services/ingestion_service.py", + "abs": _REPO_ROOT / "docstra/core/services/ingestion_service.py", + }, + { + "path": "docstra/core/indexing/code_index.py", + "abs": _REPO_ROOT / "docstra/core/indexing/code_index.py", + }, + { + "path": "docstra/core/cli.py", + "abs": _REPO_ROOT / "docstra/core/cli.py", + }, + { + "path": "docstra/core/utils/token_counter.py", + "abs": _REPO_ROOT / "docstra/core/utils/token_counter.py", + }, + { + "path": "docstra/core/documentation/generator.py", + "abs": _REPO_ROOT / "docstra/core/documentation/generator.py", + }, +] + +_EVAL_CASES = [ + RetrievalEvalCase( + question="How does Docstra load and save user configuration?", + expected_files=["docstra/core/config/settings.py"], + ), + RetrievalEvalCase( + question="Where are files ingested into ChromaDB with embeddings?", + expected_files=[ + "docstra/core/services/ingestion_service.py", + "docstra/core/ingestion/storage.py", + ], + ), + RetrievalEvalCase( + question="How does Chroma retrieval search document chunks?", + expected_files=["docstra/core/retrieval/chroma.py"], + ), + RetrievalEvalCase( + question="Where does Docstra classify query intent for context retrieval?", + expected_files=["docstra/core/retrieval/context_aware.py"], + ), + RetrievalEvalCase( + question="How are code symbols indexed for later search?", + expected_files=["docstra/core/indexing/code_index.py"], + ), + RetrievalEvalCase( + question="Where is the docstra query CLI command implemented?", + expected_files=["docstra/core/cli.py"], + ), + RetrievalEvalCase( + question="How is context token budget calculated and enforced?", + expected_files=["docstra/core/utils/token_counter.py"], + ), + RetrievalEvalCase( + question="Where does documentation generation assemble code context?", + expected_files=["docstra/core/documentation/generator.py"], + ), +] + + +class _DummyEmbedder: + """Returns a deterministic 8-d embedding so Chroma is happy without a real model.""" + + def generate_embedding(self, text: str) -> List[float]: + digest = hashlib.sha256(text.encode("utf-8")).digest() + vals = struct.unpack(">8I", digest[:32]) + return [v / 0xFFFFFFFF for v in vals] + + +def _split_into_chunks(content: str, chunk_size: int = 40) -> List[Dict[str, Any]]: + """Split file content into fixed-size line chunks for the eval corpus.""" + lines = content.splitlines(keepends=True) + chunks = [] + for start in range(0, len(lines), chunk_size): + end = min(start + chunk_size, len(lines)) + chunk_text = "".join(lines[start:end]) + chunks.append( + { + "start_line": start + 1, + "end_line": end, + "content": chunk_text, + } + ) + return chunks + + +def test_chroma_vs_fusion_retrieval_eval(tmp_path: Path) -> None: + """Build a small real corpus and compare Chroma-only vs. FusionRetriever recall.""" + import pytest + + # Skip if any corpus file is missing (e.g., running from a partial checkout). + missing = [f["path"] for f in _CORPUS_FILES if not f["abs"].exists()] + if missing: + pytest.skip(f"corpus files not found: {missing}") + + from docstra.core.ingestion.fts_storage import FtsStorage + from docstra.core.ingestion.storage import ChromaDBStorage + from docstra.core.retrieval.chroma import ChromaRetriever + from docstra.core.retrieval.fts import FtsRetriever + from docstra.core.retrieval.fusion import FusionRetriever + + chroma_storage = ChromaDBStorage(persist_directory=str(tmp_path / "chroma")) + fts_storage = FtsStorage(str(tmp_path / "index.db")) + embedder = _DummyEmbedder() + + # Registry for the minimal in-memory CodebaseIndex substitute. + # Maps file_id -> [(chunk_id, start_line, end_line)] + chunks_by_file: Dict[str, List] = {} + file_language: Dict[str, str] = {} + + for corpus_entry in _CORPUS_FILES: + file_id: str = corpus_entry["path"] + abs_path: Path = corpus_entry["abs"] + content = abs_path.read_text(encoding="utf-8", errors="replace") + file_chunks = _split_into_chunks(content) + language = "python" + + chunk_ids = [] + chunk_contents = [] + chunk_metadatas = [] + chunk_embeddings = [] + fts_chunk_ids = [] + fts_start_lines = [] + fts_end_lines = [] + fts_contents = [] + + file_chunk_tuples = [] + for fc in file_chunks: + chunk_id = f"{file_id}#L{fc['start_line']}-L{fc['end_line']}" + embedding = embedder.generate_embedding(fc["content"]) + + chunk_ids.append(chunk_id) + chunk_contents.append(fc["content"]) + chunk_metadatas.append( + { + "document_id": file_id, + "start_line": fc["start_line"], + "end_line": fc["end_line"], + "language": language, + "chunk_id": chunk_id, + } + ) + chunk_embeddings.append(embedding) + + fts_chunk_ids.append(chunk_id) + fts_start_lines.append(fc["start_line"]) + fts_end_lines.append(fc["end_line"]) + fts_contents.append(fc["content"]) + + file_chunk_tuples.append((chunk_id, fc["start_line"], fc["end_line"])) + + chroma_storage.add_chunks( + chunk_ids=chunk_ids, + contents=chunk_contents, + metadatas=chunk_metadatas, + embeddings=chunk_embeddings, + ) + fts_storage.add_chunks( + chunk_ids=fts_chunk_ids, + file_ids=[file_id] * len(fts_chunk_ids), + languages=[language] * len(fts_chunk_ids), + start_lines=fts_start_lines, + end_lines=fts_end_lines, + contents=fts_contents, + ) + chunks_by_file[file_id] = file_chunk_tuples + file_language[file_id] = language + + # Minimal CodebaseIndex substitute: only needs chunks_for_file + file_language. + code_index = SimpleNamespace( + chunks_for_file=lambda fid: chunks_by_file.get(fid, []), + file_language=lambda fid: file_language.get(fid), + ) + + chroma_retriever = ChromaRetriever(chroma_storage, embedder) + fts_retriever = FtsRetriever(fts_storage) + + fusion_retriever = FusionRetriever( + dense=chroma_retriever, + fts=fts_retriever, + code_index=code_index, + ) + + top_k = 5 + + def chroma_retrieve(question: str, n: int) -> List[Dict[str, Any]]: + return chroma_retriever.retrieve_chunks(question, n_results=n) + + def fusion_retrieve(question: str, n: int) -> List[Dict[str, Any]]: + return fusion_retriever.retrieve_chunks(question, n_results=n) + + chroma_summary = evaluate_retrieval_cases(_EVAL_CASES, chroma_retrieve, top_k=top_k) + fusion_summary = evaluate_retrieval_cases(_EVAL_CASES, fusion_retrieve, top_k=top_k) + + # --- Print side-by-side table (visible with pytest -s) --- + print(f"\n{'':=<72}") + print(f"Retrieval eval top_k={top_k} corpus={len(_CORPUS_FILES)} files") + print(f"{'':=<72}") + header = f"{'Query':<50} {'Chroma':>7} {'Fusion':>7}" + print(header) + print(f"{'':-<72}") + for cr, fr in zip(chroma_summary.results, fusion_summary.results): + chroma_rank = str(cr.rank) if cr.rank else "-" + fusion_rank = str(fr.rank) if fr.rank else "-" + q = cr.case.question[:48] + print(f"{q:<50} {chroma_rank:>7} {fusion_rank:>7}") + print(f"{'':-<72}") + print( + f"{'recall@k':<50} {chroma_summary.recall_at_k:>7.2f} {fusion_summary.recall_at_k:>7.2f}" + ) + print( + f"{'passed':<50} {chroma_summary.passed_count:>7} {fusion_summary.passed_count:>7}" + ) + print(f"{'':=<72}") + + # Sanity: scores must be valid floats, not errors. + assert 0.0 <= chroma_summary.recall_at_k <= 1.0 + assert 0.0 <= fusion_summary.recall_at_k <= 1.0