diff --git a/README.md b/README.md
index ec04f505..354fc997 100644
--- a/README.md
+++ b/README.md
@@ -118,6 +118,8 @@ A separate MCP tool, `find_similar_messages`, returns nearest neighbors for a se
> **Run only one embedding process at a time.** Don't run `msgvault embeddings build`/`resume` or `repair-encoding` concurrently with a `msgvault serve` daemon — they write the same embedding state, and concurrent writers are not coordinated across processes.
+Large archives can scope an embedding generation with `[vector.embed.scope] message_types = ["sms", "mms"]`. Scoped vector and hybrid searches must include a matching `message_type` filter so a partial index is never used as if it covered the whole archive.
+
## Importing from MBOX or Apple Mail
Import email from providers that offer MBOX exports or from a local Apple Mail data directory:
diff --git a/cmd/msgvault/cmd/add_synctech_sms_drive.go b/cmd/msgvault/cmd/add_synctech_sms_drive.go
index 44d60d04..e6f3d340 100644
--- a/cmd/msgvault/cmd/add_synctech_sms_drive.go
+++ b/cmd/msgvault/cmd/add_synctech_sms_drive.go
@@ -109,10 +109,15 @@ func runConfiguredSynctechSMSSource(ctx context.Context, src config.SynctechSMSS
return err
}
defer func() { _ = st.Close() }()
+
return runConfiguredSynctechSMSSourceWithStore(ctx, st, src)
}
func runConfiguredSynctechSMSSourceWithStore(ctx context.Context, st *store.Store, src config.SynctechSMSSource) error {
+ return runConfiguredSynctechSMSSourceWithStoreDriveClient(ctx, st, src, nil)
+}
+
+func runConfiguredSynctechSMSSourceWithStoreDriveClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, driveClient synctechsms.DriveClient) error {
opts := synctechImportOptions(src)
if opts.OwnerPhone == "" {
return fmt.Errorf("synctech-sms source %q owner_phone is required", src.Name)
@@ -128,7 +133,11 @@ func runConfiguredSynctechSMSSourceWithStore(ctx context.Context, st *store.Stor
}
_, err = synctechsms.NewImporter(st, opts).ImportPath(src.Path)
case "drive":
- err = runSynctechSMSDriveSource(ctx, st, src, opts)
+ if driveClient != nil {
+ _, err = runSynctechSMSDriveSourceWithClient(ctx, st, src, opts, driveClient)
+ } else {
+ _, err = runSynctechSMSDriveSource(ctx, st, src, opts)
+ }
default:
return fmt.Errorf("unsupported synctech-sms backend %q", src.Backend)
}
@@ -167,28 +176,28 @@ func validateSynctechSMSDriveSource(src config.SynctechSMSSource) error {
return nil
}
-func runSynctechSMSDriveSource(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions) error {
+func runSynctechSMSDriveSource(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions) (synctechsms.ImportSummary, error) {
if err := validateSynctechSMSDriveSource(src); err != nil {
- return err
+ return synctechsms.ImportSummary{}, err
}
client, err := newSynctechSMSDriveClient(ctx, src)
if err != nil {
- return err
+ return synctechsms.ImportSummary{}, err
}
return runSynctechSMSDriveSourceWithClient(ctx, st, src, opts, client)
}
-func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions, client synctechsms.DriveClient) (retErr error) {
+func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions, client synctechsms.DriveClient) (summary synctechsms.ImportSummary, retErr error) {
if err := validateSynctechSMSDriveSource(src); err != nil {
- return err
+ return summary, err
}
source, err := ensureConfiguredSynctechSMSSource(st, src, opts)
if err != nil {
- return err
+ return summary, err
}
syncID, err := st.StartSync(source.ID, synctechsms.AdapterName)
if err != nil {
- return fmt.Errorf("start sync: %w", err)
+ return summary, fmt.Errorf("start sync: %w", err)
}
completed := false
defer func() {
@@ -204,38 +213,38 @@ func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, s
}()
files, err := client.ListBackupFiles(ctx, src.FolderID)
if err != nil {
- return fmt.Errorf("list Drive backup files: %w", err)
+ return summary, fmt.Errorf("list Drive backup files: %w", err)
}
imported, err := st.ListImportedSourceItemChecksums(source.ID, "drive")
if err != nil {
- return fmt.Errorf("list imported Drive checksums: %w", err)
+ return summary, fmt.Errorf("list imported Drive checksums: %w", err)
}
stableAfter, err := time.ParseDuration(src.StableAfter)
if err != nil {
- return fmt.Errorf("parse stable_after: %w", err)
+ return summary, fmt.Errorf("parse stable_after: %w", err)
}
selected := synctechsms.SelectStableDriveFiles(files, time.Now(), stableAfter, imported)
stagingDir := filepath.Join(cfg.Data.DataDir, "imports", "synctech-sms", src.Name)
if err := os.MkdirAll(stagingDir, 0o700); err != nil {
- return fmt.Errorf("create staging directory: %w", err)
+ return summary, fmt.Errorf("create staging directory: %w", err)
}
imp := synctechsms.NewImporter(st, opts)
- var summary synctechsms.ImportSummary
for _, file := range selected {
fileSummary, err := importOneDriveBackup(ctx, st, imp, client, source.ID, file, stagingDir)
- if err != nil {
- return err
- }
summary.FilesSeen += fileSummary.FilesSeen
summary.FilesImported += fileSummary.FilesImported
summary.SMSImported += fileSummary.SMSImported
summary.MMSImported += fileSummary.MMSImported
summary.CallsImported += fileSummary.CallsImported
summary.AttachmentsImported += fileSummary.AttachmentsImported
+ summary.MessageIDs = append(summary.MessageIDs, fileSummary.MessageIDs...)
+ if err != nil {
+ return summary, err
+ }
}
if summary.FilesImported > 0 {
if err := st.RecomputeConversationStats(source.ID); err != nil {
- return fmt.Errorf("recompute conversation stats: %w", err)
+ return summary, fmt.Errorf("recompute conversation stats: %w", err)
}
}
totalRecords := int64(summary.SMSImported + summary.MMSImported + summary.CallsImported)
@@ -243,16 +252,16 @@ func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, s
MessagesProcessed: totalRecords,
MessagesAdded: totalRecords,
}); err != nil {
- return fmt.Errorf("update sync checkpoint: %w", err)
+ return summary, fmt.Errorf("update sync checkpoint: %w", err)
}
if err := st.TouchSourceLastSyncAt(source.ID); err != nil {
- return fmt.Errorf("touch source last sync: %w", err)
+ return summary, fmt.Errorf("touch source last sync: %w", err)
}
if err := st.CompleteSync(syncID, ""); err != nil {
- return fmt.Errorf("complete sync: %w", err)
+ return summary, fmt.Errorf("complete sync: %w", err)
}
completed = true
- return nil
+ return summary, nil
}
func importOneDriveBackup(ctx context.Context, st *store.Store, imp *synctechsms.Importer, client synctechsms.DriveClient, sourceID int64, file synctechsms.DriveFile, stagingDir string) (synctechsms.ImportSummary, error) {
diff --git a/cmd/msgvault/cmd/add_synctech_sms_drive_test.go b/cmd/msgvault/cmd/add_synctech_sms_drive_test.go
index f3920850..2f98ede7 100644
--- a/cmd/msgvault/cmd/add_synctech_sms_drive_test.go
+++ b/cmd/msgvault/cmd/add_synctech_sms_drive_test.go
@@ -75,8 +75,9 @@ func TestSynctechSMSDriveRunUsesSingleOuterSyncRun(t *testing.T) {
},
}
- err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client)
+ summary, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client)
require.NoError(err, "runSynctechSMSDriveSourceWithClient")
+ require.Len(summary.MessageIDs, 1, "summary message IDs")
source := getSynctechSource(t, f.Store, src.OwnerPhone)
assert.Equal(1, countSyncRuns(t, f.Store, source.ID), "sync run count")
@@ -86,7 +87,7 @@ func TestSynctechSMSDriveRunUsesSingleOuterSyncRun(t *testing.T) {
assert.Equal(int64(1), run.MessagesAdded, "messages added")
assert.True(getSynctechSource(t, f.Store, src.OwnerPhone).LastSyncAt.Valid, "last_sync_at should be touched")
- item := getSourceImportItem(t, f.Store, source.ID, "drive", "backup-1")
+ item := getDriveSourceImportItem(t, f.Store, source.ID, "backup-1")
assert.Equal("imported", item.Status, "source import status")
assert.Equal(1, item.RecordsImported, "records imported")
assert.False(item.ErrorMessage.Valid, "source import error")
@@ -113,7 +114,7 @@ func TestSynctechSMSDriveRunSetsUpIdentityAndPostSourceMigration(t *testing.T) {
src := synctechDriveTestSource()
client := fakeSynctechDriveClient{}
- err = runSynctechSMSDriveSourceWithClient(context.Background(), st, src, synctechImportOptions(src), client)
+ _, err = runSynctechSMSDriveSourceWithClient(context.Background(), st, src, synctechImportOptions(src), client)
require.NoError(err, "runSynctechSMSDriveSourceWithClient")
synctechSource := getSynctechSource(t, st, src.OwnerPhone)
@@ -154,7 +155,7 @@ func TestSynctechSMSDriveRunRecordsZeroSelectedPoll(t *testing.T) {
}},
}
- err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client)
+ _, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client)
require.NoError(err, "runSynctechSMSDriveSourceWithClient")
source := getSynctechSource(t, f.Store, src.OwnerPhone)
@@ -188,7 +189,7 @@ func TestSynctechSMSDriveRunMarksOuterSyncFailedOnDownloadError(t *testing.T) {
downloadErr: downloadErr,
}
- err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client)
+ _, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client)
require.ErrorIs(err, downloadErr, "runSynctechSMSDriveSourceWithClient")
source := getSynctechSource(t, f.Store, src.OwnerPhone)
@@ -198,17 +199,126 @@ func TestSynctechSMSDriveRunMarksOuterSyncFailedOnDownloadError(t *testing.T) {
require.True(run.ErrorMessage.Valid, "sync error_message")
assert.Contains(run.ErrorMessage.String, downloadErr.Error(), "sync error_message")
- item := getSourceImportItem(t, f.Store, source.ID, "drive", "backup-1")
+ item := getDriveSourceImportItem(t, f.Store, source.ID, "backup-1")
assert.Equal("failed", item.Status, "source import status")
require.True(item.ErrorMessage.Valid, "source import error")
assert.Contains(item.ErrorMessage.String, downloadErr.Error(), "source import error")
}
+func TestSynctechSMSDrivePartialFailureEnqueuesImportedMessages(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ home := t.TempDir()
+ cfg = config.NewDefaultConfig()
+ cfg.HomeDir = home
+ cfg.Data.DataDir = home
+ f := storetest.New(t)
+ src := synctechDriveTestSource()
+ client := fakeSynctechDriveClient{
+ files: []synctechsms.DriveFile{
+ {
+ ID: "backup-1",
+ Name: "sms-1.xml",
+ Checksum: "sum-1",
+ Size: 128,
+ ModifiedTime: time.Now().Add(-30 * time.Minute),
+ },
+ {
+ ID: "backup-2",
+ Name: "sms-2.xml",
+ Checksum: "sum-2",
+ Size: 128,
+ ModifiedTime: time.Now().Add(-30 * time.Minute),
+ },
+ },
+ downloads: map[string]string{
+ "backup-1": `
+
+`,
+ "backup-2": `
+
+
+
+`), 0o600), "write backup")
+
+ st, err := store.Open(cfg.DatabaseDSN())
+ require.NoError(err, "open store")
+ require.NoError(st.InitSchema(), "InitSchema")
+ require.NoError(st.Close(), "close store")
+
+ src := config.SynctechSMSSource{
+ Name: "pixel-local",
+ Backend: "local",
+ Path: importDir,
+ OwnerPhone: "+15550000001",
+ IncludeSMS: true,
+ }
+ require.NoError(runConfiguredSynctechSMSSource(ctx, src), "runConfiguredSynctechSMSSource")
+
+ st, err = store.Open(cfg.DatabaseDSN())
+ require.NoError(err, "reopen store")
+ t.Cleanup(func() { _ = st.Close() })
+ source := getSynctechSource(t, st, src.OwnerPhone)
+ var unstamped int
+ require.NoError(st.DB().QueryRowContext(ctx,
+ st.Rebind(`SELECT COUNT(*) FROM messages WHERE source_id = ? AND embed_gen IS NULL`),
+ source.ID,
+ ).Scan(&unstamped), "count unstamped messages")
+ assert.Equal(1, unstamped, "manual sync message remains discoverable by scan-and-fill")
+}
+
type fakeSynctechDriveClient struct {
- files []synctechsms.DriveFile
- downloads map[string]string
- listErr error
- downloadErr error
+ files []synctechsms.DriveFile
+ downloads map[string]string
+ listErr error
+ downloadErr error
+ downloadErrByID map[string]error
}
func (f fakeSynctechDriveClient) ListBackupFiles(context.Context, string) ([]synctechsms.DriveFile, error) {
@@ -219,6 +329,9 @@ func (f fakeSynctechDriveClient) ListBackupFiles(context.Context, string) ([]syn
}
func (f fakeSynctechDriveClient) DownloadToFile(_ context.Context, fileID, path string) error {
+ if err := f.downloadErrByID[fileID]; err != nil {
+ return err
+ }
if f.downloadErr != nil {
return f.downloadErr
}
@@ -280,9 +393,9 @@ func getOnlySyncRun(t *testing.T, st *store.Store, sourceID int64) store.SyncRun
return run
}
-func getSourceImportItem(t *testing.T, st *store.Store, sourceID int64, provider, providerID string) *store.SourceImportItem {
+func getDriveSourceImportItem(t *testing.T, st *store.Store, sourceID int64, providerID string) *store.SourceImportItem {
t.Helper()
- item, err := st.GetSourceImportItem(sourceID, provider, providerID)
+ item, err := st.GetSourceImportItem(sourceID, "drive", providerID)
requirepkg.NoError(t, err, "GetSourceImportItem")
return item
}
diff --git a/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go b/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go
index 3663648e..8a815438 100644
--- a/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go
+++ b/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go
@@ -5,6 +5,7 @@ package cmd
import (
"context"
"database/sql"
+ "path/filepath"
"testing"
_ "github.com/mattn/go-sqlite3"
@@ -48,3 +49,30 @@ func TestRunEmbeddingsRetire_ForceActive(t *testing.T) {
row := mustGetEmbeddingGeneration(t.Context(), t, db, 1)
assert.Equal(vector.GenerationRetired, row.State)
}
+
+func TestFillFullCoverageUsesEmbeddingScopeForEmbeddedCount(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ ctx := context.Background()
+ dataDir := t.TempDir()
+ dbPath := newEmbeddingMetadataTestDBFileAt(t, filepath.Join(dataDir, "vectors.db"))
+ seedMainDBWithScopedFullCoverageMessages(t, dataDir)
+ withEmbeddingCommandConfigDataDir(t, dbPath, dataDir)
+ cfg.Vector.Embed.Scope.MessageTypes = []string{"sms"}
+
+ backend, closeBackend, err := openEmbeddingsBackend(ctx)
+ require.NoError(err, "open embeddings backend")
+ t.Cleanup(closeBackend)
+ require.NoError(backend.Upsert(ctx, 2, []vector.Chunk{
+ {MessageID: 1, Vector: []float32{1, 0, 0, 0}},
+ {MessageID: 2, Vector: []float32{0, 1, 0, 0}},
+ }), "upsert in-scope and out-of-scope vectors")
+
+ row := embeddingGenerationRow{ID: 2}
+ require.NoError(fillFullCoverage(ctx, backend, &row))
+
+ assert.Equal(int64(1), row.LiveCount, "only sms is in scope")
+ assert.Equal(int64(1), row.EmbeddedCount, "out-of-scope email vector is excluded")
+ assert.Equal(int64(0), row.BlankCount)
+ assert.Equal(int64(0), row.MissingCount)
+}
diff --git a/cmd/msgvault/cmd/embed_test.go b/cmd/msgvault/cmd/embed_test.go
index e9af2674..f2eb3512 100644
--- a/cmd/msgvault/cmd/embed_test.go
+++ b/cmd/msgvault/cmd/embed_test.go
@@ -139,6 +139,22 @@ func TestRunEmbeddingsActivateRefusesMissingWithoutForce(t *testing.T) {
assert.Contains(err.Error(), "msgvault embeddings resume --backstop")
}
+func TestFillCoverageUsesEmbeddingScope(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ dataDir := t.TempDir()
+ dbPath := newEmbeddingMetadataTestDBFileAt(t, filepath.Join(dataDir, "vectors.db"))
+ seedMainDBWithScopedCoverageMessages(t, dataDir)
+ withEmbeddingCommandConfigDataDir(t, dbPath, dataDir)
+ cfg.Vector.Embed.Scope.MessageTypes = []string{"sms"}
+
+ row := embeddingGenerationRow{ID: 2}
+ require.NoError(fillCoverage(t.Context(), &row))
+
+ assert.Equal(int64(1), row.LiveCount)
+ assert.Equal(int64(0), row.MissingCount)
+}
+
// TestRetireEmbeddingGenerationRefusesActiveWithoutForce_PreCheck pins the
// CLI UX gate that runs against the committed metadata read BEFORE any
// backend connection: retiring an active generation without --force-active
@@ -240,6 +256,38 @@ INSERT INTO messages (id, conversation_id, source_id, source_message_id, message
requirepkg.NoError(t, err)
}
+func seedMainDBWithScopedCoverageMessages(t *testing.T, dataDir string) {
+ t.Helper()
+ s, err := store.Open(filepath.Join(dataDir, "msgvault.db"))
+ requirepkg.NoError(t, err)
+ defer func() { requirepkg.NoError(t, s.Close()) }()
+ requirepkg.NoError(t, s.InitSchema())
+ _, err = s.DB().Exec(`
+INSERT INTO sources (id, source_type, identifier) VALUES (1, 'gmail', 'me@example.com');
+INSERT INTO conversations (id, source_id, conversation_type) VALUES (1, 1, 'email_thread'), (2, 1, 'sms_thread');
+INSERT INTO messages (id, conversation_id, source_id, source_message_id, message_type, embed_gen) VALUES
+ (1, 1, 1, 'email-missing', 'email', NULL),
+ (2, 2, 1, 'sms-stamped', 'sms', 2);
+`)
+ requirepkg.NoError(t, err)
+}
+
+func seedMainDBWithScopedFullCoverageMessages(t *testing.T, dataDir string) {
+ t.Helper()
+ s, err := store.Open(filepath.Join(dataDir, "msgvault.db"))
+ requirepkg.NoError(t, err)
+ defer func() { requirepkg.NoError(t, s.Close()) }()
+ requirepkg.NoError(t, s.InitSchema())
+ _, err = s.DB().Exec(`
+INSERT INTO sources (id, source_type, identifier) VALUES (1, 'gmail', 'me@example.com');
+INSERT INTO conversations (id, source_id, conversation_type) VALUES (1, 1, 'email_thread'), (2, 1, 'sms_thread');
+INSERT INTO messages (id, conversation_id, source_id, source_message_id, message_type, embed_gen) VALUES
+ (1, 1, 1, 'email-stamped', 'email', 2),
+ (2, 2, 1, 'sms-stamped', 'sms', 2);
+`)
+ requirepkg.NoError(t, err)
+}
+
func withEmbeddingCommandConfig(t *testing.T, vecPath string) {
t.Helper()
oldCfg := cfg
diff --git a/cmd/msgvault/cmd/embed_vector.go b/cmd/msgvault/cmd/embed_vector.go
index 8df24cd0..f0372d67 100644
--- a/cmd/msgvault/cmd/embed_vector.go
+++ b/cmd/msgvault/cmd/embed_vector.go
@@ -57,6 +57,7 @@ func runEmbed(cmd *cobra.Command) error {
pgb, err := pgvector.Open(ctx, pgvector.Options{
DB: s.DB(),
Dimension: cfg.Vector.Embeddings.Dimension,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
SkipExtension: cfg.Vector.SkipExtensionCreate,
})
if err != nil {
@@ -76,10 +77,11 @@ func runEmbed(cmd *cobra.Command) error {
vecPath = filepath.Join(cfg.Data.DataDir, "vectors.db")
}
sb, err := sqlitevec.Open(ctx, sqlitevec.Options{
- Path: vecPath,
- MainPath: cfg.DatabaseDSN(),
- Dimension: cfg.Vector.Embeddings.Dimension,
- MainDB: s.DB(),
+ Path: vecPath,
+ MainPath: cfg.DatabaseDSN(),
+ Dimension: cfg.Vector.Embeddings.Dimension,
+ MainDB: s.DB(),
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
})
if err != nil {
return fmt.Errorf("open vectors.db: %w", err)
@@ -116,7 +118,8 @@ func runEmbed(cmd *cobra.Command) error {
// "Pending" is now the count of live messages still needing work for
// this generation (embed_gen <> gen), read from the main DB coverage
// rather than a queue table.
- missing, err := s.MissingCount(ctx, int64(gen))
+ scope := cfg.Vector.Embed.Scope.BuildScope()
+ missing, err := s.MissingCountScoped(ctx, int64(gen), scope.MessageTypes)
if err != nil {
return fmt.Errorf("coverage counts: %w", err)
}
@@ -138,6 +141,7 @@ func runEmbed(cmd *cobra.Command) error {
},
MaxInputChars: cfg.Vector.Embeddings.MaxInputChars,
BatchSize: cfg.Vector.Embeddings.BatchSize,
+ BuildScope: scope,
Rebind: rebind,
LastModifiedExpr: lastModifiedExpr,
TotalPending: totalPending,
@@ -161,7 +165,7 @@ func runEmbed(cmd *cobra.Command) error {
// worker later recovers from must not block activation, and an
// active generation must not be re-activated.
if rebuildInProgress {
- _, _, _, remaining, err := s.CoverageCounts(ctx, int64(gen))
+ _, _, _, remaining, err := s.CoverageCountsScoped(ctx, int64(gen), scope.MessageTypes)
if err != nil {
return fmt.Errorf("coverage counts: %w", err)
}
diff --git a/cmd/msgvault/cmd/embeddings_manage.go b/cmd/msgvault/cmd/embeddings_manage.go
index b3e0f04f..81525b60 100644
--- a/cmd/msgvault/cmd/embeddings_manage.go
+++ b/cmd/msgvault/cmd/embeddings_manage.go
@@ -64,7 +64,8 @@ func fillCoverage(ctx context.Context, row *embeddingGenerationRow) error {
return fmt.Errorf("open main db for coverage: %w", err)
}
defer func() { _ = s.Close() }()
- live, _, _, missing, err := s.CoverageCounts(ctx, int64(row.ID))
+ scope := cfg.Vector.Embed.Scope.BuildScope()
+ live, _, _, missing, err := s.CoverageCountsScoped(ctx, int64(row.ID), scope.MessageTypes)
if err != nil {
return err
}
@@ -87,7 +88,8 @@ func fillFullCoverage(ctx context.Context, backend vector.Backend, row *embeddin
return fmt.Errorf("open main db for coverage: %w", err)
}
defer func() { _ = s.Close() }()
- live, stamped, _, missing, err := s.CoverageCounts(ctx, int64(row.ID))
+ scope := cfg.Vector.Embed.Scope.BuildScope()
+ live, stamped, _, missing, err := s.CoverageCountsScoped(ctx, int64(row.ID), scope.MessageTypes)
if err != nil {
return err
}
@@ -426,6 +428,7 @@ func openEmbeddingsBackend(ctx context.Context) (vector.Backend, func(), error)
DB: db,
Dimension: cfg.Vector.Embeddings.Dimension,
SkipMigrate: true,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
})
if err != nil {
_ = db.Close()
@@ -460,10 +463,11 @@ func openEmbeddingsBackend(ctx context.Context) (vector.Backend, func(), error)
return nil, nil, fmt.Errorf("open main db for embeddings backend: %w", err)
}
b, err := sqlitevec.Open(ctx, sqlitevec.Options{
- Path: vecPath,
- MainPath: dsn,
- Dimension: cfg.Vector.Embeddings.Dimension,
- MainDB: mainStore.DB(),
+ Path: vecPath,
+ MainPath: dsn,
+ Dimension: cfg.Vector.Embeddings.Dimension,
+ MainDB: mainStore.DB(),
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
})
if err != nil {
_ = mainStore.Close()
diff --git a/cmd/msgvault/cmd/search_vector.go b/cmd/msgvault/cmd/search_vector.go
index 96cf957b..6f8e7ab6 100644
--- a/cmd/msgvault/cmd/search_vector.go
+++ b/cmd/msgvault/cmd/search_vector.go
@@ -79,6 +79,7 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc
pgb, err := pgvector.Open(ctx, pgvector.Options{
DB: mainDB,
Dimension: cfg.Vector.Embeddings.Dimension,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
SkipExtension: cfg.Vector.SkipExtensionCreate,
})
if err != nil {
@@ -110,10 +111,11 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc
}
sb, err := sqlitevec.Open(ctx, sqlitevec.Options{
- Path: vecDBPath,
- MainPath: dsn,
- Dimension: cfg.Vector.Embeddings.Dimension,
- MainDB: mainDB,
+ Path: vecDBPath,
+ MainPath: dsn,
+ Dimension: cfg.Vector.Embeddings.Dimension,
+ MainDB: mainDB,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
})
if err != nil {
_ = mainDB.Close()
@@ -147,6 +149,7 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc
KPerSignal: cfg.Vector.Search.KPerSignal,
SubjectBoost: cfg.Vector.Search.SubjectBoost,
Rebind: dialect.Rebind,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
})
q := search.Parse(queryStr)
diff --git a/cmd/msgvault/cmd/serve.go b/cmd/msgvault/cmd/serve.go
index fa1a41d8..29fe1f21 100644
--- a/cmd/msgvault/cmd/serve.go
+++ b/cmd/msgvault/cmd/serve.go
@@ -206,6 +206,7 @@ func runServe(cmd *cobra.Command, args []string) error {
Store: s,
Fingerprint: vf.Cfg.GenerationFingerprint(),
BackstopInterval: vf.Cfg.Embed.BackstopInterval,
+ BuildScope: vf.Cfg.Embed.Scope.BuildScope(),
Log: logger,
}
schedule := cfg.Vector.Embed.Schedule.Cron
diff --git a/cmd/msgvault/cmd/serve_vector.go b/cmd/msgvault/cmd/serve_vector.go
index 6be2e5e8..9f329e18 100644
--- a/cmd/msgvault/cmd/serve_vector.go
+++ b/cmd/msgvault/cmd/serve_vector.go
@@ -76,6 +76,7 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s
pgb, err := pgvector.Open(ctx, pgvector.Options{
DB: mainDB,
Dimension: cfg.Vector.Embeddings.Dimension,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
SkipMigrate: readOnly,
// ReadOnly MUST track readOnly here: this is the MCP read-only
// path (store.OpenReadOnly). When set, Open performs no writes —
@@ -103,10 +104,11 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s
vecPath = filepath.Join(cfg.Data.DataDir, "vectors.db")
}
sb, err := sqlitevec.Open(ctx, sqlitevec.Options{
- Path: vecPath,
- MainPath: mainPath,
- Dimension: cfg.Vector.Embeddings.Dimension,
- MainDB: mainDB,
+ Path: vecPath,
+ MainPath: mainPath,
+ Dimension: cfg.Vector.Embeddings.Dimension,
+ MainDB: mainDB,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
// Honor the read-only signal on SQLite too: when mainDB is a
// query-only handle (MCP), skip the embed_gen upgrade backfill,
// which would write through it. Migrate still runs (vectors.db
@@ -146,6 +148,7 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s
},
MaxInputChars: cfg.Vector.Embeddings.MaxInputChars,
BatchSize: cfg.Vector.Embeddings.BatchSize,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
// Rebind makes the worker's body-fetch + watermark SQL run on pgx.
// SQLiteDialect.Rebind is identity, so the SQLite path is unchanged.
Rebind: dialect.Rebind,
@@ -162,7 +165,8 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s
// placeholders. On PG those must become $N or pgx rejects them, so
// the serve/MCP hybrid engine (shared via vectorFeatures.HybridEngine)
// carries the dialect's Rebind. SQLite's Rebind is identity.
- Rebind: dialect.Rebind,
+ Rebind: dialect.Rebind,
+ BuildScope: cfg.Vector.Embed.Scope.BuildScope(),
})
// No sync-time enqueue: newly-persisted messages get embed_gen = NULL
diff --git a/internal/api/handlers.go b/internal/api/handlers.go
index c76ed247..4df02ff6 100644
--- a/internal/api/handlers.go
+++ b/internal/api/handlers.go
@@ -616,6 +616,8 @@ func (s *Server) handleHybridSearch(
case errors.Is(err, vector.ErrEmbeddingTimeout):
writeError(w, http.StatusServiceUnavailable, "embedding_timeout",
"the embedding endpoint did not respond in time; retry, or raise [vector.embeddings].timeout")
+ case errors.Is(err, vector.ErrIndexScopeMismatch):
+ writeError(w, http.StatusBadRequest, "index_scope_mismatch", err.Error())
default:
s.logger.Error("hybrid search failed", "query", q, "mode", mode, "error", err)
writeError(w, http.StatusInternalServerError, "internal_error", "search failed")
diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go
index c9a74715..1940aeab 100644
--- a/internal/mcp/handlers.go
+++ b/internal/mcp/handlers.go
@@ -118,6 +118,11 @@ func translateVectorErr(err error) *mcp.CallToolResult {
return mcp.NewToolResultError(
"index_building: the initial vector index is still being built",
)
+ case errors.Is(err, vector.ErrIndexScopeMismatch):
+ return mcp.NewToolResultError(
+ "index_scope_mismatch: the vector index scope does not cover this query; " +
+ "add a matching message_type filter or rebuild embeddings for the requested scope",
+ )
case errors.Is(err, vector.ErrNoActiveGeneration):
return mcp.NewToolResultError(
"no_active_generation: vector search has no active index yet; " +
@@ -512,25 +517,46 @@ func (h *handlers) findSimilarMessages(ctx context.Context, req mcp.CallToolRequ
return mcp.NewToolResultError(err.Error()), nil
}
- seed, err := h.backend.LoadVector(ctx, seedID)
+ active, err := h.backend.ActiveGeneration(ctx)
if err != nil {
if r := translateVectorErr(err); r != nil {
return r, nil
}
- return mcp.NewToolResultError(fmt.Sprintf("load seed vector: %v", err)), nil
+ return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil
+ }
+ if h.vectorCfg.Enabled {
+ fingerprint := h.vectorCfg.GenerationFingerprint()
+ if fingerprint != "" && active.Fingerprint != fingerprint {
+ err := fmt.Errorf("%w: active=%q configured=%q",
+ vector.ErrIndexStale, active.Fingerprint, fingerprint)
+ if r := translateVectorErr(err); r != nil {
+ return r, nil
+ }
+ return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil
+ }
}
- active, err := h.backend.ActiveGeneration(ctx)
+ seed, err := h.backend.LoadVector(ctx, seedID)
if err != nil {
if r := translateVectorErr(err); r != nil {
return r, nil
}
- return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil
+ return mcp.NewToolResultError(fmt.Sprintf("load seed vector: %v", err)), nil
+ }
+
+ if h.vectorCfg.Enabled {
+ scope := h.vectorCfg.Embed.Scope.BuildScope()
+ if !scope.IsEmpty() {
+ filter.MessageTypes = append([]string(nil), scope.MessageTypes...)
+ }
}
// +1 so we can drop the seed itself from results without coming up short.
hits, err := h.backend.Search(ctx, active.ID, seed, limit+1, filter)
if err != nil {
+ if r := translateVectorErr(err); r != nil {
+ return r, nil
+ }
return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err)), nil
}
diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go
index e11cc76e..5cdd97ac 100644
--- a/internal/mcp/server_test.go
+++ b/internal/mcp/server_test.go
@@ -262,6 +262,33 @@ func TestSearchMessages_HybridErrNotEnabled(t *testing.T) {
assertpkg.Contains(t, txt, "vector_not_enabled", "expected 'vector_not_enabled' error, got: %s")
}
+func TestSearchMessages_HybridErrIndexScopeMismatch(t *testing.T) {
+ backend := &fakeBackend{
+ active: vector.Generation{
+ ID: 1, Model: "fake", Dimension: 4,
+ Fingerprint: "fake:4:scope=mt-sms", State: vector.GenerationActive,
+ },
+ }
+ engine := hybrid.NewEngine(backend, nil, realEmbedder{dim: 4}, hybrid.Config{
+ ExpectedFingerprint: "fake:4:scope=mt-sms",
+ RRFK: 60,
+ KPerSignal: 10,
+ BuildScope: vector.BuildScope{MessageTypes: []string{"sms"}},
+ })
+ h := &handlers{
+ engine: &querytest.MockEngine{},
+ hybridEngine: engine,
+ backend: backend,
+ }
+
+ r := runToolExpectError(t, "search_messages", h.searchMessages, map[string]any{
+ "query": "anything",
+ "mode": searchModeVector,
+ })
+ txt := resultText(t, r)
+ assertpkg.Contains(t, txt, "index_scope_mismatch", "expected 'index_scope_mismatch' error, got: %s")
+}
+
// realEmbedder returns a deterministic vector. Used for end-to-end
// MCP hybrid tests that exercise the engine's embed → backend.Search
// path; pickEmbedGeneration tests use stubEmbedder instead.
@@ -1465,19 +1492,23 @@ func TestStageDeletion(t *testing.T) {
// optional fields so the get_stats tests can populate them. Methods
// not otherwise configured return errors and should not be called.
type fakeBackend struct {
- loadVec []float32
- loadErr error
- active vector.Generation
- activeErr error
- searchHits []vector.Hit
- searchErr error
- building *vector.Generation
- buildingErr error
- stats map[vector.GenerationID]vector.Stats
- statsErr error
+ loadVec []float32
+ loadErr error
+ loadCalls int
+ active vector.Generation
+ activeErr error
+ searchHits []vector.Hit
+ searchErr error
+ searchGen vector.GenerationID
+ searchFilter vector.Filter
+ building *vector.Generation
+ buildingErr error
+ stats map[vector.GenerationID]vector.Stats
+ statsErr error
}
func (f *fakeBackend) LoadVector(_ context.Context, _ int64) ([]float32, error) {
+ f.loadCalls++
return f.loadVec, f.loadErr
}
func (f *fakeBackend) ResetWatermarkBelow(_ context.Context, _ int64) error {
@@ -1489,7 +1520,9 @@ func (f *fakeBackend) EmbeddedMessageCount(_ context.Context, _ vector.Generatio
func (f *fakeBackend) ActiveGeneration(_ context.Context) (vector.Generation, error) {
return f.active, f.activeErr
}
-func (f *fakeBackend) Search(_ context.Context, _ vector.GenerationID, _ []float32, _ int, _ vector.Filter) ([]vector.Hit, error) {
+func (f *fakeBackend) Search(_ context.Context, gen vector.GenerationID, _ []float32, _ int, filter vector.Filter) ([]vector.Hit, error) {
+ f.searchGen = gen
+ f.searchFilter = filter
return f.searchHits, f.searchErr
}
func (f *fakeBackend) CreateGeneration(_ context.Context, _ string, _ int, _ string) (vector.GenerationID, error) {
@@ -1623,9 +1656,98 @@ func TestFindSimilarMessages_HappyPath(t *testing.T) {
assert.Equal(int64(300), resp.Messages[1].ID, "Messages[1].ID")
}
+func TestFindSimilarMessages_RejectsStaleScopedGeneration(t *testing.T) {
+ assert := assertpkg.New(t)
+ seed := []float32{1, 0, 0, 0}
+ cfg := vector.Config{Enabled: true}
+ cfg.Embeddings.Model = "nomic-embed"
+ cfg.Embeddings.Dimension = 4
+ cfg.Embed.Scope.MessageTypes = []string{"sms"}
+ fb := &fakeBackend{
+ loadVec: seed,
+ active: vector.Generation{
+ ID: 7,
+ Model: "nomic-embed",
+ Dimension: 4,
+ Fingerprint: "nomic-embed:4",
+ State: vector.GenerationActive,
+ },
+ }
+ h := &handlers{engine: &querytest.MockEngine{}, backend: fb, vectorCfg: cfg}
+
+ r := runToolExpectError(t, "find_similar_messages", h.findSimilarMessages, map[string]any{
+ "message_id": float64(100),
+ })
+ txt := resultText(t, r)
+
+ assert.Contains(txt, "index_stale", "expected stale scoped index rejection, got: %s", txt)
+ assert.Equal(vector.GenerationID(0), fb.searchGen, "Search should not run against a stale scoped index")
+}
+
+func TestFindSimilarMessages_ReportsStaleIndexBeforeMissingSeed(t *testing.T) {
+ assert := assertpkg.New(t)
+ cfg := vector.Config{Enabled: true}
+ cfg.Embeddings.Model = "nomic-embed"
+ cfg.Embeddings.Dimension = 4
+ cfg.Embed.Scope.MessageTypes = []string{"sms"}
+ fb := &fakeBackend{
+ loadErr: errors.New("seed vector missing"),
+ active: vector.Generation{
+ ID: 7,
+ Model: "nomic-embed",
+ Dimension: 4,
+ Fingerprint: "nomic-embed:4",
+ State: vector.GenerationActive,
+ },
+ }
+ h := &handlers{engine: &querytest.MockEngine{}, backend: fb, vectorCfg: cfg}
+
+ r := runToolExpectError(t, "find_similar_messages", h.findSimilarMessages, map[string]any{
+ "message_id": float64(100),
+ })
+ txt := resultText(t, r)
+
+ assert.Contains(txt, "index_stale", "expected stale index to be reported before seed lookup, got: %s", txt)
+ assert.Equal(0, fb.loadCalls, "LoadVector should not run before stale-index validation")
+}
+
+func TestFindSimilarMessages_AppliesConfiguredMessageTypeScope(t *testing.T) {
+ assert := assertpkg.New(t)
+ seed := []float32{1, 0, 0, 0}
+ cfg := vector.Config{Enabled: true}
+ cfg.Embeddings.Model = "nomic-embed"
+ cfg.Embeddings.Dimension = 4
+ cfg.Embed.Scope.MessageTypes = []string{"sms", "mms"}
+ fb := &fakeBackend{
+ loadVec: seed,
+ active: vector.Generation{
+ ID: 7,
+ Model: "nomic-embed",
+ Dimension: 4,
+ Fingerprint: cfg.GenerationFingerprint(),
+ State: vector.GenerationActive,
+ },
+ searchHits: []vector.Hit{{MessageID: 200, Score: 0.95, Rank: 1}},
+ }
+ eng := &querytest.MockEngine{
+ Messages: map[int64]*query.MessageDetail{
+ 200: testutil.NewMessageDetail(200).WithSubject("related").BuildPtr(),
+ },
+ }
+ h := &handlers{engine: eng, backend: fb, vectorCfg: cfg}
+
+ _ = runTool[similarResponse](t, "find_similar_messages", h.findSimilarMessages, map[string]any{
+ "message_id": float64(100),
+ })
+
+ assert.Equal(vector.GenerationID(7), fb.searchGen, "Search generation")
+ assert.Equal([]string{"mms", "sms"}, fb.searchFilter.MessageTypes, "Search MessageTypes filter")
+}
+
func TestFindSimilarMessages_NoActiveGeneration(t *testing.T) {
+ assert := assertpkg.New(t)
fb := &fakeBackend{
- loadErr: vector.ErrNoActiveGeneration,
+ activeErr: vector.ErrNoActiveGeneration,
}
h := &handlers{engine: &querytest.MockEngine{}, backend: fb}
@@ -1633,7 +1755,8 @@ func TestFindSimilarMessages_NoActiveGeneration(t *testing.T) {
"message_id": float64(1),
})
txt := resultText(t, r)
- assertpkg.Contains(t, txt, "no_active_generation", "expected 'no_active_generation' error, got: %s")
+ assert.Contains(txt, "no_active_generation", "expected 'no_active_generation' error, got: %s", txt)
+ assert.Equal(0, fb.loadCalls, "LoadVector should not run without an active generation")
}
func TestSearchByDomains(t *testing.T) {
diff --git a/internal/query/duckdb.go b/internal/query/duckdb.go
index e4a31059..b0374614 100644
--- a/internal/query/duckdb.go
+++ b/internal/query/duckdb.go
@@ -407,7 +407,7 @@ func (e *DuckDBEngine) parquetCTEs() string {
conv AS (
%s
)
- `, msgCTE,
+ `, msgCTE,
e.parquetPath("message_recipients"),
pCTE,
e.parquetPath("labels"),
@@ -417,6 +417,45 @@ func (e *DuckDBEngine) parquetCTEs() string {
convCTE)
}
+func duckDBMessageTypeCondition(alias string, messageTypes []string) (string, []any) {
+ var conditions []string
+ var args []any
+ var exact []string
+ includeEmail := false
+
+ for _, typ := range messageTypes {
+ typ = strings.TrimSpace(strings.ToLower(typ))
+ if typ == "" {
+ continue
+ }
+ if typ == "email" {
+ includeEmail = true
+ continue
+ }
+ exact = append(exact, typ)
+ }
+
+ col := alias + ".message_type"
+ if includeEmail {
+ conditions = append(conditions,
+ fmt.Sprintf("(%s = ? OR %s IS NULL OR %s = '')", col, col, col))
+ args = append(args, "email")
+ }
+ if len(exact) > 0 {
+ placeholders := make([]string, len(exact))
+ for i, typ := range exact {
+ placeholders[i] = "?"
+ args = append(args, typ)
+ }
+ conditions = append(conditions,
+ fmt.Sprintf("%s IN (%s)", col, strings.Join(placeholders, ",")))
+ }
+ if len(conditions) == 0 {
+ return "", nil
+ }
+ return "(" + strings.Join(conditions, " OR ") + ")", args
+}
+
// escapeILIKE escapes ILIKE wildcard characters (% and _) in user input.
func escapeILIKE(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\") // Escape backslash first
@@ -486,6 +525,14 @@ func (e *DuckDBEngine) buildNonTextSearchConditions(q *search.Query, keyColumns
var conditions []string
var args []any
+ if len(q.MessageTypes) > 0 {
+ condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes)
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
+ }
+ }
+
// from: filter - match sender email
for _, from := range q.FromAddrs {
fromPattern := "%" + escapeILIKE(from) + "%"
@@ -578,12 +625,11 @@ func (e *DuckDBEngine) buildNonTextSearchConditions(q *search.Query, keyColumns
args = append(args, *q.SmallerThan)
}
if len(q.MessageTypes) > 0 {
- placeholders := make([]string, len(q.MessageTypes))
- for i, typ := range q.MessageTypes {
- placeholders[i] = "?"
- args = append(args, typ)
+ condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes)
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
}
- conditions = append(conditions, fmt.Sprintf("msg.message_type IN (%s)", strings.Join(placeholders, ",")))
}
return conditions, args
@@ -882,8 +928,11 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []an
}
if filter.MessageType != "" {
- conditions = append(conditions, "msg.message_type = ?")
- args = append(args, filter.MessageType)
+ condition, conditionArgs := duckDBMessageTypeCondition("msg", []string{filter.MessageType})
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
+ }
}
// Sender + sender-name filters - check both message_recipients (email)
@@ -1160,10 +1209,14 @@ func (e *DuckDBEngine) GetTotalStats(ctx context.Context, opts StatsOptions) (*T
var conditions []string
var args []any
+ hasExplicitMessageTypes := false
+ if opts.SearchQuery != "" {
+ q := search.Parse(opts.SearchQuery)
+ hasExplicitMessageTypes = len(q.MessageTypes) > 0
+ }
// Restrict to email messages only; NULL and '' handle pre-message_type data.
- hasSearchMessageTypes := opts.SearchQuery != "" && len(search.Parse(opts.SearchQuery).MessageTypes) > 0
- if !hasSearchMessageTypes {
+ if !hasExplicitMessageTypes {
conditions = append(conditions, emailOnlyFilterMsg)
}
conditions = append(conditions, store.LiveMessagesWhere("msg", opts.HideDeletedFromSource))
@@ -1638,6 +1691,14 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse
}
}
+ if len(q.MessageTypes) > 0 {
+ condition, conditionArgs := duckDBMessageTypeCondition("m", q.MessageTypes)
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
+ }
+ }
+
// Has attachment filter
if q.HasAttachment != nil && *q.HasAttachment {
conditions = append(conditions, "m.has_attachments = 1")
@@ -1663,12 +1724,11 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse
args = append(args, *q.SmallerThan)
}
if len(q.MessageTypes) > 0 {
- placeholders := make([]string, len(q.MessageTypes))
- for i, typ := range q.MessageTypes {
- placeholders[i] = "?"
- args = append(args, typ)
+ condition, conditionArgs := duckDBMessageTypeCondition("m", q.MessageTypes)
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
}
- conditions = append(conditions, fmt.Sprintf("m.message_type IN (%s)", strings.Join(placeholders, ",")))
}
// Full-text search: use ILIKE fallback (FTS5 not available via sqlite_scan)
@@ -1702,7 +1762,8 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse
COALESCE(m.size_estimate, 0),
m.has_attachments,
m.attachment_count,
- m.deleted_from_source_at
+ m.deleted_from_source_at,
+ COALESCE(m.message_type, '')
FROM sqlite_db.messages m
LEFT JOIN sqlite_db.message_recipients mr_sender ON mr_sender.message_id = m.id AND mr_sender.recipient_type = 'from'
LEFT JOIN sqlite_db.participants p_sender ON p_sender.id = mr_sender.participant_id
@@ -1740,6 +1801,7 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse
&msg.HasAttachments,
&msg.AttachmentCount,
&deletedAt,
+ &msg.MessageType,
); err != nil {
return nil, fmt.Errorf("scan message: %w", err)
}
@@ -2460,11 +2522,17 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt
var args []any
// Apply basic filter conditions (ignoring join flags for search - we handle those differently)
- if len(q.MessageTypes) == 0 {
+ conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource))
+ if len(q.MessageTypes) > 0 {
+ condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes)
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
+ }
+ } else {
// Restrict to email messages only; NULL and '' handle pre-message_type data.
conditions = append(conditions, emailOnlyFilterMsg)
}
- conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource))
conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs)
if filter.After != nil {
conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)")
@@ -2617,12 +2685,11 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt
args = append(args, *q.SmallerThan)
}
if len(q.MessageTypes) > 0 {
- placeholders := make([]string, len(q.MessageTypes))
- for i, typ := range q.MessageTypes {
- placeholders[i] = "?"
- args = append(args, typ)
+ condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes)
+ if condition != "" {
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
}
- conditions = append(conditions, fmt.Sprintf("msg.message_type IN (%s)", strings.Join(placeholders, ",")))
}
// Account filter
diff --git a/internal/query/duckdb_test.go b/internal/query/duckdb_test.go
index a10e428c..ef4e41d9 100644
--- a/internal/query/duckdb_test.go
+++ b/internal/query/duckdb_test.go
@@ -41,6 +41,33 @@ func newSQLiteEngine(t *testing.T) *DuckDBEngine {
return engine
}
+func newMessageTypeParquetEngine(t *testing.T) *DuckDBEngine {
+ t.Helper()
+ b := NewTestDataBuilder(t)
+ b.AddSource("test@example.com")
+ aliceID := b.AddParticipant("alice@example.com", "example.com", "Alice")
+ bobID := b.AddParticipant("bob@example.com", "example.com", "Bob")
+ smsID := b.AddMessage(MessageOpt{
+ Subject: "lunch plan",
+ Snippet: "sushi lunch details",
+ MessageType: "sms",
+ SentAt: time.Date(2024, 4, 10, 10, 0, 0, 0, time.UTC),
+ SizeEstimate: 321,
+ })
+ emailID := b.AddMessage(MessageOpt{
+ Subject: "lunch receipt",
+ Snippet: "email lunch details",
+ MessageType: "email",
+ SentAt: time.Date(2024, 4, 11, 10, 0, 0, 0, time.UTC),
+ SizeEstimate: 999,
+ })
+ b.AddFrom(smsID, aliceID, "Alice")
+ b.AddTo(smsID, bobID, "Bob")
+ b.AddFrom(emailID, aliceID, "Alice")
+ b.AddTo(emailID, bobID, "Bob")
+ return b.BuildEngine()
+}
+
// searchFast is a test helper that parses a query string and calls SearchFast.
func searchFast(t *testing.T, engine *DuckDBEngine, queryStr string, filter MessageFilter) []MessageSummary {
t.Helper()
@@ -1123,6 +1150,54 @@ func TestDuckDBEngine_GetTotalStats_MessageTypeFilter(t *testing.T) {
assert.Equal(int64(2000), stats.TotalSize, "TotalSize")
}
+func TestDuckDBEngine_SearchFastMessageTypeFilter(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ engine := newMessageTypeParquetEngine(t)
+ ctx := context.Background()
+
+ filterOnly, err := engine.SearchFast(ctx, search.Parse("message_type:sms"), MessageFilter{}, 100, 0)
+ require.NoError(err, "SearchFast message_type only")
+ require.Len(filterOnly, 1, "filter-only message_type search")
+ assert.Equal("sms", filterOnly[0].MessageType)
+ assert.Equal("lunch plan", filterOnly[0].Subject)
+
+ withText, err := engine.SearchFast(ctx, search.Parse("message_type:sms lunch"), MessageFilter{}, 100, 0)
+ require.NoError(err, "SearchFast message_type with text")
+ require.Len(withText, 1, "message_type should scope text search")
+ assert.Equal("sms", withText[0].MessageType)
+ assert.Equal("lunch plan", withText[0].Subject)
+}
+
+func TestDuckDBEngine_GetTotalStatsMessageTypeSearch(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ engine := newMessageTypeParquetEngine(t)
+
+ stats, err := engine.GetTotalStats(context.Background(), StatsOptions{
+ SearchQuery: "message_type:sms",
+ })
+ require.NoError(err, "GetTotalStats")
+
+ assert.Equal(int64(1), stats.MessageCount, "message count")
+ assert.Equal(int64(321), stats.TotalSize, "total size")
+}
+
+func TestDuckDBEngine_AggregateMessageTypeSearch(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ engine := newMessageTypeParquetEngine(t)
+ opts := DefaultAggregateOptions()
+ opts.SearchQuery = "message_type:sms"
+
+ rows, err := engine.Aggregate(context.Background(), ViewTime, opts)
+ require.NoError(err, "Aggregate")
+ require.Len(rows, 1, "rows")
+
+ assert.Equal(int64(1), rows[0].Count, "count")
+ assert.Equal(int64(321), rows[0].TotalSize, "total size")
+}
+
// TestDuckDBEngine_ListMessages_DateFilter verifies that After/Before date filters
// work with DuckDB's TIMESTAMP column (regression: VARCHAR params need CAST).
func TestDuckDBEngine_ListMessages_DateFilter(t *testing.T) {
@@ -3266,6 +3341,13 @@ func TestDuckDBEngine_StaleParquetSchema(t *testing.T) {
assertpkg.Equal(t, "Stale Hello", results[0].Subject)
})
+ t.Run("SearchFastMessageTypeEmail", func(t *testing.T) {
+ q := search.Parse("message_type:email Stale")
+ results, err := engine.SearchFast(ctx, q, MessageFilter{}, 100, 0)
+ requirepkg.NoError(t, err, "SearchFast message_type:email with stale Parquet schema")
+ requirepkg.Len(t, results, 2)
+ })
+
t.Run("SearchFastCount", func(t *testing.T) {
q := search.Parse("Stale")
count, err := engine.SearchFastCount(ctx, q, MessageFilter{})
@@ -3285,6 +3367,12 @@ func TestDuckDBEngine_StaleParquetSchema(t *testing.T) {
assertpkg.Equal(t, int64(2), stats.MessageCount)
})
+ t.Run("GetTotalStatsMessageTypeEmail", func(t *testing.T) {
+ stats, err := engine.GetTotalStats(ctx, StatsOptions{SearchQuery: "message_type:email"})
+ requirepkg.NoError(t, err, "GetTotalStats message_type:email with stale Parquet schema")
+ assertpkg.Equal(t, int64(2), stats.MessageCount)
+ })
+
// Verify that optionalCols correctly detected the missing columns.
t.Run("ProbeDetectedMissing", func(t *testing.T) {
for _, col := range []struct{ table, col string }{
diff --git a/internal/query/sqlite.go b/internal/query/sqlite.go
index 88d78b8d..9670df1f 100644
--- a/internal/query/sqlite.go
+++ b/internal/query/sqlite.go
@@ -1497,6 +1497,15 @@ func (e *SQLiteEngine) buildSearchQueryParts(ctx context.Context, q *search.Quer
}
}
+ if len(q.MessageTypes) > 0 {
+ placeholders := make([]string, len(q.MessageTypes))
+ for i, typ := range q.MessageTypes {
+ placeholders[i] = "?"
+ args = append(args, strings.ToLower(typ))
+ }
+ conditions = append(conditions, fmt.Sprintf("m.message_type IN (%s)", strings.Join(placeholders, ",")))
+ }
+
// Has attachment filter
if q.HasAttachment != nil && *q.HasAttachment {
conditions = append(conditions, e.dialect.BoolTrueExpr("m.has_attachments"))
diff --git a/internal/query/sqlite_crud_test.go b/internal/query/sqlite_crud_test.go
index f9d365ff..13e0ea66 100644
--- a/internal/query/sqlite_crud_test.go
+++ b/internal/query/sqlite_crud_test.go
@@ -1100,6 +1100,31 @@ func TestGetTotalStatsWithSearchQuery_Combined(t *testing.T) {
assertpkg.Equal(t, int64(2000), stats.TotalSize, "SearchQuery+WithAttachments total size")
}
+func TestSearchFastWithStats_MessageTypeStats(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ env := newTestEnv(t)
+
+ smsID := env.AddMessage(dbtest.MessageOpts{
+ Subject: "Lunch via SMS",
+ SentAt: "2024-04-01 10:00:00",
+ SizeEstimate: 321,
+ })
+ _, err := env.DB.Exec(`UPDATE messages SET message_type = 'sms' WHERE id = ?`, smsID)
+ require.NoError(err, "set sms message_type")
+
+ q := search.Parse("message_type:sms")
+ result, err := env.Engine.SearchFastWithStats(env.Ctx, q, "message_type:sms", MessageFilter{}, ViewSenders, 100, 0)
+ require.NoError(err, "SearchFastWithStats")
+ require.NotNil(result.Stats, "stats")
+
+ require.Len(result.Messages, 1, "messages")
+ assert.Equal(smsID, result.Messages[0].ID, "message id")
+ assert.Equal(int64(1), result.TotalCount, "total count")
+ assert.Equal(int64(1), result.Stats.MessageCount, "stats message count")
+ assert.Equal(int64(321), result.Stats.TotalSize, "stats total size")
+}
+
func TestGetMessageRaw(t *testing.T) {
env := newTestEnv(t)
rawMIME := []byte("From: test@example.com\r\nSubject: Test\r\n\r\nHello")
diff --git a/internal/query/sqlite_search_test.go b/internal/query/sqlite_search_test.go
index 54c3cf42..f321e857 100644
--- a/internal/query/sqlite_search_test.go
+++ b/internal/query/sqlite_search_test.go
@@ -10,6 +10,7 @@ import (
assertpkg "github.com/stretchr/testify/assert"
requirepkg "github.com/stretchr/testify/require"
"go.kenn.io/msgvault/internal/search"
+ "go.kenn.io/msgvault/internal/testutil/dbtest"
"go.kenn.io/msgvault/internal/testutil/ptr"
)
@@ -94,20 +95,6 @@ func TestSearch_Filters(t *testing.T) {
}
}
-func TestSearch_MessageTypeFilter(t *testing.T) {
- require := requirepkg.New(t)
- assert := assertpkg.New(t)
- env := newTestEnv(t)
-
- _, err := env.DB.Exec(`UPDATE messages SET message_type = ? WHERE id = ?`, "sms", int64(2))
- require.NoError(err, "mark message as sms")
-
- results := env.MustSearch(search.Parse("message_type:sms Hello"), 100, 0)
- require.Len(results, 1, "message_type:sms should scope the text search")
- assert.Equal(int64(2), results[0].ID, "ID")
- assert.Equal("sms", results[0].MessageType, "MessageType")
-}
-
func TestSearch_CaseInsensitiveFallback(t *testing.T) {
env := newTestEnv(t)
@@ -148,6 +135,41 @@ func TestSearch_WithFTS(t *testing.T) {
assertpkg.Equal(t, "Hello World", results[0].Subject)
}
+func TestSearch_MessageTypeFilter(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ env := newTestEnv(t)
+ aliceID := env.MustLookupParticipant("alice@example.com")
+ bobID := env.MustLookupParticipant("bob@company.org")
+ smsID := env.AddMessage(dbtest.MessageOpts{
+ Subject: "lunch plan",
+ SentAt: "2024-04-10 10:00:00",
+ FromID: aliceID,
+ ToIDs: []int64{bobID},
+ })
+ emailID := env.AddMessage(dbtest.MessageOpts{
+ Subject: "lunch receipt",
+ SentAt: "2024-04-11 10:00:00",
+ FromID: aliceID,
+ ToIDs: []int64{bobID},
+ })
+ _, err := env.DB.Exec(`UPDATE messages SET message_type = 'sms' WHERE id = ?`, smsID)
+ require.NoError(err, "set sms message_type")
+ _, err = env.DB.Exec(`UPDATE messages SET message_type = 'email' WHERE id = ?`, emailID)
+ require.NoError(err, "set email message_type")
+
+ results := env.MustSearch(search.Parse("message_type:sms"), 100, 0)
+ require.Len(results, 1, "filter-only message_type search")
+ assert.Equal(smsID, results[0].ID)
+ assert.Equal("sms", results[0].MessageType)
+
+ env.EnableFTS()
+ results = env.MustSearch(search.Parse("message_type:sms lunch"), 100, 0)
+ require.Len(results, 1, "message_type must scope FTS search")
+ assert.Equal(smsID, results[0].ID)
+ assert.Equal("sms", results[0].MessageType)
+}
+
// TestSearch_WithFTS_SpecialChars verifies that FTS5 special characters in
// search terms don't cause syntax errors. Without quoting, these characters
// are interpreted as FTS5 operators (- = NOT, : = column filter, () = grouping).
diff --git a/internal/remote/engine_test.go b/internal/remote/engine_test.go
index dae7242a..67cfffff 100644
--- a/internal/remote/engine_test.go
+++ b/internal/remote/engine_test.go
@@ -95,7 +95,6 @@ func TestEngineGetMessageSummariesByIDs_CarriesFromAndAttachmentCount(t *testing
assert.Equal(2, s.AttachmentCount, "AttachmentCount")
assert.True(s.HasAttachments, "HasAttachments")
}
-
func TestEngineSearchSerializesMessageTypes(t *testing.T) {
require := requirepkg.New(t)
assert := assertpkg.New(t)
@@ -125,3 +124,18 @@ func TestEngineSearchSerializesMessageTypes(t *testing.T) {
}, 10, 0)
require.NoError(err, "Search")
}
+
+func TestBuildSearchQueryStringIncludesMessageTypes(t *testing.T) {
+ assert := assertpkg.New(t)
+
+ assert.Equal(
+ "message_type:sms",
+ buildSearchQueryString(search.Parse("message_type:sms")),
+ "filter-only message type query",
+ )
+ assert.Equal(
+ "lunch message_type:sms",
+ buildSearchQueryString(search.Parse("message_type:sms lunch")),
+ "message type with text term",
+ )
+}
diff --git a/internal/scheduler/embed_job.go b/internal/scheduler/embed_job.go
index cecea362..bd6f1b1b 100644
--- a/internal/scheduler/embed_job.go
+++ b/internal/scheduler/embed_job.go
@@ -34,6 +34,10 @@ type EmbedCoverage interface {
MissingCount(ctx context.Context, activeGen int64) (int64, error)
}
+type ScopedEmbedCoverage interface {
+ MissingCountScoped(ctx context.Context, activeGen int64, messageTypes []string) (int64, error)
+}
+
// Compile-time check that the production worker satisfies EmbedRunner.
var _ EmbedRunner = (*embed.Worker)(nil)
@@ -82,6 +86,10 @@ type EmbedJob struct {
// A negative value disables the auto-backstop entirely.
BackstopInterval time.Duration
+ // BuildScope limits coverage checks to the same message universe the
+ // worker scans for this generation. Empty means the full live corpus.
+ BuildScope vector.BuildScope
+
// Now returns the current time; overridable in tests to drive the
// backstop interval deterministically. nil uses time.Now.
Now func() time.Time
@@ -182,7 +190,7 @@ func (j *EmbedJob) Run(ctx context.Context) {
"gen", target)
return
}
- missing, err := j.Store.MissingCount(ctx, int64(target))
+ missing, err := j.missingCount(ctx, target)
if err != nil {
log.Warn("embed: coverage count after run failed", "gen", target, "error", err)
return
@@ -201,6 +209,17 @@ func (j *EmbedJob) Run(ctx context.Context) {
log.Info("embed: building generation activated", "gen", target)
}
+func (j *EmbedJob) missingCount(ctx context.Context, target vector.GenerationID) (int64, error) {
+ scope := vector.NewBuildScope(j.BuildScope.MessageTypes)
+ if scope.IsEmpty() {
+ return j.Store.MissingCount(ctx, int64(target))
+ }
+ if scoped, ok := j.Store.(ScopedEmbedCoverage); ok {
+ return scoped.MissingCountScoped(ctx, int64(target), scope.MessageTypes)
+ }
+ return 0, errors.New("embed coverage store does not support scoped missing counts")
+}
+
// maybeRunBackstop runs a full watermark-ignoring backstop pass on gen when
// BackstopInterval has elapsed since this generation's last one, then records
// the time. The throttle is keyed per generation so a recent backstop of one
diff --git a/internal/store/api_search_test.go b/internal/store/api_search_test.go
index 32bea8bc..d3808468 100644
--- a/internal/store/api_search_test.go
+++ b/internal/store/api_search_test.go
@@ -148,7 +148,6 @@ func TestSearchMessages_LegacyRawString(t *testing.T) {
})
}
}
-
func TestSearchMessagesQuery_MessageTypeFilter(t *testing.T) {
require := requirepkg.New(t)
assert := assertpkg.New(t)
diff --git a/internal/store/embed_gen.go b/internal/store/embed_gen.go
index e7517283..3ba69995 100644
--- a/internal/store/embed_gen.go
+++ b/internal/store/embed_gen.go
@@ -33,16 +33,27 @@ var embedGenStampChunkRows = 500
// the embeddings upsert — the worker orders the steps (upsert, then
// stamp) and relies on idempotency, see internal/vector/embed/worker.go.
func (s *Store) ScanForEmbedding(ctx context.Context, target int64, afterID int64, limit int) ([]int64, error) {
+ return s.ScanForEmbeddingScoped(ctx, target, afterID, limit, nil)
+}
+
+// ScanForEmbeddingScoped is ScanForEmbedding limited to the supplied message
+// types. An empty messageTypes slice means the full live corpus.
+func (s *Store) ScanForEmbeddingScoped(ctx context.Context, target int64, afterID int64, limit int, messageTypes []string) ([]int64, error) {
if limit <= 0 {
return nil, nil
}
+ liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes)
q := `SELECT id FROM messages
WHERE (embed_gen IS NULL OR embed_gen <> ?)
- AND ` + LiveMessagesWhere("", true) + `
+ AND ` + liveWhere + `
AND id > ?
ORDER BY id
LIMIT ?`
- rows, err := s.db.QueryContext(ctx, q, target, afterID, limit)
+ args := make([]any, 0, 3+len(liveArgs))
+ args = append(args, target)
+ args = append(args, liveArgs...)
+ args = append(args, afterID, limit)
+ rows, err := s.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, fmt.Errorf("scan for embedding: %w", err)
}
@@ -239,14 +250,22 @@ func (s *Store) ResetEmbedGen(ctx context.Context, ids []int64) error {
// activeGen == 0 means "no active/target generation"; then everything
// live is missing and stamped is 0.
func (s *Store) CoverageCounts(ctx context.Context, activeGen int64) (live, stamped, blank, missing int64, err error) {
- live, err = s.countLiveMessages(ctx)
+ return s.CoverageCountsScoped(ctx, activeGen, nil)
+}
+
+// CoverageCountsScoped is CoverageCounts limited to the supplied message types.
+// An empty messageTypes slice means the full live corpus.
+func (s *Store) CoverageCountsScoped(ctx context.Context, activeGen int64, messageTypes []string) (live, stamped, blank, missing int64, err error) {
+ live, err = s.countLiveMessagesScoped(ctx, messageTypes)
if err != nil {
return 0, 0, 0, 0, err
}
if activeGen != 0 {
+ liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes)
q := `SELECT COUNT(*) FROM messages
- WHERE embed_gen = ? AND ` + LiveMessagesWhere("", true)
- if err := s.db.QueryRowContext(ctx, q, activeGen).Scan(&stamped); err != nil {
+ WHERE embed_gen = ? AND ` + liveWhere
+ args := append([]any{activeGen}, liveArgs...)
+ if err := s.db.QueryRowContext(ctx, q, args...).Scan(&stamped); err != nil {
return 0, 0, 0, 0, fmt.Errorf("count stamped: %w", err)
}
}
@@ -259,7 +278,13 @@ func (s *Store) CoverageCounts(ctx context.Context, activeGen int64) (live, stam
// activeGen). It is a thin accessor for the scheduler/CLI activation
// gates, which only consult the missing count; missing = live - stamped.
func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error) {
- live, err := s.countLiveMessages(ctx)
+ return s.MissingCountScoped(ctx, activeGen, nil)
+}
+
+// MissingCountScoped is MissingCount limited to the supplied message types.
+// An empty messageTypes slice means the full live corpus.
+func (s *Store) MissingCountScoped(ctx context.Context, activeGen int64, messageTypes []string) (int64, error) {
+ live, err := s.countLiveMessagesScoped(ctx, messageTypes)
if err != nil {
return 0, err
}
@@ -267,9 +292,11 @@ func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error
return live, nil
}
var stamped int64
+ liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes)
q := `SELECT COUNT(*) FROM messages
- WHERE embed_gen = ? AND ` + LiveMessagesWhere("", true)
- if err := s.db.QueryRowContext(ctx, q, activeGen).Scan(&stamped); err != nil {
+ WHERE embed_gen = ? AND ` + liveWhere
+ args := append([]any{activeGen}, liveArgs...)
+ if err := s.db.QueryRowContext(ctx, q, args...).Scan(&stamped); err != nil {
return 0, fmt.Errorf("count stamped: %w", err)
}
return max(live-stamped, 0), nil
@@ -277,11 +304,48 @@ func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error
// countLiveMessages returns the total live-message count. Shared by
// CoverageCounts; kept separate so the live-predicate stays in one place.
-func (s *Store) countLiveMessages(ctx context.Context) (int64, error) {
+func (s *Store) countLiveMessagesScoped(ctx context.Context, messageTypes []string) (int64, error) {
var n int64
- q := `SELECT COUNT(*) FROM messages WHERE ` + LiveMessagesWhere("", true)
- if err := s.db.QueryRowContext(ctx, q).Scan(&n); err != nil {
+ liveWhere, args := liveMessagesWhereWithMessageTypes(messageTypes)
+ q := `SELECT COUNT(*) FROM messages WHERE ` + liveWhere
+ if err := s.db.QueryRowContext(ctx, q, args...).Scan(&n); err != nil {
return 0, fmt.Errorf("count live messages: %w", err)
}
return n, nil
}
+
+func liveMessagesWhereWithMessageTypes(messageTypes []string) (string, []any) {
+ where := LiveMessagesWhere("", true)
+ types := normalizeMessageTypes(messageTypes)
+ if len(types) == 0 {
+ return where, nil
+ }
+ placeholders := make([]string, len(types))
+ args := make([]any, len(types))
+ for i, typ := range types {
+ placeholders[i] = "?"
+ args[i] = typ
+ }
+ where += " AND message_type IN (" + strings.Join(placeholders, ",") + ")"
+ return where, args
+}
+
+func normalizeMessageTypes(messageTypes []string) []string {
+ if len(messageTypes) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(messageTypes))
+ out := make([]string, 0, len(messageTypes))
+ for _, typ := range messageTypes {
+ typ = strings.TrimSpace(strings.ToLower(typ))
+ if typ == "" {
+ continue
+ }
+ if _, ok := seen[typ]; ok {
+ continue
+ }
+ seen[typ] = struct{}{}
+ out = append(out, typ)
+ }
+ return out
+}
diff --git a/internal/synctechsms/importer.go b/internal/synctechsms/importer.go
index c16141e7..ec6d5e24 100644
--- a/internal/synctechsms/importer.go
+++ b/internal/synctechsms/importer.go
@@ -98,39 +98,47 @@ func (i *Importer) importRecord(sourceID int64, record Record, summary *ImportSu
switch record.Kind {
case RecordSMS:
if i.opts.IncludeSMS && record.SMS != nil {
- if err := i.importSMS(sourceID, *record.SMS); err != nil {
+ msgID, err := i.importSMS(sourceID, *record.SMS)
+ if err != nil {
return err
}
+ summary.MessageIDs = append(summary.MessageIDs, msgID)
summary.SMSImported++
}
case RecordMMS:
if i.opts.IncludeMMS && record.MMS != nil {
- attachments, err := i.importMMS(sourceID, *record.MMS)
+ msgID, attachments, err := i.importMMS(sourceID, *record.MMS)
if err != nil {
+ if msgID != 0 {
+ summary.MessageIDs = append(summary.MessageIDs, msgID)
+ }
return err
}
+ summary.MessageIDs = append(summary.MessageIDs, msgID)
summary.MMSImported++
summary.AttachmentsImported += attachments
}
case RecordCall:
if i.opts.IncludeCalls && record.Call != nil {
- if err := i.importCall(sourceID, *record.Call); err != nil {
+ msgID, err := i.importCall(sourceID, *record.Call)
+ if err != nil {
return err
}
+ summary.MessageIDs = append(summary.MessageIDs, msgID)
summary.CallsImported++
}
}
return nil
}
-func (i *Importer) importSMS(sourceID int64, sms SMS) error {
+func (i *Importer) importSMS(sourceID int64, sms SMS) (int64, error) {
remoteID, err := i.participantID(sms.Address, sms.ContactName.String)
if err != nil {
- return err
+ return 0, err
}
ownerID, err := i.participantID(i.opts.OwnerPhone, "Me")
if err != nil {
- return err
+ return 0, err
}
// Drafts are owner-authored messages that never made it out, but
// they still belong on the owner's side of the conversation. Without
@@ -144,36 +152,36 @@ func (i *Importer) importSMS(sourceID int64, sms SMS) error {
}
convID, err := i.ensureConversation(sourceID, textConversationKey([]int64{ownerID, remoteID}), sms.ContactName.String)
if err != nil {
- return err
+ return 0, err
}
if err := i.store.EnsureConversationParticipant(convID, remoteID, "member"); err != nil {
- return err
+ return 0, err
}
if err := i.store.EnsureConversationParticipant(convID, ownerID, "member"); err != nil {
- return err
+ return 0, err
}
msgID := stableID("sms", sms.Address, sms.Timestamp.String(), fmt.Sprint(sms.Type), sms.Body)
return i.upsertTextMessage(sourceID, convID, msgID, "sms", senderID, recipientIDs, fromMe, sms.Timestamp, sms.Body, sms.Body, 0, sms)
}
-func (i *Importer) importMMS(sourceID int64, mms MMS) (int, error) {
+func (i *Importer) importMMS(sourceID int64, mms MMS) (int64, int, error) {
ownerID, err := i.participantID(i.opts.OwnerPhone, "Me")
if err != nil {
- return 0, err
+ return 0, 0, err
}
participantIDs, senderID, recipientIDs, err := i.mmsParticipants(mms, ownerID)
if err != nil {
- return 0, err
+ return 0, 0, err
}
// Drafts belong to the owner — see the matching note in importSMS.
fromMe := mms.MessageBox == MMSBoxSent || mms.MessageBox == MMSBoxOutbox || mms.MessageBox == MMSBoxDraft
convID, err := i.ensureConversation(sourceID, textConversationKey(participantIDs), mms.ContactName.String)
if err != nil {
- return 0, err
+ return 0, 0, err
}
for _, participantID := range participantIDs {
if err := i.store.EnsureConversationParticipant(convID, participantID, "member"); err != nil {
- return 0, err
+ return 0, 0, err
}
}
body := mmsText(mms)
@@ -183,21 +191,26 @@ func (i *Importer) importMMS(sourceID int64, mms MMS) (int, error) {
}
msgID := stableID("mms", srcIDPart, mms.Timestamp.String(), sortedKey(participantIDs))
attachmentCount := countImportableAttachments(mms, i.opts.IncludeAttachments)
- if err := i.upsertTextMessage(sourceID, convID, msgID, "mms", senderID, recipientIDs, fromMe, mms.Timestamp, body, mms.Subject.String, attachmentCount, mms); err != nil {
- return 0, err
+ messageID, err := i.upsertTextMessage(sourceID, convID, msgID, "mms", senderID, recipientIDs, fromMe, mms.Timestamp, body, mms.Subject.String, attachmentCount, mms)
+ if err != nil {
+ return 0, 0, err
}
- return i.importMMSAttachments(sourceID, msgID, mms)
+ imported, err := i.importMMSAttachments(sourceID, msgID, mms)
+ if err != nil {
+ return messageID, imported, err
+ }
+ return messageID, imported, nil
}
-func (i *Importer) importCall(sourceID int64, call Call) error {
+func (i *Importer) importCall(sourceID int64, call Call) (int64, error) {
remoteAddress := callParticipantAddress(call)
remoteID, err := i.participantID(remoteAddress, call.ContactName.String)
if err != nil {
- return err
+ return 0, err
}
ownerID, err := i.participantID(i.opts.OwnerPhone, "Me")
if err != nil {
- return err
+ return 0, err
}
fromMe := call.Type == CallOutgoing
senderID := remoteID
@@ -208,20 +221,20 @@ func (i *Importer) importCall(sourceID int64, call Call) error {
}
convID, err := i.ensureConversation(sourceID, "calls:"+canonicalAddress(remoteAddress), call.ContactName.String)
if err != nil {
- return err
+ return 0, err
}
if err := i.store.EnsureConversationParticipant(convID, remoteID, "member"); err != nil {
- return err
+ return 0, err
}
if err := i.store.EnsureConversationParticipant(convID, ownerID, "member"); err != nil {
- return err
+ return 0, err
}
body := fmt.Sprintf("Call %s, %d seconds", callTypeLabel(call.Type), call.DurationSeconds)
msgID := stableID("call", remoteAddress, call.Timestamp.String(), fmt.Sprint(call.Type), strconv.Itoa(call.DurationSeconds))
return i.upsertTextMessage(sourceID, convID, msgID, "synctech_sms_call", senderID, recipientIDs, fromMe, call.Timestamp, body, body, 0, call)
}
-func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, messageType string, senderID int64, recipientIDs []int64, fromMe bool, sentAt time.Time, body, subject string, attachmentCount int, raw any) error {
+func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, messageType string, senderID int64, recipientIDs []int64, fromMe bool, sentAt time.Time, body, subject string, attachmentCount int, raw any) (int64, error) {
msgID, err := i.store.UpsertMessage(&store.Message{
ConversationID: convID,
SourceID: sourceID,
@@ -237,33 +250,33 @@ func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, me
AttachmentCount: attachmentCount,
})
if err != nil {
- return fmt.Errorf("upsert message: %w", err)
+ return 0, fmt.Errorf("upsert message: %w", err)
}
if body != "" {
if err := i.store.UpsertMessageBody(msgID, sql.NullString{String: body, Valid: true}, sql.NullString{}); err != nil {
- return fmt.Errorf("upsert body: %w", err)
+ return 0, fmt.Errorf("upsert body: %w", err)
}
}
rawJSON, err := json.Marshal(raw)
if err != nil {
- return fmt.Errorf("marshal raw record: %w", err)
+ return 0, fmt.Errorf("marshal raw record: %w", err)
}
if err := i.store.UpsertMessageRawWithFormat(msgID, rawJSON, RawFormat); err != nil {
- return fmt.Errorf("upsert raw record: %w", err)
+ return 0, fmt.Errorf("upsert raw record: %w", err)
}
if err := i.store.ReplaceMessageRecipients(msgID, "from", []int64{senderID}, []string{""}); err != nil {
- return fmt.Errorf("replace from recipient: %w", err)
+ return 0, fmt.Errorf("replace from recipient: %w", err)
}
if err := i.store.ReplaceMessageRecipients(msgID, "to", recipientIDs, blankNames(len(recipientIDs))); err != nil {
- return fmt.Errorf("replace to recipient: %w", err)
+ return 0, fmt.Errorf("replace to recipient: %w", err)
}
if err := i.store.UpsertFTS(msgID, subject, body, "", "", ""); err != nil {
- // FTS is an index, not data: a failure to populate it must never abort
- // the import. Warn and continue, matching the other UpsertFTS callers
- // (sync.go, importer/ingest.go, fbmessenger, whatsapp). [C2]
- slog.Warn("failed to upsert FTS", "message", msgID, "error", err)
+ slog.Warn("failed to update Synctech message FTS index",
+ "message_id", msgID,
+ "message_type", messageType,
+ "error", err)
}
- return nil
+ return msgID, nil
}
func (i *Importer) participantID(address, displayName string) (int64, error) {
diff --git a/internal/synctechsms/importer_test.go b/internal/synctechsms/importer_test.go
index 769ebefe..ac6a1f1d 100644
--- a/internal/synctechsms/importer_test.go
+++ b/internal/synctechsms/importer_test.go
@@ -7,6 +7,7 @@ import (
assertpkg "github.com/stretchr/testify/assert"
requirepkg "github.com/stretchr/testify/require"
"go.kenn.io/msgvault/internal/store"
+ "go.kenn.io/msgvault/internal/testutil"
"go.kenn.io/msgvault/internal/testutil/storetest"
)
@@ -84,6 +85,60 @@ func TestImporterImportsCallWithBlankNumber(t *testing.T) {
assertMessageCount(t, f.Store, 1)
}
+func TestImporterContinuesWhenFTSUpsertFails(t *testing.T) {
+ testutil.SkipIfPostgres(t, "drops the SQLite FTS virtual table to force UpsertFTS failure")
+ f := storetest.New(t)
+ _, err := f.Store.DB().Exec(`DROP TABLE messages_fts`)
+ requirepkg.NoError(t, err, "drop messages_fts")
+ dir := t.TempDir()
+ writeFile(t, filepath.Join(dir, "messages.xml"), `
+
+`)
+
+ imp := NewImporter(f.Store, ImportOptions{
+ OwnerPhone: "+15550000001",
+ IncludeSMS: true,
+ })
+ summary, err := imp.ImportPath(dir)
+ requirepkg.NoError(t, err, "ImportPath should tolerate FTS indexing failure")
+ assertpkg.Equal(t, 1, summary.SMSImported, "summary = %#v", summary)
+ assertMessageCount(t, f.Store, 1)
+}
+
+func TestImporterReturnsMMSMessageIDWhenAttachmentImportFails(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ f := storetest.New(t)
+ dir := t.TempDir()
+ writeFile(t, filepath.Join(dir, "messages.xml"), `
+
+
+
+
+
+
+
+
+
+
+`)
+
+ imp := NewImporter(f.Store, ImportOptions{
+ OwnerPhone: "+15550000001",
+ AttachmentsDir: filepath.Join(dir, "attachments"),
+ MaxAttachmentBytes: 1,
+ IncludeMMS: true,
+ IncludeAttachments: true,
+ })
+ summary, err := imp.ImportPath(dir)
+ require.Error(err, "ImportPath")
+ assert.Contains(err.Error(), "exceeds maximum size", "ImportPath error")
+
+ assertMessageCount(t, f.Store, 1)
+ require.Len(summary.MessageIDs, 1, "summary message IDs")
+ assert.Positive(summary.MessageIDs[0], "summary message id")
+}
+
// TestImporterMarksDraftsAsFromMe guards against regressing the
// draft-handling fix. SMSTypeDraft and MMSBoxDraft are owner-authored
// messages; treating them as incoming hides them on the wrong side of
diff --git a/internal/synctechsms/types.go b/internal/synctechsms/types.go
index c2d5d8eb..1ac42744 100644
--- a/internal/synctechsms/types.go
+++ b/internal/synctechsms/types.go
@@ -160,4 +160,5 @@ type ImportSummary struct {
MMSImported int
CallsImported int
AttachmentsImported int
+ MessageIDs []int64
}
diff --git a/internal/vector/backend.go b/internal/vector/backend.go
index 97ec05c8..2d8a2cb4 100644
--- a/internal/vector/backend.go
+++ b/internal/vector/backend.go
@@ -181,18 +181,18 @@ type Backend interface {
Delete(ctx context.Context, gen GenerationID, messageIDs []int64) error
Stats(ctx context.Context, gen GenerationID) (Stats, error)
- // EmbeddedMessageCount reports how many distinct LIVE, stamped
+ // EmbeddedMessageCount reports how many distinct in-scope LIVE, stamped
// (embed_gen == gen) messages actually have at least one embedding row
// for gen. This is the "embedded" leg of the coverage readout (live /
// embedded / blank / missing). It lives on the backend because the
// embeddings table is in vectors.db on SQLite (and the main DB on PG);
- // only the backend holds that handle. The live+stamped intersection is
- // REQUIRED for the coverage invariant to hold: SQLite intersects the
- // vectors.db embedding ids against a live+stamped query on the main DB
- // (cross-DB json_each, mirroring dropDeletedFromSource), while PostgreSQL
- // uses a single JOIN to messages. Distinct from Stats.EmbeddingCount only
- // in intent: this is the dedicated coverage helper and never folds the
- // aggregate (gen == 0) path.
+ // only the backend holds that handle. The live+stamped+scope intersection
+ // is REQUIRED for the coverage invariant to hold: SQLite intersects the
+ // vectors.db embedding ids against an in-scope live+stamped query on the
+ // main DB (cross-DB json_each, mirroring dropDeletedFromSource), while
+ // PostgreSQL uses a single JOIN to messages. Distinct from
+ // Stats.EmbeddingCount only in intent: this is the dedicated coverage
+ // helper and never folds the aggregate (gen == 0) path.
EmbeddedMessageCount(ctx context.Context, gen GenerationID) (int64, error)
// LoadVector returns the embedding for a specific message in the
diff --git a/internal/vector/build_scope.go b/internal/vector/build_scope.go
new file mode 100644
index 00000000..6b53cdbb
--- /dev/null
+++ b/internal/vector/build_scope.go
@@ -0,0 +1,65 @@
+package vector
+
+import (
+ "slices"
+ "sort"
+ "strings"
+)
+
+// BuildScope limits which messages are eligible for an embedding
+// generation. A zero-value scope means the full corpus.
+type BuildScope struct {
+ MessageTypes []string
+}
+
+// NewBuildScope returns a normalized, stable scope. Message types are
+// lowercase, trimmed, de-duplicated, and sorted so fingerprints and SQL
+// bindings are deterministic.
+func NewBuildScope(messageTypes []string) BuildScope {
+ seen := make(map[string]struct{}, len(messageTypes))
+ out := make([]string, 0, len(messageTypes))
+ for _, typ := range messageTypes {
+ typ = strings.TrimSpace(strings.ToLower(typ))
+ if typ == "" {
+ continue
+ }
+ if _, ok := seen[typ]; ok {
+ continue
+ }
+ seen[typ] = struct{}{}
+ out = append(out, typ)
+ }
+ sort.Strings(out)
+ return BuildScope{MessageTypes: out}
+}
+
+func (s BuildScope) IsEmpty() bool {
+ return len(s.MessageTypes) == 0
+}
+
+func (s BuildScope) Fingerprint() string {
+ if s.IsEmpty() {
+ return ""
+ }
+ return "mt-" + strings.Join(s.MessageTypes, ",")
+}
+
+func (s BuildScope) ContainsMessageType(messageType string) bool {
+ messageType = strings.TrimSpace(strings.ToLower(messageType))
+ return slices.Contains(s.MessageTypes, messageType)
+}
+
+func (s BuildScope) AllowsMessageTypes(messageTypes []string) bool {
+ if s.IsEmpty() {
+ return true
+ }
+ if len(messageTypes) == 0 {
+ return false
+ }
+ for _, typ := range messageTypes {
+ if !s.ContainsMessageType(typ) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/internal/vector/config.go b/internal/vector/config.go
index 5dd3ebd8..23ca8448 100644
--- a/internal/vector/config.go
+++ b/internal/vector/config.go
@@ -213,7 +213,8 @@ type EmbedConfig struct {
// stragglers from repair-encoding resets, transient errors, or crashes)
// in addition to the per-tick incremental scan. Zero uses the EmbedJob
// default (24h); a negative value disables the auto-backstop.
- BackstopInterval time.Duration `toml:"backstop_interval"`
+ BackstopInterval time.Duration `toml:"backstop_interval"`
+ Scope EmbedScopeConfig `toml:"scope"`
}
// EmbedScheduleConfig controls when the embed worker runs on its own
@@ -223,6 +224,16 @@ type EmbedScheduleConfig struct {
RunAfterSync bool `toml:"run_after_sync"` // trigger after each successful sync
}
+// EmbedScopeConfig limits which messages are included in newly-built
+// embedding generations. The zero value means the full corpus.
+type EmbedScopeConfig struct {
+ MessageTypes []string `toml:"message_types"`
+}
+
+func (s EmbedScopeConfig) BuildScope() BuildScope {
+ return NewBuildScope(s.MessageTypes)
+}
+
// Fingerprint returns the ":" identifier for the
// embedding endpoint half of the policy (§6.7 of the spec). Use
// Config.GenerationFingerprint when storing or comparing index
@@ -256,8 +267,12 @@ func (e EmbeddingsConfig) Fingerprint() string {
// in, an active generation built under the old single-vector policy
// would silently accept new chunked entries from an upgraded worker.
func (c *Config) GenerationFingerprint() string {
- return fmt.Sprintf("%s:%s:c%d:e%d",
+ fp := fmt.Sprintf("%s:%s:c%d:e%d",
c.Embeddings.Fingerprint(), c.Preprocess.Fingerprint(), c.Embeddings.MaxInputChars, embedPolicyVersion)
+ if scopeFP := c.Embed.Scope.BuildScope().Fingerprint(); scopeFP != "" {
+ fp = fmt.Sprintf("%s:s%s", fp, scopeFP)
+ }
+ return fp
}
// Validate returns a descriptive error if the config is unusable.
diff --git a/internal/vector/config_test.go b/internal/vector/config_test.go
index cba34145..dd76eee4 100644
--- a/internal/vector/config_test.go
+++ b/internal/vector/config_test.go
@@ -433,3 +433,18 @@ func TestConfig_GenerationFingerprint_IncludesEmbedPolicyVersion(t *testing.T) {
suffix := fmt.Sprintf(":e%d", embedPolicyVersion)
assertpkg.True(t, strings.HasSuffix(got, suffix), "GenerationFingerprint() = %q, want suffix %q", got, suffix)
}
+
+func TestConfig_GenerationFingerprint_IncludesEmbedScope(t *testing.T) {
+ base := Config{
+ Embeddings: EmbeddingsConfig{Model: "m", Dimension: 8, MaxInputChars: 6000},
+ }
+ baseline := base.GenerationFingerprint()
+
+ scoped := base
+ scoped.Embed.Scope.MessageTypes = []string{"MMS", "sms", "sms"}
+
+ assertpkg.NotEqual(t, baseline, scoped.GenerationFingerprint(),
+ "GenerationFingerprint should change when embedding is scoped")
+ assertpkg.Contains(t, scoped.GenerationFingerprint(), ":smt-mms,sms",
+ "scope fingerprint should normalize and sort message types")
+}
diff --git a/internal/vector/embed/worker.go b/internal/vector/embed/worker.go
index 7bb0943e..e1a180d0 100644
--- a/internal/vector/embed/worker.go
+++ b/internal/vector/embed/worker.go
@@ -64,6 +64,7 @@ type WorkerDeps struct {
Store WorkStore
Client EmbeddingClient
Preprocess PreprocessConfig
+ BuildScope vector.BuildScope
MaxInputChars int
BatchSize int
// beforeSkipStamp is a test hook for read-to-stamp race coverage.
@@ -309,7 +310,7 @@ func (w *Worker) run(ctx context.Context, gen vector.GenerationID, backstop bool
return res, fmt.Errorf("RunOnce: %w", err)
}
batchStart := time.Now()
- ids, err := w.deps.Store.ScanForEmbedding(ctx, int64(gen), afterID, w.deps.BatchSize)
+ ids, err := w.scanForEmbedding(ctx, int64(gen), afterID)
if err != nil {
return res, fmt.Errorf("scan for embedding: %w", err)
}
@@ -552,6 +553,20 @@ func (w *Worker) run(ctx context.Context, gen vector.GenerationID, backstop bool
}
}
+func (w *Worker) scanForEmbedding(ctx context.Context, gen int64, afterID int64) ([]int64, error) {
+ scope := vector.NewBuildScope(w.deps.BuildScope.MessageTypes)
+ if scope.IsEmpty() {
+ return w.deps.Store.ScanForEmbedding(ctx, gen, afterID, w.deps.BatchSize)
+ }
+ scoped, ok := w.deps.Store.(interface {
+ ScanForEmbeddingScoped(ctx context.Context, target int64, afterID int64, limit int, messageTypes []string) ([]int64, error)
+ })
+ if !ok {
+ return nil, errors.New("work store does not support scoped embedding scans")
+ }
+ return scoped.ScanForEmbeddingScoped(ctx, gen, afterID, w.deps.BatchSize, scope.MessageTypes)
+}
+
// advanceWatermark persists the per-gen forward-scan cursor to id after a
// batch made forward progress. The backstop never persists (it scans from
// 0 by design and must not push the optimistic watermark backward or
diff --git a/internal/vector/errors.go b/internal/vector/errors.go
index b4d5a5ec..03a7ba3e 100644
--- a/internal/vector/errors.go
+++ b/internal/vector/errors.go
@@ -64,4 +64,10 @@ var (
// "transient backend slow" response so clients can retry instead
// of treating it as a permanent failure.
ErrEmbeddingTimeout = errors.New("embedding request timed out")
+
+ // ErrIndexScopeMismatch is returned when a scoped embedding index
+ // is used without an equivalent structured filter. For example, an
+ // index built only for message_type=sms must not answer an unscoped
+ // vector query over email + SMS.
+ ErrIndexScopeMismatch = errors.New("index scope mismatch")
)
diff --git a/internal/vector/hybrid/engine.go b/internal/vector/hybrid/engine.go
index f264c1ca..412f0a46 100644
--- a/internal/vector/hybrid/engine.go
+++ b/internal/vector/hybrid/engine.go
@@ -63,7 +63,8 @@ type Config struct {
// participant/label lookup SQL that BuildFilter runs against mainDB.
// Pass PostgreSQLDialect.Rebind on PG (pgx rejects bare ?); leave nil
// (or SQLiteDialect.Rebind, which is identity) on SQLite.
- Rebind func(string) string
+ Rebind func(string) string
+ BuildScope vector.BuildScope
}
// Engine orchestrates the generation check, query embedding, and fusion
@@ -111,6 +112,9 @@ func (e *Engine) Search(ctx context.Context, req SearchRequest) ([]vector.FusedH
if err != nil {
return nil, ResultMeta{}, err
}
+ if err := e.validateBuildScope(req.Filter); err != nil {
+ return nil, ResultMeta{}, err
+ }
if req.FreeText == "" {
return nil, ResultMeta{}, errors.New("empty query")
@@ -192,6 +196,24 @@ func (e *Engine) Search(ctx context.Context, req SearchRequest) ([]vector.FusedH
}, nil
}
+func (e *Engine) validateBuildScope(filter vector.Filter) error {
+ scope := vector.NewBuildScope(e.cfg.BuildScope.MessageTypes)
+ if scope.IsEmpty() {
+ return nil
+ }
+ if len(filter.MessageTypes) == 0 {
+ return fmt.Errorf("%w: index is scoped to message_type=%s; add a matching message_type filter",
+ vector.ErrIndexScopeMismatch, strings.Join(scope.MessageTypes, ","))
+ }
+ if !scope.AllowsMessageTypes(filter.MessageTypes) {
+ return fmt.Errorf("%w: index is scoped to message_type=%s, query requested message_type=%s",
+ vector.ErrIndexScopeMismatch,
+ strings.Join(scope.MessageTypes, ","),
+ strings.Join(vector.NewBuildScope(filter.MessageTypes).MessageTypes, ","))
+ }
+ return nil
+}
+
// vectorHitsToFused wraps pure-vector hits in the FusedHit schema.
// BM25Score and RRFScore are both set to math.NaN(): "not present in
// this signal." Pure vector mode never applies Reciprocal Rank Fusion
diff --git a/internal/vector/hybrid/engine_test.go b/internal/vector/hybrid/engine_test.go
index 3eaef6da..8e424c18 100644
--- a/internal/vector/hybrid/engine_test.go
+++ b/internal/vector/hybrid/engine_test.go
@@ -165,6 +165,35 @@ func TestEngine_Hybrid_HappyPath(t *testing.T) {
assert.Equal(len(results), meta.ReturnedCount)
}
+func TestEngine_ScopedIndexRequiresMatchingMessageTypeFilter(t *testing.T) {
+ ctx := context.Background()
+ f := newEngineFixture(t)
+ f.Engine.cfg.BuildScope = vector.NewBuildScope([]string{"sms", "mms"})
+
+ _, _, err := f.Engine.Search(ctx, SearchRequest{
+ Mode: ModeVector,
+ FreeText: "lunch",
+ Limit: 5,
+ })
+ requirepkg.ErrorIs(t, err, vector.ErrIndexScopeMismatch)
+
+ _, _, err = f.Engine.Search(ctx, SearchRequest{
+ Mode: ModeVector,
+ FreeText: "lunch",
+ Limit: 5,
+ Filter: vector.Filter{MessageTypes: []string{"email"}},
+ })
+ requirepkg.ErrorIs(t, err, vector.ErrIndexScopeMismatch)
+
+ _, _, err = f.Engine.Search(ctx, SearchRequest{
+ Mode: ModeVector,
+ FreeText: "lunch",
+ Limit: 5,
+ Filter: vector.Filter{MessageTypes: []string{"sms"}},
+ })
+ requirepkg.NoError(t, err)
+}
+
// TestFTSTerms covers the FreeText → dialect-neutral term-slice
// tokenizer directly (no DB needed). FreeText is split on whitespace
// and terms the FTS5/tsquery tokenizers would drop entirely
diff --git a/internal/vector/pgvector/backend.go b/internal/vector/pgvector/backend.go
index fb0782e6..be9a9a23 100644
--- a/internal/vector/pgvector/backend.go
+++ b/internal/vector/pgvector/backend.go
@@ -39,6 +39,9 @@ type Options struct {
// per-dimension HNSW index on first migration. Optional; if zero
// the index is created on first CreateGeneration.
Dimension int
+ // BuildScope limits generation coverage to matching messages. Empty
+ // means the full corpus.
+ BuildScope vector.BuildScope
// SkipMigrate suppresses the privileged CREATE EXTENSION + full
// migrate. A WRITABLE open still applies the (extension-less) schema so
// the one-time upgrade lands — read-only-ness is now signalled by
@@ -74,7 +77,8 @@ type Options struct {
// with the pgvector extension. The same *sql.DB also serves the main
// msgvault schema (messages, message_recipients, message_labels).
type Backend struct {
- db *sql.DB
+ db *sql.DB
+ scope vector.BuildScope
}
// Open verifies the database is reachable, applies the embedding schema
@@ -85,7 +89,10 @@ func Open(ctx context.Context, opts Options) (*Backend, error) {
if opts.DB == nil {
return nil, errors.New("pgvector.Open: Options.DB is required")
}
- b := &Backend{db: opts.DB}
+ b := &Backend{
+ db: opts.DB,
+ scope: vector.NewBuildScope(opts.BuildScope.MessageTypes),
+ }
if !opts.SkipMigrate {
// serve / build / search: full migrate incl. CREATE EXTENSION (the
// extension step is gated by SkipExtension for managed PG). The eager
@@ -252,12 +259,22 @@ func isUniqueViolation(err error) bool {
// (embed_gen IS NULL OR embed_gen <> gen). Built once and reused by
// ActivateGeneration (in-tx, single-DB on PG) and Stats. The $N ordinal
// of the generation id is supplied by the caller.
-func missingForGenExistsClause(genArg string) string {
+func (b *Backend) missingForGenExistsClause(genArg string, firstScopeArg int) (string, []any) {
+ where := store.LiveMessagesWhere("", true)
+ args := make([]any, 0, len(b.scope.MessageTypes))
+ if !b.scope.IsEmpty() {
+ placeholders := make([]string, len(b.scope.MessageTypes))
+ for i, typ := range b.scope.MessageTypes {
+ placeholders[i] = "$" + strconv.Itoa(firstScopeArg+i)
+ args = append(args, typ)
+ }
+ where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ","))
+ }
return fmt.Sprintf(`EXISTS (
SELECT 1 FROM messages
WHERE (embed_gen IS NULL OR embed_gen <> %s)
AND %s
- )`, genArg, store.LiveMessagesWhere("", true))
+ )`, genArg, where), args
}
// ActivateGeneration atomically retires the current active generation
@@ -321,18 +338,20 @@ func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationI
// phase, which scan-and-fill no longer has, so a legacy/crashed gen with
// seeded_at=NULL but full coverage must be activatable. Coverage
// (missing==0) is the real gate.
+ missingClause, missingArgs := b.missingForGenExistsClause("$3", 5)
+ args := append([]any{now, now, int64(gen), force}, missingArgs...)
res, err := tx.ExecContext(ctx,
`UPDATE index_generations
SET state = 'active', activated_at = $1, completed_at = COALESCE(completed_at, $2)
WHERE id = $3 AND state = 'building'
- AND ($4 OR NOT `+missingForGenExistsClause("$3")+`)`,
- now, now, int64(gen), force)
+ AND ($4 OR NOT `+missingClause+`)`,
+ args...)
if err != nil {
return fmt.Errorf("activate: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
- return activateGateError(ctx, tx, gen, force)
+ return b.activateGateError(ctx, tx, gen, force)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit activate generation %d: %w", gen, err)
@@ -346,7 +365,7 @@ func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationI
// also satisfies the coverage predicate (embed_gen <> gen is true for an
// unknown gen id), so checking coverage first would surface the misleading
// "messages needing embedding" error instead of the real lifecycle reason.
-func activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, force bool) error {
+func (b *Backend) activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, force bool) error {
var state vector.GenerationState
if err := tx.QueryRowContext(ctx,
`SELECT state FROM index_generations WHERE id = $1`, int64(gen)).Scan(&state); err != nil {
@@ -361,8 +380,10 @@ func activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID,
// Gen exists and is building, so the only remaining reason the gated
// promote affected zero rows is the coverage term.
var missing bool
+ missingClause, missingArgs := b.missingForGenExistsClause("$1", 2)
+ args := append([]any{int64(gen)}, missingArgs...)
if err := tx.QueryRowContext(ctx,
- `SELECT `+missingForGenExistsClause("$1"), int64(gen)).Scan(&missing); err != nil {
+ `SELECT `+missingClause, args...).Scan(&missing); err != nil {
return fmt.Errorf("check coverage for generation %d: %w", gen, err)
}
if missing && !force {
@@ -1206,9 +1227,9 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St
return s, nil
}
-// EmbeddedMessageCount returns the number of LIVE messages that are
-// stamped for gen (embed_gen = gen) AND actually have at least one vector
-// for the generation. Used by the coverage readout to split stamped
+// EmbeddedMessageCount returns the number of in-scope LIVE messages that
+// are stamped for gen (embed_gen = gen) AND actually have at least one
+// vector for the generation. Used by the coverage readout to split stamped
// messages into embedded vs blank. Counts distinct messages (not chunk
// rows) so a long, multi-chunk message counts once, matching the
// EmbeddingCount semantic elsewhere.
@@ -1226,15 +1247,25 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St
// live intersection is a single JOIN against messages, mirroring
// store.LiveMessagesWhere's predicate.
func (b *Backend) EmbeddedMessageCount(ctx context.Context, gen vector.GenerationID) (int64, error) {
+ where := `e.generation_id = $1
+ AND m.embed_gen = $1
+ AND ` + store.LiveMessagesWhere("m", true)
+ args := []any{int64(gen)}
+ if !b.scope.IsEmpty() {
+ placeholders := make([]string, len(b.scope.MessageTypes))
+ for i, typ := range b.scope.MessageTypes {
+ placeholders[i] = "$" + strconv.Itoa(2+i)
+ args = append(args, typ)
+ }
+ where += fmt.Sprintf(" AND m.message_type IN (%s)", strings.Join(placeholders, ","))
+ }
var n int64
if err := b.db.QueryRowContext(ctx,
`SELECT COUNT(DISTINCT e.message_id)
FROM embeddings e
JOIN messages m ON m.id = e.message_id
- WHERE e.generation_id = $1
- AND m.embed_gen = $1
- AND `+store.LiveMessagesWhere("m", true),
- int64(gen)).Scan(&n); err != nil {
+ WHERE `+where,
+ args...).Scan(&n); err != nil {
return 0, fmt.Errorf("count embedded messages: %w", err)
}
return n, nil
diff --git a/internal/vector/pgvector/backend_filter_test.go b/internal/vector/pgvector/backend_filter_test.go
index dfebe205..2dda1024 100644
--- a/internal/vector/pgvector/backend_filter_test.go
+++ b/internal/vector/pgvector/backend_filter_test.go
@@ -3,6 +3,7 @@
package pgvector
import (
+ "fmt"
"testing"
"time"
@@ -11,6 +12,20 @@ import (
"go.kenn.io/msgvault/internal/vector"
)
+func TestBuildPGFilterClausesMessageTypes(t *testing.T) {
+ var args []any
+ bind := func(v any) string {
+ args = append(args, v)
+ return fmt.Sprintf("$%d", len(args))
+ }
+
+ clauses := buildPGFilterClauses(vector.Filter{MessageTypes: []string{"sms", "mms"}}, bind)
+
+ require.Len(t, clauses, 1)
+ assert.Equal(t, "m.message_type = ANY($1::text[])", clauses[0])
+ assert.Equal(t, []any{`{"sms","mms"}`}, args)
+}
+
func TestBackendSearchStructuredFilters(t *testing.T) {
b, ctx, db := newBackendForTest(t)
gen := seedAndEmbed(t, b, db, map[int64][]float32{
@@ -103,6 +118,11 @@ func TestBackendSearchStructuredFilters(t *testing.T) {
filter: vector.Filter{LargerThan: &largerThan, SmallerThan: &smallerThan},
want: []int64{2},
},
+ {
+ name: "message type",
+ filter: vector.Filter{MessageTypes: []string{"sms"}},
+ want: []int64{2},
+ },
{
name: "no match sentinel",
filter: vector.Filter{SenderGroups: [][]int64{{-1}}},
@@ -129,6 +149,21 @@ func TestBackendSearchStructuredFilters(t *testing.T) {
}
}
+func TestBackendSearchMessageTypeFilter(t *testing.T) {
+ b, ctx, db := newBackendForTest(t)
+ gen := seedAndEmbed(t, b, db, map[int64][]float32{
+ 1: unitVec(4, 0),
+ 2: unitVec(4, 1),
+ 3: unitVec(4, 2),
+ })
+ _, err := db.ExecContext(ctx, `UPDATE messages SET message_type = CASE id WHEN 1 THEN 'email' ELSE 'sms' END`)
+ require.NoError(t, err, "seed message_type")
+
+ hits, err := b.Search(ctx, gen, unitVec(4, 0), 10, vector.Filter{MessageTypes: []string{"sms"}})
+ require.NoError(t, err, "Search")
+ assert.Equal(t, []int64{2, 3}, hitMessageIDs(hits))
+}
+
func hitMessageIDs(hits []vector.Hit) []int64 {
out := make([]int64, len(hits))
for i, h := range hits {
diff --git a/internal/vector/pgvector/coverage_test.go b/internal/vector/pgvector/coverage_test.go
index 2293b8d9..ff177ffd 100644
--- a/internal/vector/pgvector/coverage_test.go
+++ b/internal/vector/pgvector/coverage_test.go
@@ -90,3 +90,66 @@ func TestCoverageSplit_EmbeddedBlankMissing(t *testing.T) {
assert.Equal(live, embeddedCount+blank+missing,
"invariant: live == embedded + blank + missing")
}
+
+func TestCoverageSplit_ScopedEmbeddedHoldsInvariant(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ ctx := context.Background()
+
+ db := openPGTestDB(t) // skips when MSGVAULT_TEST_DB is unset
+ b, err := Open(ctx, Options{
+ DB: db,
+ Dimension: 8,
+ BuildScope: vector.NewBuildScope([]string{"sms"}),
+ })
+ require.NoError(err, "Open backend")
+ t.Cleanup(func() { _ = b.Close() })
+
+ _, err = db.ExecContext(ctx, `
+ INSERT INTO messages (id, message_type) VALUES
+ (1, 'email'),
+ (2, 'sms')
+ ON CONFLICT (id) DO UPDATE SET
+ message_type = EXCLUDED.message_type,
+ deleted_at = NULL,
+ deleted_from_source_at = NULL,
+ embed_gen = NULL`)
+ require.NoError(err, "seed scoped messages")
+
+ gen, err := b.CreateGeneration(ctx, "test-model", 8, "fp")
+ require.NoError(err, "CreateGeneration")
+ require.NoError(b.Upsert(ctx, gen, []vector.Chunk{
+ {MessageID: 1, Vector: []float32{1, 0, 0, 0, 0, 0, 0, 0}},
+ {MessageID: 2, Vector: []float32{0, 1, 0, 0, 0, 0, 0, 0}},
+ }), "Upsert embedded vectors")
+ _, err = db.ExecContext(ctx, `UPDATE messages SET embed_gen = $1 WHERE id IN (1, 2)`, int64(gen))
+ require.NoError(err, "stamp embedded")
+
+ var live, stamped int64
+ require.NoError(db.QueryRowContext(ctx,
+ `SELECT COUNT(*) FROM messages
+ WHERE message_type = 'sms'
+ AND deleted_at IS NULL
+ AND deleted_from_source_at IS NULL`).Scan(&live),
+ "count scoped live")
+ require.NoError(db.QueryRowContext(ctx,
+ `SELECT COUNT(*) FROM messages
+ WHERE message_type = 'sms'
+ AND embed_gen = $1
+ AND deleted_at IS NULL
+ AND deleted_from_source_at IS NULL`,
+ int64(gen)).Scan(&stamped), "count scoped stamped")
+ missing := live - stamped
+
+ embeddedCount, err := b.EmbeddedMessageCount(ctx, gen)
+ require.NoError(err, "EmbeddedMessageCount")
+ blank := max(stamped-embeddedCount, 0)
+
+ assert.Equal(int64(1), live, "only sms is in scope")
+ assert.Equal(int64(1), stamped, "only scoped stamped messages count")
+ assert.Equal(int64(1), embeddedCount, "out-of-scope email vector excluded")
+ assert.Equal(int64(0), blank)
+ assert.Equal(int64(0), missing)
+ assert.Equal(live, embeddedCount+blank+missing,
+ "invariant: live == embedded + blank + missing")
+}
diff --git a/internal/vector/pgvector/ext_stub.go b/internal/vector/pgvector/ext_stub.go
index f82e3130..05c0980c 100644
--- a/internal/vector/pgvector/ext_stub.go
+++ b/internal/vector/pgvector/ext_stub.go
@@ -27,6 +27,7 @@ type Options struct {
SkipMigrate bool
ReadOnly bool
SkipExtension bool
+ BuildScope vector.BuildScope
}
// Backend is a placeholder type so non-pgvector builds can compile
diff --git a/internal/vector/pgvector/fused_test.go b/internal/vector/pgvector/fused_test.go
index 325228de..b4bcb97e 100644
--- a/internal/vector/pgvector/fused_test.go
+++ b/internal/vector/pgvector/fused_test.go
@@ -148,6 +148,34 @@ func TestFusedSearch_FTSOnly(t *testing.T) {
}
}
+func TestFusedSearch_MessageTypeFilter(t *testing.T) {
+ f := newFusedFixture(t)
+ base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
+ f.seedMsg(t, 1, "lunch plan", "sms lunch details", 10, base, false)
+ f.seedMsg(t, 2, "lunch receipt", "email lunch details", 10, base.Add(time.Hour), false)
+ f.seedMsg(t, 3, "dinner plan", "sms dinner details", 10, base.Add(2*time.Hour), false)
+ _, err := f.db.ExecContext(f.ctx, `UPDATE messages SET message_type = CASE id WHEN 2 THEN 'email' ELSE 'sms' END`)
+ require.NoError(t, err, "seed message_type")
+ f.embedAll(t, map[int64][]float32{
+ 1: unitVec(4, 0),
+ 2: unitVec(4, 1),
+ 3: unitVec(4, 2),
+ })
+
+ hits, saturated, err := f.b.FusedSearch(f.ctx, vector.FusedRequest{
+ FTSTerms: []string{"lunch"},
+ Generation: f.gen,
+ KPerSignal: 10,
+ Limit: 10,
+ RRFK: 60,
+ Filter: vector.Filter{MessageTypes: []string{"sms"}},
+ })
+ require.NoError(t, err, "FusedSearch")
+ assert.False(t, saturated, "saturated should be false")
+ require.Len(t, hits, 1, "message_type filter should exclude email FTS hits; hits=%+v", hits)
+ assert.Equal(t, int64(1), hits[0].MessageID)
+}
+
func TestFusedSearch_ANNOnly(t *testing.T) {
f := seedThree(t)
hits, saturated, err := f.b.FusedSearch(f.ctx, vector.FusedRequest{
diff --git a/internal/vector/pgvector/parity_test.go b/internal/vector/pgvector/parity_test.go
index 532f7372..37aca933 100644
--- a/internal/vector/pgvector/parity_test.go
+++ b/internal/vector/pgvector/parity_test.go
@@ -69,6 +69,7 @@ func buildSqlitevecParity(t *testing.T, corpus []parityDoc) (*sqlitevec.Backend,
CREATE TABLE messages (
id INTEGER PRIMARY KEY,
subject TEXT,
+ message_type TEXT NOT NULL DEFAULT 'email',
source_id INTEGER,
sender_id INTEGER,
has_attachments INTEGER DEFAULT 0,
diff --git a/internal/vector/sqlitevec/backend.go b/internal/vector/sqlitevec/backend.go
index a38c19f0..72ee9142 100644
--- a/internal/vector/sqlitevec/backend.go
+++ b/internal/vector/sqlitevec/backend.go
@@ -28,10 +28,11 @@ var _ vector.Backend = (*Backend)(nil)
// Options configures how Open establishes a Backend.
type Options struct {
- Path string
- MainPath string // filesystem path to msgvault.db; required for FusedSearch
- Dimension int // default dimension for EnsureVectorTable at open
- MainDB *sql.DB // handle to the main msgvault.db
+ Path string
+ MainPath string // filesystem path to msgvault.db; required for FusedSearch
+ Dimension int // default dimension for EnsureVectorTable at open
+ MainDB *sql.DB // handle to the main msgvault.db
+ BuildScope vector.BuildScope // empty means full corpus
// ReadOnly indicates the main DB handle (MainDB) was opened read-only
// — e.g. the MCP server's store.OpenReadOnly (_query_only=true). When
// set, Open SKIPS BackfillEmbedGenForUpgrade, which would otherwise
@@ -50,6 +51,7 @@ type Backend struct {
path string // filesystem path to vectors.db
mainPath string // filesystem path to msgvault.db (for ATTACH)
dim int
+ scope vector.BuildScope
// readOnly is true when mainDB was opened read-only (MCP). The
// one-time upgrade backfill self-guards on it so it never writes
// through the read-only main handle. See Options.ReadOnly.
@@ -76,6 +78,7 @@ func Open(ctx context.Context, opts Options) (*Backend, error) {
path: opts.Path,
mainPath: opts.MainPath,
dim: opts.Dimension,
+ scope: vector.NewBuildScope(opts.BuildScope.MessageTypes),
readOnly: opts.ReadOnly,
}
// Orphaned-stamp reset (vectors.db-recreate safety): clear embed_gen for
@@ -268,24 +271,35 @@ func isUniqueConstraintErr(err error) bool {
// needs embedding for gen (embed_gen IS NULL OR embed_gen <> gen). This is
// the scan-and-fill coverage gate. On SQLite the messages live in the main
// DB while the generation lifecycle lives in vectors.db, so the gate
-// cannot be folded into the activation tx — ActivateGeneration runs this
-// Go pre-check against b.mainDB before its vectors.db tx (mirroring the
-// intentionally-non-atomic scheduler gate). The full-scan backstop covers
-// any TOCTOU window between this read and the flip.
+// cannot be folded into the activation tx.
func (b *Backend) hasMissingForGen(ctx context.Context, gen vector.GenerationID) (bool, error) {
var exists int
+ where, args := b.missingCoverageWhere(int64(gen))
err := b.mainDB.QueryRowContext(ctx,
`SELECT EXISTS (
SELECT 1 FROM messages
- WHERE (embed_gen IS NULL OR embed_gen <> ?)
- AND `+store.LiveMessagesWhere("", true)+`
- )`, int64(gen)).Scan(&exists)
+ WHERE `+where+`
+ )`, args...).Scan(&exists)
if err != nil {
return false, fmt.Errorf("check missing coverage for generation %d: %w", gen, err)
}
return exists == 1, nil
}
+func (b *Backend) missingCoverageWhere(gen int64) (string, []any) {
+ where := "(embed_gen IS NULL OR embed_gen <> ?) AND " + store.LiveMessagesWhere("", true)
+ args := []any{gen}
+ if !b.scope.IsEmpty() {
+ placeholders := make([]string, len(b.scope.MessageTypes))
+ for i, typ := range b.scope.MessageTypes {
+ placeholders[i] = "?"
+ args = append(args, typ)
+ }
+ where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ","))
+ }
+ return where, args
+}
+
// ActivateGeneration atomically retires the current active generation
// (if any) and promotes `gen` to active.
func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationID, force bool) error {
@@ -1375,9 +1389,9 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St
return s, nil
}
-// EmbeddedMessageCount returns the number of LIVE messages that are
-// stamped for gen (embed_gen = gen) AND actually have at least one vector
-// for the generation. Used by the coverage readout to split stamped
+// EmbeddedMessageCount returns the number of in-scope LIVE messages that
+// are stamped for gen (embed_gen = gen) AND actually have at least one
+// vector for the generation. Used by the coverage readout to split stamped
// messages into embedded vs blank. Counts distinct messages (not chunk
// rows) so a long, multi-chunk message counts once, matching the
// EmbeddingCount semantic elsewhere.
@@ -1434,18 +1448,27 @@ func (b *Backend) EmbeddedMessageCount(ctx context.Context, gen vector.Generatio
return 0, nil
}
- // Step 2 (main.db): how many of those are live AND stamped for gen.
+ // Step 2 (main.db): how many of those are in-scope, live, AND stamped for gen.
blob, err := json.Marshal(ids)
if err != nil {
return 0, fmt.Errorf("encode embedded ids: %w", err)
}
+ where := `id IN (SELECT value FROM json_each(?))
+ AND embed_gen = ?
+ AND ` + store.LiveMessagesWhere("", true)
+ args := []any{string(blob), int64(gen)}
+ if !b.scope.IsEmpty() {
+ placeholders := make([]string, len(b.scope.MessageTypes))
+ for i, typ := range b.scope.MessageTypes {
+ placeholders[i] = "?"
+ args = append(args, typ)
+ }
+ where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ","))
+ }
var n int64
if err := b.mainDB.QueryRowContext(ctx,
- `SELECT COUNT(*) FROM messages
- WHERE id IN (SELECT value FROM json_each(?))
- AND embed_gen = ?
- AND `+store.LiveMessagesWhere("", true),
- string(blob), int64(gen)).Scan(&n); err != nil {
+ `SELECT COUNT(*) FROM messages WHERE `+where,
+ args...).Scan(&n); err != nil {
return 0, fmt.Errorf("count live embedded messages: %w", err)
}
return n, nil
diff --git a/internal/vector/sqlitevec/backend_test.go b/internal/vector/sqlitevec/backend_test.go
index 697fd6d1..7e873413 100644
--- a/internal/vector/sqlitevec/backend_test.go
+++ b/internal/vector/sqlitevec/backend_test.go
@@ -292,6 +292,49 @@ func TestBackend_ActivateGeneration_NullSeededAtActivatesWithCoverage(t *testing
assertpkg.Equal(t, vector.GenerationActive, genStateSV(t, b, gen), "now active")
}
+func TestBackend_ActivateCoverageScopesMessageTypes(t *testing.T) {
+ ctx := context.Background()
+ main, err := sql.Open("sqlite3", ":memory:")
+ requirepkg.NoError(t, err, "open main")
+ t.Cleanup(func() { _ = main.Close() })
+ _, err = main.Exec(`CREATE TABLE messages (
+ id INTEGER PRIMARY KEY,
+ message_type TEXT NOT NULL,
+ embed_gen INTEGER,
+ deleted_at DATETIME,
+ deleted_from_source_at DATETIME
+ )`)
+ requirepkg.NoError(t, err, "create messages")
+ _, err = main.Exec(`
+ INSERT INTO messages (id, message_type, deleted_from_source_at) VALUES
+ (1, 'email', NULL),
+ (2, 'sms', NULL),
+ (3, 'mms', NULL),
+ (4, 'sms', CURRENT_TIMESTAMP)`)
+ requirepkg.NoError(t, err, "insert messages")
+
+ b, err := Open(ctx, Options{
+ Path: filepath.Join(t.TempDir(), "vectors.db"),
+ Dimension: 768,
+ MainDB: main,
+ BuildScope: vector.NewBuildScope([]string{"sms", "mms"}),
+ })
+ requirepkg.NoError(t, err, "Open")
+ t.Cleanup(func() { _ = b.Close() })
+
+ gid, err := b.CreateGeneration(ctx, "m", 768, "")
+ requirepkg.NoError(t, err, "Create")
+ missing, err := b.hasMissingForGen(ctx, gid)
+ requirepkg.NoError(t, err, "hasMissingForGen before stamp")
+ assertpkg.True(t, missing, "in-scope messages need embedding")
+
+ _, err = main.Exec(`UPDATE messages SET embed_gen = ? WHERE id IN (2, 3)`, int64(gid))
+ requirepkg.NoError(t, err, "stamp in-scope messages")
+ missing, err = b.hasMissingForGen(ctx, gid)
+ requirepkg.NoError(t, err, "hasMissingForGen after stamp")
+ assertpkg.False(t, missing, "out-of-scope and deleted messages must not block scoped activation")
+}
+
// TestBackend_CreateGeneration_ResumesBuilding confirms that calling
// CreateGeneration while a building row already exists with the same
// fingerprint returns the existing id instead of failing on the unique
diff --git a/internal/vector/sqlitevec/coverage_test.go b/internal/vector/sqlitevec/coverage_test.go
index 2ffcd0ac..55c6c786 100644
--- a/internal/vector/sqlitevec/coverage_test.go
+++ b/internal/vector/sqlitevec/coverage_test.go
@@ -218,3 +218,63 @@ func TestCoverageSplit_NonLiveEmbeddedHoldsInvariant(t *testing.T) {
assert.Equal(live, embedded+blank+missingCount,
"invariant: live == embedded + blank + missing")
}
+
+func TestCoverageSplit_ScopedEmbeddedHoldsInvariant(t *testing.T) {
+ require := requirepkg.New(t)
+ assert := assertpkg.New(t)
+ ctx := context.Background()
+
+ st := testutil.NewSQLiteTestStore(t)
+ b, err := Open(ctx, Options{
+ Path: filepath.Join(t.TempDir(), "vectors.db"),
+ Dimension: 8,
+ MainDB: st.DB(),
+ BuildScope: vector.NewBuildScope([]string{"sms"}),
+ })
+ require.NoError(err, "Open backend")
+ t.Cleanup(func() { _ = b.Close() })
+
+ source, err := st.GetOrCreateSource("gmail", "me@example.com")
+ require.NoError(err, "GetOrCreateSource")
+ emailConvID, err := st.EnsureConversationWithType(source.ID, "conv-email", "email_thread", "Email")
+ require.NoError(err, "EnsureConversationWithType email")
+ smsConvID, err := st.EnsureConversationWithType(source.ID, "conv-sms", "sms_thread", "SMS")
+ require.NoError(err, "EnsureConversationWithType sms")
+
+ makeMsg := func(srcMsgID, typ string, convID int64) int64 {
+ m := &store.Message{
+ SourceID: source.ID,
+ SourceMessageID: srcMsgID,
+ ConversationID: convID,
+ MessageType: typ,
+ Subject: sql.NullString{String: "s-" + srcMsgID, Valid: true},
+ }
+ id, err := st.UpsertMessage(m)
+ require.NoErrorf(err, "UpsertMessage %s", srcMsgID)
+ return id
+ }
+ outOfScopeEmail := makeMsg("email-stamped", "email", emailConvID)
+ inScopeSMS := makeMsg("sms-stamped", "sms", smsConvID)
+
+ gen, err := b.CreateGeneration(ctx, "test-model", 8, "fp")
+ require.NoError(err, "CreateGeneration")
+ require.NoError(b.Upsert(ctx, gen, []vector.Chunk{
+ {MessageID: outOfScopeEmail, Vector: []float32{1, 0, 0, 0, 0, 0, 0, 0}},
+ {MessageID: inScopeSMS, Vector: []float32{0, 1, 0, 0, 0, 0, 0, 0}},
+ }), "Upsert embedded vectors")
+ require.NoError(st.SetEmbedGen(ctx, []int64{outOfScopeEmail, inScopeSMS}, int64(gen)), "stamp embedded")
+
+ live, stamped, _, missingCount, err := st.CoverageCountsScoped(ctx, int64(gen), []string{"sms"})
+ require.NoError(err, "CoverageCountsScoped")
+ embedded, err := b.EmbeddedMessageCount(ctx, gen)
+ require.NoError(err, "EmbeddedMessageCount")
+ blank := max(stamped-embedded, 0)
+
+ assert.Equal(int64(1), live, "only sms is in scope")
+ assert.Equal(int64(1), stamped, "only scoped stamped messages count")
+ assert.Equal(int64(1), embedded, "out-of-scope email vector excluded")
+ assert.Equal(int64(0), blank)
+ assert.Equal(int64(0), missingCount)
+ assert.Equal(live, embedded+blank+missingCount,
+ "invariant: live == embedded + blank + missing")
+}
diff --git a/internal/vector/sqlitevec/ext_stub.go b/internal/vector/sqlitevec/ext_stub.go
index 081a8398..55b50a81 100644
--- a/internal/vector/sqlitevec/ext_stub.go
+++ b/internal/vector/sqlitevec/ext_stub.go
@@ -32,11 +32,12 @@ func Available() bool { return false }
// can reference sqlitevec.Options without a compile error; the struct is
// never populated at runtime when the PG code path is taken.
type Options struct {
- Path string
- MainPath string
- Dimension int
- MainDB *sql.DB
- ReadOnly bool
+ Path string
+ MainPath string
+ Dimension int
+ MainDB *sql.DB
+ BuildScope vector.BuildScope
+ ReadOnly bool
}
// Backend is the stub backend type for builds without sqlite_vec.
diff --git a/internal/vector/sqlitevec/fused_test.go b/internal/vector/sqlitevec/fused_test.go
index 19a34f94..03d0f47a 100644
--- a/internal/vector/sqlitevec/fused_test.go
+++ b/internal/vector/sqlitevec/fused_test.go
@@ -287,11 +287,11 @@ func openFusedMainWithSchema(t *testing.T, path string) *sql.DB {
t.Cleanup(func() { _ = db.Close() })
// sent_at is DATETIME (text) to match the production schema.
schema := `
-CREATE TABLE messages (
- id INTEGER PRIMARY KEY,
- subject TEXT,
- message_type TEXT NOT NULL DEFAULT 'email',
- source_id INTEGER,
+ CREATE TABLE messages (
+ id INTEGER PRIMARY KEY,
+ subject TEXT,
+ message_type TEXT NOT NULL DEFAULT 'email',
+ source_id INTEGER,
sender_id INTEGER,
has_attachments INTEGER DEFAULT 0,
size_estimate INTEGER,
@@ -322,10 +322,10 @@ func openFusedMainWithoutMessageType(t *testing.T, path string) *sql.DB {
requirepkg.NoError(t, err, "open main")
t.Cleanup(func() { _ = db.Close() })
schema := `
-CREATE TABLE messages (
- id INTEGER PRIMARY KEY,
- subject TEXT,
- source_id INTEGER,
+ CREATE TABLE messages (
+ id INTEGER PRIMARY KEY,
+ subject TEXT,
+ source_id INTEGER,
sender_id INTEGER,
has_attachments INTEGER DEFAULT 0,
size_estimate INTEGER,