-
Notifications
You must be signed in to change notification settings - Fork 17
Add support for SSE (WagtailVectorIndexSSEConsumer) and async querying #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f66025e
482f90f
982ad2d
940b1b0
24162bf
2deafa8
215b22a
e2441e1
6b370e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,3 +53,78 @@ python manage.py update_vector_indexes | |
| ``` | ||
|
|
||
| To skip the prompt, use the `--noinput` flag. | ||
|
|
||
| ## Using event-stream (WagtailVectorIndexSSEConsumer) | ||
|
|
||
| `WagtailVectorIndexSSEConsumer` is an asynchronous HTTP consumer designed for handling Server-Sent Events (SSE) for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI ([uvicorn](https://pypi.org/project/uvicorn/), [Daphne](https://pypi.org/project/daphne/) etc.) along with [django-channels](https://pypi.org/project/django-channels/). | ||
|
|
||
| You can configure channels using the [official guide](https://channels.readthedocs.io/en/3.x/installation.html). At a minimum, install the `channels` package and add it to `INSTALLED_APPS` in your settings file, and configure support for ASGI. | ||
|
|
||
| ```python | ||
|
Morsey187 marked this conversation as resolved.
|
||
| # settings.py | ||
|
|
||
| INSTALLED_APPS = [ | ||
| "channels", | ||
| # ... | ||
| ] | ||
| ``` | ||
|
|
||
| Next, you will need to define a new consumer inheriting from `WagtailVectorIndexSSEConsumer`, and assign a Wagtail page model for the vector index you'd like to use. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: I believe this is no longer true now we use a single consumer across all indexes? |
||
|
|
||
| !!! Note | ||
| The `AuthMiddleware` is required to provide user context to the consumer. | ||
|
|
||
|
|
||
| ```python | ||
| # app_name/asgi.py | ||
| import os | ||
|
Morsey187 marked this conversation as resolved.
|
||
|
|
||
| from channels.auth import AuthMiddlewareStack | ||
| from channels.routing import ProtocolTypeRouter, URLRouter | ||
| from django.core.asgi import get_asgi_application | ||
| from django.urls import path, re_path | ||
| from wagtail_vector_index.consumers import WagtailVectorIndexSSEConsumer | ||
|
|
||
|
|
||
| os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings.production") | ||
| django_asgi_app = get_asgi_application() | ||
|
|
||
|
|
||
| application = ProtocolTypeRouter( | ||
| { | ||
| "http": URLRouter( | ||
| [ | ||
| path( | ||
| "chat-query-sse/", | ||
| AuthMiddlewareStack(WagtailVectorIndexSSEConsumer.as_asgi()), | ||
| ), | ||
| re_path(r"", get_asgi_application()), | ||
| ] | ||
| ), | ||
| } | ||
| ) | ||
| ``` | ||
|
|
||
| You should now be able to query the consumer using the [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API. The snippet below is an example implementation: | ||
|
|
||
| ```javascript | ||
| function chatQuery(query, pageType) { | ||
| const es = new EventSource( | ||
| `/chat-query-sse/?query=${query}&page_type=${pageType}`, | ||
| ); | ||
|
|
||
| es.onmessage = (e) => { | ||
| console.log(e.data); | ||
| // Do something | ||
| }; | ||
| es.onerror = () => { | ||
| // Ending an EventSource object from the server results in an error. | ||
| // Close the EventSource here to prevent repeated requests. | ||
| es.close(); | ||
| }; | ||
| } | ||
| ``` | ||
|
|
||
| ### Known issues | ||
|
|
||
| Asynchronous support in Django is fairly new and `WagtailVectorIndexSSEConsumer` can't tell when a client disconnects from an event-stream. This may result in queries being processed by the server as zombie threads. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| import asyncio | ||
| import logging | ||
|
|
||
| from channels.generic.http import AsyncHttpConsumer | ||
| from django import forms | ||
| from django.core.exceptions import ValidationError | ||
| from django.http import QueryDict | ||
|
|
||
| from .base import VectorIndexableType | ||
|
|
||
| logger = logging.Logger(__name__) | ||
|
|
||
|
|
||
| class WagtailVectorIndexQueryParamsForm(forms.Form): | ||
| """Provides a form for validating query parameters.""" | ||
|
|
||
| query = forms.CharField(max_length=255, required=True) | ||
| index = forms.CharField(max_length=255, required=True) | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| from wagtail_vector_index.index import get_vector_indexes | ||
|
|
||
| self.indexes = get_vector_indexes() | ||
|
|
||
| def clean_index(self): | ||
| index = self.cleaned_data["index"] | ||
| if index not in self.indexes: | ||
| raise forms.ValidationError("Invalid index. Please choose a valid index.") | ||
| return index | ||
|
|
||
|
|
||
| class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer): | ||
| """ | ||
| A Django Channels consumer for handling Server-Sent Events (SSE) related to WagtailVectorIndex queries. | ||
|
|
||
| Methods: | ||
| handle: The main entry point for processing HTTP requests, including SSE connections. | ||
| process_prompt: Processes the incoming prompt and sends SSE updates. | ||
|
|
||
| Note: | ||
| This consumer expects the following query parameters in the URL: | ||
| - 'query': The search query. | ||
| - 'index': The vector index to perform the query with. | ||
|
|
||
| Example URL: | ||
| "/chat-query-sse/?query=example&index=news.NewsPage" | ||
| """ | ||
|
|
||
| async def handle(self, body: bytes) -> None: | ||
| """ | ||
| Handles HTTP requests, sets up SSE headers, and processes prompts. | ||
| """ | ||
| # Send SSE headers | ||
| await self.send_headers( | ||
| headers=[ | ||
| (b"Cache-Control", b"no-cache"), | ||
| (b"Content-Type", b"text/event-stream"), | ||
| (b"Transfer-Encoding", b"chunked"), | ||
| ] | ||
| ) | ||
|
|
||
| try: | ||
| query_string = self.scope["query_string"].decode("utf-8") | ||
| query_dict = QueryDict(query_string) | ||
|
|
||
| # Validate query parameters | ||
| form = WagtailVectorIndexQueryParamsForm(query_dict) | ||
| if not form.is_valid(): | ||
| # Ignore "TRY301 Abstract `raise` to an inner function" | ||
| # So we can insure the event-stream is closed and no other code is executed | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick (non-blocking): insure -> ensure |
||
| raise ValidationError("Invalid query parameters.") # noqa: TRY301 | ||
| query = form.cleaned_data["query"] | ||
| index = form.cleaned_data["index"] | ||
|
|
||
| vector_index = form.indexes.get(index) | ||
|
|
||
| if vector_index: | ||
| await self.process_prompt(query, vector_index) | ||
|
|
||
| except ValidationError: | ||
| await self.error_response() | ||
|
|
||
| except Exception: | ||
| logging.exception("Unexpected error in WagtailVectorIndexSSEConsumer") | ||
| await self.error_response() | ||
|
|
||
| # Finish the response | ||
| await self.send_body(b"") | ||
|
|
||
| async def error_response(self) -> None: | ||
| payload = "data: Error processing request, Please try again later. \n\n" | ||
| await self.send_body(payload.encode("utf-8"), more_body=True) | ||
|
|
||
| async def process_prompt( | ||
| self, query: str, vector_index: VectorIndexableType | ||
| ) -> None: | ||
| """ | ||
| Processes the incoming prompt and sends SSE updates. | ||
|
|
||
| Raises: | ||
| asyncio.CancelledError: If the connection is cancelled or disconnected. | ||
| """ | ||
| try: | ||
| results = await vector_index.aquery(query) | ||
| for chunk in results.response: | ||
| chunk = chunk.replace( | ||
| "\n", "<br/>" | ||
| ) # Replace newlines with HTML line breaks to avoid issues with encoding. | ||
| payload = f"data: {chunk}\n\n" # Each message must be terminated using two newline characters. | ||
| await self.send_body(payload.encode("utf-8"), more_body=True) | ||
| except asyncio.CancelledError: | ||
| # Handle disconnects if needed, can occur from a server restart. | ||
| # Note: Django < 5 doesn't recognise client disconnects | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,21 @@ | ||
| import logging | ||
| from collections.abc import Generator, Iterable | ||
| from dataclasses import dataclass | ||
| from typing import Generic | ||
|
|
||
| from asgiref.sync import sync_to_async | ||
| from channels.db import database_sync_to_async | ||
| from django.conf import settings | ||
| from llm.models import Response | ||
|
|
||
| from wagtail_vector_index.ai import get_chat_backend, get_embedding_backend | ||
| from wagtail_vector_index.backends import get_vector_backend | ||
|
|
||
| from ..ai_utils.backends.base import BaseChatBackend, BaseEmbeddingBackend | ||
| from ..base import Document, VectorIndexableType | ||
|
|
||
| logger = logging.Logger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class QueryResponse(Generic[VectorIndexableType]): | ||
|
|
@@ -21,6 +27,25 @@ class QueryResponse(Generic[VectorIndexableType]): | |
| sources: Iterable[VectorIndexableType] | ||
|
|
||
|
|
||
| @dataclass | ||
| class AsyncQueryResponse(Generic[VectorIndexableType]): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I have to return the whole response object to access it's iterator (response chunks). I’d prefer to avoid this and instead update QueryResponse to return the whole response instead of a string too. Theres also the case that users might want to access the response object as a whole and for other attributes and methods like If so maybe we can rename response as "llm_response"? |
||
| """Represents a response to the VectorIndex `aquery` method, | ||
| including a response object so users can call it's iterator | ||
| and a list of sources that were used to generate the response | ||
| """ | ||
|
|
||
| response: Response | ||
| sources: Iterable[VectorIndexableType] | ||
|
|
||
|
|
||
| @database_sync_to_async | ||
| def get_metadata_from_documents_async(similar_documents): | ||
| metadata_list = [] | ||
| for doc in similar_documents: | ||
| metadata_list.append(doc.metadata["content"]) | ||
| return "\n".join(metadata_list) | ||
|
|
||
|
|
||
| class VectorIndex(Generic[VectorIndexableType]): | ||
| """Base class for a VectorIndex, representing some set of documents that can be queried""" | ||
|
|
||
|
|
@@ -70,9 +95,45 @@ def query( | |
| merged_context, | ||
| query, | ||
| ] | ||
|
|
||
| response = self.chat_backend.chat(user_messages=user_messages) | ||
| return QueryResponse(response=response.text(), sources=sources) | ||
|
|
||
| async def aquery(self, query: str) -> AsyncQueryResponse[VectorIndexableType]: | ||
| """ | ||
| Async version of the query method. | ||
| """ | ||
| if not self.chat_backend.can_stream(): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I've added this here as I think there will likely be an issue with developers expecting streaming support when using async, especially since we're abstracting from the LLM package, so theres no obvious way for them to know without tracing errors back. I don't think this is the correct solution if it is a problem, but have added it here to highlight the potential issue and get some thoughts? |
||
| logger.warning("Chat backend does not support streaming") | ||
|
|
||
| try: | ||
| query_embedding = next(self.embedding_backend.embed([query])) | ||
| except StopIteration as e: | ||
| raise ValueError("No embeddings were generated for the given query.") from e | ||
|
|
||
| similar_documents = await sync_to_async(self.backend_index.similarity_search)( | ||
| query_embedding | ||
| ) | ||
|
|
||
| sources = await sync_to_async(self._deduplicate_list)( | ||
| self.object_type.bulk_from_documents(similar_documents) | ||
| ) | ||
|
|
||
| merged_context = await get_metadata_from_documents_async(similar_documents) | ||
|
|
||
| prompt = ( | ||
| getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None) | ||
| or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer." | ||
| ) | ||
| user_messages = [ | ||
| prompt, | ||
| merged_context, | ||
| query, | ||
| ] | ||
|
|
||
| response = self.chat_backend.chat(user_messages=user_messages) | ||
| return AsyncQueryResponse(response=response, sources=sources) | ||
|
|
||
| def similar( | ||
| self, object: VectorIndexableType, *, include_self: bool = False, limit: int = 5 | ||
| ) -> list[VectorIndexableType]: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.