diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 0cb81c4..611b7f5 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -42,6 +42,8 @@ def upsert(self) -> None: ... def get_documents(self) -> Iterable["Document"]: ... + async def aget_documents(self) -> AsyncGenerator["Document", None]: ... + def _get_storage_provider(self) -> StorageProviderClass: ... @@ -76,6 +78,8 @@ class FromDocumentOperator(Protocol[ToObjectType]): def from_document(self, document: "Document") -> ToObjectType: ... + async def afrom_document(self, document: "Document") -> ToObjectType: ... + def bulk_from_documents( self, documents: Iterable["Document"] ) -> Generator[ToObjectType, None, None]: ... @@ -99,19 +103,30 @@ class ToDocumentOperator(Protocol[FromObjectType]): def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): ... def to_documents( - self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend + self, + objects: Iterable[FromObjectType], + *, + embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> Generator["Document", None, None]: ... - def bulk_to_documents( + async def ato_documents( self, objects: Iterable[FromObjectType], *, embedding_backend: BaseEmbeddingBackend, - ) -> Generator["Document", None, None]: ... + batch_size: int = 100, + ) -> AsyncGenerator["Document", None]: ... class DocumentConverter(ABC): - """Base class for a DocumentConverter that can convert objects to Documents and vice versa""" + """Base class for a DocumentConverter that can convert objects to Documents and vice versa + + Note on async methods: + Some async (a-prefixed) methods in this class return AsyncGenerators directly from + the methods of the to_document_operator or from_document_operator, so they aren't marked + with the 'async' keyword to prevent the methods from being wrapped in a Coroutine. + """ to_document_operator_class: Type[ToDocumentOperator] from_document_operator_class: Type[FromDocumentOperator] @@ -126,21 +141,32 @@ def from_document_operator(self) -> FromDocumentOperator: return self.from_document_operator_class() def to_documents( - self, object: object, *, embedding_backend: BaseEmbeddingBackend + self, + objects: Iterable[object], + *, + embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> Generator["Document", None, None]: return self.to_document_operator.to_documents( - object, embedding_backend=embedding_backend + objects, embedding_backend=embedding_backend, batch_size=batch_size + ) + + def ato_documents( + self, + objects: Iterable[object], + *, + embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, + ) -> AsyncGenerator["Document", None]: + return self.to_document_operator.ato_documents( + objects, embedding_backend=embedding_backend, batch_size=batch_size ) def from_document(self, document: "Document") -> object: return self.from_document_operator.from_document(document) - def bulk_to_documents( - self, objects: Iterable[object], *, embedding_backend: BaseEmbeddingBackend - ) -> Generator["Document", None, None]: - return self.to_document_operator.bulk_to_documents( - objects, embedding_backend=embedding_backend - ) + def afrom_document(self, document: "Document") -> object: + return self.from_document_operator.afrom_document(document) def bulk_from_documents( self, documents: Sequence["Document"] @@ -186,6 +212,9 @@ def get_embedding_backend(self) -> BaseEmbeddingBackend: def get_documents(self) -> Iterable["Document"]: raise NotImplementedError + async def aget_documents(self) -> AsyncGenerator["Document", None]: + raise NotImplementedError + def get_converter(self) -> DocumentConverter: raise NotImplementedError @@ -286,7 +315,7 @@ def find_similar( """Find similar objects to the given object""" converter = self.get_converter() object_documents: Generator[Document, None, None] = converter.to_documents( - object, embedding_backend=self.get_embedding_backend() + [object], embedding_backend=self.get_embedding_backend() ) similar_documents = [] for document in object_documents: @@ -300,6 +329,31 @@ def find_similar( if include_self or obj != object ] + async def afind_similar( + self, + object, + *, + include_self: bool = False, + limit: int = 5, + similarity_threshold: float = 0.0, + ) -> list: + """Find similar objects to the given object asynchronously""" + converter = self.get_converter() + similar_documents = [] + async for document in converter.ato_documents( + [object], embedding_backend=self.get_embedding_backend() + ): + similar_docs = self.aget_similar_documents( + document.vector, limit=limit, similarity_threshold=similarity_threshold + ) + similar_documents.extend([doc async for doc in similar_docs]) + + return [ + obj + async for obj in converter.abulk_from_documents(similar_documents) + if include_self or obj != object + ] + def search( self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0 ) -> list: @@ -315,6 +369,27 @@ def search( ) return list(self.get_converter().bulk_from_documents(similar_documents)) + async def asearch( + self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0 + ) -> list: + """Perform a search against the index, returning only a list of matching sources""" + try: + query_embedding = next(await self.get_embedding_backend().aembed([query])) + except StopIteration as e: + raise ValueError("No embeddings were generated for the given query.") from e + similar_documents = [ + doc + async for doc in self.aget_similar_documents( + query_embedding, limit=limit, similarity_threshold=similarity_threshold + ) + ] + return [ + obj + async for obj in self.get_converter().abulk_from_documents( + similar_documents + ) + ] + # Utilities def _get_storage_provider(self): diff --git a/src/wagtail_vector_index/storage/django.py b/src/wagtail_vector_index/storage/django.py index e26fac6..86f0c90 100644 --- a/src/wagtail_vector_index/storage/django.py +++ b/src/wagtail_vector_index/storage/django.py @@ -4,19 +4,21 @@ AsyncGenerator, Generator, Iterable, - MutableSequence, Sequence, ) +from dataclasses import dataclass, field from itertools import chain, islice from typing import ( TYPE_CHECKING, ClassVar, + Iterator, Optional, Type, TypeAlias, cast, ) +from asgiref.sync import sync_to_async from django.apps import apps from django.core import checks from django.core.exceptions import FieldDoesNotExist @@ -143,12 +145,12 @@ def _has_field(cls, name): @classmethod def _check_embedding_fields(cls, **kwargs): errors = [] - for field in cls._get_embedding_fields(): + for field_ in cls._get_embedding_fields(): message = "{model}.embedding_fields contains non-existent field '{name}'" - if not cls._has_field(field.field_name): + if not cls._has_field(field_.field_name): errors.append( checks.Warning( - message.format(model=cls.__name__, name=field.field_name), + message.format(model=cls.__name__, name=field_.field_name), obj=cls, id="wagtailai.WA001", ) @@ -174,7 +176,9 @@ def bulk_from_documents( keys_by_model_label = self._get_keys_by_model_label(documents) objects_by_key = self._get_models_by_key(keys_by_model_label) - yield from self._get_deduplicated_objects_generator(documents, objects_by_key) + yield from self._get_deduplicated_objects_generator( + documents=documents, objects_by_key=objects_by_key + ) async def abulk_from_documents( self, documents: Sequence[Document] @@ -185,7 +189,7 @@ async def abulk_from_documents( # N.B. `yield from` cannot be used in async functions, so we have to use a loop for object_from_document in self._get_deduplicated_objects_generator( - documents, objects_by_key + documents=documents, objects_by_key=objects_by_key ): yield object_from_document @@ -210,7 +214,7 @@ def _get_keys_by_model_label( @staticmethod def _get_deduplicated_objects_generator( - documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] + *, documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] ) -> Generator[models.Model, None, None]: seen_keys = set() # de-dupe as we go for doc in documents: @@ -257,23 +261,100 @@ async def _aget_models_by_key( return objects_by_key -class ModelToDocumentOperator(ToDocumentOperator[models.Model]): - """A class that can generate Documents from model instances""" +@dataclass +class PreparedObject: + """A class that represents a model instance and its chunks - used to persist object metadata and allow batch operations""" + + key: ModelKey + object: models.Model + chunks: list[str] + embedding_vectors: list[list[float]] | None = None + existing_documents: list[Document] | None = None + new_documents: list[Document] = field(default_factory=list) + + @property + def documents(self) -> list[Document]: + return self.new_documents or self.existing_documents or [] + + @property + def needs_updating(self) -> bool: + """Determine whether the embeddings need to be updated for this object""" + if not self.existing_documents: + return True + + document_content = {document.content for document in self.existing_documents} + + return set(self.chunks) != document_content - def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): - self.object_chunker_operator = object_chunker_operator_class() + +@dataclass +class PreparedObjectCollection: + """A collection of PreparedObjects that handles bulk operations like chunk mapping and embedding""" + + objects: list[PreparedObject] + + def __iter__(self) -> Iterator[PreparedObject]: + """Make the collection iterable, yielding PreparedObjects""" + return iter(self.objects) + + @classmethod + def prepare_objects( + cls, + objects: Iterable[models.Model], + *, + chunker_operator: ObjectChunkerOperator, + embedding_backend: BaseEmbeddingBackend, + ) -> "PreparedObjectCollection": + """Create a PreparedObjectCollection from a list of model instances""" + prepared_objects = [] + all_keys = [] + + for object in objects: + key = ModelKey.from_instance(object) + chunks = list( + chunker_operator.chunk_object( + object, chunk_size=embedding_backend.config.token_limit + ) + ) + prepared_objects.append( + PreparedObject(key=key, object=object, chunks=chunks) + ) + all_keys.append(key) + + existing_documents = Document.objects.for_keys(all_keys) + existing_documents_by_key = cls._group_documents_by_object_key( + existing_documents + ) + + for object in prepared_objects: + object.existing_documents = existing_documents_by_key[object.key] + + return cls(objects=prepared_objects) @staticmethod - def _existing_documents_match( - documents: Iterable[Document], splits: list[str] - ) -> bool: - """Determine whether the documents passed in match the text content passed in""" - if not documents: - return False + def _group_documents_by_object_key(documents) -> dict[ModelKey, list[Document]]: + """Group documents by their object key""" + documents_by_object_key = defaultdict(list) + for document in documents: + documents_by_object_key[document.object_keys[0]].append(document) + return documents_by_object_key + + @property + def objects_by_key(self) -> dict[ModelKey, PreparedObject]: + return {obj.key: obj for obj in self.objects} + + @property + def objects_needing_update(self) -> list[PreparedObject]: + """Return list of objects that need their embeddings updated""" + return [obj for obj in self.objects if obj.needs_updating] - document_content = {document.content for document in documents} + def get_all_chunks(self) -> list[str]: + """Get all chunks from objects needing updates""" + return [chunk for obj in self.objects_needing_update for chunk in obj.chunks] - return set(splits) == document_content + def get_chunk_mapping(self) -> list[ModelKey]: + """Create a mapping of chunk indices to object keys""" + return [obj.key for obj in self.objects_needing_update for _ in obj.chunks] @staticmethod def _keys_for_instance(instance: models.Model) -> list[ModelKey]: @@ -283,123 +364,176 @@ def _keys_for_instance(instance: models.Model) -> list[ModelKey]: keys = [ModelKey.from_instance(instance), *keys] return keys - @transaction.atomic - def generate_documents( - self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend - ) -> list[Document]: - """Use the AI backend to generate and store Documents for this object""" - chunks = list( - self.object_chunker_operator.chunk_object( - object, chunk_size=embedding_backend.config.token_limit - ) + def prepare_new_documents(self, embedding_vectors: Iterable[list[float]]): + """Prepare (but don't save) new documents for objects that need updating with new embeddings""" + if not embedding_vectors: + return + + chunk_mapping = self.get_chunk_mapping() + all_chunks = self.get_all_chunks() + + # Group embeddings by object + embeddings_by_key: dict[ModelKey, list[tuple[int, list[float]]]] = defaultdict( + list ) - documents = Document.objects.for_key(ModelKey(object)) - - # If the existing embeddings all match on content, we return them - # without generating new ones - if self._existing_documents_match(documents, chunks): - return list(documents) - - # Otherwise we delete all the existing Documents and get new ones - documents.delete() - - embedding_vectors = embedding_backend.embed(chunks) - generated_documents: MutableSequence[Document] = [] - for idx, returned_embedding in enumerate(embedding_vectors): - chunk = chunks[idx] - document = Document.objects.create( - object_keys=[str(key) for key in self._keys_for_instance(object)], - vector=returned_embedding, - content=chunk, - ) - generated_documents.append(document) + for idx, embedding in enumerate(embedding_vectors): + object_key = chunk_mapping[idx] + embeddings_by_key[object_key].append((idx, embedding)) - return generated_documents + # Create new documents for each object + for object_key, embeddings in embeddings_by_key.items(): + prepared_obj = self.objects_by_key[object_key] + for idx, embedding in embeddings: + chunk = all_chunks[idx] + all_keys = self._keys_for_instance(prepared_obj.object) + prepared_obj.new_documents.append( + Document( + object_keys=all_keys, + vector=embedding, + content=chunk, + ) + ) - @transaction.atomic - def bulk_generate_documents(self, objects, *, embedding_backend): - objects_by_key = {ModelKey.from_instance(obj): obj for obj in objects} - documents = Document.objects.for_keys(list(objects_by_key.keys())) + print([obj.new_documents for obj in self.objects]) - documents_by_object_key = defaultdict(list) - for document in documents: - documents_by_object_key[document.object_keys[0]].append(document) - objects_to_rebuild = {} +class ModelToDocumentOperator(ToDocumentOperator[models.Model]): + """A class that can generate Documents from model instances""" - # Maintain a list of object keys in the order they appear in the chunks - # so we can map the embeddings from the backend to the correct object - chunk_mapping = [] + def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): + self.object_chunker_operator = object_chunker_operator_class() - # Determine which objects need to be rebuilt - for key, object in objects_by_key.items(): - documents_for_object = documents_by_object_key[key] - chunks = list( - self.object_chunker_operator.chunk_object( - object, chunk_size=embedding_backend.config.token_limit - ) + @transaction.atomic + def update_documents( + self, + collection: PreparedObjectCollection, + ): + """Replace the current Documents for all objects that have new documents to save""" + replaced_keys = [str(obj.key) for obj in collection if obj.new_documents] + if replaced_keys: + Document.objects.for_keys(replaced_keys).delete() + Document.objects.bulk_create( + chain(*[obj.new_documents for obj in collection if obj.new_documents]) ) - if not self._existing_documents_match(documents_for_object, chunks): - objects_to_rebuild[key] = {"object": object, "chunks": chunks} - chunk_mapping += [key] * len(chunks) + def _update_object_collection_with_new_documents( + self, + collection: PreparedObjectCollection, + embedding_backend: BaseEmbeddingBackend, + ): + objects_to_rebuild = collection.objects_needing_update if not objects_to_rebuild: - return documents - - all_chunks = list( - chain(*[obj["chunks"] for obj in objects_to_rebuild.values()]) - ) + return list( + chain( + *[ + obj.existing_documents + for obj in collection + if obj.existing_documents + ] + ) + ) + # Get embeddings for all chunks that need updating + all_chunks = collection.get_all_chunks() embedding_vectors = list(embedding_backend.embed(all_chunks)) - documents_by_object = defaultdict(list) - for idx, embedding in enumerate(embedding_vectors): - object_key = chunk_mapping[idx] - documents_by_object[object_key].append((idx, embedding)) + # Apply the embeddings to create new documents + collection.prepare_new_documents(embedding_vectors) + # Helper methods for bulk document generation + def _delete_existing_documents(self, *, documents_by_object): existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) existing_documents.delete() - for object_key, documents in documents_by_object.items(): - for idx, returned_embedding in documents: - all_keys = self._keys_for_instance(objects_by_key[object_key]) - chunk = all_chunks[idx] - Document.objects.create( - object_keys=all_keys, - vector=returned_embedding, - content=chunk, - ) - - # Return every document object, regardless of whether it was rebuilt, retaining - # the order they appeared in the original list - documents = list(Document.objects.for_keys(list(objects_by_key.keys()))) - return sorted( - documents, - key=lambda doc: list(objects_by_key.keys()).index( - ModelKey(doc.object_keys[0]) - ), + async def _adelete_existing_documents(self, *, documents_by_object): + existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) + await existing_documents.adelete() + + def _to_documents_batch( + self, objects: Iterable[models.Model], embedding_backend: BaseEmbeddingBackend + ): + collection = PreparedObjectCollection.prepare_objects( + objects=objects, + chunker_operator=self.object_chunker_operator, + embedding_backend=embedding_backend, ) + self._update_object_collection_with_new_documents(collection, embedding_backend) + self.update_documents(collection) + yield from [doc for obj in collection for doc in obj.documents] + # Interface methods def to_documents( - self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend - ) -> Generator[Document, None, None]: - yield from self.generate_documents(object, embedding_backend=embedding_backend) - - def bulk_to_documents( self, objects: Iterable[models.Model], *, - batch_size: int = 100, embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, ) -> Generator[Document, None, None]: batches = list(batched(objects, batch_size)) for idx, batch in enumerate(batches): logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") - yield from self.bulk_generate_documents( + for document in self._to_documents_batch( batch, embedding_backend=embedding_backend + ): + yield document + + async def _ato_documents_batch( + self, objects: Iterable[models.Model], embedding_backend: BaseEmbeddingBackend + ): + collection = await sync_to_async(PreparedObjectCollection.prepare_objects)( + objects=objects, + chunker_operator=self.object_chunker_operator, + embedding_backend=embedding_backend, + ) + await self._aupdate_object_collection_with_new_documents( + collection, embedding_backend + ) + # Using sync_to_async to ensure the update can happen in a transaction + await sync_to_async(self.update_documents)(collection) + return [doc for obj in collection.objects for doc in obj.documents] + + async def _aupdate_object_collection_with_new_documents( + self, + collection: PreparedObjectCollection, + embedding_backend: BaseEmbeddingBackend, + ): + """Async version of _update_object_collection_with_new_documents""" + objects_to_rebuild = collection.objects_needing_update + + if not objects_to_rebuild: + return list( + chain( + *[ + obj.existing_documents + for obj in collection + if obj.existing_documents + ] + ) ) + # Get embeddings for all chunks that need updating + all_chunks = collection.get_all_chunks() + embedding_vectors = await embedding_backend.aembed(all_chunks) + + # Apply the embeddings to create new documents + collection.prepare_new_documents(embedding_vectors) + + async def ato_documents( + self, + objects: Iterable[models.Model], + *, + embedding_backend: BaseEmbeddingBackend, + batch_size: int = 100, + ) -> AsyncGenerator[Document, None]: + batches = list(batched(objects, batch_size)) + for idx, batch in enumerate(batches): + logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") + for document in await self._ato_documents_batch( + batch, embedding_backend=embedding_backend + ): + yield document + class EmbeddableFieldsObjectChunkerOperator( ObjectChunkerOperator[EmbeddableFieldsMixin] @@ -412,15 +546,15 @@ def chunk_object( important_content = [] embedding_fields = object._meta.model._get_embedding_fields() - for field in embedding_fields: - value = field.get_value(object) + for field_ in embedding_fields: + value = field_.get_value(object) if value is None: continue if isinstance(value, str): final_value = value else: final_value: str = "\n".join((str(v) for v in value)) - if field.important: + if field_.important: important_content.append(final_value) else: splittable_content.append(final_value) @@ -483,7 +617,7 @@ def get_documents(self) -> Iterable[Document]: # Embedding models are created, even if it is not consumed # by the caller all_documents += list( - self.get_converter().bulk_to_documents( + self.get_converter().to_documents( queryset, embedding_backend=self.get_embedding_backend() ) ) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index a9dbd37..39074e5 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -1,4 +1,5 @@ import operator +from collections.abc import AsyncGenerator from functools import reduce from typing import cast @@ -15,10 +16,31 @@ def for_key(self, object_key: str): # so we use icontains which just does a string search return self.filter(object_keys__icontains=object_key) + async def afor_key(self, object_key: str) -> AsyncGenerator["Document", None]: + if connection.vendor != "sqlite": + async for doc in self.filter(object_keys__contains=[object_key]): + yield doc + else: + # SQLite doesn't support the __contains lookup for JSON fields + # so we use icontains which just does a string search + async for doc in self.filter(object_keys__icontains=object_key): + yield doc + def for_keys(self, object_keys: list[str]): + if not object_keys: + return self.none() q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] return self.filter(reduce(operator.or_, q_objs)) + async def afor_keys( + self, object_keys: list[str] + ) -> AsyncGenerator["Document", None]: + if not object_keys: + return + q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] + async for doc in self.filter(reduce(operator.or_, q_objs)): + yield doc + @classmethod def as_manager(cls) -> "DocumentManager": return cast(DocumentManager, super().as_manager()) @@ -30,6 +52,10 @@ def for_key(self, object_key: str) -> DocumentQuerySet: ... def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... + def afor_key(self, object_key: str) -> AsyncGenerator["Document", None]: ... + + def afor_keys(self, object_keys: list[str]) -> AsyncGenerator["Document", None]: ... + class Document(models.Model): """Stores an embedding for an arbitrary chunk""" diff --git a/src/wagtail_vector_index/storage/numpy/provider.py b/src/wagtail_vector_index/storage/numpy/provider.py index 02de3f5..42f7451 100644 --- a/src/wagtail_vector_index/storage/numpy/provider.py +++ b/src/wagtail_vector_index/storage/numpy/provider.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Generator, Iterable, Sequence +from collections.abc import AsyncGenerator, Generator, Iterable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING @@ -58,6 +58,15 @@ def get_similar_documents( for document in [pair[1] for pair in sorted_similarities][:limit]: yield document + async def aget_similar_documents( + self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0 + ) -> AsyncGenerator["Document", None]: + documents = self.get_similar_documents( + query_vector, limit=limit, similarity_threshold=similarity_threshold + ) + for document in documents: + yield document + class NumpyStorageProvider(StorageProvider[ProviderConfig, NumpyIndexMixin]): config_class = ProviderConfig diff --git a/tests/async_factory.py b/tests/async_factory.py new file mode 100644 index 0000000..79dc773 --- /dev/null +++ b/tests/async_factory.py @@ -0,0 +1,177 @@ +# ruff: noqa +# From https://github.com/Andrew-Chen-Wang/factory-boy-django-async + + +import inspect + +import factory +from asgiref.sync import sync_to_async +from django.db import IntegrityError +from factory import errors +from factory.builder import BuildStep, StepBuilder, parse_declarations + + +def use_postgeneration_results(self, step, instance, results): + return self.factory._after_postgeneration( + instance, + create=step.builder.strategy == factory.enums.CREATE_STRATEGY, + results=results, + ) + + +factory.base.FactoryOptions.use_postgeneration_results = use_postgeneration_results + + +class AsyncFactory(factory.django.DjangoModelFactory): + @classmethod + async def _generate(cls, strategy, params): + if cls._meta.abstract: + raise factory.errors.FactoryError( + "Cannot generate instances of abstract factory %(f)s; " + "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract " + "is either not set or False." % dict(f=cls.__name__) + ) + + step = AsyncStepBuilder(cls._meta, params, strategy) + return await step.build() + + class Meta: + abstract = True # Optional, but explicit. + + @classmethod + async def _get_or_create(cls, model_class, *args, **kwargs): + """Create an instance of the model through objects.get_or_create.""" + manager = cls._get_manager(model_class) + + assert "defaults" not in cls._meta.django_get_or_create, ( + "'defaults' is a reserved keyword for get_or_create " + "(in %s._meta.django_get_or_create=%r)" + % (cls, cls._meta.django_get_or_create) + ) + + key_fields = {} + for field in cls._meta.django_get_or_create: + if field not in kwargs: + raise errors.FactoryError( + "django_get_or_create - " + "Unable to find initialization value for '%s' in factory %s" + % (field, cls.__name__) + ) + key_fields[field] = kwargs.pop(field) + key_fields["defaults"] = kwargs + + try: + instance, _created = await manager.aget_or_create(*args, **key_fields) + except IntegrityError as e: + get_or_create_params = { + lookup: value + for lookup, value in cls._original_params.items() + if lookup in cls._meta.django_get_or_create + } + if get_or_create_params: + try: + instance = await manager.aget(**get_or_create_params) + except manager.model.DoesNotExist: + # Original params are not a valid lookup and triggered a create(), + # that resulted in an IntegrityError. Follow Django’s behavior. + raise e + else: + raise e + + return instance + + @classmethod + async def _create(cls, model_class, *args, **kwargs): + """Create an instance of the model, and save it to the database.""" + if cls._meta.django_get_or_create: + return await cls._get_or_create(model_class, *args, **kwargs) + + manager = cls._get_manager(model_class) + return await manager.acreate(*args, **kwargs) + + @classmethod + async def create_batch(cls, size, **kwargs): + """Create a batch of instances of the model, and save them to the database.""" + return [await cls.create(**kwargs) for _ in range(size)] + + @classmethod + async def _after_postgeneration(cls, instance, create, results=None): + """Save again the instance if creating and at least one hook ran.""" + if create and results: + # Some post-generation hooks ran, and may have modified us. + if hasattr(instance, "asave"): + await instance.asave() + else: + await sync_to_async(instance.save)() + + +class AsyncBuildStep(BuildStep): + async def resolve(self, declarations): + self.stub = factory.builder.Resolver( + declarations=declarations, + step=self, + sequence=self.sequence, + ) + + for field_name in declarations: + attr = getattr(self.stub, field_name) + if inspect.isawaitable(attr): + attr = await attr + self.attributes[field_name] = attr + + +class AsyncStepBuilder(StepBuilder): + # Redefine build function that await for instance creation and awaitable postgenerations + async def build(self, parent_step=None, force_sequence=None): + """Build a factory instance.""" + # TODO: Handle "batch build" natively + pre, post = parse_declarations( + self.extras, + base_pre=self.factory_meta.pre_declarations, + base_post=self.factory_meta.post_declarations, + ) + + if force_sequence is not None: + sequence = force_sequence + elif self.force_init_sequence is not None: + sequence = self.force_init_sequence + else: + sequence = self.factory_meta.next_sequence() + + step = AsyncBuildStep( + builder=self, + sequence=sequence, + parent_step=parent_step, + ) + await step.resolve(pre) + + args, kwargs = self.factory_meta.prepare_arguments(step.attributes) + + instance = self.factory_meta.instantiate( + step=step, + args=args, + kwargs=kwargs, + ) + if inspect.isawaitable(instance): + instance = await instance + + postgen_results = {} + for declaration_name in post.sorted(): + declaration = post[declaration_name] + declaration_result = declaration.declaration.evaluate_post( + instance=instance, + step=step, + overrides=declaration.context, + ) + if inspect.isawaitable(declaration_result): + declaration_result = await declaration_result + postgen_results[declaration_name] = declaration_result + + postgen = self.factory_meta.use_postgeneration_results( + instance=instance, + step=step, + results=postgen_results, + ) + if inspect.isawaitable(postgen): + await postgen + return instance diff --git a/tests/conftest.py b/tests/conftest.py index 09e57b2..3b3bb12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,3 +65,37 @@ def use_mock_ai_backend(settings): } }, } + + +@pytest.fixture +def get_vector_for_text(): + def _get_vector_for_text(text): + if "Very similar" in text: + return [0.9, 0.1, 0.0] + elif "Somewhat similar" in text: + return [0.7, 0.3, 0.0] + elif "test" in text.lower(): + return [1.0, 0.0, 0.0] + else: + return [0.1, 0.1, 0.8] + + return _get_vector_for_text + + +@pytest.fixture +def mock_embedding_backend(get_vector_for_text): + class MockEmbeddingBackend(BaseEmbeddingBackend): + def __init__(self): + self.config = type("Config", (), {"token_limit": 100})() + + def embed(self, texts): + def embedding_generator(): + for text in texts: + yield get_vector_for_text(text) + + return embedding_generator() + + async def aembed(self, texts): + return self.embed(texts) + + return MockEmbeddingBackend() diff --git a/tests/factories.py b/tests/factories.py index fdbe608..484856d 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,5 +1,6 @@ import factory import wagtail_factories +from async_factory import AsyncFactory from faker import Faker from testapp.models import DifferentPage, ExampleModel, ExamplePage from wagtail_vector_index.storage.models import Document @@ -15,6 +16,14 @@ class Meta: model = ExampleModel +class AsyncExampleModelFactory(AsyncFactory): + title = factory.Faker("sentence") + body = factory.LazyFunction(lambda: "\n".join(fake.paragraphs())) + + class Meta: + model = ExampleModel + + class ExamplePageFactory(wagtail_factories.PageFactory): class Meta: model = ExamplePage diff --git a/tests/test_django_converter.py b/tests/test_django_converter.py index 768cd91..3d880d4 100644 --- a/tests/test_django_converter.py +++ b/tests/test_django_converter.py @@ -1,6 +1,8 @@ import factory import pytest +from asgiref.sync import sync_to_async from factories import ( + AsyncExampleModelFactory, DifferentPageFactory, DocumentFactory, ExampleModelFactory, @@ -14,14 +16,19 @@ EmbeddableFieldsObjectChunkerOperator, EmbeddingField, ModelFromDocumentOperator, + ModelKey, ModelLabel, ModelToDocumentOperator, + PreparedObject, + PreparedObjectCollection, ) +from wagtail_vector_index.storage.models import Document fake = Faker() class TestChunking: + @pytest.mark.django_db def test_get_chunks_splits_content_into_multiple_chunks( self, patch_embedding_fields ): @@ -32,6 +39,7 @@ def test_get_chunks_splits_content_into_multiple_chunks( chunks = chunker.chunk_object(instance, chunk_size=100) assert len(chunks) > 1 + @pytest.mark.django_db def test_get_chunks_adds_important_field_to_each_chunk( self, patch_embedding_fields ): @@ -181,25 +189,6 @@ def test_bulk_from_documents_returns_deduplicated_model_objects(self): class TestToDocument: - def test_existing_documents_match(self): - text_contents = ["This is a test", "Another test", "More testing content"] - documents = [ - DocumentFactory.build(content=content) for content in text_contents - ] - operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - assert operator._existing_documents_match(documents, text_contents) - - @pytest.mark.django_db - def test_keys_for_instance(self): - instance = ExamplePageFactory.create( - title="Important Title", body=fake.text(max_nb_chars=200) - ) - operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - keys = operator._keys_for_instance(instance) - assert len(keys) == 2 - assert keys[0] == f"testapp.ExamplePage:{instance.pk}" - assert keys[1] == f"wagtailcore.Page:{instance.pk}" - @pytest.mark.django_db def test_generate_documents_returns_documents(self): instance = ExamplePageFactory.create( @@ -208,7 +197,7 @@ def test_generate_documents_returns_documents(self): operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) documents = list( operator.to_documents( - instance, embedding_backend=get_embedding_backend("default") + [instance], embedding_backend=get_embedding_backend("default") ) ) assert len(documents) == 1 @@ -219,7 +208,7 @@ def test_bulk_generate_documents_returns_documents(self): instances = ExamplePageFactory.create_batch(3) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) documents = list( - operator.bulk_to_documents( + operator.to_documents( instances, embedding_backend=get_embedding_backend("default") ) ) @@ -235,7 +224,7 @@ def test_bulk_generate_documents_returns_documents_for_multiple_models(self): different_pages = DifferentPageFactory.create_batch(3) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) documents = list( - operator.bulk_to_documents( + operator.to_documents( example_pages + different_pages, embedding_backend=get_embedding_backend("default"), ) @@ -249,18 +238,18 @@ def test_bulk_generate_documents_returns_documents_for_multiple_models(self): ) @pytest.mark.django_db - def test_bulk_to_documents_batches_objects(self, mocker): + def test_to_documents_batches_objects(self, mocker): instances = ExamplePageFactory.create_batch(10) operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) - bulk_generate_mock = mocker.patch.object(operator, "bulk_generate_documents") + to_documents_batch_mock = mocker.patch.object(operator, "_to_documents_batch") list( - operator.bulk_to_documents( + operator.to_documents( instances, embedding_backend=get_embedding_backend("default"), batch_size=2, ) ) - assert bulk_generate_mock.call_count == 5 + assert to_documents_batch_mock.call_count == 5 class TestConverter: @@ -273,7 +262,7 @@ def test_returns_original_object(self, patch_embedding_fields): converter = EmbeddableFieldsDocumentConverter() document = next( converter.to_documents( - instance, embedding_backend=get_embedding_backend("default") + [instance], embedding_backend=get_embedding_backend("default") ) ) recovered_instance = converter.from_document(document) @@ -289,7 +278,7 @@ def test_convert_single_document_to_object(): ) documents = list( converter.to_documents( - instance, embedding_backend=get_embedding_backend("default") + [instance], embedding_backend=get_embedding_backend("default") ) ) recovered_instance = converter.from_document(documents[0]) @@ -305,9 +294,173 @@ def test_convert_multiple_documents_to_objects(): different_pages = DifferentPageFactory.create_batch(5) all_objects = list(example_objects + example_pages + different_pages) documents = list( - converter.bulk_to_documents( + converter.to_documents( all_objects, embedding_backend=get_embedding_backend("default") ) ) recovered_objects = list(converter.bulk_from_documents(documents)) assert recovered_objects == all_objects + + +class TestToDocumentOperatorAsync: + @pytest.mark.django_db(transaction=True) + async def test_ato_documents_batch(self, mock_embedding_backend): + instance = await AsyncExampleModelFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + + documents = await operator._ato_documents_batch( + [instance], embedding_backend=mock_embedding_backend + ) + + assert len(documents) > 0 + assert all(isinstance(doc, Document) for doc in documents) + assert all(instance.title in doc.content for doc in documents) + + @pytest.mark.django_db(transaction=True) + async def test_aupdate_object_collection_with_new_documents( + self, mock_embedding_backend + ): + instance = await AsyncExampleModelFactory.create() + collection = await sync_to_async(PreparedObjectCollection.prepare_objects)( + objects=[instance], + chunker_operator=EmbeddableFieldsObjectChunkerOperator(), + embedding_backend=mock_embedding_backend, + ) + + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + await operator._aupdate_object_collection_with_new_documents( + collection, mock_embedding_backend + ) + + assert any(obj.new_documents for obj in collection) + assert all( + isinstance(doc, Document) for obj in collection for doc in obj.new_documents + ) + + +class TestPreparedObject: + @pytest.mark.django_db + def test_needs_updating_when_no_existing_documents(self): + instance = ExamplePageFactory.build() + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + ) + assert prepared_object.needs_updating is True + + @pytest.mark.django_db + def test_needs_updating_when_chunks_match(self): + instance = ExamplePageFactory.build() + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + existing_documents=[ + DocumentFactory.build(content="chunk1"), + DocumentFactory.build(content="chunk2"), + ], + ) + assert prepared_object.needs_updating is False + + @pytest.mark.django_db + def test_needs_updating_when_chunks_differ(self): + instance = ExamplePageFactory.build() + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + existing_documents=[ + DocumentFactory.build(content="chunk1"), + DocumentFactory.build(content="different chunk"), + ], + ) + assert prepared_object.needs_updating is True + + @pytest.mark.django_db + def test_documents_returns_new_documents_when_present(self): + instance = ExamplePageFactory.build() + new_docs = [DocumentFactory.build(), DocumentFactory.build()] + existing_docs = [DocumentFactory.build(), DocumentFactory.build()] + + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1"], + new_documents=new_docs, + existing_documents=existing_docs, + ) + assert prepared_object.documents == new_docs + + @pytest.mark.django_db + def test_documents_returns_existing_documents_when_no_new_ones(self): + instance = ExamplePageFactory.build() + existing_docs = [DocumentFactory.build(), DocumentFactory.build()] + + prepared_object = PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1"], + existing_documents=existing_docs, + ) + assert prepared_object.documents == existing_docs + + +class TestPreparedObjectCollection: + @pytest.mark.django_db + def test_prepare_objects(self, patch_embedding_fields): + with patch_embedding_fields(ExamplePage, [EmbeddingField("body")]): + instances = ExamplePageFactory.create_batch(3) + chunker = EmbeddableFieldsObjectChunkerOperator() + + collection = PreparedObjectCollection.prepare_objects( + objects=instances, + chunker_operator=chunker, + embedding_backend=get_embedding_backend("default"), + ) + + assert len(collection.objects) == 3 + assert all(isinstance(obj, PreparedObject) for obj in collection) + assert all(obj.chunks for obj in collection) + + @pytest.mark.django_db + def test_get_chunk_mapping(self): + instances = ExamplePageFactory.build_batch(2) + prepared_objects = [ + PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + ) + for instance in instances + ] + collection = PreparedObjectCollection(objects=prepared_objects) + + chunk_mapping = collection.get_chunk_mapping() + assert len(chunk_mapping) == 4 # 2 instances * 2 chunks each + assert all(isinstance(key, ModelKey) for key in chunk_mapping) + + @pytest.mark.django_db + def test_prepare_new_documents(self): + instances = ExamplePageFactory.create_batch(2) + prepared_objects = [ + PreparedObject( + key=ModelKey.from_instance(instance), + object=instance, + chunks=["chunk1", "chunk2"], + ) + for instance in instances + ] + collection = PreparedObjectCollection(objects=prepared_objects) + + # Mock embedding vectors + embedding_vectors = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] + collection.prepare_new_documents(embedding_vectors) + + assert all(len(obj.new_documents) == 2 for obj in collection) + assert all( + all(isinstance(doc, Document) for doc in obj.new_documents) + for obj in collection + ) diff --git a/tests/test_index.py b/tests/test_index.py index a9cc296..d367f19 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -1,13 +1,14 @@ import unittest import pytest -from factories import DifferentPageFactory, ExamplePageFactory -from faker import Faker -from testapp.models import DifferentPage, ExamplePage -from wagtail_vector_index.ai_utils.backends.base import BaseEmbeddingBackend -from wagtail_vector_index.storage import ( - registry, +from factories import ( + DifferentPageFactory, + ExampleModelFactory, + ExamplePageFactory, ) +from faker import Faker +from testapp.models import DifferentPage, ExampleModel, ExamplePage +from wagtail_vector_index.storage import registry from wagtail_vector_index.storage.base import VectorIndex from wagtail_vector_index.storage.django import EmbeddingField, ModelKey from wagtail_vector_index.storage.models import Document @@ -15,50 +16,25 @@ fake = Faker() -def get_vector_for_text(text): - if "Very similar" in text: - return [0.9, 0.1, 0.0] - elif "Somewhat similar" in text: - return [0.7, 0.3, 0.0] - elif "test" in text.lower(): - return [1.0, 0.0, 0.0] - else: - return [0.1, 0.1, 0.8] - - -@pytest.fixture -def mock_embedding_backend(): - class MockEmbeddingBackend(BaseEmbeddingBackend): - def embed(self, texts): - def embedding_generator(): - for text in texts: - yield get_vector_for_text(text) - - return embedding_generator() - - return MockEmbeddingBackend - - @pytest.fixture -def test_pages(): +def test_objects(): return [ - ExamplePageFactory(title="Very similar to test"), - ExamplePageFactory(title="Somewhat similar to test"), - ExamplePageFactory(title="Not similar at all"), + ExampleModelFactory.create(title="Very similar to test"), + ExampleModelFactory.create(title="Somewhat similar to test"), + ExampleModelFactory.create(title="Not similar at all"), ] @pytest.fixture -def document_generator(test_pages): +def document_generator(test_objects, get_vector_for_text): def gen_documents(cls, *args, **kwargs): - for page in test_pages: - vector = get_vector_for_text(page.title) + for obj in test_objects: + vector = get_vector_for_text(obj.title) yield Document( - object_keys=[ModelKey.from_instance(page)], + object_keys=[ModelKey.from_instance(obj)], metadata={ - "title": page.title, - "object_id": str(page.pk), - "content_type_id": str(page.get_content_type().id), + "title": obj.title, + "object_id": str(obj.pk), }, vector=vector, ) @@ -67,218 +43,254 @@ def gen_documents(cls, *args, **kwargs): @pytest.fixture -def mock_vector_index(mocker, mock_embedding_backend, document_generator): +def async_document_generator(test_objects, get_vector_for_text): + async def gen_documents(cls, *args, **kwargs): + for obj in test_objects: + vector = get_vector_for_text(obj.title) + yield Document( + object_keys=[ModelKey.from_instance(obj)], + metadata={"title": obj.title, "object_id": str(obj.pk)}, + vector=vector, + ) + + return gen_documents + + +@pytest.fixture +def mock_vector_index( + mocker, mock_embedding_backend, document_generator, async_document_generator +): vector_index = ExamplePage.vector_index - mock_backend = mock_embedding_backend(config=mocker.Mock()) mocker.patch.object( - vector_index, "get_embedding_backend", return_value=mock_backend + vector_index, "get_embedding_backend", return_value=mock_embedding_backend ) mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_to_documents", + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.to_documents", side_effect=document_generator, ) - return vector_index - - -def test_registry(): - expected_class_names = [ - "ExamplePageIndex", - "ExampleModelIndex", - "DifferentPageIndex", - "MultiplePageVectorIndex", - ] - assert set(registry._registry.keys()) == set(expected_class_names) - - -def test_indexed_model_has_vector_index(): - index = ExamplePage.vector_index - assert index.__class__.__name__ == "ExamplePageIndex" - - -def test_register_custom_vector_index(): - custom_index = type("MyVectorIndex", (VectorIndex,), {})() - registry.register_index(custom_index) - assert registry["MyVectorIndex"] == custom_index - - -def test_get_embedding_fields_count(patch_embedding_fields): - with patch_embedding_fields( - ExamplePage, [EmbeddingField("test"), EmbeddingField("another_test")] - ): - assert len(ExamplePage._get_embedding_fields()) == 2 - - -def test_embedding_fields_override(patch_embedding_fields): - # In the same vein as Wagtail's search index fields, if there are - # multiple fields of the same type with the same name, only one - # should be returned - with patch_embedding_fields( - ExamplePage, [EmbeddingField("test"), EmbeddingField("test")] - ): - assert len(ExamplePage._get_embedding_fields()) == 1 - - -def test_checking_search_fields_errors_with_invalid_field(patch_embedding_fields): - with patch_embedding_fields(ExamplePage, [EmbeddingField("foo")]): - errors = ExamplePage.check() - assert "wagtailai.WA001" in [error.id for error in errors] - - -@pytest.mark.django_db -def test_index_get_documents_returns_at_least_one_document_per_page(): - pages = ExamplePageFactory.create_batch(10) - index = registry["ExamplePageIndex"] - index.rebuild_index() - documents = index.get_documents() - found_pages = { - ModelKey(document.object_keys[0]).object_id for document in documents - } - - assert found_pages == {str(page.pk) for page in pages} - - -@pytest.mark.django_db -def test_index_with_multiple_models(): - example_pages = ExamplePageFactory.create_batch(5) - different_pages = DifferentPageFactory.create_batch(5) - index = registry["MultiplePageVectorIndex"] - index.rebuild_index() - - example_pages_ids = {str(page.pk) for page in example_pages} - different_page_ids = {str(page.pk) for page in different_pages} - found_page_ids = { - ModelKey(document.object_keys[0]).object_id - for document in index.get_documents() - } - - assert found_page_ids == example_pages_ids.union(different_page_ids) - - similar_result = list(index.find_similar(DifferentPage.objects.first())) - assert len(similar_result) > 1 - for p in similar_result: - assert isinstance(p, (ExamplePage, DifferentPage)) - - search_result = list(index.search("test")) - assert len(search_result) > 1 - for p in search_result: - assert isinstance(p, (ExamplePage, DifferentPage)) - - -@pytest.mark.django_db -def test_similar_returns_no_duplicates(mocker): - pages = ExamplePageFactory.create_batch(10) - vector_index = ExamplePage.vector_index - - def gen_pages(cls, *args, **kwargs): - yield from pages - mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", - side_effect=gen_pages, + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.ato_documents", + side_effect=async_document_generator, ) - case = unittest.TestCase() - - # We expect 9 results without the page itself. - actual = vector_index.find_similar(pages[0], limit=100, include_self=False) - case.assertCountEqual(actual, pages[1:]) - - # We expect 10 results with the page itself. - actual = vector_index.find_similar(pages[0], limit=100, include_self=True) - case.assertCountEqual(actual, pages) - - -@pytest.mark.django_db -def test_query_passes_sources_to_backend(mocker): - ExamplePageFactory.create_batch(2) - index = ExamplePage.vector_index - documents = index.get_documents()[:2] - - def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.0): - yield from documents - - query_mock = mocker.patch("conftest.ChatMockBackend.chat") - expected_content = "\n".join([doc.content for doc in documents]) - similar_documents_mock = mocker.patch.object(index, "get_similar_documents") - similar_documents_mock.side_effect = get_similar_documents - index.query("") - first_call_messages = query_mock.call_args.kwargs["messages"] - assert first_call_messages[1] == {"content": expected_content, "role": "system"} - - -@pytest.mark.django_db -def test_query_with_similarity_threshold(mocker): - ExamplePageFactory.create_batch(2) - index = ExamplePage.vector_index - documents = index.get_documents()[:2] - - def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5): - yield from documents - - query_mock = mocker.patch("conftest.ChatMockBackend.chat") - expected_content = "\n".join([doc.content for doc in documents]) - similar_documents_mock = mocker.patch.object(index, "get_similar_documents") - similar_documents_mock.side_effect = get_similar_documents - index.query("", similarity_threshold=0.5) - first_call_messages = query_mock.call_args.kwargs["messages"] - assert first_call_messages[1] == {"content": expected_content, "role": "system"} - - -@pytest.mark.django_db -def test_find_similar_with_similarity_threshold(mocker): - pages = ExamplePageFactory.create_batch(10) - vector_index = ExamplePage.vector_index + return vector_index - def gen_pages(cls, *args, **kwargs): - yield from pages - mocker.patch( - "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", - side_effect=gen_pages, - ) - - # We expect 9 results without the page itself. - actual = vector_index.find_similar( - pages[0], limit=100, include_self=False, similarity_threshold=0.5 +class TestRegistry: + def test_registry(self): + expected_class_names = [ + "ExamplePageIndex", + "ExampleModelIndex", + "DifferentPageIndex", + "MultiplePageVectorIndex", + ] + assert set(registry._registry.keys()) == set(expected_class_names) + + def test_indexed_model_has_vector_index(self): + index = ExamplePage.vector_index + assert index.__class__.__name__ == "ExamplePageIndex" + + def test_register_custom_vector_index(self): + custom_index = type("MyVectorIndex", (VectorIndex,), {})() + registry.register_index(custom_index) + assert registry["MyVectorIndex"] == custom_index + + +class TestEmbeddingFields: + def test_get_embedding_fields_count(self, patch_embedding_fields): + with patch_embedding_fields( + ExamplePage, [EmbeddingField("test"), EmbeddingField("another_test")] + ): + assert len(ExamplePage._get_embedding_fields()) == 2 + + def test_embedding_fields_override(self, patch_embedding_fields): + with patch_embedding_fields( + ExamplePage, [EmbeddingField("test"), EmbeddingField("test")] + ): + assert len(ExamplePage._get_embedding_fields()) == 1 + + def test_checking_search_fields_errors_with_invalid_field( + self, patch_embedding_fields + ): + with patch_embedding_fields(ExamplePage, [EmbeddingField("foo")]): + errors = ExamplePage.check() + assert "wagtailai.WA001" in [error.id for error in errors] + + +class TestIndexOperations: + @pytest.mark.django_db + def test_index_get_documents_returns_at_least_one_document_per_page(self): + pages = ExampleModelFactory.create_batch(10) + index = registry["ExampleModelIndex"] + index.rebuild_index() + documents = index.get_documents() + found_pages = { + ModelKey(document.object_keys[0]).object_id for document in documents + } + + assert found_pages == {str(page.pk) for page in pages} + + @pytest.mark.django_db + def test_index_with_multiple_models(self): + example_pages = ExamplePageFactory.create_batch(5) + different_pages = DifferentPageFactory.create_batch(5) + index = registry["MultiplePageVectorIndex"] + index.rebuild_index() + + example_pages_ids = {str(page.pk) for page in example_pages} + different_page_ids = {str(page.pk) for page in different_pages} + found_page_ids = { + ModelKey(document.object_keys[0]).object_id + for document in index.get_documents() + } + + assert found_page_ids == example_pages_ids.union(different_page_ids) + + similar_result = list(index.find_similar(DifferentPage.objects.first())) + assert len(similar_result) > 1 + for p in similar_result: + assert isinstance(p, (ExamplePage, DifferentPage)) + + search_result = list(index.search("test")) + assert len(search_result) > 1 + for p in search_result: + assert isinstance(p, (ExamplePage, DifferentPage)) + + +class TestSimilarityOperations: + @pytest.mark.django_db + def test_similar_returns_no_duplicates(self, mocker): + pages = ExampleModelFactory.create_batch(10) + vector_index = ExamplePage.vector_index + + def gen_pages(cls, *args, **kwargs): + yield from pages + + mocker.patch( + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", + side_effect=gen_pages, + ) + + case = unittest.TestCase() + + # We expect 9 results without the page itself. + actual = vector_index.find_similar(pages[0], limit=100, include_self=False) + case.assertCountEqual(actual, pages[1:]) + + # We expect 10 results with the page itself. + actual = vector_index.find_similar(pages[0], limit=100, include_self=True) + case.assertCountEqual(actual, pages) + + @pytest.mark.django_db + def test_find_similar_with_similarity_threshold(self, mocker): + pages = ExampleModelFactory.create_batch(10) + vector_index = ExamplePage.vector_index + + def gen_pages(cls, *args, **kwargs): + yield from pages + + mocker.patch( + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", + side_effect=gen_pages, + ) + + # We expect 9 results without the page itself. + actual = vector_index.find_similar( + pages[0], limit=100, include_self=False, similarity_threshold=0.5 + ) + assert set(actual) == set(pages[1:]), f"Expected {pages[1:]}, but got {actual}" + + # We expect 10 results with the page itself. + actual = vector_index.find_similar( + pages[0], limit=100, include_self=True, similarity_threshold=0.5 + ) + assert set(actual) == set(pages), f"Expected {pages}, but got {actual}" + + @pytest.mark.django_db(transaction=True) + async def test_afind_similar(self, mock_vector_index): + objs = [obj async for obj in ExampleModel.objects.all()] + actual = await mock_vector_index.afind_similar( + objs[0], limit=100, include_self=False + ) + assert set(actual) == set(objs[1:]), f"Expected {objs[1:]}, but got {actual}" + + +class TestQueryOperations: + @pytest.mark.django_db + def test_query_passes_sources_to_backend(self, mocker): + ExampleModelFactory.create_batch(2) + index = ExamplePage.vector_index + documents = index.get_documents()[:2] + + def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.0): + yield from documents + + query_mock = mocker.patch("conftest.ChatMockBackend.chat") + expected_content = "\n".join([doc.content for doc in documents]) + similar_documents_mock = mocker.patch.object(index, "get_similar_documents") + similar_documents_mock.side_effect = get_similar_documents + index.query("") + first_call_messages = query_mock.call_args.kwargs["messages"] + assert first_call_messages[1] == {"content": expected_content, "role": "system"} + + @pytest.mark.django_db + def test_query_with_similarity_threshold(self, mocker): + ExampleModelFactory.create_batch(2) + index = ExamplePage.vector_index + documents = index.get_documents()[:2] + + def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5): + yield from documents + + query_mock = mocker.patch("conftest.ChatMockBackend.chat") + expected_content = "\n".join([doc.content for doc in documents]) + similar_documents_mock = mocker.patch.object(index, "get_similar_documents") + similar_documents_mock.side_effect = get_similar_documents + index.query("", similarity_threshold=0.5) + first_call_messages = query_mock.call_args.kwargs["messages"] + assert first_call_messages[1] == {"content": expected_content, "role": "system"} + + +class TestSearchOperations: + @pytest.mark.django_db + @pytest.mark.parametrize( + "similarity_threshold, expected_count, expected_titles", + [ + (0.9, 0, set()), + (0.6, 1, {"Very similar to test"}), + (0.1, 2, {"Very similar to test", "Somewhat similar to test"}), + ( + None, + 3, + { + "Very similar to test", + "Somewhat similar to test", + "Not similar at all", + }, + ), + ], ) - assert set(actual) == set(pages[1:]), f"Expected {pages[1:]}, but got {actual}" + def test_search_with_similarity_threshold( + self, mock_vector_index, similarity_threshold, expected_count, expected_titles + ): + kwargs = {"limit": 100} + if similarity_threshold is not None: + kwargs["similarity_threshold"] = similarity_threshold - # We expect 10 results with the page itself. - actual = vector_index.find_similar( - pages[0], limit=100, include_self=True, similarity_threshold=0.5 - ) - assert set(actual) == set(pages), f"Expected {pages}, but got {actual}" - - -@pytest.mark.django_db -@pytest.mark.parametrize( - "similarity_threshold, expected_count, expected_titles", - [ - (0.9, 0, set()), - (0.6, 1, {"Very similar to test"}), - (0.1, 2, {"Very similar to test", "Somewhat similar to test"}), - ( - None, - 3, - {"Very similar to test", "Somewhat similar to test", "Not similar at all"}, - ), - ], -) -def test_search_with_similarity_threshold( - mock_vector_index, similarity_threshold, expected_count, expected_titles -): - kwargs = {"limit": 100} - if similarity_threshold is not None: - kwargs["similarity_threshold"] = similarity_threshold + results = list(mock_vector_index.search("test", **kwargs)) - results = list(mock_vector_index.search("test", **kwargs)) + assert ( + len(results) == expected_count + ), f"Expected {expected_count} results, got {len(results)}" - assert ( - len(results) == expected_count - ), f"Expected {expected_count} results, got {len(results)}" + if expected_count > 0: + assert {result.title for result in results} == expected_titles - if expected_count > 0: - assert {result.title for result in results} == expected_titles + @pytest.mark.django_db(transaction=True) + async def test_asearch(self, mock_vector_index): + objs = [obj async for obj in ExampleModel.objects.all()] + actual = await mock_vector_index.asearch("test", limit=100) + assert set(actual) == set(objs), f"Expected {objs}, but got {actual}" diff --git a/tests/test_model_index.py b/tests/test_model_index.py index 1755211..b50617f 100644 --- a/tests/test_model_index.py +++ b/tests/test_model_index.py @@ -42,8 +42,8 @@ def get_index(self): return registry["ExamplePageIndex"] -def test_rebuilding_model_index_creates_embeddings(): - ExamplePageFactory.create_batch(10) +def test_rebuilding_model_index_creates_documents(): + ExamplePageFactory.create_batch(10, body=fake.text(max_nb_chars=10)) index = ExamplePage.vector_index index.rebuild_index() assert Document.objects.count() == 10 diff --git a/tests/test_model_to_document_operator.py b/tests/test_model_to_document_operator.py new file mode 100644 index 0000000..e69de29