Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ A barebones implementation of `VectorIndex` needs to implement one method; `get_

There are two ways to use Vector Indexes. Either:

- Adding the `VectorIndexedMixin` to a Django model, which will automatically generate an Index for that model
- Adding the `PageVectorIndexedMixin` to a Wagtail Page model, or the `VectorIndexedMixin` to a plain Django model, which will automatically generate an Index for that model.
- Creating your own subclass of one of the `VectorIndex` classes.

## Automatically Generating Indexes using `VectorIndexedMixin`
## Automatically Generating Indexes using `PageVectorIndexedMixin` or `VectorIndexedMixin`

To generate a Vector Index based on an existing model in your application:

1. Add Wagtail AI's `VectorIndexedMixin` mixin to your model
2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings
1. Add Wagtail AI's `PageVectorIndexedMixin` or `VectorIndexedMixin` mixin to your model.
2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings.

```python
from django.db import models
from wagtail.models import Page
from wagtail_vector_index.index import VectorIndexedMixin, EmbeddingField
from wagtail_vector_index.index import PageVectorIndexedMixin, EmbeddingField


class MyPage(VectorIndexedMixin, Page):
class MyPage(PageVectorIndexedMixin, Page):
body = models.TextField()

embedding_fields = [EmbeddingField("title"), EmbeddingField("body")]
Expand Down
8 changes: 4 additions & 4 deletions docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ This way, when you provide a query, we can use the same model to get an embeddin

To index your models:

1. Add Wagtail Vector Index's `VectorIndexedMixin` mixin to your model
2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings
1. Add Wagtail Vector Index's `PageVectorIndexedMixin` mixin to your page model.
2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings.

```python
from django.db import models
from wagtail.models import Page
from wagtail_vector_index.models import VectorIndexedMixin, EmbeddingField
from wagtail_vector_index.models import PageVectorIndexedMixin, EmbeddingField


class MyPage(VectorIndexedMixin, Page):
class MyPage(PageVectorIndexedMixin, Page):
body = models.TextField()

embedding_fields = [EmbeddingField("title"), EmbeddingField("body")]
Expand Down
36 changes: 19 additions & 17 deletions src/wagtail_vector_index/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from django.core import checks
from django.core.exceptions import FieldDoesNotExist
from django.db import models, transaction
from wagtail.models import Page
from wagtail.search.index import BaseField

from wagtail_vector_index.index.base import Document
Expand Down Expand Up @@ -98,7 +97,7 @@ class VectorIndexedMixin(models.Model):
embeddings = GenericRelation(
Embedding, content_type_field="content_type", for_concrete_model=False
)
vector_index_class = None
vector_index_class = ModelVectorIndex

class Meta:
abstract = True
Expand Down Expand Up @@ -240,21 +239,24 @@ def bulk_from_documents(cls, documents):
yield cls.from_document(document)

@classmethod
def get_vector_index(cls):
def get_vector_index(cls) -> ModelVectorIndex:
"""Get a vector index instance for this model"""

# If the user has specified a custom `vector_index_class`, use that
if cls.vector_index_class:
index_cls = cls.vector_index_class
# If the model is a Wagtail Page, use a special PageVectorIndex
elif issubclass(cls, Page):
index_cls = PageVectorIndex
# Otherwise use the standard ModelVectorIndex
else:
index_cls = ModelVectorIndex
name = f"{cls.__name__}Index"
bases = (cls.vector_index_class,)
dict_ = {"querysets": [cls.objects.all()]}
index_class = type(name, bases, dict_)
return index_class(object_type=cls)


class PageVectorIndexedMixin(VectorIndexedMixin):
vector_index_class = PageVectorIndex

class Meta:
abstract = True

@classmethod
def get_vector_index(cls) -> PageVectorIndex:
"""Get a vector index instance for this model"""

return type(
f"{cls.__name__}Index",
(index_cls,),
{"querysets": [cls.objects.all()]},
)(object_type=cls)
return super().get_vector_index() # type: ignore
10 changes: 7 additions & 3 deletions tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
PageVectorIndex,
)
from wagtail_vector_index.index.registry import registry
from wagtail_vector_index.models import EmbeddingField, VectorIndexedMixin
from wagtail_vector_index.models import (
EmbeddingField,
PageVectorIndexedMixin,
VectorIndexedMixin,
)


class ExampleModel(VectorIndexedMixin, models.Model):
Expand All @@ -19,15 +23,15 @@ def __str__(self):
return self.title


class ExamplePage(VectorIndexedMixin, Page):
class ExamplePage(PageVectorIndexedMixin, Page):
body = RichTextField()

content_panels = [*Page.content_panels, FieldPanel("body")]

embedding_fields = [EmbeddingField("title", important=True), EmbeddingField("body")]


class DifferentPage(VectorIndexedMixin, Page):
class DifferentPage(PageVectorIndexedMixin, Page):
body = RichTextField()

content_panels = [*Page.content_panels, FieldPanel("body")]
Expand Down