diff --git a/README.md b/README.md index ec04f505..354fc997 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,8 @@ A separate MCP tool, `find_similar_messages`, returns nearest neighbors for a se > **Run only one embedding process at a time.** Don't run `msgvault embeddings build`/`resume` or `repair-encoding` concurrently with a `msgvault serve` daemon — they write the same embedding state, and concurrent writers are not coordinated across processes. +Large archives can scope an embedding generation with `[vector.embed.scope] message_types = ["sms", "mms"]`. Scoped vector and hybrid searches must include a matching `message_type` filter so a partial index is never used as if it covered the whole archive. + ## Importing from MBOX or Apple Mail Import email from providers that offer MBOX exports or from a local Apple Mail data directory: diff --git a/cmd/msgvault/cmd/add_synctech_sms_drive.go b/cmd/msgvault/cmd/add_synctech_sms_drive.go index 44d60d04..e6f3d340 100644 --- a/cmd/msgvault/cmd/add_synctech_sms_drive.go +++ b/cmd/msgvault/cmd/add_synctech_sms_drive.go @@ -109,10 +109,15 @@ func runConfiguredSynctechSMSSource(ctx context.Context, src config.SynctechSMSS return err } defer func() { _ = st.Close() }() + return runConfiguredSynctechSMSSourceWithStore(ctx, st, src) } func runConfiguredSynctechSMSSourceWithStore(ctx context.Context, st *store.Store, src config.SynctechSMSSource) error { + return runConfiguredSynctechSMSSourceWithStoreDriveClient(ctx, st, src, nil) +} + +func runConfiguredSynctechSMSSourceWithStoreDriveClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, driveClient synctechsms.DriveClient) error { opts := synctechImportOptions(src) if opts.OwnerPhone == "" { return fmt.Errorf("synctech-sms source %q owner_phone is required", src.Name) @@ -128,7 +133,11 @@ func runConfiguredSynctechSMSSourceWithStore(ctx context.Context, st *store.Stor } _, err = synctechsms.NewImporter(st, opts).ImportPath(src.Path) case "drive": - err = runSynctechSMSDriveSource(ctx, st, src, opts) + if driveClient != nil { + _, err = runSynctechSMSDriveSourceWithClient(ctx, st, src, opts, driveClient) + } else { + _, err = runSynctechSMSDriveSource(ctx, st, src, opts) + } default: return fmt.Errorf("unsupported synctech-sms backend %q", src.Backend) } @@ -167,28 +176,28 @@ func validateSynctechSMSDriveSource(src config.SynctechSMSSource) error { return nil } -func runSynctechSMSDriveSource(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions) error { +func runSynctechSMSDriveSource(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions) (synctechsms.ImportSummary, error) { if err := validateSynctechSMSDriveSource(src); err != nil { - return err + return synctechsms.ImportSummary{}, err } client, err := newSynctechSMSDriveClient(ctx, src) if err != nil { - return err + return synctechsms.ImportSummary{}, err } return runSynctechSMSDriveSourceWithClient(ctx, st, src, opts, client) } -func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions, client synctechsms.DriveClient) (retErr error) { +func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions, client synctechsms.DriveClient) (summary synctechsms.ImportSummary, retErr error) { if err := validateSynctechSMSDriveSource(src); err != nil { - return err + return summary, err } source, err := ensureConfiguredSynctechSMSSource(st, src, opts) if err != nil { - return err + return summary, err } syncID, err := st.StartSync(source.ID, synctechsms.AdapterName) if err != nil { - return fmt.Errorf("start sync: %w", err) + return summary, fmt.Errorf("start sync: %w", err) } completed := false defer func() { @@ -204,38 +213,38 @@ func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, s }() files, err := client.ListBackupFiles(ctx, src.FolderID) if err != nil { - return fmt.Errorf("list Drive backup files: %w", err) + return summary, fmt.Errorf("list Drive backup files: %w", err) } imported, err := st.ListImportedSourceItemChecksums(source.ID, "drive") if err != nil { - return fmt.Errorf("list imported Drive checksums: %w", err) + return summary, fmt.Errorf("list imported Drive checksums: %w", err) } stableAfter, err := time.ParseDuration(src.StableAfter) if err != nil { - return fmt.Errorf("parse stable_after: %w", err) + return summary, fmt.Errorf("parse stable_after: %w", err) } selected := synctechsms.SelectStableDriveFiles(files, time.Now(), stableAfter, imported) stagingDir := filepath.Join(cfg.Data.DataDir, "imports", "synctech-sms", src.Name) if err := os.MkdirAll(stagingDir, 0o700); err != nil { - return fmt.Errorf("create staging directory: %w", err) + return summary, fmt.Errorf("create staging directory: %w", err) } imp := synctechsms.NewImporter(st, opts) - var summary synctechsms.ImportSummary for _, file := range selected { fileSummary, err := importOneDriveBackup(ctx, st, imp, client, source.ID, file, stagingDir) - if err != nil { - return err - } summary.FilesSeen += fileSummary.FilesSeen summary.FilesImported += fileSummary.FilesImported summary.SMSImported += fileSummary.SMSImported summary.MMSImported += fileSummary.MMSImported summary.CallsImported += fileSummary.CallsImported summary.AttachmentsImported += fileSummary.AttachmentsImported + summary.MessageIDs = append(summary.MessageIDs, fileSummary.MessageIDs...) + if err != nil { + return summary, err + } } if summary.FilesImported > 0 { if err := st.RecomputeConversationStats(source.ID); err != nil { - return fmt.Errorf("recompute conversation stats: %w", err) + return summary, fmt.Errorf("recompute conversation stats: %w", err) } } totalRecords := int64(summary.SMSImported + summary.MMSImported + summary.CallsImported) @@ -243,16 +252,16 @@ func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, s MessagesProcessed: totalRecords, MessagesAdded: totalRecords, }); err != nil { - return fmt.Errorf("update sync checkpoint: %w", err) + return summary, fmt.Errorf("update sync checkpoint: %w", err) } if err := st.TouchSourceLastSyncAt(source.ID); err != nil { - return fmt.Errorf("touch source last sync: %w", err) + return summary, fmt.Errorf("touch source last sync: %w", err) } if err := st.CompleteSync(syncID, ""); err != nil { - return fmt.Errorf("complete sync: %w", err) + return summary, fmt.Errorf("complete sync: %w", err) } completed = true - return nil + return summary, nil } func importOneDriveBackup(ctx context.Context, st *store.Store, imp *synctechsms.Importer, client synctechsms.DriveClient, sourceID int64, file synctechsms.DriveFile, stagingDir string) (synctechsms.ImportSummary, error) { diff --git a/cmd/msgvault/cmd/add_synctech_sms_drive_test.go b/cmd/msgvault/cmd/add_synctech_sms_drive_test.go index f3920850..2f98ede7 100644 --- a/cmd/msgvault/cmd/add_synctech_sms_drive_test.go +++ b/cmd/msgvault/cmd/add_synctech_sms_drive_test.go @@ -75,8 +75,9 @@ func TestSynctechSMSDriveRunUsesSingleOuterSyncRun(t *testing.T) { }, } - err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) + summary, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) require.NoError(err, "runSynctechSMSDriveSourceWithClient") + require.Len(summary.MessageIDs, 1, "summary message IDs") source := getSynctechSource(t, f.Store, src.OwnerPhone) assert.Equal(1, countSyncRuns(t, f.Store, source.ID), "sync run count") @@ -86,7 +87,7 @@ func TestSynctechSMSDriveRunUsesSingleOuterSyncRun(t *testing.T) { assert.Equal(int64(1), run.MessagesAdded, "messages added") assert.True(getSynctechSource(t, f.Store, src.OwnerPhone).LastSyncAt.Valid, "last_sync_at should be touched") - item := getSourceImportItem(t, f.Store, source.ID, "drive", "backup-1") + item := getDriveSourceImportItem(t, f.Store, source.ID, "backup-1") assert.Equal("imported", item.Status, "source import status") assert.Equal(1, item.RecordsImported, "records imported") assert.False(item.ErrorMessage.Valid, "source import error") @@ -113,7 +114,7 @@ func TestSynctechSMSDriveRunSetsUpIdentityAndPostSourceMigration(t *testing.T) { src := synctechDriveTestSource() client := fakeSynctechDriveClient{} - err = runSynctechSMSDriveSourceWithClient(context.Background(), st, src, synctechImportOptions(src), client) + _, err = runSynctechSMSDriveSourceWithClient(context.Background(), st, src, synctechImportOptions(src), client) require.NoError(err, "runSynctechSMSDriveSourceWithClient") synctechSource := getSynctechSource(t, st, src.OwnerPhone) @@ -154,7 +155,7 @@ func TestSynctechSMSDriveRunRecordsZeroSelectedPoll(t *testing.T) { }}, } - err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) + _, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) require.NoError(err, "runSynctechSMSDriveSourceWithClient") source := getSynctechSource(t, f.Store, src.OwnerPhone) @@ -188,7 +189,7 @@ func TestSynctechSMSDriveRunMarksOuterSyncFailedOnDownloadError(t *testing.T) { downloadErr: downloadErr, } - err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) + _, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) require.ErrorIs(err, downloadErr, "runSynctechSMSDriveSourceWithClient") source := getSynctechSource(t, f.Store, src.OwnerPhone) @@ -198,17 +199,126 @@ func TestSynctechSMSDriveRunMarksOuterSyncFailedOnDownloadError(t *testing.T) { require.True(run.ErrorMessage.Valid, "sync error_message") assert.Contains(run.ErrorMessage.String, downloadErr.Error(), "sync error_message") - item := getSourceImportItem(t, f.Store, source.ID, "drive", "backup-1") + item := getDriveSourceImportItem(t, f.Store, source.ID, "backup-1") assert.Equal("failed", item.Status, "source import status") require.True(item.ErrorMessage.Valid, "source import error") assert.Contains(item.ErrorMessage.String, downloadErr.Error(), "source import error") } +func TestSynctechSMSDrivePartialFailureEnqueuesImportedMessages(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + home := t.TempDir() + cfg = config.NewDefaultConfig() + cfg.HomeDir = home + cfg.Data.DataDir = home + f := storetest.New(t) + src := synctechDriveTestSource() + client := fakeSynctechDriveClient{ + files: []synctechsms.DriveFile{ + { + ID: "backup-1", + Name: "sms-1.xml", + Checksum: "sum-1", + Size: 128, + ModifiedTime: time.Now().Add(-30 * time.Minute), + }, + { + ID: "backup-2", + Name: "sms-2.xml", + Checksum: "sum-2", + Size: 128, + ModifiedTime: time.Now().Add(-30 * time.Minute), + }, + }, + downloads: map[string]string{ + "backup-1": ` + +`, + "backup-2": ` + + + +`), 0o600), "write backup") + + st, err := store.Open(cfg.DatabaseDSN()) + require.NoError(err, "open store") + require.NoError(st.InitSchema(), "InitSchema") + require.NoError(st.Close(), "close store") + + src := config.SynctechSMSSource{ + Name: "pixel-local", + Backend: "local", + Path: importDir, + OwnerPhone: "+15550000001", + IncludeSMS: true, + } + require.NoError(runConfiguredSynctechSMSSource(ctx, src), "runConfiguredSynctechSMSSource") + + st, err = store.Open(cfg.DatabaseDSN()) + require.NoError(err, "reopen store") + t.Cleanup(func() { _ = st.Close() }) + source := getSynctechSource(t, st, src.OwnerPhone) + var unstamped int + require.NoError(st.DB().QueryRowContext(ctx, + st.Rebind(`SELECT COUNT(*) FROM messages WHERE source_id = ? AND embed_gen IS NULL`), + source.ID, + ).Scan(&unstamped), "count unstamped messages") + assert.Equal(1, unstamped, "manual sync message remains discoverable by scan-and-fill") +} + type fakeSynctechDriveClient struct { - files []synctechsms.DriveFile - downloads map[string]string - listErr error - downloadErr error + files []synctechsms.DriveFile + downloads map[string]string + listErr error + downloadErr error + downloadErrByID map[string]error } func (f fakeSynctechDriveClient) ListBackupFiles(context.Context, string) ([]synctechsms.DriveFile, error) { @@ -219,6 +329,9 @@ func (f fakeSynctechDriveClient) ListBackupFiles(context.Context, string) ([]syn } func (f fakeSynctechDriveClient) DownloadToFile(_ context.Context, fileID, path string) error { + if err := f.downloadErrByID[fileID]; err != nil { + return err + } if f.downloadErr != nil { return f.downloadErr } @@ -280,9 +393,9 @@ func getOnlySyncRun(t *testing.T, st *store.Store, sourceID int64) store.SyncRun return run } -func getSourceImportItem(t *testing.T, st *store.Store, sourceID int64, provider, providerID string) *store.SourceImportItem { +func getDriveSourceImportItem(t *testing.T, st *store.Store, sourceID int64, providerID string) *store.SourceImportItem { t.Helper() - item, err := st.GetSourceImportItem(sourceID, provider, providerID) + item, err := st.GetSourceImportItem(sourceID, "drive", providerID) requirepkg.NoError(t, err, "GetSourceImportItem") return item } diff --git a/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go b/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go index 3663648e..8a815438 100644 --- a/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go +++ b/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go @@ -5,6 +5,7 @@ package cmd import ( "context" "database/sql" + "path/filepath" "testing" _ "github.com/mattn/go-sqlite3" @@ -48,3 +49,30 @@ func TestRunEmbeddingsRetire_ForceActive(t *testing.T) { row := mustGetEmbeddingGeneration(t.Context(), t, db, 1) assert.Equal(vector.GenerationRetired, row.State) } + +func TestFillFullCoverageUsesEmbeddingScopeForEmbeddedCount(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + ctx := context.Background() + dataDir := t.TempDir() + dbPath := newEmbeddingMetadataTestDBFileAt(t, filepath.Join(dataDir, "vectors.db")) + seedMainDBWithScopedFullCoverageMessages(t, dataDir) + withEmbeddingCommandConfigDataDir(t, dbPath, dataDir) + cfg.Vector.Embed.Scope.MessageTypes = []string{"sms"} + + backend, closeBackend, err := openEmbeddingsBackend(ctx) + require.NoError(err, "open embeddings backend") + t.Cleanup(closeBackend) + require.NoError(backend.Upsert(ctx, 2, []vector.Chunk{ + {MessageID: 1, Vector: []float32{1, 0, 0, 0}}, + {MessageID: 2, Vector: []float32{0, 1, 0, 0}}, + }), "upsert in-scope and out-of-scope vectors") + + row := embeddingGenerationRow{ID: 2} + require.NoError(fillFullCoverage(ctx, backend, &row)) + + assert.Equal(int64(1), row.LiveCount, "only sms is in scope") + assert.Equal(int64(1), row.EmbeddedCount, "out-of-scope email vector is excluded") + assert.Equal(int64(0), row.BlankCount) + assert.Equal(int64(0), row.MissingCount) +} diff --git a/cmd/msgvault/cmd/embed_test.go b/cmd/msgvault/cmd/embed_test.go index e9af2674..f2eb3512 100644 --- a/cmd/msgvault/cmd/embed_test.go +++ b/cmd/msgvault/cmd/embed_test.go @@ -139,6 +139,22 @@ func TestRunEmbeddingsActivateRefusesMissingWithoutForce(t *testing.T) { assert.Contains(err.Error(), "msgvault embeddings resume --backstop") } +func TestFillCoverageUsesEmbeddingScope(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + dataDir := t.TempDir() + dbPath := newEmbeddingMetadataTestDBFileAt(t, filepath.Join(dataDir, "vectors.db")) + seedMainDBWithScopedCoverageMessages(t, dataDir) + withEmbeddingCommandConfigDataDir(t, dbPath, dataDir) + cfg.Vector.Embed.Scope.MessageTypes = []string{"sms"} + + row := embeddingGenerationRow{ID: 2} + require.NoError(fillCoverage(t.Context(), &row)) + + assert.Equal(int64(1), row.LiveCount) + assert.Equal(int64(0), row.MissingCount) +} + // TestRetireEmbeddingGenerationRefusesActiveWithoutForce_PreCheck pins the // CLI UX gate that runs against the committed metadata read BEFORE any // backend connection: retiring an active generation without --force-active @@ -240,6 +256,38 @@ INSERT INTO messages (id, conversation_id, source_id, source_message_id, message requirepkg.NoError(t, err) } +func seedMainDBWithScopedCoverageMessages(t *testing.T, dataDir string) { + t.Helper() + s, err := store.Open(filepath.Join(dataDir, "msgvault.db")) + requirepkg.NoError(t, err) + defer func() { requirepkg.NoError(t, s.Close()) }() + requirepkg.NoError(t, s.InitSchema()) + _, err = s.DB().Exec(` +INSERT INTO sources (id, source_type, identifier) VALUES (1, 'gmail', 'me@example.com'); +INSERT INTO conversations (id, source_id, conversation_type) VALUES (1, 1, 'email_thread'), (2, 1, 'sms_thread'); +INSERT INTO messages (id, conversation_id, source_id, source_message_id, message_type, embed_gen) VALUES + (1, 1, 1, 'email-missing', 'email', NULL), + (2, 2, 1, 'sms-stamped', 'sms', 2); +`) + requirepkg.NoError(t, err) +} + +func seedMainDBWithScopedFullCoverageMessages(t *testing.T, dataDir string) { + t.Helper() + s, err := store.Open(filepath.Join(dataDir, "msgvault.db")) + requirepkg.NoError(t, err) + defer func() { requirepkg.NoError(t, s.Close()) }() + requirepkg.NoError(t, s.InitSchema()) + _, err = s.DB().Exec(` +INSERT INTO sources (id, source_type, identifier) VALUES (1, 'gmail', 'me@example.com'); +INSERT INTO conversations (id, source_id, conversation_type) VALUES (1, 1, 'email_thread'), (2, 1, 'sms_thread'); +INSERT INTO messages (id, conversation_id, source_id, source_message_id, message_type, embed_gen) VALUES + (1, 1, 1, 'email-stamped', 'email', 2), + (2, 2, 1, 'sms-stamped', 'sms', 2); +`) + requirepkg.NoError(t, err) +} + func withEmbeddingCommandConfig(t *testing.T, vecPath string) { t.Helper() oldCfg := cfg diff --git a/cmd/msgvault/cmd/embed_vector.go b/cmd/msgvault/cmd/embed_vector.go index 8df24cd0..f0372d67 100644 --- a/cmd/msgvault/cmd/embed_vector.go +++ b/cmd/msgvault/cmd/embed_vector.go @@ -57,6 +57,7 @@ func runEmbed(cmd *cobra.Command) error { pgb, err := pgvector.Open(ctx, pgvector.Options{ DB: s.DB(), Dimension: cfg.Vector.Embeddings.Dimension, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), SkipExtension: cfg.Vector.SkipExtensionCreate, }) if err != nil { @@ -76,10 +77,11 @@ func runEmbed(cmd *cobra.Command) error { vecPath = filepath.Join(cfg.Data.DataDir, "vectors.db") } sb, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecPath, - MainPath: cfg.DatabaseDSN(), - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: s.DB(), + Path: vecPath, + MainPath: cfg.DatabaseDSN(), + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: s.DB(), + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { return fmt.Errorf("open vectors.db: %w", err) @@ -116,7 +118,8 @@ func runEmbed(cmd *cobra.Command) error { // "Pending" is now the count of live messages still needing work for // this generation (embed_gen <> gen), read from the main DB coverage // rather than a queue table. - missing, err := s.MissingCount(ctx, int64(gen)) + scope := cfg.Vector.Embed.Scope.BuildScope() + missing, err := s.MissingCountScoped(ctx, int64(gen), scope.MessageTypes) if err != nil { return fmt.Errorf("coverage counts: %w", err) } @@ -138,6 +141,7 @@ func runEmbed(cmd *cobra.Command) error { }, MaxInputChars: cfg.Vector.Embeddings.MaxInputChars, BatchSize: cfg.Vector.Embeddings.BatchSize, + BuildScope: scope, Rebind: rebind, LastModifiedExpr: lastModifiedExpr, TotalPending: totalPending, @@ -161,7 +165,7 @@ func runEmbed(cmd *cobra.Command) error { // worker later recovers from must not block activation, and an // active generation must not be re-activated. if rebuildInProgress { - _, _, _, remaining, err := s.CoverageCounts(ctx, int64(gen)) + _, _, _, remaining, err := s.CoverageCountsScoped(ctx, int64(gen), scope.MessageTypes) if err != nil { return fmt.Errorf("coverage counts: %w", err) } diff --git a/cmd/msgvault/cmd/embeddings_manage.go b/cmd/msgvault/cmd/embeddings_manage.go index b3e0f04f..81525b60 100644 --- a/cmd/msgvault/cmd/embeddings_manage.go +++ b/cmd/msgvault/cmd/embeddings_manage.go @@ -64,7 +64,8 @@ func fillCoverage(ctx context.Context, row *embeddingGenerationRow) error { return fmt.Errorf("open main db for coverage: %w", err) } defer func() { _ = s.Close() }() - live, _, _, missing, err := s.CoverageCounts(ctx, int64(row.ID)) + scope := cfg.Vector.Embed.Scope.BuildScope() + live, _, _, missing, err := s.CoverageCountsScoped(ctx, int64(row.ID), scope.MessageTypes) if err != nil { return err } @@ -87,7 +88,8 @@ func fillFullCoverage(ctx context.Context, backend vector.Backend, row *embeddin return fmt.Errorf("open main db for coverage: %w", err) } defer func() { _ = s.Close() }() - live, stamped, _, missing, err := s.CoverageCounts(ctx, int64(row.ID)) + scope := cfg.Vector.Embed.Scope.BuildScope() + live, stamped, _, missing, err := s.CoverageCountsScoped(ctx, int64(row.ID), scope.MessageTypes) if err != nil { return err } @@ -426,6 +428,7 @@ func openEmbeddingsBackend(ctx context.Context) (vector.Backend, func(), error) DB: db, Dimension: cfg.Vector.Embeddings.Dimension, SkipMigrate: true, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { _ = db.Close() @@ -460,10 +463,11 @@ func openEmbeddingsBackend(ctx context.Context) (vector.Backend, func(), error) return nil, nil, fmt.Errorf("open main db for embeddings backend: %w", err) } b, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecPath, - MainPath: dsn, - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: mainStore.DB(), + Path: vecPath, + MainPath: dsn, + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: mainStore.DB(), + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { _ = mainStore.Close() diff --git a/cmd/msgvault/cmd/search_vector.go b/cmd/msgvault/cmd/search_vector.go index 96cf957b..6f8e7ab6 100644 --- a/cmd/msgvault/cmd/search_vector.go +++ b/cmd/msgvault/cmd/search_vector.go @@ -79,6 +79,7 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc pgb, err := pgvector.Open(ctx, pgvector.Options{ DB: mainDB, Dimension: cfg.Vector.Embeddings.Dimension, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), SkipExtension: cfg.Vector.SkipExtensionCreate, }) if err != nil { @@ -110,10 +111,11 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc } sb, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecDBPath, - MainPath: dsn, - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: mainDB, + Path: vecDBPath, + MainPath: dsn, + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: mainDB, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { _ = mainDB.Close() @@ -147,6 +149,7 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc KPerSignal: cfg.Vector.Search.KPerSignal, SubjectBoost: cfg.Vector.Search.SubjectBoost, Rebind: dialect.Rebind, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) q := search.Parse(queryStr) diff --git a/cmd/msgvault/cmd/serve.go b/cmd/msgvault/cmd/serve.go index fa1a41d8..29fe1f21 100644 --- a/cmd/msgvault/cmd/serve.go +++ b/cmd/msgvault/cmd/serve.go @@ -206,6 +206,7 @@ func runServe(cmd *cobra.Command, args []string) error { Store: s, Fingerprint: vf.Cfg.GenerationFingerprint(), BackstopInterval: vf.Cfg.Embed.BackstopInterval, + BuildScope: vf.Cfg.Embed.Scope.BuildScope(), Log: logger, } schedule := cfg.Vector.Embed.Schedule.Cron diff --git a/cmd/msgvault/cmd/serve_vector.go b/cmd/msgvault/cmd/serve_vector.go index 6be2e5e8..9f329e18 100644 --- a/cmd/msgvault/cmd/serve_vector.go +++ b/cmd/msgvault/cmd/serve_vector.go @@ -76,6 +76,7 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s pgb, err := pgvector.Open(ctx, pgvector.Options{ DB: mainDB, Dimension: cfg.Vector.Embeddings.Dimension, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), SkipMigrate: readOnly, // ReadOnly MUST track readOnly here: this is the MCP read-only // path (store.OpenReadOnly). When set, Open performs no writes — @@ -103,10 +104,11 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s vecPath = filepath.Join(cfg.Data.DataDir, "vectors.db") } sb, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecPath, - MainPath: mainPath, - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: mainDB, + Path: vecPath, + MainPath: mainPath, + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: mainDB, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), // Honor the read-only signal on SQLite too: when mainDB is a // query-only handle (MCP), skip the embed_gen upgrade backfill, // which would write through it. Migrate still runs (vectors.db @@ -146,6 +148,7 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s }, MaxInputChars: cfg.Vector.Embeddings.MaxInputChars, BatchSize: cfg.Vector.Embeddings.BatchSize, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), // Rebind makes the worker's body-fetch + watermark SQL run on pgx. // SQLiteDialect.Rebind is identity, so the SQLite path is unchanged. Rebind: dialect.Rebind, @@ -162,7 +165,8 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s // placeholders. On PG those must become $N or pgx rejects them, so // the serve/MCP hybrid engine (shared via vectorFeatures.HybridEngine) // carries the dialect's Rebind. SQLite's Rebind is identity. - Rebind: dialect.Rebind, + Rebind: dialect.Rebind, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) // No sync-time enqueue: newly-persisted messages get embed_gen = NULL diff --git a/internal/api/handlers.go b/internal/api/handlers.go index c76ed247..4df02ff6 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -616,6 +616,8 @@ func (s *Server) handleHybridSearch( case errors.Is(err, vector.ErrEmbeddingTimeout): writeError(w, http.StatusServiceUnavailable, "embedding_timeout", "the embedding endpoint did not respond in time; retry, or raise [vector.embeddings].timeout") + case errors.Is(err, vector.ErrIndexScopeMismatch): + writeError(w, http.StatusBadRequest, "index_scope_mismatch", err.Error()) default: s.logger.Error("hybrid search failed", "query", q, "mode", mode, "error", err) writeError(w, http.StatusInternalServerError, "internal_error", "search failed") diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index c9a74715..1940aeab 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -118,6 +118,11 @@ func translateVectorErr(err error) *mcp.CallToolResult { return mcp.NewToolResultError( "index_building: the initial vector index is still being built", ) + case errors.Is(err, vector.ErrIndexScopeMismatch): + return mcp.NewToolResultError( + "index_scope_mismatch: the vector index scope does not cover this query; " + + "add a matching message_type filter or rebuild embeddings for the requested scope", + ) case errors.Is(err, vector.ErrNoActiveGeneration): return mcp.NewToolResultError( "no_active_generation: vector search has no active index yet; " + @@ -512,25 +517,46 @@ func (h *handlers) findSimilarMessages(ctx context.Context, req mcp.CallToolRequ return mcp.NewToolResultError(err.Error()), nil } - seed, err := h.backend.LoadVector(ctx, seedID) + active, err := h.backend.ActiveGeneration(ctx) if err != nil { if r := translateVectorErr(err); r != nil { return r, nil } - return mcp.NewToolResultError(fmt.Sprintf("load seed vector: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil + } + if h.vectorCfg.Enabled { + fingerprint := h.vectorCfg.GenerationFingerprint() + if fingerprint != "" && active.Fingerprint != fingerprint { + err := fmt.Errorf("%w: active=%q configured=%q", + vector.ErrIndexStale, active.Fingerprint, fingerprint) + if r := translateVectorErr(err); r != nil { + return r, nil + } + return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil + } } - active, err := h.backend.ActiveGeneration(ctx) + seed, err := h.backend.LoadVector(ctx, seedID) if err != nil { if r := translateVectorErr(err); r != nil { return r, nil } - return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("load seed vector: %v", err)), nil + } + + if h.vectorCfg.Enabled { + scope := h.vectorCfg.Embed.Scope.BuildScope() + if !scope.IsEmpty() { + filter.MessageTypes = append([]string(nil), scope.MessageTypes...) + } } // +1 so we can drop the seed itself from results without coming up short. hits, err := h.backend.Search(ctx, active.ID, seed, limit+1, filter) if err != nil { + if r := translateVectorErr(err); r != nil { + return r, nil + } return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err)), nil } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index e11cc76e..5cdd97ac 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -262,6 +262,33 @@ func TestSearchMessages_HybridErrNotEnabled(t *testing.T) { assertpkg.Contains(t, txt, "vector_not_enabled", "expected 'vector_not_enabled' error, got: %s") } +func TestSearchMessages_HybridErrIndexScopeMismatch(t *testing.T) { + backend := &fakeBackend{ + active: vector.Generation{ + ID: 1, Model: "fake", Dimension: 4, + Fingerprint: "fake:4:scope=mt-sms", State: vector.GenerationActive, + }, + } + engine := hybrid.NewEngine(backend, nil, realEmbedder{dim: 4}, hybrid.Config{ + ExpectedFingerprint: "fake:4:scope=mt-sms", + RRFK: 60, + KPerSignal: 10, + BuildScope: vector.BuildScope{MessageTypes: []string{"sms"}}, + }) + h := &handlers{ + engine: &querytest.MockEngine{}, + hybridEngine: engine, + backend: backend, + } + + r := runToolExpectError(t, "search_messages", h.searchMessages, map[string]any{ + "query": "anything", + "mode": searchModeVector, + }) + txt := resultText(t, r) + assertpkg.Contains(t, txt, "index_scope_mismatch", "expected 'index_scope_mismatch' error, got: %s") +} + // realEmbedder returns a deterministic vector. Used for end-to-end // MCP hybrid tests that exercise the engine's embed → backend.Search // path; pickEmbedGeneration tests use stubEmbedder instead. @@ -1465,19 +1492,23 @@ func TestStageDeletion(t *testing.T) { // optional fields so the get_stats tests can populate them. Methods // not otherwise configured return errors and should not be called. type fakeBackend struct { - loadVec []float32 - loadErr error - active vector.Generation - activeErr error - searchHits []vector.Hit - searchErr error - building *vector.Generation - buildingErr error - stats map[vector.GenerationID]vector.Stats - statsErr error + loadVec []float32 + loadErr error + loadCalls int + active vector.Generation + activeErr error + searchHits []vector.Hit + searchErr error + searchGen vector.GenerationID + searchFilter vector.Filter + building *vector.Generation + buildingErr error + stats map[vector.GenerationID]vector.Stats + statsErr error } func (f *fakeBackend) LoadVector(_ context.Context, _ int64) ([]float32, error) { + f.loadCalls++ return f.loadVec, f.loadErr } func (f *fakeBackend) ResetWatermarkBelow(_ context.Context, _ int64) error { @@ -1489,7 +1520,9 @@ func (f *fakeBackend) EmbeddedMessageCount(_ context.Context, _ vector.Generatio func (f *fakeBackend) ActiveGeneration(_ context.Context) (vector.Generation, error) { return f.active, f.activeErr } -func (f *fakeBackend) Search(_ context.Context, _ vector.GenerationID, _ []float32, _ int, _ vector.Filter) ([]vector.Hit, error) { +func (f *fakeBackend) Search(_ context.Context, gen vector.GenerationID, _ []float32, _ int, filter vector.Filter) ([]vector.Hit, error) { + f.searchGen = gen + f.searchFilter = filter return f.searchHits, f.searchErr } func (f *fakeBackend) CreateGeneration(_ context.Context, _ string, _ int, _ string) (vector.GenerationID, error) { @@ -1623,9 +1656,98 @@ func TestFindSimilarMessages_HappyPath(t *testing.T) { assert.Equal(int64(300), resp.Messages[1].ID, "Messages[1].ID") } +func TestFindSimilarMessages_RejectsStaleScopedGeneration(t *testing.T) { + assert := assertpkg.New(t) + seed := []float32{1, 0, 0, 0} + cfg := vector.Config{Enabled: true} + cfg.Embeddings.Model = "nomic-embed" + cfg.Embeddings.Dimension = 4 + cfg.Embed.Scope.MessageTypes = []string{"sms"} + fb := &fakeBackend{ + loadVec: seed, + active: vector.Generation{ + ID: 7, + Model: "nomic-embed", + Dimension: 4, + Fingerprint: "nomic-embed:4", + State: vector.GenerationActive, + }, + } + h := &handlers{engine: &querytest.MockEngine{}, backend: fb, vectorCfg: cfg} + + r := runToolExpectError(t, "find_similar_messages", h.findSimilarMessages, map[string]any{ + "message_id": float64(100), + }) + txt := resultText(t, r) + + assert.Contains(txt, "index_stale", "expected stale scoped index rejection, got: %s", txt) + assert.Equal(vector.GenerationID(0), fb.searchGen, "Search should not run against a stale scoped index") +} + +func TestFindSimilarMessages_ReportsStaleIndexBeforeMissingSeed(t *testing.T) { + assert := assertpkg.New(t) + cfg := vector.Config{Enabled: true} + cfg.Embeddings.Model = "nomic-embed" + cfg.Embeddings.Dimension = 4 + cfg.Embed.Scope.MessageTypes = []string{"sms"} + fb := &fakeBackend{ + loadErr: errors.New("seed vector missing"), + active: vector.Generation{ + ID: 7, + Model: "nomic-embed", + Dimension: 4, + Fingerprint: "nomic-embed:4", + State: vector.GenerationActive, + }, + } + h := &handlers{engine: &querytest.MockEngine{}, backend: fb, vectorCfg: cfg} + + r := runToolExpectError(t, "find_similar_messages", h.findSimilarMessages, map[string]any{ + "message_id": float64(100), + }) + txt := resultText(t, r) + + assert.Contains(txt, "index_stale", "expected stale index to be reported before seed lookup, got: %s", txt) + assert.Equal(0, fb.loadCalls, "LoadVector should not run before stale-index validation") +} + +func TestFindSimilarMessages_AppliesConfiguredMessageTypeScope(t *testing.T) { + assert := assertpkg.New(t) + seed := []float32{1, 0, 0, 0} + cfg := vector.Config{Enabled: true} + cfg.Embeddings.Model = "nomic-embed" + cfg.Embeddings.Dimension = 4 + cfg.Embed.Scope.MessageTypes = []string{"sms", "mms"} + fb := &fakeBackend{ + loadVec: seed, + active: vector.Generation{ + ID: 7, + Model: "nomic-embed", + Dimension: 4, + Fingerprint: cfg.GenerationFingerprint(), + State: vector.GenerationActive, + }, + searchHits: []vector.Hit{{MessageID: 200, Score: 0.95, Rank: 1}}, + } + eng := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 200: testutil.NewMessageDetail(200).WithSubject("related").BuildPtr(), + }, + } + h := &handlers{engine: eng, backend: fb, vectorCfg: cfg} + + _ = runTool[similarResponse](t, "find_similar_messages", h.findSimilarMessages, map[string]any{ + "message_id": float64(100), + }) + + assert.Equal(vector.GenerationID(7), fb.searchGen, "Search generation") + assert.Equal([]string{"mms", "sms"}, fb.searchFilter.MessageTypes, "Search MessageTypes filter") +} + func TestFindSimilarMessages_NoActiveGeneration(t *testing.T) { + assert := assertpkg.New(t) fb := &fakeBackend{ - loadErr: vector.ErrNoActiveGeneration, + activeErr: vector.ErrNoActiveGeneration, } h := &handlers{engine: &querytest.MockEngine{}, backend: fb} @@ -1633,7 +1755,8 @@ func TestFindSimilarMessages_NoActiveGeneration(t *testing.T) { "message_id": float64(1), }) txt := resultText(t, r) - assertpkg.Contains(t, txt, "no_active_generation", "expected 'no_active_generation' error, got: %s") + assert.Contains(txt, "no_active_generation", "expected 'no_active_generation' error, got: %s", txt) + assert.Equal(0, fb.loadCalls, "LoadVector should not run without an active generation") } func TestSearchByDomains(t *testing.T) { diff --git a/internal/query/duckdb.go b/internal/query/duckdb.go index e4a31059..b0374614 100644 --- a/internal/query/duckdb.go +++ b/internal/query/duckdb.go @@ -407,7 +407,7 @@ func (e *DuckDBEngine) parquetCTEs() string { conv AS ( %s ) - `, msgCTE, + `, msgCTE, e.parquetPath("message_recipients"), pCTE, e.parquetPath("labels"), @@ -417,6 +417,45 @@ func (e *DuckDBEngine) parquetCTEs() string { convCTE) } +func duckDBMessageTypeCondition(alias string, messageTypes []string) (string, []any) { + var conditions []string + var args []any + var exact []string + includeEmail := false + + for _, typ := range messageTypes { + typ = strings.TrimSpace(strings.ToLower(typ)) + if typ == "" { + continue + } + if typ == "email" { + includeEmail = true + continue + } + exact = append(exact, typ) + } + + col := alias + ".message_type" + if includeEmail { + conditions = append(conditions, + fmt.Sprintf("(%s = ? OR %s IS NULL OR %s = '')", col, col, col)) + args = append(args, "email") + } + if len(exact) > 0 { + placeholders := make([]string, len(exact)) + for i, typ := range exact { + placeholders[i] = "?" + args = append(args, typ) + } + conditions = append(conditions, + fmt.Sprintf("%s IN (%s)", col, strings.Join(placeholders, ","))) + } + if len(conditions) == 0 { + return "", nil + } + return "(" + strings.Join(conditions, " OR ") + ")", args +} + // escapeILIKE escapes ILIKE wildcard characters (% and _) in user input. func escapeILIKE(s string) string { s = strings.ReplaceAll(s, "\\", "\\\\") // Escape backslash first @@ -486,6 +525,14 @@ func (e *DuckDBEngine) buildNonTextSearchConditions(q *search.Query, keyColumns var conditions []string var args []any + if len(q.MessageTypes) > 0 { + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } + } + // from: filter - match sender email for _, from := range q.FromAddrs { fromPattern := "%" + escapeILIKE(from) + "%" @@ -578,12 +625,11 @@ func (e *DuckDBEngine) buildNonTextSearchConditions(q *search.Query, keyColumns args = append(args, *q.SmallerThan) } if len(q.MessageTypes) > 0 { - placeholders := make([]string, len(q.MessageTypes)) - for i, typ := range q.MessageTypes { - placeholders[i] = "?" - args = append(args, typ) + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) } - conditions = append(conditions, fmt.Sprintf("msg.message_type IN (%s)", strings.Join(placeholders, ","))) } return conditions, args @@ -882,8 +928,11 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []an } if filter.MessageType != "" { - conditions = append(conditions, "msg.message_type = ?") - args = append(args, filter.MessageType) + condition, conditionArgs := duckDBMessageTypeCondition("msg", []string{filter.MessageType}) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } } // Sender + sender-name filters - check both message_recipients (email) @@ -1160,10 +1209,14 @@ func (e *DuckDBEngine) GetTotalStats(ctx context.Context, opts StatsOptions) (*T var conditions []string var args []any + hasExplicitMessageTypes := false + if opts.SearchQuery != "" { + q := search.Parse(opts.SearchQuery) + hasExplicitMessageTypes = len(q.MessageTypes) > 0 + } // Restrict to email messages only; NULL and '' handle pre-message_type data. - hasSearchMessageTypes := opts.SearchQuery != "" && len(search.Parse(opts.SearchQuery).MessageTypes) > 0 - if !hasSearchMessageTypes { + if !hasExplicitMessageTypes { conditions = append(conditions, emailOnlyFilterMsg) } conditions = append(conditions, store.LiveMessagesWhere("msg", opts.HideDeletedFromSource)) @@ -1638,6 +1691,14 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse } } + if len(q.MessageTypes) > 0 { + condition, conditionArgs := duckDBMessageTypeCondition("m", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } + } + // Has attachment filter if q.HasAttachment != nil && *q.HasAttachment { conditions = append(conditions, "m.has_attachments = 1") @@ -1663,12 +1724,11 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse args = append(args, *q.SmallerThan) } if len(q.MessageTypes) > 0 { - placeholders := make([]string, len(q.MessageTypes)) - for i, typ := range q.MessageTypes { - placeholders[i] = "?" - args = append(args, typ) + condition, conditionArgs := duckDBMessageTypeCondition("m", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) } - conditions = append(conditions, fmt.Sprintf("m.message_type IN (%s)", strings.Join(placeholders, ","))) } // Full-text search: use ILIKE fallback (FTS5 not available via sqlite_scan) @@ -1702,7 +1762,8 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse COALESCE(m.size_estimate, 0), m.has_attachments, m.attachment_count, - m.deleted_from_source_at + m.deleted_from_source_at, + COALESCE(m.message_type, '') FROM sqlite_db.messages m LEFT JOIN sqlite_db.message_recipients mr_sender ON mr_sender.message_id = m.id AND mr_sender.recipient_type = 'from' LEFT JOIN sqlite_db.participants p_sender ON p_sender.id = mr_sender.participant_id @@ -1740,6 +1801,7 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse &msg.HasAttachments, &msg.AttachmentCount, &deletedAt, + &msg.MessageType, ); err != nil { return nil, fmt.Errorf("scan message: %w", err) } @@ -2460,11 +2522,17 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt var args []any // Apply basic filter conditions (ignoring join flags for search - we handle those differently) - if len(q.MessageTypes) == 0 { + conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource)) + if len(q.MessageTypes) > 0 { + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } + } else { // Restrict to email messages only; NULL and '' handle pre-message_type data. conditions = append(conditions, emailOnlyFilterMsg) } - conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource)) conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs) if filter.After != nil { conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)") @@ -2617,12 +2685,11 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt args = append(args, *q.SmallerThan) } if len(q.MessageTypes) > 0 { - placeholders := make([]string, len(q.MessageTypes)) - for i, typ := range q.MessageTypes { - placeholders[i] = "?" - args = append(args, typ) + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) } - conditions = append(conditions, fmt.Sprintf("msg.message_type IN (%s)", strings.Join(placeholders, ","))) } // Account filter diff --git a/internal/query/duckdb_test.go b/internal/query/duckdb_test.go index a10e428c..ef4e41d9 100644 --- a/internal/query/duckdb_test.go +++ b/internal/query/duckdb_test.go @@ -41,6 +41,33 @@ func newSQLiteEngine(t *testing.T) *DuckDBEngine { return engine } +func newMessageTypeParquetEngine(t *testing.T) *DuckDBEngine { + t.Helper() + b := NewTestDataBuilder(t) + b.AddSource("test@example.com") + aliceID := b.AddParticipant("alice@example.com", "example.com", "Alice") + bobID := b.AddParticipant("bob@example.com", "example.com", "Bob") + smsID := b.AddMessage(MessageOpt{ + Subject: "lunch plan", + Snippet: "sushi lunch details", + MessageType: "sms", + SentAt: time.Date(2024, 4, 10, 10, 0, 0, 0, time.UTC), + SizeEstimate: 321, + }) + emailID := b.AddMessage(MessageOpt{ + Subject: "lunch receipt", + Snippet: "email lunch details", + MessageType: "email", + SentAt: time.Date(2024, 4, 11, 10, 0, 0, 0, time.UTC), + SizeEstimate: 999, + }) + b.AddFrom(smsID, aliceID, "Alice") + b.AddTo(smsID, bobID, "Bob") + b.AddFrom(emailID, aliceID, "Alice") + b.AddTo(emailID, bobID, "Bob") + return b.BuildEngine() +} + // searchFast is a test helper that parses a query string and calls SearchFast. func searchFast(t *testing.T, engine *DuckDBEngine, queryStr string, filter MessageFilter) []MessageSummary { t.Helper() @@ -1123,6 +1150,54 @@ func TestDuckDBEngine_GetTotalStats_MessageTypeFilter(t *testing.T) { assert.Equal(int64(2000), stats.TotalSize, "TotalSize") } +func TestDuckDBEngine_SearchFastMessageTypeFilter(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + engine := newMessageTypeParquetEngine(t) + ctx := context.Background() + + filterOnly, err := engine.SearchFast(ctx, search.Parse("message_type:sms"), MessageFilter{}, 100, 0) + require.NoError(err, "SearchFast message_type only") + require.Len(filterOnly, 1, "filter-only message_type search") + assert.Equal("sms", filterOnly[0].MessageType) + assert.Equal("lunch plan", filterOnly[0].Subject) + + withText, err := engine.SearchFast(ctx, search.Parse("message_type:sms lunch"), MessageFilter{}, 100, 0) + require.NoError(err, "SearchFast message_type with text") + require.Len(withText, 1, "message_type should scope text search") + assert.Equal("sms", withText[0].MessageType) + assert.Equal("lunch plan", withText[0].Subject) +} + +func TestDuckDBEngine_GetTotalStatsMessageTypeSearch(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + engine := newMessageTypeParquetEngine(t) + + stats, err := engine.GetTotalStats(context.Background(), StatsOptions{ + SearchQuery: "message_type:sms", + }) + require.NoError(err, "GetTotalStats") + + assert.Equal(int64(1), stats.MessageCount, "message count") + assert.Equal(int64(321), stats.TotalSize, "total size") +} + +func TestDuckDBEngine_AggregateMessageTypeSearch(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + engine := newMessageTypeParquetEngine(t) + opts := DefaultAggregateOptions() + opts.SearchQuery = "message_type:sms" + + rows, err := engine.Aggregate(context.Background(), ViewTime, opts) + require.NoError(err, "Aggregate") + require.Len(rows, 1, "rows") + + assert.Equal(int64(1), rows[0].Count, "count") + assert.Equal(int64(321), rows[0].TotalSize, "total size") +} + // TestDuckDBEngine_ListMessages_DateFilter verifies that After/Before date filters // work with DuckDB's TIMESTAMP column (regression: VARCHAR params need CAST). func TestDuckDBEngine_ListMessages_DateFilter(t *testing.T) { @@ -3266,6 +3341,13 @@ func TestDuckDBEngine_StaleParquetSchema(t *testing.T) { assertpkg.Equal(t, "Stale Hello", results[0].Subject) }) + t.Run("SearchFastMessageTypeEmail", func(t *testing.T) { + q := search.Parse("message_type:email Stale") + results, err := engine.SearchFast(ctx, q, MessageFilter{}, 100, 0) + requirepkg.NoError(t, err, "SearchFast message_type:email with stale Parquet schema") + requirepkg.Len(t, results, 2) + }) + t.Run("SearchFastCount", func(t *testing.T) { q := search.Parse("Stale") count, err := engine.SearchFastCount(ctx, q, MessageFilter{}) @@ -3285,6 +3367,12 @@ func TestDuckDBEngine_StaleParquetSchema(t *testing.T) { assertpkg.Equal(t, int64(2), stats.MessageCount) }) + t.Run("GetTotalStatsMessageTypeEmail", func(t *testing.T) { + stats, err := engine.GetTotalStats(ctx, StatsOptions{SearchQuery: "message_type:email"}) + requirepkg.NoError(t, err, "GetTotalStats message_type:email with stale Parquet schema") + assertpkg.Equal(t, int64(2), stats.MessageCount) + }) + // Verify that optionalCols correctly detected the missing columns. t.Run("ProbeDetectedMissing", func(t *testing.T) { for _, col := range []struct{ table, col string }{ diff --git a/internal/query/sqlite.go b/internal/query/sqlite.go index 88d78b8d..9670df1f 100644 --- a/internal/query/sqlite.go +++ b/internal/query/sqlite.go @@ -1497,6 +1497,15 @@ func (e *SQLiteEngine) buildSearchQueryParts(ctx context.Context, q *search.Quer } } + if len(q.MessageTypes) > 0 { + placeholders := make([]string, len(q.MessageTypes)) + for i, typ := range q.MessageTypes { + placeholders[i] = "?" + args = append(args, strings.ToLower(typ)) + } + conditions = append(conditions, fmt.Sprintf("m.message_type IN (%s)", strings.Join(placeholders, ","))) + } + // Has attachment filter if q.HasAttachment != nil && *q.HasAttachment { conditions = append(conditions, e.dialect.BoolTrueExpr("m.has_attachments")) diff --git a/internal/query/sqlite_crud_test.go b/internal/query/sqlite_crud_test.go index f9d365ff..13e0ea66 100644 --- a/internal/query/sqlite_crud_test.go +++ b/internal/query/sqlite_crud_test.go @@ -1100,6 +1100,31 @@ func TestGetTotalStatsWithSearchQuery_Combined(t *testing.T) { assertpkg.Equal(t, int64(2000), stats.TotalSize, "SearchQuery+WithAttachments total size") } +func TestSearchFastWithStats_MessageTypeStats(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + env := newTestEnv(t) + + smsID := env.AddMessage(dbtest.MessageOpts{ + Subject: "Lunch via SMS", + SentAt: "2024-04-01 10:00:00", + SizeEstimate: 321, + }) + _, err := env.DB.Exec(`UPDATE messages SET message_type = 'sms' WHERE id = ?`, smsID) + require.NoError(err, "set sms message_type") + + q := search.Parse("message_type:sms") + result, err := env.Engine.SearchFastWithStats(env.Ctx, q, "message_type:sms", MessageFilter{}, ViewSenders, 100, 0) + require.NoError(err, "SearchFastWithStats") + require.NotNil(result.Stats, "stats") + + require.Len(result.Messages, 1, "messages") + assert.Equal(smsID, result.Messages[0].ID, "message id") + assert.Equal(int64(1), result.TotalCount, "total count") + assert.Equal(int64(1), result.Stats.MessageCount, "stats message count") + assert.Equal(int64(321), result.Stats.TotalSize, "stats total size") +} + func TestGetMessageRaw(t *testing.T) { env := newTestEnv(t) rawMIME := []byte("From: test@example.com\r\nSubject: Test\r\n\r\nHello") diff --git a/internal/query/sqlite_search_test.go b/internal/query/sqlite_search_test.go index 54c3cf42..f321e857 100644 --- a/internal/query/sqlite_search_test.go +++ b/internal/query/sqlite_search_test.go @@ -10,6 +10,7 @@ import ( assertpkg "github.com/stretchr/testify/assert" requirepkg "github.com/stretchr/testify/require" "go.kenn.io/msgvault/internal/search" + "go.kenn.io/msgvault/internal/testutil/dbtest" "go.kenn.io/msgvault/internal/testutil/ptr" ) @@ -94,20 +95,6 @@ func TestSearch_Filters(t *testing.T) { } } -func TestSearch_MessageTypeFilter(t *testing.T) { - require := requirepkg.New(t) - assert := assertpkg.New(t) - env := newTestEnv(t) - - _, err := env.DB.Exec(`UPDATE messages SET message_type = ? WHERE id = ?`, "sms", int64(2)) - require.NoError(err, "mark message as sms") - - results := env.MustSearch(search.Parse("message_type:sms Hello"), 100, 0) - require.Len(results, 1, "message_type:sms should scope the text search") - assert.Equal(int64(2), results[0].ID, "ID") - assert.Equal("sms", results[0].MessageType, "MessageType") -} - func TestSearch_CaseInsensitiveFallback(t *testing.T) { env := newTestEnv(t) @@ -148,6 +135,41 @@ func TestSearch_WithFTS(t *testing.T) { assertpkg.Equal(t, "Hello World", results[0].Subject) } +func TestSearch_MessageTypeFilter(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + env := newTestEnv(t) + aliceID := env.MustLookupParticipant("alice@example.com") + bobID := env.MustLookupParticipant("bob@company.org") + smsID := env.AddMessage(dbtest.MessageOpts{ + Subject: "lunch plan", + SentAt: "2024-04-10 10:00:00", + FromID: aliceID, + ToIDs: []int64{bobID}, + }) + emailID := env.AddMessage(dbtest.MessageOpts{ + Subject: "lunch receipt", + SentAt: "2024-04-11 10:00:00", + FromID: aliceID, + ToIDs: []int64{bobID}, + }) + _, err := env.DB.Exec(`UPDATE messages SET message_type = 'sms' WHERE id = ?`, smsID) + require.NoError(err, "set sms message_type") + _, err = env.DB.Exec(`UPDATE messages SET message_type = 'email' WHERE id = ?`, emailID) + require.NoError(err, "set email message_type") + + results := env.MustSearch(search.Parse("message_type:sms"), 100, 0) + require.Len(results, 1, "filter-only message_type search") + assert.Equal(smsID, results[0].ID) + assert.Equal("sms", results[0].MessageType) + + env.EnableFTS() + results = env.MustSearch(search.Parse("message_type:sms lunch"), 100, 0) + require.Len(results, 1, "message_type must scope FTS search") + assert.Equal(smsID, results[0].ID) + assert.Equal("sms", results[0].MessageType) +} + // TestSearch_WithFTS_SpecialChars verifies that FTS5 special characters in // search terms don't cause syntax errors. Without quoting, these characters // are interpreted as FTS5 operators (- = NOT, : = column filter, () = grouping). diff --git a/internal/remote/engine_test.go b/internal/remote/engine_test.go index dae7242a..67cfffff 100644 --- a/internal/remote/engine_test.go +++ b/internal/remote/engine_test.go @@ -95,7 +95,6 @@ func TestEngineGetMessageSummariesByIDs_CarriesFromAndAttachmentCount(t *testing assert.Equal(2, s.AttachmentCount, "AttachmentCount") assert.True(s.HasAttachments, "HasAttachments") } - func TestEngineSearchSerializesMessageTypes(t *testing.T) { require := requirepkg.New(t) assert := assertpkg.New(t) @@ -125,3 +124,18 @@ func TestEngineSearchSerializesMessageTypes(t *testing.T) { }, 10, 0) require.NoError(err, "Search") } + +func TestBuildSearchQueryStringIncludesMessageTypes(t *testing.T) { + assert := assertpkg.New(t) + + assert.Equal( + "message_type:sms", + buildSearchQueryString(search.Parse("message_type:sms")), + "filter-only message type query", + ) + assert.Equal( + "lunch message_type:sms", + buildSearchQueryString(search.Parse("message_type:sms lunch")), + "message type with text term", + ) +} diff --git a/internal/scheduler/embed_job.go b/internal/scheduler/embed_job.go index cecea362..bd6f1b1b 100644 --- a/internal/scheduler/embed_job.go +++ b/internal/scheduler/embed_job.go @@ -34,6 +34,10 @@ type EmbedCoverage interface { MissingCount(ctx context.Context, activeGen int64) (int64, error) } +type ScopedEmbedCoverage interface { + MissingCountScoped(ctx context.Context, activeGen int64, messageTypes []string) (int64, error) +} + // Compile-time check that the production worker satisfies EmbedRunner. var _ EmbedRunner = (*embed.Worker)(nil) @@ -82,6 +86,10 @@ type EmbedJob struct { // A negative value disables the auto-backstop entirely. BackstopInterval time.Duration + // BuildScope limits coverage checks to the same message universe the + // worker scans for this generation. Empty means the full live corpus. + BuildScope vector.BuildScope + // Now returns the current time; overridable in tests to drive the // backstop interval deterministically. nil uses time.Now. Now func() time.Time @@ -182,7 +190,7 @@ func (j *EmbedJob) Run(ctx context.Context) { "gen", target) return } - missing, err := j.Store.MissingCount(ctx, int64(target)) + missing, err := j.missingCount(ctx, target) if err != nil { log.Warn("embed: coverage count after run failed", "gen", target, "error", err) return @@ -201,6 +209,17 @@ func (j *EmbedJob) Run(ctx context.Context) { log.Info("embed: building generation activated", "gen", target) } +func (j *EmbedJob) missingCount(ctx context.Context, target vector.GenerationID) (int64, error) { + scope := vector.NewBuildScope(j.BuildScope.MessageTypes) + if scope.IsEmpty() { + return j.Store.MissingCount(ctx, int64(target)) + } + if scoped, ok := j.Store.(ScopedEmbedCoverage); ok { + return scoped.MissingCountScoped(ctx, int64(target), scope.MessageTypes) + } + return 0, errors.New("embed coverage store does not support scoped missing counts") +} + // maybeRunBackstop runs a full watermark-ignoring backstop pass on gen when // BackstopInterval has elapsed since this generation's last one, then records // the time. The throttle is keyed per generation so a recent backstop of one diff --git a/internal/store/api_search_test.go b/internal/store/api_search_test.go index 32bea8bc..d3808468 100644 --- a/internal/store/api_search_test.go +++ b/internal/store/api_search_test.go @@ -148,7 +148,6 @@ func TestSearchMessages_LegacyRawString(t *testing.T) { }) } } - func TestSearchMessagesQuery_MessageTypeFilter(t *testing.T) { require := requirepkg.New(t) assert := assertpkg.New(t) diff --git a/internal/store/embed_gen.go b/internal/store/embed_gen.go index e7517283..3ba69995 100644 --- a/internal/store/embed_gen.go +++ b/internal/store/embed_gen.go @@ -33,16 +33,27 @@ var embedGenStampChunkRows = 500 // the embeddings upsert — the worker orders the steps (upsert, then // stamp) and relies on idempotency, see internal/vector/embed/worker.go. func (s *Store) ScanForEmbedding(ctx context.Context, target int64, afterID int64, limit int) ([]int64, error) { + return s.ScanForEmbeddingScoped(ctx, target, afterID, limit, nil) +} + +// ScanForEmbeddingScoped is ScanForEmbedding limited to the supplied message +// types. An empty messageTypes slice means the full live corpus. +func (s *Store) ScanForEmbeddingScoped(ctx context.Context, target int64, afterID int64, limit int, messageTypes []string) ([]int64, error) { if limit <= 0 { return nil, nil } + liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes) q := `SELECT id FROM messages WHERE (embed_gen IS NULL OR embed_gen <> ?) - AND ` + LiveMessagesWhere("", true) + ` + AND ` + liveWhere + ` AND id > ? ORDER BY id LIMIT ?` - rows, err := s.db.QueryContext(ctx, q, target, afterID, limit) + args := make([]any, 0, 3+len(liveArgs)) + args = append(args, target) + args = append(args, liveArgs...) + args = append(args, afterID, limit) + rows, err := s.db.QueryContext(ctx, q, args...) if err != nil { return nil, fmt.Errorf("scan for embedding: %w", err) } @@ -239,14 +250,22 @@ func (s *Store) ResetEmbedGen(ctx context.Context, ids []int64) error { // activeGen == 0 means "no active/target generation"; then everything // live is missing and stamped is 0. func (s *Store) CoverageCounts(ctx context.Context, activeGen int64) (live, stamped, blank, missing int64, err error) { - live, err = s.countLiveMessages(ctx) + return s.CoverageCountsScoped(ctx, activeGen, nil) +} + +// CoverageCountsScoped is CoverageCounts limited to the supplied message types. +// An empty messageTypes slice means the full live corpus. +func (s *Store) CoverageCountsScoped(ctx context.Context, activeGen int64, messageTypes []string) (live, stamped, blank, missing int64, err error) { + live, err = s.countLiveMessagesScoped(ctx, messageTypes) if err != nil { return 0, 0, 0, 0, err } if activeGen != 0 { + liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes) q := `SELECT COUNT(*) FROM messages - WHERE embed_gen = ? AND ` + LiveMessagesWhere("", true) - if err := s.db.QueryRowContext(ctx, q, activeGen).Scan(&stamped); err != nil { + WHERE embed_gen = ? AND ` + liveWhere + args := append([]any{activeGen}, liveArgs...) + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&stamped); err != nil { return 0, 0, 0, 0, fmt.Errorf("count stamped: %w", err) } } @@ -259,7 +278,13 @@ func (s *Store) CoverageCounts(ctx context.Context, activeGen int64) (live, stam // activeGen). It is a thin accessor for the scheduler/CLI activation // gates, which only consult the missing count; missing = live - stamped. func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error) { - live, err := s.countLiveMessages(ctx) + return s.MissingCountScoped(ctx, activeGen, nil) +} + +// MissingCountScoped is MissingCount limited to the supplied message types. +// An empty messageTypes slice means the full live corpus. +func (s *Store) MissingCountScoped(ctx context.Context, activeGen int64, messageTypes []string) (int64, error) { + live, err := s.countLiveMessagesScoped(ctx, messageTypes) if err != nil { return 0, err } @@ -267,9 +292,11 @@ func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error return live, nil } var stamped int64 + liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes) q := `SELECT COUNT(*) FROM messages - WHERE embed_gen = ? AND ` + LiveMessagesWhere("", true) - if err := s.db.QueryRowContext(ctx, q, activeGen).Scan(&stamped); err != nil { + WHERE embed_gen = ? AND ` + liveWhere + args := append([]any{activeGen}, liveArgs...) + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&stamped); err != nil { return 0, fmt.Errorf("count stamped: %w", err) } return max(live-stamped, 0), nil @@ -277,11 +304,48 @@ func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error // countLiveMessages returns the total live-message count. Shared by // CoverageCounts; kept separate so the live-predicate stays in one place. -func (s *Store) countLiveMessages(ctx context.Context) (int64, error) { +func (s *Store) countLiveMessagesScoped(ctx context.Context, messageTypes []string) (int64, error) { var n int64 - q := `SELECT COUNT(*) FROM messages WHERE ` + LiveMessagesWhere("", true) - if err := s.db.QueryRowContext(ctx, q).Scan(&n); err != nil { + liveWhere, args := liveMessagesWhereWithMessageTypes(messageTypes) + q := `SELECT COUNT(*) FROM messages WHERE ` + liveWhere + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&n); err != nil { return 0, fmt.Errorf("count live messages: %w", err) } return n, nil } + +func liveMessagesWhereWithMessageTypes(messageTypes []string) (string, []any) { + where := LiveMessagesWhere("", true) + types := normalizeMessageTypes(messageTypes) + if len(types) == 0 { + return where, nil + } + placeholders := make([]string, len(types)) + args := make([]any, len(types)) + for i, typ := range types { + placeholders[i] = "?" + args[i] = typ + } + where += " AND message_type IN (" + strings.Join(placeholders, ",") + ")" + return where, args +} + +func normalizeMessageTypes(messageTypes []string) []string { + if len(messageTypes) == 0 { + return nil + } + seen := make(map[string]struct{}, len(messageTypes)) + out := make([]string, 0, len(messageTypes)) + for _, typ := range messageTypes { + typ = strings.TrimSpace(strings.ToLower(typ)) + if typ == "" { + continue + } + if _, ok := seen[typ]; ok { + continue + } + seen[typ] = struct{}{} + out = append(out, typ) + } + return out +} diff --git a/internal/synctechsms/importer.go b/internal/synctechsms/importer.go index c16141e7..ec6d5e24 100644 --- a/internal/synctechsms/importer.go +++ b/internal/synctechsms/importer.go @@ -98,39 +98,47 @@ func (i *Importer) importRecord(sourceID int64, record Record, summary *ImportSu switch record.Kind { case RecordSMS: if i.opts.IncludeSMS && record.SMS != nil { - if err := i.importSMS(sourceID, *record.SMS); err != nil { + msgID, err := i.importSMS(sourceID, *record.SMS) + if err != nil { return err } + summary.MessageIDs = append(summary.MessageIDs, msgID) summary.SMSImported++ } case RecordMMS: if i.opts.IncludeMMS && record.MMS != nil { - attachments, err := i.importMMS(sourceID, *record.MMS) + msgID, attachments, err := i.importMMS(sourceID, *record.MMS) if err != nil { + if msgID != 0 { + summary.MessageIDs = append(summary.MessageIDs, msgID) + } return err } + summary.MessageIDs = append(summary.MessageIDs, msgID) summary.MMSImported++ summary.AttachmentsImported += attachments } case RecordCall: if i.opts.IncludeCalls && record.Call != nil { - if err := i.importCall(sourceID, *record.Call); err != nil { + msgID, err := i.importCall(sourceID, *record.Call) + if err != nil { return err } + summary.MessageIDs = append(summary.MessageIDs, msgID) summary.CallsImported++ } } return nil } -func (i *Importer) importSMS(sourceID int64, sms SMS) error { +func (i *Importer) importSMS(sourceID int64, sms SMS) (int64, error) { remoteID, err := i.participantID(sms.Address, sms.ContactName.String) if err != nil { - return err + return 0, err } ownerID, err := i.participantID(i.opts.OwnerPhone, "Me") if err != nil { - return err + return 0, err } // Drafts are owner-authored messages that never made it out, but // they still belong on the owner's side of the conversation. Without @@ -144,36 +152,36 @@ func (i *Importer) importSMS(sourceID int64, sms SMS) error { } convID, err := i.ensureConversation(sourceID, textConversationKey([]int64{ownerID, remoteID}), sms.ContactName.String) if err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, remoteID, "member"); err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, ownerID, "member"); err != nil { - return err + return 0, err } msgID := stableID("sms", sms.Address, sms.Timestamp.String(), fmt.Sprint(sms.Type), sms.Body) return i.upsertTextMessage(sourceID, convID, msgID, "sms", senderID, recipientIDs, fromMe, sms.Timestamp, sms.Body, sms.Body, 0, sms) } -func (i *Importer) importMMS(sourceID int64, mms MMS) (int, error) { +func (i *Importer) importMMS(sourceID int64, mms MMS) (int64, int, error) { ownerID, err := i.participantID(i.opts.OwnerPhone, "Me") if err != nil { - return 0, err + return 0, 0, err } participantIDs, senderID, recipientIDs, err := i.mmsParticipants(mms, ownerID) if err != nil { - return 0, err + return 0, 0, err } // Drafts belong to the owner — see the matching note in importSMS. fromMe := mms.MessageBox == MMSBoxSent || mms.MessageBox == MMSBoxOutbox || mms.MessageBox == MMSBoxDraft convID, err := i.ensureConversation(sourceID, textConversationKey(participantIDs), mms.ContactName.String) if err != nil { - return 0, err + return 0, 0, err } for _, participantID := range participantIDs { if err := i.store.EnsureConversationParticipant(convID, participantID, "member"); err != nil { - return 0, err + return 0, 0, err } } body := mmsText(mms) @@ -183,21 +191,26 @@ func (i *Importer) importMMS(sourceID int64, mms MMS) (int, error) { } msgID := stableID("mms", srcIDPart, mms.Timestamp.String(), sortedKey(participantIDs)) attachmentCount := countImportableAttachments(mms, i.opts.IncludeAttachments) - if err := i.upsertTextMessage(sourceID, convID, msgID, "mms", senderID, recipientIDs, fromMe, mms.Timestamp, body, mms.Subject.String, attachmentCount, mms); err != nil { - return 0, err + messageID, err := i.upsertTextMessage(sourceID, convID, msgID, "mms", senderID, recipientIDs, fromMe, mms.Timestamp, body, mms.Subject.String, attachmentCount, mms) + if err != nil { + return 0, 0, err } - return i.importMMSAttachments(sourceID, msgID, mms) + imported, err := i.importMMSAttachments(sourceID, msgID, mms) + if err != nil { + return messageID, imported, err + } + return messageID, imported, nil } -func (i *Importer) importCall(sourceID int64, call Call) error { +func (i *Importer) importCall(sourceID int64, call Call) (int64, error) { remoteAddress := callParticipantAddress(call) remoteID, err := i.participantID(remoteAddress, call.ContactName.String) if err != nil { - return err + return 0, err } ownerID, err := i.participantID(i.opts.OwnerPhone, "Me") if err != nil { - return err + return 0, err } fromMe := call.Type == CallOutgoing senderID := remoteID @@ -208,20 +221,20 @@ func (i *Importer) importCall(sourceID int64, call Call) error { } convID, err := i.ensureConversation(sourceID, "calls:"+canonicalAddress(remoteAddress), call.ContactName.String) if err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, remoteID, "member"); err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, ownerID, "member"); err != nil { - return err + return 0, err } body := fmt.Sprintf("Call %s, %d seconds", callTypeLabel(call.Type), call.DurationSeconds) msgID := stableID("call", remoteAddress, call.Timestamp.String(), fmt.Sprint(call.Type), strconv.Itoa(call.DurationSeconds)) return i.upsertTextMessage(sourceID, convID, msgID, "synctech_sms_call", senderID, recipientIDs, fromMe, call.Timestamp, body, body, 0, call) } -func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, messageType string, senderID int64, recipientIDs []int64, fromMe bool, sentAt time.Time, body, subject string, attachmentCount int, raw any) error { +func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, messageType string, senderID int64, recipientIDs []int64, fromMe bool, sentAt time.Time, body, subject string, attachmentCount int, raw any) (int64, error) { msgID, err := i.store.UpsertMessage(&store.Message{ ConversationID: convID, SourceID: sourceID, @@ -237,33 +250,33 @@ func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, me AttachmentCount: attachmentCount, }) if err != nil { - return fmt.Errorf("upsert message: %w", err) + return 0, fmt.Errorf("upsert message: %w", err) } if body != "" { if err := i.store.UpsertMessageBody(msgID, sql.NullString{String: body, Valid: true}, sql.NullString{}); err != nil { - return fmt.Errorf("upsert body: %w", err) + return 0, fmt.Errorf("upsert body: %w", err) } } rawJSON, err := json.Marshal(raw) if err != nil { - return fmt.Errorf("marshal raw record: %w", err) + return 0, fmt.Errorf("marshal raw record: %w", err) } if err := i.store.UpsertMessageRawWithFormat(msgID, rawJSON, RawFormat); err != nil { - return fmt.Errorf("upsert raw record: %w", err) + return 0, fmt.Errorf("upsert raw record: %w", err) } if err := i.store.ReplaceMessageRecipients(msgID, "from", []int64{senderID}, []string{""}); err != nil { - return fmt.Errorf("replace from recipient: %w", err) + return 0, fmt.Errorf("replace from recipient: %w", err) } if err := i.store.ReplaceMessageRecipients(msgID, "to", recipientIDs, blankNames(len(recipientIDs))); err != nil { - return fmt.Errorf("replace to recipient: %w", err) + return 0, fmt.Errorf("replace to recipient: %w", err) } if err := i.store.UpsertFTS(msgID, subject, body, "", "", ""); err != nil { - // FTS is an index, not data: a failure to populate it must never abort - // the import. Warn and continue, matching the other UpsertFTS callers - // (sync.go, importer/ingest.go, fbmessenger, whatsapp). [C2] - slog.Warn("failed to upsert FTS", "message", msgID, "error", err) + slog.Warn("failed to update Synctech message FTS index", + "message_id", msgID, + "message_type", messageType, + "error", err) } - return nil + return msgID, nil } func (i *Importer) participantID(address, displayName string) (int64, error) { diff --git a/internal/synctechsms/importer_test.go b/internal/synctechsms/importer_test.go index 769ebefe..ac6a1f1d 100644 --- a/internal/synctechsms/importer_test.go +++ b/internal/synctechsms/importer_test.go @@ -7,6 +7,7 @@ import ( assertpkg "github.com/stretchr/testify/assert" requirepkg "github.com/stretchr/testify/require" "go.kenn.io/msgvault/internal/store" + "go.kenn.io/msgvault/internal/testutil" "go.kenn.io/msgvault/internal/testutil/storetest" ) @@ -84,6 +85,60 @@ func TestImporterImportsCallWithBlankNumber(t *testing.T) { assertMessageCount(t, f.Store, 1) } +func TestImporterContinuesWhenFTSUpsertFails(t *testing.T) { + testutil.SkipIfPostgres(t, "drops the SQLite FTS virtual table to force UpsertFTS failure") + f := storetest.New(t) + _, err := f.Store.DB().Exec(`DROP TABLE messages_fts`) + requirepkg.NoError(t, err, "drop messages_fts") + dir := t.TempDir() + writeFile(t, filepath.Join(dir, "messages.xml"), ` + +`) + + imp := NewImporter(f.Store, ImportOptions{ + OwnerPhone: "+15550000001", + IncludeSMS: true, + }) + summary, err := imp.ImportPath(dir) + requirepkg.NoError(t, err, "ImportPath should tolerate FTS indexing failure") + assertpkg.Equal(t, 1, summary.SMSImported, "summary = %#v", summary) + assertMessageCount(t, f.Store, 1) +} + +func TestImporterReturnsMMSMessageIDWhenAttachmentImportFails(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + f := storetest.New(t) + dir := t.TempDir() + writeFile(t, filepath.Join(dir, "messages.xml"), ` + + + + + + + + + + +`) + + imp := NewImporter(f.Store, ImportOptions{ + OwnerPhone: "+15550000001", + AttachmentsDir: filepath.Join(dir, "attachments"), + MaxAttachmentBytes: 1, + IncludeMMS: true, + IncludeAttachments: true, + }) + summary, err := imp.ImportPath(dir) + require.Error(err, "ImportPath") + assert.Contains(err.Error(), "exceeds maximum size", "ImportPath error") + + assertMessageCount(t, f.Store, 1) + require.Len(summary.MessageIDs, 1, "summary message IDs") + assert.Positive(summary.MessageIDs[0], "summary message id") +} + // TestImporterMarksDraftsAsFromMe guards against regressing the // draft-handling fix. SMSTypeDraft and MMSBoxDraft are owner-authored // messages; treating them as incoming hides them on the wrong side of diff --git a/internal/synctechsms/types.go b/internal/synctechsms/types.go index c2d5d8eb..1ac42744 100644 --- a/internal/synctechsms/types.go +++ b/internal/synctechsms/types.go @@ -160,4 +160,5 @@ type ImportSummary struct { MMSImported int CallsImported int AttachmentsImported int + MessageIDs []int64 } diff --git a/internal/vector/backend.go b/internal/vector/backend.go index 97ec05c8..2d8a2cb4 100644 --- a/internal/vector/backend.go +++ b/internal/vector/backend.go @@ -181,18 +181,18 @@ type Backend interface { Delete(ctx context.Context, gen GenerationID, messageIDs []int64) error Stats(ctx context.Context, gen GenerationID) (Stats, error) - // EmbeddedMessageCount reports how many distinct LIVE, stamped + // EmbeddedMessageCount reports how many distinct in-scope LIVE, stamped // (embed_gen == gen) messages actually have at least one embedding row // for gen. This is the "embedded" leg of the coverage readout (live / // embedded / blank / missing). It lives on the backend because the // embeddings table is in vectors.db on SQLite (and the main DB on PG); - // only the backend holds that handle. The live+stamped intersection is - // REQUIRED for the coverage invariant to hold: SQLite intersects the - // vectors.db embedding ids against a live+stamped query on the main DB - // (cross-DB json_each, mirroring dropDeletedFromSource), while PostgreSQL - // uses a single JOIN to messages. Distinct from Stats.EmbeddingCount only - // in intent: this is the dedicated coverage helper and never folds the - // aggregate (gen == 0) path. + // only the backend holds that handle. The live+stamped+scope intersection + // is REQUIRED for the coverage invariant to hold: SQLite intersects the + // vectors.db embedding ids against an in-scope live+stamped query on the + // main DB (cross-DB json_each, mirroring dropDeletedFromSource), while + // PostgreSQL uses a single JOIN to messages. Distinct from + // Stats.EmbeddingCount only in intent: this is the dedicated coverage + // helper and never folds the aggregate (gen == 0) path. EmbeddedMessageCount(ctx context.Context, gen GenerationID) (int64, error) // LoadVector returns the embedding for a specific message in the diff --git a/internal/vector/build_scope.go b/internal/vector/build_scope.go new file mode 100644 index 00000000..6b53cdbb --- /dev/null +++ b/internal/vector/build_scope.go @@ -0,0 +1,65 @@ +package vector + +import ( + "slices" + "sort" + "strings" +) + +// BuildScope limits which messages are eligible for an embedding +// generation. A zero-value scope means the full corpus. +type BuildScope struct { + MessageTypes []string +} + +// NewBuildScope returns a normalized, stable scope. Message types are +// lowercase, trimmed, de-duplicated, and sorted so fingerprints and SQL +// bindings are deterministic. +func NewBuildScope(messageTypes []string) BuildScope { + seen := make(map[string]struct{}, len(messageTypes)) + out := make([]string, 0, len(messageTypes)) + for _, typ := range messageTypes { + typ = strings.TrimSpace(strings.ToLower(typ)) + if typ == "" { + continue + } + if _, ok := seen[typ]; ok { + continue + } + seen[typ] = struct{}{} + out = append(out, typ) + } + sort.Strings(out) + return BuildScope{MessageTypes: out} +} + +func (s BuildScope) IsEmpty() bool { + return len(s.MessageTypes) == 0 +} + +func (s BuildScope) Fingerprint() string { + if s.IsEmpty() { + return "" + } + return "mt-" + strings.Join(s.MessageTypes, ",") +} + +func (s BuildScope) ContainsMessageType(messageType string) bool { + messageType = strings.TrimSpace(strings.ToLower(messageType)) + return slices.Contains(s.MessageTypes, messageType) +} + +func (s BuildScope) AllowsMessageTypes(messageTypes []string) bool { + if s.IsEmpty() { + return true + } + if len(messageTypes) == 0 { + return false + } + for _, typ := range messageTypes { + if !s.ContainsMessageType(typ) { + return false + } + } + return true +} diff --git a/internal/vector/config.go b/internal/vector/config.go index 5dd3ebd8..23ca8448 100644 --- a/internal/vector/config.go +++ b/internal/vector/config.go @@ -213,7 +213,8 @@ type EmbedConfig struct { // stragglers from repair-encoding resets, transient errors, or crashes) // in addition to the per-tick incremental scan. Zero uses the EmbedJob // default (24h); a negative value disables the auto-backstop. - BackstopInterval time.Duration `toml:"backstop_interval"` + BackstopInterval time.Duration `toml:"backstop_interval"` + Scope EmbedScopeConfig `toml:"scope"` } // EmbedScheduleConfig controls when the embed worker runs on its own @@ -223,6 +224,16 @@ type EmbedScheduleConfig struct { RunAfterSync bool `toml:"run_after_sync"` // trigger after each successful sync } +// EmbedScopeConfig limits which messages are included in newly-built +// embedding generations. The zero value means the full corpus. +type EmbedScopeConfig struct { + MessageTypes []string `toml:"message_types"` +} + +func (s EmbedScopeConfig) BuildScope() BuildScope { + return NewBuildScope(s.MessageTypes) +} + // Fingerprint returns the ":" identifier for the // embedding endpoint half of the policy (§6.7 of the spec). Use // Config.GenerationFingerprint when storing or comparing index @@ -256,8 +267,12 @@ func (e EmbeddingsConfig) Fingerprint() string { // in, an active generation built under the old single-vector policy // would silently accept new chunked entries from an upgraded worker. func (c *Config) GenerationFingerprint() string { - return fmt.Sprintf("%s:%s:c%d:e%d", + fp := fmt.Sprintf("%s:%s:c%d:e%d", c.Embeddings.Fingerprint(), c.Preprocess.Fingerprint(), c.Embeddings.MaxInputChars, embedPolicyVersion) + if scopeFP := c.Embed.Scope.BuildScope().Fingerprint(); scopeFP != "" { + fp = fmt.Sprintf("%s:s%s", fp, scopeFP) + } + return fp } // Validate returns a descriptive error if the config is unusable. diff --git a/internal/vector/config_test.go b/internal/vector/config_test.go index cba34145..dd76eee4 100644 --- a/internal/vector/config_test.go +++ b/internal/vector/config_test.go @@ -433,3 +433,18 @@ func TestConfig_GenerationFingerprint_IncludesEmbedPolicyVersion(t *testing.T) { suffix := fmt.Sprintf(":e%d", embedPolicyVersion) assertpkg.True(t, strings.HasSuffix(got, suffix), "GenerationFingerprint() = %q, want suffix %q", got, suffix) } + +func TestConfig_GenerationFingerprint_IncludesEmbedScope(t *testing.T) { + base := Config{ + Embeddings: EmbeddingsConfig{Model: "m", Dimension: 8, MaxInputChars: 6000}, + } + baseline := base.GenerationFingerprint() + + scoped := base + scoped.Embed.Scope.MessageTypes = []string{"MMS", "sms", "sms"} + + assertpkg.NotEqual(t, baseline, scoped.GenerationFingerprint(), + "GenerationFingerprint should change when embedding is scoped") + assertpkg.Contains(t, scoped.GenerationFingerprint(), ":smt-mms,sms", + "scope fingerprint should normalize and sort message types") +} diff --git a/internal/vector/embed/worker.go b/internal/vector/embed/worker.go index 7bb0943e..e1a180d0 100644 --- a/internal/vector/embed/worker.go +++ b/internal/vector/embed/worker.go @@ -64,6 +64,7 @@ type WorkerDeps struct { Store WorkStore Client EmbeddingClient Preprocess PreprocessConfig + BuildScope vector.BuildScope MaxInputChars int BatchSize int // beforeSkipStamp is a test hook for read-to-stamp race coverage. @@ -309,7 +310,7 @@ func (w *Worker) run(ctx context.Context, gen vector.GenerationID, backstop bool return res, fmt.Errorf("RunOnce: %w", err) } batchStart := time.Now() - ids, err := w.deps.Store.ScanForEmbedding(ctx, int64(gen), afterID, w.deps.BatchSize) + ids, err := w.scanForEmbedding(ctx, int64(gen), afterID) if err != nil { return res, fmt.Errorf("scan for embedding: %w", err) } @@ -552,6 +553,20 @@ func (w *Worker) run(ctx context.Context, gen vector.GenerationID, backstop bool } } +func (w *Worker) scanForEmbedding(ctx context.Context, gen int64, afterID int64) ([]int64, error) { + scope := vector.NewBuildScope(w.deps.BuildScope.MessageTypes) + if scope.IsEmpty() { + return w.deps.Store.ScanForEmbedding(ctx, gen, afterID, w.deps.BatchSize) + } + scoped, ok := w.deps.Store.(interface { + ScanForEmbeddingScoped(ctx context.Context, target int64, afterID int64, limit int, messageTypes []string) ([]int64, error) + }) + if !ok { + return nil, errors.New("work store does not support scoped embedding scans") + } + return scoped.ScanForEmbeddingScoped(ctx, gen, afterID, w.deps.BatchSize, scope.MessageTypes) +} + // advanceWatermark persists the per-gen forward-scan cursor to id after a // batch made forward progress. The backstop never persists (it scans from // 0 by design and must not push the optimistic watermark backward or diff --git a/internal/vector/errors.go b/internal/vector/errors.go index b4d5a5ec..03a7ba3e 100644 --- a/internal/vector/errors.go +++ b/internal/vector/errors.go @@ -64,4 +64,10 @@ var ( // "transient backend slow" response so clients can retry instead // of treating it as a permanent failure. ErrEmbeddingTimeout = errors.New("embedding request timed out") + + // ErrIndexScopeMismatch is returned when a scoped embedding index + // is used without an equivalent structured filter. For example, an + // index built only for message_type=sms must not answer an unscoped + // vector query over email + SMS. + ErrIndexScopeMismatch = errors.New("index scope mismatch") ) diff --git a/internal/vector/hybrid/engine.go b/internal/vector/hybrid/engine.go index f264c1ca..412f0a46 100644 --- a/internal/vector/hybrid/engine.go +++ b/internal/vector/hybrid/engine.go @@ -63,7 +63,8 @@ type Config struct { // participant/label lookup SQL that BuildFilter runs against mainDB. // Pass PostgreSQLDialect.Rebind on PG (pgx rejects bare ?); leave nil // (or SQLiteDialect.Rebind, which is identity) on SQLite. - Rebind func(string) string + Rebind func(string) string + BuildScope vector.BuildScope } // Engine orchestrates the generation check, query embedding, and fusion @@ -111,6 +112,9 @@ func (e *Engine) Search(ctx context.Context, req SearchRequest) ([]vector.FusedH if err != nil { return nil, ResultMeta{}, err } + if err := e.validateBuildScope(req.Filter); err != nil { + return nil, ResultMeta{}, err + } if req.FreeText == "" { return nil, ResultMeta{}, errors.New("empty query") @@ -192,6 +196,24 @@ func (e *Engine) Search(ctx context.Context, req SearchRequest) ([]vector.FusedH }, nil } +func (e *Engine) validateBuildScope(filter vector.Filter) error { + scope := vector.NewBuildScope(e.cfg.BuildScope.MessageTypes) + if scope.IsEmpty() { + return nil + } + if len(filter.MessageTypes) == 0 { + return fmt.Errorf("%w: index is scoped to message_type=%s; add a matching message_type filter", + vector.ErrIndexScopeMismatch, strings.Join(scope.MessageTypes, ",")) + } + if !scope.AllowsMessageTypes(filter.MessageTypes) { + return fmt.Errorf("%w: index is scoped to message_type=%s, query requested message_type=%s", + vector.ErrIndexScopeMismatch, + strings.Join(scope.MessageTypes, ","), + strings.Join(vector.NewBuildScope(filter.MessageTypes).MessageTypes, ",")) + } + return nil +} + // vectorHitsToFused wraps pure-vector hits in the FusedHit schema. // BM25Score and RRFScore are both set to math.NaN(): "not present in // this signal." Pure vector mode never applies Reciprocal Rank Fusion diff --git a/internal/vector/hybrid/engine_test.go b/internal/vector/hybrid/engine_test.go index 3eaef6da..8e424c18 100644 --- a/internal/vector/hybrid/engine_test.go +++ b/internal/vector/hybrid/engine_test.go @@ -165,6 +165,35 @@ func TestEngine_Hybrid_HappyPath(t *testing.T) { assert.Equal(len(results), meta.ReturnedCount) } +func TestEngine_ScopedIndexRequiresMatchingMessageTypeFilter(t *testing.T) { + ctx := context.Background() + f := newEngineFixture(t) + f.Engine.cfg.BuildScope = vector.NewBuildScope([]string{"sms", "mms"}) + + _, _, err := f.Engine.Search(ctx, SearchRequest{ + Mode: ModeVector, + FreeText: "lunch", + Limit: 5, + }) + requirepkg.ErrorIs(t, err, vector.ErrIndexScopeMismatch) + + _, _, err = f.Engine.Search(ctx, SearchRequest{ + Mode: ModeVector, + FreeText: "lunch", + Limit: 5, + Filter: vector.Filter{MessageTypes: []string{"email"}}, + }) + requirepkg.ErrorIs(t, err, vector.ErrIndexScopeMismatch) + + _, _, err = f.Engine.Search(ctx, SearchRequest{ + Mode: ModeVector, + FreeText: "lunch", + Limit: 5, + Filter: vector.Filter{MessageTypes: []string{"sms"}}, + }) + requirepkg.NoError(t, err) +} + // TestFTSTerms covers the FreeText → dialect-neutral term-slice // tokenizer directly (no DB needed). FreeText is split on whitespace // and terms the FTS5/tsquery tokenizers would drop entirely diff --git a/internal/vector/pgvector/backend.go b/internal/vector/pgvector/backend.go index fb0782e6..be9a9a23 100644 --- a/internal/vector/pgvector/backend.go +++ b/internal/vector/pgvector/backend.go @@ -39,6 +39,9 @@ type Options struct { // per-dimension HNSW index on first migration. Optional; if zero // the index is created on first CreateGeneration. Dimension int + // BuildScope limits generation coverage to matching messages. Empty + // means the full corpus. + BuildScope vector.BuildScope // SkipMigrate suppresses the privileged CREATE EXTENSION + full // migrate. A WRITABLE open still applies the (extension-less) schema so // the one-time upgrade lands — read-only-ness is now signalled by @@ -74,7 +77,8 @@ type Options struct { // with the pgvector extension. The same *sql.DB also serves the main // msgvault schema (messages, message_recipients, message_labels). type Backend struct { - db *sql.DB + db *sql.DB + scope vector.BuildScope } // Open verifies the database is reachable, applies the embedding schema @@ -85,7 +89,10 @@ func Open(ctx context.Context, opts Options) (*Backend, error) { if opts.DB == nil { return nil, errors.New("pgvector.Open: Options.DB is required") } - b := &Backend{db: opts.DB} + b := &Backend{ + db: opts.DB, + scope: vector.NewBuildScope(opts.BuildScope.MessageTypes), + } if !opts.SkipMigrate { // serve / build / search: full migrate incl. CREATE EXTENSION (the // extension step is gated by SkipExtension for managed PG). The eager @@ -252,12 +259,22 @@ func isUniqueViolation(err error) bool { // (embed_gen IS NULL OR embed_gen <> gen). Built once and reused by // ActivateGeneration (in-tx, single-DB on PG) and Stats. The $N ordinal // of the generation id is supplied by the caller. -func missingForGenExistsClause(genArg string) string { +func (b *Backend) missingForGenExistsClause(genArg string, firstScopeArg int) (string, []any) { + where := store.LiveMessagesWhere("", true) + args := make([]any, 0, len(b.scope.MessageTypes)) + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "$" + strconv.Itoa(firstScopeArg+i) + args = append(args, typ) + } + where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ",")) + } return fmt.Sprintf(`EXISTS ( SELECT 1 FROM messages WHERE (embed_gen IS NULL OR embed_gen <> %s) AND %s - )`, genArg, store.LiveMessagesWhere("", true)) + )`, genArg, where), args } // ActivateGeneration atomically retires the current active generation @@ -321,18 +338,20 @@ func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationI // phase, which scan-and-fill no longer has, so a legacy/crashed gen with // seeded_at=NULL but full coverage must be activatable. Coverage // (missing==0) is the real gate. + missingClause, missingArgs := b.missingForGenExistsClause("$3", 5) + args := append([]any{now, now, int64(gen), force}, missingArgs...) res, err := tx.ExecContext(ctx, `UPDATE index_generations SET state = 'active', activated_at = $1, completed_at = COALESCE(completed_at, $2) WHERE id = $3 AND state = 'building' - AND ($4 OR NOT `+missingForGenExistsClause("$3")+`)`, - now, now, int64(gen), force) + AND ($4 OR NOT `+missingClause+`)`, + args...) if err != nil { return fmt.Errorf("activate: %w", err) } n, _ := res.RowsAffected() if n == 0 { - return activateGateError(ctx, tx, gen, force) + return b.activateGateError(ctx, tx, gen, force) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit activate generation %d: %w", gen, err) @@ -346,7 +365,7 @@ func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationI // also satisfies the coverage predicate (embed_gen <> gen is true for an // unknown gen id), so checking coverage first would surface the misleading // "messages needing embedding" error instead of the real lifecycle reason. -func activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, force bool) error { +func (b *Backend) activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, force bool) error { var state vector.GenerationState if err := tx.QueryRowContext(ctx, `SELECT state FROM index_generations WHERE id = $1`, int64(gen)).Scan(&state); err != nil { @@ -361,8 +380,10 @@ func activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, // Gen exists and is building, so the only remaining reason the gated // promote affected zero rows is the coverage term. var missing bool + missingClause, missingArgs := b.missingForGenExistsClause("$1", 2) + args := append([]any{int64(gen)}, missingArgs...) if err := tx.QueryRowContext(ctx, - `SELECT `+missingForGenExistsClause("$1"), int64(gen)).Scan(&missing); err != nil { + `SELECT `+missingClause, args...).Scan(&missing); err != nil { return fmt.Errorf("check coverage for generation %d: %w", gen, err) } if missing && !force { @@ -1206,9 +1227,9 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St return s, nil } -// EmbeddedMessageCount returns the number of LIVE messages that are -// stamped for gen (embed_gen = gen) AND actually have at least one vector -// for the generation. Used by the coverage readout to split stamped +// EmbeddedMessageCount returns the number of in-scope LIVE messages that +// are stamped for gen (embed_gen = gen) AND actually have at least one +// vector for the generation. Used by the coverage readout to split stamped // messages into embedded vs blank. Counts distinct messages (not chunk // rows) so a long, multi-chunk message counts once, matching the // EmbeddingCount semantic elsewhere. @@ -1226,15 +1247,25 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St // live intersection is a single JOIN against messages, mirroring // store.LiveMessagesWhere's predicate. func (b *Backend) EmbeddedMessageCount(ctx context.Context, gen vector.GenerationID) (int64, error) { + where := `e.generation_id = $1 + AND m.embed_gen = $1 + AND ` + store.LiveMessagesWhere("m", true) + args := []any{int64(gen)} + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "$" + strconv.Itoa(2+i) + args = append(args, typ) + } + where += fmt.Sprintf(" AND m.message_type IN (%s)", strings.Join(placeholders, ",")) + } var n int64 if err := b.db.QueryRowContext(ctx, `SELECT COUNT(DISTINCT e.message_id) FROM embeddings e JOIN messages m ON m.id = e.message_id - WHERE e.generation_id = $1 - AND m.embed_gen = $1 - AND `+store.LiveMessagesWhere("m", true), - int64(gen)).Scan(&n); err != nil { + WHERE `+where, + args...).Scan(&n); err != nil { return 0, fmt.Errorf("count embedded messages: %w", err) } return n, nil diff --git a/internal/vector/pgvector/backend_filter_test.go b/internal/vector/pgvector/backend_filter_test.go index dfebe205..2dda1024 100644 --- a/internal/vector/pgvector/backend_filter_test.go +++ b/internal/vector/pgvector/backend_filter_test.go @@ -3,6 +3,7 @@ package pgvector import ( + "fmt" "testing" "time" @@ -11,6 +12,20 @@ import ( "go.kenn.io/msgvault/internal/vector" ) +func TestBuildPGFilterClausesMessageTypes(t *testing.T) { + var args []any + bind := func(v any) string { + args = append(args, v) + return fmt.Sprintf("$%d", len(args)) + } + + clauses := buildPGFilterClauses(vector.Filter{MessageTypes: []string{"sms", "mms"}}, bind) + + require.Len(t, clauses, 1) + assert.Equal(t, "m.message_type = ANY($1::text[])", clauses[0]) + assert.Equal(t, []any{`{"sms","mms"}`}, args) +} + func TestBackendSearchStructuredFilters(t *testing.T) { b, ctx, db := newBackendForTest(t) gen := seedAndEmbed(t, b, db, map[int64][]float32{ @@ -103,6 +118,11 @@ func TestBackendSearchStructuredFilters(t *testing.T) { filter: vector.Filter{LargerThan: &largerThan, SmallerThan: &smallerThan}, want: []int64{2}, }, + { + name: "message type", + filter: vector.Filter{MessageTypes: []string{"sms"}}, + want: []int64{2}, + }, { name: "no match sentinel", filter: vector.Filter{SenderGroups: [][]int64{{-1}}}, @@ -129,6 +149,21 @@ func TestBackendSearchStructuredFilters(t *testing.T) { } } +func TestBackendSearchMessageTypeFilter(t *testing.T) { + b, ctx, db := newBackendForTest(t) + gen := seedAndEmbed(t, b, db, map[int64][]float32{ + 1: unitVec(4, 0), + 2: unitVec(4, 1), + 3: unitVec(4, 2), + }) + _, err := db.ExecContext(ctx, `UPDATE messages SET message_type = CASE id WHEN 1 THEN 'email' ELSE 'sms' END`) + require.NoError(t, err, "seed message_type") + + hits, err := b.Search(ctx, gen, unitVec(4, 0), 10, vector.Filter{MessageTypes: []string{"sms"}}) + require.NoError(t, err, "Search") + assert.Equal(t, []int64{2, 3}, hitMessageIDs(hits)) +} + func hitMessageIDs(hits []vector.Hit) []int64 { out := make([]int64, len(hits)) for i, h := range hits { diff --git a/internal/vector/pgvector/coverage_test.go b/internal/vector/pgvector/coverage_test.go index 2293b8d9..ff177ffd 100644 --- a/internal/vector/pgvector/coverage_test.go +++ b/internal/vector/pgvector/coverage_test.go @@ -90,3 +90,66 @@ func TestCoverageSplit_EmbeddedBlankMissing(t *testing.T) { assert.Equal(live, embeddedCount+blank+missing, "invariant: live == embedded + blank + missing") } + +func TestCoverageSplit_ScopedEmbeddedHoldsInvariant(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + ctx := context.Background() + + db := openPGTestDB(t) // skips when MSGVAULT_TEST_DB is unset + b, err := Open(ctx, Options{ + DB: db, + Dimension: 8, + BuildScope: vector.NewBuildScope([]string{"sms"}), + }) + require.NoError(err, "Open backend") + t.Cleanup(func() { _ = b.Close() }) + + _, err = db.ExecContext(ctx, ` + INSERT INTO messages (id, message_type) VALUES + (1, 'email'), + (2, 'sms') + ON CONFLICT (id) DO UPDATE SET + message_type = EXCLUDED.message_type, + deleted_at = NULL, + deleted_from_source_at = NULL, + embed_gen = NULL`) + require.NoError(err, "seed scoped messages") + + gen, err := b.CreateGeneration(ctx, "test-model", 8, "fp") + require.NoError(err, "CreateGeneration") + require.NoError(b.Upsert(ctx, gen, []vector.Chunk{ + {MessageID: 1, Vector: []float32{1, 0, 0, 0, 0, 0, 0, 0}}, + {MessageID: 2, Vector: []float32{0, 1, 0, 0, 0, 0, 0, 0}}, + }), "Upsert embedded vectors") + _, err = db.ExecContext(ctx, `UPDATE messages SET embed_gen = $1 WHERE id IN (1, 2)`, int64(gen)) + require.NoError(err, "stamp embedded") + + var live, stamped int64 + require.NoError(db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM messages + WHERE message_type = 'sms' + AND deleted_at IS NULL + AND deleted_from_source_at IS NULL`).Scan(&live), + "count scoped live") + require.NoError(db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM messages + WHERE message_type = 'sms' + AND embed_gen = $1 + AND deleted_at IS NULL + AND deleted_from_source_at IS NULL`, + int64(gen)).Scan(&stamped), "count scoped stamped") + missing := live - stamped + + embeddedCount, err := b.EmbeddedMessageCount(ctx, gen) + require.NoError(err, "EmbeddedMessageCount") + blank := max(stamped-embeddedCount, 0) + + assert.Equal(int64(1), live, "only sms is in scope") + assert.Equal(int64(1), stamped, "only scoped stamped messages count") + assert.Equal(int64(1), embeddedCount, "out-of-scope email vector excluded") + assert.Equal(int64(0), blank) + assert.Equal(int64(0), missing) + assert.Equal(live, embeddedCount+blank+missing, + "invariant: live == embedded + blank + missing") +} diff --git a/internal/vector/pgvector/ext_stub.go b/internal/vector/pgvector/ext_stub.go index f82e3130..05c0980c 100644 --- a/internal/vector/pgvector/ext_stub.go +++ b/internal/vector/pgvector/ext_stub.go @@ -27,6 +27,7 @@ type Options struct { SkipMigrate bool ReadOnly bool SkipExtension bool + BuildScope vector.BuildScope } // Backend is a placeholder type so non-pgvector builds can compile diff --git a/internal/vector/pgvector/fused_test.go b/internal/vector/pgvector/fused_test.go index 325228de..b4bcb97e 100644 --- a/internal/vector/pgvector/fused_test.go +++ b/internal/vector/pgvector/fused_test.go @@ -148,6 +148,34 @@ func TestFusedSearch_FTSOnly(t *testing.T) { } } +func TestFusedSearch_MessageTypeFilter(t *testing.T) { + f := newFusedFixture(t) + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + f.seedMsg(t, 1, "lunch plan", "sms lunch details", 10, base, false) + f.seedMsg(t, 2, "lunch receipt", "email lunch details", 10, base.Add(time.Hour), false) + f.seedMsg(t, 3, "dinner plan", "sms dinner details", 10, base.Add(2*time.Hour), false) + _, err := f.db.ExecContext(f.ctx, `UPDATE messages SET message_type = CASE id WHEN 2 THEN 'email' ELSE 'sms' END`) + require.NoError(t, err, "seed message_type") + f.embedAll(t, map[int64][]float32{ + 1: unitVec(4, 0), + 2: unitVec(4, 1), + 3: unitVec(4, 2), + }) + + hits, saturated, err := f.b.FusedSearch(f.ctx, vector.FusedRequest{ + FTSTerms: []string{"lunch"}, + Generation: f.gen, + KPerSignal: 10, + Limit: 10, + RRFK: 60, + Filter: vector.Filter{MessageTypes: []string{"sms"}}, + }) + require.NoError(t, err, "FusedSearch") + assert.False(t, saturated, "saturated should be false") + require.Len(t, hits, 1, "message_type filter should exclude email FTS hits; hits=%+v", hits) + assert.Equal(t, int64(1), hits[0].MessageID) +} + func TestFusedSearch_ANNOnly(t *testing.T) { f := seedThree(t) hits, saturated, err := f.b.FusedSearch(f.ctx, vector.FusedRequest{ diff --git a/internal/vector/pgvector/parity_test.go b/internal/vector/pgvector/parity_test.go index 532f7372..37aca933 100644 --- a/internal/vector/pgvector/parity_test.go +++ b/internal/vector/pgvector/parity_test.go @@ -69,6 +69,7 @@ func buildSqlitevecParity(t *testing.T, corpus []parityDoc) (*sqlitevec.Backend, CREATE TABLE messages ( id INTEGER PRIMARY KEY, subject TEXT, + message_type TEXT NOT NULL DEFAULT 'email', source_id INTEGER, sender_id INTEGER, has_attachments INTEGER DEFAULT 0, diff --git a/internal/vector/sqlitevec/backend.go b/internal/vector/sqlitevec/backend.go index a38c19f0..72ee9142 100644 --- a/internal/vector/sqlitevec/backend.go +++ b/internal/vector/sqlitevec/backend.go @@ -28,10 +28,11 @@ var _ vector.Backend = (*Backend)(nil) // Options configures how Open establishes a Backend. type Options struct { - Path string - MainPath string // filesystem path to msgvault.db; required for FusedSearch - Dimension int // default dimension for EnsureVectorTable at open - MainDB *sql.DB // handle to the main msgvault.db + Path string + MainPath string // filesystem path to msgvault.db; required for FusedSearch + Dimension int // default dimension for EnsureVectorTable at open + MainDB *sql.DB // handle to the main msgvault.db + BuildScope vector.BuildScope // empty means full corpus // ReadOnly indicates the main DB handle (MainDB) was opened read-only // — e.g. the MCP server's store.OpenReadOnly (_query_only=true). When // set, Open SKIPS BackfillEmbedGenForUpgrade, which would otherwise @@ -50,6 +51,7 @@ type Backend struct { path string // filesystem path to vectors.db mainPath string // filesystem path to msgvault.db (for ATTACH) dim int + scope vector.BuildScope // readOnly is true when mainDB was opened read-only (MCP). The // one-time upgrade backfill self-guards on it so it never writes // through the read-only main handle. See Options.ReadOnly. @@ -76,6 +78,7 @@ func Open(ctx context.Context, opts Options) (*Backend, error) { path: opts.Path, mainPath: opts.MainPath, dim: opts.Dimension, + scope: vector.NewBuildScope(opts.BuildScope.MessageTypes), readOnly: opts.ReadOnly, } // Orphaned-stamp reset (vectors.db-recreate safety): clear embed_gen for @@ -268,24 +271,35 @@ func isUniqueConstraintErr(err error) bool { // needs embedding for gen (embed_gen IS NULL OR embed_gen <> gen). This is // the scan-and-fill coverage gate. On SQLite the messages live in the main // DB while the generation lifecycle lives in vectors.db, so the gate -// cannot be folded into the activation tx — ActivateGeneration runs this -// Go pre-check against b.mainDB before its vectors.db tx (mirroring the -// intentionally-non-atomic scheduler gate). The full-scan backstop covers -// any TOCTOU window between this read and the flip. +// cannot be folded into the activation tx. func (b *Backend) hasMissingForGen(ctx context.Context, gen vector.GenerationID) (bool, error) { var exists int + where, args := b.missingCoverageWhere(int64(gen)) err := b.mainDB.QueryRowContext(ctx, `SELECT EXISTS ( SELECT 1 FROM messages - WHERE (embed_gen IS NULL OR embed_gen <> ?) - AND `+store.LiveMessagesWhere("", true)+` - )`, int64(gen)).Scan(&exists) + WHERE `+where+` + )`, args...).Scan(&exists) if err != nil { return false, fmt.Errorf("check missing coverage for generation %d: %w", gen, err) } return exists == 1, nil } +func (b *Backend) missingCoverageWhere(gen int64) (string, []any) { + where := "(embed_gen IS NULL OR embed_gen <> ?) AND " + store.LiveMessagesWhere("", true) + args := []any{gen} + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "?" + args = append(args, typ) + } + where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ",")) + } + return where, args +} + // ActivateGeneration atomically retires the current active generation // (if any) and promotes `gen` to active. func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationID, force bool) error { @@ -1375,9 +1389,9 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St return s, nil } -// EmbeddedMessageCount returns the number of LIVE messages that are -// stamped for gen (embed_gen = gen) AND actually have at least one vector -// for the generation. Used by the coverage readout to split stamped +// EmbeddedMessageCount returns the number of in-scope LIVE messages that +// are stamped for gen (embed_gen = gen) AND actually have at least one +// vector for the generation. Used by the coverage readout to split stamped // messages into embedded vs blank. Counts distinct messages (not chunk // rows) so a long, multi-chunk message counts once, matching the // EmbeddingCount semantic elsewhere. @@ -1434,18 +1448,27 @@ func (b *Backend) EmbeddedMessageCount(ctx context.Context, gen vector.Generatio return 0, nil } - // Step 2 (main.db): how many of those are live AND stamped for gen. + // Step 2 (main.db): how many of those are in-scope, live, AND stamped for gen. blob, err := json.Marshal(ids) if err != nil { return 0, fmt.Errorf("encode embedded ids: %w", err) } + where := `id IN (SELECT value FROM json_each(?)) + AND embed_gen = ? + AND ` + store.LiveMessagesWhere("", true) + args := []any{string(blob), int64(gen)} + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "?" + args = append(args, typ) + } + where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ",")) + } var n int64 if err := b.mainDB.QueryRowContext(ctx, - `SELECT COUNT(*) FROM messages - WHERE id IN (SELECT value FROM json_each(?)) - AND embed_gen = ? - AND `+store.LiveMessagesWhere("", true), - string(blob), int64(gen)).Scan(&n); err != nil { + `SELECT COUNT(*) FROM messages WHERE `+where, + args...).Scan(&n); err != nil { return 0, fmt.Errorf("count live embedded messages: %w", err) } return n, nil diff --git a/internal/vector/sqlitevec/backend_test.go b/internal/vector/sqlitevec/backend_test.go index 697fd6d1..7e873413 100644 --- a/internal/vector/sqlitevec/backend_test.go +++ b/internal/vector/sqlitevec/backend_test.go @@ -292,6 +292,49 @@ func TestBackend_ActivateGeneration_NullSeededAtActivatesWithCoverage(t *testing assertpkg.Equal(t, vector.GenerationActive, genStateSV(t, b, gen), "now active") } +func TestBackend_ActivateCoverageScopesMessageTypes(t *testing.T) { + ctx := context.Background() + main, err := sql.Open("sqlite3", ":memory:") + requirepkg.NoError(t, err, "open main") + t.Cleanup(func() { _ = main.Close() }) + _, err = main.Exec(`CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + message_type TEXT NOT NULL, + embed_gen INTEGER, + deleted_at DATETIME, + deleted_from_source_at DATETIME + )`) + requirepkg.NoError(t, err, "create messages") + _, err = main.Exec(` + INSERT INTO messages (id, message_type, deleted_from_source_at) VALUES + (1, 'email', NULL), + (2, 'sms', NULL), + (3, 'mms', NULL), + (4, 'sms', CURRENT_TIMESTAMP)`) + requirepkg.NoError(t, err, "insert messages") + + b, err := Open(ctx, Options{ + Path: filepath.Join(t.TempDir(), "vectors.db"), + Dimension: 768, + MainDB: main, + BuildScope: vector.NewBuildScope([]string{"sms", "mms"}), + }) + requirepkg.NoError(t, err, "Open") + t.Cleanup(func() { _ = b.Close() }) + + gid, err := b.CreateGeneration(ctx, "m", 768, "") + requirepkg.NoError(t, err, "Create") + missing, err := b.hasMissingForGen(ctx, gid) + requirepkg.NoError(t, err, "hasMissingForGen before stamp") + assertpkg.True(t, missing, "in-scope messages need embedding") + + _, err = main.Exec(`UPDATE messages SET embed_gen = ? WHERE id IN (2, 3)`, int64(gid)) + requirepkg.NoError(t, err, "stamp in-scope messages") + missing, err = b.hasMissingForGen(ctx, gid) + requirepkg.NoError(t, err, "hasMissingForGen after stamp") + assertpkg.False(t, missing, "out-of-scope and deleted messages must not block scoped activation") +} + // TestBackend_CreateGeneration_ResumesBuilding confirms that calling // CreateGeneration while a building row already exists with the same // fingerprint returns the existing id instead of failing on the unique diff --git a/internal/vector/sqlitevec/coverage_test.go b/internal/vector/sqlitevec/coverage_test.go index 2ffcd0ac..55c6c786 100644 --- a/internal/vector/sqlitevec/coverage_test.go +++ b/internal/vector/sqlitevec/coverage_test.go @@ -218,3 +218,63 @@ func TestCoverageSplit_NonLiveEmbeddedHoldsInvariant(t *testing.T) { assert.Equal(live, embedded+blank+missingCount, "invariant: live == embedded + blank + missing") } + +func TestCoverageSplit_ScopedEmbeddedHoldsInvariant(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + ctx := context.Background() + + st := testutil.NewSQLiteTestStore(t) + b, err := Open(ctx, Options{ + Path: filepath.Join(t.TempDir(), "vectors.db"), + Dimension: 8, + MainDB: st.DB(), + BuildScope: vector.NewBuildScope([]string{"sms"}), + }) + require.NoError(err, "Open backend") + t.Cleanup(func() { _ = b.Close() }) + + source, err := st.GetOrCreateSource("gmail", "me@example.com") + require.NoError(err, "GetOrCreateSource") + emailConvID, err := st.EnsureConversationWithType(source.ID, "conv-email", "email_thread", "Email") + require.NoError(err, "EnsureConversationWithType email") + smsConvID, err := st.EnsureConversationWithType(source.ID, "conv-sms", "sms_thread", "SMS") + require.NoError(err, "EnsureConversationWithType sms") + + makeMsg := func(srcMsgID, typ string, convID int64) int64 { + m := &store.Message{ + SourceID: source.ID, + SourceMessageID: srcMsgID, + ConversationID: convID, + MessageType: typ, + Subject: sql.NullString{String: "s-" + srcMsgID, Valid: true}, + } + id, err := st.UpsertMessage(m) + require.NoErrorf(err, "UpsertMessage %s", srcMsgID) + return id + } + outOfScopeEmail := makeMsg("email-stamped", "email", emailConvID) + inScopeSMS := makeMsg("sms-stamped", "sms", smsConvID) + + gen, err := b.CreateGeneration(ctx, "test-model", 8, "fp") + require.NoError(err, "CreateGeneration") + require.NoError(b.Upsert(ctx, gen, []vector.Chunk{ + {MessageID: outOfScopeEmail, Vector: []float32{1, 0, 0, 0, 0, 0, 0, 0}}, + {MessageID: inScopeSMS, Vector: []float32{0, 1, 0, 0, 0, 0, 0, 0}}, + }), "Upsert embedded vectors") + require.NoError(st.SetEmbedGen(ctx, []int64{outOfScopeEmail, inScopeSMS}, int64(gen)), "stamp embedded") + + live, stamped, _, missingCount, err := st.CoverageCountsScoped(ctx, int64(gen), []string{"sms"}) + require.NoError(err, "CoverageCountsScoped") + embedded, err := b.EmbeddedMessageCount(ctx, gen) + require.NoError(err, "EmbeddedMessageCount") + blank := max(stamped-embedded, 0) + + assert.Equal(int64(1), live, "only sms is in scope") + assert.Equal(int64(1), stamped, "only scoped stamped messages count") + assert.Equal(int64(1), embedded, "out-of-scope email vector excluded") + assert.Equal(int64(0), blank) + assert.Equal(int64(0), missingCount) + assert.Equal(live, embedded+blank+missingCount, + "invariant: live == embedded + blank + missing") +} diff --git a/internal/vector/sqlitevec/ext_stub.go b/internal/vector/sqlitevec/ext_stub.go index 081a8398..55b50a81 100644 --- a/internal/vector/sqlitevec/ext_stub.go +++ b/internal/vector/sqlitevec/ext_stub.go @@ -32,11 +32,12 @@ func Available() bool { return false } // can reference sqlitevec.Options without a compile error; the struct is // never populated at runtime when the PG code path is taken. type Options struct { - Path string - MainPath string - Dimension int - MainDB *sql.DB - ReadOnly bool + Path string + MainPath string + Dimension int + MainDB *sql.DB + BuildScope vector.BuildScope + ReadOnly bool } // Backend is the stub backend type for builds without sqlite_vec. diff --git a/internal/vector/sqlitevec/fused_test.go b/internal/vector/sqlitevec/fused_test.go index 19a34f94..03d0f47a 100644 --- a/internal/vector/sqlitevec/fused_test.go +++ b/internal/vector/sqlitevec/fused_test.go @@ -287,11 +287,11 @@ func openFusedMainWithSchema(t *testing.T, path string) *sql.DB { t.Cleanup(func() { _ = db.Close() }) // sent_at is DATETIME (text) to match the production schema. schema := ` -CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - subject TEXT, - message_type TEXT NOT NULL DEFAULT 'email', - source_id INTEGER, + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + subject TEXT, + message_type TEXT NOT NULL DEFAULT 'email', + source_id INTEGER, sender_id INTEGER, has_attachments INTEGER DEFAULT 0, size_estimate INTEGER, @@ -322,10 +322,10 @@ func openFusedMainWithoutMessageType(t *testing.T, path string) *sql.DB { requirepkg.NoError(t, err, "open main") t.Cleanup(func() { _ = db.Close() }) schema := ` -CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - subject TEXT, - source_id INTEGER, + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + subject TEXT, + source_id INTEGER, sender_id INTEGER, has_attachments INTEGER DEFAULT 0, size_estimate INTEGER,