From f66025e7fb580046d29c52470e241db07d4924e0 Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Tue, 19 Dec 2023 15:30:08 +0000 Subject: [PATCH 1/9] Add WagtailVectorIndexSSEConsumer and async support --- docs/quick-start.md | 93 +++++++++++++++ pyproject.toml | 1 + src/wagtail_vector_index/consumers.py | 149 +++++++++++++++++++++++++ src/wagtail_vector_index/index/base.py | 42 +++++++ 4 files changed, 285 insertions(+) create mode 100644 src/wagtail_vector_index/consumers.py diff --git a/docs/quick-start.md b/docs/quick-start.md index 2b0e0db..cce665d 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -53,3 +53,96 @@ python manage.py update_vector_indexes ``` To skip the prompt, use the `--noinput` flag. + +## Using event-stream (WagtailVectorIndexSSEConsumer) + +The WagtailVectorIndexSSEConsumer is a asynchronous HTTP consumer designed for handling Server-Sent Events (SSE), for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI (uvicorn, daphne etc.) along with configuring django-channels. + +Configure channels via this [guide](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#configuration), otherwise, simply add channels to INSTALLED_APPS. + +```python +INSTALLED_APPS = [ + "channels", + # ... +] +``` + +and set the channel layer. The [in-memory channel layer](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#in-memory-channel-layer) can be used for local envrionments, however, redis is the only offically maintained channel layer supported for production use, thus [channels_redis](https://pypi.org/project/channels-redis/) will have to be install. + +``` +CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels_redis.core.RedisChannelLayer", + "CONFIG": { + "hosts": [("127.0.0.1", 6379)], + }, + }, +} +``` + +You will now need to create an new consumer inheriting WagtailVectorIndexSSEConsumer, but, assigning a page model for the vector index you'd like to use. + + +*Note: AuthMiddleware is required to provide User context to the consumer. + +```python +import os + +from channels.auth import AuthMiddlewareStack +from channels.routing import ProtocolTypeRouter, URLRouter +from django.core.asgi import get_asgi_application +from django.urls import path, re_path +from wagtail_vector_index.consumers import WagtailVectorIndexSSEConsumer + + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings.production") +django_asgi_app = get_asgi_application() + +class WikiPageChatQuerySSEConsumer(WagtailVectorIndexSSEConsumer): + page_model_name = "wiki.WikiPage" + + +application = ProtocolTypeRouter( + { + "http": URLRouter( + [ + path( + "chat-query-sse//", + AuthMiddlewareStack(WikiPageChatQuerySSEConsumer.as_asgi()), + ), + re_path(r"", get_asgi_application()), + ] + ), + } +) +``` + +You should now be able to query the consumer via the EventSource API, you can use the snippet below as a reference. + +```javascript +import { v4 as uuidv4 } from 'uuid'; + + +function chatQuery(query) { + const searchParams = new URLSearchParams({ + query, + }); + const es = new EventSource( + `/chat-query-sse/${uuidv4()}/?query=${searchParams}`, + ); + + es.onmessage = (e) => { + console.log(e.data); + // Do something + }; + es.onerror = () => { + // Ending an EventSource object from the server results in an error. + // Close the EventSource here to prevent repeated requests. + es.close(); + }; +} +``` + +### Known issues + +Asynchronous support in Django is fairly new, as a result of this WagtailVectorIndexSSEConsumer can't tell when a client disconnects from an event-stream, This may result in queries being processed by the server as zombie threads. Support for handling [client disconnects](https://docs.djangoproject.com/en/dev/topics/async/#handling-disconnects) was added in Django 5.0, although currently untested in this package. diff --git a/pyproject.toml b/pyproject.toml index 33d6c51..45b93a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "Django>=4.2", "Wagtail>=5.2", "aiohttp>=3.9.0b0; python_version >= '3.12'", + "channels>=3.0.5", ] [project.optional-dependencies] numpy = [ diff --git a/src/wagtail_vector_index/consumers.py b/src/wagtail_vector_index/consumers.py new file mode 100644 index 0000000..f5fe392 --- /dev/null +++ b/src/wagtail_vector_index/consumers.py @@ -0,0 +1,149 @@ +import asyncio +import logging + +from django.apps import apps +from django.core.exceptions import PermissionDenied +from channels.generic.http import AsyncHttpConsumer + +logger = logging.Logger(__name__) + + +class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer): + """ + A Django Channels consumer for handling Server-Sent Events (SSE) related to WagtailVectorIndex queries. + + Attributes: + page_instance (Model): The Wagtail page model instance for which the vector index is queried. + page_model_name (str): The name of the Wagtail page model for which the vector index is queried. + vector_index (VectorIndex): The vector index associated with the Wagtail page model. + + Methods: + handle: The main entry point for processing HTTP requests, including SSE connections. + process_prompt: Processes the incoming prompt and sends SSE updates. + check_permissions: Checks user authentication and raises PermissionDenied if not authenticated. + ratelimit_request: Placeholder for implementing rate-limiting logic. + + Usage: + - Inherit from this class and set either 'page_instance' or 'page_model_name' attribute. + - Implement custom logic within 'process_prompt' to handle vector index queries and SSE updates. + - Optionally, override 'check_permissions' and 'ratelimit_request' for additional safety checks. + + Example: + ```python + class CustomSSEConsumer(WagtailVectorIndexSSEConsumer): + page_instance = YourWagtailPageModel + + async def process_prompt(self, query): + # Your custom logic to handle the query and send SSE updates + pass + ``` + """ + page_instance = None + page_model_name = None + + def __init__(self, *args, **kwargs): + """ + Initializes the consumer and checks the required attributes. + + Raises: + ValueError: If neither 'page_instance' nor 'page_model_name' is set. + ValueError: If the specified 'page_model_name' is not found. + ValueError: If the specified page model does not inherit the ModelVectorIndex mixin. + """ + super().__init__(*args, **kwargs) + + # Check if either page_model_name or page_instance is set + if self.page_model_name is None and self.page_instance is None: + raise ValueError('You must set either the page_model_name or page_instance attribute') + + if not self.page_instance: + try: + self.page_instance = apps.get_model(self.page_model_name) + except LookupError: + raise ValueError(f'Model {self.page_model_name} not found') + + # Check if the page model has the required method (ModelVectorIndex mixin) + if not hasattr(self.page_instance, 'get_vector_index') or not callable(self.page_instance.get_vector_index): + raise ValueError('Your page_model must inherit the ModelVectorIndex mixin') + + self.vector_index = self.page_instance.get_vector_index() + + async def handle(self, body): + """ + Handles HTTP requests, sets up SSE headers, and processes prompts. + + Raises: + PermissionDenied: If the user is not authenticated. + """ + # Send SSE headers + await self.send_headers(headers=[ + (b"Cache-Control", b"no-cache"), + (b"Content-Type", b"text/event-stream"), + (b"Transfer-Encoding", b"chunked"), + ]) + + try: + try: + user = self.scope['user'] + except KeyError as e: + raise ValueError('User not found in scope, make sure AuthMiddlewareStack is applied correctly') from e + + await self.check_permissions(user) # Check permissions + await self.ratelimit_request() # Apply rate limiting + + # Process and reply to prompt + query = self.scope["query_string"].decode("utf-8") + await self.process_prompt(query) + + except Exception as e: + logging.exception("Unexpected error in WagtailVectorIndexSSEConsumer") + payload = "data: Error processing request, Please try again later. \n\n" + await self.send_body(payload.encode("utf-8"), more_body=True) + + # Finish the response + await self.send_body(b"") + + async def process_prompt(self, query): + """ + Processes the incoming prompt and sends SSE updates. + + Raises: + asyncio.CancelledError: If the connection is cancelled or disconnected. + """ + try: + stream_response, sources = await self.vector_index.query_async(query) + # TODO send or stream sources as characters to the client as well? + for chunk in stream_response: + if chunk.choices[0].delta.content is not None: + # TODO Remove after testing + # print(chunk.choices[0].delta.content, end="") # Uncomment to view response in terminal + # await asyncio.sleep(0.1) # Uncomment to test a more delayed response + content = chunk.choices[0].delta.content.replace('\n', '
') # Support line breaks + payload = f"data: {content or ''}\n\n" + await self.send_body(payload.encode("utf-8"), more_body=True) + except asyncio.CancelledError: + # Handle disconnects if needed, can occur from a server restart. + # Note: Django < 5 doesn't recognise client disconnects + pass + + async def check_permissions(self, user): + """ + Checks user authentication and raises PermissionDenied if not authenticated. + + Args: + user: The authenticated user. + + Raises: + PermissionDenied: If the user is not authenticated. + """ + if not user.is_authenticated: + # TODO log a 403, no way to send one via SSE, may need custom middleware? + raise PermissionDenied("Permission denided") + + async def ratelimit_request(self): + """ + Placeholder for implementing rate-limiting logic. + + Implement your custom rate-limiting logic within this method. + """ + pass diff --git a/src/wagtail_vector_index/index/base.py b/src/wagtail_vector_index/index/base.py index b0a33ea..24cdf78 100644 --- a/src/wagtail_vector_index/index/base.py +++ b/src/wagtail_vector_index/index/base.py @@ -4,6 +4,9 @@ from django.conf import settings +from asgiref.sync import sync_to_async +from typing import Generic, Iterable, List, Callable +from channels.db import database_sync_to_async from wagtail_vector_index.ai import get_chat_backend, get_embedding_backend from wagtail_vector_index.backends import get_vector_backend @@ -21,6 +24,14 @@ class QueryResponse(Generic[VectorIndexableType]): sources: Iterable[VectorIndexableType] +@database_sync_to_async +def get_metadata_from_documents_async(similar_documents): + metadata_list = [] + for doc in similar_documents: + metadata_list.append(doc.metadata["content"]) + return "\n".join(metadata_list) + + class VectorIndex(Generic[VectorIndexableType]): """Base class for a VectorIndex, representing some set of documents that can be queried""" @@ -70,9 +81,40 @@ def query( merged_context, query, ] + response = self.chat_backend.chat(user_messages=user_messages) return QueryResponse(response=response.text(), sources=sources) + + async def query_async(self, query: str) -> tuple[Callable, Iterable[VectorIndexableType]]: + """ + Async version of query method returning LLM response (chat) as a callable, and a list of sources + """ + try: + query_embedding = next(self.embedding_backend.embed([query])) + except StopIteration as e: + raise ValueError("No embeddings were generated for the given query.") from e + + similar_documents = await sync_to_async(self.backend_index.similarity_search)(query_embedding) + + # Add and test async _deduplicate_list method + sources = await sync_to_async(self.object_type.bulk_from_documents)(similar_documents) + merged_context = await get_metadata_from_documents_async(similar_documents) + + prompt = ( + getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None) + or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer." + ) + user_messages = [ + prompt, + merged_context, + query, + ] + return ( + self.chat_backend.chat(system_messages=[], user_messages=user_messages, stream=True), + sources + ) + def similar( self, object: VectorIndexableType, *, include_self: bool = False, limit: int = 5 ) -> list[VectorIndexableType]: From 482f90ff15a99bc1c64741219034280312fadf3c Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Thu, 21 Dec 2023 14:10:47 +0000 Subject: [PATCH 2/9] Cleanup code --- docs/quick-start.md | 35 +----- src/wagtail_vector_index/consumers.py | 158 +++++++++---------------- src/wagtail_vector_index/index/base.py | 14 ++- 3 files changed, 70 insertions(+), 137 deletions(-) diff --git a/docs/quick-start.md b/docs/quick-start.md index cce665d..e661503 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -67,23 +67,9 @@ INSTALLED_APPS = [ ] ``` -and set the channel layer. The [in-memory channel layer](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#in-memory-channel-layer) can be used for local envrionments, however, redis is the only offically maintained channel layer supported for production use, thus [channels_redis](https://pypi.org/project/channels-redis/) will have to be install. - -``` -CHANNEL_LAYERS = { - "default": { - "BACKEND": "channels_redis.core.RedisChannelLayer", - "CONFIG": { - "hosts": [("127.0.0.1", 6379)], - }, - }, -} -``` - You will now need to create an new consumer inheriting WagtailVectorIndexSSEConsumer, but, assigning a page model for the vector index you'd like to use. - -*Note: AuthMiddleware is required to provide User context to the consumer. +\*Note: AuthMiddleware is required to provide User context to the consumer. ```python import os @@ -98,17 +84,14 @@ from wagtail_vector_index.consumers import WagtailVectorIndexSSEConsumer os.environ.setdefault("DJANGO_SETTINGS_MODULE", "app.settings.production") django_asgi_app = get_asgi_application() -class WikiPageChatQuerySSEConsumer(WagtailVectorIndexSSEConsumer): - page_model_name = "wiki.WikiPage" - application = ProtocolTypeRouter( { "http": URLRouter( [ path( - "chat-query-sse//", - AuthMiddlewareStack(WikiPageChatQuerySSEConsumer.as_asgi()), + "chat-query-sse/", + AuthMiddlewareStack(WagtailVectorIndexSSEConsumer.as_asgi()), ), re_path(r"", get_asgi_application()), ] @@ -120,15 +103,9 @@ application = ProtocolTypeRouter( You should now be able to query the consumer via the EventSource API, you can use the snippet below as a reference. ```javascript -import { v4 as uuidv4 } from 'uuid'; - - -function chatQuery(query) { - const searchParams = new URLSearchParams({ - query, - }); +function chatQuery(query, pageType) { const es = new EventSource( - `/chat-query-sse/${uuidv4()}/?query=${searchParams}`, + `/chat-query-sse/?query=${query}&page_type=${pageType}`, ); es.onmessage = (e) => { @@ -145,4 +122,4 @@ function chatQuery(query) { ### Known issues -Asynchronous support in Django is fairly new, as a result of this WagtailVectorIndexSSEConsumer can't tell when a client disconnects from an event-stream, This may result in queries being processed by the server as zombie threads. Support for handling [client disconnects](https://docs.djangoproject.com/en/dev/topics/async/#handling-disconnects) was added in Django 5.0, although currently untested in this package. +Asynchronous support in Django is fairly new, as a result of this WagtailVectorIndexSSEConsumer can't tell when a client disconnects from an event-stream, This may result in queries being processed by the server as zombie threads. diff --git a/src/wagtail_vector_index/consumers.py b/src/wagtail_vector_index/consumers.py index f5fe392..4c3d1b8 100644 --- a/src/wagtail_vector_index/consumers.py +++ b/src/wagtail_vector_index/consumers.py @@ -1,109 +1,85 @@ import asyncio import logging -from django.apps import apps -from django.core.exceptions import PermissionDenied from channels.generic.http import AsyncHttpConsumer +from django import forms +from django.apps import apps +from django.http import QueryDict logger = logging.Logger(__name__) +class WagtailVectorIndexQueryParamsForm(forms.Form): + """Provides a form for validating query parameters.""" + + query = forms.CharField(max_length=255, required=True) + page_type = forms.CharField(max_length=255, required=True) + + class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer): """ A Django Channels consumer for handling Server-Sent Events (SSE) related to WagtailVectorIndex queries. - Attributes: - page_instance (Model): The Wagtail page model instance for which the vector index is queried. - page_model_name (str): The name of the Wagtail page model for which the vector index is queried. - vector_index (VectorIndex): The vector index associated with the Wagtail page model. - Methods: handle: The main entry point for processing HTTP requests, including SSE connections. process_prompt: Processes the incoming prompt and sends SSE updates. - check_permissions: Checks user authentication and raises PermissionDenied if not authenticated. - ratelimit_request: Placeholder for implementing rate-limiting logic. - Usage: - - Inherit from this class and set either 'page_instance' or 'page_model_name' attribute. - - Implement custom logic within 'process_prompt' to handle vector index queries and SSE updates. - - Optionally, override 'check_permissions' and 'ratelimit_request' for additional safety checks. + Note: + This consumer expects the following query parameters in the URL: + - 'query': The search query. + - 'page_type': The type of Wagtail page to search. - Example: - ```python - class CustomSSEConsumer(WagtailVectorIndexSSEConsumer): - page_instance = YourWagtailPageModel - - async def process_prompt(self, query): - # Your custom logic to handle the query and send SSE updates - pass - ``` + Example URL: + "/chat-query-sse/?query=example&page_type=news.NewsPage" """ - page_instance = None - page_model_name = None - - def __init__(self, *args, **kwargs): - """ - Initializes the consumer and checks the required attributes. - - Raises: - ValueError: If neither 'page_instance' nor 'page_model_name' is set. - ValueError: If the specified 'page_model_name' is not found. - ValueError: If the specified page model does not inherit the ModelVectorIndex mixin. - """ - super().__init__(*args, **kwargs) - - # Check if either page_model_name or page_instance is set - if self.page_model_name is None and self.page_instance is None: - raise ValueError('You must set either the page_model_name or page_instance attribute') - - if not self.page_instance: - try: - self.page_instance = apps.get_model(self.page_model_name) - except LookupError: - raise ValueError(f'Model {self.page_model_name} not found') - - # Check if the page model has the required method (ModelVectorIndex mixin) - if not hasattr(self.page_instance, 'get_vector_index') or not callable(self.page_instance.get_vector_index): - raise ValueError('Your page_model must inherit the ModelVectorIndex mixin') - - self.vector_index = self.page_instance.get_vector_index() async def handle(self, body): """ Handles HTTP requests, sets up SSE headers, and processes prompts. - - Raises: - PermissionDenied: If the user is not authenticated. """ # Send SSE headers - await self.send_headers(headers=[ - (b"Cache-Control", b"no-cache"), - (b"Content-Type", b"text/event-stream"), - (b"Transfer-Encoding", b"chunked"), - ]) + await self.send_headers( + headers=[ + (b"Cache-Control", b"no-cache"), + (b"Content-Type", b"text/event-stream"), + (b"Transfer-Encoding", b"chunked"), + ] + ) try: - try: - user = self.scope['user'] - except KeyError as e: - raise ValueError('User not found in scope, make sure AuthMiddlewareStack is applied correctly') from e - - await self.check_permissions(user) # Check permissions - await self.ratelimit_request() # Apply rate limiting - - # Process and reply to prompt - query = self.scope["query_string"].decode("utf-8") - await self.process_prompt(query) - - except Exception as e: - logging.exception("Unexpected error in WagtailVectorIndexSSEConsumer") - payload = "data: Error processing request, Please try again later. \n\n" + query_string = self.scope["query_string"].decode("utf-8") + query_dict = QueryDict(query_string) + + # Validate query parameters + form = WagtailVectorIndexQueryParamsForm(query_dict) + if form.is_valid(): + query = form.cleaned_data["query"] + page_type = form.cleaned_data["page_type"] + + # Get a model class by its name + page_model = apps.get_model(page_type) + vector_index = page_model.get_vector_index() + + try: + # Process and reply to prompt + await self.process_prompt(query, vector_index) + except Exception: + logging.exception( + "Unexpected error in WagtailVectorIndexSSEConsumer" + ) + payload = ( + "data: Error processing request, Please try again later. \n\n" + ) + await self.send_body(payload.encode("utf-8"), more_body=True) + + except (ValueError, UnicodeDecodeError, KeyError, LookupError, AttributeError): + payload = "data: Error processing request. \n\n" await self.send_body(payload.encode("utf-8"), more_body=True) # Finish the response await self.send_body(b"") - async def process_prompt(self, query): + async def process_prompt(self, query, vector_index): """ Processes the incoming prompt and sends SSE updates. @@ -111,39 +87,15 @@ async def process_prompt(self, query): asyncio.CancelledError: If the connection is cancelled or disconnected. """ try: - stream_response, sources = await self.vector_index.query_async(query) - # TODO send or stream sources as characters to the client as well? + stream_response, _sources = await vector_index.aquery(query) for chunk in stream_response: if chunk.choices[0].delta.content is not None: - # TODO Remove after testing - # print(chunk.choices[0].delta.content, end="") # Uncomment to view response in terminal - # await asyncio.sleep(0.1) # Uncomment to test a more delayed response - content = chunk.choices[0].delta.content.replace('\n', '
') # Support line breaks + content = chunk.choices[0].delta.content.replace( + "\n", "
" + ) # Support line breaks payload = f"data: {content or ''}\n\n" await self.send_body(payload.encode("utf-8"), more_body=True) except asyncio.CancelledError: # Handle disconnects if needed, can occur from a server restart. # Note: Django < 5 doesn't recognise client disconnects pass - - async def check_permissions(self, user): - """ - Checks user authentication and raises PermissionDenied if not authenticated. - - Args: - user: The authenticated user. - - Raises: - PermissionDenied: If the user is not authenticated. - """ - if not user.is_authenticated: - # TODO log a 403, no way to send one via SSE, may need custom middleware? - raise PermissionDenied("Permission denided") - - async def ratelimit_request(self): - """ - Placeholder for implementing rate-limiting logic. - - Implement your custom rate-limiting logic within this method. - """ - pass diff --git a/src/wagtail_vector_index/index/base.py b/src/wagtail_vector_index/index/base.py index 24cdf78..aa8e6b5 100644 --- a/src/wagtail_vector_index/index/base.py +++ b/src/wagtail_vector_index/index/base.py @@ -1,6 +1,5 @@ from collections.abc import Generator, Iterable from dataclasses import dataclass -from typing import Generic from django.conf import settings @@ -85,8 +84,9 @@ def query( response = self.chat_backend.chat(user_messages=user_messages) return QueryResponse(response=response.text(), sources=sources) - - async def query_async(self, query: str) -> tuple[Callable, Iterable[VectorIndexableType]]: + async def aquery( + self, query: str + ) -> tuple[Callable, Iterable[VectorIndexableType]]: """ Async version of query method returning LLM response (chat) as a callable, and a list of sources """ @@ -95,10 +95,14 @@ async def query_async(self, query: str) -> tuple[Callable, Iterable[VectorIndexa except StopIteration as e: raise ValueError("No embeddings were generated for the given query.") from e - similar_documents = await sync_to_async(self.backend_index.similarity_search)(query_embedding) + similar_documents = await sync_to_async(self.backend_index.similarity_search)( + query_embedding + ) # Add and test async _deduplicate_list method - sources = await sync_to_async(self.object_type.bulk_from_documents)(similar_documents) + sources = await sync_to_async(self.object_type.bulk_from_documents)( + similar_documents + ) merged_context = await get_metadata_from_documents_async(similar_documents) prompt = ( From 982ad2d4eea01cd838265ec9148f47a355055da5 Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Fri, 22 Dec 2023 13:31:56 +0000 Subject: [PATCH 3/9] Update markdown docs --- docs/quick-start.md | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/quick-start.md b/docs/quick-start.md index e661503..67284d0 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -56,22 +56,27 @@ To skip the prompt, use the `--noinput` flag. ## Using event-stream (WagtailVectorIndexSSEConsumer) -The WagtailVectorIndexSSEConsumer is a asynchronous HTTP consumer designed for handling Server-Sent Events (SSE), for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI (uvicorn, daphne etc.) along with configuring django-channels. +`WagtailVectorIndexSSEConsumer` is an asynchronous HTTP consumer designed for handling Server-Sent Events (SSE) for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI ([uvicorn](https://pypi.org/project/uvicorn/), [Daphne](https://pypi.org/project/daphne/) etc.) along with [django-channels](https://pypi.org/project/django-channels/). -Configure channels via this [guide](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#configuration), otherwise, simply add channels to INSTALLED_APPS. +You can configure channels using the [official guide](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#configuration). At a minimum, add `channels` to `INSTALLED_APPS` in your settings file. ```python +# settings.py + INSTALLED_APPS = [ "channels", # ... ] ``` -You will now need to create an new consumer inheriting WagtailVectorIndexSSEConsumer, but, assigning a page model for the vector index you'd like to use. +Next, you will need to define a new consumer inheriting from `WagtailVectorIndexSSEConsumer`, and assign a Wagtail page model for the vector index you'd like to use. + +!!! Note + The `AuthMiddleware` is required to provide the user context to the consumer. -\*Note: AuthMiddleware is required to provide User context to the consumer. ```python +# app_name/asgi.py import os from channels.auth import AuthMiddlewareStack @@ -100,7 +105,7 @@ application = ProtocolTypeRouter( ) ``` -You should now be able to query the consumer via the EventSource API, you can use the snippet below as a reference. +You should now be able to query the consumer using the [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) API. The snippet below is an example implementation: ```javascript function chatQuery(query, pageType) { @@ -122,4 +127,4 @@ function chatQuery(query, pageType) { ### Known issues -Asynchronous support in Django is fairly new, as a result of this WagtailVectorIndexSSEConsumer can't tell when a client disconnects from an event-stream, This may result in queries being processed by the server as zombie threads. +Asynchronous support in Django is fairly new and `WagtailVectorIndexSSEConsumer` can't tell when a client disconnects from an event-stream. This may result in queries being processed by the server as zombie threads. From 940b1b00020fd7d17e785174cd7deada1275498d Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Fri, 22 Dec 2023 13:32:17 +0000 Subject: [PATCH 4/9] Move channels to an optional dependency --- docs/quick-start.md | 2 +- pyproject.toml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/quick-start.md b/docs/quick-start.md index 67284d0..5307a61 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -58,7 +58,7 @@ To skip the prompt, use the `--noinput` flag. `WagtailVectorIndexSSEConsumer` is an asynchronous HTTP consumer designed for handling Server-Sent Events (SSE) for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI ([uvicorn](https://pypi.org/project/uvicorn/), [Daphne](https://pypi.org/project/daphne/) etc.) along with [django-channels](https://pypi.org/project/django-channels/). -You can configure channels using the [official guide](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#configuration). At a minimum, add `channels` to `INSTALLED_APPS` in your settings file. +You can configure channels using the [official guide](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#configuration). At a minimum, install the `channels` package and add it to `INSTALLED_APPS` in your settings file. ```python # settings.py diff --git a/pyproject.toml b/pyproject.toml index 45b93a2..b938760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,11 @@ dependencies = [ "Django>=4.2", "Wagtail>=5.2", "aiohttp>=3.9.0b0; python_version >= '3.12'", - "channels>=3.0.5", ] [project.optional-dependencies] +sse = [ + "channels>=3.0.5", +] numpy = [ "numpy>=1.26.0", ] From 24162bffe9929670a3bdfa525c3a997bd36d58f8 Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Fri, 22 Dec 2023 14:26:39 +0000 Subject: [PATCH 5/9] Add types --- src/wagtail_vector_index/consumers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/wagtail_vector_index/consumers.py b/src/wagtail_vector_index/consumers.py index 4c3d1b8..72e30e5 100644 --- a/src/wagtail_vector_index/consumers.py +++ b/src/wagtail_vector_index/consumers.py @@ -1,11 +1,15 @@ import asyncio import logging +from typing import Type from channels.generic.http import AsyncHttpConsumer from django import forms from django.apps import apps from django.http import QueryDict +# Define type instead of importing directly to prevent AppRegistryNotReady errors +VectorIndexType = Type["wagtail_vector_index.index.VectorIndex"] # noqa + logger = logging.Logger(__name__) @@ -33,7 +37,7 @@ class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer): "/chat-query-sse/?query=example&page_type=news.NewsPage" """ - async def handle(self, body): + async def handle(self, body: bytes) -> None: """ Handles HTTP requests, sets up SSE headers, and processes prompts. """ @@ -79,7 +83,7 @@ async def handle(self, body): # Finish the response await self.send_body(b"") - async def process_prompt(self, query, vector_index): + async def process_prompt(self, query: str, vector_index: VectorIndexType) -> None: """ Processes the incoming prompt and sends SSE updates. From 2deafa80d846f94297cea1a8fd8d169e5635a131 Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Fri, 22 Dec 2023 16:19:44 +0000 Subject: [PATCH 6/9] Update query params to pass in index name instead of page_type --- src/wagtail_vector_index/consumers.py | 71 +++++++++++++++----------- src/wagtail_vector_index/index/base.py | 7 +-- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/src/wagtail_vector_index/consumers.py b/src/wagtail_vector_index/consumers.py index 72e30e5..fe0c545 100644 --- a/src/wagtail_vector_index/consumers.py +++ b/src/wagtail_vector_index/consumers.py @@ -1,15 +1,12 @@ import asyncio import logging -from typing import Type +from typing import Any from channels.generic.http import AsyncHttpConsumer from django import forms -from django.apps import apps +from django.core.exceptions import ValidationError from django.http import QueryDict -# Define type instead of importing directly to prevent AppRegistryNotReady errors -VectorIndexType = Type["wagtail_vector_index.index.VectorIndex"] # noqa - logger = logging.Logger(__name__) @@ -17,7 +14,20 @@ class WagtailVectorIndexQueryParamsForm(forms.Form): """Provides a form for validating query parameters.""" query = forms.CharField(max_length=255, required=True) - page_type = forms.CharField(max_length=255, required=True) + index = forms.CharField(max_length=255, required=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + from wagtail_vector_index.index import get_vector_indexes + + self.indexes = get_vector_indexes() + + def clean_index(self): + index = self.cleaned_data["index"] + if index not in self.indexes: + raise forms.ValidationError("Invalid index. Please choose a valid index.") + return index class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer): @@ -31,10 +41,10 @@ class WagtailVectorIndexSSEConsumer(AsyncHttpConsumer): Note: This consumer expects the following query parameters in the URL: - 'query': The search query. - - 'page_type': The type of Wagtail page to search. + - 'index': The vector index to perform the query with. Example URL: - "/chat-query-sse/?query=example&page_type=news.NewsPage" + "/chat-query-sse/?query=example&index=news.NewsPage" """ async def handle(self, body: bytes) -> None: @@ -56,34 +66,33 @@ async def handle(self, body: bytes) -> None: # Validate query parameters form = WagtailVectorIndexQueryParamsForm(query_dict) - if form.is_valid(): - query = form.cleaned_data["query"] - page_type = form.cleaned_data["page_type"] - - # Get a model class by its name - page_model = apps.get_model(page_type) - vector_index = page_model.get_vector_index() - - try: - # Process and reply to prompt - await self.process_prompt(query, vector_index) - except Exception: - logging.exception( - "Unexpected error in WagtailVectorIndexSSEConsumer" - ) - payload = ( - "data: Error processing request, Please try again later. \n\n" - ) - await self.send_body(payload.encode("utf-8"), more_body=True) + if not form.is_valid(): + # Ignore "TRY301 Abstract `raise` to an inner function" + # So we can insure the event-stream is closed and no other code is executed + raise ValidationError("Invalid query parameters.") # noqa: TRY301 + query = form.cleaned_data["query"] + index = form.cleaned_data["index"] + + vector_index = form.indexes.get(index) + + if vector_index: + await self.process_prompt(query, vector_index) - except (ValueError, UnicodeDecodeError, KeyError, LookupError, AttributeError): - payload = "data: Error processing request. \n\n" - await self.send_body(payload.encode("utf-8"), more_body=True) + except ValidationError: + await self.error_response() + + except Exception: + logging.exception("Unexpected error in WagtailVectorIndexSSEConsumer") + await self.error_response() # Finish the response await self.send_body(b"") - async def process_prompt(self, query: str, vector_index: VectorIndexType) -> None: + async def error_response(self) -> None: + payload = "data: Error processing request, Please try again later. \n\n" + await self.send_body(payload.encode("utf-8"), more_body=True) + + async def process_prompt(self, query: str, vector_index: Any) -> None: """ Processes the incoming prompt and sends SSE updates. diff --git a/src/wagtail_vector_index/index/base.py b/src/wagtail_vector_index/index/base.py index aa8e6b5..56e8db2 100644 --- a/src/wagtail_vector_index/index/base.py +++ b/src/wagtail_vector_index/index/base.py @@ -99,10 +99,11 @@ async def aquery( similar_documents = await sync_to_async(self.backend_index.similarity_search)( query_embedding ) + sources = [] # Add and test async _deduplicate_list method - sources = await sync_to_async(self.object_type.bulk_from_documents)( - similar_documents - ) + # sources = await sync_to_async(self.object_type.bulk_from_documents)( + # similar_documents + # ) merged_context = await get_metadata_from_documents_async(similar_documents) prompt = ( From 215b22afe26c72d282dfb4b017dbbcfc3cb289c1 Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Thu, 7 Mar 2024 15:01:35 +0000 Subject: [PATCH 7/9] Update markdown link for django-channels documentation --- docs/quick-start.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/quick-start.md b/docs/quick-start.md index 5307a61..4fa6016 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -58,7 +58,7 @@ To skip the prompt, use the `--noinput` flag. `WagtailVectorIndexSSEConsumer` is an asynchronous HTTP consumer designed for handling Server-Sent Events (SSE) for streaming responses from queries using the vector index in real-time. Using the consumer requires ASGI ([uvicorn](https://pypi.org/project/uvicorn/), [Daphne](https://pypi.org/project/daphne/) etc.) along with [django-channels](https://pypi.org/project/django-channels/). -You can configure channels using the [official guide](https://channels.readthedocs.io/en/stable/topics/channel_layers.html#configuration). At a minimum, install the `channels` package and add it to `INSTALLED_APPS` in your settings file. +You can configure channels using the [official guide](https://channels.readthedocs.io/en/3.x/installation.html). At a minimum, install the `channels` package and add it to `INSTALLED_APPS` in your settings file, and configure support for ASGI. ```python # settings.py @@ -72,7 +72,7 @@ INSTALLED_APPS = [ Next, you will need to define a new consumer inheriting from `WagtailVectorIndexSSEConsumer`, and assign a Wagtail page model for the vector index you'd like to use. !!! Note - The `AuthMiddleware` is required to provide the user context to the consumer. + The `AuthMiddleware` is required to provide user context to the consumer. ```python From e2441e16d6a5884c079d55589834f5254eebc28c Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Mon, 11 Mar 2024 08:17:14 +0000 Subject: [PATCH 8/9] Rebase related updates and improvements around async --- .../ai_utils/backends/llm.py | 4 ++ src/wagtail_vector_index/consumers.py | 13 +++--- src/wagtail_vector_index/index/base.py | 43 +++++++++++++------ 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/wagtail_vector_index/ai_utils/backends/llm.py b/src/wagtail_vector_index/ai_utils/backends/llm.py index 26e8e60..6f96294 100644 --- a/src/wagtail_vector_index/ai_utils/backends/llm.py +++ b/src/wagtail_vector_index/ai_utils/backends/llm.py @@ -95,6 +95,10 @@ def _get_llm_chat_model(self) -> llm.Model: setattr(model, config_key, config_val) return model + def can_stream(self) -> bool: + model = self._get_llm_chat_model() + return model.can_stream + class LLMEmbeddingBackend(BaseEmbeddingBackend[LLMEmbeddingBackendConfig]): config: LLMEmbeddingBackendConfig diff --git a/src/wagtail_vector_index/consumers.py b/src/wagtail_vector_index/consumers.py index fe0c545..8f0a26a 100644 --- a/src/wagtail_vector_index/consumers.py +++ b/src/wagtail_vector_index/consumers.py @@ -100,14 +100,11 @@ async def process_prompt(self, query: str, vector_index: Any) -> None: asyncio.CancelledError: If the connection is cancelled or disconnected. """ try: - stream_response, _sources = await vector_index.aquery(query) - for chunk in stream_response: - if chunk.choices[0].delta.content is not None: - content = chunk.choices[0].delta.content.replace( - "\n", "
" - ) # Support line breaks - payload = f"data: {content or ''}\n\n" - await self.send_body(payload.encode("utf-8"), more_body=True) + results = await vector_index.aquery(query) + for chunk in results.response: + chunk = chunk.replace('\n', '
') # Replace newlines with HTML line breaks to avoid issues with encoding. + payload = f"data: {chunk}\n\n" # Each message must be terminated using two newline characters. + await self.send_body(payload.encode("utf-8"), more_body=True) except asyncio.CancelledError: # Handle disconnects if needed, can occur from a server restart. # Note: Django < 5 doesn't recognise client disconnects diff --git a/src/wagtail_vector_index/index/base.py b/src/wagtail_vector_index/index/base.py index 56e8db2..af2866e 100644 --- a/src/wagtail_vector_index/index/base.py +++ b/src/wagtail_vector_index/index/base.py @@ -1,10 +1,11 @@ +import logging from collections.abc import Generator, Iterable from dataclasses import dataclass +from llm.models import Response from django.conf import settings - from asgiref.sync import sync_to_async -from typing import Generic, Iterable, List, Callable +from typing import Generic, Iterable, List from channels.db import database_sync_to_async from wagtail_vector_index.ai import get_chat_backend, get_embedding_backend from wagtail_vector_index.backends import get_vector_backend @@ -12,6 +13,8 @@ from ..ai_utils.backends.base import BaseChatBackend, BaseEmbeddingBackend from ..base import Document, VectorIndexableType +logger = logging.Logger(__name__) + @dataclass class QueryResponse(Generic[VectorIndexableType]): @@ -23,6 +26,17 @@ class QueryResponse(Generic[VectorIndexableType]): sources: Iterable[VectorIndexableType] +@dataclass +class AsyncQueryResponse(Generic[VectorIndexableType]): + """Represents a response to the VectorIndex `aquery` method, + including a response object so users can call it's iterator + and a list of sources that were used to generate the response + """ + + response: Response + sources: Iterable[VectorIndexableType] + + @database_sync_to_async def get_metadata_from_documents_async(similar_documents): metadata_list = [] @@ -86,24 +100,26 @@ def query( async def aquery( self, query: str - ) -> tuple[Callable, Iterable[VectorIndexableType]]: + ) -> AsyncQueryResponse[VectorIndexableType]: """ - Async version of query method returning LLM response (chat) as a callable, and a list of sources + Async version of the query method. """ + if not self.chat_backend.can_stream(): + logger.warning("Chat backend does not support streaming") + try: query_embedding = next(self.embedding_backend.embed([query])) except StopIteration as e: raise ValueError("No embeddings were generated for the given query.") from e - similar_documents = await sync_to_async(self.backend_index.similarity_search)( query_embedding ) - sources = [] - # Add and test async _deduplicate_list method - # sources = await sync_to_async(self.object_type.bulk_from_documents)( - # similar_documents - # ) + + sources = await sync_to_async(self._deduplicate_list)( + self.object_type.bulk_from_documents(similar_documents) + ) + merged_context = await get_metadata_from_documents_async(similar_documents) prompt = ( @@ -115,10 +131,9 @@ async def aquery( merged_context, query, ] - return ( - self.chat_backend.chat(system_messages=[], user_messages=user_messages, stream=True), - sources - ) + + response = self.chat_backend.chat(user_messages=user_messages) + return AsyncQueryResponse(response=response, sources=sources) def similar( self, object: VectorIndexableType, *, include_self: bool = False, limit: int = 5 From 6b370e9f6b9c1575d8ecc725d719e990e24bdd0f Mon Sep 17 00:00:00 2001 From: Ben Morse Date: Mon, 11 Mar 2024 09:15:35 +0000 Subject: [PATCH 9/9] Lint --- src/wagtail_vector_index/consumers.py | 13 +++++++++---- src/wagtail_vector_index/index/base.py | 13 ++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/wagtail_vector_index/consumers.py b/src/wagtail_vector_index/consumers.py index 8f0a26a..b052377 100644 --- a/src/wagtail_vector_index/consumers.py +++ b/src/wagtail_vector_index/consumers.py @@ -1,12 +1,13 @@ import asyncio import logging -from typing import Any from channels.generic.http import AsyncHttpConsumer from django import forms from django.core.exceptions import ValidationError from django.http import QueryDict +from .base import VectorIndexableType + logger = logging.Logger(__name__) @@ -92,7 +93,9 @@ async def error_response(self) -> None: payload = "data: Error processing request, Please try again later. \n\n" await self.send_body(payload.encode("utf-8"), more_body=True) - async def process_prompt(self, query: str, vector_index: Any) -> None: + async def process_prompt( + self, query: str, vector_index: VectorIndexableType + ) -> None: """ Processes the incoming prompt and sends SSE updates. @@ -102,8 +105,10 @@ async def process_prompt(self, query: str, vector_index: Any) -> None: try: results = await vector_index.aquery(query) for chunk in results.response: - chunk = chunk.replace('\n', '
') # Replace newlines with HTML line breaks to avoid issues with encoding. - payload = f"data: {chunk}\n\n" # Each message must be terminated using two newline characters. + chunk = chunk.replace( + "\n", "
" + ) # Replace newlines with HTML line breaks to avoid issues with encoding. + payload = f"data: {chunk}\n\n" # Each message must be terminated using two newline characters. await self.send_body(payload.encode("utf-8"), more_body=True) except asyncio.CancelledError: # Handle disconnects if needed, can occur from a server restart. diff --git a/src/wagtail_vector_index/index/base.py b/src/wagtail_vector_index/index/base.py index af2866e..3e8999e 100644 --- a/src/wagtail_vector_index/index/base.py +++ b/src/wagtail_vector_index/index/base.py @@ -1,12 +1,13 @@ import logging from collections.abc import Generator, Iterable from dataclasses import dataclass -from llm.models import Response +from typing import Generic -from django.conf import settings from asgiref.sync import sync_to_async -from typing import Generic, Iterable, List from channels.db import database_sync_to_async +from django.conf import settings +from llm.models import Response + from wagtail_vector_index.ai import get_chat_backend, get_embedding_backend from wagtail_vector_index.backends import get_vector_backend @@ -29,7 +30,7 @@ class QueryResponse(Generic[VectorIndexableType]): @dataclass class AsyncQueryResponse(Generic[VectorIndexableType]): """Represents a response to the VectorIndex `aquery` method, - including a response object so users can call it's iterator + including a response object so users can call it's iterator and a list of sources that were used to generate the response """ @@ -98,9 +99,7 @@ def query( response = self.chat_backend.chat(user_messages=user_messages) return QueryResponse(response=response.text(), sources=sources) - async def aquery( - self, query: str - ) -> AsyncQueryResponse[VectorIndexableType]: + async def aquery(self, query: str) -> AsyncQueryResponse[VectorIndexableType]: """ Async version of the query method. """