Skip to content
Open
101 changes: 88 additions & 13 deletions src/wagtail_vector_index/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def upsert(self) -> None: ...

def get_documents(self) -> Iterable["Document"]: ...

async def aget_documents(self) -> AsyncGenerator["Document", None]: ...

def _get_storage_provider(self) -> StorageProviderClass: ...


Expand Down Expand Up @@ -76,6 +78,8 @@ class FromDocumentOperator(Protocol[ToObjectType]):

def from_document(self, document: "Document") -> ToObjectType: ...

async def afrom_document(self, document: "Document") -> ToObjectType: ...

def bulk_from_documents(
self, documents: Iterable["Document"]
) -> Generator[ToObjectType, None, None]: ...
Expand All @@ -99,19 +103,30 @@ class ToDocumentOperator(Protocol[FromObjectType]):
def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): ...

def to_documents(
self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend
self,
objects: Iterable[FromObjectType],
*,
embedding_backend: BaseEmbeddingBackend,
batch_size: int = 100,
) -> Generator["Document", None, None]: ...

def bulk_to_documents(
async def ato_documents(
self,
objects: Iterable[FromObjectType],
*,
embedding_backend: BaseEmbeddingBackend,
) -> Generator["Document", None, None]: ...
batch_size: int = 100,
) -> AsyncGenerator["Document", None]: ...


class DocumentConverter(ABC):
"""Base class for a DocumentConverter that can convert objects to Documents and vice versa"""
"""Base class for a DocumentConverter that can convert objects to Documents and vice versa

Note on async methods:
Some async (a-prefixed) methods in this class return AsyncGenerators directly from
the methods of the to_document_operator or from_document_operator, so they aren't marked
with the 'async' keyword to prevent the methods from being wrapped in a Coroutine.
"""

to_document_operator_class: Type[ToDocumentOperator]
from_document_operator_class: Type[FromDocumentOperator]
Expand All @@ -126,21 +141,32 @@ def from_document_operator(self) -> FromDocumentOperator:
return self.from_document_operator_class()

def to_documents(
self, object: object, *, embedding_backend: BaseEmbeddingBackend
self,
objects: Iterable[object],
*,
embedding_backend: BaseEmbeddingBackend,
batch_size: int = 100,
) -> Generator["Document", None, None]:
return self.to_document_operator.to_documents(
object, embedding_backend=embedding_backend
objects, embedding_backend=embedding_backend, batch_size=batch_size
)

def ato_documents(
self,
objects: Iterable[object],
*,
embedding_backend: BaseEmbeddingBackend,
batch_size: int = 100,
) -> AsyncGenerator["Document", None]:
return self.to_document_operator.ato_documents(
objects, embedding_backend=embedding_backend, batch_size=batch_size
)

def from_document(self, document: "Document") -> object:
return self.from_document_operator.from_document(document)

def bulk_to_documents(
self, objects: Iterable[object], *, embedding_backend: BaseEmbeddingBackend
) -> Generator["Document", None, None]:
return self.to_document_operator.bulk_to_documents(
objects, embedding_backend=embedding_backend
)
def afrom_document(self, document: "Document") -> object:
return self.from_document_operator.afrom_document(document)

def bulk_from_documents(
self, documents: Sequence["Document"]
Expand Down Expand Up @@ -186,6 +212,9 @@ def get_embedding_backend(self) -> BaseEmbeddingBackend:
def get_documents(self) -> Iterable["Document"]:
raise NotImplementedError

async def aget_documents(self) -> AsyncGenerator["Document", None]:
raise NotImplementedError

def get_converter(self) -> DocumentConverter:
raise NotImplementedError

Expand Down Expand Up @@ -286,7 +315,7 @@ def find_similar(
"""Find similar objects to the given object"""
converter = self.get_converter()
object_documents: Generator[Document, None, None] = converter.to_documents(
object, embedding_backend=self.get_embedding_backend()
[object], embedding_backend=self.get_embedding_backend()
)
similar_documents = []
for document in object_documents:
Expand All @@ -300,6 +329,31 @@ def find_similar(
if include_self or obj != object
]

async def afind_similar(
self,
object,
*,
include_self: bool = False,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> list:
"""Find similar objects to the given object asynchronously"""
converter = self.get_converter()
similar_documents = []
async for document in converter.ato_documents(
[object], embedding_backend=self.get_embedding_backend()
):
similar_docs = self.aget_similar_documents(
document.vector, limit=limit, similarity_threshold=similarity_threshold
)
similar_documents.extend([doc async for doc in similar_docs])

return [
obj
async for obj in converter.abulk_from_documents(similar_documents)
if include_self or obj != object
]

def search(
self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0
) -> list:
Expand All @@ -315,6 +369,27 @@ def search(
)
return list(self.get_converter().bulk_from_documents(similar_documents))

async def asearch(
self, query: str, *, limit: int = 5, similarity_threshold: float = 0.0
) -> list:
"""Perform a search against the index, returning only a list of matching sources"""
try:
query_embedding = next(await self.get_embedding_backend().aembed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e
similar_documents = [
doc
async for doc in self.aget_similar_documents(
query_embedding, limit=limit, similarity_threshold=similarity_threshold
)
]
return [
obj
async for obj in self.get_converter().abulk_from_documents(
similar_documents
)
]

# Utilities

def _get_storage_provider(self):
Expand Down
Loading