diff --git a/README.md b/README.md index 5bb05f9a..643ac971 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ The configuration of the server is done using environment variables: | `QDRANT_URL` | URL of the Qdrant server | None | | `QDRANT_API_KEY` | API key for the Qdrant server | None | | `COLLECTION_NAME` | Name of the default collection to use. | None | +| `QDRANT_VECTOR_NAME` | Name of the vector to be used. | None | | `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None | | `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` | | `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` | @@ -51,6 +52,9 @@ The configuration of the server is done using environment variables: Note: You cannot provide both `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time. +> [!IMPORTANT] +> `QDRANT_VECTOR_NAME` will be used for new collections, and the unnamed default vector is used if it is not set. For existing collections the embedding model name is used if it already exists to ensure backward compatibility. + > [!IMPORTANT] > Command-line arguments are not supported anymore! Please use environment variables for all configuration. diff --git a/src/mcp_server_qdrant/mcp_server.py b/src/mcp_server_qdrant/mcp_server.py index 0617b9d8..7044ff3a 100644 --- a/src/mcp_server_qdrant/mcp_server.py +++ b/src/mcp_server_qdrant/mcp_server.py @@ -66,12 +66,13 @@ def __init__( assert self.embedding_provider is not None, "Embedding provider is required" self.qdrant_connector = QdrantConnector( - qdrant_settings.location, - qdrant_settings.api_key, - qdrant_settings.collection_name, - self.embedding_provider, - qdrant_settings.local_path, - make_indexes(qdrant_settings.filterable_fields_dict()), + qdrant_url=qdrant_settings.location, + qdrant_api_key=qdrant_settings.api_key, + collection_name=qdrant_settings.collection_name, + embedding_provider=self.embedding_provider, + qdrant_local_path=qdrant_settings.local_path, + vector_name=qdrant_settings.vector_name, + field_indexes=make_indexes(qdrant_settings.filterable_fields_dict()), ) super().__init__(name=name, instructions=instructions, **settings) diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index 8d3e5aa8..16cabdd7 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -30,6 +30,7 @@ class QdrantConnector: :param qdrant_api_key: The API key to use for the Qdrant server. :param collection_name: The name of the default collection to use. If not provided, each tool will require the collection name to be provided. + :param vector_name: The name of the vector to be used. If not provided, the default vector will be used. :param embedding_provider: The embedding provider to use. :param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used. """ @@ -40,6 +41,7 @@ def __init__( qdrant_api_key: str | None, collection_name: str | None, embedding_provider: EmbeddingProvider, + vector_name: str | None = None, qdrant_local_path: str | None = None, field_indexes: dict[str, models.PayloadSchemaType] | None = None, ): @@ -51,6 +53,7 @@ def __init__( location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path ) self._field_indexes = field_indexes + self._vector_name = vector_name async def get_collection_names(self) -> list[str]: """ @@ -75,16 +78,20 @@ async def store(self, entry: Entry, *, collection_name: str | None = None): # ToDo: instead of embedding text explicitly, use `models.Document`, # it should unlock usage of server-side inference. embeddings = await self._embedding_provider.embed_documents([entry.content]) + vector = ( + embeddings[0] + if self._vector_name is None + else {self._vector_name: embeddings[0]} + ) # Add to Qdrant - vector_name = self._embedding_provider.get_vector_name() payload = {"document": entry.content, METADATA_PATH: entry.metadata} await self._client.upsert( collection_name=collection_name, points=[ models.PointStruct( id=uuid.uuid4().hex, - vector={vector_name: embeddings[0]}, + vector=vector, payload=payload, ) ], @@ -118,13 +125,12 @@ async def search( # it should unlock usage of server-side inference. query_vector = await self._embedding_provider.embed_query(query) - vector_name = self._embedding_provider.get_vector_name() # Search in Qdrant search_results = await self._client.query_points( collection_name=collection_name, query=query_vector, - using=vector_name, + using=self._vector_name, limit=limit, query_filter=query_filter, ) @@ -146,17 +152,17 @@ async def _ensure_collection_exists(self, collection_name: str): if not collection_exists: # Create the collection with the appropriate vector size vector_size = self._embedding_provider.get_vector_size() + vector_params = models.VectorParams( + size=vector_size, + distance=models.Distance.COSINE, + ) # Use the vector name as defined in the embedding provider - vector_name = self._embedding_provider.get_vector_name() await self._client.create_collection( collection_name=collection_name, - vectors_config={ - vector_name: models.VectorParams( - size=vector_size, - distance=models.Distance.COSINE, - ) - }, + vectors_config=vector_params + if self._vector_name is None + else {self._vector_name: vector_params}, ) # Create payload indexes if configured @@ -168,3 +174,16 @@ async def _ensure_collection_exists(self, collection_name: str): field_name=field_name, field_schema=field_type, ) + else: + points: list[models.ScoredPoint] = ( + await self._client.query_points( + collection_name=collection_name, limit=1, with_vectors=True + ) + ).points + model_vector_name = self._embedding_provider.get_vector_name() + if ( + len(points) > 0 + and isinstance(points[0].vector, dict) + and model_vector_name in points[0].vector + ): + self._vector_name = model_vector_name diff --git a/src/mcp_server_qdrant/settings.py b/src/mcp_server_qdrant/settings.py index e48c10d1..79ce5fc5 100644 --- a/src/mcp_server_qdrant/settings.py +++ b/src/mcp_server_qdrant/settings.py @@ -82,6 +82,7 @@ class QdrantSettings(BaseSettings): default=None, validation_alias="COLLECTION_NAME" ) local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH") + vector_name: str | None = Field(default=None, validation_alias="QDRANT_VECTOR_NAME") search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT") read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")