Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,78 @@ python manage.py update_vector_indexes
```

To skip the prompt, use the `--noinput` flag.

## Using event-stream (WagtailVectorIndexSSEConsumer)
Comment thread
Morsey187 marked this conversation as resolved.

`WagtailVectorIndexSSEConsumer` is an asynchronous HTTP consumer designed for handling Server-Sent Events (SSE) for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI ([uvicorn](https://pypi.org/project/uvicorn/), [Daphne](https://pypi.org/project/daphne/) etc.) along with [django-channels](https://pypi.org/project/django-channels/).

You can configure channels using the [official guide](https://channels.readthedocs.io/en/3.x/installation.html). At a minimum, install the `channels` package and add it to `INSTALLED_APPS` in your settings file, and configure support for ASGI.

```python
Comment thread
Morsey187 marked this conversation as resolved.
# settings.py

INSTALLED_APPS = [
"channels",
# ...
]
```

Next, you will need to define a new consumer inheriting from `WagtailVectorIndexSSEConsumer`, and assign a Wagtail page model for the vector index you'd like to use.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: I believe this is no longer true now we use a single consumer across all indexes?


!!! Note
The `AuthMiddleware` is required to provide user context to the consumer.


```python
# app_name/asgi.py
import os
Comment thread
Morsey187 marked this conversation as resolved.

from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.core.asgi import get_asgi_application
from django.urls import path, re_path
from wagtail_vector_index.consumers import WagtailVectorIndexSSEConsumer


os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings.production")
django_asgi_app = get_asgi_application()


application = ProtocolTypeRouter(
{
"http": URLRouter(
[
path(
"chat-query-sse/",
AuthMiddlewareStack(WagtailVectorIndexSSEConsumer.as_asgi()),
),
re_path(r"", get_asgi_application()),
]
),
}
)
```

You should now be able to query the consumer using the [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API. The snippet below is an example implementation:

```javascript
function chatQuery(query, pageType) {
const es = new EventSource(
`/chat-query-sse/?query=${query}&page_type=${pageType}`,
);

es.onmessage = (e) => {
console.log(e.data);
// Do something
};
es.onerror = () => {
// Ending an EventSource object from the server results in an error.
// Close the EventSource here to prevent repeated requests.
es.close();
};
}
```

### Known issues

Asynchronous support in Django is fairly new and `WagtailVectorIndexSSEConsumer` can't tell when a client disconnects from an event-stream. This may result in queries being processed by the server as zombie threads.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ dependencies = [
"aiohttp>=3.9.0b0; python_version >= '3.12'",
]
[project.optional-dependencies]
sse = [
"channels>=3.0.5",
]
numpy = [
"numpy>=1.26.0",
]
Expand Down
4 changes: 4 additions & 0 deletions src/wagtail_vector_index/ai_utils/backends/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def _get_llm_chat_model(self) -> llm.Model:
setattr(model, config_key, config_val)
return model

def can_stream(self) -> bool:
model = self._get_llm_chat_model()
return model.can_stream


class LLMEmbeddingBackend(BaseEmbeddingBackend[LLMEmbeddingBackendConfig]):
config: LLMEmbeddingBackendConfig
Expand Down
116 changes: 116 additions & 0 deletions src/wagtail_vector_index/consumers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
import logging

from channels.generic.http import AsyncHttpConsumer
from django import forms
from django.core.exceptions import ValidationError
from django.http import QueryDict

from .base import VectorIndexableType

logger = logging.Logger(__name__)


class WagtailVectorIndexQueryParamsForm(forms.Form):
"""Provides a form for validating query parameters."""

query = forms.CharField(max_length=255, required=True)
index = forms.CharField(max_length=255, required=True)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

from wagtail_vector_index.index import get_vector_indexes

self.indexes = get_vector_indexes()

def clean_index(self):
index = self.cleaned_data["index"]
if index not in self.indexes:
raise forms.ValidationError("Invalid index. Please choose a valid index.")
return index


class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer):
"""
A Django Channels consumer for handling Server-Sent Events (SSE) related to WagtailVectorIndex queries.

Methods:
handle: The main entry point for processing HTTP requests, including SSE connections.
process_prompt: Processes the incoming prompt and sends SSE updates.

Note:
This consumer expects the following query parameters in the URL:
- 'query': The search query.
- 'index': The vector index to perform the query with.

Example URL:
"/chat-query-sse/?query=example&index=news.NewsPage"
"""

async def handle(self, body: bytes) -> None:
"""
Handles HTTP requests, sets up SSE headers, and processes prompts.
"""
# Send SSE headers
await self.send_headers(
headers=[
(b"Cache-Control", b"no-cache"),
(b"Content-Type", b"text/event-stream"),
(b"Transfer-Encoding", b"chunked"),
]
)

try:
query_string = self.scope["query_string"].decode("utf-8")
query_dict = QueryDict(query_string)

# Validate query parameters
form = WagtailVectorIndexQueryParamsForm(query_dict)
if not form.is_valid():
# Ignore "TRY301 Abstract `raise` to an inner function"
# So we can insure the event-stream is closed and no other code is executed
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (non-blocking): insure -> ensure

raise ValidationError("Invalid query parameters.") # noqa: TRY301
query = form.cleaned_data["query"]
index = form.cleaned_data["index"]

vector_index = form.indexes.get(index)

if vector_index:
await self.process_prompt(query, vector_index)

except ValidationError:
await self.error_response()

except Exception:
logging.exception("Unexpected error in WagtailVectorIndexSSEConsumer")
await self.error_response()

# Finish the response
await self.send_body(b"")

async def error_response(self) -> None:
payload = "data: Error processing request, Please try again later. \n\n"
await self.send_body(payload.encode("utf-8"), more_body=True)

async def process_prompt(
self, query: str, vector_index: VectorIndexableType
) -> None:
"""
Processes the incoming prompt and sends SSE updates.

Raises:
asyncio.CancelledError: If the connection is cancelled or disconnected.
"""
try:
results = await vector_index.aquery(query)
for chunk in results.response:
chunk = chunk.replace(
"\n", "<br/>"
) # Replace newlines with HTML line breaks to avoid issues with encoding.
payload = f"data: {chunk}\n\n" # Each message must be terminated using two newline characters.
await self.send_body(payload.encode("utf-8"), more_body=True)
except asyncio.CancelledError:
# Handle disconnects if needed, can occur from a server restart.
# Note: Django < 5 doesn't recognise client disconnects
pass
61 changes: 61 additions & 0 deletions src/wagtail_vector_index/index/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import logging
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from typing import Generic

from asgiref.sync import sync_to_async
from channels.db import database_sync_to_async
from django.conf import settings
from llm.models import Response

from wagtail_vector_index.ai import get_chat_backend, get_embedding_backend
from wagtail_vector_index.backends import get_vector_backend

from ..ai_utils.backends.base import BaseChatBackend, BaseEmbeddingBackend
from ..base import Document, VectorIndexableType

logger = logging.Logger(__name__)


@dataclass
class QueryResponse(Generic[VectorIndexableType]):
Expand All @@ -21,6 +27,25 @@ class QueryResponse(Generic[VectorIndexableType]):
sources: Iterable[VectorIndexableType]


@dataclass
class AsyncQueryResponse(Generic[VectorIndexableType]):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I have to return the whole response object to access it's iterator (response chunks). I’d prefer to avoid this and instead update QueryResponse to return the whole response instead of a string too.

Theres also the case that users might want to access the response object as a whole and for other attributes and methods like json().

If so maybe we can rename response as "llm_response"?

"""Represents a response to the VectorIndex `aquery` method,
including a response object so users can call it's iterator
and a list of sources that were used to generate the response
"""

response: Response
sources: Iterable[VectorIndexableType]


@database_sync_to_async
def get_metadata_from_documents_async(similar_documents):
metadata_list = []
for doc in similar_documents:
metadata_list.append(doc.metadata["content"])
return "\n".join(metadata_list)


class VectorIndex(Generic[VectorIndexableType]):
"""Base class for a VectorIndex, representing some set of documents that can be queried"""

Expand Down Expand Up @@ -70,9 +95,45 @@ def query(
merged_context,
query,
]

response = self.chat_backend.chat(user_messages=user_messages)
return QueryResponse(response=response.text(), sources=sources)

async def aquery(self, query: str) -> AsyncQueryResponse[VectorIndexableType]:
"""
Async version of the query method.
"""
if not self.chat_backend.can_stream():
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I've added this here as I think there will likely be an issue with developers expecting streaming support when using async, especially since we're abstracting from the LLM package, so theres no obvious way for them to know without tracing errors back.

I don't think this is the correct solution if it is a problem, but have added it here to highlight the potential issue and get some thoughts?

logger.warning("Chat backend does not support streaming")

try:
query_embedding = next(self.embedding_backend.embed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e

similar_documents = await sync_to_async(self.backend_index.similarity_search)(
query_embedding
)

sources = await sync_to_async(self._deduplicate_list)(
self.object_type.bulk_from_documents(similar_documents)
)

merged_context = await get_metadata_from_documents_async(similar_documents)

prompt = (
getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None)
or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer."
)
user_messages = [
prompt,
merged_context,
query,
]

response = self.chat_backend.chat(user_messages=user_messages)
return AsyncQueryResponse(response=response, sources=sources)

def similar(
self, object: VectorIndexableType, *, include_self: bool = False, limit: int = 5
) -> list[VectorIndexableType]:
Expand Down