Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ The configuration of the server is done using environment variables:
| `QDRANT_URL` | URL of the Qdrant server | None |
| `QDRANT_API_KEY` | API key for the Qdrant server | None |
| `COLLECTION_NAME` | Name of the default collection to use. | None |
| `QDRANT_VECTOR_NAME` | Name of the vector to be used. | None |
| `QDRANT_LOCAL_PATH` | Path to the local Qdrant database (alternative to `QDRANT_URL`) | None |
| `EMBEDDING_PROVIDER` | Embedding provider to use (currently only "fastembed" is supported) | `fastembed` |
| `EMBEDDING_MODEL` | Name of the embedding model to use | `sentence-transformers/all-MiniLM-L6-v2` |
Expand All @@ -51,6 +52,9 @@ The configuration of the server is done using environment variables:

Note: You cannot provide both `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time.

> [!IMPORTANT]
> `QDRANT_VECTOR_NAME` will be used for new collections, and the unnamed default vector is used if it is not set. For existing collections the embedding model name is used if it already exists to ensure backward compatibility.

> [!IMPORTANT]
> Command-line arguments are not supported anymore! Please use environment variables for all configuration.

Expand Down
13 changes: 7 additions & 6 deletions src/mcp_server_qdrant/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ def __init__(
assert self.embedding_provider is not None, "Embedding provider is required"

self.qdrant_connector = QdrantConnector(
qdrant_settings.location,
qdrant_settings.api_key,
qdrant_settings.collection_name,
self.embedding_provider,
qdrant_settings.local_path,
make_indexes(qdrant_settings.filterable_fields_dict()),
qdrant_url=qdrant_settings.location,
qdrant_api_key=qdrant_settings.api_key,
collection_name=qdrant_settings.collection_name,
embedding_provider=self.embedding_provider,
qdrant_local_path=qdrant_settings.local_path,
vector_name=qdrant_settings.vector_name,
field_indexes=make_indexes(qdrant_settings.filterable_fields_dict()),
)

super().__init__(name=name, instructions=instructions, **settings)
Expand Down
41 changes: 30 additions & 11 deletions src/mcp_server_qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class QdrantConnector:
:param qdrant_api_key: The API key to use for the Qdrant server.
:param collection_name: The name of the default collection to use. If not provided, each tool will require
the collection name to be provided.
:param vector_name: The name of the vector to be used. If not provided, the default vector will be used.
:param embedding_provider: The embedding provider to use.
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
"""
Expand All @@ -40,6 +41,7 @@ def __init__(
qdrant_api_key: str | None,
collection_name: str | None,
embedding_provider: EmbeddingProvider,
vector_name: str | None = None,
qdrant_local_path: str | None = None,
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
):
Expand All @@ -51,6 +53,7 @@ def __init__(
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
)
self._field_indexes = field_indexes
self._vector_name = vector_name

async def get_collection_names(self) -> list[str]:
"""
Expand All @@ -75,16 +78,20 @@ async def store(self, entry: Entry, *, collection_name: str | None = None):
# ToDo: instead of embedding text explicitly, use `models.Document`,
# it should unlock usage of server-side inference.
embeddings = await self._embedding_provider.embed_documents([entry.content])
vector = (
embeddings[0]
if self._vector_name is None
else {self._vector_name: embeddings[0]}
)

# Add to Qdrant
vector_name = self._embedding_provider.get_vector_name()
payload = {"document": entry.content, METADATA_PATH: entry.metadata}
await self._client.upsert(
collection_name=collection_name,
points=[
models.PointStruct(
id=uuid.uuid4().hex,
vector={vector_name: embeddings[0]},
vector=vector,
payload=payload,
)
],
Expand Down Expand Up @@ -118,13 +125,12 @@ async def search(
# it should unlock usage of server-side inference.

query_vector = await self._embedding_provider.embed_query(query)
vector_name = self._embedding_provider.get_vector_name()

# Search in Qdrant
search_results = await self._client.query_points(
collection_name=collection_name,
query=query_vector,
using=vector_name,
using=self._vector_name,
limit=limit,
query_filter=query_filter,
)
Expand All @@ -146,17 +152,17 @@ async def _ensure_collection_exists(self, collection_name: str):
if not collection_exists:
# Create the collection with the appropriate vector size
vector_size = self._embedding_provider.get_vector_size()
vector_params = models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)

# Use the vector name as defined in the embedding provider
vector_name = self._embedding_provider.get_vector_name()
await self._client.create_collection(
collection_name=collection_name,
vectors_config={
vector_name: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
)
},
vectors_config=vector_params
if self._vector_name is None
else {self._vector_name: vector_params},
)

# Create payload indexes if configured
Expand All @@ -168,3 +174,16 @@ async def _ensure_collection_exists(self, collection_name: str):
field_name=field_name,
field_schema=field_type,
)
else:
points: list[models.ScoredPoint] = (
await self._client.query_points(
collection_name=collection_name, limit=1, with_vectors=True
)
).points
model_vector_name = self._embedding_provider.get_vector_name()
if (
len(points) > 0
and isinstance(points[0].vector, dict)
and model_vector_name in points[0].vector
):
self._vector_name = model_vector_name
1 change: 1 addition & 0 deletions src/mcp_server_qdrant/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class QdrantSettings(BaseSettings):
default=None, validation_alias="COLLECTION_NAME"
)
local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
vector_name: str | None = Field(default=None, validation_alias="QDRANT_VECTOR_NAME")
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")

Expand Down
Loading