diff --git a/docs/indexes.md b/docs/indexes.md index a6ee3e0..f0328f0 100644 --- a/docs/indexes.md +++ b/docs/indexes.md @@ -6,23 +6,23 @@ A barebones implementation of `VectorIndex` needs to implement one method; `get_ There are two ways to use Vector Indexes. Either: -- Adding the `VectorIndexedMixin` to a Django model, which will automatically generate an Index for that model +- Adding the `PageVectorIndexedMixin` to a Wagtail Page model, or the `VectorIndexedMixin` to a plain Django model, which will automatically generate an Index for that model. - Creating your own subclass of one of the `VectorIndex` classes. -## Automatically Generating Indexes using `VectorIndexedMixin` +## Automatically Generating Indexes using `PageVectorIndexedMixin` or `VectorIndexedMixin` To generate a Vector Index based on an existing model in your application: -1. Add Wagtail AI's `VectorIndexedMixin` mixin to your model -2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings +1. Add Wagtail AI's `PageVectorIndexedMixin` or `VectorIndexedMixin` mixin to your model. +2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings. ```python from django.db import models from wagtail.models import Page -from wagtail_vector_index.index import VectorIndexedMixin, EmbeddingField +from wagtail_vector_index.index import PageVectorIndexedMixin, EmbeddingField -class MyPage(VectorIndexedMixin, Page): +class MyPage(PageVectorIndexedMixin, Page): body = models.TextField() embedding_fields = [EmbeddingField("title"), EmbeddingField("body")] diff --git a/docs/quick-start.md b/docs/quick-start.md index 2b0e0db..cb35185 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -19,16 +19,16 @@ This way, when you provide a query, we can use the same model to get an embeddin To index your models: -1. Add Wagtail Vector Index's `VectorIndexedMixin` mixin to your model -2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings +1. Add Wagtail Vector Index's `PageVectorIndexedMixin` mixin to your page model. +2. Set `embedding_fields` to a list of `EmbeddingField`s representing the fields you want to be included in the embeddings. ```python from django.db import models from wagtail.models import Page -from wagtail_vector_index.models import VectorIndexedMixin, EmbeddingField +from wagtail_vector_index.models import PageVectorIndexedMixin, EmbeddingField -class MyPage(VectorIndexedMixin, Page): +class MyPage(PageVectorIndexedMixin, Page): body = models.TextField() embedding_fields = [EmbeddingField("title"), EmbeddingField("body")] diff --git a/src/wagtail_vector_index/models.py b/src/wagtail_vector_index/models.py index 6c34ede..0e63cde 100644 --- a/src/wagtail_vector_index/models.py +++ b/src/wagtail_vector_index/models.py @@ -6,7 +6,6 @@ from django.core import checks from django.core.exceptions import FieldDoesNotExist from django.db import models, transaction -from wagtail.models import Page from wagtail.search.index import BaseField from wagtail_vector_index.index.base import Document @@ -98,7 +97,7 @@ class VectorIndexedMixin(models.Model): embeddings = GenericRelation( Embedding, content_type_field="content_type", for_concrete_model=False ) - vector_index_class = None + vector_index_class = ModelVectorIndex class Meta: abstract = True @@ -240,21 +239,24 @@ def bulk_from_documents(cls, documents): yield cls.from_document(document) @classmethod - def get_vector_index(cls): + def get_vector_index(cls) -> ModelVectorIndex: """Get a vector index instance for this model""" - # If the user has specified a custom `vector_index_class`, use that - if cls.vector_index_class: - index_cls = cls.vector_index_class - # If the model is a Wagtail Page, use a special PageVectorIndex - elif issubclass(cls, Page): - index_cls = PageVectorIndex - # Otherwise use the standard ModelVectorIndex - else: - index_cls = ModelVectorIndex + name = f"{cls.__name__}Index" + bases = (cls.vector_index_class,) + dict_ = {"querysets": [cls.objects.all()]} + index_class = type(name, bases, dict_) + return index_class(object_type=cls) + + +class PageVectorIndexedMixin(VectorIndexedMixin): + vector_index_class = PageVectorIndex + + class Meta: + abstract = True + + @classmethod + def get_vector_index(cls) -> PageVectorIndex: + """Get a vector index instance for this model""" - return type( - f"{cls.__name__}Index", - (index_cls,), - {"querysets": [cls.objects.all()]}, - )(object_type=cls) + return super().get_vector_index() # type: ignore diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 05b5bdd..be70c60 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -6,7 +6,11 @@ PageVectorIndex, ) from wagtail_vector_index.index.registry import registry -from wagtail_vector_index.models import EmbeddingField, VectorIndexedMixin +from wagtail_vector_index.models import ( + EmbeddingField, + PageVectorIndexedMixin, + VectorIndexedMixin, +) class ExampleModel(VectorIndexedMixin, models.Model): @@ -19,7 +23,7 @@ def __str__(self): return self.title -class ExamplePage(VectorIndexedMixin, Page): +class ExamplePage(PageVectorIndexedMixin, Page): body = RichTextField() content_panels = [*Page.content_panels, FieldPanel("body")] @@ -27,7 +31,7 @@ class ExamplePage(VectorIndexedMixin, Page): embedding_fields = [EmbeddingField("title", important=True), EmbeddingField("body")] -class DifferentPage(VectorIndexedMixin, Page): +class DifferentPage(PageVectorIndexedMixin, Page): body = RichTextField() content_panels = [*Page.content_panels, FieldPanel("body")]