diff --git a/code_review_graph/embeddings.py b/code_review_graph/embeddings.py index 468753a..7324c6c 100644 --- a/code_review_graph/embeddings.py +++ b/code_review_graph/embeddings.py @@ -62,7 +62,11 @@ def _get_model(self): if self._model is None: try: from sentence_transformers import SentenceTransformer - self._model = SentenceTransformer(self._model_name) + self._model = SentenceTransformer( + self._model_name, + trust_remote_code=True, + model_kwargs={"trust_remote_code": True}, + ) except ImportError: raise ImportError( "sentence-transformers not installed. " diff --git a/code_review_graph/search.py b/code_review_graph/search.py index 425f67f..d2eb84e 100644 --- a/code_review_graph/search.py +++ b/code_review_graph/search.py @@ -168,6 +168,7 @@ def _embedding_search( store: GraphStore, query: str, limit: int = 50, + model: str | None = None, ) -> list[tuple[int, float]]: """Run a vector similarity search using the embedding store. @@ -180,7 +181,7 @@ def _embedding_search( return [] try: - emb_store = EmbeddingStore(store.db_path) + emb_store = EmbeddingStore(store.db_path, model=model) try: if not emb_store.available or emb_store.count() == 0: return [] @@ -264,6 +265,7 @@ def hybrid_search( kind: Optional[str] = None, limit: int = 20, context_files: Optional[list[str]] = None, + model: Optional[str] = None, ) -> list[dict[str, Any]]: """Hybrid search combining FTS5 BM25 and vector embeddings via RRF. @@ -301,7 +303,7 @@ def hybrid_search( logger.warning("FTS5 unavailable, will use fallback: %s", e) # Try embedding search - emb_results = _embedding_search(store, query, limit=fetch_limit) + emb_results = _embedding_search(store, query, limit=fetch_limit, model=model) # ------ Phase 2: Merge via RRF or fallback ------ if fts_results or emb_results: diff --git a/code_review_graph/tools/query.py b/code_review_graph/tools/query.py index 0330d9a..6516783 100644 --- a/code_review_graph/tools/query.py +++ b/code_review_graph/tools/query.py @@ -321,6 +321,7 @@ def semantic_search_nodes( try: results = hybrid_search( store, query, kind=kind, limit=limit, context_files=context_files, + model=model, ) search_mode = "hybrid"