From f1983c510833d5c070ec70b6df127ea481b471b2 Mon Sep 17 00:00:00 2001 From: Dan Shapiro Date: Thu, 25 Jun 2026 16:38:04 -0500 Subject: [PATCH 1/4] feat(search): support message_type filters Add parsed message_type filters through the search, vector, remote, and MCP paths so scoped SMS/MMS/email queries do not silently widen after parsing. This keeps result rows, stats, DuckDB fast paths, SQLite FTS, pgvector filters, hybrid search, and remote query reconstruction aligned around the same user-visible operator. The embedding scope work also needs durable enqueue behavior across Synctech imports and generation transitions. Preserve best-effort import semantics, enqueue already-persisted Synctech messages even after partial failures, wire the manual Synctech sync path into vector enqueueing, and filter pending embeddings per generation scope so full-corpus and scoped generations remain independently complete. Included follow-up fixes: - feat(vector): support scoped embedding builds - fix(vector): enqueue synctech sms imports - fix(search): honor message_type scopes - fix(vector): close scoped search gaps - fix(mcp): validate similar index before seed load - fix(mcp): preserve no-active similar error - fix(vector): align pg parity sqlite fixture - fix(search): keep scoped stats consistent - fix(search): complete scoped stats coverage - fix(remote): preserve message type searches - fix(vector): preserve scoped error contracts - fix(vector): enqueue manual synctech syncs - fix(vector): enqueue per generation scope Generated with Codex Co-authored-by: Wes McKinney --- README.md | 2 + cmd/msgvault/cmd/add_synctech_sms_drive.go | 51 +++--- .../cmd/add_synctech_sms_drive_test.go | 137 ++++++++++++++-- cmd/msgvault/cmd/embed_vector.go | 16 +- cmd/msgvault/cmd/search_vector.go | 11 +- cmd/msgvault/cmd/serve.go | 1 + cmd/msgvault/cmd/serve_vector.go | 14 +- internal/api/handlers.go | 2 + internal/mcp/handlers.go | 34 +++- internal/mcp/server_test.go | 149 ++++++++++++++++-- internal/query/duckdb.go | 113 ++++++++++--- internal/query/duckdb_test.go | 88 +++++++++++ internal/query/sqlite.go | 9 ++ internal/query/sqlite_crud_test.go | 25 +++ internal/query/sqlite_search_test.go | 50 ++++-- internal/remote/engine_test.go | 16 +- internal/scheduler/embed_job.go | 21 ++- internal/store/api_search_test.go | 1 - internal/store/embed_gen.go | 86 ++++++++-- internal/synctechsms/importer.go | 83 ++++++---- internal/synctechsms/importer_test.go | 55 +++++++ internal/synctechsms/types.go | 1 + internal/vector/build_scope.go | 65 ++++++++ internal/vector/config.go | 19 ++- internal/vector/config_test.go | 15 ++ internal/vector/embed/worker.go | 17 +- internal/vector/errors.go | 6 + internal/vector/hybrid/engine.go | 24 ++- internal/vector/hybrid/engine_test.go | 29 ++++ internal/vector/pgvector/backend.go | 39 +++-- .../vector/pgvector/backend_filter_test.go | 41 ++++- internal/vector/pgvector/ext_stub.go | 1 + internal/vector/pgvector/fused_test.go | 30 ++++ internal/vector/pgvector/parity_test.go | 1 + internal/vector/sqlitevec/backend.go | 36 +++-- internal/vector/sqlitevec/backend_test.go | 43 +++++ internal/vector/sqlitevec/ext_stub.go | 11 +- internal/vector/sqlitevec/fused_test.go | 18 +-- 38 files changed, 1170 insertions(+), 190 deletions(-) create mode 100644 internal/vector/build_scope.go diff --git a/README.md b/README.md index ec04f505e..354fc9972 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,8 @@ A separate MCP tool, `find_similar_messages`, returns nearest neighbors for a se > **Run only one embedding process at a time.** Don't run `msgvault embeddings build`/`resume` or `repair-encoding` concurrently with a `msgvault serve` daemon — they write the same embedding state, and concurrent writers are not coordinated across processes. +Large archives can scope an embedding generation with `[vector.embed.scope] message_types = ["sms", "mms"]`. Scoped vector and hybrid searches must include a matching `message_type` filter so a partial index is never used as if it covered the whole archive. + ## Importing from MBOX or Apple Mail Import email from providers that offer MBOX exports or from a local Apple Mail data directory: diff --git a/cmd/msgvault/cmd/add_synctech_sms_drive.go b/cmd/msgvault/cmd/add_synctech_sms_drive.go index 44d60d04c..e6f3d340e 100644 --- a/cmd/msgvault/cmd/add_synctech_sms_drive.go +++ b/cmd/msgvault/cmd/add_synctech_sms_drive.go @@ -109,10 +109,15 @@ func runConfiguredSynctechSMSSource(ctx context.Context, src config.SynctechSMSS return err } defer func() { _ = st.Close() }() + return runConfiguredSynctechSMSSourceWithStore(ctx, st, src) } func runConfiguredSynctechSMSSourceWithStore(ctx context.Context, st *store.Store, src config.SynctechSMSSource) error { + return runConfiguredSynctechSMSSourceWithStoreDriveClient(ctx, st, src, nil) +} + +func runConfiguredSynctechSMSSourceWithStoreDriveClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, driveClient synctechsms.DriveClient) error { opts := synctechImportOptions(src) if opts.OwnerPhone == "" { return fmt.Errorf("synctech-sms source %q owner_phone is required", src.Name) @@ -128,7 +133,11 @@ func runConfiguredSynctechSMSSourceWithStore(ctx context.Context, st *store.Stor } _, err = synctechsms.NewImporter(st, opts).ImportPath(src.Path) case "drive": - err = runSynctechSMSDriveSource(ctx, st, src, opts) + if driveClient != nil { + _, err = runSynctechSMSDriveSourceWithClient(ctx, st, src, opts, driveClient) + } else { + _, err = runSynctechSMSDriveSource(ctx, st, src, opts) + } default: return fmt.Errorf("unsupported synctech-sms backend %q", src.Backend) } @@ -167,28 +176,28 @@ func validateSynctechSMSDriveSource(src config.SynctechSMSSource) error { return nil } -func runSynctechSMSDriveSource(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions) error { +func runSynctechSMSDriveSource(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions) (synctechsms.ImportSummary, error) { if err := validateSynctechSMSDriveSource(src); err != nil { - return err + return synctechsms.ImportSummary{}, err } client, err := newSynctechSMSDriveClient(ctx, src) if err != nil { - return err + return synctechsms.ImportSummary{}, err } return runSynctechSMSDriveSourceWithClient(ctx, st, src, opts, client) } -func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions, client synctechsms.DriveClient) (retErr error) { +func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, src config.SynctechSMSSource, opts synctechsms.ImportOptions, client synctechsms.DriveClient) (summary synctechsms.ImportSummary, retErr error) { if err := validateSynctechSMSDriveSource(src); err != nil { - return err + return summary, err } source, err := ensureConfiguredSynctechSMSSource(st, src, opts) if err != nil { - return err + return summary, err } syncID, err := st.StartSync(source.ID, synctechsms.AdapterName) if err != nil { - return fmt.Errorf("start sync: %w", err) + return summary, fmt.Errorf("start sync: %w", err) } completed := false defer func() { @@ -204,38 +213,38 @@ func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, s }() files, err := client.ListBackupFiles(ctx, src.FolderID) if err != nil { - return fmt.Errorf("list Drive backup files: %w", err) + return summary, fmt.Errorf("list Drive backup files: %w", err) } imported, err := st.ListImportedSourceItemChecksums(source.ID, "drive") if err != nil { - return fmt.Errorf("list imported Drive checksums: %w", err) + return summary, fmt.Errorf("list imported Drive checksums: %w", err) } stableAfter, err := time.ParseDuration(src.StableAfter) if err != nil { - return fmt.Errorf("parse stable_after: %w", err) + return summary, fmt.Errorf("parse stable_after: %w", err) } selected := synctechsms.SelectStableDriveFiles(files, time.Now(), stableAfter, imported) stagingDir := filepath.Join(cfg.Data.DataDir, "imports", "synctech-sms", src.Name) if err := os.MkdirAll(stagingDir, 0o700); err != nil { - return fmt.Errorf("create staging directory: %w", err) + return summary, fmt.Errorf("create staging directory: %w", err) } imp := synctechsms.NewImporter(st, opts) - var summary synctechsms.ImportSummary for _, file := range selected { fileSummary, err := importOneDriveBackup(ctx, st, imp, client, source.ID, file, stagingDir) - if err != nil { - return err - } summary.FilesSeen += fileSummary.FilesSeen summary.FilesImported += fileSummary.FilesImported summary.SMSImported += fileSummary.SMSImported summary.MMSImported += fileSummary.MMSImported summary.CallsImported += fileSummary.CallsImported summary.AttachmentsImported += fileSummary.AttachmentsImported + summary.MessageIDs = append(summary.MessageIDs, fileSummary.MessageIDs...) + if err != nil { + return summary, err + } } if summary.FilesImported > 0 { if err := st.RecomputeConversationStats(source.ID); err != nil { - return fmt.Errorf("recompute conversation stats: %w", err) + return summary, fmt.Errorf("recompute conversation stats: %w", err) } } totalRecords := int64(summary.SMSImported + summary.MMSImported + summary.CallsImported) @@ -243,16 +252,16 @@ func runSynctechSMSDriveSourceWithClient(ctx context.Context, st *store.Store, s MessagesProcessed: totalRecords, MessagesAdded: totalRecords, }); err != nil { - return fmt.Errorf("update sync checkpoint: %w", err) + return summary, fmt.Errorf("update sync checkpoint: %w", err) } if err := st.TouchSourceLastSyncAt(source.ID); err != nil { - return fmt.Errorf("touch source last sync: %w", err) + return summary, fmt.Errorf("touch source last sync: %w", err) } if err := st.CompleteSync(syncID, ""); err != nil { - return fmt.Errorf("complete sync: %w", err) + return summary, fmt.Errorf("complete sync: %w", err) } completed = true - return nil + return summary, nil } func importOneDriveBackup(ctx context.Context, st *store.Store, imp *synctechsms.Importer, client synctechsms.DriveClient, sourceID int64, file synctechsms.DriveFile, stagingDir string) (synctechsms.ImportSummary, error) { diff --git a/cmd/msgvault/cmd/add_synctech_sms_drive_test.go b/cmd/msgvault/cmd/add_synctech_sms_drive_test.go index f39208500..2f98ede79 100644 --- a/cmd/msgvault/cmd/add_synctech_sms_drive_test.go +++ b/cmd/msgvault/cmd/add_synctech_sms_drive_test.go @@ -75,8 +75,9 @@ func TestSynctechSMSDriveRunUsesSingleOuterSyncRun(t *testing.T) { }, } - err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) + summary, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) require.NoError(err, "runSynctechSMSDriveSourceWithClient") + require.Len(summary.MessageIDs, 1, "summary message IDs") source := getSynctechSource(t, f.Store, src.OwnerPhone) assert.Equal(1, countSyncRuns(t, f.Store, source.ID), "sync run count") @@ -86,7 +87,7 @@ func TestSynctechSMSDriveRunUsesSingleOuterSyncRun(t *testing.T) { assert.Equal(int64(1), run.MessagesAdded, "messages added") assert.True(getSynctechSource(t, f.Store, src.OwnerPhone).LastSyncAt.Valid, "last_sync_at should be touched") - item := getSourceImportItem(t, f.Store, source.ID, "drive", "backup-1") + item := getDriveSourceImportItem(t, f.Store, source.ID, "backup-1") assert.Equal("imported", item.Status, "source import status") assert.Equal(1, item.RecordsImported, "records imported") assert.False(item.ErrorMessage.Valid, "source import error") @@ -113,7 +114,7 @@ func TestSynctechSMSDriveRunSetsUpIdentityAndPostSourceMigration(t *testing.T) { src := synctechDriveTestSource() client := fakeSynctechDriveClient{} - err = runSynctechSMSDriveSourceWithClient(context.Background(), st, src, synctechImportOptions(src), client) + _, err = runSynctechSMSDriveSourceWithClient(context.Background(), st, src, synctechImportOptions(src), client) require.NoError(err, "runSynctechSMSDriveSourceWithClient") synctechSource := getSynctechSource(t, st, src.OwnerPhone) @@ -154,7 +155,7 @@ func TestSynctechSMSDriveRunRecordsZeroSelectedPoll(t *testing.T) { }}, } - err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) + _, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) require.NoError(err, "runSynctechSMSDriveSourceWithClient") source := getSynctechSource(t, f.Store, src.OwnerPhone) @@ -188,7 +189,7 @@ func TestSynctechSMSDriveRunMarksOuterSyncFailedOnDownloadError(t *testing.T) { downloadErr: downloadErr, } - err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) + _, err := runSynctechSMSDriveSourceWithClient(context.Background(), f.Store, src, synctechImportOptions(src), client) require.ErrorIs(err, downloadErr, "runSynctechSMSDriveSourceWithClient") source := getSynctechSource(t, f.Store, src.OwnerPhone) @@ -198,17 +199,126 @@ func TestSynctechSMSDriveRunMarksOuterSyncFailedOnDownloadError(t *testing.T) { require.True(run.ErrorMessage.Valid, "sync error_message") assert.Contains(run.ErrorMessage.String, downloadErr.Error(), "sync error_message") - item := getSourceImportItem(t, f.Store, source.ID, "drive", "backup-1") + item := getDriveSourceImportItem(t, f.Store, source.ID, "backup-1") assert.Equal("failed", item.Status, "source import status") require.True(item.ErrorMessage.Valid, "source import error") assert.Contains(item.ErrorMessage.String, downloadErr.Error(), "source import error") } +func TestSynctechSMSDrivePartialFailureEnqueuesImportedMessages(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + home := t.TempDir() + cfg = config.NewDefaultConfig() + cfg.HomeDir = home + cfg.Data.DataDir = home + f := storetest.New(t) + src := synctechDriveTestSource() + client := fakeSynctechDriveClient{ + files: []synctechsms.DriveFile{ + { + ID: "backup-1", + Name: "sms-1.xml", + Checksum: "sum-1", + Size: 128, + ModifiedTime: time.Now().Add(-30 * time.Minute), + }, + { + ID: "backup-2", + Name: "sms-2.xml", + Checksum: "sum-2", + Size: 128, + ModifiedTime: time.Now().Add(-30 * time.Minute), + }, + }, + downloads: map[string]string{ + "backup-1": ` + +`, + "backup-2": ` + + + +`), 0o600), "write backup") + + st, err := store.Open(cfg.DatabaseDSN()) + require.NoError(err, "open store") + require.NoError(st.InitSchema(), "InitSchema") + require.NoError(st.Close(), "close store") + + src := config.SynctechSMSSource{ + Name: "pixel-local", + Backend: "local", + Path: importDir, + OwnerPhone: "+15550000001", + IncludeSMS: true, + } + require.NoError(runConfiguredSynctechSMSSource(ctx, src), "runConfiguredSynctechSMSSource") + + st, err = store.Open(cfg.DatabaseDSN()) + require.NoError(err, "reopen store") + t.Cleanup(func() { _ = st.Close() }) + source := getSynctechSource(t, st, src.OwnerPhone) + var unstamped int + require.NoError(st.DB().QueryRowContext(ctx, + st.Rebind(`SELECT COUNT(*) FROM messages WHERE source_id = ? AND embed_gen IS NULL`), + source.ID, + ).Scan(&unstamped), "count unstamped messages") + assert.Equal(1, unstamped, "manual sync message remains discoverable by scan-and-fill") +} + type fakeSynctechDriveClient struct { - files []synctechsms.DriveFile - downloads map[string]string - listErr error - downloadErr error + files []synctechsms.DriveFile + downloads map[string]string + listErr error + downloadErr error + downloadErrByID map[string]error } func (f fakeSynctechDriveClient) ListBackupFiles(context.Context, string) ([]synctechsms.DriveFile, error) { @@ -219,6 +329,9 @@ func (f fakeSynctechDriveClient) ListBackupFiles(context.Context, string) ([]syn } func (f fakeSynctechDriveClient) DownloadToFile(_ context.Context, fileID, path string) error { + if err := f.downloadErrByID[fileID]; err != nil { + return err + } if f.downloadErr != nil { return f.downloadErr } @@ -280,9 +393,9 @@ func getOnlySyncRun(t *testing.T, st *store.Store, sourceID int64) store.SyncRun return run } -func getSourceImportItem(t *testing.T, st *store.Store, sourceID int64, provider, providerID string) *store.SourceImportItem { +func getDriveSourceImportItem(t *testing.T, st *store.Store, sourceID int64, providerID string) *store.SourceImportItem { t.Helper() - item, err := st.GetSourceImportItem(sourceID, provider, providerID) + item, err := st.GetSourceImportItem(sourceID, "drive", providerID) requirepkg.NoError(t, err, "GetSourceImportItem") return item } diff --git a/cmd/msgvault/cmd/embed_vector.go b/cmd/msgvault/cmd/embed_vector.go index 8df24cd03..f0372d67a 100644 --- a/cmd/msgvault/cmd/embed_vector.go +++ b/cmd/msgvault/cmd/embed_vector.go @@ -57,6 +57,7 @@ func runEmbed(cmd *cobra.Command) error { pgb, err := pgvector.Open(ctx, pgvector.Options{ DB: s.DB(), Dimension: cfg.Vector.Embeddings.Dimension, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), SkipExtension: cfg.Vector.SkipExtensionCreate, }) if err != nil { @@ -76,10 +77,11 @@ func runEmbed(cmd *cobra.Command) error { vecPath = filepath.Join(cfg.Data.DataDir, "vectors.db") } sb, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecPath, - MainPath: cfg.DatabaseDSN(), - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: s.DB(), + Path: vecPath, + MainPath: cfg.DatabaseDSN(), + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: s.DB(), + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { return fmt.Errorf("open vectors.db: %w", err) @@ -116,7 +118,8 @@ func runEmbed(cmd *cobra.Command) error { // "Pending" is now the count of live messages still needing work for // this generation (embed_gen <> gen), read from the main DB coverage // rather than a queue table. - missing, err := s.MissingCount(ctx, int64(gen)) + scope := cfg.Vector.Embed.Scope.BuildScope() + missing, err := s.MissingCountScoped(ctx, int64(gen), scope.MessageTypes) if err != nil { return fmt.Errorf("coverage counts: %w", err) } @@ -138,6 +141,7 @@ func runEmbed(cmd *cobra.Command) error { }, MaxInputChars: cfg.Vector.Embeddings.MaxInputChars, BatchSize: cfg.Vector.Embeddings.BatchSize, + BuildScope: scope, Rebind: rebind, LastModifiedExpr: lastModifiedExpr, TotalPending: totalPending, @@ -161,7 +165,7 @@ func runEmbed(cmd *cobra.Command) error { // worker later recovers from must not block activation, and an // active generation must not be re-activated. if rebuildInProgress { - _, _, _, remaining, err := s.CoverageCounts(ctx, int64(gen)) + _, _, _, remaining, err := s.CoverageCountsScoped(ctx, int64(gen), scope.MessageTypes) if err != nil { return fmt.Errorf("coverage counts: %w", err) } diff --git a/cmd/msgvault/cmd/search_vector.go b/cmd/msgvault/cmd/search_vector.go index 96cf957bb..6f8e7ab64 100644 --- a/cmd/msgvault/cmd/search_vector.go +++ b/cmd/msgvault/cmd/search_vector.go @@ -79,6 +79,7 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc pgb, err := pgvector.Open(ctx, pgvector.Options{ DB: mainDB, Dimension: cfg.Vector.Embeddings.Dimension, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), SkipExtension: cfg.Vector.SkipExtensionCreate, }) if err != nil { @@ -110,10 +111,11 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc } sb, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecDBPath, - MainPath: dsn, - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: mainDB, + Path: vecDBPath, + MainPath: dsn, + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: mainDB, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { _ = mainDB.Close() @@ -147,6 +149,7 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, sc KPerSignal: cfg.Vector.Search.KPerSignal, SubjectBoost: cfg.Vector.Search.SubjectBoost, Rebind: dialect.Rebind, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) q := search.Parse(queryStr) diff --git a/cmd/msgvault/cmd/serve.go b/cmd/msgvault/cmd/serve.go index fa1a41d8d..29fe1f21e 100644 --- a/cmd/msgvault/cmd/serve.go +++ b/cmd/msgvault/cmd/serve.go @@ -206,6 +206,7 @@ func runServe(cmd *cobra.Command, args []string) error { Store: s, Fingerprint: vf.Cfg.GenerationFingerprint(), BackstopInterval: vf.Cfg.Embed.BackstopInterval, + BuildScope: vf.Cfg.Embed.Scope.BuildScope(), Log: logger, } schedule := cfg.Vector.Embed.Schedule.Cron diff --git a/cmd/msgvault/cmd/serve_vector.go b/cmd/msgvault/cmd/serve_vector.go index 6be2e5e8c..9f329e182 100644 --- a/cmd/msgvault/cmd/serve_vector.go +++ b/cmd/msgvault/cmd/serve_vector.go @@ -76,6 +76,7 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s pgb, err := pgvector.Open(ctx, pgvector.Options{ DB: mainDB, Dimension: cfg.Vector.Embeddings.Dimension, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), SkipMigrate: readOnly, // ReadOnly MUST track readOnly here: this is the MCP read-only // path (store.OpenReadOnly). When set, Open performs no writes — @@ -103,10 +104,11 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s vecPath = filepath.Join(cfg.Data.DataDir, "vectors.db") } sb, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecPath, - MainPath: mainPath, - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: mainDB, + Path: vecPath, + MainPath: mainPath, + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: mainDB, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), // Honor the read-only signal on SQLite too: when mainDB is a // query-only handle (MCP), skip the embed_gen upgrade backfill, // which would write through it. Migrate still runs (vectors.db @@ -146,6 +148,7 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s }, MaxInputChars: cfg.Vector.Embeddings.MaxInputChars, BatchSize: cfg.Vector.Embeddings.BatchSize, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), // Rebind makes the worker's body-fetch + watermark SQL run on pgx. // SQLiteDialect.Rebind is identity, so the SQLite path is unchanged. Rebind: dialect.Rebind, @@ -162,7 +165,8 @@ func setupVectorFeatures(ctx context.Context, mainStore *store.Store, mainPath s // placeholders. On PG those must become $N or pgx rejects them, so // the serve/MCP hybrid engine (shared via vectorFeatures.HybridEngine) // carries the dialect's Rebind. SQLite's Rebind is identity. - Rebind: dialect.Rebind, + Rebind: dialect.Rebind, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) // No sync-time enqueue: newly-persisted messages get embed_gen = NULL diff --git a/internal/api/handlers.go b/internal/api/handlers.go index c76ed2473..4df02ff66 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -616,6 +616,8 @@ func (s *Server) handleHybridSearch( case errors.Is(err, vector.ErrEmbeddingTimeout): writeError(w, http.StatusServiceUnavailable, "embedding_timeout", "the embedding endpoint did not respond in time; retry, or raise [vector.embeddings].timeout") + case errors.Is(err, vector.ErrIndexScopeMismatch): + writeError(w, http.StatusBadRequest, "index_scope_mismatch", err.Error()) default: s.logger.Error("hybrid search failed", "query", q, "mode", mode, "error", err) writeError(w, http.StatusInternalServerError, "internal_error", "search failed") diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index c9a747154..1940aeab0 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -118,6 +118,11 @@ func translateVectorErr(err error) *mcp.CallToolResult { return mcp.NewToolResultError( "index_building: the initial vector index is still being built", ) + case errors.Is(err, vector.ErrIndexScopeMismatch): + return mcp.NewToolResultError( + "index_scope_mismatch: the vector index scope does not cover this query; " + + "add a matching message_type filter or rebuild embeddings for the requested scope", + ) case errors.Is(err, vector.ErrNoActiveGeneration): return mcp.NewToolResultError( "no_active_generation: vector search has no active index yet; " + @@ -512,25 +517,46 @@ func (h *handlers) findSimilarMessages(ctx context.Context, req mcp.CallToolRequ return mcp.NewToolResultError(err.Error()), nil } - seed, err := h.backend.LoadVector(ctx, seedID) + active, err := h.backend.ActiveGeneration(ctx) if err != nil { if r := translateVectorErr(err); r != nil { return r, nil } - return mcp.NewToolResultError(fmt.Sprintf("load seed vector: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil + } + if h.vectorCfg.Enabled { + fingerprint := h.vectorCfg.GenerationFingerprint() + if fingerprint != "" && active.Fingerprint != fingerprint { + err := fmt.Errorf("%w: active=%q configured=%q", + vector.ErrIndexStale, active.Fingerprint, fingerprint) + if r := translateVectorErr(err); r != nil { + return r, nil + } + return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil + } } - active, err := h.backend.ActiveGeneration(ctx) + seed, err := h.backend.LoadVector(ctx, seedID) if err != nil { if r := translateVectorErr(err); r != nil { return r, nil } - return mcp.NewToolResultError(fmt.Sprintf("active generation: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("load seed vector: %v", err)), nil + } + + if h.vectorCfg.Enabled { + scope := h.vectorCfg.Embed.Scope.BuildScope() + if !scope.IsEmpty() { + filter.MessageTypes = append([]string(nil), scope.MessageTypes...) + } } // +1 so we can drop the seed itself from results without coming up short. hits, err := h.backend.Search(ctx, active.ID, seed, limit+1, filter) if err != nil { + if r := translateVectorErr(err); r != nil { + return r, nil + } return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err)), nil } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index e11cc76e4..5cdd97ac1 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -262,6 +262,33 @@ func TestSearchMessages_HybridErrNotEnabled(t *testing.T) { assertpkg.Contains(t, txt, "vector_not_enabled", "expected 'vector_not_enabled' error, got: %s") } +func TestSearchMessages_HybridErrIndexScopeMismatch(t *testing.T) { + backend := &fakeBackend{ + active: vector.Generation{ + ID: 1, Model: "fake", Dimension: 4, + Fingerprint: "fake:4:scope=mt-sms", State: vector.GenerationActive, + }, + } + engine := hybrid.NewEngine(backend, nil, realEmbedder{dim: 4}, hybrid.Config{ + ExpectedFingerprint: "fake:4:scope=mt-sms", + RRFK: 60, + KPerSignal: 10, + BuildScope: vector.BuildScope{MessageTypes: []string{"sms"}}, + }) + h := &handlers{ + engine: &querytest.MockEngine{}, + hybridEngine: engine, + backend: backend, + } + + r := runToolExpectError(t, "search_messages", h.searchMessages, map[string]any{ + "query": "anything", + "mode": searchModeVector, + }) + txt := resultText(t, r) + assertpkg.Contains(t, txt, "index_scope_mismatch", "expected 'index_scope_mismatch' error, got: %s") +} + // realEmbedder returns a deterministic vector. Used for end-to-end // MCP hybrid tests that exercise the engine's embed → backend.Search // path; pickEmbedGeneration tests use stubEmbedder instead. @@ -1465,19 +1492,23 @@ func TestStageDeletion(t *testing.T) { // optional fields so the get_stats tests can populate them. Methods // not otherwise configured return errors and should not be called. type fakeBackend struct { - loadVec []float32 - loadErr error - active vector.Generation - activeErr error - searchHits []vector.Hit - searchErr error - building *vector.Generation - buildingErr error - stats map[vector.GenerationID]vector.Stats - statsErr error + loadVec []float32 + loadErr error + loadCalls int + active vector.Generation + activeErr error + searchHits []vector.Hit + searchErr error + searchGen vector.GenerationID + searchFilter vector.Filter + building *vector.Generation + buildingErr error + stats map[vector.GenerationID]vector.Stats + statsErr error } func (f *fakeBackend) LoadVector(_ context.Context, _ int64) ([]float32, error) { + f.loadCalls++ return f.loadVec, f.loadErr } func (f *fakeBackend) ResetWatermarkBelow(_ context.Context, _ int64) error { @@ -1489,7 +1520,9 @@ func (f *fakeBackend) EmbeddedMessageCount(_ context.Context, _ vector.Generatio func (f *fakeBackend) ActiveGeneration(_ context.Context) (vector.Generation, error) { return f.active, f.activeErr } -func (f *fakeBackend) Search(_ context.Context, _ vector.GenerationID, _ []float32, _ int, _ vector.Filter) ([]vector.Hit, error) { +func (f *fakeBackend) Search(_ context.Context, gen vector.GenerationID, _ []float32, _ int, filter vector.Filter) ([]vector.Hit, error) { + f.searchGen = gen + f.searchFilter = filter return f.searchHits, f.searchErr } func (f *fakeBackend) CreateGeneration(_ context.Context, _ string, _ int, _ string) (vector.GenerationID, error) { @@ -1623,9 +1656,98 @@ func TestFindSimilarMessages_HappyPath(t *testing.T) { assert.Equal(int64(300), resp.Messages[1].ID, "Messages[1].ID") } +func TestFindSimilarMessages_RejectsStaleScopedGeneration(t *testing.T) { + assert := assertpkg.New(t) + seed := []float32{1, 0, 0, 0} + cfg := vector.Config{Enabled: true} + cfg.Embeddings.Model = "nomic-embed" + cfg.Embeddings.Dimension = 4 + cfg.Embed.Scope.MessageTypes = []string{"sms"} + fb := &fakeBackend{ + loadVec: seed, + active: vector.Generation{ + ID: 7, + Model: "nomic-embed", + Dimension: 4, + Fingerprint: "nomic-embed:4", + State: vector.GenerationActive, + }, + } + h := &handlers{engine: &querytest.MockEngine{}, backend: fb, vectorCfg: cfg} + + r := runToolExpectError(t, "find_similar_messages", h.findSimilarMessages, map[string]any{ + "message_id": float64(100), + }) + txt := resultText(t, r) + + assert.Contains(txt, "index_stale", "expected stale scoped index rejection, got: %s", txt) + assert.Equal(vector.GenerationID(0), fb.searchGen, "Search should not run against a stale scoped index") +} + +func TestFindSimilarMessages_ReportsStaleIndexBeforeMissingSeed(t *testing.T) { + assert := assertpkg.New(t) + cfg := vector.Config{Enabled: true} + cfg.Embeddings.Model = "nomic-embed" + cfg.Embeddings.Dimension = 4 + cfg.Embed.Scope.MessageTypes = []string{"sms"} + fb := &fakeBackend{ + loadErr: errors.New("seed vector missing"), + active: vector.Generation{ + ID: 7, + Model: "nomic-embed", + Dimension: 4, + Fingerprint: "nomic-embed:4", + State: vector.GenerationActive, + }, + } + h := &handlers{engine: &querytest.MockEngine{}, backend: fb, vectorCfg: cfg} + + r := runToolExpectError(t, "find_similar_messages", h.findSimilarMessages, map[string]any{ + "message_id": float64(100), + }) + txt := resultText(t, r) + + assert.Contains(txt, "index_stale", "expected stale index to be reported before seed lookup, got: %s", txt) + assert.Equal(0, fb.loadCalls, "LoadVector should not run before stale-index validation") +} + +func TestFindSimilarMessages_AppliesConfiguredMessageTypeScope(t *testing.T) { + assert := assertpkg.New(t) + seed := []float32{1, 0, 0, 0} + cfg := vector.Config{Enabled: true} + cfg.Embeddings.Model = "nomic-embed" + cfg.Embeddings.Dimension = 4 + cfg.Embed.Scope.MessageTypes = []string{"sms", "mms"} + fb := &fakeBackend{ + loadVec: seed, + active: vector.Generation{ + ID: 7, + Model: "nomic-embed", + Dimension: 4, + Fingerprint: cfg.GenerationFingerprint(), + State: vector.GenerationActive, + }, + searchHits: []vector.Hit{{MessageID: 200, Score: 0.95, Rank: 1}}, + } + eng := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 200: testutil.NewMessageDetail(200).WithSubject("related").BuildPtr(), + }, + } + h := &handlers{engine: eng, backend: fb, vectorCfg: cfg} + + _ = runTool[similarResponse](t, "find_similar_messages", h.findSimilarMessages, map[string]any{ + "message_id": float64(100), + }) + + assert.Equal(vector.GenerationID(7), fb.searchGen, "Search generation") + assert.Equal([]string{"mms", "sms"}, fb.searchFilter.MessageTypes, "Search MessageTypes filter") +} + func TestFindSimilarMessages_NoActiveGeneration(t *testing.T) { + assert := assertpkg.New(t) fb := &fakeBackend{ - loadErr: vector.ErrNoActiveGeneration, + activeErr: vector.ErrNoActiveGeneration, } h := &handlers{engine: &querytest.MockEngine{}, backend: fb} @@ -1633,7 +1755,8 @@ func TestFindSimilarMessages_NoActiveGeneration(t *testing.T) { "message_id": float64(1), }) txt := resultText(t, r) - assertpkg.Contains(t, txt, "no_active_generation", "expected 'no_active_generation' error, got: %s") + assert.Contains(txt, "no_active_generation", "expected 'no_active_generation' error, got: %s", txt) + assert.Equal(0, fb.loadCalls, "LoadVector should not run without an active generation") } func TestSearchByDomains(t *testing.T) { diff --git a/internal/query/duckdb.go b/internal/query/duckdb.go index e4a31059c..b0374614a 100644 --- a/internal/query/duckdb.go +++ b/internal/query/duckdb.go @@ -407,7 +407,7 @@ func (e *DuckDBEngine) parquetCTEs() string { conv AS ( %s ) - `, msgCTE, + `, msgCTE, e.parquetPath("message_recipients"), pCTE, e.parquetPath("labels"), @@ -417,6 +417,45 @@ func (e *DuckDBEngine) parquetCTEs() string { convCTE) } +func duckDBMessageTypeCondition(alias string, messageTypes []string) (string, []any) { + var conditions []string + var args []any + var exact []string + includeEmail := false + + for _, typ := range messageTypes { + typ = strings.TrimSpace(strings.ToLower(typ)) + if typ == "" { + continue + } + if typ == "email" { + includeEmail = true + continue + } + exact = append(exact, typ) + } + + col := alias + ".message_type" + if includeEmail { + conditions = append(conditions, + fmt.Sprintf("(%s = ? OR %s IS NULL OR %s = '')", col, col, col)) + args = append(args, "email") + } + if len(exact) > 0 { + placeholders := make([]string, len(exact)) + for i, typ := range exact { + placeholders[i] = "?" + args = append(args, typ) + } + conditions = append(conditions, + fmt.Sprintf("%s IN (%s)", col, strings.Join(placeholders, ","))) + } + if len(conditions) == 0 { + return "", nil + } + return "(" + strings.Join(conditions, " OR ") + ")", args +} + // escapeILIKE escapes ILIKE wildcard characters (% and _) in user input. func escapeILIKE(s string) string { s = strings.ReplaceAll(s, "\\", "\\\\") // Escape backslash first @@ -486,6 +525,14 @@ func (e *DuckDBEngine) buildNonTextSearchConditions(q *search.Query, keyColumns var conditions []string var args []any + if len(q.MessageTypes) > 0 { + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } + } + // from: filter - match sender email for _, from := range q.FromAddrs { fromPattern := "%" + escapeILIKE(from) + "%" @@ -578,12 +625,11 @@ func (e *DuckDBEngine) buildNonTextSearchConditions(q *search.Query, keyColumns args = append(args, *q.SmallerThan) } if len(q.MessageTypes) > 0 { - placeholders := make([]string, len(q.MessageTypes)) - for i, typ := range q.MessageTypes { - placeholders[i] = "?" - args = append(args, typ) + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) } - conditions = append(conditions, fmt.Sprintf("msg.message_type IN (%s)", strings.Join(placeholders, ","))) } return conditions, args @@ -882,8 +928,11 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []an } if filter.MessageType != "" { - conditions = append(conditions, "msg.message_type = ?") - args = append(args, filter.MessageType) + condition, conditionArgs := duckDBMessageTypeCondition("msg", []string{filter.MessageType}) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } } // Sender + sender-name filters - check both message_recipients (email) @@ -1160,10 +1209,14 @@ func (e *DuckDBEngine) GetTotalStats(ctx context.Context, opts StatsOptions) (*T var conditions []string var args []any + hasExplicitMessageTypes := false + if opts.SearchQuery != "" { + q := search.Parse(opts.SearchQuery) + hasExplicitMessageTypes = len(q.MessageTypes) > 0 + } // Restrict to email messages only; NULL and '' handle pre-message_type data. - hasSearchMessageTypes := opts.SearchQuery != "" && len(search.Parse(opts.SearchQuery).MessageTypes) > 0 - if !hasSearchMessageTypes { + if !hasExplicitMessageTypes { conditions = append(conditions, emailOnlyFilterMsg) } conditions = append(conditions, store.LiveMessagesWhere("msg", opts.HideDeletedFromSource)) @@ -1638,6 +1691,14 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse } } + if len(q.MessageTypes) > 0 { + condition, conditionArgs := duckDBMessageTypeCondition("m", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } + } + // Has attachment filter if q.HasAttachment != nil && *q.HasAttachment { conditions = append(conditions, "m.has_attachments = 1") @@ -1663,12 +1724,11 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse args = append(args, *q.SmallerThan) } if len(q.MessageTypes) > 0 { - placeholders := make([]string, len(q.MessageTypes)) - for i, typ := range q.MessageTypes { - placeholders[i] = "?" - args = append(args, typ) + condition, conditionArgs := duckDBMessageTypeCondition("m", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) } - conditions = append(conditions, fmt.Sprintf("m.message_type IN (%s)", strings.Join(placeholders, ","))) } // Full-text search: use ILIKE fallback (FTS5 not available via sqlite_scan) @@ -1702,7 +1762,8 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse COALESCE(m.size_estimate, 0), m.has_attachments, m.attachment_count, - m.deleted_from_source_at + m.deleted_from_source_at, + COALESCE(m.message_type, '') FROM sqlite_db.messages m LEFT JOIN sqlite_db.message_recipients mr_sender ON mr_sender.message_id = m.id AND mr_sender.recipient_type = 'from' LEFT JOIN sqlite_db.participants p_sender ON p_sender.id = mr_sender.participant_id @@ -1740,6 +1801,7 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse &msg.HasAttachments, &msg.AttachmentCount, &deletedAt, + &msg.MessageType, ); err != nil { return nil, fmt.Errorf("scan message: %w", err) } @@ -2460,11 +2522,17 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt var args []any // Apply basic filter conditions (ignoring join flags for search - we handle those differently) - if len(q.MessageTypes) == 0 { + conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource)) + if len(q.MessageTypes) > 0 { + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + } + } else { // Restrict to email messages only; NULL and '' handle pre-message_type data. conditions = append(conditions, emailOnlyFilterMsg) } - conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource)) conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs) if filter.After != nil { conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)") @@ -2617,12 +2685,11 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt args = append(args, *q.SmallerThan) } if len(q.MessageTypes) > 0 { - placeholders := make([]string, len(q.MessageTypes)) - for i, typ := range q.MessageTypes { - placeholders[i] = "?" - args = append(args, typ) + condition, conditionArgs := duckDBMessageTypeCondition("msg", q.MessageTypes) + if condition != "" { + conditions = append(conditions, condition) + args = append(args, conditionArgs...) } - conditions = append(conditions, fmt.Sprintf("msg.message_type IN (%s)", strings.Join(placeholders, ","))) } // Account filter diff --git a/internal/query/duckdb_test.go b/internal/query/duckdb_test.go index a10e428c9..ef4e41d93 100644 --- a/internal/query/duckdb_test.go +++ b/internal/query/duckdb_test.go @@ -41,6 +41,33 @@ func newSQLiteEngine(t *testing.T) *DuckDBEngine { return engine } +func newMessageTypeParquetEngine(t *testing.T) *DuckDBEngine { + t.Helper() + b := NewTestDataBuilder(t) + b.AddSource("test@example.com") + aliceID := b.AddParticipant("alice@example.com", "example.com", "Alice") + bobID := b.AddParticipant("bob@example.com", "example.com", "Bob") + smsID := b.AddMessage(MessageOpt{ + Subject: "lunch plan", + Snippet: "sushi lunch details", + MessageType: "sms", + SentAt: time.Date(2024, 4, 10, 10, 0, 0, 0, time.UTC), + SizeEstimate: 321, + }) + emailID := b.AddMessage(MessageOpt{ + Subject: "lunch receipt", + Snippet: "email lunch details", + MessageType: "email", + SentAt: time.Date(2024, 4, 11, 10, 0, 0, 0, time.UTC), + SizeEstimate: 999, + }) + b.AddFrom(smsID, aliceID, "Alice") + b.AddTo(smsID, bobID, "Bob") + b.AddFrom(emailID, aliceID, "Alice") + b.AddTo(emailID, bobID, "Bob") + return b.BuildEngine() +} + // searchFast is a test helper that parses a query string and calls SearchFast. func searchFast(t *testing.T, engine *DuckDBEngine, queryStr string, filter MessageFilter) []MessageSummary { t.Helper() @@ -1123,6 +1150,54 @@ func TestDuckDBEngine_GetTotalStats_MessageTypeFilter(t *testing.T) { assert.Equal(int64(2000), stats.TotalSize, "TotalSize") } +func TestDuckDBEngine_SearchFastMessageTypeFilter(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + engine := newMessageTypeParquetEngine(t) + ctx := context.Background() + + filterOnly, err := engine.SearchFast(ctx, search.Parse("message_type:sms"), MessageFilter{}, 100, 0) + require.NoError(err, "SearchFast message_type only") + require.Len(filterOnly, 1, "filter-only message_type search") + assert.Equal("sms", filterOnly[0].MessageType) + assert.Equal("lunch plan", filterOnly[0].Subject) + + withText, err := engine.SearchFast(ctx, search.Parse("message_type:sms lunch"), MessageFilter{}, 100, 0) + require.NoError(err, "SearchFast message_type with text") + require.Len(withText, 1, "message_type should scope text search") + assert.Equal("sms", withText[0].MessageType) + assert.Equal("lunch plan", withText[0].Subject) +} + +func TestDuckDBEngine_GetTotalStatsMessageTypeSearch(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + engine := newMessageTypeParquetEngine(t) + + stats, err := engine.GetTotalStats(context.Background(), StatsOptions{ + SearchQuery: "message_type:sms", + }) + require.NoError(err, "GetTotalStats") + + assert.Equal(int64(1), stats.MessageCount, "message count") + assert.Equal(int64(321), stats.TotalSize, "total size") +} + +func TestDuckDBEngine_AggregateMessageTypeSearch(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + engine := newMessageTypeParquetEngine(t) + opts := DefaultAggregateOptions() + opts.SearchQuery = "message_type:sms" + + rows, err := engine.Aggregate(context.Background(), ViewTime, opts) + require.NoError(err, "Aggregate") + require.Len(rows, 1, "rows") + + assert.Equal(int64(1), rows[0].Count, "count") + assert.Equal(int64(321), rows[0].TotalSize, "total size") +} + // TestDuckDBEngine_ListMessages_DateFilter verifies that After/Before date filters // work with DuckDB's TIMESTAMP column (regression: VARCHAR params need CAST). func TestDuckDBEngine_ListMessages_DateFilter(t *testing.T) { @@ -3266,6 +3341,13 @@ func TestDuckDBEngine_StaleParquetSchema(t *testing.T) { assertpkg.Equal(t, "Stale Hello", results[0].Subject) }) + t.Run("SearchFastMessageTypeEmail", func(t *testing.T) { + q := search.Parse("message_type:email Stale") + results, err := engine.SearchFast(ctx, q, MessageFilter{}, 100, 0) + requirepkg.NoError(t, err, "SearchFast message_type:email with stale Parquet schema") + requirepkg.Len(t, results, 2) + }) + t.Run("SearchFastCount", func(t *testing.T) { q := search.Parse("Stale") count, err := engine.SearchFastCount(ctx, q, MessageFilter{}) @@ -3285,6 +3367,12 @@ func TestDuckDBEngine_StaleParquetSchema(t *testing.T) { assertpkg.Equal(t, int64(2), stats.MessageCount) }) + t.Run("GetTotalStatsMessageTypeEmail", func(t *testing.T) { + stats, err := engine.GetTotalStats(ctx, StatsOptions{SearchQuery: "message_type:email"}) + requirepkg.NoError(t, err, "GetTotalStats message_type:email with stale Parquet schema") + assertpkg.Equal(t, int64(2), stats.MessageCount) + }) + // Verify that optionalCols correctly detected the missing columns. t.Run("ProbeDetectedMissing", func(t *testing.T) { for _, col := range []struct{ table, col string }{ diff --git a/internal/query/sqlite.go b/internal/query/sqlite.go index 88d78b8de..9670df1ff 100644 --- a/internal/query/sqlite.go +++ b/internal/query/sqlite.go @@ -1497,6 +1497,15 @@ func (e *SQLiteEngine) buildSearchQueryParts(ctx context.Context, q *search.Quer } } + if len(q.MessageTypes) > 0 { + placeholders := make([]string, len(q.MessageTypes)) + for i, typ := range q.MessageTypes { + placeholders[i] = "?" + args = append(args, strings.ToLower(typ)) + } + conditions = append(conditions, fmt.Sprintf("m.message_type IN (%s)", strings.Join(placeholders, ","))) + } + // Has attachment filter if q.HasAttachment != nil && *q.HasAttachment { conditions = append(conditions, e.dialect.BoolTrueExpr("m.has_attachments")) diff --git a/internal/query/sqlite_crud_test.go b/internal/query/sqlite_crud_test.go index f9d365ffa..13e0ea665 100644 --- a/internal/query/sqlite_crud_test.go +++ b/internal/query/sqlite_crud_test.go @@ -1100,6 +1100,31 @@ func TestGetTotalStatsWithSearchQuery_Combined(t *testing.T) { assertpkg.Equal(t, int64(2000), stats.TotalSize, "SearchQuery+WithAttachments total size") } +func TestSearchFastWithStats_MessageTypeStats(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + env := newTestEnv(t) + + smsID := env.AddMessage(dbtest.MessageOpts{ + Subject: "Lunch via SMS", + SentAt: "2024-04-01 10:00:00", + SizeEstimate: 321, + }) + _, err := env.DB.Exec(`UPDATE messages SET message_type = 'sms' WHERE id = ?`, smsID) + require.NoError(err, "set sms message_type") + + q := search.Parse("message_type:sms") + result, err := env.Engine.SearchFastWithStats(env.Ctx, q, "message_type:sms", MessageFilter{}, ViewSenders, 100, 0) + require.NoError(err, "SearchFastWithStats") + require.NotNil(result.Stats, "stats") + + require.Len(result.Messages, 1, "messages") + assert.Equal(smsID, result.Messages[0].ID, "message id") + assert.Equal(int64(1), result.TotalCount, "total count") + assert.Equal(int64(1), result.Stats.MessageCount, "stats message count") + assert.Equal(int64(321), result.Stats.TotalSize, "stats total size") +} + func TestGetMessageRaw(t *testing.T) { env := newTestEnv(t) rawMIME := []byte("From: test@example.com\r\nSubject: Test\r\n\r\nHello") diff --git a/internal/query/sqlite_search_test.go b/internal/query/sqlite_search_test.go index 54c3cf423..f321e857f 100644 --- a/internal/query/sqlite_search_test.go +++ b/internal/query/sqlite_search_test.go @@ -10,6 +10,7 @@ import ( assertpkg "github.com/stretchr/testify/assert" requirepkg "github.com/stretchr/testify/require" "go.kenn.io/msgvault/internal/search" + "go.kenn.io/msgvault/internal/testutil/dbtest" "go.kenn.io/msgvault/internal/testutil/ptr" ) @@ -94,20 +95,6 @@ func TestSearch_Filters(t *testing.T) { } } -func TestSearch_MessageTypeFilter(t *testing.T) { - require := requirepkg.New(t) - assert := assertpkg.New(t) - env := newTestEnv(t) - - _, err := env.DB.Exec(`UPDATE messages SET message_type = ? WHERE id = ?`, "sms", int64(2)) - require.NoError(err, "mark message as sms") - - results := env.MustSearch(search.Parse("message_type:sms Hello"), 100, 0) - require.Len(results, 1, "message_type:sms should scope the text search") - assert.Equal(int64(2), results[0].ID, "ID") - assert.Equal("sms", results[0].MessageType, "MessageType") -} - func TestSearch_CaseInsensitiveFallback(t *testing.T) { env := newTestEnv(t) @@ -148,6 +135,41 @@ func TestSearch_WithFTS(t *testing.T) { assertpkg.Equal(t, "Hello World", results[0].Subject) } +func TestSearch_MessageTypeFilter(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + env := newTestEnv(t) + aliceID := env.MustLookupParticipant("alice@example.com") + bobID := env.MustLookupParticipant("bob@company.org") + smsID := env.AddMessage(dbtest.MessageOpts{ + Subject: "lunch plan", + SentAt: "2024-04-10 10:00:00", + FromID: aliceID, + ToIDs: []int64{bobID}, + }) + emailID := env.AddMessage(dbtest.MessageOpts{ + Subject: "lunch receipt", + SentAt: "2024-04-11 10:00:00", + FromID: aliceID, + ToIDs: []int64{bobID}, + }) + _, err := env.DB.Exec(`UPDATE messages SET message_type = 'sms' WHERE id = ?`, smsID) + require.NoError(err, "set sms message_type") + _, err = env.DB.Exec(`UPDATE messages SET message_type = 'email' WHERE id = ?`, emailID) + require.NoError(err, "set email message_type") + + results := env.MustSearch(search.Parse("message_type:sms"), 100, 0) + require.Len(results, 1, "filter-only message_type search") + assert.Equal(smsID, results[0].ID) + assert.Equal("sms", results[0].MessageType) + + env.EnableFTS() + results = env.MustSearch(search.Parse("message_type:sms lunch"), 100, 0) + require.Len(results, 1, "message_type must scope FTS search") + assert.Equal(smsID, results[0].ID) + assert.Equal("sms", results[0].MessageType) +} + // TestSearch_WithFTS_SpecialChars verifies that FTS5 special characters in // search terms don't cause syntax errors. Without quoting, these characters // are interpreted as FTS5 operators (- = NOT, : = column filter, () = grouping). diff --git a/internal/remote/engine_test.go b/internal/remote/engine_test.go index dae7242a6..67cfffffc 100644 --- a/internal/remote/engine_test.go +++ b/internal/remote/engine_test.go @@ -95,7 +95,6 @@ func TestEngineGetMessageSummariesByIDs_CarriesFromAndAttachmentCount(t *testing assert.Equal(2, s.AttachmentCount, "AttachmentCount") assert.True(s.HasAttachments, "HasAttachments") } - func TestEngineSearchSerializesMessageTypes(t *testing.T) { require := requirepkg.New(t) assert := assertpkg.New(t) @@ -125,3 +124,18 @@ func TestEngineSearchSerializesMessageTypes(t *testing.T) { }, 10, 0) require.NoError(err, "Search") } + +func TestBuildSearchQueryStringIncludesMessageTypes(t *testing.T) { + assert := assertpkg.New(t) + + assert.Equal( + "message_type:sms", + buildSearchQueryString(search.Parse("message_type:sms")), + "filter-only message type query", + ) + assert.Equal( + "lunch message_type:sms", + buildSearchQueryString(search.Parse("message_type:sms lunch")), + "message type with text term", + ) +} diff --git a/internal/scheduler/embed_job.go b/internal/scheduler/embed_job.go index cecea3622..bd6f1b1b7 100644 --- a/internal/scheduler/embed_job.go +++ b/internal/scheduler/embed_job.go @@ -34,6 +34,10 @@ type EmbedCoverage interface { MissingCount(ctx context.Context, activeGen int64) (int64, error) } +type ScopedEmbedCoverage interface { + MissingCountScoped(ctx context.Context, activeGen int64, messageTypes []string) (int64, error) +} + // Compile-time check that the production worker satisfies EmbedRunner. var _ EmbedRunner = (*embed.Worker)(nil) @@ -82,6 +86,10 @@ type EmbedJob struct { // A negative value disables the auto-backstop entirely. BackstopInterval time.Duration + // BuildScope limits coverage checks to the same message universe the + // worker scans for this generation. Empty means the full live corpus. + BuildScope vector.BuildScope + // Now returns the current time; overridable in tests to drive the // backstop interval deterministically. nil uses time.Now. Now func() time.Time @@ -182,7 +190,7 @@ func (j *EmbedJob) Run(ctx context.Context) { "gen", target) return } - missing, err := j.Store.MissingCount(ctx, int64(target)) + missing, err := j.missingCount(ctx, target) if err != nil { log.Warn("embed: coverage count after run failed", "gen", target, "error", err) return @@ -201,6 +209,17 @@ func (j *EmbedJob) Run(ctx context.Context) { log.Info("embed: building generation activated", "gen", target) } +func (j *EmbedJob) missingCount(ctx context.Context, target vector.GenerationID) (int64, error) { + scope := vector.NewBuildScope(j.BuildScope.MessageTypes) + if scope.IsEmpty() { + return j.Store.MissingCount(ctx, int64(target)) + } + if scoped, ok := j.Store.(ScopedEmbedCoverage); ok { + return scoped.MissingCountScoped(ctx, int64(target), scope.MessageTypes) + } + return 0, errors.New("embed coverage store does not support scoped missing counts") +} + // maybeRunBackstop runs a full watermark-ignoring backstop pass on gen when // BackstopInterval has elapsed since this generation's last one, then records // the time. The throttle is keyed per generation so a recent backstop of one diff --git a/internal/store/api_search_test.go b/internal/store/api_search_test.go index 32bea8bc2..d3808468c 100644 --- a/internal/store/api_search_test.go +++ b/internal/store/api_search_test.go @@ -148,7 +148,6 @@ func TestSearchMessages_LegacyRawString(t *testing.T) { }) } } - func TestSearchMessagesQuery_MessageTypeFilter(t *testing.T) { require := requirepkg.New(t) assert := assertpkg.New(t) diff --git a/internal/store/embed_gen.go b/internal/store/embed_gen.go index e7517283a..3ba69995f 100644 --- a/internal/store/embed_gen.go +++ b/internal/store/embed_gen.go @@ -33,16 +33,27 @@ var embedGenStampChunkRows = 500 // the embeddings upsert — the worker orders the steps (upsert, then // stamp) and relies on idempotency, see internal/vector/embed/worker.go. func (s *Store) ScanForEmbedding(ctx context.Context, target int64, afterID int64, limit int) ([]int64, error) { + return s.ScanForEmbeddingScoped(ctx, target, afterID, limit, nil) +} + +// ScanForEmbeddingScoped is ScanForEmbedding limited to the supplied message +// types. An empty messageTypes slice means the full live corpus. +func (s *Store) ScanForEmbeddingScoped(ctx context.Context, target int64, afterID int64, limit int, messageTypes []string) ([]int64, error) { if limit <= 0 { return nil, nil } + liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes) q := `SELECT id FROM messages WHERE (embed_gen IS NULL OR embed_gen <> ?) - AND ` + LiveMessagesWhere("", true) + ` + AND ` + liveWhere + ` AND id > ? ORDER BY id LIMIT ?` - rows, err := s.db.QueryContext(ctx, q, target, afterID, limit) + args := make([]any, 0, 3+len(liveArgs)) + args = append(args, target) + args = append(args, liveArgs...) + args = append(args, afterID, limit) + rows, err := s.db.QueryContext(ctx, q, args...) if err != nil { return nil, fmt.Errorf("scan for embedding: %w", err) } @@ -239,14 +250,22 @@ func (s *Store) ResetEmbedGen(ctx context.Context, ids []int64) error { // activeGen == 0 means "no active/target generation"; then everything // live is missing and stamped is 0. func (s *Store) CoverageCounts(ctx context.Context, activeGen int64) (live, stamped, blank, missing int64, err error) { - live, err = s.countLiveMessages(ctx) + return s.CoverageCountsScoped(ctx, activeGen, nil) +} + +// CoverageCountsScoped is CoverageCounts limited to the supplied message types. +// An empty messageTypes slice means the full live corpus. +func (s *Store) CoverageCountsScoped(ctx context.Context, activeGen int64, messageTypes []string) (live, stamped, blank, missing int64, err error) { + live, err = s.countLiveMessagesScoped(ctx, messageTypes) if err != nil { return 0, 0, 0, 0, err } if activeGen != 0 { + liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes) q := `SELECT COUNT(*) FROM messages - WHERE embed_gen = ? AND ` + LiveMessagesWhere("", true) - if err := s.db.QueryRowContext(ctx, q, activeGen).Scan(&stamped); err != nil { + WHERE embed_gen = ? AND ` + liveWhere + args := append([]any{activeGen}, liveArgs...) + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&stamped); err != nil { return 0, 0, 0, 0, fmt.Errorf("count stamped: %w", err) } } @@ -259,7 +278,13 @@ func (s *Store) CoverageCounts(ctx context.Context, activeGen int64) (live, stam // activeGen). It is a thin accessor for the scheduler/CLI activation // gates, which only consult the missing count; missing = live - stamped. func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error) { - live, err := s.countLiveMessages(ctx) + return s.MissingCountScoped(ctx, activeGen, nil) +} + +// MissingCountScoped is MissingCount limited to the supplied message types. +// An empty messageTypes slice means the full live corpus. +func (s *Store) MissingCountScoped(ctx context.Context, activeGen int64, messageTypes []string) (int64, error) { + live, err := s.countLiveMessagesScoped(ctx, messageTypes) if err != nil { return 0, err } @@ -267,9 +292,11 @@ func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error return live, nil } var stamped int64 + liveWhere, liveArgs := liveMessagesWhereWithMessageTypes(messageTypes) q := `SELECT COUNT(*) FROM messages - WHERE embed_gen = ? AND ` + LiveMessagesWhere("", true) - if err := s.db.QueryRowContext(ctx, q, activeGen).Scan(&stamped); err != nil { + WHERE embed_gen = ? AND ` + liveWhere + args := append([]any{activeGen}, liveArgs...) + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&stamped); err != nil { return 0, fmt.Errorf("count stamped: %w", err) } return max(live-stamped, 0), nil @@ -277,11 +304,48 @@ func (s *Store) MissingCount(ctx context.Context, activeGen int64) (int64, error // countLiveMessages returns the total live-message count. Shared by // CoverageCounts; kept separate so the live-predicate stays in one place. -func (s *Store) countLiveMessages(ctx context.Context) (int64, error) { +func (s *Store) countLiveMessagesScoped(ctx context.Context, messageTypes []string) (int64, error) { var n int64 - q := `SELECT COUNT(*) FROM messages WHERE ` + LiveMessagesWhere("", true) - if err := s.db.QueryRowContext(ctx, q).Scan(&n); err != nil { + liveWhere, args := liveMessagesWhereWithMessageTypes(messageTypes) + q := `SELECT COUNT(*) FROM messages WHERE ` + liveWhere + if err := s.db.QueryRowContext(ctx, q, args...).Scan(&n); err != nil { return 0, fmt.Errorf("count live messages: %w", err) } return n, nil } + +func liveMessagesWhereWithMessageTypes(messageTypes []string) (string, []any) { + where := LiveMessagesWhere("", true) + types := normalizeMessageTypes(messageTypes) + if len(types) == 0 { + return where, nil + } + placeholders := make([]string, len(types)) + args := make([]any, len(types)) + for i, typ := range types { + placeholders[i] = "?" + args[i] = typ + } + where += " AND message_type IN (" + strings.Join(placeholders, ",") + ")" + return where, args +} + +func normalizeMessageTypes(messageTypes []string) []string { + if len(messageTypes) == 0 { + return nil + } + seen := make(map[string]struct{}, len(messageTypes)) + out := make([]string, 0, len(messageTypes)) + for _, typ := range messageTypes { + typ = strings.TrimSpace(strings.ToLower(typ)) + if typ == "" { + continue + } + if _, ok := seen[typ]; ok { + continue + } + seen[typ] = struct{}{} + out = append(out, typ) + } + return out +} diff --git a/internal/synctechsms/importer.go b/internal/synctechsms/importer.go index c16141e7e..ec6d5e24f 100644 --- a/internal/synctechsms/importer.go +++ b/internal/synctechsms/importer.go @@ -98,39 +98,47 @@ func (i *Importer) importRecord(sourceID int64, record Record, summary *ImportSu switch record.Kind { case RecordSMS: if i.opts.IncludeSMS && record.SMS != nil { - if err := i.importSMS(sourceID, *record.SMS); err != nil { + msgID, err := i.importSMS(sourceID, *record.SMS) + if err != nil { return err } + summary.MessageIDs = append(summary.MessageIDs, msgID) summary.SMSImported++ } case RecordMMS: if i.opts.IncludeMMS && record.MMS != nil { - attachments, err := i.importMMS(sourceID, *record.MMS) + msgID, attachments, err := i.importMMS(sourceID, *record.MMS) if err != nil { + if msgID != 0 { + summary.MessageIDs = append(summary.MessageIDs, msgID) + } return err } + summary.MessageIDs = append(summary.MessageIDs, msgID) summary.MMSImported++ summary.AttachmentsImported += attachments } case RecordCall: if i.opts.IncludeCalls && record.Call != nil { - if err := i.importCall(sourceID, *record.Call); err != nil { + msgID, err := i.importCall(sourceID, *record.Call) + if err != nil { return err } + summary.MessageIDs = append(summary.MessageIDs, msgID) summary.CallsImported++ } } return nil } -func (i *Importer) importSMS(sourceID int64, sms SMS) error { +func (i *Importer) importSMS(sourceID int64, sms SMS) (int64, error) { remoteID, err := i.participantID(sms.Address, sms.ContactName.String) if err != nil { - return err + return 0, err } ownerID, err := i.participantID(i.opts.OwnerPhone, "Me") if err != nil { - return err + return 0, err } // Drafts are owner-authored messages that never made it out, but // they still belong on the owner's side of the conversation. Without @@ -144,36 +152,36 @@ func (i *Importer) importSMS(sourceID int64, sms SMS) error { } convID, err := i.ensureConversation(sourceID, textConversationKey([]int64{ownerID, remoteID}), sms.ContactName.String) if err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, remoteID, "member"); err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, ownerID, "member"); err != nil { - return err + return 0, err } msgID := stableID("sms", sms.Address, sms.Timestamp.String(), fmt.Sprint(sms.Type), sms.Body) return i.upsertTextMessage(sourceID, convID, msgID, "sms", senderID, recipientIDs, fromMe, sms.Timestamp, sms.Body, sms.Body, 0, sms) } -func (i *Importer) importMMS(sourceID int64, mms MMS) (int, error) { +func (i *Importer) importMMS(sourceID int64, mms MMS) (int64, int, error) { ownerID, err := i.participantID(i.opts.OwnerPhone, "Me") if err != nil { - return 0, err + return 0, 0, err } participantIDs, senderID, recipientIDs, err := i.mmsParticipants(mms, ownerID) if err != nil { - return 0, err + return 0, 0, err } // Drafts belong to the owner — see the matching note in importSMS. fromMe := mms.MessageBox == MMSBoxSent || mms.MessageBox == MMSBoxOutbox || mms.MessageBox == MMSBoxDraft convID, err := i.ensureConversation(sourceID, textConversationKey(participantIDs), mms.ContactName.String) if err != nil { - return 0, err + return 0, 0, err } for _, participantID := range participantIDs { if err := i.store.EnsureConversationParticipant(convID, participantID, "member"); err != nil { - return 0, err + return 0, 0, err } } body := mmsText(mms) @@ -183,21 +191,26 @@ func (i *Importer) importMMS(sourceID int64, mms MMS) (int, error) { } msgID := stableID("mms", srcIDPart, mms.Timestamp.String(), sortedKey(participantIDs)) attachmentCount := countImportableAttachments(mms, i.opts.IncludeAttachments) - if err := i.upsertTextMessage(sourceID, convID, msgID, "mms", senderID, recipientIDs, fromMe, mms.Timestamp, body, mms.Subject.String, attachmentCount, mms); err != nil { - return 0, err + messageID, err := i.upsertTextMessage(sourceID, convID, msgID, "mms", senderID, recipientIDs, fromMe, mms.Timestamp, body, mms.Subject.String, attachmentCount, mms) + if err != nil { + return 0, 0, err } - return i.importMMSAttachments(sourceID, msgID, mms) + imported, err := i.importMMSAttachments(sourceID, msgID, mms) + if err != nil { + return messageID, imported, err + } + return messageID, imported, nil } -func (i *Importer) importCall(sourceID int64, call Call) error { +func (i *Importer) importCall(sourceID int64, call Call) (int64, error) { remoteAddress := callParticipantAddress(call) remoteID, err := i.participantID(remoteAddress, call.ContactName.String) if err != nil { - return err + return 0, err } ownerID, err := i.participantID(i.opts.OwnerPhone, "Me") if err != nil { - return err + return 0, err } fromMe := call.Type == CallOutgoing senderID := remoteID @@ -208,20 +221,20 @@ func (i *Importer) importCall(sourceID int64, call Call) error { } convID, err := i.ensureConversation(sourceID, "calls:"+canonicalAddress(remoteAddress), call.ContactName.String) if err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, remoteID, "member"); err != nil { - return err + return 0, err } if err := i.store.EnsureConversationParticipant(convID, ownerID, "member"); err != nil { - return err + return 0, err } body := fmt.Sprintf("Call %s, %d seconds", callTypeLabel(call.Type), call.DurationSeconds) msgID := stableID("call", remoteAddress, call.Timestamp.String(), fmt.Sprint(call.Type), strconv.Itoa(call.DurationSeconds)) return i.upsertTextMessage(sourceID, convID, msgID, "synctech_sms_call", senderID, recipientIDs, fromMe, call.Timestamp, body, body, 0, call) } -func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, messageType string, senderID int64, recipientIDs []int64, fromMe bool, sentAt time.Time, body, subject string, attachmentCount int, raw any) error { +func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, messageType string, senderID int64, recipientIDs []int64, fromMe bool, sentAt time.Time, body, subject string, attachmentCount int, raw any) (int64, error) { msgID, err := i.store.UpsertMessage(&store.Message{ ConversationID: convID, SourceID: sourceID, @@ -237,33 +250,33 @@ func (i *Importer) upsertTextMessage(sourceID, convID int64, sourceMessageID, me AttachmentCount: attachmentCount, }) if err != nil { - return fmt.Errorf("upsert message: %w", err) + return 0, fmt.Errorf("upsert message: %w", err) } if body != "" { if err := i.store.UpsertMessageBody(msgID, sql.NullString{String: body, Valid: true}, sql.NullString{}); err != nil { - return fmt.Errorf("upsert body: %w", err) + return 0, fmt.Errorf("upsert body: %w", err) } } rawJSON, err := json.Marshal(raw) if err != nil { - return fmt.Errorf("marshal raw record: %w", err) + return 0, fmt.Errorf("marshal raw record: %w", err) } if err := i.store.UpsertMessageRawWithFormat(msgID, rawJSON, RawFormat); err != nil { - return fmt.Errorf("upsert raw record: %w", err) + return 0, fmt.Errorf("upsert raw record: %w", err) } if err := i.store.ReplaceMessageRecipients(msgID, "from", []int64{senderID}, []string{""}); err != nil { - return fmt.Errorf("replace from recipient: %w", err) + return 0, fmt.Errorf("replace from recipient: %w", err) } if err := i.store.ReplaceMessageRecipients(msgID, "to", recipientIDs, blankNames(len(recipientIDs))); err != nil { - return fmt.Errorf("replace to recipient: %w", err) + return 0, fmt.Errorf("replace to recipient: %w", err) } if err := i.store.UpsertFTS(msgID, subject, body, "", "", ""); err != nil { - // FTS is an index, not data: a failure to populate it must never abort - // the import. Warn and continue, matching the other UpsertFTS callers - // (sync.go, importer/ingest.go, fbmessenger, whatsapp). [C2] - slog.Warn("failed to upsert FTS", "message", msgID, "error", err) + slog.Warn("failed to update Synctech message FTS index", + "message_id", msgID, + "message_type", messageType, + "error", err) } - return nil + return msgID, nil } func (i *Importer) participantID(address, displayName string) (int64, error) { diff --git a/internal/synctechsms/importer_test.go b/internal/synctechsms/importer_test.go index 769ebefe9..ac6a1f1de 100644 --- a/internal/synctechsms/importer_test.go +++ b/internal/synctechsms/importer_test.go @@ -7,6 +7,7 @@ import ( assertpkg "github.com/stretchr/testify/assert" requirepkg "github.com/stretchr/testify/require" "go.kenn.io/msgvault/internal/store" + "go.kenn.io/msgvault/internal/testutil" "go.kenn.io/msgvault/internal/testutil/storetest" ) @@ -84,6 +85,60 @@ func TestImporterImportsCallWithBlankNumber(t *testing.T) { assertMessageCount(t, f.Store, 1) } +func TestImporterContinuesWhenFTSUpsertFails(t *testing.T) { + testutil.SkipIfPostgres(t, "drops the SQLite FTS virtual table to force UpsertFTS failure") + f := storetest.New(t) + _, err := f.Store.DB().Exec(`DROP TABLE messages_fts`) + requirepkg.NoError(t, err, "drop messages_fts") + dir := t.TempDir() + writeFile(t, filepath.Join(dir, "messages.xml"), ` + +`) + + imp := NewImporter(f.Store, ImportOptions{ + OwnerPhone: "+15550000001", + IncludeSMS: true, + }) + summary, err := imp.ImportPath(dir) + requirepkg.NoError(t, err, "ImportPath should tolerate FTS indexing failure") + assertpkg.Equal(t, 1, summary.SMSImported, "summary = %#v", summary) + assertMessageCount(t, f.Store, 1) +} + +func TestImporterReturnsMMSMessageIDWhenAttachmentImportFails(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + f := storetest.New(t) + dir := t.TempDir() + writeFile(t, filepath.Join(dir, "messages.xml"), ` + + + + + + + + + + +`) + + imp := NewImporter(f.Store, ImportOptions{ + OwnerPhone: "+15550000001", + AttachmentsDir: filepath.Join(dir, "attachments"), + MaxAttachmentBytes: 1, + IncludeMMS: true, + IncludeAttachments: true, + }) + summary, err := imp.ImportPath(dir) + require.Error(err, "ImportPath") + assert.Contains(err.Error(), "exceeds maximum size", "ImportPath error") + + assertMessageCount(t, f.Store, 1) + require.Len(summary.MessageIDs, 1, "summary message IDs") + assert.Positive(summary.MessageIDs[0], "summary message id") +} + // TestImporterMarksDraftsAsFromMe guards against regressing the // draft-handling fix. SMSTypeDraft and MMSBoxDraft are owner-authored // messages; treating them as incoming hides them on the wrong side of diff --git a/internal/synctechsms/types.go b/internal/synctechsms/types.go index c2d5d8eb3..1ac427444 100644 --- a/internal/synctechsms/types.go +++ b/internal/synctechsms/types.go @@ -160,4 +160,5 @@ type ImportSummary struct { MMSImported int CallsImported int AttachmentsImported int + MessageIDs []int64 } diff --git a/internal/vector/build_scope.go b/internal/vector/build_scope.go new file mode 100644 index 000000000..6b53cdbb1 --- /dev/null +++ b/internal/vector/build_scope.go @@ -0,0 +1,65 @@ +package vector + +import ( + "slices" + "sort" + "strings" +) + +// BuildScope limits which messages are eligible for an embedding +// generation. A zero-value scope means the full corpus. +type BuildScope struct { + MessageTypes []string +} + +// NewBuildScope returns a normalized, stable scope. Message types are +// lowercase, trimmed, de-duplicated, and sorted so fingerprints and SQL +// bindings are deterministic. +func NewBuildScope(messageTypes []string) BuildScope { + seen := make(map[string]struct{}, len(messageTypes)) + out := make([]string, 0, len(messageTypes)) + for _, typ := range messageTypes { + typ = strings.TrimSpace(strings.ToLower(typ)) + if typ == "" { + continue + } + if _, ok := seen[typ]; ok { + continue + } + seen[typ] = struct{}{} + out = append(out, typ) + } + sort.Strings(out) + return BuildScope{MessageTypes: out} +} + +func (s BuildScope) IsEmpty() bool { + return len(s.MessageTypes) == 0 +} + +func (s BuildScope) Fingerprint() string { + if s.IsEmpty() { + return "" + } + return "mt-" + strings.Join(s.MessageTypes, ",") +} + +func (s BuildScope) ContainsMessageType(messageType string) bool { + messageType = strings.TrimSpace(strings.ToLower(messageType)) + return slices.Contains(s.MessageTypes, messageType) +} + +func (s BuildScope) AllowsMessageTypes(messageTypes []string) bool { + if s.IsEmpty() { + return true + } + if len(messageTypes) == 0 { + return false + } + for _, typ := range messageTypes { + if !s.ContainsMessageType(typ) { + return false + } + } + return true +} diff --git a/internal/vector/config.go b/internal/vector/config.go index 5dd3ebd85..23ca84483 100644 --- a/internal/vector/config.go +++ b/internal/vector/config.go @@ -213,7 +213,8 @@ type EmbedConfig struct { // stragglers from repair-encoding resets, transient errors, or crashes) // in addition to the per-tick incremental scan. Zero uses the EmbedJob // default (24h); a negative value disables the auto-backstop. - BackstopInterval time.Duration `toml:"backstop_interval"` + BackstopInterval time.Duration `toml:"backstop_interval"` + Scope EmbedScopeConfig `toml:"scope"` } // EmbedScheduleConfig controls when the embed worker runs on its own @@ -223,6 +224,16 @@ type EmbedScheduleConfig struct { RunAfterSync bool `toml:"run_after_sync"` // trigger after each successful sync } +// EmbedScopeConfig limits which messages are included in newly-built +// embedding generations. The zero value means the full corpus. +type EmbedScopeConfig struct { + MessageTypes []string `toml:"message_types"` +} + +func (s EmbedScopeConfig) BuildScope() BuildScope { + return NewBuildScope(s.MessageTypes) +} + // Fingerprint returns the ":" identifier for the // embedding endpoint half of the policy (§6.7 of the spec). Use // Config.GenerationFingerprint when storing or comparing index @@ -256,8 +267,12 @@ func (e EmbeddingsConfig) Fingerprint() string { // in, an active generation built under the old single-vector policy // would silently accept new chunked entries from an upgraded worker. func (c *Config) GenerationFingerprint() string { - return fmt.Sprintf("%s:%s:c%d:e%d", + fp := fmt.Sprintf("%s:%s:c%d:e%d", c.Embeddings.Fingerprint(), c.Preprocess.Fingerprint(), c.Embeddings.MaxInputChars, embedPolicyVersion) + if scopeFP := c.Embed.Scope.BuildScope().Fingerprint(); scopeFP != "" { + fp = fmt.Sprintf("%s:s%s", fp, scopeFP) + } + return fp } // Validate returns a descriptive error if the config is unusable. diff --git a/internal/vector/config_test.go b/internal/vector/config_test.go index cba341454..dd76eee40 100644 --- a/internal/vector/config_test.go +++ b/internal/vector/config_test.go @@ -433,3 +433,18 @@ func TestConfig_GenerationFingerprint_IncludesEmbedPolicyVersion(t *testing.T) { suffix := fmt.Sprintf(":e%d", embedPolicyVersion) assertpkg.True(t, strings.HasSuffix(got, suffix), "GenerationFingerprint() = %q, want suffix %q", got, suffix) } + +func TestConfig_GenerationFingerprint_IncludesEmbedScope(t *testing.T) { + base := Config{ + Embeddings: EmbeddingsConfig{Model: "m", Dimension: 8, MaxInputChars: 6000}, + } + baseline := base.GenerationFingerprint() + + scoped := base + scoped.Embed.Scope.MessageTypes = []string{"MMS", "sms", "sms"} + + assertpkg.NotEqual(t, baseline, scoped.GenerationFingerprint(), + "GenerationFingerprint should change when embedding is scoped") + assertpkg.Contains(t, scoped.GenerationFingerprint(), ":smt-mms,sms", + "scope fingerprint should normalize and sort message types") +} diff --git a/internal/vector/embed/worker.go b/internal/vector/embed/worker.go index 7bb0943e3..e1a180d06 100644 --- a/internal/vector/embed/worker.go +++ b/internal/vector/embed/worker.go @@ -64,6 +64,7 @@ type WorkerDeps struct { Store WorkStore Client EmbeddingClient Preprocess PreprocessConfig + BuildScope vector.BuildScope MaxInputChars int BatchSize int // beforeSkipStamp is a test hook for read-to-stamp race coverage. @@ -309,7 +310,7 @@ func (w *Worker) run(ctx context.Context, gen vector.GenerationID, backstop bool return res, fmt.Errorf("RunOnce: %w", err) } batchStart := time.Now() - ids, err := w.deps.Store.ScanForEmbedding(ctx, int64(gen), afterID, w.deps.BatchSize) + ids, err := w.scanForEmbedding(ctx, int64(gen), afterID) if err != nil { return res, fmt.Errorf("scan for embedding: %w", err) } @@ -552,6 +553,20 @@ func (w *Worker) run(ctx context.Context, gen vector.GenerationID, backstop bool } } +func (w *Worker) scanForEmbedding(ctx context.Context, gen int64, afterID int64) ([]int64, error) { + scope := vector.NewBuildScope(w.deps.BuildScope.MessageTypes) + if scope.IsEmpty() { + return w.deps.Store.ScanForEmbedding(ctx, gen, afterID, w.deps.BatchSize) + } + scoped, ok := w.deps.Store.(interface { + ScanForEmbeddingScoped(ctx context.Context, target int64, afterID int64, limit int, messageTypes []string) ([]int64, error) + }) + if !ok { + return nil, errors.New("work store does not support scoped embedding scans") + } + return scoped.ScanForEmbeddingScoped(ctx, gen, afterID, w.deps.BatchSize, scope.MessageTypes) +} + // advanceWatermark persists the per-gen forward-scan cursor to id after a // batch made forward progress. The backstop never persists (it scans from // 0 by design and must not push the optimistic watermark backward or diff --git a/internal/vector/errors.go b/internal/vector/errors.go index b4d5a5ecf..03a7ba3e0 100644 --- a/internal/vector/errors.go +++ b/internal/vector/errors.go @@ -64,4 +64,10 @@ var ( // "transient backend slow" response so clients can retry instead // of treating it as a permanent failure. ErrEmbeddingTimeout = errors.New("embedding request timed out") + + // ErrIndexScopeMismatch is returned when a scoped embedding index + // is used without an equivalent structured filter. For example, an + // index built only for message_type=sms must not answer an unscoped + // vector query over email + SMS. + ErrIndexScopeMismatch = errors.New("index scope mismatch") ) diff --git a/internal/vector/hybrid/engine.go b/internal/vector/hybrid/engine.go index f264c1ca1..412f0a463 100644 --- a/internal/vector/hybrid/engine.go +++ b/internal/vector/hybrid/engine.go @@ -63,7 +63,8 @@ type Config struct { // participant/label lookup SQL that BuildFilter runs against mainDB. // Pass PostgreSQLDialect.Rebind on PG (pgx rejects bare ?); leave nil // (or SQLiteDialect.Rebind, which is identity) on SQLite. - Rebind func(string) string + Rebind func(string) string + BuildScope vector.BuildScope } // Engine orchestrates the generation check, query embedding, and fusion @@ -111,6 +112,9 @@ func (e *Engine) Search(ctx context.Context, req SearchRequest) ([]vector.FusedH if err != nil { return nil, ResultMeta{}, err } + if err := e.validateBuildScope(req.Filter); err != nil { + return nil, ResultMeta{}, err + } if req.FreeText == "" { return nil, ResultMeta{}, errors.New("empty query") @@ -192,6 +196,24 @@ func (e *Engine) Search(ctx context.Context, req SearchRequest) ([]vector.FusedH }, nil } +func (e *Engine) validateBuildScope(filter vector.Filter) error { + scope := vector.NewBuildScope(e.cfg.BuildScope.MessageTypes) + if scope.IsEmpty() { + return nil + } + if len(filter.MessageTypes) == 0 { + return fmt.Errorf("%w: index is scoped to message_type=%s; add a matching message_type filter", + vector.ErrIndexScopeMismatch, strings.Join(scope.MessageTypes, ",")) + } + if !scope.AllowsMessageTypes(filter.MessageTypes) { + return fmt.Errorf("%w: index is scoped to message_type=%s, query requested message_type=%s", + vector.ErrIndexScopeMismatch, + strings.Join(scope.MessageTypes, ","), + strings.Join(vector.NewBuildScope(filter.MessageTypes).MessageTypes, ",")) + } + return nil +} + // vectorHitsToFused wraps pure-vector hits in the FusedHit schema. // BM25Score and RRFScore are both set to math.NaN(): "not present in // this signal." Pure vector mode never applies Reciprocal Rank Fusion diff --git a/internal/vector/hybrid/engine_test.go b/internal/vector/hybrid/engine_test.go index 3eaef6da3..8e424c184 100644 --- a/internal/vector/hybrid/engine_test.go +++ b/internal/vector/hybrid/engine_test.go @@ -165,6 +165,35 @@ func TestEngine_Hybrid_HappyPath(t *testing.T) { assert.Equal(len(results), meta.ReturnedCount) } +func TestEngine_ScopedIndexRequiresMatchingMessageTypeFilter(t *testing.T) { + ctx := context.Background() + f := newEngineFixture(t) + f.Engine.cfg.BuildScope = vector.NewBuildScope([]string{"sms", "mms"}) + + _, _, err := f.Engine.Search(ctx, SearchRequest{ + Mode: ModeVector, + FreeText: "lunch", + Limit: 5, + }) + requirepkg.ErrorIs(t, err, vector.ErrIndexScopeMismatch) + + _, _, err = f.Engine.Search(ctx, SearchRequest{ + Mode: ModeVector, + FreeText: "lunch", + Limit: 5, + Filter: vector.Filter{MessageTypes: []string{"email"}}, + }) + requirepkg.ErrorIs(t, err, vector.ErrIndexScopeMismatch) + + _, _, err = f.Engine.Search(ctx, SearchRequest{ + Mode: ModeVector, + FreeText: "lunch", + Limit: 5, + Filter: vector.Filter{MessageTypes: []string{"sms"}}, + }) + requirepkg.NoError(t, err) +} + // TestFTSTerms covers the FreeText → dialect-neutral term-slice // tokenizer directly (no DB needed). FreeText is split on whitespace // and terms the FTS5/tsquery tokenizers would drop entirely diff --git a/internal/vector/pgvector/backend.go b/internal/vector/pgvector/backend.go index fb0782e62..ac6782386 100644 --- a/internal/vector/pgvector/backend.go +++ b/internal/vector/pgvector/backend.go @@ -39,6 +39,9 @@ type Options struct { // per-dimension HNSW index on first migration. Optional; if zero // the index is created on first CreateGeneration. Dimension int + // BuildScope limits generation coverage to matching messages. Empty + // means the full corpus. + BuildScope vector.BuildScope // SkipMigrate suppresses the privileged CREATE EXTENSION + full // migrate. A WRITABLE open still applies the (extension-less) schema so // the one-time upgrade lands — read-only-ness is now signalled by @@ -74,7 +77,8 @@ type Options struct { // with the pgvector extension. The same *sql.DB also serves the main // msgvault schema (messages, message_recipients, message_labels). type Backend struct { - db *sql.DB + db *sql.DB + scope vector.BuildScope } // Open verifies the database is reachable, applies the embedding schema @@ -85,7 +89,10 @@ func Open(ctx context.Context, opts Options) (*Backend, error) { if opts.DB == nil { return nil, errors.New("pgvector.Open: Options.DB is required") } - b := &Backend{db: opts.DB} + b := &Backend{ + db: opts.DB, + scope: vector.NewBuildScope(opts.BuildScope.MessageTypes), + } if !opts.SkipMigrate { // serve / build / search: full migrate incl. CREATE EXTENSION (the // extension step is gated by SkipExtension for managed PG). The eager @@ -252,12 +259,22 @@ func isUniqueViolation(err error) bool { // (embed_gen IS NULL OR embed_gen <> gen). Built once and reused by // ActivateGeneration (in-tx, single-DB on PG) and Stats. The $N ordinal // of the generation id is supplied by the caller. -func missingForGenExistsClause(genArg string) string { +func (b *Backend) missingForGenExistsClause(genArg string, firstScopeArg int) (string, []any) { + where := store.LiveMessagesWhere("", true) + args := make([]any, 0, len(b.scope.MessageTypes)) + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "$" + strconv.Itoa(firstScopeArg+i) + args = append(args, typ) + } + where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ",")) + } return fmt.Sprintf(`EXISTS ( SELECT 1 FROM messages WHERE (embed_gen IS NULL OR embed_gen <> %s) AND %s - )`, genArg, store.LiveMessagesWhere("", true)) + )`, genArg, where), args } // ActivateGeneration atomically retires the current active generation @@ -321,18 +338,20 @@ func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationI // phase, which scan-and-fill no longer has, so a legacy/crashed gen with // seeded_at=NULL but full coverage must be activatable. Coverage // (missing==0) is the real gate. + missingClause, missingArgs := b.missingForGenExistsClause("$3", 5) + args := append([]any{now, now, int64(gen), force}, missingArgs...) res, err := tx.ExecContext(ctx, `UPDATE index_generations SET state = 'active', activated_at = $1, completed_at = COALESCE(completed_at, $2) WHERE id = $3 AND state = 'building' - AND ($4 OR NOT `+missingForGenExistsClause("$3")+`)`, - now, now, int64(gen), force) + AND ($4 OR NOT `+missingClause+`)`, + args...) if err != nil { return fmt.Errorf("activate: %w", err) } n, _ := res.RowsAffected() if n == 0 { - return activateGateError(ctx, tx, gen, force) + return b.activateGateError(ctx, tx, gen, force) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit activate generation %d: %w", gen, err) @@ -346,7 +365,7 @@ func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationI // also satisfies the coverage predicate (embed_gen <> gen is true for an // unknown gen id), so checking coverage first would surface the misleading // "messages needing embedding" error instead of the real lifecycle reason. -func activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, force bool) error { +func (b *Backend) activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, force bool) error { var state vector.GenerationState if err := tx.QueryRowContext(ctx, `SELECT state FROM index_generations WHERE id = $1`, int64(gen)).Scan(&state); err != nil { @@ -361,8 +380,10 @@ func activateGateError(ctx context.Context, tx *sql.Tx, gen vector.GenerationID, // Gen exists and is building, so the only remaining reason the gated // promote affected zero rows is the coverage term. var missing bool + missingClause, missingArgs := b.missingForGenExistsClause("$1", 2) + args := append([]any{int64(gen)}, missingArgs...) if err := tx.QueryRowContext(ctx, - `SELECT `+missingForGenExistsClause("$1"), int64(gen)).Scan(&missing); err != nil { + `SELECT `+missingClause, args...).Scan(&missing); err != nil { return fmt.Errorf("check coverage for generation %d: %w", gen, err) } if missing && !force { diff --git a/internal/vector/pgvector/backend_filter_test.go b/internal/vector/pgvector/backend_filter_test.go index dfebe2056..b96d8d185 100644 --- a/internal/vector/pgvector/backend_filter_test.go +++ b/internal/vector/pgvector/backend_filter_test.go @@ -3,6 +3,7 @@ package pgvector import ( + "fmt" "testing" "time" @@ -11,8 +12,24 @@ import ( "go.kenn.io/msgvault/internal/vector" ) +func TestBuildPGFilterClausesMessageTypes(t *testing.T) { + var args []any + bind := func(v any) string { + args = append(args, v) + return fmt.Sprintf("$%d", len(args)) + } + + clauses := buildPGFilterClauses(vector.Filter{MessageTypes: []string{"sms", "mms"}}, bind) + + require.Len(t, clauses, 1) + assert.Equal(t, "m.message_type = ANY($1::text[])", clauses[0]) + assert.Equal(t, []any{`{"sms","mms"}`}, args) +} + func TestBackendSearchStructuredFilters(t *testing.T) { b, ctx, db := newBackendForTest(t) + _, err := db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`) + require.NoError(t, err, "add message_type") gen := seedAndEmbed(t, b, db, map[int64][]float32{ 1: unitVec(4, 0), 2: unitVec(4, 1), @@ -20,7 +37,7 @@ func TestBackendSearchStructuredFilters(t *testing.T) { }) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) - _, err := db.ExecContext(ctx, ` + _, err = db.ExecContext(ctx, ` UPDATE messages SET source_id = CASE id WHEN 1 THEN 10 WHEN 2 THEN 20 ELSE 30 END, message_type = CASE id WHEN 1 THEN 'email' WHEN 2 THEN 'sms' ELSE 'mms' END, @@ -103,6 +120,11 @@ func TestBackendSearchStructuredFilters(t *testing.T) { filter: vector.Filter{LargerThan: &largerThan, SmallerThan: &smallerThan}, want: []int64{2}, }, + { + name: "message type", + filter: vector.Filter{MessageTypes: []string{"sms"}}, + want: []int64{2}, + }, { name: "no match sentinel", filter: vector.Filter{SenderGroups: [][]int64{{-1}}}, @@ -129,6 +151,23 @@ func TestBackendSearchStructuredFilters(t *testing.T) { } } +func TestBackendSearchMessageTypeFilter(t *testing.T) { + b, ctx, db := newBackendForTest(t) + _, err := db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`) + require.NoError(t, err, "add message_type") + gen := seedAndEmbed(t, b, db, map[int64][]float32{ + 1: unitVec(4, 0), + 2: unitVec(4, 1), + 3: unitVec(4, 2), + }) + _, err = db.ExecContext(ctx, `UPDATE messages SET message_type = CASE id WHEN 1 THEN 'email' ELSE 'sms' END`) + require.NoError(t, err, "seed message_type") + + hits, err := b.Search(ctx, gen, unitVec(4, 0), 10, vector.Filter{MessageTypes: []string{"sms"}}) + require.NoError(t, err, "Search") + assert.Equal(t, []int64{2, 3}, hitMessageIDs(hits)) +} + func hitMessageIDs(hits []vector.Hit) []int64 { out := make([]int64, len(hits)) for i, h := range hits { diff --git a/internal/vector/pgvector/ext_stub.go b/internal/vector/pgvector/ext_stub.go index f82e31309..05c0980ca 100644 --- a/internal/vector/pgvector/ext_stub.go +++ b/internal/vector/pgvector/ext_stub.go @@ -27,6 +27,7 @@ type Options struct { SkipMigrate bool ReadOnly bool SkipExtension bool + BuildScope vector.BuildScope } // Backend is a placeholder type so non-pgvector builds can compile diff --git a/internal/vector/pgvector/fused_test.go b/internal/vector/pgvector/fused_test.go index 325228de3..261c4fe62 100644 --- a/internal/vector/pgvector/fused_test.go +++ b/internal/vector/pgvector/fused_test.go @@ -148,6 +148,36 @@ func TestFusedSearch_FTSOnly(t *testing.T) { } } +func TestFusedSearch_MessageTypeFilter(t *testing.T) { + f := newFusedFixture(t) + _, err := f.db.ExecContext(f.ctx, `ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`) + require.NoError(t, err, "add message_type") + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + f.seedMsg(t, 1, "lunch plan", "sms lunch details", 10, base, false) + f.seedMsg(t, 2, "lunch receipt", "email lunch details", 10, base.Add(time.Hour), false) + f.seedMsg(t, 3, "dinner plan", "sms dinner details", 10, base.Add(2*time.Hour), false) + _, err = f.db.ExecContext(f.ctx, `UPDATE messages SET message_type = CASE id WHEN 2 THEN 'email' ELSE 'sms' END`) + require.NoError(t, err, "seed message_type") + f.embedAll(t, map[int64][]float32{ + 1: unitVec(4, 0), + 2: unitVec(4, 1), + 3: unitVec(4, 2), + }) + + hits, saturated, err := f.b.FusedSearch(f.ctx, vector.FusedRequest{ + FTSTerms: []string{"lunch"}, + Generation: f.gen, + KPerSignal: 10, + Limit: 10, + RRFK: 60, + Filter: vector.Filter{MessageTypes: []string{"sms"}}, + }) + require.NoError(t, err, "FusedSearch") + assert.False(t, saturated, "saturated should be false") + require.Len(t, hits, 1, "message_type filter should exclude email FTS hits; hits=%+v", hits) + assert.Equal(t, int64(1), hits[0].MessageID) +} + func TestFusedSearch_ANNOnly(t *testing.T) { f := seedThree(t) hits, saturated, err := f.b.FusedSearch(f.ctx, vector.FusedRequest{ diff --git a/internal/vector/pgvector/parity_test.go b/internal/vector/pgvector/parity_test.go index 532f7372c..37aca933a 100644 --- a/internal/vector/pgvector/parity_test.go +++ b/internal/vector/pgvector/parity_test.go @@ -69,6 +69,7 @@ func buildSqlitevecParity(t *testing.T, corpus []parityDoc) (*sqlitevec.Backend, CREATE TABLE messages ( id INTEGER PRIMARY KEY, subject TEXT, + message_type TEXT NOT NULL DEFAULT 'email', source_id INTEGER, sender_id INTEGER, has_attachments INTEGER DEFAULT 0, diff --git a/internal/vector/sqlitevec/backend.go b/internal/vector/sqlitevec/backend.go index a38c19f05..6b00c09e4 100644 --- a/internal/vector/sqlitevec/backend.go +++ b/internal/vector/sqlitevec/backend.go @@ -28,10 +28,11 @@ var _ vector.Backend = (*Backend)(nil) // Options configures how Open establishes a Backend. type Options struct { - Path string - MainPath string // filesystem path to msgvault.db; required for FusedSearch - Dimension int // default dimension for EnsureVectorTable at open - MainDB *sql.DB // handle to the main msgvault.db + Path string + MainPath string // filesystem path to msgvault.db; required for FusedSearch + Dimension int // default dimension for EnsureVectorTable at open + MainDB *sql.DB // handle to the main msgvault.db + BuildScope vector.BuildScope // empty means full corpus // ReadOnly indicates the main DB handle (MainDB) was opened read-only // — e.g. the MCP server's store.OpenReadOnly (_query_only=true). When // set, Open SKIPS BackfillEmbedGenForUpgrade, which would otherwise @@ -50,6 +51,7 @@ type Backend struct { path string // filesystem path to vectors.db mainPath string // filesystem path to msgvault.db (for ATTACH) dim int + scope vector.BuildScope // readOnly is true when mainDB was opened read-only (MCP). The // one-time upgrade backfill self-guards on it so it never writes // through the read-only main handle. See Options.ReadOnly. @@ -76,6 +78,7 @@ func Open(ctx context.Context, opts Options) (*Backend, error) { path: opts.Path, mainPath: opts.MainPath, dim: opts.Dimension, + scope: vector.NewBuildScope(opts.BuildScope.MessageTypes), readOnly: opts.ReadOnly, } // Orphaned-stamp reset (vectors.db-recreate safety): clear embed_gen for @@ -268,24 +271,35 @@ func isUniqueConstraintErr(err error) bool { // needs embedding for gen (embed_gen IS NULL OR embed_gen <> gen). This is // the scan-and-fill coverage gate. On SQLite the messages live in the main // DB while the generation lifecycle lives in vectors.db, so the gate -// cannot be folded into the activation tx — ActivateGeneration runs this -// Go pre-check against b.mainDB before its vectors.db tx (mirroring the -// intentionally-non-atomic scheduler gate). The full-scan backstop covers -// any TOCTOU window between this read and the flip. +// cannot be folded into the activation tx. func (b *Backend) hasMissingForGen(ctx context.Context, gen vector.GenerationID) (bool, error) { var exists int + where, args := b.missingCoverageWhere(int64(gen)) err := b.mainDB.QueryRowContext(ctx, `SELECT EXISTS ( SELECT 1 FROM messages - WHERE (embed_gen IS NULL OR embed_gen <> ?) - AND `+store.LiveMessagesWhere("", true)+` - )`, int64(gen)).Scan(&exists) + WHERE `+where+` + )`, args...).Scan(&exists) if err != nil { return false, fmt.Errorf("check missing coverage for generation %d: %w", gen, err) } return exists == 1, nil } +func (b *Backend) missingCoverageWhere(gen int64) (string, []any) { + where := "(embed_gen IS NULL OR embed_gen <> ?) AND " + store.LiveMessagesWhere("", true) + args := []any{gen} + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "?" + args = append(args, typ) + } + where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ",")) + } + return where, args +} + // ActivateGeneration atomically retires the current active generation // (if any) and promotes `gen` to active. func (b *Backend) ActivateGeneration(ctx context.Context, gen vector.GenerationID, force bool) error { diff --git a/internal/vector/sqlitevec/backend_test.go b/internal/vector/sqlitevec/backend_test.go index 697fd6d14..7e8734133 100644 --- a/internal/vector/sqlitevec/backend_test.go +++ b/internal/vector/sqlitevec/backend_test.go @@ -292,6 +292,49 @@ func TestBackend_ActivateGeneration_NullSeededAtActivatesWithCoverage(t *testing assertpkg.Equal(t, vector.GenerationActive, genStateSV(t, b, gen), "now active") } +func TestBackend_ActivateCoverageScopesMessageTypes(t *testing.T) { + ctx := context.Background() + main, err := sql.Open("sqlite3", ":memory:") + requirepkg.NoError(t, err, "open main") + t.Cleanup(func() { _ = main.Close() }) + _, err = main.Exec(`CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + message_type TEXT NOT NULL, + embed_gen INTEGER, + deleted_at DATETIME, + deleted_from_source_at DATETIME + )`) + requirepkg.NoError(t, err, "create messages") + _, err = main.Exec(` + INSERT INTO messages (id, message_type, deleted_from_source_at) VALUES + (1, 'email', NULL), + (2, 'sms', NULL), + (3, 'mms', NULL), + (4, 'sms', CURRENT_TIMESTAMP)`) + requirepkg.NoError(t, err, "insert messages") + + b, err := Open(ctx, Options{ + Path: filepath.Join(t.TempDir(), "vectors.db"), + Dimension: 768, + MainDB: main, + BuildScope: vector.NewBuildScope([]string{"sms", "mms"}), + }) + requirepkg.NoError(t, err, "Open") + t.Cleanup(func() { _ = b.Close() }) + + gid, err := b.CreateGeneration(ctx, "m", 768, "") + requirepkg.NoError(t, err, "Create") + missing, err := b.hasMissingForGen(ctx, gid) + requirepkg.NoError(t, err, "hasMissingForGen before stamp") + assertpkg.True(t, missing, "in-scope messages need embedding") + + _, err = main.Exec(`UPDATE messages SET embed_gen = ? WHERE id IN (2, 3)`, int64(gid)) + requirepkg.NoError(t, err, "stamp in-scope messages") + missing, err = b.hasMissingForGen(ctx, gid) + requirepkg.NoError(t, err, "hasMissingForGen after stamp") + assertpkg.False(t, missing, "out-of-scope and deleted messages must not block scoped activation") +} + // TestBackend_CreateGeneration_ResumesBuilding confirms that calling // CreateGeneration while a building row already exists with the same // fingerprint returns the existing id instead of failing on the unique diff --git a/internal/vector/sqlitevec/ext_stub.go b/internal/vector/sqlitevec/ext_stub.go index 081a8398b..55b50a811 100644 --- a/internal/vector/sqlitevec/ext_stub.go +++ b/internal/vector/sqlitevec/ext_stub.go @@ -32,11 +32,12 @@ func Available() bool { return false } // can reference sqlitevec.Options without a compile error; the struct is // never populated at runtime when the PG code path is taken. type Options struct { - Path string - MainPath string - Dimension int - MainDB *sql.DB - ReadOnly bool + Path string + MainPath string + Dimension int + MainDB *sql.DB + BuildScope vector.BuildScope + ReadOnly bool } // Backend is the stub backend type for builds without sqlite_vec. diff --git a/internal/vector/sqlitevec/fused_test.go b/internal/vector/sqlitevec/fused_test.go index 19a34f940..03d0f47a2 100644 --- a/internal/vector/sqlitevec/fused_test.go +++ b/internal/vector/sqlitevec/fused_test.go @@ -287,11 +287,11 @@ func openFusedMainWithSchema(t *testing.T, path string) *sql.DB { t.Cleanup(func() { _ = db.Close() }) // sent_at is DATETIME (text) to match the production schema. schema := ` -CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - subject TEXT, - message_type TEXT NOT NULL DEFAULT 'email', - source_id INTEGER, + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + subject TEXT, + message_type TEXT NOT NULL DEFAULT 'email', + source_id INTEGER, sender_id INTEGER, has_attachments INTEGER DEFAULT 0, size_estimate INTEGER, @@ -322,10 +322,10 @@ func openFusedMainWithoutMessageType(t *testing.T, path string) *sql.DB { requirepkg.NoError(t, err, "open main") t.Cleanup(func() { _ = db.Close() }) schema := ` -CREATE TABLE messages ( - id INTEGER PRIMARY KEY, - subject TEXT, - source_id INTEGER, + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + subject TEXT, + source_id INTEGER, sender_id INTEGER, has_attachments INTEGER DEFAULT 0, size_estimate INTEGER, From dda21d47c70901455880ea0d83117e283da85e71 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 25 Jun 2026 16:57:06 -0500 Subject: [PATCH 2/4] fix(vector): avoid duplicate pg message type DDL The shared pgvector test schema already includes messages.message_type. Re-adding that column in individual tests makes the setup non-idempotent on PostgreSQL and can fail before the filters are exercised. Remove the redundant ALTER TABLE statements so the tests rely on the common fixture schema and only seed message_type values needed by each case. Generated with Codex Co-authored-by: Codex --- internal/vector/pgvector/backend_filter_test.go | 8 ++------ internal/vector/pgvector/fused_test.go | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/internal/vector/pgvector/backend_filter_test.go b/internal/vector/pgvector/backend_filter_test.go index b96d8d185..2dda10245 100644 --- a/internal/vector/pgvector/backend_filter_test.go +++ b/internal/vector/pgvector/backend_filter_test.go @@ -28,8 +28,6 @@ func TestBuildPGFilterClausesMessageTypes(t *testing.T) { func TestBackendSearchStructuredFilters(t *testing.T) { b, ctx, db := newBackendForTest(t) - _, err := db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`) - require.NoError(t, err, "add message_type") gen := seedAndEmbed(t, b, db, map[int64][]float32{ 1: unitVec(4, 0), 2: unitVec(4, 1), @@ -37,7 +35,7 @@ func TestBackendSearchStructuredFilters(t *testing.T) { }) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) - _, err = db.ExecContext(ctx, ` + _, err := db.ExecContext(ctx, ` UPDATE messages SET source_id = CASE id WHEN 1 THEN 10 WHEN 2 THEN 20 ELSE 30 END, message_type = CASE id WHEN 1 THEN 'email' WHEN 2 THEN 'sms' ELSE 'mms' END, @@ -153,14 +151,12 @@ func TestBackendSearchStructuredFilters(t *testing.T) { func TestBackendSearchMessageTypeFilter(t *testing.T) { b, ctx, db := newBackendForTest(t) - _, err := db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`) - require.NoError(t, err, "add message_type") gen := seedAndEmbed(t, b, db, map[int64][]float32{ 1: unitVec(4, 0), 2: unitVec(4, 1), 3: unitVec(4, 2), }) - _, err = db.ExecContext(ctx, `UPDATE messages SET message_type = CASE id WHEN 1 THEN 'email' ELSE 'sms' END`) + _, err := db.ExecContext(ctx, `UPDATE messages SET message_type = CASE id WHEN 1 THEN 'email' ELSE 'sms' END`) require.NoError(t, err, "seed message_type") hits, err := b.Search(ctx, gen, unitVec(4, 0), 10, vector.Filter{MessageTypes: []string{"sms"}}) diff --git a/internal/vector/pgvector/fused_test.go b/internal/vector/pgvector/fused_test.go index 261c4fe62..b4bcb97e9 100644 --- a/internal/vector/pgvector/fused_test.go +++ b/internal/vector/pgvector/fused_test.go @@ -150,13 +150,11 @@ func TestFusedSearch_FTSOnly(t *testing.T) { func TestFusedSearch_MessageTypeFilter(t *testing.T) { f := newFusedFixture(t) - _, err := f.db.ExecContext(f.ctx, `ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`) - require.NoError(t, err, "add message_type") base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) f.seedMsg(t, 1, "lunch plan", "sms lunch details", 10, base, false) f.seedMsg(t, 2, "lunch receipt", "email lunch details", 10, base.Add(time.Hour), false) f.seedMsg(t, 3, "dinner plan", "sms dinner details", 10, base.Add(2*time.Hour), false) - _, err = f.db.ExecContext(f.ctx, `UPDATE messages SET message_type = CASE id WHEN 2 THEN 'email' ELSE 'sms' END`) + _, err := f.db.ExecContext(f.ctx, `UPDATE messages SET message_type = CASE id WHEN 2 THEN 'email' ELSE 'sms' END`) require.NoError(t, err, "seed message_type") f.embedAll(t, map[int64][]float32{ 1: unitVec(4, 0), From 646a65afe547760caed03816c0c65bc8534dba7c Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 25 Jun 2026 18:50:12 -0500 Subject: [PATCH 3/4] fix(vector): scope embeddings management coverage Scoped embedding builds only stamp messages that match the configured build scope. Management commands were still evaluating activation coverage and backend coverage gates against the full live corpus, so a valid scoped generation could look incomplete whenever out-of-scope messages were unstamped. Thread the configured build scope through management backend opens and coverage reads so list/activate/retire evaluate the same message universe as the build worker. The regression seeds an out-of-scope missing email beside an in-scope stamped SMS to lock the activation preflight to scoped coverage. Generated with Codex Co-authored-by: Codex --- cmd/msgvault/cmd/embed_test.go | 32 +++++++++++++++++++++++++++ cmd/msgvault/cmd/embeddings_manage.go | 16 +++++++++----- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/cmd/msgvault/cmd/embed_test.go b/cmd/msgvault/cmd/embed_test.go index e9af2674c..cae34feb0 100644 --- a/cmd/msgvault/cmd/embed_test.go +++ b/cmd/msgvault/cmd/embed_test.go @@ -139,6 +139,22 @@ func TestRunEmbeddingsActivateRefusesMissingWithoutForce(t *testing.T) { assert.Contains(err.Error(), "msgvault embeddings resume --backstop") } +func TestFillCoverageUsesEmbeddingScope(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + dataDir := t.TempDir() + dbPath := newEmbeddingMetadataTestDBFileAt(t, filepath.Join(dataDir, "vectors.db")) + seedMainDBWithScopedCoverageMessages(t, dataDir) + withEmbeddingCommandConfigDataDir(t, dbPath, dataDir) + cfg.Vector.Embed.Scope.MessageTypes = []string{"sms"} + + row := embeddingGenerationRow{ID: 2} + require.NoError(fillCoverage(t.Context(), &row)) + + assert.Equal(int64(1), row.LiveCount) + assert.Equal(int64(0), row.MissingCount) +} + // TestRetireEmbeddingGenerationRefusesActiveWithoutForce_PreCheck pins the // CLI UX gate that runs against the committed metadata read BEFORE any // backend connection: retiring an active generation without --force-active @@ -240,6 +256,22 @@ INSERT INTO messages (id, conversation_id, source_id, source_message_id, message requirepkg.NoError(t, err) } +func seedMainDBWithScopedCoverageMessages(t *testing.T, dataDir string) { + t.Helper() + s, err := store.Open(filepath.Join(dataDir, "msgvault.db")) + requirepkg.NoError(t, err) + defer func() { requirepkg.NoError(t, s.Close()) }() + requirepkg.NoError(t, s.InitSchema()) + _, err = s.DB().Exec(` +INSERT INTO sources (id, source_type, identifier) VALUES (1, 'gmail', 'me@example.com'); +INSERT INTO conversations (id, source_id, conversation_type) VALUES (1, 1, 'email_thread'), (2, 1, 'sms_thread'); +INSERT INTO messages (id, conversation_id, source_id, source_message_id, message_type, embed_gen) VALUES + (1, 1, 1, 'email-missing', 'email', NULL), + (2, 2, 1, 'sms-stamped', 'sms', 2); +`) + requirepkg.NoError(t, err) +} + func withEmbeddingCommandConfig(t *testing.T, vecPath string) { t.Helper() oldCfg := cfg diff --git a/cmd/msgvault/cmd/embeddings_manage.go b/cmd/msgvault/cmd/embeddings_manage.go index b3e0f04fa..81525b603 100644 --- a/cmd/msgvault/cmd/embeddings_manage.go +++ b/cmd/msgvault/cmd/embeddings_manage.go @@ -64,7 +64,8 @@ func fillCoverage(ctx context.Context, row *embeddingGenerationRow) error { return fmt.Errorf("open main db for coverage: %w", err) } defer func() { _ = s.Close() }() - live, _, _, missing, err := s.CoverageCounts(ctx, int64(row.ID)) + scope := cfg.Vector.Embed.Scope.BuildScope() + live, _, _, missing, err := s.CoverageCountsScoped(ctx, int64(row.ID), scope.MessageTypes) if err != nil { return err } @@ -87,7 +88,8 @@ func fillFullCoverage(ctx context.Context, backend vector.Backend, row *embeddin return fmt.Errorf("open main db for coverage: %w", err) } defer func() { _ = s.Close() }() - live, stamped, _, missing, err := s.CoverageCounts(ctx, int64(row.ID)) + scope := cfg.Vector.Embed.Scope.BuildScope() + live, stamped, _, missing, err := s.CoverageCountsScoped(ctx, int64(row.ID), scope.MessageTypes) if err != nil { return err } @@ -426,6 +428,7 @@ func openEmbeddingsBackend(ctx context.Context) (vector.Backend, func(), error) DB: db, Dimension: cfg.Vector.Embeddings.Dimension, SkipMigrate: true, + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { _ = db.Close() @@ -460,10 +463,11 @@ func openEmbeddingsBackend(ctx context.Context) (vector.Backend, func(), error) return nil, nil, fmt.Errorf("open main db for embeddings backend: %w", err) } b, err := sqlitevec.Open(ctx, sqlitevec.Options{ - Path: vecPath, - MainPath: dsn, - Dimension: cfg.Vector.Embeddings.Dimension, - MainDB: mainStore.DB(), + Path: vecPath, + MainPath: dsn, + Dimension: cfg.Vector.Embeddings.Dimension, + MainDB: mainStore.DB(), + BuildScope: cfg.Vector.Embed.Scope.BuildScope(), }) if err != nil { _ = mainStore.Close() From 471559706f8a666af5cf8a09f463a804efe3293b Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 25 Jun 2026 19:01:58 -0500 Subject: [PATCH 4/4] fix(vector): scope embedded coverage counts Scoped embedding coverage must keep all four displayed legs in the same message universe. After management switched live/stamped/missing to the configured build scope, the backend embedded count could still include out-of-scope stamped vectors, producing impossible list output such as embedded > live. Apply the backend build scope inside EmbeddedMessageCount for sqlitevec and pgvector so the embedded leg matches scoped coverage reads. Add a command-level fillFullCoverage regression with an out-of-scope stamped vector, plus backend invariant tests for both storage engines. Generated with Codex Co-authored-by: Codex --- .../cmd/embed_manage_sqlitevec_test.go | 28 +++++++++ cmd/msgvault/cmd/embed_test.go | 16 +++++ internal/vector/backend.go | 16 ++--- internal/vector/pgvector/backend.go | 24 ++++--- internal/vector/pgvector/coverage_test.go | 63 +++++++++++++++++++ internal/vector/sqlitevec/backend.go | 27 +++++--- internal/vector/sqlitevec/coverage_test.go | 60 ++++++++++++++++++ 7 files changed, 210 insertions(+), 24 deletions(-) diff --git a/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go b/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go index 3663648e5..8a815438b 100644 --- a/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go +++ b/cmd/msgvault/cmd/embed_manage_sqlitevec_test.go @@ -5,6 +5,7 @@ package cmd import ( "context" "database/sql" + "path/filepath" "testing" _ "github.com/mattn/go-sqlite3" @@ -48,3 +49,30 @@ func TestRunEmbeddingsRetire_ForceActive(t *testing.T) { row := mustGetEmbeddingGeneration(t.Context(), t, db, 1) assert.Equal(vector.GenerationRetired, row.State) } + +func TestFillFullCoverageUsesEmbeddingScopeForEmbeddedCount(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + ctx := context.Background() + dataDir := t.TempDir() + dbPath := newEmbeddingMetadataTestDBFileAt(t, filepath.Join(dataDir, "vectors.db")) + seedMainDBWithScopedFullCoverageMessages(t, dataDir) + withEmbeddingCommandConfigDataDir(t, dbPath, dataDir) + cfg.Vector.Embed.Scope.MessageTypes = []string{"sms"} + + backend, closeBackend, err := openEmbeddingsBackend(ctx) + require.NoError(err, "open embeddings backend") + t.Cleanup(closeBackend) + require.NoError(backend.Upsert(ctx, 2, []vector.Chunk{ + {MessageID: 1, Vector: []float32{1, 0, 0, 0}}, + {MessageID: 2, Vector: []float32{0, 1, 0, 0}}, + }), "upsert in-scope and out-of-scope vectors") + + row := embeddingGenerationRow{ID: 2} + require.NoError(fillFullCoverage(ctx, backend, &row)) + + assert.Equal(int64(1), row.LiveCount, "only sms is in scope") + assert.Equal(int64(1), row.EmbeddedCount, "out-of-scope email vector is excluded") + assert.Equal(int64(0), row.BlankCount) + assert.Equal(int64(0), row.MissingCount) +} diff --git a/cmd/msgvault/cmd/embed_test.go b/cmd/msgvault/cmd/embed_test.go index cae34feb0..f2eb3512e 100644 --- a/cmd/msgvault/cmd/embed_test.go +++ b/cmd/msgvault/cmd/embed_test.go @@ -272,6 +272,22 @@ INSERT INTO messages (id, conversation_id, source_id, source_message_id, message requirepkg.NoError(t, err) } +func seedMainDBWithScopedFullCoverageMessages(t *testing.T, dataDir string) { + t.Helper() + s, err := store.Open(filepath.Join(dataDir, "msgvault.db")) + requirepkg.NoError(t, err) + defer func() { requirepkg.NoError(t, s.Close()) }() + requirepkg.NoError(t, s.InitSchema()) + _, err = s.DB().Exec(` +INSERT INTO sources (id, source_type, identifier) VALUES (1, 'gmail', 'me@example.com'); +INSERT INTO conversations (id, source_id, conversation_type) VALUES (1, 1, 'email_thread'), (2, 1, 'sms_thread'); +INSERT INTO messages (id, conversation_id, source_id, source_message_id, message_type, embed_gen) VALUES + (1, 1, 1, 'email-stamped', 'email', 2), + (2, 2, 1, 'sms-stamped', 'sms', 2); +`) + requirepkg.NoError(t, err) +} + func withEmbeddingCommandConfig(t *testing.T, vecPath string) { t.Helper() oldCfg := cfg diff --git a/internal/vector/backend.go b/internal/vector/backend.go index 97ec05c88..2d8a2cb4d 100644 --- a/internal/vector/backend.go +++ b/internal/vector/backend.go @@ -181,18 +181,18 @@ type Backend interface { Delete(ctx context.Context, gen GenerationID, messageIDs []int64) error Stats(ctx context.Context, gen GenerationID) (Stats, error) - // EmbeddedMessageCount reports how many distinct LIVE, stamped + // EmbeddedMessageCount reports how many distinct in-scope LIVE, stamped // (embed_gen == gen) messages actually have at least one embedding row // for gen. This is the "embedded" leg of the coverage readout (live / // embedded / blank / missing). It lives on the backend because the // embeddings table is in vectors.db on SQLite (and the main DB on PG); - // only the backend holds that handle. The live+stamped intersection is - // REQUIRED for the coverage invariant to hold: SQLite intersects the - // vectors.db embedding ids against a live+stamped query on the main DB - // (cross-DB json_each, mirroring dropDeletedFromSource), while PostgreSQL - // uses a single JOIN to messages. Distinct from Stats.EmbeddingCount only - // in intent: this is the dedicated coverage helper and never folds the - // aggregate (gen == 0) path. + // only the backend holds that handle. The live+stamped+scope intersection + // is REQUIRED for the coverage invariant to hold: SQLite intersects the + // vectors.db embedding ids against an in-scope live+stamped query on the + // main DB (cross-DB json_each, mirroring dropDeletedFromSource), while + // PostgreSQL uses a single JOIN to messages. Distinct from + // Stats.EmbeddingCount only in intent: this is the dedicated coverage + // helper and never folds the aggregate (gen == 0) path. EmbeddedMessageCount(ctx context.Context, gen GenerationID) (int64, error) // LoadVector returns the embedding for a specific message in the diff --git a/internal/vector/pgvector/backend.go b/internal/vector/pgvector/backend.go index ac6782386..be9a9a237 100644 --- a/internal/vector/pgvector/backend.go +++ b/internal/vector/pgvector/backend.go @@ -1227,9 +1227,9 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St return s, nil } -// EmbeddedMessageCount returns the number of LIVE messages that are -// stamped for gen (embed_gen = gen) AND actually have at least one vector -// for the generation. Used by the coverage readout to split stamped +// EmbeddedMessageCount returns the number of in-scope LIVE messages that +// are stamped for gen (embed_gen = gen) AND actually have at least one +// vector for the generation. Used by the coverage readout to split stamped // messages into embedded vs blank. Counts distinct messages (not chunk // rows) so a long, multi-chunk message counts once, matching the // EmbeddingCount semantic elsewhere. @@ -1247,15 +1247,25 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St // live intersection is a single JOIN against messages, mirroring // store.LiveMessagesWhere's predicate. func (b *Backend) EmbeddedMessageCount(ctx context.Context, gen vector.GenerationID) (int64, error) { + where := `e.generation_id = $1 + AND m.embed_gen = $1 + AND ` + store.LiveMessagesWhere("m", true) + args := []any{int64(gen)} + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "$" + strconv.Itoa(2+i) + args = append(args, typ) + } + where += fmt.Sprintf(" AND m.message_type IN (%s)", strings.Join(placeholders, ",")) + } var n int64 if err := b.db.QueryRowContext(ctx, `SELECT COUNT(DISTINCT e.message_id) FROM embeddings e JOIN messages m ON m.id = e.message_id - WHERE e.generation_id = $1 - AND m.embed_gen = $1 - AND `+store.LiveMessagesWhere("m", true), - int64(gen)).Scan(&n); err != nil { + WHERE `+where, + args...).Scan(&n); err != nil { return 0, fmt.Errorf("count embedded messages: %w", err) } return n, nil diff --git a/internal/vector/pgvector/coverage_test.go b/internal/vector/pgvector/coverage_test.go index 2293b8d9d..ff177ffd5 100644 --- a/internal/vector/pgvector/coverage_test.go +++ b/internal/vector/pgvector/coverage_test.go @@ -90,3 +90,66 @@ func TestCoverageSplit_EmbeddedBlankMissing(t *testing.T) { assert.Equal(live, embeddedCount+blank+missing, "invariant: live == embedded + blank + missing") } + +func TestCoverageSplit_ScopedEmbeddedHoldsInvariant(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + ctx := context.Background() + + db := openPGTestDB(t) // skips when MSGVAULT_TEST_DB is unset + b, err := Open(ctx, Options{ + DB: db, + Dimension: 8, + BuildScope: vector.NewBuildScope([]string{"sms"}), + }) + require.NoError(err, "Open backend") + t.Cleanup(func() { _ = b.Close() }) + + _, err = db.ExecContext(ctx, ` + INSERT INTO messages (id, message_type) VALUES + (1, 'email'), + (2, 'sms') + ON CONFLICT (id) DO UPDATE SET + message_type = EXCLUDED.message_type, + deleted_at = NULL, + deleted_from_source_at = NULL, + embed_gen = NULL`) + require.NoError(err, "seed scoped messages") + + gen, err := b.CreateGeneration(ctx, "test-model", 8, "fp") + require.NoError(err, "CreateGeneration") + require.NoError(b.Upsert(ctx, gen, []vector.Chunk{ + {MessageID: 1, Vector: []float32{1, 0, 0, 0, 0, 0, 0, 0}}, + {MessageID: 2, Vector: []float32{0, 1, 0, 0, 0, 0, 0, 0}}, + }), "Upsert embedded vectors") + _, err = db.ExecContext(ctx, `UPDATE messages SET embed_gen = $1 WHERE id IN (1, 2)`, int64(gen)) + require.NoError(err, "stamp embedded") + + var live, stamped int64 + require.NoError(db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM messages + WHERE message_type = 'sms' + AND deleted_at IS NULL + AND deleted_from_source_at IS NULL`).Scan(&live), + "count scoped live") + require.NoError(db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM messages + WHERE message_type = 'sms' + AND embed_gen = $1 + AND deleted_at IS NULL + AND deleted_from_source_at IS NULL`, + int64(gen)).Scan(&stamped), "count scoped stamped") + missing := live - stamped + + embeddedCount, err := b.EmbeddedMessageCount(ctx, gen) + require.NoError(err, "EmbeddedMessageCount") + blank := max(stamped-embeddedCount, 0) + + assert.Equal(int64(1), live, "only sms is in scope") + assert.Equal(int64(1), stamped, "only scoped stamped messages count") + assert.Equal(int64(1), embeddedCount, "out-of-scope email vector excluded") + assert.Equal(int64(0), blank) + assert.Equal(int64(0), missing) + assert.Equal(live, embeddedCount+blank+missing, + "invariant: live == embedded + blank + missing") +} diff --git a/internal/vector/sqlitevec/backend.go b/internal/vector/sqlitevec/backend.go index 6b00c09e4..72ee9142f 100644 --- a/internal/vector/sqlitevec/backend.go +++ b/internal/vector/sqlitevec/backend.go @@ -1389,9 +1389,9 @@ func (b *Backend) Stats(ctx context.Context, gen vector.GenerationID) (vector.St return s, nil } -// EmbeddedMessageCount returns the number of LIVE messages that are -// stamped for gen (embed_gen = gen) AND actually have at least one vector -// for the generation. Used by the coverage readout to split stamped +// EmbeddedMessageCount returns the number of in-scope LIVE messages that +// are stamped for gen (embed_gen = gen) AND actually have at least one +// vector for the generation. Used by the coverage readout to split stamped // messages into embedded vs blank. Counts distinct messages (not chunk // rows) so a long, multi-chunk message counts once, matching the // EmbeddingCount semantic elsewhere. @@ -1448,18 +1448,27 @@ func (b *Backend) EmbeddedMessageCount(ctx context.Context, gen vector.Generatio return 0, nil } - // Step 2 (main.db): how many of those are live AND stamped for gen. + // Step 2 (main.db): how many of those are in-scope, live, AND stamped for gen. blob, err := json.Marshal(ids) if err != nil { return 0, fmt.Errorf("encode embedded ids: %w", err) } + where := `id IN (SELECT value FROM json_each(?)) + AND embed_gen = ? + AND ` + store.LiveMessagesWhere("", true) + args := []any{string(blob), int64(gen)} + if !b.scope.IsEmpty() { + placeholders := make([]string, len(b.scope.MessageTypes)) + for i, typ := range b.scope.MessageTypes { + placeholders[i] = "?" + args = append(args, typ) + } + where += fmt.Sprintf(" AND message_type IN (%s)", strings.Join(placeholders, ",")) + } var n int64 if err := b.mainDB.QueryRowContext(ctx, - `SELECT COUNT(*) FROM messages - WHERE id IN (SELECT value FROM json_each(?)) - AND embed_gen = ? - AND `+store.LiveMessagesWhere("", true), - string(blob), int64(gen)).Scan(&n); err != nil { + `SELECT COUNT(*) FROM messages WHERE `+where, + args...).Scan(&n); err != nil { return 0, fmt.Errorf("count live embedded messages: %w", err) } return n, nil diff --git a/internal/vector/sqlitevec/coverage_test.go b/internal/vector/sqlitevec/coverage_test.go index 2ffcd0ac6..55c6c7861 100644 --- a/internal/vector/sqlitevec/coverage_test.go +++ b/internal/vector/sqlitevec/coverage_test.go @@ -218,3 +218,63 @@ func TestCoverageSplit_NonLiveEmbeddedHoldsInvariant(t *testing.T) { assert.Equal(live, embedded+blank+missingCount, "invariant: live == embedded + blank + missing") } + +func TestCoverageSplit_ScopedEmbeddedHoldsInvariant(t *testing.T) { + require := requirepkg.New(t) + assert := assertpkg.New(t) + ctx := context.Background() + + st := testutil.NewSQLiteTestStore(t) + b, err := Open(ctx, Options{ + Path: filepath.Join(t.TempDir(), "vectors.db"), + Dimension: 8, + MainDB: st.DB(), + BuildScope: vector.NewBuildScope([]string{"sms"}), + }) + require.NoError(err, "Open backend") + t.Cleanup(func() { _ = b.Close() }) + + source, err := st.GetOrCreateSource("gmail", "me@example.com") + require.NoError(err, "GetOrCreateSource") + emailConvID, err := st.EnsureConversationWithType(source.ID, "conv-email", "email_thread", "Email") + require.NoError(err, "EnsureConversationWithType email") + smsConvID, err := st.EnsureConversationWithType(source.ID, "conv-sms", "sms_thread", "SMS") + require.NoError(err, "EnsureConversationWithType sms") + + makeMsg := func(srcMsgID, typ string, convID int64) int64 { + m := &store.Message{ + SourceID: source.ID, + SourceMessageID: srcMsgID, + ConversationID: convID, + MessageType: typ, + Subject: sql.NullString{String: "s-" + srcMsgID, Valid: true}, + } + id, err := st.UpsertMessage(m) + require.NoErrorf(err, "UpsertMessage %s", srcMsgID) + return id + } + outOfScopeEmail := makeMsg("email-stamped", "email", emailConvID) + inScopeSMS := makeMsg("sms-stamped", "sms", smsConvID) + + gen, err := b.CreateGeneration(ctx, "test-model", 8, "fp") + require.NoError(err, "CreateGeneration") + require.NoError(b.Upsert(ctx, gen, []vector.Chunk{ + {MessageID: outOfScopeEmail, Vector: []float32{1, 0, 0, 0, 0, 0, 0, 0}}, + {MessageID: inScopeSMS, Vector: []float32{0, 1, 0, 0, 0, 0, 0, 0}}, + }), "Upsert embedded vectors") + require.NoError(st.SetEmbedGen(ctx, []int64{outOfScopeEmail, inScopeSMS}, int64(gen)), "stamp embedded") + + live, stamped, _, missingCount, err := st.CoverageCountsScoped(ctx, int64(gen), []string{"sms"}) + require.NoError(err, "CoverageCountsScoped") + embedded, err := b.EmbeddedMessageCount(ctx, gen) + require.NoError(err, "EmbeddedMessageCount") + blank := max(stamped-embedded, 0) + + assert.Equal(int64(1), live, "only sms is in scope") + assert.Equal(int64(1), stamped, "only scoped stamped messages count") + assert.Equal(int64(1), embedded, "out-of-scope email vector excluded") + assert.Equal(int64(0), blank) + assert.Equal(int64(0), missingCount) + assert.Equal(live, embedded+blank+missingCount, + "invariant: live == embedded + blank + missing") +}