From e0b9d5154209c8caabba8f98e2afcd2599ef2574 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 13 Sep 2024 15:55:18 +0000 Subject: [PATCH 01/11] Typing fixes --- src/wagtail_vector_index/storage/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index a9dbd37..67cfa5a 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -6,7 +6,7 @@ from django.db.models import Q -class DocumentQuerySet(models.QuerySet): +class DocumentManager(models.Manager): def for_key(self, object_key: str): if connection.vendor != "sqlite": return self.filter(object_keys__contains=[object_key]) From d1f00243872fc272d276834a87122d4e4c010c88 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 13 Sep 2024 17:24:16 +0000 Subject: [PATCH 02/11] Fix for_keys when using Postgres --- src/wagtail_vector_index/storage/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index 67cfa5a..a9dbd37 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -6,7 +6,7 @@ from django.db.models import Q -class DocumentManager(models.Manager): +class DocumentQuerySet(models.QuerySet): def for_key(self, object_key: str): if connection.vendor != "sqlite": return self.filter(object_keys__contains=[object_key]) From 652ef60aa9b2e2b76d9941ced3f8491060ee53e0 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 20 Sep 2024 12:56:55 +0000 Subject: [PATCH 03/11] Revert Document manager to be derived from QuerySet Added a workaround for typing issues with django-types --- src/wagtail_vector_index/storage/models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index a9dbd37..b2c28e2 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -31,6 +31,13 @@ def for_key(self, object_key: str) -> DocumentQuerySet: ... def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... +class DocumentManager(models.Manager["Document"]): + # Workaround for typing issues + def for_key(self, object_key: str) -> DocumentQuerySet: ... + + def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... + + class Document(models.Model): """Stores an embedding for an arbitrary chunk""" From cd4448af21420bacb869e3c3d633492d48ebe40b Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 20 Sep 2024 13:50:29 +0000 Subject: [PATCH 04/11] Add as_manager on DocumentQuerySet which casts returned type to DocumentManager --- src/wagtail_vector_index/storage/models.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index b2c28e2..a9dbd37 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -31,13 +31,6 @@ def for_key(self, object_key: str) -> DocumentQuerySet: ... def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... -class DocumentManager(models.Manager["Document"]): - # Workaround for typing issues - def for_key(self, object_key: str) -> DocumentQuerySet: ... - - def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... - - class Document(models.Model): """Stores an embedding for an arbitrary chunk""" From b65345ae7dd26b921d43745e2f5981cb9055759c Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 20 Sep 2024 09:01:26 +0000 Subject: [PATCH 05/11] Add afind_similar and other async supporting async methods --- src/wagtail_vector_index/storage/base.py | 69 +++++- src/wagtail_vector_index/storage/django.py | 229 ++++++++++++++---- src/wagtail_vector_index/storage/models.py | 18 ++ .../storage/numpy/provider.py | 11 +- 4 files changed, 276 insertions(+), 51 deletions(-) diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 0cb81c4..79f2adf 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -42,6 +42,8 @@ def upsert(self) -> None: ... def get_documents(self) -> Iterable["Document"]: ... + async def aget_documents(self) -> AsyncGenerator["Document", None]: ... + def _get_storage_provider(self) -> StorageProviderClass: ... @@ -102,6 +104,10 @@ def to_documents( self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend ) -> Generator["Document", None, None]: ... + async def ato_documents( + self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend + ) -> AsyncGenerator["Document", None]: ... + def bulk_to_documents( self, objects: Iterable[FromObjectType], @@ -109,9 +115,22 @@ def bulk_to_documents( embedding_backend: BaseEmbeddingBackend, ) -> Generator["Document", None, None]: ... + async def abulk_to_documents( + self, + objects: Iterable[FromObjectType], + *, + embedding_backend: BaseEmbeddingBackend, + ) -> AsyncGenerator["Document", None]: ... + class DocumentConverter(ABC): - """Base class for a DocumentConverter that can convert objects to Documents and vice versa""" + """Base class for a DocumentConverter that can convert objects to Documents and vice versa + + Note on async methods: + Some async (a-prefixed) methods in this class return AsyncGenerators directly from + the methods of the to_document_operator or from_document_operator, so they aren't marked + with the 'async' keyword to prevent the methods from being wrapped in a Coroutine. + """ to_document_operator_class: Type[ToDocumentOperator] from_document_operator_class: Type[FromDocumentOperator] @@ -132,9 +151,19 @@ def to_documents( object, embedding_backend=embedding_backend ) + def ato_documents( + self, object: object, *, embedding_backend: BaseEmbeddingBackend + ) -> AsyncGenerator["Document", None]: + return self.to_document_operator.ato_documents( + object, embedding_backend=embedding_backend + ) + def from_document(self, document: "Document") -> object: return self.from_document_operator.from_document(document) + def afrom_document(self, document: "Document") -> object: + return self.from_document_operator.afrom_document(document) + def bulk_to_documents( self, objects: Iterable[object], *, embedding_backend: BaseEmbeddingBackend ) -> Generator["Document", None, None]: @@ -142,6 +171,16 @@ def bulk_to_documents( objects, embedding_backend=embedding_backend ) + def abulk_to_documents( + self, + objects: Iterable[object], + *, + embedding_backend: BaseEmbeddingBackend, + ) -> AsyncGenerator["Document", None]: + return self.to_document_operator.abulk_to_documents( + objects, embedding_backend=embedding_backend + ) + def bulk_from_documents( self, documents: Sequence["Document"] ) -> Generator[object, None, None]: @@ -186,6 +225,9 @@ def get_embedding_backend(self) -> BaseEmbeddingBackend: def get_documents(self) -> Iterable["Document"]: raise NotImplementedError + async def aget_documents(self) -> AsyncGenerator["Document", None]: + raise NotImplementedError + def get_converter(self) -> DocumentConverter: raise NotImplementedError @@ -300,6 +342,31 @@ def find_similar( if include_self or obj != object ] + async def afind_similar( + self, + object, + *, + include_self: bool = False, + limit: int = 5, + similarity_threshold: float = 0.0, + ) -> list: + """Find similar objects to the given object asynchronously""" + converter = self.get_converter() + similar_documents = [] + async for document in converter.ato_documents( + object, embedding_backend=self.get_embedding_backend() + ): + similar_docs = self.aget_similar_documents( + document.vector, limit=limit, similarity_threshold=similarity_threshold + ) + similar_documents.extend([doc async for doc in similar_docs]) + + return [ + obj + async for obj in converter.abulk_from_documents(similar_documents) + if include_self or obj != object + ] + def search( self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0 ) -> list: diff --git a/src/wagtail_vector_index/storage/django.py b/src/wagtail_vector_index/storage/django.py index e26fac6..c572f7f 100644 --- a/src/wagtail_vector_index/storage/django.py +++ b/src/wagtail_vector_index/storage/django.py @@ -4,7 +4,6 @@ AsyncGenerator, Generator, Iterable, - MutableSequence, Sequence, ) from itertools import chain, islice @@ -17,6 +16,7 @@ cast, ) +from asgiref.sync import sync_to_async from django.apps import apps from django.core import checks from django.core.exceptions import FieldDoesNotExist @@ -174,7 +174,9 @@ def bulk_from_documents( keys_by_model_label = self._get_keys_by_model_label(documents) objects_by_key = self._get_models_by_key(keys_by_model_label) - yield from self._get_deduplicated_objects_generator(documents, objects_by_key) + yield from self._get_deduplicated_objects_generator( + documents=documents, objects_by_key=objects_by_key + ) async def abulk_from_documents( self, documents: Sequence[Document] @@ -185,7 +187,7 @@ async def abulk_from_documents( # N.B. `yield from` cannot be used in async functions, so we have to use a loop for object_from_document in self._get_deduplicated_objects_generator( - documents, objects_by_key + documents=documents, objects_by_key=objects_by_key ): yield object_from_document @@ -210,7 +212,7 @@ def _get_keys_by_model_label( @staticmethod def _get_deduplicated_objects_generator( - documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] + *, documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] ) -> Generator[models.Model, None, None]: seen_keys = set() # de-dupe as we go for doc in documents: @@ -263,9 +265,10 @@ class ModelToDocumentOperator(ToDocumentOperator[models.Model]): def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): self.object_chunker_operator = object_chunker_operator_class() + # Utility methods @staticmethod def _existing_documents_match( - documents: Iterable[Document], splits: list[str] + documents: Iterable[Document], chunks: list[str] ) -> bool: """Determine whether the documents passed in match the text content passed in""" if not documents: @@ -273,7 +276,7 @@ def _existing_documents_match( document_content = {document.content for document in documents} - return set(splits) == document_content + return set(chunks) == document_content @staticmethod def _keys_for_instance(instance: models.Model) -> list[ModelKey]: @@ -283,84 +286,155 @@ def _keys_for_instance(instance: models.Model) -> list[ModelKey]: keys = [ModelKey.from_instance(instance), *keys] return keys - @transaction.atomic - def generate_documents( - self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend - ) -> list[Document]: - """Use the AI backend to generate and store Documents for this object""" - chunks = list( + def _get_chunks( + self, object: models.Model, embedding_backend: BaseEmbeddingBackend + ) -> list[str]: + """Get chunks of text from the object using the object chunker operator""" + return list( self.object_chunker_operator.chunk_object( object, chunk_size=embedding_backend.config.token_limit ) ) - documents = Document.objects.for_key(ModelKey(object)) - # If the existing embeddings all match on content, we return them - # without generating new ones - if self._existing_documents_match(documents, chunks): - return list(documents) + # Synchronous document generation methods + def _create_new_documents( + self, + object: models.Model, + chunks: list[str], + embedding_backend: BaseEmbeddingBackend, + ) -> list[Document]: + Document.objects.for_key(ModelKey(object)).delete() + embedding_vectors = embedding_backend.embed(chunks) + return [ + Document.objects.create( + object_keys=[str(key) for key in self._keys_for_instance(object)], + vector=embedding, + content=chunk, + ) + for chunk, embedding in zip(chunks, embedding_vectors, strict=False) + ] - # Otherwise we delete all the existing Documents and get new ones - documents.delete() + # Asynchronous document generation methods + async def _acreate_new_documents( + self, + object: models.Model, + chunks: list[str], + embedding_backend: BaseEmbeddingBackend, + ) -> AsyncGenerator[Document, None]: + await Document.objects.for_key(ModelKey(object)).adelete() embedding_vectors = embedding_backend.embed(chunks) - generated_documents: MutableSequence[Document] = [] - for idx, returned_embedding in enumerate(embedding_vectors): - chunk = chunks[idx] - document = Document.objects.create( + for chunk, embedding in zip(chunks, embedding_vectors, strict=False): + yield await Document.objects.acreate( object_keys=[str(key) for key in self._keys_for_instance(object)], - vector=returned_embedding, + vector=embedding, content=chunk, ) - generated_documents.append(document) - - return generated_documents + # Bulk document generation methods @transaction.atomic - def bulk_generate_documents(self, objects, *, embedding_backend): + def bulk_generate_documents( + self, objects, *, embedding_backend + ) -> Iterable[Document]: + """Generate documents in bulk for the given objects""" objects_by_key = {ModelKey.from_instance(obj): obj for obj in objects} documents = Document.objects.for_keys(list(objects_by_key.keys())) + documents_by_object_key = self._group_documents_by_object_key(documents) + objects_to_rebuild, chunk_mapping = self._identify_objects_to_rebuild( + objects_by_key=objects_by_key, + documents_by_object_key=documents_by_object_key, + embedding_backend=embedding_backend, + ) + if not objects_to_rebuild: + return documents + + return self._rebuild_documents( + objects_to_rebuild=objects_to_rebuild, + chunk_mapping=chunk_mapping, + objects_by_key=objects_by_key, + embedding_backend=embedding_backend, + ) + + async def abulk_generate_documents( + self, + objects: Iterable[models.Model], + *, + embedding_backend: BaseEmbeddingBackend, + ) -> AsyncGenerator[Document, None]: + """Generate documents in bulk for the given objects asynchronously""" + documents = await sync_to_async(self.bulk_generate_documents)( + objects=objects, embedding_backend=embedding_backend + ) + for document in documents: + yield document + + # Helper methods for bulk document generation + def _group_documents_by_object_key(self, documents): + """Group documents by their object key""" documents_by_object_key = defaultdict(list) for document in documents: documents_by_object_key[document.object_keys[0]].append(document) + return documents_by_object_key + def _identify_objects_to_rebuild( + self, *, objects_by_key, documents_by_object_key, embedding_backend + ): + """Identify which objects need to be rebuilt""" objects_to_rebuild = {} - - # Maintain a list of object keys in the order they appear in the chunks - # so we can map the embeddings from the backend to the correct object chunk_mapping = [] - - # Determine which objects need to be rebuilt for key, object in objects_by_key.items(): documents_for_object = documents_by_object_key[key] - chunks = list( - self.object_chunker_operator.chunk_object( - object, chunk_size=embedding_backend.config.token_limit - ) + chunks = self._get_chunks( + object=object, embedding_backend=embedding_backend ) - - if not self._existing_documents_match(documents_for_object, chunks): + if not self._existing_documents_match( + documents=documents_for_object, chunks=chunks + ): objects_to_rebuild[key] = {"object": object, "chunks": chunks} chunk_mapping += [key] * len(chunks) + return objects_to_rebuild, chunk_mapping - if not objects_to_rebuild: - return documents - + def _rebuild_documents( + self, *, objects_to_rebuild, chunk_mapping, objects_by_key, embedding_backend + ): + """Rebuild Documents for the identified objects""" all_chunks = list( chain(*[obj["chunks"] for obj in objects_to_rebuild.values()]) ) - embedding_vectors = list(embedding_backend.embed(all_chunks)) - documents_by_object = defaultdict(list) + documents_by_object = self._group_embeddings_by_object( + embedding_vectors=embedding_vectors, chunk_mapping=chunk_mapping + ) + + self._delete_existing_documents(documents_by_object=documents_by_object) + self._create_new_documents_bulk( + documents_by_object=documents_by_object, + objects_by_key=objects_by_key, + all_chunks=all_chunks, + ) + + return self._get_sorted_documents(objects_by_key=objects_by_key) + def _group_embeddings_by_object(self, *, embedding_vectors, chunk_mapping): + """Group embedding vectors by their corresponding object""" + documents_by_object = defaultdict(list) for idx, embedding in enumerate(embedding_vectors): object_key = chunk_mapping[idx] documents_by_object[object_key].append((idx, embedding)) + return documents_by_object + def _delete_existing_documents(self, *, documents_by_object): existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) existing_documents.delete() + async def _adelete_existing_documents(self, *, documents_by_object): + existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) + await existing_documents.adelete() + + def _create_new_documents_bulk( + self, *, documents_by_object, objects_by_key, all_chunks + ): for object_key, documents in documents_by_object.items(): for idx, returned_embedding in documents: all_keys = self._keys_for_instance(objects_by_key[object_key]) @@ -371,8 +445,21 @@ def bulk_generate_documents(self, objects, *, embedding_backend): content=chunk, ) - # Return every document object, regardless of whether it was rebuilt, retaining - # the order they appeared in the original list + async def _acreate_new_documents_bulk( + self, *, documents_by_object, objects_by_key, all_chunks + ): + for object_key, documents in documents_by_object.items(): + for idx, returned_embedding in documents: + all_keys = self._keys_for_instance(objects_by_key[object_key]) + chunk = all_chunks[idx] + await Document.objects.acreate( + object_keys=all_keys, + vector=returned_embedding, + content=chunk, + ) + + def _get_sorted_documents(self, *, objects_by_key): + """Get sorted documents for the given objects""" documents = list(Document.objects.for_keys(list(objects_by_key.keys()))) return sorted( documents, @@ -381,10 +468,36 @@ def bulk_generate_documents(self, objects, *, embedding_backend): ), ) + # Interface methods + @transaction.atomic def to_documents( self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend ) -> Generator[Document, None, None]: - yield from self.generate_documents(object, embedding_backend=embedding_backend) + """Use the AI backend to generate and store Documents for this object""" + chunks = self._get_chunks(object, embedding_backend) + documents = Document.objects.for_key(ModelKey(object)) + + if self._existing_documents_match(list(documents), chunks): + yield from documents + + yield from self._create_new_documents(object, chunks, embedding_backend) + + async def ato_documents( + self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend + ) -> AsyncGenerator[Document, None]: + """Use the AI backend to generate and store Documents for this object asynchronously""" + chunks = self._get_chunks(object, embedding_backend) + documents = Document.objects.afor_key(ModelKey(object)) + documents = [doc async for doc in documents] + + if self._existing_documents_match(documents, chunks): + for document in documents: + yield document + + async for document in self._acreate_new_documents( + object, chunks, embedding_backend + ): + yield document def bulk_to_documents( self, @@ -393,12 +506,30 @@ def bulk_to_documents( batch_size: int = 100, embedding_backend: BaseEmbeddingBackend, ) -> Generator[Document, None, None]: + """Convert multiple model instances to Documents in batches""" batches = list(batched(objects, batch_size)) for idx, batch in enumerate(batches): logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") - yield from self.bulk_generate_documents( + for document in self.bulk_generate_documents( batch, embedding_backend=embedding_backend - ) + ): + yield document + + async def abulk_to_documents( + self, + objects: Iterable[models.Model], + *, + batch_size: int = 100, + embedding_backend: BaseEmbeddingBackend, + ) -> AsyncGenerator[Document, None]: + """Convert multiple model instances to Documents asynchronously in batches""" + batches = list(batched(objects, batch_size)) + for idx, batch in enumerate(batches): + logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") + async for document in self.abulk_generate_documents( + objects=batch, embedding_backend=embedding_backend + ): + yield document class EmbeddableFieldsObjectChunkerOperator( diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index a9dbd37..bdbe90f 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -1,4 +1,5 @@ import operator +from collections.abc import AsyncGenerator from functools import reduce from typing import cast @@ -15,10 +16,27 @@ def for_key(self, object_key: str): # so we use icontains which just does a string search return self.filter(object_keys__icontains=object_key) + async def afor_key(self, object_key: str) -> AsyncGenerator["Document", None]: + if connection.vendor != "sqlite": + async for doc in self.filter(object_keys__contains=[object_key]): + yield doc + else: + # SQLite doesn't support the __contains lookup for JSON fields + # so we use icontains which just does a string search + async for doc in self.filter(object_keys__icontains=object_key): + yield doc + def for_keys(self, object_keys: list[str]): q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] return self.filter(reduce(operator.or_, q_objs)) + async def afor_keys( + self, object_keys: list[str] + ) -> AsyncGenerator["Document", None]: + q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] + async for doc in self.filter(reduce(operator.or_, q_objs)): + yield doc + @classmethod def as_manager(cls) -> "DocumentManager": return cast(DocumentManager, super().as_manager()) diff --git a/src/wagtail_vector_index/storage/numpy/provider.py b/src/wagtail_vector_index/storage/numpy/provider.py index 02de3f5..42f7451 100644 --- a/src/wagtail_vector_index/storage/numpy/provider.py +++ b/src/wagtail_vector_index/storage/numpy/provider.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Generator, Iterable, Sequence +from collections.abc import AsyncGenerator, Generator, Iterable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING @@ -58,6 +58,15 @@ def get_similar_documents( for document in [pair[1] for pair in sorted_similarities][:limit]: yield document + async def aget_similar_documents( + self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0 + ) -> AsyncGenerator["Document", None]: + documents = self.get_similar_documents( + query_vector, limit=limit, similarity_threshold=similarity_threshold + ) + for document in documents: + yield document + class NumpyStorageProvider(StorageProvider[ProviderConfig, NumpyIndexMixin]): config_class = ProviderConfig From c8518ff60d5b513e2714687cde0505a486b97128 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 20 Sep 2024 09:03:34 +0000 Subject: [PATCH 06/11] Add/regroup async tests --- tests/async_factory.py | 177 +++++++++ tests/conftest.py | 31 ++ tests/factories.py | 9 + tests/test_django_converter.py | 41 ++ tests/test_index.py | 478 ++++++++++++----------- tests/test_model_index.py | 4 +- tests/test_model_to_document_operator.py | 0 7 files changed, 501 insertions(+), 239 deletions(-) create mode 100644 tests/async_factory.py create mode 100644 tests/test_model_to_document_operator.py diff --git a/tests/async_factory.py b/tests/async_factory.py new file mode 100644 index 0000000..79dc773 --- /dev/null +++ b/tests/async_factory.py @@ -0,0 +1,177 @@ +# ruff: noqa +# From https://github.com/Andrew-Chen-Wang/factory-boy-django-async + + +import inspect + +import factory +from asgiref.sync import sync_to_async +from django.db import IntegrityError +from factory import errors +from factory.builder import BuildStep, StepBuilder, parse_declarations + + +def use_postgeneration_results(self, step, instance, results): + return self.factory._after_postgeneration( + instance, + create=step.builder.strategy == factory.enums.CREATE_STRATEGY, + results=results, + ) + + +factory.base.FactoryOptions.use_postgeneration_results = use_postgeneration_results + + +class AsyncFactory(factory.django.DjangoModelFactory): + @classmethod + async def _generate(cls, strategy, params): + if cls._meta.abstract: + raise factory.errors.FactoryError( + "Cannot generate instances of abstract factory %(f)s; " + "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract " + "is either not set or False." % dict(f=cls.__name__) + ) + + step = AsyncStepBuilder(cls._meta, params, strategy) + return await step.build() + + class Meta: + abstract = True # Optional, but explicit. + + @classmethod + async def _get_or_create(cls, model_class, *args, **kwargs): + """Create an instance of the model through objects.get_or_create.""" + manager = cls._get_manager(model_class) + + assert "defaults" not in cls._meta.django_get_or_create, ( + "'defaults' is a reserved keyword for get_or_create " + "(in %s._meta.django_get_or_create=%r)" + % (cls, cls._meta.django_get_or_create) + ) + + key_fields = {} + for field in cls._meta.django_get_or_create: + if field not in kwargs: + raise errors.FactoryError( + "django_get_or_create - " + "Unable to find initialization value for '%s' in factory %s" + % (field, cls.__name__) + ) + key_fields[field] = kwargs.pop(field) + key_fields["defaults"] = kwargs + + try: + instance, _created = await manager.aget_or_create(*args, **key_fields) + except IntegrityError as e: + get_or_create_params = { + lookup: value + for lookup, value in cls._original_params.items() + if lookup in cls._meta.django_get_or_create + } + if get_or_create_params: + try: + instance = await manager.aget(**get_or_create_params) + except manager.model.DoesNotExist: + # Original params are not a valid lookup and triggered a create(), + # that resulted in an IntegrityError. Follow Django’s behavior. + raise e + else: + raise e + + return instance + + @classmethod + async def _create(cls, model_class, *args, **kwargs): + """Create an instance of the model, and save it to the database.""" + if cls._meta.django_get_or_create: + return await cls._get_or_create(model_class, *args, **kwargs) + + manager = cls._get_manager(model_class) + return await manager.acreate(*args, **kwargs) + + @classmethod + async def create_batch(cls, size, **kwargs): + """Create a batch of instances of the model, and save them to the database.""" + return [await cls.create(**kwargs) for _ in range(size)] + + @classmethod + async def _after_postgeneration(cls, instance, create, results=None): + """Save again the instance if creating and at least one hook ran.""" + if create and results: + # Some post-generation hooks ran, and may have modified us. + if hasattr(instance, "asave"): + await instance.asave() + else: + await sync_to_async(instance.save)() + + +class AsyncBuildStep(BuildStep): + async def resolve(self, declarations): + self.stub = factory.builder.Resolver( + declarations=declarations, + step=self, + sequence=self.sequence, + ) + + for field_name in declarations: + attr = getattr(self.stub, field_name) + if inspect.isawaitable(attr): + attr = await attr + self.attributes[field_name] = attr + + +class AsyncStepBuilder(StepBuilder): + # Redefine build function that await for instance creation and awaitable postgenerations + async def build(self, parent_step=None, force_sequence=None): + """Build a factory instance.""" + # TODO: Handle "batch build" natively + pre, post = parse_declarations( + self.extras, + base_pre=self.factory_meta.pre_declarations, + base_post=self.factory_meta.post_declarations, + ) + + if force_sequence is not None: + sequence = force_sequence + elif self.force_init_sequence is not None: + sequence = self.force_init_sequence + else: + sequence = self.factory_meta.next_sequence() + + step = AsyncBuildStep( + builder=self, + sequence=sequence, + parent_step=parent_step, + ) + await step.resolve(pre) + + args, kwargs = self.factory_meta.prepare_arguments(step.attributes) + + instance = self.factory_meta.instantiate( + step=step, + args=args, + kwargs=kwargs, + ) + if inspect.isawaitable(instance): + instance = await instance + + postgen_results = {} + for declaration_name in post.sorted(): + declaration = post[declaration_name] + declaration_result = declaration.declaration.evaluate_post( + instance=instance, + step=step, + overrides=declaration.context, + ) + if inspect.isawaitable(declaration_result): + declaration_result = await declaration_result + postgen_results[declaration_name] = declaration_result + + postgen = self.factory_meta.use_postgeneration_results( + instance=instance, + step=step, + results=postgen_results, + ) + if inspect.isawaitable(postgen): + await postgen + return instance diff --git a/tests/conftest.py b/tests/conftest.py index 09e57b2..65f2fc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,3 +65,34 @@ def use_mock_ai_backend(settings): } }, } + + +@pytest.fixture +def get_vector_for_text(): + def _get_vector_for_text(text): + if "Very similar" in text: + return [0.9, 0.1, 0.0] + elif "Somewhat similar" in text: + return [0.7, 0.3, 0.0] + elif "test" in text.lower(): + return [1.0, 0.0, 0.0] + else: + return [0.1, 0.1, 0.8] + + return _get_vector_for_text + + +@pytest.fixture +def mock_embedding_backend(get_vector_for_text): + class MockEmbeddingBackend(BaseEmbeddingBackend): + def __init__(self): + self.config = type("Config", (), {"token_limit": 100})() + + def embed(self, texts): + def embedding_generator(): + for text in texts: + yield get_vector_for_text(text) + + return embedding_generator() + + return MockEmbeddingBackend() diff --git a/tests/factories.py b/tests/factories.py index fdbe608..484856d 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,5 +1,6 @@ import factory import wagtail_factories +from async_factory import AsyncFactory from faker import Faker from testapp.models import DifferentPage, ExampleModel, ExamplePage from wagtail_vector_index.storage.models import Document @@ -15,6 +16,14 @@ class Meta: model = ExampleModel +class AsyncExampleModelFactory(AsyncFactory): + title = factory.Faker("sentence") + body = factory.LazyFunction(lambda: "\n".join(fake.paragraphs())) + + class Meta: + model = ExampleModel + + class ExamplePageFactory(wagtail_factories.PageFactory): class Meta: model = ExamplePage diff --git a/tests/test_django_converter.py b/tests/test_django_converter.py index 768cd91..2995b3b 100644 --- a/tests/test_django_converter.py +++ b/tests/test_django_converter.py @@ -1,6 +1,7 @@ import factory import pytest from factories import ( + AsyncExampleModelFactory, DifferentPageFactory, DocumentFactory, ExampleModelFactory, @@ -17,11 +18,13 @@ ModelLabel, ModelToDocumentOperator, ) +from wagtail_vector_index.storage.models import Document fake = Faker() class TestChunking: + @pytest.mark.django_db def test_get_chunks_splits_content_into_multiple_chunks( self, patch_embedding_fields ): @@ -32,6 +35,7 @@ def test_get_chunks_splits_content_into_multiple_chunks( chunks = chunker.chunk_object(instance, chunk_size=100) assert len(chunks) > 1 + @pytest.mark.django_db def test_get_chunks_adds_important_field_to_each_chunk( self, patch_embedding_fields ): @@ -311,3 +315,40 @@ def test_convert_multiple_documents_to_objects(): ) recovered_objects = list(converter.bulk_from_documents(documents)) assert recovered_objects == all_objects + + +class TestToDocumentOperatorAsync: + @pytest.mark.django_db(transaction=True) + async def test_ato_documents(self, mock_embedding_backend): + instance = await AsyncExampleModelFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + documents = [ + doc + async for doc in operator.ato_documents( + instance, embedding_backend=mock_embedding_backend + ) + ] + + assert len(documents) > 0 + assert all(isinstance(doc, Document) for doc in documents) + assert all(instance.title in doc.content for doc in documents) + + @pytest.mark.django_db(transaction=True) + async def test_abulk_to_documents(self, mock_embedding_backend): + instances = await AsyncExampleModelFactory.create_batch(3) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + + documents = [ + doc + async for doc in operator.abulk_to_documents( + instances, embedding_backend=mock_embedding_backend + ) + ] + + assert len(documents) > 0 + assert all(isinstance(doc, Document) for doc in documents) + assert any(instances[0].title in doc.content for doc in documents) + assert any(instances[1].title in doc.content for doc in documents) + assert any(instances[2].title in doc.content for doc in documents) diff --git a/tests/test_index.py b/tests/test_index.py index a9cc296..4b1515e 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1,13 +1,14 @@ import unittest import pytest -from factories import DifferentPageFactory, ExamplePageFactory -from faker import Faker -from testapp.models import DifferentPage, ExamplePage -from wagtail_vector_index.ai_utils.backends.base import BaseEmbeddingBackend -from wagtail_vector_index.storage import ( - registry, +from factories import ( + DifferentPageFactory, + ExampleModelFactory, + ExamplePageFactory, ) +from faker import Faker +from testapp.models import DifferentPage, ExampleModel, ExamplePage +from wagtail_vector_index.storage import registry from wagtail_vector_index.storage.base import VectorIndex from wagtail_vector_index.storage.django import EmbeddingField, ModelKey from wagtail_vector_index.storage.models import Document @@ -15,50 +16,25 @@ fake = Faker() -def get_vector_for_text(text): - if "Very similar" in text: - return [0.9, 0.1, 0.0] - elif "Somewhat similar" in text: - return [0.7, 0.3, 0.0] - elif "test" in text.lower(): - return [1.0, 0.0, 0.0] - else: - return [0.1, 0.1, 0.8] - - -@pytest.fixture -def mock_embedding_backend(): - class MockEmbeddingBackend(BaseEmbeddingBackend): - def embed(self, texts): - def embedding_generator(): - for text in texts: - yield get_vector_for_text(text) - - return embedding_generator() - - return MockEmbeddingBackend - - @pytest.fixture -def test_pages(): +def test_objects(): return [ - ExamplePageFactory(title="Very similar to test"), - ExamplePageFactory(title="Somewhat similar to test"), - ExamplePageFactory(title="Not similar at all"), + ExampleModelFactory.create(title="Very similar to test"), + ExampleModelFactory.create(title="Somewhat similar to test"), + ExampleModelFactory.create(title="Not similar at all"), ] @pytest.fixture -def document_generator(test_pages): +def document_generator(test_objects, get_vector_for_text): def gen_documents(cls, *args, **kwargs): - for page in test_pages: - vector = get_vector_for_text(page.title) + for obj in test_objects: + vector = get_vector_for_text(obj.title) yield Document( - object_keys=[ModelKey.from_instance(page)], + object_keys=[ModelKey.from_instance(obj)], metadata={ - "title": page.title, - "object_id": str(page.pk), - "content_type_id": str(page.get_content_type().id), + "title": obj.title, + "object_id": str(obj.pk), }, vector=vector, ) @@ -67,12 +43,27 @@ def gen_documents(cls, *args, **kwargs): @pytest.fixture -def mock_vector_index(mocker, mock_embedding_backend, document_generator): +def async_document_generator(test_objects, get_vector_for_text): + async def gen_documents(cls, *args, **kwargs): + for obj in test_objects: + vector = get_vector_for_text(obj.title) + yield Document( + object_keys=[ModelKey.from_instance(obj)], + metadata={"title": obj.title, "object_id": str(obj.pk)}, + vector=vector, + ) + + return gen_documents + + +@pytest.fixture +def mock_vector_index( + mocker, mock_embedding_backend, document_generator, async_document_generator +): vector_index = ExamplePage.vector_index - mock_backend = mock_embedding_backend(config=mocker.Mock()) mocker.patch.object( - vector_index, "get_embedding_backend", return_value=mock_backend + vector_index, "get_embedding_backend", return_value=mock_embedding_backend ) mocker.patch( @@ -80,205 +71,218 @@ def mock_vector_index(mocker, mock_embedding_backend, document_generator): side_effect=document_generator, ) - return vector_index - - -def test_registry(): - expected_class_names = [ - "ExamplePageIndex", - "ExampleModelIndex", - "DifferentPageIndex", - "MultiplePageVectorIndex", - ] - assert set(registry._registry.keys()) == set(expected_class_names) - - -def test_indexed_model_has_vector_index(): - index = ExamplePage.vector_index - assert index.__class__.__name__ == "ExamplePageIndex" - - -def test_register_custom_vector_index(): - custom_index = type("MyVectorIndex", (VectorIndex,), {})() - registry.register_index(custom_index) - assert registry["MyVectorIndex"] == custom_index - - -def test_get_embedding_fields_count(patch_embedding_fields): - with patch_embedding_fields( - ExamplePage, [EmbeddingField("test"), EmbeddingField("another_test")] - ): - assert len(ExamplePage._get_embedding_fields()) == 2 - - -def test_embedding_fields_override(patch_embedding_fields): - # In the same vein as Wagtail's search index fields, if there are - # multiple fields of the same type with the same name, only one - # should be returned - with patch_embedding_fields( - ExamplePage, [EmbeddingField("test"), EmbeddingField("test")] - ): - assert len(ExamplePage._get_embedding_fields()) == 1 - - -def test_checking_search_fields_errors_with_invalid_field(patch_embedding_fields): - with patch_embedding_fields(ExamplePage, [EmbeddingField("foo")]): - errors = ExamplePage.check() - assert "wagtailai.WA001" in [error.id for error in errors] - - -@pytest.mark.django_db -def test_index_get_documents_returns_at_least_one_document_per_page(): - pages = ExamplePageFactory.create_batch(10) - index = registry["ExamplePageIndex"] - index.rebuild_index() - documents = index.get_documents() - found_pages = { - ModelKey(document.object_keys[0]).object_id for document in documents - } - - assert found_pages == {str(page.pk) for page in pages} - - -@pytest.mark.django_db -def test_index_with_multiple_models(): - example_pages = ExamplePageFactory.create_batch(5) - different_pages = DifferentPageFactory.create_batch(5) - index = registry["MultiplePageVectorIndex"] - index.rebuild_index() - - example_pages_ids = {str(page.pk) for page in example_pages} - different_page_ids = {str(page.pk) for page in different_pages} - found_page_ids = { - ModelKey(document.object_keys[0]).object_id - for document in index.get_documents() - } - - assert found_page_ids == example_pages_ids.union(different_page_ids) - - similar_result = list(index.find_similar(DifferentPage.objects.first())) - assert len(similar_result) > 1 - for p in similar_result: - assert isinstance(p, (ExamplePage, DifferentPage)) - - search_result = list(index.search("test")) - assert len(search_result) > 1 - for p in search_result: - assert isinstance(p, (ExamplePage, DifferentPage)) - - -@pytest.mark.django_db -def test_similar_returns_no_duplicates(mocker): - pages = ExamplePageFactory.create_batch(10) - vector_index = ExamplePage.vector_index - - def gen_pages(cls, *args, **kwargs): - yield from pages - mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", - side_effect=gen_pages, + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.abulk_to_documents", + side_effect=async_document_generator, ) - case = unittest.TestCase() - - # We expect 9 results without the page itself. - actual = vector_index.find_similar(pages[0], limit=100, include_self=False) - case.assertCountEqual(actual, pages[1:]) - - # We expect 10 results with the page itself. - actual = vector_index.find_similar(pages[0], limit=100, include_self=True) - case.assertCountEqual(actual, pages) - - -@pytest.mark.django_db -def test_query_passes_sources_to_backend(mocker): - ExamplePageFactory.create_batch(2) - index = ExamplePage.vector_index - documents = index.get_documents()[:2] - - def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.0): - yield from documents - - query_mock = mocker.patch("conftest.ChatMockBackend.chat") - expected_content = "\n".join([doc.content for doc in documents]) - similar_documents_mock = mocker.patch.object(index, "get_similar_documents") - similar_documents_mock.side_effect = get_similar_documents - index.query("") - first_call_messages = query_mock.call_args.kwargs["messages"] - assert first_call_messages[1] == {"content": expected_content, "role": "system"} - - -@pytest.mark.django_db -def test_query_with_similarity_threshold(mocker): - ExamplePageFactory.create_batch(2) - index = ExamplePage.vector_index - documents = index.get_documents()[:2] - - def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5): - yield from documents - - query_mock = mocker.patch("conftest.ChatMockBackend.chat") - expected_content = "\n".join([doc.content for doc in documents]) - similar_documents_mock = mocker.patch.object(index, "get_similar_documents") - similar_documents_mock.side_effect = get_similar_documents - index.query("", similarity_threshold=0.5) - first_call_messages = query_mock.call_args.kwargs["messages"] - assert first_call_messages[1] == {"content": expected_content, "role": "system"} - - -@pytest.mark.django_db -def test_find_similar_with_similarity_threshold(mocker): - pages = ExamplePageFactory.create_batch(10) - vector_index = ExamplePage.vector_index - - def gen_pages(cls, *args, **kwargs): - yield from pages - - mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", - side_effect=gen_pages, - ) + return vector_index - # We expect 9 results without the page itself. - actual = vector_index.find_similar( - pages[0], limit=100, include_self=False, similarity_threshold=0.5 - ) - assert set(actual) == set(pages[1:]), f"Expected {pages[1:]}, but got {actual}" - # We expect 10 results with the page itself. - actual = vector_index.find_similar( - pages[0], limit=100, include_self=True, similarity_threshold=0.5 +class TestRegistry: + def test_registry(self): + expected_class_names = [ + "ExamplePageIndex", + "ExampleModelIndex", + "DifferentPageIndex", + "MultiplePageVectorIndex", + ] + assert set(registry._registry.keys()) == set(expected_class_names) + + def test_indexed_model_has_vector_index(self): + index = ExamplePage.vector_index + assert index.__class__.__name__ == "ExamplePageIndex" + + def test_register_custom_vector_index(self): + custom_index = type("MyVectorIndex", (VectorIndex,), {})() + registry.register_index(custom_index) + assert registry["MyVectorIndex"] == custom_index + + +class TestEmbeddingFields: + def test_get_embedding_fields_count(self, patch_embedding_fields): + with patch_embedding_fields( + ExamplePage, [EmbeddingField("test"), EmbeddingField("another_test")] + ): + assert len(ExamplePage._get_embedding_fields()) == 2 + + def test_embedding_fields_override(self, patch_embedding_fields): + with patch_embedding_fields( + ExamplePage, [EmbeddingField("test"), EmbeddingField("test")] + ): + assert len(ExamplePage._get_embedding_fields()) == 1 + + def test_checking_search_fields_errors_with_invalid_field( + self, patch_embedding_fields + ): + with patch_embedding_fields(ExamplePage, [EmbeddingField("foo")]): + errors = ExamplePage.check() + assert "wagtailai.WA001" in [error.id for error in errors] + + +class TestIndexOperations: + @pytest.mark.django_db + def test_index_get_documents_returns_at_least_one_document_per_page(self): + pages = ExampleModelFactory.create_batch(10) + index = registry["ExampleModelIndex"] + index.rebuild_index() + documents = index.get_documents() + found_pages = { + ModelKey(document.object_keys[0]).object_id for document in documents + } + + assert found_pages == {str(page.pk) for page in pages} + + @pytest.mark.django_db + def test_index_with_multiple_models(self): + example_pages = ExamplePageFactory.create_batch(5) + different_pages = DifferentPageFactory.create_batch(5) + index = registry["MultiplePageVectorIndex"] + index.rebuild_index() + + example_pages_ids = {str(page.pk) for page in example_pages} + different_page_ids = {str(page.pk) for page in different_pages} + found_page_ids = { + ModelKey(document.object_keys[0]).object_id + for document in index.get_documents() + } + + assert found_page_ids == example_pages_ids.union(different_page_ids) + + similar_result = list(index.find_similar(DifferentPage.objects.first())) + assert len(similar_result) > 1 + for p in similar_result: + assert isinstance(p, (ExamplePage, DifferentPage)) + + search_result = list(index.search("test")) + assert len(search_result) > 1 + for p in search_result: + assert isinstance(p, (ExamplePage, DifferentPage)) + + +class TestSimilarityOperations: + @pytest.mark.django_db + def test_similar_returns_no_duplicates(self, mocker): + pages = ExampleModelFactory.create_batch(10) + vector_index = ExamplePage.vector_index + + def gen_pages(cls, *args, **kwargs): + yield from pages + + mocker.patch( + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", + side_effect=gen_pages, + ) + + case = unittest.TestCase() + + # We expect 9 results without the page itself. + actual = vector_index.find_similar(pages[0], limit=100, include_self=False) + case.assertCountEqual(actual, pages[1:]) + + # We expect 10 results with the page itself. + actual = vector_index.find_similar(pages[0], limit=100, include_self=True) + case.assertCountEqual(actual, pages) + + @pytest.mark.django_db + def test_find_similar_with_similarity_threshold(self, mocker): + pages = ExampleModelFactory.create_batch(10) + vector_index = ExamplePage.vector_index + + def gen_pages(cls, *args, **kwargs): + yield from pages + + mocker.patch( + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", + side_effect=gen_pages, + ) + + # We expect 9 results without the page itself. + actual = vector_index.find_similar( + pages[0], limit=100, include_self=False, similarity_threshold=0.5 + ) + assert set(actual) == set(pages[1:]), f"Expected {pages[1:]}, but got {actual}" + + # We expect 10 results with the page itself. + actual = vector_index.find_similar( + pages[0], limit=100, include_self=True, similarity_threshold=0.5 + ) + assert set(actual) == set(pages), f"Expected {pages}, but got {actual}" + + @pytest.mark.django_db(transaction=True) + async def test_afind_similar(self, mock_vector_index): + objs = [obj async for obj in ExampleModel.objects.all()] + actual = await mock_vector_index.afind_similar( + objs[0], limit=100, include_self=False + ) + assert set(actual) == set(objs[1:]), f"Expected {objs[1:]}, but got {actual}" + + +class TestQueryOperations: + @pytest.mark.django_db + def test_query_passes_sources_to_backend(self, mocker): + ExampleModelFactory.create_batch(2) + index = ExamplePage.vector_index + documents = index.get_documents()[:2] + + def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.0): + yield from documents + + query_mock = mocker.patch("conftest.ChatMockBackend.chat") + expected_content = "\n".join([doc.content for doc in documents]) + similar_documents_mock = mocker.patch.object(index, "get_similar_documents") + similar_documents_mock.side_effect = get_similar_documents + index.query("") + first_call_messages = query_mock.call_args.kwargs["messages"] + assert first_call_messages[1] == {"content": expected_content, "role": "system"} + + @pytest.mark.django_db + def test_query_with_similarity_threshold(self, mocker): + ExampleModelFactory.create_batch(2) + index = ExamplePage.vector_index + documents = index.get_documents()[:2] + + def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5): + yield from documents + + query_mock = mocker.patch("conftest.ChatMockBackend.chat") + expected_content = "\n".join([doc.content for doc in documents]) + similar_documents_mock = mocker.patch.object(index, "get_similar_documents") + similar_documents_mock.side_effect = get_similar_documents + index.query("", similarity_threshold=0.5) + first_call_messages = query_mock.call_args.kwargs["messages"] + assert first_call_messages[1] == {"content": expected_content, "role": "system"} + + @pytest.mark.django_db + @pytest.mark.parametrize( + "similarity_threshold, expected_count, expected_titles", + [ + (0.9, 0, set()), + (0.6, 1, {"Very similar to test"}), + (0.1, 2, {"Very similar to test", "Somewhat similar to test"}), + ( + None, + 3, + { + "Very similar to test", + "Somewhat similar to test", + "Not similar at all", + }, + ), + ], ) - assert set(actual) == set(pages), f"Expected {pages}, but got {actual}" - - -@pytest.mark.django_db -@pytest.mark.parametrize( - "similarity_threshold, expected_count, expected_titles", - [ - (0.9, 0, set()), - (0.6, 1, {"Very similar to test"}), - (0.1, 2, {"Very similar to test", "Somewhat similar to test"}), - ( - None, - 3, - {"Very similar to test", "Somewhat similar to test", "Not similar at all"}, - ), - ], -) -def test_search_with_similarity_threshold( - mock_vector_index, similarity_threshold, expected_count, expected_titles -): - kwargs = {"limit": 100} - if similarity_threshold is not None: - kwargs["similarity_threshold"] = similarity_threshold + def test_search_with_similarity_threshold( + self, mock_vector_index, similarity_threshold, expected_count, expected_titles + ): + kwargs = {"limit": 100} + if similarity_threshold is not None: + kwargs["similarity_threshold"] = similarity_threshold - results = list(mock_vector_index.search("test", **kwargs)) + results = list(mock_vector_index.search("test", **kwargs)) - assert ( - len(results) == expected_count - ), f"Expected {expected_count} results, got {len(results)}" + assert ( + len(results) == expected_count + ), f"Expected {expected_count} results, got {len(results)}" - if expected_count > 0: - assert {result.title for result in results} == expected_titles + if expected_count > 0: + assert {result.title for result in results} == expected_titles diff --git a/tests/test_model_index.py b/tests/test_model_index.py index 1755211..b50617f 100644 --- a/tests/test_model_index.py +++ b/tests/test_model_index.py @@ -42,8 +42,8 @@ def get_index(self): return registry["ExamplePageIndex"] -def test_rebuilding_model_index_creates_embeddings(): - ExamplePageFactory.create_batch(10) +def test_rebuilding_model_index_creates_documents(): + ExamplePageFactory.create_batch(10, body=fake.text(max_nb_chars=10)) index = ExamplePage.vector_index index.rebuild_index() assert Document.objects.count() == 10 diff --git a/tests/test_model_to_document_operator.py b/tests/test_model_to_document_operator.py new file mode 100644 index 0000000..e69de29 From ef7784f9d58715fd7f4d4550d1bf742fcccb1b96 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Thu, 26 Sep 2024 09:34:23 +0000 Subject: [PATCH 07/11] Add async search API --- src/wagtail_vector_index/storage/base.py | 21 +++++++++++++++++++++ tests/conftest.py | 3 +++ tests/test_index.py | 8 ++++++++ 3 files changed, 32 insertions(+) diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 79f2adf..e535e55 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -382,6 +382,27 @@ def search( ) return list(self.get_converter().bulk_from_documents(similar_documents)) + async def asearch( + self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0 + ) -> list: + """Perform a search against the index, returning only a list of matching sources""" + try: + query_embedding = next(await self.get_embedding_backend().aembed([query])) + except StopIteration as e: + raise ValueError("No embeddings were generated for the given query.") from e + similar_documents = [ + doc + async for doc in self.aget_similar_documents( + query_embedding, limit=limit, similarity_threshold=similarity_threshold + ) + ] + return [ + obj + async for obj in self.get_converter().abulk_from_documents( + similar_documents + ) + ] + # Utilities def _get_storage_provider(self): diff --git a/tests/conftest.py b/tests/conftest.py index 65f2fc8..3b3bb12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,4 +95,7 @@ def embedding_generator(): return embedding_generator() + async def aembed(self, texts): + return self.embed(texts) + return MockEmbeddingBackend() diff --git a/tests/test_index.py b/tests/test_index.py index 4b1515e..ed284b4 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -253,6 +253,8 @@ def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5): first_call_messages = query_mock.call_args.kwargs["messages"] assert first_call_messages[1] == {"content": expected_content, "role": "system"} + +class TestSearchOperations: @pytest.mark.django_db @pytest.mark.parametrize( "similarity_threshold, expected_count, expected_titles", @@ -286,3 +288,9 @@ def test_search_with_similarity_threshold( if expected_count > 0: assert {result.title for result in results} == expected_titles + + @pytest.mark.django_db(transaction=True) + async def test_asearch(self, mock_vector_index): + objs = [obj async for obj in ExampleModel.objects.all()] + actual = await mock_vector_index.asearch("test", limit=100) + assert set(actual) == set(objs), f"Expected {objs}, but got {actual}" From b81f03970f16ff8552823e031eb1adb55a79a480 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Thu, 26 Sep 2024 09:45:31 +0000 Subject: [PATCH 08/11] Add type signature for async methods to DocumentManager --- src/wagtail_vector_index/storage/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index bdbe90f..8284a66 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -48,6 +48,10 @@ def for_key(self, object_key: str) -> DocumentQuerySet: ... def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... + def afor_key(self, object_key: str) -> AsyncGenerator["Document", None]: ... + + def afor_keys(self, object_keys: list[str]) -> AsyncGenerator["Document", None]: ... + class Document(models.Model): """Stores an embedding for an arbitrary chunk""" From 66c9d24be416949c44f9ea520a5053282b06391c Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Thu, 26 Sep 2024 09:46:04 +0000 Subject: [PATCH 09/11] Add afrom_document to FromDocumentOperator protocol --- src/wagtail_vector_index/storage/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index e535e55..43f3058 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -78,6 +78,8 @@ class FromDocumentOperator(Protocol[ToObjectType]): def from_document(self, document: "Document") -> ToObjectType: ... + async def afrom_document(self, document: "Document") -> ToObjectType: ... + def bulk_from_documents( self, documents: Iterable["Document"] ) -> Generator[ToObjectType, None, None]: ... From 84d429623ad944ec189cc2eff0874c4b44917185 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 22 Nov 2024 11:41:51 +0000 Subject: [PATCH 10/11] Use transaction for document generation from models --- src/wagtail_vector_index/storage/django.py | 32 ++++++++++++++-------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/wagtail_vector_index/storage/django.py b/src/wagtail_vector_index/storage/django.py index c572f7f..03ab2da 100644 --- a/src/wagtail_vector_index/storage/django.py +++ b/src/wagtail_vector_index/storage/django.py @@ -10,6 +10,7 @@ from typing import ( TYPE_CHECKING, ClassVar, + Iterator, Optional, Type, TypeAlias, @@ -297,14 +298,15 @@ def _get_chunks( ) # Synchronous document generation methods - def _create_new_documents( + @transaction.atomic + def replace_documents( self, object: models.Model, chunks: list[str], - embedding_backend: BaseEmbeddingBackend, - ) -> list[Document]: + embedding_vectors: Iterator[list[float]], + ): + """Replace the current Documents for an object with new ones within a transaction""" Document.objects.for_key(ModelKey(object)).delete() - embedding_vectors = embedding_backend.embed(chunks) return [ Document.objects.create( object_keys=[str(key) for key in self._keys_for_instance(object)], @@ -314,22 +316,28 @@ def _create_new_documents( for chunk, embedding in zip(chunks, embedding_vectors, strict=False) ] - # Asynchronous document generation methods + def _create_new_documents( + self, + object: models.Model, + chunks: list[str], + embedding_backend: BaseEmbeddingBackend, + ) -> list[Document]: + embedding_vectors = embedding_backend.embed(chunks) + return self.replace_documents(object, chunks, embedding_vectors) + # Asynchronous document generation methods async def _acreate_new_documents( self, object: models.Model, chunks: list[str], embedding_backend: BaseEmbeddingBackend, ) -> AsyncGenerator[Document, None]: - await Document.objects.for_key(ModelKey(object)).adelete() embedding_vectors = embedding_backend.embed(chunks) - for chunk, embedding in zip(chunks, embedding_vectors, strict=False): - yield await Document.objects.acreate( - object_keys=[str(key) for key in self._keys_for_instance(object)], - vector=embedding, - content=chunk, - ) + documents = await sync_to_async(self.replace_documents)( + object, chunks, embedding_vectors + ) + for document in documents: + yield document # Bulk document generation methods @transaction.atomic From 5f96824da234cd1a0b1e6e2ebc0bb75c09aa514d Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Fri, 6 Dec 2024 13:53:00 +0000 Subject: [PATCH 11/11] Move preparation of documents and saving of documents to separate stages --- src/wagtail_vector_index/storage/base.py | 49 +-- src/wagtail_vector_index/storage/django.py | 433 ++++++++++----------- src/wagtail_vector_index/storage/models.py | 4 + tests/test_django_converter.py | 206 +++++++--- tests/test_index.py | 4 +- 5 files changed, 396 insertions(+), 300 deletions(-) diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 43f3058..611b7f5 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -103,25 +103,19 @@ class ToDocumentOperator(Protocol[FromObjectType]): def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): ... def to_documents( - self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend - ) -> Generator["Document", None, None]: ... - - async def ato_documents( - self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend - ) -> AsyncGenerator["Document", None]: ... - - def bulk_to_documents( self, objects: Iterable[FromObjectType], *, embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> Generator["Document", None, None]: ... - async def abulk_to_documents( + async def ato_documents( self, objects: Iterable[FromObjectType], *, embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> AsyncGenerator["Document", None]: ... @@ -147,17 +141,25 @@ def from_document_operator(self) -> FromDocumentOperator: return self.from_document_operator_class() def to_documents( - self, object: object, *, embedding_backend: BaseEmbeddingBackend + self, + objects: Iterable[object], + *, + embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> Generator["Document", None, None]: return self.to_document_operator.to_documents( - object, embedding_backend=embedding_backend + objects, embedding_backend=embedding_backend, batch_size=batch_size ) def ato_documents( - self, object: object, *, embedding_backend: BaseEmbeddingBackend + self, + objects: Iterable[object], + *, + embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> AsyncGenerator["Document", None]: return self.to_document_operator.ato_documents( - object, embedding_backend=embedding_backend + objects, embedding_backend=embedding_backend, batch_size=batch_size ) def from_document(self, document: "Document") -> object: @@ -166,23 +168,6 @@ def from_document(self, document: "Document") -> object: def afrom_document(self, document: "Document") -> object: return self.from_document_operator.afrom_document(document) - def bulk_to_documents( - self, objects: Iterable[object], *, embedding_backend: BaseEmbeddingBackend - ) -> Generator["Document", None, None]: - return self.to_document_operator.bulk_to_documents( - objects, embedding_backend=embedding_backend - ) - - def abulk_to_documents( - self, - objects: Iterable[object], - *, - embedding_backend: BaseEmbeddingBackend, - ) -> AsyncGenerator["Document", None]: - return self.to_document_operator.abulk_to_documents( - objects, embedding_backend=embedding_backend - ) - def bulk_from_documents( self, documents: Sequence["Document"] ) -> Generator[object, None, None]: @@ -330,7 +315,7 @@ def find_similar( """Find similar objects to the given object""" converter = self.get_converter() object_documents: Generator[Document, None, None] = converter.to_documents( - object, embedding_backend=self.get_embedding_backend() + [object], embedding_backend=self.get_embedding_backend() ) similar_documents = [] for document in object_documents: @@ -356,7 +341,7 @@ async def afind_similar( converter = self.get_converter() similar_documents = [] async for document in converter.ato_documents( - object, embedding_backend=self.get_embedding_backend() + [object], embedding_backend=self.get_embedding_backend() ): similar_docs = self.aget_similar_documents( document.vector, limit=limit, similarity_threshold=similarity_threshold diff --git a/src/wagtail_vector_index/storage/django.py b/src/wagtail_vector_index/storage/django.py index 03ab2da..86f0c90 100644 --- a/src/wagtail_vector_index/storage/django.py +++ b/src/wagtail_vector_index/storage/django.py @@ -6,6 +6,7 @@ Iterable, Sequence, ) +from dataclasses import dataclass, field from itertools import chain, islice from typing import ( TYPE_CHECKING, @@ -144,12 +145,12 @@ def _has_field(cls, name): @classmethod def _check_embedding_fields(cls, **kwargs): errors = [] - for field in cls._get_embedding_fields(): + for field_ in cls._get_embedding_fields(): message = "{model}.embedding_fields contains non-existent field '{name}'" - if not cls._has_field(field.field_name): + if not cls._has_field(field_.field_name): errors.append( checks.Warning( - message.format(model=cls.__name__, name=field.field_name), + message.format(model=cls.__name__, name=field_.field_name), obj=cls, id="wagtailai.WA001", ) @@ -260,24 +261,100 @@ async def _aget_models_by_key( return objects_by_key -class ModelToDocumentOperator(ToDocumentOperator[models.Model]): - """A class that can generate Documents from model instances""" +@dataclass +class PreparedObject: + """A class that represents a model instance and its chunks - used to persist object metadata and allow batch operations""" - def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): - self.object_chunker_operator = object_chunker_operator_class() + key: ModelKey + object: models.Model + chunks: list[str] + embedding_vectors: list[list[float]] | None = None + existing_documents: list[Document] | None = None + new_documents: list[Document] = field(default_factory=list) + + @property + def documents(self) -> list[Document]: + return self.new_documents or self.existing_documents or [] + + @property + def needs_updating(self) -> bool: + """Determine whether the embeddings need to be updated for this object""" + if not self.existing_documents: + return True + + document_content = {document.content for document in self.existing_documents} + + return set(self.chunks) != document_content + + +@dataclass +class PreparedObjectCollection: + """A collection of PreparedObjects that handles bulk operations like chunk mapping and embedding""" + + objects: list[PreparedObject] + + def __iter__(self) -> Iterator[PreparedObject]: + """Make the collection iterable, yielding PreparedObjects""" + return iter(self.objects) + + @classmethod + def prepare_objects( + cls, + objects: Iterable[models.Model], + *, + chunker_operator: ObjectChunkerOperator, + embedding_backend: BaseEmbeddingBackend, + ) -> "PreparedObjectCollection": + """Create a PreparedObjectCollection from a list of model instances""" + prepared_objects = [] + all_keys = [] + + for object in objects: + key = ModelKey.from_instance(object) + chunks = list( + chunker_operator.chunk_object( + object, chunk_size=embedding_backend.config.token_limit + ) + ) + prepared_objects.append( + PreparedObject(key=key, object=object, chunks=chunks) + ) + all_keys.append(key) + + existing_documents = Document.objects.for_keys(all_keys) + existing_documents_by_key = cls._group_documents_by_object_key( + existing_documents + ) + + for object in prepared_objects: + object.existing_documents = existing_documents_by_key[object.key] + + return cls(objects=prepared_objects) - # Utility methods @staticmethod - def _existing_documents_match( - documents: Iterable[Document], chunks: list[str] - ) -> bool: - """Determine whether the documents passed in match the text content passed in""" - if not documents: - return False + def _group_documents_by_object_key(documents) -> dict[ModelKey, list[Document]]: + """Group documents by their object key""" + documents_by_object_key = defaultdict(list) + for document in documents: + documents_by_object_key[document.object_keys[0]].append(document) + return documents_by_object_key + + @property + def objects_by_key(self) -> dict[ModelKey, PreparedObject]: + return {obj.key: obj for obj in self.objects} - document_content = {document.content for document in documents} + @property + def objects_needing_update(self) -> list[PreparedObject]: + """Return list of objects that need their embeddings updated""" + return [obj for obj in self.objects if obj.needs_updating] + + def get_all_chunks(self) -> list[str]: + """Get all chunks from objects needing updates""" + return [chunk for obj in self.objects_needing_update for chunk in obj.chunks] - return set(chunks) == document_content + def get_chunk_mapping(self) -> list[ModelKey]: + """Create a mapping of chunk indices to object keys""" + return [obj.key for obj in self.objects_needing_update for _ in obj.chunks] @staticmethod def _keys_for_instance(instance: models.Model) -> list[ModelKey]: @@ -287,151 +364,84 @@ def _keys_for_instance(instance: models.Model) -> list[ModelKey]: keys = [ModelKey.from_instance(instance), *keys] return keys - def _get_chunks( - self, object: models.Model, embedding_backend: BaseEmbeddingBackend - ) -> list[str]: - """Get chunks of text from the object using the object chunker operator""" - return list( - self.object_chunker_operator.chunk_object( - object, chunk_size=embedding_backend.config.token_limit - ) - ) + def prepare_new_documents(self, embedding_vectors: Iterable[list[float]]): + """Prepare (but don't save) new documents for objects that need updating with new embeddings""" + if not embedding_vectors: + return - # Synchronous document generation methods - @transaction.atomic - def replace_documents( - self, - object: models.Model, - chunks: list[str], - embedding_vectors: Iterator[list[float]], - ): - """Replace the current Documents for an object with new ones within a transaction""" - Document.objects.for_key(ModelKey(object)).delete() - return [ - Document.objects.create( - object_keys=[str(key) for key in self._keys_for_instance(object)], - vector=embedding, - content=chunk, - ) - for chunk, embedding in zip(chunks, embedding_vectors, strict=False) - ] + chunk_mapping = self.get_chunk_mapping() + all_chunks = self.get_all_chunks() - def _create_new_documents( - self, - object: models.Model, - chunks: list[str], - embedding_backend: BaseEmbeddingBackend, - ) -> list[Document]: - embedding_vectors = embedding_backend.embed(chunks) - return self.replace_documents(object, chunks, embedding_vectors) - - # Asynchronous document generation methods - async def _acreate_new_documents( - self, - object: models.Model, - chunks: list[str], - embedding_backend: BaseEmbeddingBackend, - ) -> AsyncGenerator[Document, None]: - embedding_vectors = embedding_backend.embed(chunks) - documents = await sync_to_async(self.replace_documents)( - object, chunks, embedding_vectors + # Group embeddings by object + embeddings_by_key: dict[ModelKey, list[tuple[int, list[float]]]] = defaultdict( + list ) - for document in documents: - yield document + for idx, embedding in enumerate(embedding_vectors): + object_key = chunk_mapping[idx] + embeddings_by_key[object_key].append((idx, embedding)) - # Bulk document generation methods - @transaction.atomic - def bulk_generate_documents( - self, objects, *, embedding_backend - ) -> Iterable[Document]: - """Generate documents in bulk for the given objects""" - objects_by_key = {ModelKey.from_instance(obj): obj for obj in objects} - documents = Document.objects.for_keys(list(objects_by_key.keys())) - documents_by_object_key = self._group_documents_by_object_key(documents) - objects_to_rebuild, chunk_mapping = self._identify_objects_to_rebuild( - objects_by_key=objects_by_key, - documents_by_object_key=documents_by_object_key, - embedding_backend=embedding_backend, - ) + # Create new documents for each object + for object_key, embeddings in embeddings_by_key.items(): + prepared_obj = self.objects_by_key[object_key] + for idx, embedding in embeddings: + chunk = all_chunks[idx] + all_keys = self._keys_for_instance(prepared_obj.object) + prepared_obj.new_documents.append( + Document( + object_keys=all_keys, + vector=embedding, + content=chunk, + ) + ) - if not objects_to_rebuild: - return documents + print([obj.new_documents for obj in self.objects]) - return self._rebuild_documents( - objects_to_rebuild=objects_to_rebuild, - chunk_mapping=chunk_mapping, - objects_by_key=objects_by_key, - embedding_backend=embedding_backend, - ) - async def abulk_generate_documents( - self, - objects: Iterable[models.Model], - *, - embedding_backend: BaseEmbeddingBackend, - ) -> AsyncGenerator[Document, None]: - """Generate documents in bulk for the given objects asynchronously""" - documents = await sync_to_async(self.bulk_generate_documents)( - objects=objects, embedding_backend=embedding_backend - ) - for document in documents: - yield document +class ModelToDocumentOperator(ToDocumentOperator[models.Model]): + """A class that can generate Documents from model instances""" - # Helper methods for bulk document generation - def _group_documents_by_object_key(self, documents): - """Group documents by their object key""" - documents_by_object_key = defaultdict(list) - for document in documents: - documents_by_object_key[document.object_keys[0]].append(document) - return documents_by_object_key + def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): + self.object_chunker_operator = object_chunker_operator_class() - def _identify_objects_to_rebuild( - self, *, objects_by_key, documents_by_object_key, embedding_backend + @transaction.atomic + def update_documents( + self, + collection: PreparedObjectCollection, ): - """Identify which objects need to be rebuilt""" - objects_to_rebuild = {} - chunk_mapping = [] - for key, object in objects_by_key.items(): - documents_for_object = documents_by_object_key[key] - chunks = self._get_chunks( - object=object, embedding_backend=embedding_backend + """Replace the current Documents for all objects that have new documents to save""" + replaced_keys = [str(obj.key) for obj in collection if obj.new_documents] + if replaced_keys: + Document.objects.for_keys(replaced_keys).delete() + Document.objects.bulk_create( + chain(*[obj.new_documents for obj in collection if obj.new_documents]) ) - if not self._existing_documents_match( - documents=documents_for_object, chunks=chunks - ): - objects_to_rebuild[key] = {"object": object, "chunks": chunks} - chunk_mapping += [key] * len(chunks) - return objects_to_rebuild, chunk_mapping - def _rebuild_documents( - self, *, objects_to_rebuild, chunk_mapping, objects_by_key, embedding_backend + def _update_object_collection_with_new_documents( + self, + collection: PreparedObjectCollection, + embedding_backend: BaseEmbeddingBackend, ): - """Rebuild Documents for the identified objects""" - all_chunks = list( - chain(*[obj["chunks"] for obj in objects_to_rebuild.values()]) - ) - embedding_vectors = list(embedding_backend.embed(all_chunks)) - documents_by_object = self._group_embeddings_by_object( - embedding_vectors=embedding_vectors, chunk_mapping=chunk_mapping - ) + objects_to_rebuild = collection.objects_needing_update - self._delete_existing_documents(documents_by_object=documents_by_object) - self._create_new_documents_bulk( - documents_by_object=documents_by_object, - objects_by_key=objects_by_key, - all_chunks=all_chunks, - ) + if not objects_to_rebuild: + return list( + chain( + *[ + obj.existing_documents + for obj in collection + if obj.existing_documents + ] + ) + ) - return self._get_sorted_documents(objects_by_key=objects_by_key) + # Get embeddings for all chunks that need updating + all_chunks = collection.get_all_chunks() + embedding_vectors = list(embedding_backend.embed(all_chunks)) - def _group_embeddings_by_object(self, *, embedding_vectors, chunk_mapping): - """Group embedding vectors by their corresponding object""" - documents_by_object = defaultdict(list) - for idx, embedding in enumerate(embedding_vectors): - object_key = chunk_mapping[idx] - documents_by_object[object_key].append((idx, embedding)) - return documents_by_object + # Apply the embeddings to create new documents + collection.prepare_new_documents(embedding_vectors) + # Helper methods for bulk document generation def _delete_existing_documents(self, *, documents_by_object): existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) existing_documents.delete() @@ -440,102 +450,87 @@ async def _adelete_existing_documents(self, *, documents_by_object): existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) await existing_documents.adelete() - def _create_new_documents_bulk( - self, *, documents_by_object, objects_by_key, all_chunks + def _to_documents_batch( + self, objects: Iterable[models.Model], embedding_backend: BaseEmbeddingBackend ): - for object_key, documents in documents_by_object.items(): - for idx, returned_embedding in documents: - all_keys = self._keys_for_instance(objects_by_key[object_key]) - chunk = all_chunks[idx] - Document.objects.create( - object_keys=all_keys, - vector=returned_embedding, - content=chunk, - ) - - async def _acreate_new_documents_bulk( - self, *, documents_by_object, objects_by_key, all_chunks - ): - for object_key, documents in documents_by_object.items(): - for idx, returned_embedding in documents: - all_keys = self._keys_for_instance(objects_by_key[object_key]) - chunk = all_chunks[idx] - await Document.objects.acreate( - object_keys=all_keys, - vector=returned_embedding, - content=chunk, - ) - - def _get_sorted_documents(self, *, objects_by_key): - """Get sorted documents for the given objects""" - documents = list(Document.objects.for_keys(list(objects_by_key.keys()))) - return sorted( - documents, - key=lambda doc: list(objects_by_key.keys()).index( - ModelKey(doc.object_keys[0]) - ), + collection = PreparedObjectCollection.prepare_objects( + objects=objects, + chunker_operator=self.object_chunker_operator, + embedding_backend=embedding_backend, ) + self._update_object_collection_with_new_documents(collection, embedding_backend) + self.update_documents(collection) + yield from [doc for obj in collection for doc in obj.documents] # Interface methods - @transaction.atomic def to_documents( - self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend - ) -> Generator[Document, None, None]: - """Use the AI backend to generate and store Documents for this object""" - chunks = self._get_chunks(object, embedding_backend) - documents = Document.objects.for_key(ModelKey(object)) - - if self._existing_documents_match(list(documents), chunks): - yield from documents - - yield from self._create_new_documents(object, chunks, embedding_backend) - - async def ato_documents( - self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend - ) -> AsyncGenerator[Document, None]: - """Use the AI backend to generate and store Documents for this object asynchronously""" - chunks = self._get_chunks(object, embedding_backend) - documents = Document.objects.afor_key(ModelKey(object)) - documents = [doc async for doc in documents] - - if self._existing_documents_match(documents, chunks): - for document in documents: - yield document - - async for document in self._acreate_new_documents( - object, chunks, embedding_backend - ): - yield document - - def bulk_to_documents( self, objects: Iterable[models.Model], *, - batch_size: int = 100, embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> Generator[Document, None, None]: - """Convert multiple model instances to Documents in batches""" batches = list(batched(objects, batch_size)) for idx, batch in enumerate(batches): logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") - for document in self.bulk_generate_documents( + for document in self._to_documents_batch( batch, embedding_backend=embedding_backend ): yield document - async def abulk_to_documents( + async def _ato_documents_batch( + self, objects: Iterable[models.Model], embedding_backend: BaseEmbeddingBackend + ): + collection = await sync_to_async(PreparedObjectCollection.prepare_objects)( + objects=objects, + chunker_operator=self.object_chunker_operator, + embedding_backend=embedding_backend, + ) + await self._aupdate_object_collection_with_new_documents( + collection, embedding_backend + ) + # Using sync_to_async to ensure the update can happen in a transaction + await sync_to_async(self.update_documents)(collection) + return [doc for obj in collection.objects for doc in obj.documents] + + async def _aupdate_object_collection_with_new_documents( + self, + collection: PreparedObjectCollection, + embedding_backend: BaseEmbeddingBackend, + ): + """Async version of _update_object_collection_with_new_documents""" + objects_to_rebuild = collection.objects_needing_update + + if not objects_to_rebuild: + return list( + chain( + *[ + obj.existing_documents + for obj in collection + if obj.existing_documents + ] + ) + ) + + # Get embeddings for all chunks that need updating + all_chunks = collection.get_all_chunks() + embedding_vectors = await embedding_backend.aembed(all_chunks) + + # Apply the embeddings to create new documents + collection.prepare_new_documents(embedding_vectors) + + async def ato_documents( self, objects: Iterable[models.Model], *, - batch_size: int = 100, embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> AsyncGenerator[Document, None]: - """Convert multiple model instances to Documents asynchronously in batches""" batches = list(batched(objects, batch_size)) for idx, batch in enumerate(batches): logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") - async for document in self.abulk_generate_documents( - objects=batch, embedding_backend=embedding_backend + for document in await self._ato_documents_batch( + batch, embedding_backend=embedding_backend ): yield document @@ -551,15 +546,15 @@ def chunk_object( important_content = [] embedding_fields = object._meta.model._get_embedding_fields() - for field in embedding_fields: - value = field.get_value(object) + for field_ in embedding_fields: + value = field_.get_value(object) if value is None: continue if isinstance(value, str): final_value = value else: final_value: str = "\n".join((str(v) for v in value)) - if field.important: + if field_.important: important_content.append(final_value) else: splittable_content.append(final_value) @@ -622,7 +617,7 @@ def get_documents(self) -> Iterable[Document]: # Embedding models are created, even if it is not consumed # by the caller all_documents += list( - self.get_converter().bulk_to_documents( + self.get_converter().to_documents( queryset, embedding_backend=self.get_embedding_backend() ) ) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index 8284a66..39074e5 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -27,12 +27,16 @@ async def afor_key(self, object_key: str) -> AsyncGenerator["Document", None]: yield doc def for_keys(self, object_keys: list[str]): + if not object_keys: + return self.none() q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] return self.filter(reduce(operator.or_, q_objs)) async def afor_keys( self, object_keys: list[str] ) -> AsyncGenerator["Document", None]: + if not object_keys: + return q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] async for doc in self.filter(reduce(operator.or_, q_objs)): yield doc diff --git a/tests/test_django_converter.py b/tests/test_django_converter.py index 2995b3b..3d880d4 100644 --- a/tests/test_django_converter.py +++ b/tests/test_django_converter.py @@ -1,5 +1,6 @@ import factory import pytest +from asgiref.sync import sync_to_async from factories import ( AsyncExampleModelFactory, DifferentPageFactory, @@ -15,8 +16,11 @@ EmbeddableFieldsObjectChunkerOperator, EmbeddingField, ModelFromDocumentOperator, + ModelKey, ModelLabel, ModelToDocumentOperator, + PreparedObject, + PreparedObjectCollection, ) from wagtail_vector_index.storage.models import Document @@ -185,25 +189,6 @@ def test_bulk_from_documents_returns_deduplicated_model_objects(self): class TestToDocument: - def test_existing_documents_match(self): - text_contents = ["This is a test", "Another test", "More testing content"] - documents = [ - DocumentFactory.build(content=content) for content in text_contents - ] - operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - assert operator._existing_documents_match(documents, text_contents) - - @pytest.mark.django_db - def test_keys_for_instance(self): - instance = ExamplePageFactory.create( - title="Important Title", body=fake.text(max_nb_chars=200) - ) - operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - keys = operator._keys_for_instance(instance) - assert len(keys) == 2 - assert keys[0] == f"testapp.ExamplePage:{instance.pk}" - assert keys[1] == f"wagtailcore.Page:{instance.pk}" - @pytest.mark.django_db def test_generate_documents_returns_documents(self): instance = ExamplePageFactory.create( @@ -212,7 +197,7 @@ def test_generate_documents_returns_documents(self): operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) documents = list( operator.to_documents( - instance, embedding_backend=get_embedding_backend("default") + [instance], embedding_backend=get_embedding_backend("default") ) ) assert len(documents) == 1 @@ -223,7 +208,7 @@ def test_bulk_generate_documents_returns_documents(self): instances = ExamplePageFactory.create_batch(3) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) documents = list( - operator.bulk_to_documents( + operator.to_documents( instances, embedding_backend=get_embedding_backend("default") ) ) @@ -239,7 +224,7 @@ def test_bulk_generate_documents_returns_documents_for_multiple_models(self): different_pages = DifferentPageFactory.create_batch(3) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) documents = list( - operator.bulk_to_documents( + operator.to_documents( example_pages + different_pages, embedding_backend=get_embedding_backend("default"), ) @@ -253,18 +238,18 @@ def test_bulk_generate_documents_returns_documents_for_multiple_models(self): ) @pytest.mark.django_db - def test_bulk_to_documents_batches_objects(self, mocker): + def test_to_documents_batches_objects(self, mocker): instances = ExamplePageFactory.create_batch(10) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - bulk_generate_mock = mocker.patch.object(operator, "bulk_generate_documents") + to_documents_batch_mock = mocker.patch.object(operator, "_to_documents_batch") list( - operator.bulk_to_documents( + operator.to_documents( instances, embedding_backend=get_embedding_backend("default"), batch_size=2, ) ) - assert bulk_generate_mock.call_count == 5 + assert to_documents_batch_mock.call_count == 5 class TestConverter: @@ -277,7 +262,7 @@ def test_returns_original_object(self, patch_embedding_fields): converter = EmbeddableFieldsDocumentConverter() document = next( converter.to_documents( - instance, embedding_backend=get_embedding_backend("default") + [instance], embedding_backend=get_embedding_backend("default") ) ) recovered_instance = converter.from_document(document) @@ -293,7 +278,7 @@ def test_convert_single_document_to_object(): ) documents = list( converter.to_documents( - instance, embedding_backend=get_embedding_backend("default") + [instance], embedding_backend=get_embedding_backend("default") ) ) recovered_instance = converter.from_document(documents[0]) @@ -309,7 +294,7 @@ def test_convert_multiple_documents_to_objects(): different_pages = DifferentPageFactory.create_batch(5) all_objects = list(example_objects + example_pages + different_pages) documents = list( - converter.bulk_to_documents( + converter.to_documents( all_objects, embedding_backend=get_embedding_backend("default") ) ) @@ -319,36 +304,163 @@ def test_convert_multiple_documents_to_objects(): class TestToDocumentOperatorAsync: @pytest.mark.django_db(transaction=True) - async def test_ato_documents(self, mock_embedding_backend): + async def test_ato_documents_batch(self, mock_embedding_backend): instance = await AsyncExampleModelFactory.create( title="Important Title", body=fake.text(max_nb_chars=200) ) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - documents = [ - doc - async for doc in operator.ato_documents( - instance, embedding_backend=mock_embedding_backend - ) - ] + + documents = await operator._ato_documents_batch( + [instance], embedding_backend=mock_embedding_backend + ) assert len(documents) > 0 assert all(isinstance(doc, Document) for doc in documents) assert all(instance.title in doc.content for doc in documents) @pytest.mark.django_db(transaction=True) - async def test_abulk_to_documents(self, mock_embedding_backend): - instances = await AsyncExampleModelFactory.create_batch(3) + async def test_aupdate_object_collection_with_new_documents( + self, mock_embedding_backend + ): + instance = await AsyncExampleModelFactory.create() + collection = await sync_to_async(PreparedObjectCollection.prepare_objects)( + objects=[instance], + chunker_operator=EmbeddableFieldsObjectChunkerOperator(), + embedding_backend=mock_embedding_backend, + ) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + await operator._aupdate_object_collection_with_new_documents( + collection, mock_embedding_backend + ) + + assert any(obj.new_documents for obj in collection) + assert all( + isinstance(doc, Document) for obj in collection for doc in obj.new_documents + ) + + +class TestPreparedObject: + @pytest.mark.django_db + def test_needs_updating_when_no_existing_documents(self): + instance = ExamplePageFactory.build() + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + ) + assert prepared_object.needs_updating is True + + @pytest.mark.django_db + def test_needs_updating_when_chunks_match(self): + instance = ExamplePageFactory.build() + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + existing_documents=[ + DocumentFactory.build(content="chunk1"), + DocumentFactory.build(content="chunk2"), + ], + ) + assert prepared_object.needs_updating is False + + @pytest.mark.django_db + def test_needs_updating_when_chunks_differ(self): + instance = ExamplePageFactory.build() + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + existing_documents=[ + DocumentFactory.build(content="chunk1"), + DocumentFactory.build(content="different chunk"), + ], + ) + assert prepared_object.needs_updating is True + + @pytest.mark.django_db + def test_documents_returns_new_documents_when_present(self): + instance = ExamplePageFactory.build() + new_docs = [DocumentFactory.build(), DocumentFactory.build()] + existing_docs = [DocumentFactory.build(), DocumentFactory.build()] + + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1"], + new_documents=new_docs, + existing_documents=existing_docs, + ) + assert prepared_object.documents == new_docs + + @pytest.mark.django_db + def test_documents_returns_existing_documents_when_no_new_ones(self): + instance = ExamplePageFactory.build() + existing_docs = [DocumentFactory.build(), DocumentFactory.build()] + + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1"], + existing_documents=existing_docs, + ) + assert prepared_object.documents == existing_docs + + +class TestPreparedObjectCollection: + @pytest.mark.django_db + def test_prepare_objects(self, patch_embedding_fields): + with patch_embedding_fields(ExamplePage, [EmbeddingField("body")]): + instances = ExamplePageFactory.create_batch(3) + chunker = EmbeddableFieldsObjectChunkerOperator() - documents = [ - doc - async for doc in operator.abulk_to_documents( - instances, embedding_backend=mock_embedding_backend + collection = PreparedObjectCollection.prepare_objects( + objects=instances, + chunker_operator=chunker, + embedding_backend=get_embedding_backend("default"), + ) + + assert len(collection.objects) == 3 + assert all(isinstance(obj, PreparedObject) for obj in collection) + assert all(obj.chunks for obj in collection) + + @pytest.mark.django_db + def test_get_chunk_mapping(self): + instances = ExamplePageFactory.build_batch(2) + prepared_objects = [ + PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], ) + for instance in instances ] + collection = PreparedObjectCollection(objects=prepared_objects) - assert len(documents) > 0 - assert all(isinstance(doc, Document) for doc in documents) - assert any(instances[0].title in doc.content for doc in documents) - assert any(instances[1].title in doc.content for doc in documents) - assert any(instances[2].title in doc.content for doc in documents) + chunk_mapping = collection.get_chunk_mapping() + assert len(chunk_mapping) == 4 # 2 instances * 2 chunks each + assert all(isinstance(key, ModelKey) for key in chunk_mapping) + + @pytest.mark.django_db + def test_prepare_new_documents(self): + instances = ExamplePageFactory.create_batch(2) + prepared_objects = [ + PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + ) + for instance in instances + ] + collection = PreparedObjectCollection(objects=prepared_objects) + + # Mock embedding vectors + embedding_vectors = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] + collection.prepare_new_documents(embedding_vectors) + + assert all(len(obj.new_documents) == 2 for obj in collection) + assert all( + all(isinstance(doc, Document) for doc in obj.new_documents) + for obj in collection + ) diff --git a/tests/test_index.py b/tests/test_index.py index ed284b4..d367f19 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -67,12 +67,12 @@ def mock_vector_index( ) mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_to_documents", + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.to_documents", side_effect=document_generator, ) mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.abulk_to_documents", + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.ato_documents", side_effect=async_document_generator, )