From 2c3c688a75e6760de1501957ad9febaee60fd7f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Thu, 11 Jun 2026 15:36:10 +0200 Subject: [PATCH 01/20] feat(config): add RetrievalConfig with rrf_k and fts top-k knobs --- docstra/core/config/settings.py | 27 +++++++++++++++++++++++++++ tests/test_retrieval_config.py | 16 ++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 tests/test_retrieval_config.py diff --git a/docstra/core/config/settings.py b/docstra/core/config/settings.py index 43123c8..b991a3b 100644 --- a/docstra/core/config/settings.py +++ b/docstra/core/config/settings.py @@ -120,6 +120,18 @@ def __init__( self.exclude_patterns = exclude_patterns or [] +class RetrievalConfig: + def __init__( + self, + rrf_k: int = 60, + fts_chunks_top_k: int = 50, + fts_symbols_top_k: int = 25, + ) -> None: + self.rrf_k = rrf_k + self.fts_chunks_top_k = fts_chunks_top_k + self.fts_symbols_top_k = fts_symbols_top_k + + class ConfigManager: def __init__(self, config_path: Optional[str] = None) -> None: self.config_path = config_path or "./.docstra/config.yaml" @@ -180,6 +192,7 @@ def __init__(self) -> None: self.processing = ProcessingConfig() self.ingestion = IngestionConfig() self.documentation = DocumentationConfig() + self.retrieval = RetrievalConfig() def save_to_file(self, path: str) -> None: """Save configuration to YAML file.""" @@ -219,6 +232,11 @@ def save_to_file(self, path: str) -> None: "exclude_patterns": self.ingestion.exclude_patterns, }, "documentation": self.documentation.model_dump(), + "retrieval": { + "rrf_k": self.retrieval.rrf_k, + "fts_chunks_top_k": self.retrieval.fts_chunks_top_k, + "fts_symbols_top_k": self.retrieval.fts_symbols_top_k, + }, } # Write to YAML file @@ -283,3 +301,12 @@ def load_from_file(self, path: str) -> None: self.processing.chunk_overlap = processing_data["chunk_overlap"] if "exclude_patterns" in processing_data: self.processing.exclude_patterns = processing_data["exclude_patterns"] + + if "retrieval" in config_dict: + retrieval_data = config_dict["retrieval"] + if "rrf_k" in retrieval_data: + self.retrieval.rrf_k = retrieval_data["rrf_k"] + if "fts_chunks_top_k" in retrieval_data: + self.retrieval.fts_chunks_top_k = retrieval_data["fts_chunks_top_k"] + if "fts_symbols_top_k" in retrieval_data: + self.retrieval.fts_symbols_top_k = retrieval_data["fts_symbols_top_k"] diff --git a/tests/test_retrieval_config.py b/tests/test_retrieval_config.py new file mode 100644 index 0000000..2c50424 --- /dev/null +++ b/tests/test_retrieval_config.py @@ -0,0 +1,16 @@ +"""Settings coverage for the retrieval config block.""" + +from docstra.core.config.settings import RetrievalConfig, UserConfig + + +def test_retrieval_config_defaults(): + config = RetrievalConfig() + assert config.rrf_k == 60 + assert config.fts_chunks_top_k == 50 + assert config.fts_symbols_top_k == 25 + + +def test_user_config_exposes_retrieval(): + user_config = UserConfig() + assert isinstance(user_config.retrieval, RetrievalConfig) + assert user_config.retrieval.rrf_k == 60 From 016967d8d08df9a7502c8c612d84dec4c2ec9e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Thu, 11 Jun 2026 15:46:49 +0200 Subject: [PATCH 02/20] feat(fts): add FtsStorage with chunks_fts schema and CRUD --- docstra/core/ingestion/fts_storage.py | 139 ++++++++++++++++++++++++++ tests/test_fts_storage.py | 63 ++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 docstra/core/ingestion/fts_storage.py create mode 100644 tests/test_fts_storage.py diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py new file mode 100644 index 0000000..9617409 --- /dev/null +++ b/docstra/core/ingestion/fts_storage.py @@ -0,0 +1,139 @@ +"""SQLite + FTS5 store for lexical retrieval over chunks and symbols.""" + +from __future__ import annotations + +import os +import sqlite3 +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence + +from docstra.core.indexing.model import IndexedSymbol + +SCHEMA_VERSION = 1 + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS chunks ( + chunk_id TEXT PRIMARY KEY, + file_id TEXT NOT NULL, + language TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_chunks_file ON chunks(file_id); + +CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + content, + content='chunks', + content_rowid='rowid', + tokenize='unicode61 remove_diacritics 2' +); + +CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN + INSERT INTO chunks_fts(rowid, content) VALUES (new.rowid, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN + INSERT INTO chunks_fts(chunks_fts, rowid, content) VALUES('delete', old.rowid, old.content); +END; + +CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN + INSERT INTO chunks_fts(chunks_fts, rowid, content) VALUES('delete', old.rowid, old.content); + INSERT INTO chunks_fts(rowid, content) VALUES (new.rowid, new.content); +END; + +CREATE VIRTUAL TABLE IF NOT EXISTS symbols_fts USING fts5( + symbol_id UNINDEXED, + file_id UNINDEXED, + kind UNINDEXED, + name, + tokenize='unicode61 remove_diacritics 2' +); +""" + + +class FtsStorage: + """SQLite store with FTS5 indexes for chunks and symbols.""" + + def __init__(self, db_path: str) -> None: + self.db_path = db_path + os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True) + self._conn = sqlite3.connect(db_path) + self._conn.row_factory = sqlite3.Row + self._migrate() + + def _migrate(self) -> None: + with self._conn: + self._conn.executescript(_SCHEMA) + current = self._conn.execute("SELECT version FROM schema_version").fetchone() + if current is None: + self._conn.execute( + "INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,) + ) + + def close(self) -> None: + self._conn.close() + + # --- chunks --- + + def add_chunks( + self, + *, + chunk_ids: Sequence[str], + file_ids: Sequence[str], + languages: Sequence[str], + start_lines: Sequence[int], + end_lines: Sequence[int], + contents: Sequence[str], + ) -> None: + rows = list(zip(chunk_ids, file_ids, languages, start_lines, end_lines, contents)) + with self._conn: + self._conn.executemany( + """ + INSERT INTO chunks (chunk_id, file_id, language, start_line, end_line, content) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(chunk_id) DO UPDATE SET + file_id=excluded.file_id, + language=excluded.language, + start_line=excluded.start_line, + end_line=excluded.end_line, + content=excluded.content + """, + rows, + ) + + def delete_by_file(self, file_id: str) -> None: + with self._conn: + self._conn.execute("DELETE FROM chunks WHERE file_id = ?", (file_id,)) + self._conn.execute("DELETE FROM symbols_fts WHERE file_id = ?", (file_id,)) + + def search_chunks( + self, query: str, n_results: int = 50, *, language: Optional[str] = None, file_id: Optional[str] = None + ) -> List[Dict[str, Any]]: + clauses = ["chunks_fts MATCH ?"] + params: List[Any] = [query] + if language is not None: + clauses.append("chunks.language = ?") + params.append(language) + if file_id is not None: + clauses.append("chunks.file_id = ?") + params.append(file_id) + sql = f""" + SELECT chunks.chunk_id, chunks.file_id, chunks.language, + chunks.start_line, chunks.end_line, chunks.content, + -bm25(chunks_fts) AS score + FROM chunks_fts + JOIN chunks ON chunks.rowid = chunks_fts.rowid + WHERE {' AND '.join(clauses)} + ORDER BY score DESC + LIMIT ? + """ + params.append(n_results) + return [dict(row) for row in self._conn.execute(sql, params).fetchall()] + + # --- symbols (filled in Task 2b) --- diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py new file mode 100644 index 0000000..bcc34ad --- /dev/null +++ b/tests/test_fts_storage.py @@ -0,0 +1,63 @@ +"""Coverage for the FTS5-backed lexical store.""" + +from pathlib import Path + +import pytest + +from docstra.core.ingestion.fts_storage import FtsStorage + + +def _add_chunk(store: FtsStorage, **overrides): + payload = dict( + chunk_id="repo/file.py#L1-L10", + file_id="repo/file.py", + language="python", + start_line=1, + end_line=10, + content="def make_chunk_id(file_id, start_line, end_line):\n return f'{file_id}#L{start_line}-L{end_line}'", + ) + payload.update(overrides) + store.add_chunks( + chunk_ids=[payload["chunk_id"]], + file_ids=[payload["file_id"]], + languages=[payload["language"]], + start_lines=[payload["start_line"]], + end_lines=[payload["end_line"]], + contents=[payload["content"]], + ) + return payload + + +def test_schema_creates_on_first_open(tmp_path: Path): + db_path = tmp_path / "index.db" + store = FtsStorage(str(db_path)) + assert db_path.exists() + # Idempotent: opening again does not raise. + FtsStorage(str(db_path)) + + +def test_add_and_search_chunks(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + hits = store.search_chunks("make_chunk_id", n_results=5) + assert len(hits) == 1 + assert hits[0]["chunk_id"] == "repo/file.py#L1-L10" + assert hits[0]["file_id"] == "repo/file.py" + assert hits[0]["start_line"] == 1 + assert hits[0]["end_line"] == 10 + assert "make_chunk_id" in hits[0]["content"] + + +def test_delete_by_file_removes_chunks_and_fts(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + store.delete_by_file("repo/file.py") + assert store.search_chunks("make_chunk_id", n_results=5) == [] + + +def test_search_supports_language_filter(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store, chunk_id="a.py#L1-L1", file_id="a.py", language="python", content="foo bar") + _add_chunk(store, chunk_id="b.ts#L1-L1", file_id="b.ts", language="typescript", content="foo bar") + hits = store.search_chunks("foo", n_results=5, language="python") + assert {h["file_id"] for h in hits} == {"a.py"} From 3448ab7280b107b8c80d4c22adf9d5532e0bc469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Thu, 11 Jun 2026 15:52:39 +0200 Subject: [PATCH 03/20] feat(fts): add symbol indexing and search to FtsStorage --- docstra/core/ingestion/fts_storage.py | 23 ++++++++++++-- tests/test_fts_storage.py | 46 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py index 9617409..d820300 100644 --- a/docstra/core/ingestion/fts_storage.py +++ b/docstra/core/ingestion/fts_storage.py @@ -4,7 +4,6 @@ import os import sqlite3 -from pathlib import Path from typing import Any, Dict, List, Optional, Sequence from docstra.core.indexing.model import IndexedSymbol @@ -136,4 +135,24 @@ def search_chunks( params.append(n_results) return [dict(row) for row in self._conn.execute(sql, params).fetchall()] - # --- symbols (filled in Task 2b) --- + # --- symbols --- + + def add_symbols(self, symbols: List[IndexedSymbol]) -> None: + rows = [ + (symbol.id, symbol.file_id, symbol.kind, symbol.name) for symbol in symbols + ] + with self._conn: + self._conn.executemany( + "INSERT INTO symbols_fts (symbol_id, file_id, kind, name) VALUES (?, ?, ?, ?)", + rows, + ) + + def search_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: + sql = """ + SELECT symbol_id, file_id, kind, name, -bm25(symbols_fts) AS score + FROM symbols_fts + WHERE symbols_fts MATCH ? + ORDER BY score DESC + LIMIT ? + """ + return [dict(row) for row in self._conn.execute(sql, (query, n_results)).fetchall()] diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py index bcc34ad..9fc775f 100644 --- a/tests/test_fts_storage.py +++ b/tests/test_fts_storage.py @@ -61,3 +61,49 @@ def test_search_supports_language_filter(tmp_path: Path): _add_chunk(store, chunk_id="b.ts#L1-L1", file_id="b.ts", language="typescript", content="foo bar") hits = store.search_chunks("foo", n_results=5, language="python") assert {h["file_id"] for h in hits} == {"a.py"} + + +def test_add_and_search_symbols(tmp_path: Path): + from docstra.core.indexing.model import IndexedSymbol + + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols([ + IndexedSymbol( + id="repo/file.py::function::make_chunk_id::L67", + file_id="repo/file.py", + name="make_chunk_id", + kind="function", + language="python", + line=67, + ), + IndexedSymbol( + id="repo/other.py::class::CoreIndexBuilder::L194", + file_id="repo/other.py", + name="CoreIndexBuilder", + kind="class", + language="python", + line=194, + ), + ]) + hits = store.search_symbols("CoreIndexBuilder", n_results=5) + assert len(hits) == 1 + assert hits[0]["name"] == "CoreIndexBuilder" + assert hits[0]["file_id"] == "repo/other.py" + + +def test_delete_by_file_removes_symbols(tmp_path: Path): + from docstra.core.indexing.model import IndexedSymbol + + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols([ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ) + ]) + store.delete_by_file("x.py") + assert store.search_symbols("foo", n_results=5) == [] From 8f780e0081e7b1ba83ecf995d7b9465af73ded61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Thu, 11 Jun 2026 16:00:17 +0200 Subject: [PATCH 04/20] feat(retrieval): add FtsRetriever wrapper around FtsStorage --- docstra/core/retrieval/fts.py | 29 +++++++++++++++++++++++++ tests/test_fts_retriever.py | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 docstra/core/retrieval/fts.py create mode 100644 tests/test_fts_retriever.py diff --git a/docstra/core/retrieval/fts.py b/docstra/core/retrieval/fts.py new file mode 100644 index 0000000..63d29fd --- /dev/null +++ b/docstra/core/retrieval/fts.py @@ -0,0 +1,29 @@ +"""Retriever that delegates lexical search to FtsStorage.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from docstra.core.ingestion.fts_storage import FtsStorage + + +class FtsRetriever: + """Thin wrapper exposing FtsStorage searches to higher-level retrievers.""" + + def __init__(self, storage: FtsStorage) -> None: + self.storage = storage + + def retrieve_chunks( + self, + query: str, + n_results: int = 50, + *, + language: Optional[str] = None, + file_id: Optional[str] = None, + ) -> List[Dict[str, Any]]: + return self.storage.search_chunks( + query, n_results=n_results, language=language, file_id=file_id + ) + + def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: + return self.storage.search_symbols(query, n_results=n_results) diff --git a/tests/test_fts_retriever.py b/tests/test_fts_retriever.py new file mode 100644 index 0000000..7d4df46 --- /dev/null +++ b/tests/test_fts_retriever.py @@ -0,0 +1,41 @@ +"""Coverage for the FTS-backed retriever.""" + +from pathlib import Path + +from docstra.core.indexing.model import IndexedSymbol +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever + + +def test_retrieve_chunks_delegates_to_storage(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + store.add_chunks( + chunk_ids=["repo/file.py#L1-L10"], + file_ids=["repo/file.py"], + languages=["python"], + start_lines=[1], + end_lines=[10], + contents=["def make_chunk_id(): pass"], + ) + retriever = FtsRetriever(store) + hits = retriever.retrieve_chunks("make_chunk_id", n_results=5) + assert len(hits) == 1 + assert hits[0]["chunk_id"] == "repo/file.py#L1-L10" + + +def test_retrieve_symbols_delegates_to_storage(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols([ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ) + ]) + retriever = FtsRetriever(store) + hits = retriever.retrieve_symbols("foo", n_results=5) + assert len(hits) == 1 + assert hits[0]["name"] == "foo" From c707f4c44ecfbba9b9ba1952cacbc5a28e58019f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 11:52:51 +0200 Subject: [PATCH 05/20] feat(retrieval): add FusionRetriever with RRF over dense and lexical sources --- docstra/core/indexing/code_index.py | 17 ++- docstra/core/retrieval/fusion.py | 133 ++++++++++++++++++++++++ tests/test_core_index.py | 32 ++++++ tests/test_fusion_retriever.py | 156 ++++++++++++++++++++++++++++ 4 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 docstra/core/retrieval/fusion.py create mode 100644 tests/test_fusion_retriever.py diff --git a/docstra/core/indexing/code_index.py b/docstra/core/indexing/code_index.py index 0eaf768..9263e24 100644 --- a/docstra/core/indexing/code_index.py +++ b/docstra/core/indexing/code_index.py @@ -7,7 +7,7 @@ from collections import defaultdict import os from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union from docstra.core.document_processing.document import Document, DocumentType from docstra.core.indexing.model import ( @@ -455,6 +455,21 @@ def get_related_files(self, filepath: str) -> List[str]: related_files.discard(file_id) return sorted(related_files) + def chunks_for_file(self, file_id: str) -> List[Tuple[str, int, int]]: + """Return (chunk_id, start_line, end_line) tuples for a file in line order.""" + matching = [ + (chunk.id, chunk.start_line, chunk.end_line) + for chunk in self._manifest.chunks + if chunk.file_id == file_id + ] + matching.sort(key=lambda tup: tup[1]) + return matching + + def file_language(self, file_id: str) -> Optional[str]: + """Return the language recorded in the manifest for a file id, if any.""" + entry = self._files_by_id.get(file_id) + return entry.language if entry else None + def clear(self) -> None: """Clear the persisted manifest and in-memory lookups.""" self._manifest = CoreIndexManifest.empty( diff --git a/docstra/core/retrieval/fusion.py b/docstra/core/retrieval/fusion.py new file mode 100644 index 0000000..fed9b1e --- /dev/null +++ b/docstra/core/retrieval/fusion.py @@ -0,0 +1,133 @@ +"""Reciprocal Rank Fusion over dense + lexical retrieval sources.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Protocol + +from docstra.core.indexing.code_index import CodebaseIndex + + +def rrf_score(rank: int, k: int) -> float: + """Standard RRF contribution for a single source at 1-based rank.""" + return 1.0 / (k + rank) + + +class _DenseLike(Protocol): + def retrieve_chunks(self, query: str, n_results: int = 20, **filters) -> List[Dict[str, Any]]: ... + + +class _FtsLike(Protocol): + def retrieve_chunks(self, query: str, n_results: int = 50, **filters) -> List[Dict[str, Any]]: ... + def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: ... + + +class FusionRetriever: + """Runs dense + lexical (chunks, symbols) retrieval and fuses with RRF.""" + + def __init__( + self, + dense: _DenseLike, + fts: _FtsLike, + code_index: CodebaseIndex, + *, + rrf_k: int = 60, + fts_chunks_top_k: int = 50, + fts_symbols_top_k: int = 25, + ) -> None: + self.dense = dense + self.fts = fts + self.code_index = code_index + self.rrf_k = rrf_k + self.fts_chunks_top_k = fts_chunks_top_k + self.fts_symbols_top_k = fts_symbols_top_k + + def retrieve(self, query: str, n_results: int = 20, **filters) -> List[Dict[str, Any]]: + return self.retrieve_chunks(query, n_results=n_results, **filters) + + def retrieve_chunks( + self, query: str, n_results: int = 20, **filters + ) -> List[Dict[str, Any]]: + dense_hits = self.dense.retrieve_chunks(query, n_results=n_results * 2, **filters) + lex_chunk_hits = self.fts.retrieve_chunks( + query, n_results=self.fts_chunks_top_k, **filters + ) + symbol_hits = self.fts.retrieve_symbols(query, n_results=self.fts_symbols_top_k) + symbol_chunk_hits = self._symbols_to_chunks(symbol_hits, filters) + + scored: Dict[str, float] = defaultdict(float) + record: Dict[str, Dict[str, Any]] = {} + for source in (dense_hits, lex_chunk_hits, symbol_chunk_hits): + for rank, hit in enumerate(source, start=1): + chunk_id = self._chunk_id(hit) + if chunk_id is None: + continue + scored[chunk_id] += rrf_score(rank, self.rrf_k) + record.setdefault(chunk_id, hit) + + ordered = sorted(record.items(), key=lambda kv: (-scored[kv[0]], kv[0])) + return [hit for _, hit in ordered[:n_results]] + + def retrieve_by_language( + self, query: str, language: str, n_results: int = 20 + ) -> List[Dict[str, Any]]: + return self.retrieve_chunks(query, n_results=n_results, language=language) + + def retrieve_code_examples( + self, query: str, n_results: int = 10, languages: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: + raise NotImplementedError("retrieve_code_examples is moved over in Task 7") + + def _chunk_id(self, hit: Dict[str, Any]) -> Optional[str]: + return hit.get("chunk_id") or hit.get("id") + + def _symbols_to_chunks( + self, symbol_hits: Iterable[Dict[str, Any]], filters: Dict[str, Any] + ) -> List[Dict[str, Any]]: + language = filters.get("language") + file_id_filter = filters.get("file_id") + results: List[Dict[str, Any]] = [] + seen: set[str] = set() + for symbol_hit in symbol_hits: + file_id = symbol_hit.get("file_id") + if file_id is None: + continue + if file_id_filter is not None and file_id != file_id_filter: + continue + if language is not None and self.code_index.file_language(file_id) != language: + continue + symbol_id = symbol_hit.get("symbol_id", "") + line = _extract_line_from_symbol_id(symbol_id) + if line is None: + continue + for chunk_id, start_line, end_line in self.code_index.chunks_for_file(file_id): + if start_line <= line <= end_line and chunk_id not in seen: + seen.add(chunk_id) + results.append({ + "chunk_id": chunk_id, + "id": chunk_id, + "file_id": file_id, + "start_line": start_line, + "end_line": end_line, + "content": "", + "metadata": { + "document_id": file_id, + "start_line": start_line, + "end_line": end_line, + "via_symbol": symbol_hit.get("name"), + }, + }) + break + return results + + +def _extract_line_from_symbol_id(symbol_id: str) -> Optional[int]: + if not symbol_id: + return None + tail = symbol_id.rsplit("::", 1)[-1] + if not tail.startswith("L"): + return None + try: + return int(tail[1:]) + except ValueError: + return None diff --git a/tests/test_core_index.py b/tests/test_core_index.py index 5c90a20..3cc4bc1 100644 --- a/tests/test_core_index.py +++ b/tests/test_core_index.py @@ -479,3 +479,35 @@ def test_codebase_index_rejects_legacy_sidecars_without_core_manifest( with pytest.raises(FileNotFoundError, match="Rerun 'docstra ingest'"): CodebaseIndex(index_directory=str(index_dir), codebase_root=str(tmp_path)) + + +def test_chunks_for_file_returns_chunks_in_line_order(tmp_path): + from docstra.core.indexing.code_index import CodebaseIndex + from docstra.core.indexing.model import CoreIndexManifest, IndexedChunk + + index = CodebaseIndex(index_directory=str(tmp_path / "index")) + index._manifest = CoreIndexManifest.empty() + index._manifest.chunks.extend([ + IndexedChunk(id="a.py#L1-L10", file_id="a.py", language="python", + start_line=1, end_line=10, chunk_type="code"), + IndexedChunk(id="a.py#L11-L20", file_id="a.py", language="python", + start_line=11, end_line=20, chunk_type="code"), + IndexedChunk(id="b.py#L1-L5", file_id="b.py", language="python", + start_line=1, end_line=5, chunk_type="code"), + ]) + + assert index.chunks_for_file("a.py") == [("a.py#L1-L10", 1, 10), ("a.py#L11-L20", 11, 20)] + assert index.chunks_for_file("missing.py") == [] + + # Verify sorting is enforced even when chunks are inserted in reverse line order. + index2 = CodebaseIndex(index_directory=str(tmp_path / "index2")) + index2._manifest = CoreIndexManifest.empty() + index2._manifest.chunks.extend([ + IndexedChunk(id="c.py#L100-L110", file_id="c.py", language="python", + start_line=100, end_line=110, chunk_type="code"), + IndexedChunk(id="c.py#L1-L10", file_id="c.py", language="python", + start_line=1, end_line=10, chunk_type="code"), + ]) + + result = index2.chunks_for_file("c.py") + assert result == [("c.py#L1-L10", 1, 10), ("c.py#L100-L110", 100, 110)] diff --git a/tests/test_fusion_retriever.py b/tests/test_fusion_retriever.py new file mode 100644 index 0000000..91f0ca2 --- /dev/null +++ b/tests/test_fusion_retriever.py @@ -0,0 +1,156 @@ +"""Coverage for RRF fusion of dense + lexical retrieval.""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from docstra.core.retrieval.fusion import FusionRetriever, rrf_score + + +class _FakeDense: + def __init__(self, results: List[Dict[str, Any]]): + self._results = results + + def retrieve_chunks(self, query: str, n_results: int = 20, **filters): + return self._results[:n_results] + + +class _FakeFts: + def __init__(self, chunks: List[Dict[str, Any]], symbols: List[Dict[str, Any]]): + self._chunks = chunks + self._symbols = symbols + + def retrieve_chunks(self, query: str, n_results: int = 50, **filters): + return self._chunks[:n_results] + + def retrieve_symbols(self, query: str, n_results: int = 25): + return self._symbols[:n_results] + + +def _chunk(chunk_id: str, file_id: str, start_line: int = 1, end_line: int = 10): + return { + "chunk_id": chunk_id, + "id": chunk_id, + "file_id": file_id, + "language": "python", + "start_line": start_line, + "end_line": end_line, + "content": f"# chunk {chunk_id}", + "metadata": {"document_id": file_id, "start_line": start_line, "end_line": end_line}, + } + + +def _fake_code_index(chunks_by_file): + def chunks_for_file(file_id): + return chunks_by_file.get(file_id, []) + + return SimpleNamespace(chunks_for_file=chunks_for_file, file_language=lambda fid: None) + + +def test_rrf_score_known_inputs(): + assert rrf_score(rank=1, k=60) == pytest.approx(1 / 61) + assert rrf_score(rank=5, k=60) == pytest.approx(1 / 65) + + +def test_fusion_orders_by_combined_rank(): + a = _chunk("repo/a.py#L1-L10", "repo/a.py") + b = _chunk("repo/b.py#L1-L10", "repo/b.py") + c = _chunk("repo/c.py#L1-L10", "repo/c.py") + + # dense: [a(rank1), b(rank2), c(rank3)] + # fts chunks: [b(rank1), c(rank2), a(rank3)] + # a: 1/61 + 1/63, b: 1/61 + 1/62, c: 1/63 + 1/62 + # b has the highest combined score, a and c tie below it + dense = _FakeDense([a, b, c]) + fts = _FakeFts(chunks=[b, c, a], symbols=[]) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=_fake_code_index({}), + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("anything", n_results=3) + ids = [h["chunk_id"] for h in hits] + assert ids[0] == "repo/b.py#L1-L10" + assert set(ids[1:]) == {"repo/a.py#L1-L10", "repo/c.py#L1-L10"} + + +def test_symbol_hit_promotes_containing_chunk(): + a = _chunk("repo/a.py#L1-L10", "repo/a.py", start_line=1, end_line=10) + b = _chunk("repo/b.py#L1-L10", "repo/b.py") + + dense = _FakeDense([b, a]) + fts = _FakeFts( + chunks=[], + symbols=[{"symbol_id": "repo/a.py::function::foo::L5", "file_id": "repo/a.py", "name": "foo", "kind": "function"}], + ) + + code_index = _fake_code_index({"repo/a.py": [("repo/a.py#L1-L10", 1, 10)]}) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=code_index, + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("foo", n_results=2) + assert hits[0]["chunk_id"] == "repo/a.py#L1-L10" + + +def test_symbol_path_respects_language_filter(): + """Symbol-derived chunks must obey the language filter just like dense/lex chunks.""" + py_chunk = _chunk("repo/a.py#L1-L10", "repo/a.py", start_line=1, end_line=10) + ts_chunk = _chunk("repo/b.ts#L1-L10", "repo/b.ts", start_line=1, end_line=10) + + dense = _FakeDense([py_chunk]) + fts = _FakeFts( + chunks=[], + symbols=[ + {"symbol_id": "repo/a.py::function::foo::L5", "file_id": "repo/a.py", "name": "foo", "kind": "function"}, + {"symbol_id": "repo/b.ts::function::foo::L3", "file_id": "repo/b.ts", "name": "foo", "kind": "function"}, + ], + ) + + code_index = SimpleNamespace( + chunks_for_file=lambda fid: { + "repo/a.py": [("repo/a.py#L1-L10", 1, 10)], + "repo/b.ts": [("repo/b.ts#L1-L10", 1, 10)], + }.get(fid, []), + file_language=lambda fid: {"repo/a.py": "python", "repo/b.ts": "typescript"}.get(fid), + ) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=code_index, + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + + hits = fusion.retrieve_chunks("foo", n_results=5, language="python") + ids = [h["chunk_id"] for h in hits] + assert "repo/b.ts#L1-L10" not in ids + assert "repo/a.py#L1-L10" in ids + + +def test_empty_lexical_source_does_not_break_fusion(): + a = _chunk("repo/a.py#L1-L10", "repo/a.py") + dense = _FakeDense([a]) + fts = _FakeFts(chunks=[], symbols=[]) + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=_fake_code_index({}), + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("anything", n_results=5) + assert [h["chunk_id"] for h in hits] == ["repo/a.py#L1-L10"] From b15b87ed6e3c55ef089a26dbd7933fe0c16279b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 12:08:15 +0200 Subject: [PATCH 06/20] feat(ingestion): write chunks and symbols to FtsStorage alongside Chroma --- docstra/core/indexing/code_index.py | 4 ++ docstra/core/ingestion/storage.py | 16 +++++ docstra/core/services/ingestion_service.py | 9 +++ tests/test_ingestion_fts.py | 79 ++++++++++++++++++++++ 4 files changed, 108 insertions(+) create mode 100644 tests/test_ingestion_fts.py diff --git a/docstra/core/indexing/code_index.py b/docstra/core/indexing/code_index.py index 9263e24..4d033cd 100644 --- a/docstra/core/indexing/code_index.py +++ b/docstra/core/indexing/code_index.py @@ -567,3 +567,7 @@ def index_documents(self, documents: List[Document]) -> None: def get_index(self) -> CodebaseIndex: """Get the underlying codebase index.""" return self.index + + def get_manifest(self) -> CoreIndexManifest: + """Return the in-memory manifest built during indexing.""" + return self.index.manifest diff --git a/docstra/core/ingestion/storage.py b/docstra/core/ingestion/storage.py index ad8ec4e..18ef79f 100644 --- a/docstra/core/ingestion/storage.py +++ b/docstra/core/ingestion/storage.py @@ -15,6 +15,7 @@ from docstra.core.document_processing.document import Document from docstra.core.indexing.model import make_chunk_id, normalize_file_id +from docstra.core.ingestion.fts_storage import FtsStorage ChromaScalar = str | int | float | bool ChromaMetadata = Metadata @@ -399,16 +400,20 @@ def __init__( storage: ChromaDBStorage, embedding_generator: Any, codebase_root: Optional[str] = None, + fts_storage: Optional[FtsStorage] = None, ): """Initialize the document indexer. Args: storage: ChromaDB storage embedding_generator: Generator for creating embeddings + codebase_root: Root directory of the codebase + fts_storage: Optional FTS store for lexical retrieval """ self.storage = storage self.embedding_generator = embedding_generator self.codebase_root = codebase_root + self.fts_storage = fts_storage def _prepare_metadata_for_chroma(self, metadata) -> dict: """Convert document metadata to ChromaDB-compatible format. @@ -537,6 +542,17 @@ def index_document(self, document: Document) -> str: embeddings=chunk_embeddings, ) + if self.fts_storage is not None: + self.fts_storage.delete_by_file(doc_id) + self.fts_storage.add_chunks( + chunk_ids=chunk_ids, + file_ids=[doc_id] * len(chunk_ids), + languages=[str(document.metadata.language)] * len(chunk_ids), + start_lines=[chunk.start_line for chunk in document.chunks], + end_lines=[chunk.end_line for chunk in document.chunks], + contents=chunk_contents, + ) + return persisted_doc_id def index_documents(self, documents: List[Document]) -> List[str]: diff --git a/docstra/core/services/ingestion_service.py b/docstra/core/services/ingestion_service.py index 313cfc3..287be98 100644 --- a/docstra/core/services/ingestion_service.py +++ b/docstra/core/services/ingestion_service.py @@ -29,6 +29,7 @@ SyntaxAwareChunking, ) from docstra.core.ingestion.embeddings import EmbeddingFactory +from docstra.core.ingestion.fts_storage import FtsStorage from docstra.core.ingestion.storage import ChromaDBStorage, DocumentIndexer from docstra.core.indexing.code_index import CodebaseIndex, CodebaseIndexer from docstra.core.indexing.model import CORE_INDEX_FILENAME @@ -90,6 +91,9 @@ def ingest_codebase( legacy_repo_map = persist_directory / "repo_map.json" if legacy_repo_map.exists(): legacy_repo_map.unlink() + index_db_path = persist_directory / "index.db" + if index_db_path.exists(): + index_db_path.unlink() index_path = persist_directory / "index" core_index_path = index_path / CORE_INDEX_FILENAME @@ -147,11 +151,13 @@ def ingest_codebase( ) storage = ChromaDBStorage(persist_directory=str(persist_directory / "chroma")) + fts_storage = FtsStorage(str(persist_directory / "index.db")) doc_indexer = DocumentIndexer( storage, embedding_generator, codebase_root=str(codebase_path_abs), + fts_storage=fts_storage, ) code_indexer = CodebaseIndexer( @@ -272,6 +278,9 @@ def ingest_codebase( doc_indexer.index_documents(documents) code_indexer.index_documents(documents) + manifest = code_indexer.get_manifest() + fts_storage.add_symbols(list(manifest.symbols)) + progress.update( task_index, completed=True, description="[green]Indexed all documents" ) diff --git a/tests/test_ingestion_fts.py b/tests/test_ingestion_fts.py new file mode 100644 index 0000000..207ba37 --- /dev/null +++ b/tests/test_ingestion_fts.py @@ -0,0 +1,79 @@ +"""Verify ingestion writes chunks and symbols into the FTS store alongside Chroma.""" + +from pathlib import Path +from typing import List + +from docstra.core.document_processing.document import ( + CodeChunk, + Document, + DocumentMetadata, + DocumentType, +) +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.ingestion.storage import ChromaDBStorage, DocumentIndexer + + +class _FakeEmbedder: + """Returns a deterministic small vector so Chroma is happy without a real model.""" + + def generate_embedding(self, _text: str) -> List[float]: + return [0.0, 1.0, 0.0] + + +def _make_document(filepath: str, chunk_content: str) -> Document: + metadata = DocumentMetadata( + filepath=filepath, + language=DocumentType.PYTHON, + size_bytes=len(chunk_content.encode("utf-8")), + last_modified=0.0, + line_count=chunk_content.count("\n") + 1, + ) + return Document( + content=chunk_content, + metadata=metadata, + chunks=[ + CodeChunk( + content=chunk_content, + start_line=1, + end_line=metadata.line_count, + chunk_type="function", + ) + ], + ) + + +def test_document_indexer_writes_chunks_to_fts(tmp_path: Path) -> None: + chroma = ChromaDBStorage(persist_directory=str(tmp_path / "chroma")) + fts = FtsStorage(str(tmp_path / "index.db")) + + document = _make_document("repo/foo.py", "def find_me(): pass") + + indexer = DocumentIndexer( + chroma, + embedding_generator=_FakeEmbedder(), + codebase_root=str(tmp_path), + fts_storage=fts, + ) + indexer.index_document(document) + + hits = fts.search_chunks("find_me", n_results=5) + assert len(hits) == 1 + assert hits[0]["file_id"] == "repo/foo.py" + + +def test_document_indexer_reindex_replaces_chunks(tmp_path: Path) -> None: + """Re-indexing the same file should not duplicate chunks in the FTS store.""" + chroma = ChromaDBStorage(persist_directory=str(tmp_path / "chroma")) + fts = FtsStorage(str(tmp_path / "index.db")) + indexer = DocumentIndexer( + chroma, + embedding_generator=_FakeEmbedder(), + codebase_root=str(tmp_path), + fts_storage=fts, + ) + + indexer.index_document(_make_document("repo/foo.py", "def find_me(): pass")) + indexer.index_document(_make_document("repo/foo.py", "def find_me(): pass")) + + hits = fts.search_chunks("find_me", n_results=5) + assert len(hits) == 1 From 97334d148ae9bd6d73092a4a2617fddb3680a50c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 12:29:20 +0200 Subject: [PATCH 07/20] feat(query): use FusionRetriever as the base for context-aware retrieval --- docstra/core/services/query_service.py | 33 +++++++++++++++++++++----- tests/test_index_loading.py | 2 ++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/docstra/core/services/query_service.py b/docstra/core/services/query_service.py index 7b72551..7bdf012 100644 --- a/docstra/core/services/query_service.py +++ b/docstra/core/services/query_service.py @@ -20,8 +20,10 @@ from docstra.core.retrieval.chroma import ChromaRetriever from docstra.core.indexing.code_index import CodebaseIndex, CodebaseIndexer from docstra.core.indexing.model import CORE_INDEX_FILENAME -from docstra.core.retrieval.hybrid import HybridRetriever from docstra.core.retrieval.context_aware import ContextAwareRetriever +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.utils.token_counter import get_token_counter, ContextBudgetManager @@ -94,7 +96,9 @@ def __init__( self.storage: Optional[ChromaDBStorage] = None self.retriever: Optional[ChromaRetriever] = None self.code_indexer: Optional[CodebaseIndexer] = None - self.hybrid_retriever: Optional[HybridRetriever] = None + self.fts_storage: Optional[FtsStorage] = None + self.fts_retriever: Optional[FtsRetriever] = None + self.fusion_retriever: Optional[FusionRetriever] = None self.context_aware_retriever: Optional[ContextAwareRetriever] = None self.token_counter = get_token_counter( self.user_config.model.model_name, self.user_config.model.provider @@ -135,17 +139,24 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): legacy_index_artifacts = CodebaseIndex.legacy_artifacts_in(index_path) legacy_repo_map = effective_persist_dir / "repo_map.json" - if not core_index_path.exists() or not chroma_check_file.exists(): + index_db_path = effective_persist_dir / "index.db" + if not core_index_path.exists() or not chroma_check_file.exists() or not index_db_path.exists(): migration_hint = "" if legacy_index_artifacts or legacy_repo_map.exists(): migration_hint = ( " Legacy index artifacts were found. Rerun 'docstra ingest' " "to rebuild the index in the new format." ) + if not index_db_path.exists() and core_index_path.exists() and chroma_check_file.exists(): + migration_hint += ( + " The lexical index (.docstra/index.db) is missing — likely an older " + "ingest. Rerun 'docstra ingest' to rebuild it." + ) error_msg = ( f"Codebase at {abs_codebase_path} not fully initialized for querying. " f"ChromaDB path: {chroma_path} (check file: {chroma_check_file}, exists: {chroma_check_file.exists()}), " - f"Core index path: {core_index_path} (exists: {core_index_path.exists()}). " + f"Core index path: {core_index_path} (exists: {core_index_path.exists()}), " + f"Lexical index: {index_db_path} (exists: {index_db_path.exists()}). " "Run 'docstra init' and 'docstra ingest' first." f"{migration_hint}" ) @@ -166,7 +177,17 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): code_index_instance = self.code_indexer.get_index() if code_index_instance is None: raise ValueError(f"Failed to load code index from {index_path}") - self.hybrid_retriever = HybridRetriever(self.retriever, code_index_instance) + + self.fts_storage = FtsStorage(str(effective_persist_dir / "index.db")) + self.fts_retriever = FtsRetriever(self.fts_storage) + self.fusion_retriever = FusionRetriever( + dense=self.retriever, + fts=self.fts_retriever, + code_index=code_index_instance, + rrf_k=self.user_config.retrieval.rrf_k, + fts_chunks_top_k=self.user_config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=self.user_config.retrieval.fts_symbols_top_k, + ) # Initialize context-aware retriever repo_map = None @@ -181,7 +202,7 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): ) self.context_aware_retriever = ContextAwareRetriever( - base_retriever=self.retriever, + base_retriever=self.fusion_retriever, budget_manager=self.budget_manager, code_index=code_index_instance, repo_map=repo_map, diff --git a/tests/test_index_loading.py b/tests/test_index_loading.py index 082ae71..9f6c333 100644 --- a/tests/test_index_loading.py +++ b/tests/test_index_loading.py @@ -10,6 +10,7 @@ DocumentType, ) from docstra.core.indexing.model import CORE_INDEX_FILENAME, CoreIndexBuilder +from docstra.core.ingestion.fts_storage import FtsStorage from docstra.core.ingestion.storage import ChromaDBStorage from docstra.core.services.query_service import QueryService from docstra.core.services.repository_explorer_service import RepositoryExplorerService @@ -62,6 +63,7 @@ def _write_core_index(codebase_root: Path) -> None: manifest.model_dump_json(indent=2), encoding="utf-8" ) ChromaDBStorage(persist_directory=str(persist_dir / "chroma")) + FtsStorage(str(persist_dir / "index.db")) def test_query_service_initializes_from_core_index_without_repo_map( From 1a3c6de4b7d88c5f8dba0f7a9e3037b21517deb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 12:58:54 +0200 Subject: [PATCH 08/20] refactor(retrieval): replace HybridRetriever with FusionRetriever at all call sites --- docstra/core/__init__.py | 21 ++-- docstra/core/cli.py | 19 +++- docstra/core/documentation/generator.py | 24 +++-- docstra/core/retrieval/context_aware.py | 21 ++-- docstra/core/retrieval/fusion.py | 101 +++++++++++++++++- .../core/services/documentation_service.py | 1 + tests/test_fusion_retriever.py | 23 ++++ 7 files changed, 180 insertions(+), 30 deletions(-) diff --git a/docstra/core/__init__.py b/docstra/core/__init__.py index db9ac88..a86a582 100644 --- a/docstra/core/__init__.py +++ b/docstra/core/__init__.py @@ -39,8 +39,10 @@ from docstra.core.llm.local import LocalModelClient from docstra.core.llm.ollama import OllamaClient from docstra.core.llm.openai import OpenAIClient +from docstra.core.ingestion.fts_storage import FtsStorage from docstra.core.retrieval.chroma import ChromaRetriever -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever class docstraant: @@ -110,9 +112,16 @@ def setup_components(self): codebase_root=str(Path.cwd()), ) - # Hybrid retriever - self.hybrid_retriever = HybridRetriever( - self.retriever, self.code_indexer.get_index() + # Fusion retriever + fts_storage = FtsStorage(f"{storage_dir}/index.db") + fts_retriever = FtsRetriever(fts_storage) + self.fusion_retriever = FusionRetriever( + dense=self.retriever, + fts=fts_retriever, + code_index=self.code_indexer.get_index(), + rrf_k=self.config.retrieval.rrf_k, + fts_chunks_top_k=self.config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=self.config.retrieval.fts_symbols_top_k, ) # LLM client @@ -287,8 +296,8 @@ def answer_question(self, question: str, n_results: int = 5) -> str: Generated answer """ # Retrieve relevant chunks - results = self.hybrid_retriever.retrieve( - query=question, n_results=n_results, use_code_context=True + results = self.fusion_retriever.retrieve( + query=question, n_results=n_results ) # Generate answer diff --git a/docstra/core/cli.py b/docstra/core/cli.py index 7ab8cbf..060ab26 100644 --- a/docstra/core/cli.py +++ b/docstra/core/cli.py @@ -50,7 +50,9 @@ RetrievalEvalSummary, evaluate_retrieval_cases, ) -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.services.initialization_service import InitializationService from docstra.core.services.ingestion_service import IngestionService from docstra.core.services.query_service import QueryService @@ -1689,7 +1691,7 @@ def _get_persist_paths( def _create_retrieval_eval_runner( user_config: UserConfig, abs_codebase_path: Path ) -> Callable[[str, int], List[Dict[str, Any]]]: - _, chroma_path, index_path = _get_persist_paths(user_config, abs_codebase_path) + effective_persist_dir, chroma_path, index_path = _get_persist_paths(user_config, abs_codebase_path) core_index_path = index_path / CORE_INDEX_FILENAME chroma_check_file = chroma_path / "chroma.sqlite3" legacy_index_artifacts = CodebaseIndex.legacy_artifacts_in(index_path) @@ -1730,10 +1732,19 @@ def _create_retrieval_eval_runner( code_index = code_indexer.get_index() if code_index: - hybrid_retriever = HybridRetriever(base_retriever, code_index) + fts_storage = FtsStorage(str(effective_persist_dir / "index.db")) + fts_retriever = FtsRetriever(fts_storage) + fusion_retriever = FusionRetriever( + dense=base_retriever, + fts=fts_retriever, + code_index=code_index, + rrf_k=user_config.retrieval.rrf_k, + fts_chunks_top_k=user_config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=user_config.retrieval.fts_symbols_top_k, + ) def retrieve(question: str, top_k: int) -> List[Dict[str, Any]]: - return hybrid_retriever.retrieve(question, n_results=top_k) + return fusion_retriever.retrieve(question, n_results=top_k) return retrieve diff --git a/docstra/core/documentation/generator.py b/docstra/core/documentation/generator.py index 7723a1b..ab8c860 100644 --- a/docstra/core/documentation/generator.py +++ b/docstra/core/documentation/generator.py @@ -28,7 +28,9 @@ from docstra.core.document_processing.document import Document from docstra.core.indexing.repo_map import RepositoryMap from docstra.core.retrieval.chroma import ChromaRetriever -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.retrieval.fts import FtsRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.indexing.code_index import CodebaseIndex from docstra.core.documentation.prompts import ( EnhancedDocumentationPrompts, @@ -114,6 +116,7 @@ def __init__( max_workers: Optional[int] = None, documentation_depth: str = "comprehensive", # "overview", "standard", "comprehensive" style_guide: Optional[str] = None, + persist_directory: Optional[Union[str, Path]] = None, ): """Initialize the enhanced documentation generator. @@ -129,6 +132,7 @@ def __init__( max_workers: Maximum number of worker threads documentation_depth: Level of documentation detail to generate style_guide: Custom style guide for documentation + persist_directory: Persist directory root (needed to locate index.db for FTS) """ self.llm_client = llm_client self.output_dir = Path(output_dir) @@ -145,11 +149,15 @@ def __init__( # Enhanced progress reporting self.progress_reporter = DocumentationProgressReporter(self.console) - # Set up hybrid retriever if available - self.hybrid_retriever = None - if self.chroma_retriever and self.code_index: - self.hybrid_retriever = HybridRetriever( - self.chroma_retriever, self.code_index + # Set up fusion retriever if chroma retriever, code index, and persist dir are available + self.fusion_retriever = None + if self.chroma_retriever and self.code_index and persist_directory: + fts_storage = FtsStorage(str(Path(persist_directory) / "index.db")) + fts_retriever = FtsRetriever(fts_storage) + self.fusion_retriever = FusionRetriever( + dense=self.chroma_retriever, + fts=fts_retriever, + code_index=self.code_index, ) # Documentation state @@ -695,7 +703,7 @@ def _build_file_context(self, document: Document) -> str: ) # Add cross-references - if self.hybrid_retriever: + if self.fusion_retriever: cross_refs = self._get_file_cross_references(document) if cross_refs: context_parts.append( @@ -754,7 +762,7 @@ def _get_similar_code_examples(self, document: Document) -> List[Dict[str, Any]] def _get_file_cross_references(self, document: Document) -> List[str]: """Get cross-references for a file.""" - if not self.hybrid_retriever or not self.chroma_retriever: + if not self.fusion_retriever or not self.chroma_retriever: return [] try: diff --git a/docstra/core/retrieval/context_aware.py b/docstra/core/retrieval/context_aware.py index 1b59817..64460f4 100644 --- a/docstra/core/retrieval/context_aware.py +++ b/docstra/core/retrieval/context_aware.py @@ -11,7 +11,7 @@ from docstra.core.indexing.code_index import CodebaseIndex from docstra.core.indexing.repo_map import RepositoryMap from docstra.core.retrieval.chroma import ChromaRetriever -from docstra.core.retrieval.hybrid import HybridRetriever +from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.utils.token_counter import ContextBudgetManager @@ -73,11 +73,11 @@ def __init__( self.code_index = code_index self.repo_map = repo_map - # Create hybrid retriever if code index available - if code_index: - self.hybrid_retriever = HybridRetriever(base_retriever, code_index) + # Accept a FusionRetriever if one was passed in (normal path via query_service) + if hasattr(base_retriever, "retrieve_code_examples"): + self.fusion_retriever = base_retriever else: - self.hybrid_retriever = None + self.fusion_retriever = None def retrieve_with_budget( self, query: str, context_type: str = "query", **kwargs: Any @@ -383,12 +383,11 @@ def _get_general_context( ) -> Dict[str, Any]: """Get balanced general context for queries without clear intent.""" - # Use hybrid retrieval if available, otherwise fall back to base retriever - if self.hybrid_retriever: - results = self.hybrid_retriever.retrieve( + # Use fusion retrieval if available, otherwise fall back to base retriever + if self.fusion_retriever: + results = self.fusion_retriever.retrieve( query=query, n_results=10, # Get more initially for filtering - use_code_context=True, ) else: results = self.base_retriever.retrieve_chunks(query=query, n_results=8) @@ -571,12 +570,12 @@ def _get_targeted_code_samples( ) -> Optional[str]: """Get targeted code samples based on query analysis.""" - if not self.hybrid_retriever: + if not self.fusion_retriever: return None # Get code examples using hybrid retrieval try: - examples = self.hybrid_retriever.retrieve_code_examples( + examples = self.fusion_retriever.retrieve_code_examples( query=query, n_results=3 ) diff --git a/docstra/core/retrieval/fusion.py b/docstra/core/retrieval/fusion.py index fed9b1e..2301076 100644 --- a/docstra/core/retrieval/fusion.py +++ b/docstra/core/retrieval/fusion.py @@ -76,7 +76,106 @@ def retrieve_by_language( def retrieve_code_examples( self, query: str, n_results: int = 10, languages: Optional[List[str]] = None ) -> List[Dict[str, Any]]: - raise NotImplementedError("retrieve_code_examples is moved over in Task 7") + """Retrieve chunks that are good examples of the queried concept. + + Args: + query: Query string + n_results: Number of results to return + languages: Optional list of languages to filter by + + Returns: + List of example chunks + """ + # Start with basic vector search + filters = {} + if languages: + # We'll retrieve for each language separately + all_results = [] + for language in languages: + results = self.retrieve_by_language( + query=query, + language=language, + n_results=max(n_results // len(languages), 1), + ) + all_results.extend(results) + + vector_results = all_results + else: + vector_results = self.retrieve_chunks( + query=query, + n_results=n_results * 2, # Get more for filtering + **filters, + ) + + # Filter for chunks that are likely to be good examples + # - Prefer complete functions/methods + # - Prefer moderately sized chunks (not too short, not too long) + # - Prefer chunks with meaningful names + good_examples = [] + + for chunk in vector_results: + chunk_type = chunk["metadata"].get("chunk_type", "") + content = chunk["content"] + + # Score the chunk as an example + example_score = 0.0 + + # Prefer functions/methods + if chunk_type in ["function", "method"]: + example_score += 1.0 + + # Check content length (not too short, not too long) + lines = content.count("\n") + 1 + if 5 <= lines <= 50: + example_score += 0.5 + + # Look for meaningful names (more than 3 characters, not generic) + symbols = chunk["metadata"].get("symbols", []) + generic_symbols = [ + "main", + "init", + "test", + "get", + "set", + "run", + "func", + "foo", + "bar", + ] + + for symbol in symbols: + if len(symbol) > 3 and symbol.lower() not in generic_symbols: + example_score += 0.3 + break + + # Use original vector score + vector_score = chunk.get("score", 0) + if vector_score is not None: + # Combine scores (vector score is typically a distance, so lower is better) + combined_score = example_score - vector_score + else: + combined_score = example_score + + chunk_id = self._chunk_id(chunk) + good_examples.append( + { + "chunk_id": chunk_id, + "id": chunk_id, + "content": content, + "metadata": chunk["metadata"], + "score": combined_score, + "original_score": vector_score, + } + ) + + # Sort by combined score and return top results + sorted_examples = sorted( + good_examples, + key=lambda x: x.get("score", 0), + reverse=True, # Higher score is better + ) + + return sorted_examples[:n_results] def _chunk_id(self, hit: Dict[str, Any]) -> Optional[str]: return hit.get("chunk_id") or hit.get("id") diff --git a/docstra/core/services/documentation_service.py b/docstra/core/services/documentation_service.py index a94cbd9..f37f6ab 100644 --- a/docstra/core/services/documentation_service.py +++ b/docstra/core/services/documentation_service.py @@ -314,6 +314,7 @@ def generate_documentation( max_workers=effective_max_workers, documentation_depth="comprehensive", style_guide=effective_llm_style_prompt, + persist_directory=abs_persist_directory, ) self.console.print( diff --git a/tests/test_fusion_retriever.py b/tests/test_fusion_retriever.py index 91f0ca2..75cf1a7 100644 --- a/tests/test_fusion_retriever.py +++ b/tests/test_fusion_retriever.py @@ -154,3 +154,26 @@ def test_empty_lexical_source_does_not_break_fusion(): ) hits = fusion.retrieve_chunks("anything", n_results=5) assert [h["chunk_id"] for h in hits] == ["repo/a.py#L1-L10"] + + +def test_retrieve_code_examples_prefers_function_chunks(): + func_chunk = _chunk("repo/a.py#L1-L20", "repo/a.py") + func_chunk["metadata"]["chunk_type"] = "function" + func_chunk["content"] = "def good():\n" + " pass\n" * 10 # 11 lines + tiny_chunk = _chunk("repo/b.py#L1-L2", "repo/b.py") + tiny_chunk["metadata"]["chunk_type"] = "other" + tiny_chunk["content"] = "x = 1\n" + + dense = _FakeDense([tiny_chunk, func_chunk]) + fts = _FakeFts(chunks=[], symbols=[]) + + fusion = FusionRetriever( + dense=dense, + fts=fts, + code_index=_fake_code_index({}), + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + examples = fusion.retrieve_code_examples("anything", n_results=2) + assert examples[0]["chunk_id"] == "repo/a.py#L1-L20" From 4dd4a259ce8b03d14da3a76c245f30cab1a8bb84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 13:36:58 +0200 Subject: [PATCH 09/20] refactor(retrieval): remove obsolete HybridRetriever and tidy generator config wiring --- docstra/core/documentation/generator.py | 21 +- docstra/core/retrieval/hybrid.py | 531 ------------------ .../core/services/documentation_service.py | 1 + 3 files changed, 16 insertions(+), 537 deletions(-) delete mode 100644 docstra/core/retrieval/hybrid.py diff --git a/docstra/core/documentation/generator.py b/docstra/core/documentation/generator.py index ab8c860..24fe794 100644 --- a/docstra/core/documentation/generator.py +++ b/docstra/core/documentation/generator.py @@ -117,6 +117,7 @@ def __init__( documentation_depth: str = "comprehensive", # "overview", "standard", "comprehensive" style_guide: Optional[str] = None, persist_directory: Optional[Union[str, Path]] = None, + user_config: Optional[Any] = None, ): """Initialize the enhanced documentation generator. @@ -133,6 +134,7 @@ def __init__( documentation_depth: Level of documentation detail to generate style_guide: Custom style guide for documentation persist_directory: Persist directory root (needed to locate index.db for FTS) + user_config: UserConfig instance for retrieval settings """ self.llm_client = llm_client self.output_dir = Path(output_dir) @@ -154,11 +156,18 @@ def __init__( if self.chroma_retriever and self.code_index and persist_directory: fts_storage = FtsStorage(str(Path(persist_directory) / "index.db")) fts_retriever = FtsRetriever(fts_storage) - self.fusion_retriever = FusionRetriever( - dense=self.chroma_retriever, - fts=fts_retriever, - code_index=self.code_index, - ) + kwargs = { + "dense": self.chroma_retriever, + "fts": fts_retriever, + "code_index": self.code_index, + } + if user_config and hasattr(user_config, 'retrieval'): + kwargs.update({ + "rrf_k": user_config.retrieval.rrf_k, + "fts_chunks_top_k": user_config.retrieval.fts_chunks_top_k, + "fts_symbols_top_k": user_config.retrieval.fts_symbols_top_k, + }) + self.fusion_retriever = FusionRetriever(**kwargs) # Documentation state self.processed_documents: Dict[str, Document] = {} @@ -762,7 +771,7 @@ def _get_similar_code_examples(self, document: Document) -> List[Dict[str, Any]] def _get_file_cross_references(self, document: Document) -> List[str]: """Get cross-references for a file.""" - if not self.fusion_retriever or not self.chroma_retriever: + if not self.chroma_retriever: return [] try: diff --git a/docstra/core/retrieval/hybrid.py b/docstra/core/retrieval/hybrid.py deleted file mode 100644 index 35d80b4..0000000 --- a/docstra/core/retrieval/hybrid.py +++ /dev/null @@ -1,531 +0,0 @@ -# File: ./docstra/core/retrieval/hybrid.py - -""" -Hybrid retrieval strategies for enhanced document search. -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from docstra.core.indexing.code_index import CodebaseIndex -from docstra.core.retrieval.chroma import ChromaRetriever - - -class HybridRetriever: - """Hybrid retriever combining vector search with structural code information.""" - - def __init__( - self, retriever: ChromaRetriever, code_index: Optional[CodebaseIndex] = None - ): - """Initialize the hybrid retriever. - - Args: - retriever: Base retriever for vector search - code_index: Code index for structural information - """ - self.retriever = retriever - self.code_index = code_index - - def retrieve( - self, query: str, n_results: int = 20, use_code_context: bool = True, **filters - ) -> List[Dict[str, Any]]: - """Perform hybrid retrieval using both vector search and code structure. - - Args: - query: Query string - n_results: Number of results to return - use_code_context: Whether to use code context for re-ranking - **filters: Additional filters to apply - - Returns: - List of matching chunks - """ - # Start with vector search - vector_results = self.retriever.retrieve_chunks( - query=query, - n_results=n_results * 2, # Get more results for reranking - **filters, - ) - - if not use_code_context or not self.code_index: - # If code context not requested or not available, return vector results - return vector_results[:n_results] - - # Extract potential code symbols from query - potential_symbols = self._extract_potential_symbols(query) - - # Find symbols in the code index - symbol_matches = [] - for symbol in potential_symbols: - symbol_locs = self.code_index.search_symbol(symbol) - if symbol_locs: - symbol_matches.extend(symbol_locs) - - # Re-rank results based on symbol matches - reranked_results = self._rerank_with_symbol_matches( - vector_results, symbol_matches, n_results - ) - - return reranked_results - - def _extract_potential_symbols(self, query: str) -> List[str]: - """Extract potential code symbols from a query. - - Args: - query: Query string - - Returns: - List of potential symbols - """ - # This is a simplified approach. A more sophisticated approach would use - # NLP techniques to identify potential code symbols. - - # Split by whitespace and punctuation - words = query.split() - - # Filter out common words and keep only likely symbols - stop_words = { - "the", - "a", - "an", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "as", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "do", - "does", - "did", - "can", - "could", - "will", - "would", - "shall", - "should", - "may", - "might", - "must", - "how", - "when", - "where", - "why", - "what", - "who", - "whom", - "which", - "if", - "then", - "else", - "so", - "such", - "and", - "or", - "not", - "no", - "yes", - "this", - "that", - "these", - "those", - "code", - "function", - "method", - "class", - "variable", - "import", - "implement", - "define", - "declaration", - } - - symbols = [ - word.strip(",.()[]{}:;\"'") - for word in words - if word.strip(",.()[]{}:;\"'").lower() not in stop_words - and len(word.strip(",.()[]{}:;\"'")) >= 2 # Minimum length - ] - - return symbols - - def _rerank_with_symbol_matches( - self, - vector_results: List[Dict[str, Any]], - symbol_matches: List[Dict[str, Any]], - n_results: int, - ) -> List[Dict[str, Any]]: - """Re-rank results based on symbol matches. - - Args: - vector_results: Results from vector search - symbol_matches: Matches from symbol search - n_results: Number of results to return - - Returns: - Re-ranked results - """ - # Create a set of files with symbol matches - symbol_files = {match["filepath"] for match in symbol_matches} - - # Create a dictionary to track scores - result_scores: Dict[str, float] = {} - - # Score based on vector search position - for i, result in enumerate(vector_results): - chunk_id = result["id"] - # Base score from vector search (higher for early results) - score = 1.0 - (i / len(vector_results)) - - # Get document ID from chunk metadata - doc_id = result["metadata"].get("document_id", "") - - # Boost score if document contains symbol matches - if doc_id in symbol_files: - score += 0.5 - - result_scores[chunk_id] = score - - # Sort results by score - reranked_ids = sorted( - result_scores.keys(), key=lambda x: result_scores[x], reverse=True - ) - - # Create final results list - reranked_results = [] - id_to_result = {result["id"]: result for result in vector_results} - - for chunk_id in reranked_ids[:n_results]: - if chunk_id in id_to_result: - reranked_results.append(id_to_result[chunk_id]) - - return reranked_results - - def retrieve_for_function( - self, query: str, function_name: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve chunks relevant to a specific function. - - Args: - query: Query string - function_name: Name of the function - n_results: Number of results to return - - Returns: - List of matching chunks - """ - if not self.code_index: - return self.retriever.retrieve_chunks(query, n_results) - - # Find files containing the function - function_locs = self.code_index.search_function(function_name) - - if not function_locs: - return self.retriever.retrieve_chunks(query, n_results) - - # Get relevant file paths - file_paths = [loc["filepath"] for loc in function_locs] - - # Combine results from each file - all_results = [] - for filepath in file_paths: - results = self.retriever.retrieve_by_filepath( - query=query, filepath=filepath, n_results=n_results - ) - all_results.extend(results) - - # Sort by relevance and return top results - sorted_results = sorted( - all_results, - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] - - def retrieve_for_class( - self, query: str, class_name: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve chunks relevant to a specific class. - - Args: - query: Query string - class_name: Name of the class - n_results: Number of results to return - - Returns: - List of matching chunks - """ - if not self.code_index: - return self.retriever.retrieve_chunks(query, n_results) - - # Find files containing the class - class_locs = self.code_index.search_class(class_name) - - if not class_locs: - return self.retriever.retrieve_chunks(query, n_results) - - # Get relevant file paths - file_paths = [loc["filepath"] for loc in class_locs] - - # Combine results from each file - all_results = [] - for filepath in file_paths: - results = self.retriever.retrieve_by_filepath( - query=query, filepath=filepath, n_results=n_results - ) - all_results.extend(results) - - # Sort by relevance and return top results - sorted_results = sorted( - all_results, - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] - - def retrieve_related_code( - self, query: str, chunk_id: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve code chunks related to a specific chunk. - - Args: - query: Query string - chunk_id: ID of the chunk to find related code for - n_results: Number of results to return - - Returns: - List of related chunks - """ - # Extract document ID from chunk ID - if "#" in chunk_id: - document_id = chunk_id.split("#")[0] - else: - document_id = chunk_id - - # Get all chunks for the document - document_chunks = self.retriever.get_chunks_for_document(document_id) - - if not document_chunks: - return self.retriever.retrieve_chunks(query, n_results) - - # If we have a code index, also get related files - related_files = [] - if self.code_index: - related_files = self.code_index.get_related_files(document_id) - - # Combine results from document chunks and related files - all_results = [] - - # Add document chunks (with high priority) - for chunk in document_chunks: - all_results.append( - { - "id": chunk["id"], - "content": chunk["content"], - "metadata": chunk["metadata"], - "score": 0.0, # High priority - } - ) - - # Get chunks from related files - if related_files: - for filepath in related_files: - results = self.retriever.retrieve_by_filepath( - query=query, filepath=filepath, n_results=n_results - ) - all_results.extend(results) - - # Remove duplicates (keeping highest score) - unique_results = {} - for result in all_results: - result_id = result["id"] - if result_id not in unique_results or ( - result.get("score", float("inf")) - < unique_results[result_id].get("score", float("inf")) - ): - unique_results[result_id] = result - - # Sort by score and return top results - sorted_results = sorted( - unique_results.values(), - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] - - def retrieve_code_examples( - self, query: str, n_results: int = 10, languages: Optional[List[str]] = None - ) -> List[Dict[str, Any]]: - """Retrieve chunks that are good examples of the queried concept. - - Args: - query: Query string - n_results: Number of results to return - languages: Optional list of languages to filter by - - Returns: - List of example chunks - """ - # Start with basic vector search - filters = {} - if languages: - # We'll retrieve for each language separately - all_results = [] - for language in languages: - results = self.retriever.retrieve_by_language( - query=query, - language=language, - n_results=max(n_results // len(languages), 1), - ) - all_results.extend(results) - - vector_results = all_results - else: - vector_results = self.retriever.retrieve_chunks( - query=query, - n_results=n_results * 2, # Get more for filtering - **filters, - ) - - # Filter for chunks that are likely to be good examples - # - Prefer complete functions/methods - # - Prefer moderately sized chunks (not too short, not too long) - # - Prefer chunks with meaningful names - good_examples = [] - - for chunk in vector_results: - chunk_type = chunk["metadata"].get("chunk_type", "") - content = chunk["content"] - - # Score the chunk as an example - example_score = 0.0 - - # Prefer functions/methods - if chunk_type in ["function", "method"]: - example_score += 1.0 - - # Check content length (not too short, not too long) - lines = content.count("\n") + 1 - if 5 <= lines <= 50: - example_score += 0.5 - - # Look for meaningful names (more than 3 characters, not generic) - symbols = chunk["metadata"].get("symbols", []) - generic_symbols = [ - "main", - "init", - "test", - "get", - "set", - "run", - "func", - "foo", - "bar", - ] - - for symbol in symbols: - if len(symbol) > 3 and symbol.lower() not in generic_symbols: - example_score += 0.3 - break - - # Use original vector score - vector_score = chunk.get("score", 0) - if vector_score is not None: - # Combine scores (vector score is typically a distance, so lower is better) - combined_score = example_score - vector_score - else: - combined_score = example_score - - good_examples.append( - { - "id": chunk["id"], - "content": content, - "metadata": chunk["metadata"], - "score": combined_score, - "original_score": vector_score, - } - ) - - # Sort by combined score and return top results - sorted_examples = sorted( - good_examples, - key=lambda x: x.get("score", 0), - reverse=True, # Higher score is better - ) - - return sorted_examples[:n_results] - - def retrieve_implementation_details( - self, query: str, symbol: str, n_results: int = 10 - ) -> List[Dict[str, Any]]: - """Retrieve implementation details for a specific symbol. - - Args: - query: Query string - symbol: Symbol to find implementation details for - n_results: Number of results to return - - Returns: - List of chunks with implementation details - """ - if not self.code_index: - # Without code index, fall back to basic search - return self.retriever.retrieve_chunks(query, n_results) - - # Find symbol locations - symbol_locs = self.code_index.search_symbol(symbol) - - if not symbol_locs: - # Try function and class indexes if symbol not found - function_locs = self.code_index.search_function(symbol) - class_locs = self.code_index.search_class(symbol) - - symbol_locs = function_locs + class_locs - - if not symbol_locs: - # If still not found, fall back to basic search - return self.retriever.retrieve_chunks(query, n_results) - - # Get relevant file paths - file_paths = [loc["filepath"] for loc in symbol_locs] - - # Combine results from each file - all_results = [] - for filepath in file_paths: - results = self.retriever.retrieve_by_filepath( - query=( - query if query else symbol - ), # Use symbol as query if no query provided - filepath=filepath, - n_results=n_results, - ) - all_results.extend(results) - - # Sort by relevance and return top results - sorted_results = sorted( - all_results, - key=lambda x: ( - x.get("score", 0) if x.get("score") is not None else float("inf") - ), - ) - - return sorted_results[:n_results] diff --git a/docstra/core/services/documentation_service.py b/docstra/core/services/documentation_service.py index f37f6ab..9fb8d33 100644 --- a/docstra/core/services/documentation_service.py +++ b/docstra/core/services/documentation_service.py @@ -315,6 +315,7 @@ def generate_documentation( documentation_depth="comprehensive", style_guide=effective_llm_style_prompt, persist_directory=abs_persist_directory, + user_config=self.user_config, ) self.console.print( From 167a78d29f41e0095c6557681f33630c0e645cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 13:41:28 +0200 Subject: [PATCH 10/20] feat(ingestion): clear FTS DB on force or legacy-state reindex --- docstra/core/services/ingestion_service.py | 3 +++ tests/test_ingestion_force.py | 29 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 tests/test_ingestion_force.py diff --git a/docstra/core/services/ingestion_service.py b/docstra/core/services/ingestion_service.py index 287be98..b7fd934 100644 --- a/docstra/core/services/ingestion_service.py +++ b/docstra/core/services/ingestion_service.py @@ -112,6 +112,9 @@ def ingest_codebase( shutil.rmtree(index_path) if legacy_repo_map.exists(): legacy_repo_map.unlink() + index_db_path = persist_directory / "index.db" + if index_db_path.exists(): + index_db_path.unlink() # Check if already indexed and not forcing if core_index_path.exists() and not force: diff --git a/tests/test_ingestion_force.py b/tests/test_ingestion_force.py new file mode 100644 index 0000000..66d795f --- /dev/null +++ b/tests/test_ingestion_force.py @@ -0,0 +1,29 @@ +"""Verify that a forced reindex clears the legacy FTS database file.""" + +from pathlib import Path + +from docstra.core.config.settings import UserConfig +from docstra.core.services.ingestion_service import IngestionService + + +def test_force_reindex_removes_index_db(tmp_path: Path) -> None: + repo = tmp_path / "repo" + repo.mkdir() + (repo / "file.py").write_text("def hello(): return 1\n") + + persist_dir = repo / ".docstra" + persist_dir.mkdir() + stale_db = persist_dir / "index.db" + stale_db.write_bytes(b"stale-bytes-not-a-real-sqlite-file") + assert stale_db.exists() + + config = UserConfig() + config.storage.persist_directory = str(persist_dir) + service = IngestionService() + service.ingest_codebase(str(repo), config, force=True) + + # The forced reindex should have removed the stale FTS DB before rebuilding; + # the new one created by ingestion will be a valid SQLite file. + assert stale_db.exists(), "expected ingestion to recreate the FTS DB" + # And it should not be the stale bytes we put down. + assert stale_db.read_bytes()[:16] != b"stale-bytes-not-" From 630f925e445412ba180e9b943d745d8ee8c0f6b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Fri, 12 Jun 2026 14:56:45 +0200 Subject: [PATCH 11/20] test(retrieval): add fusion retriever to the evaluation harness --- tests/test_retrieval_evaluation.py | 281 +++++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) diff --git a/tests/test_retrieval_evaluation.py b/tests/test_retrieval_evaluation.py index 83adfc3..9e11a83 100644 --- a/tests/test_retrieval_evaluation.py +++ b/tests/test_retrieval_evaluation.py @@ -4,6 +4,8 @@ import sys from importlib import util from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List MODULE_PATH = ( Path(__file__).resolve().parents[1] @@ -154,3 +156,282 @@ def retrieve(question: str, candidate_k: int): assert summary.results[0].passed is True assert summary.results[0].rank == 2 assert summary.results[0].retrieved_files == ["a.py", "b.py"] + + +# --------------------------------------------------------------------------- +# Integration: Chroma-only vs. FusionRetriever side-by-side eval +# --------------------------------------------------------------------------- + +_REPO_ROOT = Path(__file__).resolve().parents[1] + +# Subset of real source files that map to the eval queries in evaluation.py. +# Keep the list small so the test stays fast (no network, local embeddings only). +_CORPUS_FILES: List[Dict[str, Any]] = [ + { + "path": "docstra/core/config/settings.py", + "abs": _REPO_ROOT / "docstra/core/config/settings.py", + }, + { + "path": "docstra/core/retrieval/chroma.py", + "abs": _REPO_ROOT / "docstra/core/retrieval/chroma.py", + }, + { + "path": "docstra/core/retrieval/context_aware.py", + "abs": _REPO_ROOT / "docstra/core/retrieval/context_aware.py", + }, + { + "path": "docstra/core/ingestion/storage.py", + "abs": _REPO_ROOT / "docstra/core/ingestion/storage.py", + }, + { + "path": "docstra/core/services/ingestion_service.py", + "abs": _REPO_ROOT / "docstra/core/services/ingestion_service.py", + }, + { + "path": "docstra/core/indexing/code_index.py", + "abs": _REPO_ROOT / "docstra/core/indexing/code_index.py", + }, + { + "path": "docstra/core/cli.py", + "abs": _REPO_ROOT / "docstra/core/cli.py", + }, + { + "path": "docstra/core/utils/token_counter.py", + "abs": _REPO_ROOT / "docstra/core/utils/token_counter.py", + }, + { + "path": "docstra/core/documentation/generator.py", + "abs": _REPO_ROOT / "docstra/core/documentation/generator.py", + }, +] + +_EVAL_CASES = [ + RetrievalEvalCase( + question="How does Docstra load and save user configuration?", + expected_files=["docstra/core/config/settings.py"], + ), + RetrievalEvalCase( + question="Where are files ingested into ChromaDB with embeddings?", + expected_files=[ + "docstra/core/services/ingestion_service.py", + "docstra/core/ingestion/storage.py", + ], + ), + RetrievalEvalCase( + question="How does Chroma retrieval search document chunks?", + expected_files=["docstra/core/retrieval/chroma.py"], + ), + RetrievalEvalCase( + question="Where does Docstra classify query intent for context retrieval?", + expected_files=["docstra/core/retrieval/context_aware.py"], + ), + RetrievalEvalCase( + question="How are code symbols indexed for later search?", + expected_files=["docstra/core/indexing/code_index.py"], + ), + RetrievalEvalCase( + question="Where is the docstra query CLI command implemented?", + expected_files=["docstra/core/cli.py"], + ), + RetrievalEvalCase( + question="How is context token budget calculated and enforced?", + expected_files=["docstra/core/utils/token_counter.py"], + ), + RetrievalEvalCase( + question="Where does documentation generation assemble code context?", + expected_files=["docstra/core/documentation/generator.py"], + ), +] + + +class _DummyEmbedder: + """Returns a deterministic 3-d embedding so Chroma is happy without a real model.""" + + def generate_embedding(self, text: str) -> List[float]: + # Simple hash-based spread so different files get different vectors. + h = hash(text[:64]) & 0xFFFFFF + a = ((h >> 16) & 0xFF) / 255.0 + b = ((h >> 8) & 0xFF) / 255.0 + c = (h & 0xFF) / 255.0 + return [a, b, c] + + +def _split_into_chunks(content: str, chunk_size: int = 40) -> List[Dict[str, Any]]: + """Split file content into fixed-size line chunks for the eval corpus.""" + lines = content.splitlines(keepends=True) + chunks = [] + for start in range(0, len(lines), chunk_size): + end = min(start + chunk_size, len(lines)) + chunk_text = "".join(lines[start:end]) + chunks.append( + { + "start_line": start + 1, + "end_line": end, + "content": chunk_text, + } + ) + return chunks + + +def test_chroma_vs_fusion_retrieval_eval(tmp_path: Path) -> None: + """Build a small real corpus and compare Chroma-only vs. FusionRetriever recall.""" + import pytest + + # Skip if any corpus file is missing (e.g., running from a partial checkout). + missing = [f["path"] for f in _CORPUS_FILES if not f["abs"].exists()] + if missing: + pytest.skip(f"corpus files not found: {missing}") + + from docstra.core.ingestion.fts_storage import FtsStorage + from docstra.core.ingestion.storage import ChromaDBStorage + from docstra.core.retrieval.chroma import ChromaRetriever + from docstra.core.retrieval.fts import FtsRetriever + from docstra.core.retrieval.fusion import FusionRetriever + + chroma_storage = ChromaDBStorage(persist_directory=str(tmp_path / "chroma")) + fts_storage = FtsStorage(str(tmp_path / "index.db")) + embedder = _DummyEmbedder() + + # Registry for the minimal in-memory CodebaseIndex substitute. + # Maps file_id -> [(chunk_id, start_line, end_line)] + chunks_by_file: Dict[str, List] = {} + file_language: Dict[str, str] = {} + + for corpus_entry in _CORPUS_FILES: + file_id: str = corpus_entry["path"] + abs_path: Path = corpus_entry["abs"] + content = abs_path.read_text(encoding="utf-8", errors="replace") + file_chunks = _split_into_chunks(content) + language = "python" + + chunk_ids = [] + chunk_contents = [] + chunk_metadatas = [] + chunk_embeddings = [] + fts_chunk_ids = [] + fts_start_lines = [] + fts_end_lines = [] + fts_contents = [] + + file_chunk_tuples = [] + for fc in file_chunks: + chunk_id = f"{file_id}#L{fc['start_line']}-L{fc['end_line']}" + embedding = embedder.generate_embedding(fc["content"]) + + chunk_ids.append(chunk_id) + chunk_contents.append(fc["content"]) + chunk_metadatas.append( + { + "document_id": file_id, + "start_line": fc["start_line"], + "end_line": fc["end_line"], + "language": language, + "chunk_id": chunk_id, + } + ) + chunk_embeddings.append(embedding) + + fts_chunk_ids.append(chunk_id) + fts_start_lines.append(fc["start_line"]) + fts_end_lines.append(fc["end_line"]) + fts_contents.append(fc["content"]) + + file_chunk_tuples.append((chunk_id, fc["start_line"], fc["end_line"])) + + chroma_storage.add_chunks( + chunk_ids=chunk_ids, + contents=chunk_contents, + metadatas=chunk_metadatas, + embeddings=chunk_embeddings, + ) + fts_storage.add_chunks( + chunk_ids=fts_chunk_ids, + file_ids=[file_id] * len(fts_chunk_ids), + languages=[language] * len(fts_chunk_ids), + start_lines=fts_start_lines, + end_lines=fts_end_lines, + contents=fts_contents, + ) + chunks_by_file[file_id] = file_chunk_tuples + file_language[file_id] = language + + # Minimal CodebaseIndex substitute: only needs chunks_for_file + file_language. + code_index = SimpleNamespace( + chunks_for_file=lambda fid: chunks_by_file.get(fid, []), + file_language=lambda fid: file_language.get(fid), + ) + + chroma_retriever = ChromaRetriever(chroma_storage, embedder) + fts_retriever = FtsRetriever(fts_storage) + + import re as _re + + def _fts_query(raw: str) -> str: + """Strip FTS5-special chars and return a space-joined word query.""" + words = _re.findall(r"[A-Za-z0-9_]+", raw) + return " ".join(words) if words else "x" + + # FTS hits lack a 'metadata' dict; wrap them so collect_retrieved_files can + # extract document_id regardless of which source produced a chunk. + # Also sanitise the query string so FTS5 doesn't choke on punctuation. + class _NormalisingFts: + def retrieve_chunks( + self, query: str, n_results: int = 50, **filters + ) -> List[Dict[str, Any]]: + hits = fts_retriever.retrieve_chunks( + _fts_query(query), n_results=n_results, **filters + ) + for hit in hits: + if "metadata" not in hit: + hit["metadata"] = { + "document_id": hit.get("file_id", ""), + "start_line": hit.get("start_line"), + "end_line": hit.get("end_line"), + } + hit["id"] = hit.get("chunk_id", "") + return hits + + def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: + return fts_retriever.retrieve_symbols(_fts_query(query), n_results=n_results) + + fusion_retriever = FusionRetriever( + dense=chroma_retriever, + fts=_NormalisingFts(), + code_index=code_index, + ) + + top_k = 5 + + def chroma_retrieve(question: str, n: int) -> List[Dict[str, Any]]: + return chroma_retriever.retrieve_chunks(question, n_results=n) + + def fusion_retrieve(question: str, n: int) -> List[Dict[str, Any]]: + return fusion_retriever.retrieve_chunks(question, n_results=n) + + chroma_summary = evaluate_retrieval_cases(_EVAL_CASES, chroma_retrieve, top_k=top_k) + fusion_summary = evaluate_retrieval_cases(_EVAL_CASES, fusion_retrieve, top_k=top_k) + + # --- Print side-by-side table (visible with pytest -s) --- + print(f"\n{'':=<72}") + print(f"Retrieval eval top_k={top_k} corpus={len(_CORPUS_FILES)} files") + print(f"{'':=<72}") + header = f"{'Query':<50} {'Chroma':>7} {'Fusion':>7}" + print(header) + print(f"{'':-<72}") + for cr, fr in zip(chroma_summary.results, fusion_summary.results): + chroma_rank = str(cr.rank) if cr.rank else "-" + fusion_rank = str(fr.rank) if fr.rank else "-" + q = cr.case.question[:48] + print(f"{q:<50} {chroma_rank:>7} {fusion_rank:>7}") + print(f"{'':-<72}") + print( + f"{'recall@k':<50} {chroma_summary.recall_at_k:>7.2f} {fusion_summary.recall_at_k:>7.2f}" + ) + print( + f"{'passed':<50} {chroma_summary.passed_count:>7} {fusion_summary.passed_count:>7}" + ) + print(f"{'':=<72}") + + # Sanity: scores must be valid floats, not errors. + assert 0.0 <= chroma_summary.recall_at_k <= 1.0 + assert 0.0 <= fusion_summary.recall_at_k <= 1.0 From 2fe650a7acc93b12e915124788992d2c3b0935d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 10:28:26 +0200 Subject: [PATCH 12/20] fix(retrieval): sanitize FTS queries and normalize hit shape with Chroma --- docstra/core/ingestion/fts_storage.py | 64 +++++++++++++++++++++++++-- tests/test_fts_retriever.py | 2 + tests/test_fts_storage.py | 41 ++++++++++++++++- tests/test_retrieval_evaluation.py | 45 +++---------------- 4 files changed, 110 insertions(+), 42 deletions(-) diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py index d820300..9787c86 100644 --- a/docstra/core/ingestion/fts_storage.py +++ b/docstra/core/ingestion/fts_storage.py @@ -3,11 +3,27 @@ from __future__ import annotations import os +import re import sqlite3 from typing import Any, Dict, List, Optional, Sequence from docstra.core.indexing.model import IndexedSymbol +_FTS_TOKEN_RE = re.compile(r"\w+", re.UNICODE) + + +def _sanitize_fts_query(query: str) -> str: + """Strip FTS5 syntax characters and lowercase the result. + + Tokenizes on word boundaries, joins with spaces, lowercases. The + lowercasing matters: FTS5 boolean operators (NOT/AND/OR) are recognized + only in uppercase, so lowercasing user queries prevents accidental + boolean parsing while leaving the unicode61 tokenizer's case-insensitive + matching unchanged. + """ + tokens = _FTS_TOKEN_RE.findall(query) + return " ".join(tokens).lower() + SCHEMA_VERSION = 1 _SCHEMA = """ @@ -114,8 +130,11 @@ def delete_by_file(self, file_id: str) -> None: def search_chunks( self, query: str, n_results: int = 50, *, language: Optional[str] = None, file_id: Optional[str] = None ) -> List[Dict[str, Any]]: + match_query = _sanitize_fts_query(query) + if not match_query: + return [] clauses = ["chunks_fts MATCH ?"] - params: List[Any] = [query] + params: List[Any] = [match_query] if language is not None: clauses.append("chunks.language = ?") params.append(language) @@ -133,7 +152,27 @@ def search_chunks( LIMIT ? """ params.append(n_results) - return [dict(row) for row in self._conn.execute(sql, params).fetchall()] + results = [] + for row in self._conn.execute(sql, params).fetchall(): + results.append({ + "id": row["chunk_id"], + "chunk_id": row["chunk_id"], + "file_id": row["file_id"], + "language": row["language"], + "start_line": row["start_line"], + "end_line": row["end_line"], + "content": row["content"], + "score": row["score"], + "metadata": { + "document_id": row["file_id"], + "filepath": row["file_id"], + "start_line": row["start_line"], + "end_line": row["end_line"], + "language": row["language"], + "chunk_type": "code", + }, + }) + return results # --- symbols --- @@ -148,6 +187,9 @@ def add_symbols(self, symbols: List[IndexedSymbol]) -> None: ) def search_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: + match_query = _sanitize_fts_query(query) + if not match_query: + return [] sql = """ SELECT symbol_id, file_id, kind, name, -bm25(symbols_fts) AS score FROM symbols_fts @@ -155,4 +197,20 @@ def search_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any] ORDER BY score DESC LIMIT ? """ - return [dict(row) for row in self._conn.execute(sql, (query, n_results)).fetchall()] + results = [] + for row in self._conn.execute(sql, (match_query, n_results)).fetchall(): + results.append({ + "id": row["symbol_id"], + "symbol_id": row["symbol_id"], + "file_id": row["file_id"], + "kind": row["kind"], + "name": row["name"], + "score": row["score"], + "metadata": { + "document_id": row["file_id"], + "filepath": row["file_id"], + "name": row["name"], + "kind": row["kind"], + }, + }) + return results diff --git a/tests/test_fts_retriever.py b/tests/test_fts_retriever.py index 7d4df46..f7af02c 100644 --- a/tests/test_fts_retriever.py +++ b/tests/test_fts_retriever.py @@ -21,6 +21,7 @@ def test_retrieve_chunks_delegates_to_storage(tmp_path: Path): hits = retriever.retrieve_chunks("make_chunk_id", n_results=5) assert len(hits) == 1 assert hits[0]["chunk_id"] == "repo/file.py#L1-L10" + assert hits[0]["metadata"]["document_id"] == hits[0]["file_id"] def test_retrieve_symbols_delegates_to_storage(tmp_path: Path): @@ -39,3 +40,4 @@ def test_retrieve_symbols_delegates_to_storage(tmp_path: Path): hits = retriever.retrieve_symbols("foo", n_results=5) assert len(hits) == 1 assert hits[0]["name"] == "foo" + assert hits[0]["metadata"]["document_id"] == "x.py" diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py index 9fc775f..616ac9e 100644 --- a/tests/test_fts_storage.py +++ b/tests/test_fts_storage.py @@ -4,7 +4,7 @@ import pytest -from docstra.core.ingestion.fts_storage import FtsStorage +from docstra.core.ingestion.fts_storage import FtsStorage, _sanitize_fts_query def _add_chunk(store: FtsStorage, **overrides): @@ -46,6 +46,8 @@ def test_add_and_search_chunks(tmp_path: Path): assert hits[0]["start_line"] == 1 assert hits[0]["end_line"] == 10 assert "make_chunk_id" in hits[0]["content"] + # Shape contract: metadata sub-dict must be present and document_id must match file_id. + assert hits[0]["metadata"]["document_id"] == hits[0]["file_id"] def test_delete_by_file_removes_chunks_and_fts(tmp_path: Path): @@ -89,6 +91,9 @@ def test_add_and_search_symbols(tmp_path: Path): assert len(hits) == 1 assert hits[0]["name"] == "CoreIndexBuilder" assert hits[0]["file_id"] == "repo/other.py" + assert hits[0]["id"] == hits[0]["symbol_id"] + assert hits[0]["metadata"]["document_id"] == hits[0]["file_id"] + assert hits[0]["metadata"]["name"] == "CoreIndexBuilder" def test_delete_by_file_removes_symbols(tmp_path: Path): @@ -107,3 +112,37 @@ def test_delete_by_file_removes_symbols(tmp_path: Path): ]) store.delete_by_file("x.py") assert store.search_symbols("foo", n_results=5) == [] + + +def test_sanitize_fts_query_strips_special_chars(): + assert _sanitize_fts_query("What is the config?") == "what is the config" + assert _sanitize_fts_query("(foo OR bar)") == "foo or bar" + assert _sanitize_fts_query('search "exact phrase"') == "search exact phrase" + assert _sanitize_fts_query("a*b") == "a b" + assert _sanitize_fts_query("") == "" + assert _sanitize_fts_query("???") == "" + + +def test_punctuation_query_does_not_raise(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + # FTS5 would raise sqlite3.OperationalError without sanitization. + hits = store.search_chunks("What is the config?", n_results=5) + assert isinstance(hits, list) + + +def test_empty_query_returns_empty_list(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + assert store.search_chunks("???", n_results=5) == [] + assert store.search_symbols("???", n_results=5) == [] + + +def test_query_with_reserved_words_does_not_raise(tmp_path: Path): + """A natural-language query starting with NOT/AND/OR must not crash FTS5.""" + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store, content="def handle_not_null(value): pass") + # Each of these would otherwise be parsed as a FTS5 boolean expression. + for query in ("NOT null handling", "OR operator", "AND something", "foo AND"): + # Must not raise; result can be empty or have hits — we only care about no crash. + store.search_chunks(query, n_results=5) diff --git a/tests/test_retrieval_evaluation.py b/tests/test_retrieval_evaluation.py index 9e11a83..aad3fb6 100644 --- a/tests/test_retrieval_evaluation.py +++ b/tests/test_retrieval_evaluation.py @@ -1,6 +1,8 @@ from __future__ import annotations +import hashlib import json +import struct import sys from importlib import util from pathlib import Path @@ -245,15 +247,12 @@ def retrieve(question: str, candidate_k: int): class _DummyEmbedder: - """Returns a deterministic 3-d embedding so Chroma is happy without a real model.""" + """Returns a deterministic 8-d embedding so Chroma is happy without a real model.""" def generate_embedding(self, text: str) -> List[float]: - # Simple hash-based spread so different files get different vectors. - h = hash(text[:64]) & 0xFFFFFF - a = ((h >> 16) & 0xFF) / 255.0 - b = ((h >> 8) & 0xFF) / 255.0 - c = (h & 0xFF) / 255.0 - return [a, b, c] + digest = hashlib.sha256(text.encode("utf-8")).digest() + vals = struct.unpack(">8I", digest[:32]) + return [v / 0xFFFFFFFF for v in vals] def _split_into_chunks(content: str, chunk_size: int = 40) -> List[Dict[str, Any]]: @@ -364,39 +363,9 @@ def test_chroma_vs_fusion_retrieval_eval(tmp_path: Path) -> None: chroma_retriever = ChromaRetriever(chroma_storage, embedder) fts_retriever = FtsRetriever(fts_storage) - import re as _re - - def _fts_query(raw: str) -> str: - """Strip FTS5-special chars and return a space-joined word query.""" - words = _re.findall(r"[A-Za-z0-9_]+", raw) - return " ".join(words) if words else "x" - - # FTS hits lack a 'metadata' dict; wrap them so collect_retrieved_files can - # extract document_id regardless of which source produced a chunk. - # Also sanitise the query string so FTS5 doesn't choke on punctuation. - class _NormalisingFts: - def retrieve_chunks( - self, query: str, n_results: int = 50, **filters - ) -> List[Dict[str, Any]]: - hits = fts_retriever.retrieve_chunks( - _fts_query(query), n_results=n_results, **filters - ) - for hit in hits: - if "metadata" not in hit: - hit["metadata"] = { - "document_id": hit.get("file_id", ""), - "start_line": hit.get("start_line"), - "end_line": hit.get("end_line"), - } - hit["id"] = hit.get("chunk_id", "") - return hits - - def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: - return fts_retriever.retrieve_symbols(_fts_query(query), n_results=n_results) - fusion_retriever = FusionRetriever( dense=chroma_retriever, - fts=_NormalisingFts(), + fts=fts_retriever, code_index=code_index, ) From 4efe9eae50e3fc22abc2d1baa310e3d8cf19a8ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 12:02:18 +0200 Subject: [PATCH 13/20] fix(retrieval): hydrate symbol-derived chunks with content and wire FTS into Docstra facade --- docstra/core/__init__.py | 15 ++++++-- docstra/core/ingestion/fts_storage.py | 9 +++++ docstra/core/retrieval/fts.py | 3 ++ docstra/core/retrieval/fusion.py | 10 ++++- tests/test_fts_retriever.py | 17 +++++++++ tests/test_fts_storage.py | 14 +++++++ tests/test_fusion_retriever.py | 53 +++++++++++++++++++++++++++ 7 files changed, 117 insertions(+), 4 deletions(-) diff --git a/docstra/core/__init__.py b/docstra/core/__init__.py index a86a582..fcc9021 100644 --- a/docstra/core/__init__.py +++ b/docstra/core/__init__.py @@ -88,11 +88,16 @@ def setup_components(self): ] ) + # FTS storage (shared by indexer and retriever) + self.fts_storage = FtsStorage(f"{storage_dir}/index.db") + self.fts_retriever = FtsRetriever(self.fts_storage) + # Document indexer self.document_indexer = DocumentIndexer( self.storage, self.embedding_generator, codebase_root=str(Path.cwd()), + fts_storage=self.fts_storage, ) # Code indexer @@ -113,11 +118,9 @@ def setup_components(self): ) # Fusion retriever - fts_storage = FtsStorage(f"{storage_dir}/index.db") - fts_retriever = FtsRetriever(fts_storage) self.fusion_retriever = FusionRetriever( dense=self.retriever, - fts=fts_retriever, + fts=self.fts_retriever, code_index=self.code_indexer.get_index(), rrf_k=self.config.retrieval.rrf_k, fts_chunks_top_k=self.config.retrieval.fts_chunks_top_k, @@ -207,6 +210,12 @@ def index_file(self, filepath: str) -> str: doc_id = self.document_indexer.index_document(document) self.code_indexer.index_document(document) + # Write symbols to FTS for this file + manifest = self.code_indexer.get_manifest() + file_symbols = [s for s in manifest.symbols if s.file_id == doc_id] + if file_symbols: + self.fts_storage.add_symbols(file_symbols) + return doc_id def document_code( diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py index 9787c86..0a384a0 100644 --- a/docstra/core/ingestion/fts_storage.py +++ b/docstra/core/ingestion/fts_storage.py @@ -174,6 +174,15 @@ def search_chunks( }) return results + def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: + """Return the full row for a single chunk id, or None if absent.""" + sql = ( + "SELECT chunk_id, file_id, language, start_line, end_line, content " + "FROM chunks WHERE chunk_id = ? LIMIT 1" + ) + row = self._conn.execute(sql, (chunk_id,)).fetchone() + return dict(row) if row else None + # --- symbols --- def add_symbols(self, symbols: List[IndexedSymbol]) -> None: diff --git a/docstra/core/retrieval/fts.py b/docstra/core/retrieval/fts.py index 63d29fd..143154a 100644 --- a/docstra/core/retrieval/fts.py +++ b/docstra/core/retrieval/fts.py @@ -27,3 +27,6 @@ def retrieve_chunks( def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: return self.storage.search_symbols(query, n_results=n_results) + + def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: + return self.storage.get_chunk(chunk_id) diff --git a/docstra/core/retrieval/fusion.py b/docstra/core/retrieval/fusion.py index 2301076..b34b078 100644 --- a/docstra/core/retrieval/fusion.py +++ b/docstra/core/retrieval/fusion.py @@ -20,6 +20,7 @@ def retrieve_chunks(self, query: str, n_results: int = 20, **filters) -> List[Di class _FtsLike(Protocol): def retrieve_chunks(self, query: str, n_results: int = 50, **filters) -> List[Dict[str, Any]]: ... def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: ... + def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: ... class FusionRetriever: @@ -202,17 +203,24 @@ def _symbols_to_chunks( for chunk_id, start_line, end_line in self.code_index.chunks_for_file(file_id): if start_line <= line <= end_line and chunk_id not in seen: seen.add(chunk_id) + chunk_row = self.fts.get_chunk(chunk_id) + content = chunk_row["content"] if chunk_row else "" + lang = chunk_row["language"] if chunk_row else "" results.append({ "chunk_id": chunk_id, "id": chunk_id, "file_id": file_id, + "language": lang, "start_line": start_line, "end_line": end_line, - "content": "", + "content": content, "metadata": { "document_id": file_id, + "filepath": file_id, "start_line": start_line, "end_line": end_line, + "language": lang, + "chunk_type": "code", "via_symbol": symbol_hit.get("name"), }, }) diff --git a/tests/test_fts_retriever.py b/tests/test_fts_retriever.py index f7af02c..1934470 100644 --- a/tests/test_fts_retriever.py +++ b/tests/test_fts_retriever.py @@ -41,3 +41,20 @@ def test_retrieve_symbols_delegates_to_storage(tmp_path: Path): assert len(hits) == 1 assert hits[0]["name"] == "foo" assert hits[0]["metadata"]["document_id"] == "x.py" + + +def test_get_chunk_returns_row_and_none(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + store.add_chunks( + chunk_ids=["repo/a.py#L1-L5"], + file_ids=["repo/a.py"], + languages=["python"], + start_lines=[1], + end_lines=[5], + contents=["def bar(): pass"], + ) + retriever = FtsRetriever(store) + row = retriever.get_chunk("repo/a.py#L1-L5") + assert row is not None + assert row["chunk_id"] == "repo/a.py#L1-L5" + assert retriever.get_chunk("missing#L1-L1") is None diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py index 616ac9e..a9579f1 100644 --- a/tests/test_fts_storage.py +++ b/tests/test_fts_storage.py @@ -138,6 +138,20 @@ def test_empty_query_returns_empty_list(tmp_path: Path): assert store.search_symbols("???", n_results=5) == [] +def test_get_chunk_returns_row(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + payload = _add_chunk(store) + row = store.get_chunk(payload["chunk_id"]) + assert row is not None + assert row["chunk_id"] == payload["chunk_id"] + assert row["content"] == payload["content"] + + +def test_get_chunk_returns_none_for_missing(tmp_path: Path): + store = FtsStorage(str(tmp_path / "index.db")) + assert store.get_chunk("nonexistent#L1-L1") is None + + def test_query_with_reserved_words_does_not_raise(tmp_path: Path): """A natural-language query starting with NOT/AND/OR must not crash FTS5.""" store = FtsStorage(str(tmp_path / "index.db")) diff --git a/tests/test_fusion_retriever.py b/tests/test_fusion_retriever.py index 75cf1a7..b2cdf6d 100644 --- a/tests/test_fusion_retriever.py +++ b/tests/test_fusion_retriever.py @@ -27,6 +27,9 @@ def retrieve_chunks(self, query: str, n_results: int = 50, **filters): def retrieve_symbols(self, query: str, n_results: int = 25): return self._symbols[:n_results] + def get_chunk(self, chunk_id: str): + return None + def _chunk(chunk_id: str, file_id: str, start_line: int = 1, end_line: int = 10): return { @@ -156,6 +159,56 @@ def test_empty_lexical_source_does_not_break_fusion(): assert [h["chunk_id"] for h in hits] == ["repo/a.py#L1-L10"] +def test_symbol_derived_chunk_carries_content(tmp_path): + """A symbol hit that promotes a chunk must surface the chunk's real content.""" + from docstra.core.ingestion.fts_storage import FtsStorage + from docstra.core.retrieval.fts import FtsRetriever + + fts_storage = FtsStorage(str(tmp_path / "index.db")) + fts_storage.add_chunks( + chunk_ids=["repo/a.py#L1-L10"], + file_ids=["repo/a.py"], + languages=["python"], + start_lines=[1], + end_lines=[10], + contents=["def foo():\n return 42\n"], + ) + fts = FtsRetriever(fts_storage) + + dense = _FakeDense([]) + + class _SymbolOnlyFts: + def __init__(self, real): + self._real = real + + def retrieve_chunks(self, *a, **kw): + return [] + + def retrieve_symbols(self, *a, **kw): + return [{"symbol_id": "repo/a.py::function::foo::L2", "file_id": "repo/a.py", "name": "foo", "kind": "function"}] + + def get_chunk(self, chunk_id): + return self._real.get_chunk(chunk_id) + + code_index = SimpleNamespace( + chunks_for_file=lambda fid: [("repo/a.py#L1-L10", 1, 10)] if fid == "repo/a.py" else [], + file_language=lambda fid: "python" if fid == "repo/a.py" else None, + ) + + fusion = FusionRetriever( + dense=dense, + fts=_SymbolOnlyFts(fts), + code_index=code_index, + rrf_k=60, + fts_chunks_top_k=10, + fts_symbols_top_k=10, + ) + hits = fusion.retrieve_chunks("foo", n_results=5) + assert len(hits) == 1 + assert "def foo()" in hits[0]["content"] + assert hits[0]["metadata"]["via_symbol"] == "foo" + + def test_retrieve_code_examples_prefers_function_chunks(): func_chunk = _chunk("repo/a.py#L1-L20", "repo/a.py") func_chunk["metadata"]["chunk_type"] = "function" From e42483293fe1817999e0427e9e4a5ef082d19fd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 13:06:17 +0200 Subject: [PATCH 14/20] fix(retrieval): accept and ignore unknown filter keys in FtsRetriever --- docstra/core/retrieval/fts.py | 1 + tests/test_fts_retriever.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/docstra/core/retrieval/fts.py b/docstra/core/retrieval/fts.py index 143154a..527e320 100644 --- a/docstra/core/retrieval/fts.py +++ b/docstra/core/retrieval/fts.py @@ -20,6 +20,7 @@ def retrieve_chunks( *, language: Optional[str] = None, file_id: Optional[str] = None, + **_unused_filters: Any, ) -> List[Dict[str, Any]]: return self.storage.search_chunks( query, n_results=n_results, language=language, file_id=file_id diff --git a/tests/test_fts_retriever.py b/tests/test_fts_retriever.py index 1934470..5e0df04 100644 --- a/tests/test_fts_retriever.py +++ b/tests/test_fts_retriever.py @@ -58,3 +58,20 @@ def test_get_chunk_returns_row_and_none(tmp_path: Path): assert row is not None assert row["chunk_id"] == "repo/a.py#L1-L5" assert retriever.get_chunk("missing#L1-L1") is None + + +def test_retrieve_chunks_ignores_unknown_filter_keys(tmp_path: Path): + """Unknown filter keys must not crash; FtsRetriever silently ignores them like Chroma.""" + store = FtsStorage(str(tmp_path / "index.db")) + store.add_chunks( + chunk_ids=["a.py#L1-L1"], + file_ids=["a.py"], + languages=["python"], + start_lines=[1], + end_lines=[1], + contents=["def foo(): pass"], + ) + retriever = FtsRetriever(store) + # document_id is a Chroma-style key not supported by FTS — must not raise. + hits = retriever.retrieve_chunks("foo", n_results=5, document_id="ignored") + assert len(hits) == 1 From 7a1b8e6c5e3eee0ab69dad479356d4b0d634528e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 13:13:34 +0200 Subject: [PATCH 15/20] fix(fts): allow FtsStorage connection across threads --- docstra/core/ingestion/fts_storage.py | 28 +++++++++++++++++++-------- tests/test_fts_storage.py | 13 +++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py index 0a384a0..075a121 100644 --- a/docstra/core/ingestion/fts_storage.py +++ b/docstra/core/ingestion/fts_storage.py @@ -5,6 +5,7 @@ import os import re import sqlite3 +import threading from typing import Any, Dict, List, Optional, Sequence from docstra.core.indexing.model import IndexedSymbol @@ -73,12 +74,18 @@ def _sanitize_fts_query(query: str) -> str: class FtsStorage: - """SQLite store with FTS5 indexes for chunks and symbols.""" + """SQLite store with FTS5 indexes for chunks and symbols. + + The connection is shared across threads (check_same_thread=False) and all + statements run under an RLock; Python's sqlite3 module does not serialize + concurrent execute() calls on a shared connection on its own. + """ def __init__(self, db_path: str) -> None: self.db_path = db_path + self._lock = threading.RLock() os.makedirs(os.path.dirname(db_path) or ".", exist_ok=True) - self._conn = sqlite3.connect(db_path) + self._conn = sqlite3.connect(db_path, check_same_thread=False) self._conn.row_factory = sqlite3.Row self._migrate() @@ -107,7 +114,7 @@ def add_chunks( contents: Sequence[str], ) -> None: rows = list(zip(chunk_ids, file_ids, languages, start_lines, end_lines, contents)) - with self._conn: + with self._lock, self._conn: self._conn.executemany( """ INSERT INTO chunks (chunk_id, file_id, language, start_line, end_line, content) @@ -123,7 +130,7 @@ def add_chunks( ) def delete_by_file(self, file_id: str) -> None: - with self._conn: + with self._lock, self._conn: self._conn.execute("DELETE FROM chunks WHERE file_id = ?", (file_id,)) self._conn.execute("DELETE FROM symbols_fts WHERE file_id = ?", (file_id,)) @@ -153,7 +160,9 @@ def search_chunks( """ params.append(n_results) results = [] - for row in self._conn.execute(sql, params).fetchall(): + with self._lock: + rows = self._conn.execute(sql, params).fetchall() + for row in rows: results.append({ "id": row["chunk_id"], "chunk_id": row["chunk_id"], @@ -180,7 +189,8 @@ def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: "SELECT chunk_id, file_id, language, start_line, end_line, content " "FROM chunks WHERE chunk_id = ? LIMIT 1" ) - row = self._conn.execute(sql, (chunk_id,)).fetchone() + with self._lock: + row = self._conn.execute(sql, (chunk_id,)).fetchone() return dict(row) if row else None # --- symbols --- @@ -189,7 +199,7 @@ def add_symbols(self, symbols: List[IndexedSymbol]) -> None: rows = [ (symbol.id, symbol.file_id, symbol.kind, symbol.name) for symbol in symbols ] - with self._conn: + with self._lock, self._conn: self._conn.executemany( "INSERT INTO symbols_fts (symbol_id, file_id, kind, name) VALUES (?, ?, ?, ?)", rows, @@ -207,7 +217,9 @@ def search_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any] LIMIT ? """ results = [] - for row in self._conn.execute(sql, (match_query, n_results)).fetchall(): + with self._lock: + rows = self._conn.execute(sql, (match_query, n_results)).fetchall() + for row in rows: results.append({ "id": row["symbol_id"], "symbol_id": row["symbol_id"], diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py index a9579f1..912378e 100644 --- a/tests/test_fts_storage.py +++ b/tests/test_fts_storage.py @@ -160,3 +160,16 @@ def test_query_with_reserved_words_does_not_raise(tmp_path: Path): for query in ("NOT null handling", "OR operator", "AND something", "foo AND"): # Must not raise; result can be empty or have hits — we only care about no crash. store.search_chunks(query, n_results=5) + + +def test_search_chunks_safe_across_threads(tmp_path: Path) -> None: + """The FtsStorage connection must be usable from worker threads.""" + from concurrent.futures import ThreadPoolExecutor + + store = FtsStorage(str(tmp_path / "index.db")) + _add_chunk(store) + + with ThreadPoolExecutor(max_workers=4) as pool: + results = list(pool.map(lambda _: store.search_chunks("make_chunk_id", 5), range(8))) + + assert all(len(hits) == 1 for hits in results) From 08cca5feb4396d000f9dbc6860efce00690b900c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 13:37:46 +0200 Subject: [PATCH 16/20] fix(fts): make add_symbols idempotent by deleting touched file_ids first --- docstra/core/ingestion/fts_storage.py | 12 ++++++--- tests/test_fts_storage.py | 38 +++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py index 075a121..13893fe 100644 --- a/docstra/core/ingestion/fts_storage.py +++ b/docstra/core/ingestion/fts_storage.py @@ -196,10 +196,16 @@ def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: # --- symbols --- def add_symbols(self, symbols: List[IndexedSymbol]) -> None: - rows = [ - (symbol.id, symbol.file_id, symbol.kind, symbol.name) for symbol in symbols - ] + if not symbols: + return + file_ids = sorted({symbol.file_id for symbol in symbols}) + rows = [(s.id, s.file_id, s.kind, s.name) for s in symbols] with self._lock, self._conn: + placeholders = ",".join("?" * len(file_ids)) + self._conn.execute( + f"DELETE FROM symbols_fts WHERE file_id IN ({placeholders})", + file_ids, + ) self._conn.executemany( "INSERT INTO symbols_fts (symbol_id, file_id, kind, name) VALUES (?, ?, ?, ?)", rows, diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py index 912378e..2adfc70 100644 --- a/tests/test_fts_storage.py +++ b/tests/test_fts_storage.py @@ -173,3 +173,41 @@ def test_search_chunks_safe_across_threads(tmp_path: Path) -> None: results = list(pool.map(lambda _: store.search_chunks("make_chunk_id", 5), range(8))) assert all(len(hits) == 1 for hits in results) + + +def test_add_symbols_is_idempotent(tmp_path: Path): + """Calling add_symbols twice with the same input must not duplicate rows.""" + from docstra.core.indexing.model import IndexedSymbol + + store = FtsStorage(str(tmp_path / "index.db")) + symbols = [ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ), + IndexedSymbol( + id="x.py::function::bar::L10", + file_id="x.py", + name="bar", + kind="function", + language="python", + line=10, + ), + ] + store.add_symbols(symbols) + store.add_symbols(symbols) # second call must not duplicate + + foo_hits = store.search_symbols("foo", n_results=10) + bar_hits = store.search_symbols("bar", n_results=10) + assert len(foo_hits) == 1 + assert len(bar_hits) == 1 + + +def test_add_symbols_empty_list_is_no_op(tmp_path: Path): + """An empty symbol list should be a clean no-op (no IN () syntax error).""" + store = FtsStorage(str(tmp_path / "index.db")) + store.add_symbols([]) # must not raise From 3876ca5f99e8b8a838ebb15c264d4aeb6b6fa51e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 13:39:19 +0200 Subject: [PATCH 17/20] perf(indexing): cache chunks-by-file lookup in CodebaseIndex --- docstra/core/indexing/code_index.py | 17 ++++++++++------- tests/test_core_index.py | 2 ++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docstra/core/indexing/code_index.py b/docstra/core/indexing/code_index.py index 4d033cd..722ae7a 100644 --- a/docstra/core/indexing/code_index.py +++ b/docstra/core/indexing/code_index.py @@ -55,6 +55,7 @@ def __init__( self._functions_by_name: Dict[str, List[IndexedSymbol]] = defaultdict(list) self._classes_by_name: Dict[str, List[IndexedSymbol]] = defaultdict(list) self._symbols_by_file: Dict[str, List[IndexedSymbol]] = defaultdict(list) + self._chunks_by_file: Dict[str, List[Tuple[str, int, int]]] = defaultdict(list) self._imports_by_source: Dict[str, List[ImportRecord]] = defaultdict(list) self._imports_by_text: Dict[str, List[str]] = defaultdict(list) self._dependencies_by_source: Dict[str, List[str]] = defaultdict(list) @@ -110,6 +111,7 @@ def _rebuild_lookups(self) -> None: self._functions_by_name = defaultdict(list) self._classes_by_name = defaultdict(list) self._symbols_by_file = defaultdict(list) + self._chunks_by_file = defaultdict(list) self._imports_by_source = defaultdict(list) self._imports_by_text = defaultdict(list) self._dependencies_by_source = defaultdict(list) @@ -123,6 +125,13 @@ def _rebuild_lookups(self) -> None: elif symbol.kind == "class": self._classes_by_name[symbol.name].append(symbol) + for chunk in self._manifest.chunks: + self._chunks_by_file[chunk.file_id].append( + (chunk.id, chunk.start_line, chunk.end_line) + ) + for chunks in self._chunks_by_file.values(): + chunks.sort(key=lambda item: item[1]) + for import_record in self._manifest.imports: self._imports_by_source[import_record.source_file_id].append(import_record) self._imports_by_text[import_record.raw_text].append( @@ -457,13 +466,7 @@ def get_related_files(self, filepath: str) -> List[str]: def chunks_for_file(self, file_id: str) -> List[Tuple[str, int, int]]: """Return (chunk_id, start_line, end_line) tuples for a file in line order.""" - matching = [ - (chunk.id, chunk.start_line, chunk.end_line) - for chunk in self._manifest.chunks - if chunk.file_id == file_id - ] - matching.sort(key=lambda tup: tup[1]) - return matching + return list(self._chunks_by_file.get(file_id, [])) def file_language(self, file_id: str) -> Optional[str]: """Return the language recorded in the manifest for a file id, if any.""" diff --git a/tests/test_core_index.py b/tests/test_core_index.py index 3cc4bc1..e75a76b 100644 --- a/tests/test_core_index.py +++ b/tests/test_core_index.py @@ -495,6 +495,7 @@ def test_chunks_for_file_returns_chunks_in_line_order(tmp_path): IndexedChunk(id="b.py#L1-L5", file_id="b.py", language="python", start_line=1, end_line=5, chunk_type="code"), ]) + index._rebuild_lookups() assert index.chunks_for_file("a.py") == [("a.py#L1-L10", 1, 10), ("a.py#L11-L20", 11, 20)] assert index.chunks_for_file("missing.py") == [] @@ -508,6 +509,7 @@ def test_chunks_for_file_returns_chunks_in_line_order(tmp_path): IndexedChunk(id="c.py#L1-L10", file_id="c.py", language="python", start_line=1, end_line=10, chunk_type="code"), ]) + index2._rebuild_lookups() result = index2.chunks_for_file("c.py") assert result == [("c.py#L1-L10", 1, 10), ("c.py#L100-L110", 100, 110)] From c64ef35073ee3218d97f36f996f26bd8d6bc4b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 13:41:29 +0200 Subject: [PATCH 18/20] fix(retrieval): drop incompatible score subtraction in retrieve_code_examples --- docstra/core/retrieval/fusion.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/docstra/core/retrieval/fusion.py b/docstra/core/retrieval/fusion.py index b34b078..abe8992 100644 --- a/docstra/core/retrieval/fusion.py +++ b/docstra/core/retrieval/fusion.py @@ -149,25 +149,17 @@ def retrieve_code_examples( example_score += 0.3 break - # Use original vector score - vector_score = chunk.get("score", 0) - if vector_score is not None: - # Combine scores (vector score is typically a distance, so lower is better) - combined_score = example_score - vector_score - else: - combined_score = example_score - chunk_id = self._chunk_id(chunk) - good_examples.append( - { - "chunk_id": chunk_id, - "id": chunk_id, - "content": content, - "metadata": chunk["metadata"], - "score": combined_score, - "original_score": vector_score, - } - ) + if example_score > 0: + good_examples.append( + { + "chunk_id": chunk_id, + "id": chunk_id, + "content": content, + "metadata": chunk["metadata"], + "score": example_score, + } + ) # Sort by combined score and return top results sorted_examples = sorted( From 99cb690e05b4a6db9eb6ec3865db5ec8f68b2bdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 13:44:14 +0200 Subject: [PATCH 19/20] refactor(retrieval): require FusionRetriever in ContextAwareRetriever --- docstra/core/retrieval/context_aware.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/docstra/core/retrieval/context_aware.py b/docstra/core/retrieval/context_aware.py index 64460f4..4fadf2a 100644 --- a/docstra/core/retrieval/context_aware.py +++ b/docstra/core/retrieval/context_aware.py @@ -10,7 +10,6 @@ from docstra.core.indexing.code_index import CodebaseIndex from docstra.core.indexing.repo_map import RepositoryMap -from docstra.core.retrieval.chroma import ChromaRetriever from docstra.core.retrieval.fusion import FusionRetriever from docstra.core.utils.token_counter import ContextBudgetManager @@ -63,7 +62,7 @@ class ContextAwareRetriever: def __init__( self, - base_retriever: ChromaRetriever, + base_retriever: FusionRetriever, budget_manager: ContextBudgetManager, code_index: Optional[CodebaseIndex] = None, repo_map: Optional[RepositoryMap] = None, @@ -72,12 +71,7 @@ def __init__( self.budget_manager = budget_manager self.code_index = code_index self.repo_map = repo_map - - # Accept a FusionRetriever if one was passed in (normal path via query_service) - if hasattr(base_retriever, "retrieve_code_examples"): - self.fusion_retriever = base_retriever - else: - self.fusion_retriever = None + self.fusion_retriever = base_retriever def retrieve_with_budget( self, query: str, context_type: str = "query", **kwargs: Any @@ -383,14 +377,10 @@ def _get_general_context( ) -> Dict[str, Any]: """Get balanced general context for queries without clear intent.""" - # Use fusion retrieval if available, otherwise fall back to base retriever - if self.fusion_retriever: - results = self.fusion_retriever.retrieve( - query=query, - n_results=10, # Get more initially for filtering - ) - else: - results = self.base_retriever.retrieve_chunks(query=query, n_results=8) + results = self.fusion_retriever.retrieve( + query=query, + n_results=10, # Get more initially for filtering + ) # Budget-aware context assembly context_parts = {} @@ -570,9 +560,6 @@ def _get_targeted_code_samples( ) -> Optional[str]: """Get targeted code samples based on query analysis.""" - if not self.fusion_retriever: - return None - # Get code examples using hybrid retrieval try: examples = self.fusion_retriever.retrieve_code_examples( From bce191397098cd99807e660d762cd3afb5b40087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Andresen=20Osberg?= Date: Mon, 15 Jun 2026 16:39:19 +0200 Subject: [PATCH 20/20] chore: fix lint, format, and type errors --- docstra/core/__init__.py | 4 +- docstra/core/cli.py | 4 +- docstra/core/documentation/generator.py | 27 ++++---- docstra/core/ingestion/fts_storage.py | 80 +++++++++++++---------- docstra/core/retrieval/fusion.py | 63 ++++++++++++------- docstra/core/services/query_service.py | 12 +++- tests/test_core_index.py | 67 +++++++++++++++----- tests/test_fts_retriever.py | 22 ++++--- tests/test_fts_storage.py | 84 +++++++++++++++---------- tests/test_fusion_retriever.py | 52 ++++++++++++--- 10 files changed, 273 insertions(+), 142 deletions(-) diff --git a/docstra/core/__init__.py b/docstra/core/__init__.py index fcc9021..d4a8026 100644 --- a/docstra/core/__init__.py +++ b/docstra/core/__init__.py @@ -305,9 +305,7 @@ def answer_question(self, question: str, n_results: int = 5) -> str: Generated answer """ # Retrieve relevant chunks - results = self.fusion_retriever.retrieve( - query=question, n_results=n_results - ) + results = self.fusion_retriever.retrieve(query=question, n_results=n_results) # Generate answer return self._require_text_response( diff --git a/docstra/core/cli.py b/docstra/core/cli.py index 060ab26..c22a02a 100644 --- a/docstra/core/cli.py +++ b/docstra/core/cli.py @@ -1691,7 +1691,9 @@ def _get_persist_paths( def _create_retrieval_eval_runner( user_config: UserConfig, abs_codebase_path: Path ) -> Callable[[str, int], List[Dict[str, Any]]]: - effective_persist_dir, chroma_path, index_path = _get_persist_paths(user_config, abs_codebase_path) + effective_persist_dir, chroma_path, index_path = _get_persist_paths( + user_config, abs_codebase_path + ) core_index_path = index_path / CORE_INDEX_FILENAME chroma_check_file = chroma_path / "chroma.sqlite3" legacy_index_artifacts = CodebaseIndex.legacy_artifacts_in(index_path) diff --git a/docstra/core/documentation/generator.py b/docstra/core/documentation/generator.py index 24fe794..13ab5c5 100644 --- a/docstra/core/documentation/generator.py +++ b/docstra/core/documentation/generator.py @@ -156,18 +156,21 @@ def __init__( if self.chroma_retriever and self.code_index and persist_directory: fts_storage = FtsStorage(str(Path(persist_directory) / "index.db")) fts_retriever = FtsRetriever(fts_storage) - kwargs = { - "dense": self.chroma_retriever, - "fts": fts_retriever, - "code_index": self.code_index, - } - if user_config and hasattr(user_config, 'retrieval'): - kwargs.update({ - "rrf_k": user_config.retrieval.rrf_k, - "fts_chunks_top_k": user_config.retrieval.fts_chunks_top_k, - "fts_symbols_top_k": user_config.retrieval.fts_symbols_top_k, - }) - self.fusion_retriever = FusionRetriever(**kwargs) + if user_config and hasattr(user_config, "retrieval"): + self.fusion_retriever = FusionRetriever( + dense=self.chroma_retriever, + fts=fts_retriever, + code_index=self.code_index, + rrf_k=user_config.retrieval.rrf_k, + fts_chunks_top_k=user_config.retrieval.fts_chunks_top_k, + fts_symbols_top_k=user_config.retrieval.fts_symbols_top_k, + ) + else: + self.fusion_retriever = FusionRetriever( + dense=self.chroma_retriever, + fts=fts_retriever, + code_index=self.code_index, + ) # Documentation state self.processed_documents: Dict[str, Document] = {} diff --git a/docstra/core/ingestion/fts_storage.py b/docstra/core/ingestion/fts_storage.py index 13893fe..28e78ab 100644 --- a/docstra/core/ingestion/fts_storage.py +++ b/docstra/core/ingestion/fts_storage.py @@ -25,6 +25,7 @@ def _sanitize_fts_query(query: str) -> str: tokens = _FTS_TOKEN_RE.findall(query) return " ".join(tokens).lower() + SCHEMA_VERSION = 1 _SCHEMA = """ @@ -92,7 +93,9 @@ def __init__(self, db_path: str) -> None: def _migrate(self) -> None: with self._conn: self._conn.executescript(_SCHEMA) - current = self._conn.execute("SELECT version FROM schema_version").fetchone() + current = self._conn.execute( + "SELECT version FROM schema_version" + ).fetchone() if current is None: self._conn.execute( "INSERT INTO schema_version (version) VALUES (?)", (SCHEMA_VERSION,) @@ -113,7 +116,9 @@ def add_chunks( end_lines: Sequence[int], contents: Sequence[str], ) -> None: - rows = list(zip(chunk_ids, file_ids, languages, start_lines, end_lines, contents)) + rows = list( + zip(chunk_ids, file_ids, languages, start_lines, end_lines, contents) + ) with self._lock, self._conn: self._conn.executemany( """ @@ -135,7 +140,12 @@ def delete_by_file(self, file_id: str) -> None: self._conn.execute("DELETE FROM symbols_fts WHERE file_id = ?", (file_id,)) def search_chunks( - self, query: str, n_results: int = 50, *, language: Optional[str] = None, file_id: Optional[str] = None + self, + query: str, + n_results: int = 50, + *, + language: Optional[str] = None, + file_id: Optional[str] = None, ) -> List[Dict[str, Any]]: match_query = _sanitize_fts_query(query) if not match_query: @@ -154,7 +164,7 @@ def search_chunks( -bm25(chunks_fts) AS score FROM chunks_fts JOIN chunks ON chunks.rowid = chunks_fts.rowid - WHERE {' AND '.join(clauses)} + WHERE {" AND ".join(clauses)} ORDER BY score DESC LIMIT ? """ @@ -163,24 +173,26 @@ def search_chunks( with self._lock: rows = self._conn.execute(sql, params).fetchall() for row in rows: - results.append({ - "id": row["chunk_id"], - "chunk_id": row["chunk_id"], - "file_id": row["file_id"], - "language": row["language"], - "start_line": row["start_line"], - "end_line": row["end_line"], - "content": row["content"], - "score": row["score"], - "metadata": { - "document_id": row["file_id"], - "filepath": row["file_id"], + results.append( + { + "id": row["chunk_id"], + "chunk_id": row["chunk_id"], + "file_id": row["file_id"], + "language": row["language"], "start_line": row["start_line"], "end_line": row["end_line"], - "language": row["language"], - "chunk_type": "code", - }, - }) + "content": row["content"], + "score": row["score"], + "metadata": { + "document_id": row["file_id"], + "filepath": row["file_id"], + "start_line": row["start_line"], + "end_line": row["end_line"], + "language": row["language"], + "chunk_type": "code", + }, + } + ) return results def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: @@ -226,18 +238,20 @@ def search_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any] with self._lock: rows = self._conn.execute(sql, (match_query, n_results)).fetchall() for row in rows: - results.append({ - "id": row["symbol_id"], - "symbol_id": row["symbol_id"], - "file_id": row["file_id"], - "kind": row["kind"], - "name": row["name"], - "score": row["score"], - "metadata": { - "document_id": row["file_id"], - "filepath": row["file_id"], - "name": row["name"], + results.append( + { + "id": row["symbol_id"], + "symbol_id": row["symbol_id"], + "file_id": row["file_id"], "kind": row["kind"], - }, - }) + "name": row["name"], + "score": row["score"], + "metadata": { + "document_id": row["file_id"], + "filepath": row["file_id"], + "name": row["name"], + "kind": row["kind"], + }, + } + ) return results diff --git a/docstra/core/retrieval/fusion.py b/docstra/core/retrieval/fusion.py index abe8992..f898a8a 100644 --- a/docstra/core/retrieval/fusion.py +++ b/docstra/core/retrieval/fusion.py @@ -14,12 +14,18 @@ def rrf_score(rank: int, k: int) -> float: class _DenseLike(Protocol): - def retrieve_chunks(self, query: str, n_results: int = 20, **filters) -> List[Dict[str, Any]]: ... + def retrieve_chunks( + self, query: str, n_results: int = 20, **filters + ) -> List[Dict[str, Any]]: ... class _FtsLike(Protocol): - def retrieve_chunks(self, query: str, n_results: int = 50, **filters) -> List[Dict[str, Any]]: ... - def retrieve_symbols(self, query: str, n_results: int = 25) -> List[Dict[str, Any]]: ... + def retrieve_chunks( + self, query: str, n_results: int = 50, **filters + ) -> List[Dict[str, Any]]: ... + def retrieve_symbols( + self, query: str, n_results: int = 25 + ) -> List[Dict[str, Any]]: ... def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]: ... @@ -43,13 +49,17 @@ def __init__( self.fts_chunks_top_k = fts_chunks_top_k self.fts_symbols_top_k = fts_symbols_top_k - def retrieve(self, query: str, n_results: int = 20, **filters) -> List[Dict[str, Any]]: + def retrieve( + self, query: str, n_results: int = 20, **filters + ) -> List[Dict[str, Any]]: return self.retrieve_chunks(query, n_results=n_results, **filters) def retrieve_chunks( self, query: str, n_results: int = 20, **filters ) -> List[Dict[str, Any]]: - dense_hits = self.dense.retrieve_chunks(query, n_results=n_results * 2, **filters) + dense_hits = self.dense.retrieve_chunks( + query, n_results=n_results * 2, **filters + ) lex_chunk_hits = self.fts.retrieve_chunks( query, n_results=self.fts_chunks_top_k, **filters ) @@ -186,36 +196,43 @@ def _symbols_to_chunks( continue if file_id_filter is not None and file_id != file_id_filter: continue - if language is not None and self.code_index.file_language(file_id) != language: + if ( + language is not None + and self.code_index.file_language(file_id) != language + ): continue symbol_id = symbol_hit.get("symbol_id", "") line = _extract_line_from_symbol_id(symbol_id) if line is None: continue - for chunk_id, start_line, end_line in self.code_index.chunks_for_file(file_id): + for chunk_id, start_line, end_line in self.code_index.chunks_for_file( + file_id + ): if start_line <= line <= end_line and chunk_id not in seen: seen.add(chunk_id) chunk_row = self.fts.get_chunk(chunk_id) content = chunk_row["content"] if chunk_row else "" lang = chunk_row["language"] if chunk_row else "" - results.append({ - "chunk_id": chunk_id, - "id": chunk_id, - "file_id": file_id, - "language": lang, - "start_line": start_line, - "end_line": end_line, - "content": content, - "metadata": { - "document_id": file_id, - "filepath": file_id, + results.append( + { + "chunk_id": chunk_id, + "id": chunk_id, + "file_id": file_id, + "language": lang, "start_line": start_line, "end_line": end_line, - "language": lang, - "chunk_type": "code", - "via_symbol": symbol_hit.get("name"), - }, - }) + "content": content, + "metadata": { + "document_id": file_id, + "filepath": file_id, + "start_line": start_line, + "end_line": end_line, + "language": lang, + "chunk_type": "code", + "via_symbol": symbol_hit.get("name"), + }, + } + ) break return results diff --git a/docstra/core/services/query_service.py b/docstra/core/services/query_service.py index 7bdf012..47fc945 100644 --- a/docstra/core/services/query_service.py +++ b/docstra/core/services/query_service.py @@ -140,14 +140,22 @@ def _ensure_retrieval_components_initialized(self, abs_codebase_path: Path): legacy_repo_map = effective_persist_dir / "repo_map.json" index_db_path = effective_persist_dir / "index.db" - if not core_index_path.exists() or not chroma_check_file.exists() or not index_db_path.exists(): + if ( + not core_index_path.exists() + or not chroma_check_file.exists() + or not index_db_path.exists() + ): migration_hint = "" if legacy_index_artifacts or legacy_repo_map.exists(): migration_hint = ( " Legacy index artifacts were found. Rerun 'docstra ingest' " "to rebuild the index in the new format." ) - if not index_db_path.exists() and core_index_path.exists() and chroma_check_file.exists(): + if ( + not index_db_path.exists() + and core_index_path.exists() + and chroma_check_file.exists() + ): migration_hint += ( " The lexical index (.docstra/index.db) is missing — likely an older " "ingest. Rerun 'docstra ingest' to rebuild it." diff --git a/tests/test_core_index.py b/tests/test_core_index.py index e75a76b..8c3e728 100644 --- a/tests/test_core_index.py +++ b/tests/test_core_index.py @@ -487,28 +487,65 @@ def test_chunks_for_file_returns_chunks_in_line_order(tmp_path): index = CodebaseIndex(index_directory=str(tmp_path / "index")) index._manifest = CoreIndexManifest.empty() - index._manifest.chunks.extend([ - IndexedChunk(id="a.py#L1-L10", file_id="a.py", language="python", - start_line=1, end_line=10, chunk_type="code"), - IndexedChunk(id="a.py#L11-L20", file_id="a.py", language="python", - start_line=11, end_line=20, chunk_type="code"), - IndexedChunk(id="b.py#L1-L5", file_id="b.py", language="python", - start_line=1, end_line=5, chunk_type="code"), - ]) + index._manifest.chunks.extend( + [ + IndexedChunk( + id="a.py#L1-L10", + file_id="a.py", + language="python", + start_line=1, + end_line=10, + chunk_type="code", + ), + IndexedChunk( + id="a.py#L11-L20", + file_id="a.py", + language="python", + start_line=11, + end_line=20, + chunk_type="code", + ), + IndexedChunk( + id="b.py#L1-L5", + file_id="b.py", + language="python", + start_line=1, + end_line=5, + chunk_type="code", + ), + ] + ) index._rebuild_lookups() - assert index.chunks_for_file("a.py") == [("a.py#L1-L10", 1, 10), ("a.py#L11-L20", 11, 20)] + assert index.chunks_for_file("a.py") == [ + ("a.py#L1-L10", 1, 10), + ("a.py#L11-L20", 11, 20), + ] assert index.chunks_for_file("missing.py") == [] # Verify sorting is enforced even when chunks are inserted in reverse line order. index2 = CodebaseIndex(index_directory=str(tmp_path / "index2")) index2._manifest = CoreIndexManifest.empty() - index2._manifest.chunks.extend([ - IndexedChunk(id="c.py#L100-L110", file_id="c.py", language="python", - start_line=100, end_line=110, chunk_type="code"), - IndexedChunk(id="c.py#L1-L10", file_id="c.py", language="python", - start_line=1, end_line=10, chunk_type="code"), - ]) + index2._manifest.chunks.extend( + [ + IndexedChunk( + id="c.py#L100-L110", + file_id="c.py", + language="python", + start_line=100, + end_line=110, + chunk_type="code", + ), + IndexedChunk( + id="c.py#L1-L10", + file_id="c.py", + language="python", + start_line=1, + end_line=10, + chunk_type="code", + ), + ] + ) index2._rebuild_lookups() result = index2.chunks_for_file("c.py") diff --git a/tests/test_fts_retriever.py b/tests/test_fts_retriever.py index 5e0df04..e08ebc3 100644 --- a/tests/test_fts_retriever.py +++ b/tests/test_fts_retriever.py @@ -26,16 +26,18 @@ def test_retrieve_chunks_delegates_to_storage(tmp_path: Path): def test_retrieve_symbols_delegates_to_storage(tmp_path: Path): store = FtsStorage(str(tmp_path / "index.db")) - store.add_symbols([ - IndexedSymbol( - id="x.py::function::foo::L1", - file_id="x.py", - name="foo", - kind="function", - language="python", - line=1, - ) - ]) + store.add_symbols( + [ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ) + ] + ) retriever = FtsRetriever(store) hits = retriever.retrieve_symbols("foo", n_results=5) assert len(hits) == 1 diff --git a/tests/test_fts_storage.py b/tests/test_fts_storage.py index 2adfc70..8f25a52 100644 --- a/tests/test_fts_storage.py +++ b/tests/test_fts_storage.py @@ -2,8 +2,6 @@ from pathlib import Path -import pytest - from docstra.core.ingestion.fts_storage import FtsStorage, _sanitize_fts_query @@ -30,7 +28,7 @@ def _add_chunk(store: FtsStorage, **overrides): def test_schema_creates_on_first_open(tmp_path: Path): db_path = tmp_path / "index.db" - store = FtsStorage(str(db_path)) + FtsStorage(str(db_path)) assert db_path.exists() # Idempotent: opening again does not raise. FtsStorage(str(db_path)) @@ -59,8 +57,20 @@ def test_delete_by_file_removes_chunks_and_fts(tmp_path: Path): def test_search_supports_language_filter(tmp_path: Path): store = FtsStorage(str(tmp_path / "index.db")) - _add_chunk(store, chunk_id="a.py#L1-L1", file_id="a.py", language="python", content="foo bar") - _add_chunk(store, chunk_id="b.ts#L1-L1", file_id="b.ts", language="typescript", content="foo bar") + _add_chunk( + store, + chunk_id="a.py#L1-L1", + file_id="a.py", + language="python", + content="foo bar", + ) + _add_chunk( + store, + chunk_id="b.ts#L1-L1", + file_id="b.ts", + language="typescript", + content="foo bar", + ) hits = store.search_chunks("foo", n_results=5, language="python") assert {h["file_id"] for h in hits} == {"a.py"} @@ -69,24 +79,26 @@ def test_add_and_search_symbols(tmp_path: Path): from docstra.core.indexing.model import IndexedSymbol store = FtsStorage(str(tmp_path / "index.db")) - store.add_symbols([ - IndexedSymbol( - id="repo/file.py::function::make_chunk_id::L67", - file_id="repo/file.py", - name="make_chunk_id", - kind="function", - language="python", - line=67, - ), - IndexedSymbol( - id="repo/other.py::class::CoreIndexBuilder::L194", - file_id="repo/other.py", - name="CoreIndexBuilder", - kind="class", - language="python", - line=194, - ), - ]) + store.add_symbols( + [ + IndexedSymbol( + id="repo/file.py::function::make_chunk_id::L67", + file_id="repo/file.py", + name="make_chunk_id", + kind="function", + language="python", + line=67, + ), + IndexedSymbol( + id="repo/other.py::class::CoreIndexBuilder::L194", + file_id="repo/other.py", + name="CoreIndexBuilder", + kind="class", + language="python", + line=194, + ), + ] + ) hits = store.search_symbols("CoreIndexBuilder", n_results=5) assert len(hits) == 1 assert hits[0]["name"] == "CoreIndexBuilder" @@ -100,16 +112,18 @@ def test_delete_by_file_removes_symbols(tmp_path: Path): from docstra.core.indexing.model import IndexedSymbol store = FtsStorage(str(tmp_path / "index.db")) - store.add_symbols([ - IndexedSymbol( - id="x.py::function::foo::L1", - file_id="x.py", - name="foo", - kind="function", - language="python", - line=1, - ) - ]) + store.add_symbols( + [ + IndexedSymbol( + id="x.py::function::foo::L1", + file_id="x.py", + name="foo", + kind="function", + language="python", + line=1, + ) + ] + ) store.delete_by_file("x.py") assert store.search_symbols("foo", n_results=5) == [] @@ -170,7 +184,9 @@ def test_search_chunks_safe_across_threads(tmp_path: Path) -> None: _add_chunk(store) with ThreadPoolExecutor(max_workers=4) as pool: - results = list(pool.map(lambda _: store.search_chunks("make_chunk_id", 5), range(8))) + results = list( + pool.map(lambda _: store.search_chunks("make_chunk_id", 5), range(8)) + ) assert all(len(hits) == 1 for hits in results) diff --git a/tests/test_fusion_retriever.py b/tests/test_fusion_retriever.py index b2cdf6d..cc015ef 100644 --- a/tests/test_fusion_retriever.py +++ b/tests/test_fusion_retriever.py @@ -40,7 +40,11 @@ def _chunk(chunk_id: str, file_id: str, start_line: int = 1, end_line: int = 10) "start_line": start_line, "end_line": end_line, "content": f"# chunk {chunk_id}", - "metadata": {"document_id": file_id, "start_line": start_line, "end_line": end_line}, + "metadata": { + "document_id": file_id, + "start_line": start_line, + "end_line": end_line, + }, } @@ -48,7 +52,9 @@ def _fake_code_index(chunks_by_file): def chunks_for_file(file_id): return chunks_by_file.get(file_id, []) - return SimpleNamespace(chunks_for_file=chunks_for_file, file_language=lambda fid: None) + return SimpleNamespace( + chunks_for_file=chunks_for_file, file_language=lambda fid: None + ) def test_rrf_score_known_inputs(): @@ -89,7 +95,14 @@ def test_symbol_hit_promotes_containing_chunk(): dense = _FakeDense([b, a]) fts = _FakeFts( chunks=[], - symbols=[{"symbol_id": "repo/a.py::function::foo::L5", "file_id": "repo/a.py", "name": "foo", "kind": "function"}], + symbols=[ + { + "symbol_id": "repo/a.py::function::foo::L5", + "file_id": "repo/a.py", + "name": "foo", + "kind": "function", + } + ], ) code_index = _fake_code_index({"repo/a.py": [("repo/a.py#L1-L10", 1, 10)]}) @@ -109,14 +122,23 @@ def test_symbol_hit_promotes_containing_chunk(): def test_symbol_path_respects_language_filter(): """Symbol-derived chunks must obey the language filter just like dense/lex chunks.""" py_chunk = _chunk("repo/a.py#L1-L10", "repo/a.py", start_line=1, end_line=10) - ts_chunk = _chunk("repo/b.ts#L1-L10", "repo/b.ts", start_line=1, end_line=10) dense = _FakeDense([py_chunk]) fts = _FakeFts( chunks=[], symbols=[ - {"symbol_id": "repo/a.py::function::foo::L5", "file_id": "repo/a.py", "name": "foo", "kind": "function"}, - {"symbol_id": "repo/b.ts::function::foo::L3", "file_id": "repo/b.ts", "name": "foo", "kind": "function"}, + { + "symbol_id": "repo/a.py::function::foo::L5", + "file_id": "repo/a.py", + "name": "foo", + "kind": "function", + }, + { + "symbol_id": "repo/b.ts::function::foo::L3", + "file_id": "repo/b.ts", + "name": "foo", + "kind": "function", + }, ], ) @@ -125,7 +147,10 @@ def test_symbol_path_respects_language_filter(): "repo/a.py": [("repo/a.py#L1-L10", 1, 10)], "repo/b.ts": [("repo/b.ts#L1-L10", 1, 10)], }.get(fid, []), - file_language=lambda fid: {"repo/a.py": "python", "repo/b.ts": "typescript"}.get(fid), + file_language=lambda fid: { + "repo/a.py": "python", + "repo/b.ts": "typescript", + }.get(fid), ) fusion = FusionRetriever( @@ -185,13 +210,22 @@ def retrieve_chunks(self, *a, **kw): return [] def retrieve_symbols(self, *a, **kw): - return [{"symbol_id": "repo/a.py::function::foo::L2", "file_id": "repo/a.py", "name": "foo", "kind": "function"}] + return [ + { + "symbol_id": "repo/a.py::function::foo::L2", + "file_id": "repo/a.py", + "name": "foo", + "kind": "function", + } + ] def get_chunk(self, chunk_id): return self._real.get_chunk(chunk_id) code_index = SimpleNamespace( - chunks_for_file=lambda fid: [("repo/a.py#L1-L10", 1, 10)] if fid == "repo/a.py" else [], + chunks_for_file=lambda fid: ( + [("repo/a.py#L1-L10", 1, 10)] if fid == "repo/a.py" else [] + ), file_language=lambda fid: "python" if fid == "repo/a.py" else None, )