From e07897bdb5e9a34fe2b67a39aed51b8b2f05f233 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Tue, 23 Jun 2026 23:22:42 -0400 Subject: [PATCH 1/4] Extract vector encoding queue Vector embedding queues were only available inside msgvault, which made other Kenn tools copy or reimplement the same crash-safe claim/release mechanics. Moving the SQL-backed task queue into kit gives callers a shared, app-neutral helper while preserving msgvault's existing table shape. The package keeps schema ownership with callers and validates SQL identifiers before interpolating table or column names, so future consumers can adapt the queue to their own storage without importing msgvault internals. Validation: go test ./...; go vet ./... Generated with Codex Co-authored-by: Codex --- go.mod | 1 + go.sum | 2 + vector/encodingqueue/doc.go | 7 + vector/encodingqueue/queue.go | 297 +++++++++++++++++++++++++++++ vector/encodingqueue/queue_test.go | 211 ++++++++++++++++++++ 5 files changed, 518 insertions(+) create mode 100644 vector/encodingqueue/doc.go create mode 100644 vector/encodingqueue/queue.go create mode 100644 vector/encodingqueue/queue_test.go diff --git a/go.mod b/go.mod index e9497ee..927d7ab 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/klauspost/compress v1.18.6 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.44 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect diff --git a/go.sum b/go.sum index f82d09d..a62d568 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posthog/posthog-go v1.12.6 h1:N+FrKWY6DOuDhV2OMgvtKAKDYGTdtS9/nuvr0BTyBp0= diff --git a/vector/encodingqueue/doc.go b/vector/encodingqueue/doc.go new file mode 100644 index 0000000..a86385f --- /dev/null +++ b/vector/encodingqueue/doc.go @@ -0,0 +1,7 @@ +// Package encodingqueue provides SQLite-compatible queue helpers for +// vector encoding tasks. +// +// The package expects one pending-task table with a compound primary key +// over generation and task IDs, an enqueue timestamp, and nullable claim +// fields. Callers own schema creation and database lifecycle. +package encodingqueue diff --git a/vector/encodingqueue/queue.go b/vector/encodingqueue/queue.go new file mode 100644 index 0000000..d9056bc --- /dev/null +++ b/vector/encodingqueue/queue.go @@ -0,0 +1,297 @@ +package encodingqueue + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "regexp" + "sort" + "time" +) + +var identifierPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// Schema names the pending-task table and columns used by Queue. +// +// Identifiers are interpolated into SQL after validation, so every field +// must be a simple SQL identifier rather than arbitrary SQL text. +type Schema struct { + Table string + GroupIDColumn string + TaskIDColumn string + EnqueuedAtColumn string + ClaimedAtColumn string + ClaimTokenColumn string +} + +// DefaultSchema returns the table shape used by msgvault's vector +// embedding queue. +func DefaultSchema() Schema { + return Schema{ + Table: "pending_embeddings", + GroupIDColumn: "generation_id", + TaskIDColumn: "message_id", + EnqueuedAtColumn: "enqueued_at", + ClaimedAtColumn: "claimed_at", + ClaimTokenColumn: "claim_token", + } +} + +// Queue wraps a SQL pending-task table with crash-safe claim, complete, +// release, and stale-claim reclamation operations. +type Queue struct { + db *sql.DB + schema Schema +} + +type execContext interface { + ExecContext(context.Context, string, ...any) (sql.Result, error) +} + +// NewQueue returns a Queue bound to db. The caller retains ownership of +// db; Queue does not close it. +func NewQueue(db *sql.DB, schema Schema) (*Queue, error) { + if err := schema.validate(); err != nil { + return nil, err + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + return &Queue{db: db, schema: schema}, nil +} + +// Enqueue inserts every task ID once for every group ID. Duplicate +// (group, task) pairs are ignored by the pending table's uniqueness +// constraint. +func (q *Queue) Enqueue(ctx context.Context, groupIDs []int64, taskIDs []int64) error { + if len(groupIDs) == 0 || len(taskIDs) == 0 { + return nil + } + tx, err := q.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin enqueue tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if err := q.EnqueueTx(ctx, tx, groupIDs, taskIDs); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit enqueue: %w", err) + } + return nil +} + +// EnqueueTx inserts every task ID once for every group ID using tx. The +// caller is responsible for committing or rolling back tx. +func (q *Queue) EnqueueTx(ctx context.Context, tx *sql.Tx, groupIDs []int64, taskIDs []int64) error { + if tx == nil { + return fmt.Errorf("tx is nil") + } + return q.enqueueWith(ctx, tx, groupIDs, taskIDs) +} + +// Claim marks up to batch available rows for groupID as claimed by a +// fresh token, returning task IDs in ascending order with the token to +// pass to Complete or Release. +func (q *Queue) Claim(ctx context.Context, groupID int64, batch int) ([]int64, string, error) { + if batch <= 0 { + return nil, "", nil + } + token, err := newToken() + if err != nil { + return nil, "", fmt.Errorf("new token: %w", err) + } + now := time.Now().Unix() + + tx, err := q.db.BeginTx(ctx, nil) + if err != nil { + return nil, "", fmt.Errorf("begin claim tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + rows, err := tx.QueryContext(ctx, q.claimSQL(), now, token, groupID, batch) + if err != nil { + return nil, "", fmt.Errorf("claim query: %w", err) + } + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + _ = rows.Close() + return nil, "", fmt.Errorf("scan claimed task id: %w", err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, "", fmt.Errorf("claim rows: %w", err) + } + if err := rows.Close(); err != nil { + return nil, "", fmt.Errorf("close claim rows: %w", err) + } + if err := tx.Commit(); err != nil { + return nil, "", fmt.Errorf("commit claim: %w", err) + } + if len(ids) == 0 { + return nil, "", nil + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids, token, nil +} + +// Complete deletes claimed rows whose claim token still matches token. +func (q *Queue) Complete(ctx context.Context, groupID int64, token string, taskIDs []int64) error { + if len(taskIDs) == 0 { + return nil + } + return q.updateClaimedIDs(ctx, "delete pending tasks", q.completeSQL(), groupID, token, taskIDs) +} + +// Release clears matching claims so tasks can be retried by another +// worker. +func (q *Queue) Release(ctx context.Context, groupID int64, token string, taskIDs []int64) error { + if len(taskIDs) == 0 { + return nil + } + return q.updateClaimedIDs(ctx, "release pending tasks", q.releaseSQL(), groupID, token, taskIDs) +} + +// ReclaimStale clears claims older than olderThan and returns the number +// of rows reclaimed. +func (q *Queue) ReclaimStale(ctx context.Context, olderThan time.Duration) (int, error) { + cutoff := time.Now().Add(-olderThan).Unix() + res, err := q.db.ExecContext(ctx, q.reclaimSQL(), cutoff) + if err != nil { + return 0, fmt.Errorf("reclaim stale: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("rows affected: %w", err) + } + return int(n), nil +} + +func (q *Queue) updateClaimedIDs(ctx context.Context, op, query string, groupID int64, token string, taskIDs []int64) error { + blob, err := json.Marshal(taskIDs) + if err != nil { + return fmt.Errorf("encode task ids: %w", err) + } + if _, err := q.db.ExecContext(ctx, query, groupID, token, string(blob)); err != nil { + return fmt.Errorf("%s: %w", op, err) + } + return nil +} + +func (q *Queue) enqueueWith(ctx context.Context, execer execContext, groupIDs []int64, taskIDs []int64) error { + if len(groupIDs) == 0 || len(taskIDs) == 0 { + return nil + } + blob, err := json.Marshal(taskIDs) + if err != nil { + return fmt.Errorf("encode task ids: %w", err) + } + now := time.Now().Unix() + for _, groupID := range groupIDs { + if _, err := execer.ExecContext(ctx, q.enqueueSQL(), groupID, now, string(blob)); err != nil { + return fmt.Errorf("insert pending tasks (group=%d): %w", groupID, err) + } + } + return nil +} + +func (q *Queue) enqueueSQL() string { + s := q.schema + return fmt.Sprintf(` + INSERT OR IGNORE INTO %s (%s, %s, %s) + SELECT ?, value, ? FROM json_each(?)`, + s.Table, s.GroupIDColumn, s.TaskIDColumn, s.EnqueuedAtColumn) +} + +func (q *Queue) claimSQL() string { + s := q.schema + return fmt.Sprintf(` + UPDATE %s + SET %s = ?, %s = ? + WHERE (%s, %s) IN ( + SELECT %s, %s + FROM %s + WHERE %s = ? + AND %s IS NULL + ORDER BY %s + LIMIT ?) + RETURNING %s`, + s.Table, + s.ClaimedAtColumn, s.ClaimTokenColumn, + s.GroupIDColumn, s.TaskIDColumn, + s.GroupIDColumn, s.TaskIDColumn, + s.Table, + s.GroupIDColumn, + s.ClaimedAtColumn, + s.TaskIDColumn, + s.TaskIDColumn) +} + +func (q *Queue) completeSQL() string { + s := q.schema + return fmt.Sprintf(` + DELETE FROM %s + WHERE %s = ? + AND %s = ? + AND %s IN (SELECT value FROM json_each(?))`, + s.Table, s.GroupIDColumn, s.ClaimTokenColumn, s.TaskIDColumn) +} + +func (q *Queue) releaseSQL() string { + s := q.schema + return fmt.Sprintf(` + UPDATE %s + SET %s = NULL, %s = NULL + WHERE %s = ? + AND %s = ? + AND %s IN (SELECT value FROM json_each(?))`, + s.Table, + s.ClaimedAtColumn, s.ClaimTokenColumn, + s.GroupIDColumn, + s.ClaimTokenColumn, + s.TaskIDColumn) +} + +func (q *Queue) reclaimSQL() string { + s := q.schema + return fmt.Sprintf(` + UPDATE %s + SET %s = NULL, %s = NULL + WHERE %s IS NOT NULL AND %s < ?`, + s.Table, + s.ClaimedAtColumn, s.ClaimTokenColumn, + s.ClaimedAtColumn, s.ClaimedAtColumn) +} + +func (s Schema) validate() error { + fields := map[string]string{ + "table": s.Table, + "group id column": s.GroupIDColumn, + "task id column": s.TaskIDColumn, + "enqueued at column": s.EnqueuedAtColumn, + "claimed at column": s.ClaimedAtColumn, + "claim token column": s.ClaimTokenColumn, + } + for name, value := range fields { + if !identifierPattern.MatchString(value) { + return fmt.Errorf("invalid %s %q", name, value) + } + } + return nil +} + +func newToken() (string, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("read random: %w", err) + } + return hex.EncodeToString(b), nil +} diff --git a/vector/encodingqueue/queue_test.go b/vector/encodingqueue/queue_test.go new file mode 100644 index 0000000..ea7f89d --- /dev/null +++ b/vector/encodingqueue/queue_test.go @@ -0,0 +1,211 @@ +package encodingqueue_test + +import ( + "context" + "database/sql" + "path/filepath" + "sort" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector/encodingqueue" +) + +func TestQueueEnqueueClaimReleaseComplete(t *testing.T) { + ctx := context.Background() + require := require.New(t) + assert := assert.New(t) + db := openQueueDB(t) + q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) + require.NoError(err) + + require.NoError(q.Enqueue(ctx, []int64{1, 2}, []int64{10, 11, 12})) + + ids, token, err := q.Claim(ctx, 1, 2) + require.NoError(err) + require.NotEmpty(token) + assert.Equal([]int64{10, 11}, ids) + assert.Equal(1, countAvailable(t, db, 1)) + + require.NoError(q.Release(ctx, 1, token, ids)) + assert.Equal(3, countAvailable(t, db, 1)) + + more, token2, err := q.Claim(ctx, 1, 10) + require.NoError(err) + require.NotEqual(token, token2) + assert.Equal([]int64{10, 11, 12}, more) + + require.NoError(q.Complete(ctx, 1, token2, more)) + assert.Equal(0, countPending(t, db, 1)) + assert.Equal(3, countPending(t, db, 2)) +} + +func TestQueueIgnoresDuplicateEnqueue(t *testing.T) { + ctx := context.Background() + require := require.New(t) + assert := assert.New(t) + db := openQueueDB(t) + q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) + require.NoError(err) + + require.NoError(q.Enqueue(ctx, []int64{1}, []int64{42})) + require.NoError(q.Enqueue(ctx, []int64{1}, []int64{42, 42})) + + assert.Equal(1, countPending(t, db, 1)) +} + +func TestQueueEnqueueTxUsesCallerTransaction(t *testing.T) { + ctx := context.Background() + require := require.New(t) + assert := assert.New(t) + db := openQueueDB(t) + q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) + require.NoError(err) + + tx, err := db.BeginTx(ctx, nil) + require.NoError(err) + require.NoError(q.EnqueueTx(ctx, tx, []int64{1}, []int64{10, 11})) + assert.Equal(0, countPending(t, db, 1)) + require.NoError(tx.Commit()) + + assert.Equal(2, countPending(t, db, 1)) +} + +func TestQueueClaimReturnsEmptyForNoWork(t *testing.T) { + ctx := context.Background() + require := require.New(t) + assert := assert.New(t) + db := openQueueDB(t) + q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) + require.NoError(err) + + ids, token, err := q.Claim(ctx, 1, 10) + require.NoError(err) + + assert.Empty(ids) + assert.Empty(token) +} + +func TestQueueReclaimStalePreservesNewClaimFromLateComplete(t *testing.T) { + ctx := context.Background() + require := require.New(t) + assert := assert.New(t) + db := openQueueDB(t) + q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) + require.NoError(err) + require.NoError(q.Enqueue(ctx, []int64{1}, []int64{1, 2})) + + idsA, tokenA, err := q.Claim(ctx, 1, 2) + require.NoError(err) + require.Len(idsA, 2) + backdateClaims(t, db, 20*time.Minute) + + reclaimed, err := q.ReclaimStale(ctx, 10*time.Minute) + require.NoError(err) + assert.Equal(2, reclaimed) + + idsB, tokenB, err := q.Claim(ctx, 1, 2) + require.NoError(err) + require.Len(idsB, 2) + require.NotEqual(tokenA, tokenB) + + require.NoError(q.Complete(ctx, 1, tokenA, idsA)) + assert.Equal(2, countPending(t, db, 1)) + assert.Equal(2, countClaimedByToken(t, db, tokenB)) + + require.NoError(q.Complete(ctx, 1, tokenB, idsB)) + assert.Equal(0, countPending(t, db, 1)) +} + +func TestQueueClaimReturnsIDsAscending(t *testing.T) { + ctx := context.Background() + require := require.New(t) + assert := assert.New(t) + db := openQueueDB(t) + q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) + require.NoError(err) + require.NoError(q.Enqueue(ctx, []int64{1}, []int64{9, 3, 7, 1})) + + ids, _, err := q.Claim(ctx, 1, 10) + require.NoError(err) + + assert.True(sort.SliceIsSorted(ids, func(i, j int) bool { return ids[i] < ids[j] }), "ids: %v", ids) +} + +func TestNewQueueRejectsUnsafeIdentifiers(t *testing.T) { + _, err := encodingqueue.NewQueue(nil, encodingqueue.Schema{ + Table: "pending_embeddings; DROP TABLE pending_embeddings", + GroupIDColumn: "generation_id", + TaskIDColumn: "message_id", + EnqueuedAtColumn: "enqueued_at", + ClaimedAtColumn: "claimed_at", + ClaimTokenColumn: "claim_token", + }) + + require.Error(t, err) +} + +func openQueueDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "queue.db")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, db.Close()) }) + _, err = db.Exec(` +CREATE TABLE pending_embeddings ( + generation_id INTEGER NOT NULL, + message_id INTEGER NOT NULL, + enqueued_at INTEGER NOT NULL, + claimed_at INTEGER, + claim_token TEXT, + PRIMARY KEY (generation_id, message_id) +); +CREATE INDEX idx_pending_available + ON pending_embeddings(generation_id, message_id) WHERE claimed_at IS NULL; +CREATE INDEX idx_pending_claims + ON pending_embeddings(claimed_at) WHERE claimed_at IS NOT NULL;`) + require.NoError(t, err) + return db +} + +func countPending(t *testing.T, db *sql.DB, generationID int64) int { + t.Helper() + var n int + err := db.QueryRow( + `SELECT COUNT(*) FROM pending_embeddings WHERE generation_id = ?`, + generationID, + ).Scan(&n) + require.NoError(t, err) + return n +} + +func countAvailable(t *testing.T, db *sql.DB, generationID int64) int { + t.Helper() + var n int + err := db.QueryRow( + `SELECT COUNT(*) FROM pending_embeddings WHERE generation_id = ? AND claimed_at IS NULL`, + generationID, + ).Scan(&n) + require.NoError(t, err) + return n +} + +func countClaimedByToken(t *testing.T, db *sql.DB, token string) int { + t.Helper() + var n int + err := db.QueryRow(`SELECT COUNT(*) FROM pending_embeddings WHERE claim_token = ?`, token).Scan(&n) + require.NoError(t, err) + return n +} + +func backdateClaims(t *testing.T, db *sql.DB, age time.Duration) { + t.Helper() + _, err := db.Exec( + `UPDATE pending_embeddings SET claimed_at = ? WHERE claimed_at IS NOT NULL`, + time.Now().Add(-age).Unix(), + ) + require.NoError(t, err) +} From 061a0a3042966855c88d5ec04fc87bd4badc0d31 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Thu, 25 Jun 2026 01:05:40 -0400 Subject: [PATCH 2/4] Replace embedding queue with backend-neutral vector toolkit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The initial extraction lifted msgvault's pending_embeddings claim queue, but that is storage and scheduling policy a consumer owns, and msgvault is already moving off it toward scan-and-fill (kenn-io/msgvault#411). Shipping it would have given kit a reusable package msgvault no longer uses. The surface that is genuinely shared across callers — msgvault and kata both embed for search — is chunking, model-generation identity, batched encoding, and merging results across generations. This adds those as pure transforms, plus a Store[K,G] contract and Fill/Search flows that own the scan-and-fill and query-and-merge orchestration so a backend supplies only SQL. Document and generation identity stay opaque (msgvault int64, kata UUID); storage and query construction live in backend subpackages. vector/sqlitevec is the reference backend on the same sqlite-vec binding msgvault uses, with per-generation vec0 tables so a model migration across differing dimensions still serves a union of live generations. This also fixes the go.mod tidy drift that was failing Go hygiene. Validation: go build/vet/test ./... with CGO; vec0 exercised hermetically via the bundled sqlite-vec extension. Generated with Claude Code (Opus 4.8) Co-Authored-By: Claude Opus 4.8 --- go.mod | 3 +- go.sum | 2 + vector/AGENTS.md | 41 ++++ vector/chunk.go | 48 +++++ vector/chunk_test.go | 84 ++++++++ vector/doc.go | 24 +++ vector/encode.go | 111 +++++++++++ vector/encode_test.go | 124 ++++++++++++ vector/encodingqueue/doc.go | 7 - vector/encodingqueue/queue.go | 297 ---------------------------- vector/encodingqueue/queue_test.go | 211 -------------------- vector/flow.go | 115 +++++++++++ vector/flow_test.go | 180 +++++++++++++++++ vector/generation.go | 50 +++++ vector/generation_test.go | 43 ++++ vector/search.go | 158 +++++++++++++++ vector/search_test.go | 109 ++++++++++ vector/sqlitevec/sqlitevec.go | 187 ++++++++++++++++++ vector/sqlitevec/sqlitevec_test.go | 141 +++++++++++++ vector/sqlitevec/store.go | 196 ++++++++++++++++++ vector/sqlitevec/vec0_smoke_test.go | 42 ++++ vector/store.go | 47 +++++ 22 files changed, 1704 insertions(+), 516 deletions(-) create mode 100644 vector/AGENTS.md create mode 100644 vector/chunk.go create mode 100644 vector/chunk_test.go create mode 100644 vector/doc.go create mode 100644 vector/encode.go create mode 100644 vector/encode_test.go delete mode 100644 vector/encodingqueue/doc.go delete mode 100644 vector/encodingqueue/queue.go delete mode 100644 vector/encodingqueue/queue_test.go create mode 100644 vector/flow.go create mode 100644 vector/flow_test.go create mode 100644 vector/generation.go create mode 100644 vector/generation_test.go create mode 100644 vector/search.go create mode 100644 vector/search_test.go create mode 100644 vector/sqlitevec/sqlitevec.go create mode 100644 vector/sqlitevec/sqlitevec_test.go create mode 100644 vector/sqlitevec/store.go create mode 100644 vector/sqlitevec/vec0_smoke_test.go create mode 100644 vector/store.go diff --git a/go.mod b/go.mod index 927d7ab..2cfdf19 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module go.kenn.io/kit go 1.26.3 require ( + github.com/asg017/sqlite-vec-go-bindings v0.1.6 github.com/gofrs/flock v0.13.0 + github.com/mattn/go-sqlite3 v1.14.44 github.com/posthog/posthog-go v1.12.6 github.com/stretchr/testify v1.11.1 go.opentelemetry.io/otel v1.43.0 @@ -32,7 +34,6 @@ require ( github.com/klauspost/compress v1.18.6 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.44 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect diff --git a/go.sum b/go.sum index a62d568..3d118e8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= +github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= diff --git a/vector/AGENTS.md b/vector/AGENTS.md new file mode 100644 index 0000000..6c86733 --- /dev/null +++ b/vector/AGENTS.md @@ -0,0 +1,41 @@ +# vector package invariants + +`go.kenn.io/kit/vector` owns the backend-neutral parts of an embedding +pipeline. Preserve these invariants when changing it. + +## The storage boundary is the point of this package + +- The core `vector` package must not import `database/sql`, a driver, or + any backend client, and must not construct backend SQL. The `Fill` and + `Search` flows reach storage only through the `Store[K, G]` interface. +- Persistence is a function of the caller's source system. Backends live + in their own subpackages (e.g. `vector/sqlitevec`) so a caller wiring + one backend never pulls another backend's driver. New backends + (pgvector, duckdb) go in sibling subpackages, not into the core. +- Backends own query construction. The differences between sqlite-vec + `vec0 MATCH`, pgvector `<=>`, and duckdb `array_distance` belong behind + `QueryGeneration`, never in the core flows. + +## Keys and generations are opaque + +- Document identity is the caller's type `K` and generation identity its + type `G`. msgvault uses `int64`; kata uses UUIDs. Compare them for + equality only; never assume a type, a single id namespace, or an + ordering. Backends additionally require `K`/`G` to be types + `database/sql` can bind and scan. + +## Merge semantics + +- `Merge` takes per-generation lists in descending preference and keeps + the earliest list's hit on overlap (prefer the newer generation during + a migration). Coverage is a union — never drop a document that only one + generation covers, and never emit duplicates. +- Cross-generation scores are not comparable. Default to + `MergeNormalizedScore`; raw-score merging is opt-in. + +## Generations during migration + +- The mid-migration union exists because new documents land only in the + building generation while the active generation still serves the bulk. + `Search` must keep querying every generation `LiveGenerations` returns, + in the order it returns them. diff --git a/vector/chunk.go b/vector/chunk.go new file mode 100644 index 0000000..78239dd --- /dev/null +++ b/vector/chunk.go @@ -0,0 +1,48 @@ +package vector + +// Chunk is a window of text encoded as a single vector. Index is the +// chunk's position within the source content, starting at zero. +type Chunk struct { + Index int + Text string +} + +// SplitOptions controls how Split windows content into chunks. +type SplitOptions struct { + // MaxRunes bounds the number of runes in each chunk. Values <= 0 + // disable splitting and return the content as a single chunk. + MaxRunes int + // Overlap is the number of runes shared between consecutive chunks. + // It is clamped to the range [0, MaxRunes-1]. + Overlap int +} + +// Split windows content into overlapping chunks of at most MaxRunes runes. +// It splits on runes rather than bytes so multi-byte characters are never +// torn apart. Empty content yields no chunks. +// +// Split measures size in runes, not model tokens. Callers that budget by +// tokens should convert their token budget to an approximate rune count. +func Split(content string, o SplitOptions) []Chunk { + if content == "" { + return nil + } + runes := []rune(content) + if o.MaxRunes <= 0 || len(runes) <= o.MaxRunes { + return []Chunk{{Index: 0, Text: content}} + } + + overlap := min(max(o.Overlap, 0), o.MaxRunes-1) + stride := o.MaxRunes - overlap + + var chunks []Chunk + for start, idx := 0, 0; start < len(runes); start, idx = start+stride, idx+1 { + end := start + o.MaxRunes + if end >= len(runes) { + chunks = append(chunks, Chunk{Index: idx, Text: string(runes[start:])}) + break + } + chunks = append(chunks, Chunk{Index: idx, Text: string(runes[start:end])}) + } + return chunks +} diff --git a/vector/chunk_test.go b/vector/chunk_test.go new file mode 100644 index 0000000..e2af4f9 --- /dev/null +++ b/vector/chunk_test.go @@ -0,0 +1,84 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func TestSplit(t *testing.T) { + tests := []struct { + name string + content string + opts vector.SplitOptions + want []vector.Chunk + }{ + { + name: "empty yields no chunks", + content: "", + opts: vector.SplitOptions{MaxRunes: 4}, + want: nil, + }, + { + name: "non-positive max returns single chunk", + content: "hello world", + opts: vector.SplitOptions{MaxRunes: 0}, + want: []vector.Chunk{{Index: 0, Text: "hello world"}}, + }, + { + name: "content shorter than max is one chunk", + content: "abcd", + opts: vector.SplitOptions{MaxRunes: 8}, + want: []vector.Chunk{{Index: 0, Text: "abcd"}}, + }, + { + name: "windows without overlap", + content: "abcdefghij", + opts: vector.SplitOptions{MaxRunes: 5}, + want: []vector.Chunk{ + {Index: 0, Text: "abcde"}, + {Index: 1, Text: "fghij"}, + }, + }, + { + name: "windows with overlap", + content: "abcdefghij", + opts: vector.SplitOptions{MaxRunes: 4, Overlap: 1}, + want: []vector.Chunk{ + {Index: 0, Text: "abcd"}, + {Index: 1, Text: "defg"}, + {Index: 2, Text: "ghij"}, + }, + }, + { + name: "overlap at or above max clamps to max-1", + content: "abcdef", + opts: vector.SplitOptions{MaxRunes: 3, Overlap: 9}, + want: []vector.Chunk{ + {Index: 0, Text: "abc"}, + {Index: 1, Text: "bcd"}, + {Index: 2, Text: "cde"}, + {Index: 3, Text: "def"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, vector.Split(tt.content, tt.opts)) + }) + } +} + +func TestSplitDoesNotTearMultiByteRunes(t *testing.T) { + assert := assert.New(t) + // Each emoji is multiple bytes but one rune. + chunks := vector.Split("😀😁😂🤣", vector.SplitOptions{MaxRunes: 2}) + + assert.Equal([]vector.Chunk{ + {Index: 0, Text: "😀😁"}, + {Index: 1, Text: "😂🤣"}, + }, chunks) +} diff --git a/vector/doc.go b/vector/doc.go new file mode 100644 index 0000000..6d96ae3 --- /dev/null +++ b/vector/doc.go @@ -0,0 +1,24 @@ +// Package vector provides backend-neutral building blocks for embedding +// content and searching the resulting vectors. +// +// It is organized in three layers: +// +// - Transforms and value types: Split windows content into chunks, +// Generation identifies an embedding model, EncodeBatched batches +// encode calls, and RollupByDocument and Merge reduce and combine +// search results across generations. These are pure functions. +// +// - The Store contract: Store[K, G] is the persistence interface the +// flows depend on. Implementations are a function of the caller's +// source system and own all backend SQL and query construction; see +// the sqlitevec subpackage for a worked example. +// +// - Flows: Fill runs the scan-and-fill embedding loop and Search runs +// the cross-generation query-and-merge, both over a Store. +// +// Nothing in this package opens a database, holds an index, or constructs +// backend SQL — the flows delegate every storage operation to the Store. +// Document identity is the caller's own key type K, and generation +// identity its type G; the package compares both for equality but never +// interprets them. +package vector diff --git a/vector/encode.go b/vector/encode.go new file mode 100644 index 0000000..59458bf --- /dev/null +++ b/vector/encode.go @@ -0,0 +1,111 @@ +package vector + +import ( + "context" + "fmt" + "sync" +) + +// Vector is a single embedding. +type Vector []float32 + +// EncodeFunc turns a batch of texts into one vector each, in the same +// order. Implementations own the model or API client and any retry or +// backoff policy, since retryability is provider-specific. +type EncodeFunc func(ctx context.Context, texts []string) ([][]float32, error) + +// BatchOptions controls how EncodeBatched groups and parallelizes calls. +type BatchOptions struct { + // BatchSize is the maximum number of chunks passed to EncodeFunc in a + // single call. Values <= 0 send every chunk in one call. + BatchSize int + // Concurrency bounds how many EncodeFunc calls run at once. Values + // <= 0 mean one call at a time. + Concurrency int +} + +// EncodeBatched splits chunks into batches, invokes enc with bounded +// concurrency, and returns one Vector per input chunk in input order. It +// stops launching work at the first error or when ctx is cancelled, and +// reports the first error encountered. +func EncodeBatched(ctx context.Context, enc EncodeFunc, chunks []Chunk, o BatchOptions) ([]Vector, error) { + if enc == nil { + return nil, fmt.Errorf("encode func is nil") + } + if len(chunks) == 0 { + return nil, nil + } + + batchSize := o.BatchSize + if batchSize <= 0 { + batchSize = len(chunks) + } + concurrency := o.Concurrency + if concurrency <= 0 { + concurrency = 1 + } + + out := make([]Vector, len(chunks)) + sem := make(chan struct{}, concurrency) + var ( + wg sync.WaitGroup + mu sync.Mutex + firstErr error + ) + failed := func() bool { + mu.Lock() + defer mu.Unlock() + return firstErr != nil + } + setErr := func(err error) { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + + for start := 0; start < len(chunks); start += batchSize { + if ctx.Err() != nil { + setErr(ctx.Err()) + break + } + if failed() { + break + } + + end := min(start+batchSize, len(chunks)) + texts := make([]string, end-start) + for i, c := range chunks[start:end] { + texts[i] = c.Text + } + + sem <- struct{}{} + wg.Add(1) + go func(start int, texts []string) { + defer wg.Done() + defer func() { <-sem }() + + vecs, err := enc(ctx, texts) + if err != nil { + setErr(fmt.Errorf("encode batch at %d: %w", start, err)) + return + } + if len(vecs) != len(texts) { + setErr(fmt.Errorf("encode batch at %d: got %d vectors for %d texts", start, len(vecs), len(texts))) + return + } + // Each batch owns a disjoint index range, so writes to out + // never overlap across goroutines. + for i, v := range vecs { + out[start+i] = Vector(v) + } + }(start, texts) + } + + wg.Wait() + if firstErr != nil { + return nil, firstErr + } + return out, nil +} diff --git a/vector/encode_test.go b/vector/encode_test.go new file mode 100644 index 0000000..f783493 --- /dev/null +++ b/vector/encode_test.go @@ -0,0 +1,124 @@ +package vector_test + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" +) + +func chunks(texts ...string) []vector.Chunk { + out := make([]vector.Chunk, len(texts)) + for i, txt := range texts { + out[i] = vector.Chunk{Index: i, Text: txt} + } + return out +} + +// echoEncoder returns one vector per text whose single component encodes +// the text length, so results can be matched back to their input order. +func echoEncoder(record func(batch []string)) vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + if record != nil { + record(texts) + } + out := make([][]float32, len(texts)) + for i, txt := range texts { + out[i] = []float32{float32(len(txt))} + } + return out, nil + } +} + +func TestEncodeBatchedPreservesOrderAcrossBatches(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + var mu sync.Mutex + var sizes []int + enc := echoEncoder(func(batch []string) { + mu.Lock() + sizes = append(sizes, len(batch)) + mu.Unlock() + }) + + in := chunks("a", "bb", "ccc", "dddd", "eeeee") + out, err := vector.EncodeBatched(context.Background(), enc, in, vector.BatchOptions{BatchSize: 2, Concurrency: 3}) + require.NoError(err) + require.Len(out, len(in)) + for i, c := range in { + assert.Equal(float32(len(c.Text)), out[i][0], "vector %d matches its input", i) + } + + mu.Lock() + defer mu.Unlock() + assert.ElementsMatch([]int{2, 2, 1}, sizes, "batches are sized by BatchSize") +} + +func TestEncodeBatchedRespectsConcurrencyBound(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + var inFlight, maxInFlight atomic.Int64 + enc := func(_ context.Context, texts []string) ([][]float32, error) { + cur := inFlight.Add(1) + for { + prev := maxInFlight.Load() + if cur <= prev || maxInFlight.CompareAndSwap(prev, cur) { + break + } + } + defer inFlight.Add(-1) + out := make([][]float32, len(texts)) + return out, nil + } + + in := chunks("a", "b", "c", "d", "e", "f", "g", "h") + _, err := vector.EncodeBatched(context.Background(), enc, in, vector.BatchOptions{BatchSize: 1, Concurrency: 2}) + require.NoError(err) + assert.LessOrEqual(maxInFlight.Load(), int64(2), "never exceeds the concurrency bound") +} + +func TestEncodeBatchedSurfacesEncodeError(t *testing.T) { + assert := assert.New(t) + sentinel := errors.New("boom") + enc := func(_ context.Context, _ []string) ([][]float32, error) { return nil, sentinel } + + _, err := vector.EncodeBatched(context.Background(), enc, chunks("a", "b"), vector.BatchOptions{BatchSize: 1}) + assert.ErrorIs(err, sentinel) +} + +func TestEncodeBatchedRejectsCountMismatch(t *testing.T) { + assert := assert.New(t) + enc := func(_ context.Context, _ []string) ([][]float32, error) { + return [][]float32{{1}}, nil // one vector for two texts + } + + _, err := vector.EncodeBatched(context.Background(), enc, chunks("a", "b"), vector.BatchOptions{}) + assert.ErrorContains(err, "vectors for") +} + +func TestEncodeBatchedNilEncoder(t *testing.T) { + _, err := vector.EncodeBatched(context.Background(), nil, chunks("a"), vector.BatchOptions{}) + assert.Error(t, err) +} + +func TestEncodeBatchedEmptyInput(t *testing.T) { + assert := assert.New(t) + out, err := vector.EncodeBatched(context.Background(), echoEncoder(nil), nil, vector.BatchOptions{}) + assert.NoError(err) + assert.Empty(out) +} + +func TestEncodeBatchedStopsOnCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := vector.EncodeBatched(ctx, echoEncoder(nil), chunks("a", "b"), vector.BatchOptions{BatchSize: 1}) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/vector/encodingqueue/doc.go b/vector/encodingqueue/doc.go deleted file mode 100644 index a86385f..0000000 --- a/vector/encodingqueue/doc.go +++ /dev/null @@ -1,7 +0,0 @@ -// Package encodingqueue provides SQLite-compatible queue helpers for -// vector encoding tasks. -// -// The package expects one pending-task table with a compound primary key -// over generation and task IDs, an enqueue timestamp, and nullable claim -// fields. Callers own schema creation and database lifecycle. -package encodingqueue diff --git a/vector/encodingqueue/queue.go b/vector/encodingqueue/queue.go deleted file mode 100644 index d9056bc..0000000 --- a/vector/encodingqueue/queue.go +++ /dev/null @@ -1,297 +0,0 @@ -package encodingqueue - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/hex" - "encoding/json" - "fmt" - "regexp" - "sort" - "time" -) - -var identifierPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) - -// Schema names the pending-task table and columns used by Queue. -// -// Identifiers are interpolated into SQL after validation, so every field -// must be a simple SQL identifier rather than arbitrary SQL text. -type Schema struct { - Table string - GroupIDColumn string - TaskIDColumn string - EnqueuedAtColumn string - ClaimedAtColumn string - ClaimTokenColumn string -} - -// DefaultSchema returns the table shape used by msgvault's vector -// embedding queue. -func DefaultSchema() Schema { - return Schema{ - Table: "pending_embeddings", - GroupIDColumn: "generation_id", - TaskIDColumn: "message_id", - EnqueuedAtColumn: "enqueued_at", - ClaimedAtColumn: "claimed_at", - ClaimTokenColumn: "claim_token", - } -} - -// Queue wraps a SQL pending-task table with crash-safe claim, complete, -// release, and stale-claim reclamation operations. -type Queue struct { - db *sql.DB - schema Schema -} - -type execContext interface { - ExecContext(context.Context, string, ...any) (sql.Result, error) -} - -// NewQueue returns a Queue bound to db. The caller retains ownership of -// db; Queue does not close it. -func NewQueue(db *sql.DB, schema Schema) (*Queue, error) { - if err := schema.validate(); err != nil { - return nil, err - } - if db == nil { - return nil, fmt.Errorf("db is nil") - } - return &Queue{db: db, schema: schema}, nil -} - -// Enqueue inserts every task ID once for every group ID. Duplicate -// (group, task) pairs are ignored by the pending table's uniqueness -// constraint. -func (q *Queue) Enqueue(ctx context.Context, groupIDs []int64, taskIDs []int64) error { - if len(groupIDs) == 0 || len(taskIDs) == 0 { - return nil - } - tx, err := q.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("begin enqueue tx: %w", err) - } - defer func() { _ = tx.Rollback() }() - - if err := q.EnqueueTx(ctx, tx, groupIDs, taskIDs); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return fmt.Errorf("commit enqueue: %w", err) - } - return nil -} - -// EnqueueTx inserts every task ID once for every group ID using tx. The -// caller is responsible for committing or rolling back tx. -func (q *Queue) EnqueueTx(ctx context.Context, tx *sql.Tx, groupIDs []int64, taskIDs []int64) error { - if tx == nil { - return fmt.Errorf("tx is nil") - } - return q.enqueueWith(ctx, tx, groupIDs, taskIDs) -} - -// Claim marks up to batch available rows for groupID as claimed by a -// fresh token, returning task IDs in ascending order with the token to -// pass to Complete or Release. -func (q *Queue) Claim(ctx context.Context, groupID int64, batch int) ([]int64, string, error) { - if batch <= 0 { - return nil, "", nil - } - token, err := newToken() - if err != nil { - return nil, "", fmt.Errorf("new token: %w", err) - } - now := time.Now().Unix() - - tx, err := q.db.BeginTx(ctx, nil) - if err != nil { - return nil, "", fmt.Errorf("begin claim tx: %w", err) - } - defer func() { _ = tx.Rollback() }() - - rows, err := tx.QueryContext(ctx, q.claimSQL(), now, token, groupID, batch) - if err != nil { - return nil, "", fmt.Errorf("claim query: %w", err) - } - var ids []int64 - for rows.Next() { - var id int64 - if err := rows.Scan(&id); err != nil { - _ = rows.Close() - return nil, "", fmt.Errorf("scan claimed task id: %w", err) - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - _ = rows.Close() - return nil, "", fmt.Errorf("claim rows: %w", err) - } - if err := rows.Close(); err != nil { - return nil, "", fmt.Errorf("close claim rows: %w", err) - } - if err := tx.Commit(); err != nil { - return nil, "", fmt.Errorf("commit claim: %w", err) - } - if len(ids) == 0 { - return nil, "", nil - } - sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) - return ids, token, nil -} - -// Complete deletes claimed rows whose claim token still matches token. -func (q *Queue) Complete(ctx context.Context, groupID int64, token string, taskIDs []int64) error { - if len(taskIDs) == 0 { - return nil - } - return q.updateClaimedIDs(ctx, "delete pending tasks", q.completeSQL(), groupID, token, taskIDs) -} - -// Release clears matching claims so tasks can be retried by another -// worker. -func (q *Queue) Release(ctx context.Context, groupID int64, token string, taskIDs []int64) error { - if len(taskIDs) == 0 { - return nil - } - return q.updateClaimedIDs(ctx, "release pending tasks", q.releaseSQL(), groupID, token, taskIDs) -} - -// ReclaimStale clears claims older than olderThan and returns the number -// of rows reclaimed. -func (q *Queue) ReclaimStale(ctx context.Context, olderThan time.Duration) (int, error) { - cutoff := time.Now().Add(-olderThan).Unix() - res, err := q.db.ExecContext(ctx, q.reclaimSQL(), cutoff) - if err != nil { - return 0, fmt.Errorf("reclaim stale: %w", err) - } - n, err := res.RowsAffected() - if err != nil { - return 0, fmt.Errorf("rows affected: %w", err) - } - return int(n), nil -} - -func (q *Queue) updateClaimedIDs(ctx context.Context, op, query string, groupID int64, token string, taskIDs []int64) error { - blob, err := json.Marshal(taskIDs) - if err != nil { - return fmt.Errorf("encode task ids: %w", err) - } - if _, err := q.db.ExecContext(ctx, query, groupID, token, string(blob)); err != nil { - return fmt.Errorf("%s: %w", op, err) - } - return nil -} - -func (q *Queue) enqueueWith(ctx context.Context, execer execContext, groupIDs []int64, taskIDs []int64) error { - if len(groupIDs) == 0 || len(taskIDs) == 0 { - return nil - } - blob, err := json.Marshal(taskIDs) - if err != nil { - return fmt.Errorf("encode task ids: %w", err) - } - now := time.Now().Unix() - for _, groupID := range groupIDs { - if _, err := execer.ExecContext(ctx, q.enqueueSQL(), groupID, now, string(blob)); err != nil { - return fmt.Errorf("insert pending tasks (group=%d): %w", groupID, err) - } - } - return nil -} - -func (q *Queue) enqueueSQL() string { - s := q.schema - return fmt.Sprintf(` - INSERT OR IGNORE INTO %s (%s, %s, %s) - SELECT ?, value, ? FROM json_each(?)`, - s.Table, s.GroupIDColumn, s.TaskIDColumn, s.EnqueuedAtColumn) -} - -func (q *Queue) claimSQL() string { - s := q.schema - return fmt.Sprintf(` - UPDATE %s - SET %s = ?, %s = ? - WHERE (%s, %s) IN ( - SELECT %s, %s - FROM %s - WHERE %s = ? - AND %s IS NULL - ORDER BY %s - LIMIT ?) - RETURNING %s`, - s.Table, - s.ClaimedAtColumn, s.ClaimTokenColumn, - s.GroupIDColumn, s.TaskIDColumn, - s.GroupIDColumn, s.TaskIDColumn, - s.Table, - s.GroupIDColumn, - s.ClaimedAtColumn, - s.TaskIDColumn, - s.TaskIDColumn) -} - -func (q *Queue) completeSQL() string { - s := q.schema - return fmt.Sprintf(` - DELETE FROM %s - WHERE %s = ? - AND %s = ? - AND %s IN (SELECT value FROM json_each(?))`, - s.Table, s.GroupIDColumn, s.ClaimTokenColumn, s.TaskIDColumn) -} - -func (q *Queue) releaseSQL() string { - s := q.schema - return fmt.Sprintf(` - UPDATE %s - SET %s = NULL, %s = NULL - WHERE %s = ? - AND %s = ? - AND %s IN (SELECT value FROM json_each(?))`, - s.Table, - s.ClaimedAtColumn, s.ClaimTokenColumn, - s.GroupIDColumn, - s.ClaimTokenColumn, - s.TaskIDColumn) -} - -func (q *Queue) reclaimSQL() string { - s := q.schema - return fmt.Sprintf(` - UPDATE %s - SET %s = NULL, %s = NULL - WHERE %s IS NOT NULL AND %s < ?`, - s.Table, - s.ClaimedAtColumn, s.ClaimTokenColumn, - s.ClaimedAtColumn, s.ClaimedAtColumn) -} - -func (s Schema) validate() error { - fields := map[string]string{ - "table": s.Table, - "group id column": s.GroupIDColumn, - "task id column": s.TaskIDColumn, - "enqueued at column": s.EnqueuedAtColumn, - "claimed at column": s.ClaimedAtColumn, - "claim token column": s.ClaimTokenColumn, - } - for name, value := range fields { - if !identifierPattern.MatchString(value) { - return fmt.Errorf("invalid %s %q", name, value) - } - } - return nil -} - -func newToken() (string, error) { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - return "", fmt.Errorf("read random: %w", err) - } - return hex.EncodeToString(b), nil -} diff --git a/vector/encodingqueue/queue_test.go b/vector/encodingqueue/queue_test.go deleted file mode 100644 index ea7f89d..0000000 --- a/vector/encodingqueue/queue_test.go +++ /dev/null @@ -1,211 +0,0 @@ -package encodingqueue_test - -import ( - "context" - "database/sql" - "path/filepath" - "sort" - "testing" - "time" - - _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "go.kenn.io/kit/vector/encodingqueue" -) - -func TestQueueEnqueueClaimReleaseComplete(t *testing.T) { - ctx := context.Background() - require := require.New(t) - assert := assert.New(t) - db := openQueueDB(t) - q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) - require.NoError(err) - - require.NoError(q.Enqueue(ctx, []int64{1, 2}, []int64{10, 11, 12})) - - ids, token, err := q.Claim(ctx, 1, 2) - require.NoError(err) - require.NotEmpty(token) - assert.Equal([]int64{10, 11}, ids) - assert.Equal(1, countAvailable(t, db, 1)) - - require.NoError(q.Release(ctx, 1, token, ids)) - assert.Equal(3, countAvailable(t, db, 1)) - - more, token2, err := q.Claim(ctx, 1, 10) - require.NoError(err) - require.NotEqual(token, token2) - assert.Equal([]int64{10, 11, 12}, more) - - require.NoError(q.Complete(ctx, 1, token2, more)) - assert.Equal(0, countPending(t, db, 1)) - assert.Equal(3, countPending(t, db, 2)) -} - -func TestQueueIgnoresDuplicateEnqueue(t *testing.T) { - ctx := context.Background() - require := require.New(t) - assert := assert.New(t) - db := openQueueDB(t) - q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) - require.NoError(err) - - require.NoError(q.Enqueue(ctx, []int64{1}, []int64{42})) - require.NoError(q.Enqueue(ctx, []int64{1}, []int64{42, 42})) - - assert.Equal(1, countPending(t, db, 1)) -} - -func TestQueueEnqueueTxUsesCallerTransaction(t *testing.T) { - ctx := context.Background() - require := require.New(t) - assert := assert.New(t) - db := openQueueDB(t) - q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) - require.NoError(err) - - tx, err := db.BeginTx(ctx, nil) - require.NoError(err) - require.NoError(q.EnqueueTx(ctx, tx, []int64{1}, []int64{10, 11})) - assert.Equal(0, countPending(t, db, 1)) - require.NoError(tx.Commit()) - - assert.Equal(2, countPending(t, db, 1)) -} - -func TestQueueClaimReturnsEmptyForNoWork(t *testing.T) { - ctx := context.Background() - require := require.New(t) - assert := assert.New(t) - db := openQueueDB(t) - q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) - require.NoError(err) - - ids, token, err := q.Claim(ctx, 1, 10) - require.NoError(err) - - assert.Empty(ids) - assert.Empty(token) -} - -func TestQueueReclaimStalePreservesNewClaimFromLateComplete(t *testing.T) { - ctx := context.Background() - require := require.New(t) - assert := assert.New(t) - db := openQueueDB(t) - q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) - require.NoError(err) - require.NoError(q.Enqueue(ctx, []int64{1}, []int64{1, 2})) - - idsA, tokenA, err := q.Claim(ctx, 1, 2) - require.NoError(err) - require.Len(idsA, 2) - backdateClaims(t, db, 20*time.Minute) - - reclaimed, err := q.ReclaimStale(ctx, 10*time.Minute) - require.NoError(err) - assert.Equal(2, reclaimed) - - idsB, tokenB, err := q.Claim(ctx, 1, 2) - require.NoError(err) - require.Len(idsB, 2) - require.NotEqual(tokenA, tokenB) - - require.NoError(q.Complete(ctx, 1, tokenA, idsA)) - assert.Equal(2, countPending(t, db, 1)) - assert.Equal(2, countClaimedByToken(t, db, tokenB)) - - require.NoError(q.Complete(ctx, 1, tokenB, idsB)) - assert.Equal(0, countPending(t, db, 1)) -} - -func TestQueueClaimReturnsIDsAscending(t *testing.T) { - ctx := context.Background() - require := require.New(t) - assert := assert.New(t) - db := openQueueDB(t) - q, err := encodingqueue.NewQueue(db, encodingqueue.DefaultSchema()) - require.NoError(err) - require.NoError(q.Enqueue(ctx, []int64{1}, []int64{9, 3, 7, 1})) - - ids, _, err := q.Claim(ctx, 1, 10) - require.NoError(err) - - assert.True(sort.SliceIsSorted(ids, func(i, j int) bool { return ids[i] < ids[j] }), "ids: %v", ids) -} - -func TestNewQueueRejectsUnsafeIdentifiers(t *testing.T) { - _, err := encodingqueue.NewQueue(nil, encodingqueue.Schema{ - Table: "pending_embeddings; DROP TABLE pending_embeddings", - GroupIDColumn: "generation_id", - TaskIDColumn: "message_id", - EnqueuedAtColumn: "enqueued_at", - ClaimedAtColumn: "claimed_at", - ClaimTokenColumn: "claim_token", - }) - - require.Error(t, err) -} - -func openQueueDB(t *testing.T) *sql.DB { - t.Helper() - db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "queue.db")) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, db.Close()) }) - _, err = db.Exec(` -CREATE TABLE pending_embeddings ( - generation_id INTEGER NOT NULL, - message_id INTEGER NOT NULL, - enqueued_at INTEGER NOT NULL, - claimed_at INTEGER, - claim_token TEXT, - PRIMARY KEY (generation_id, message_id) -); -CREATE INDEX idx_pending_available - ON pending_embeddings(generation_id, message_id) WHERE claimed_at IS NULL; -CREATE INDEX idx_pending_claims - ON pending_embeddings(claimed_at) WHERE claimed_at IS NOT NULL;`) - require.NoError(t, err) - return db -} - -func countPending(t *testing.T, db *sql.DB, generationID int64) int { - t.Helper() - var n int - err := db.QueryRow( - `SELECT COUNT(*) FROM pending_embeddings WHERE generation_id = ?`, - generationID, - ).Scan(&n) - require.NoError(t, err) - return n -} - -func countAvailable(t *testing.T, db *sql.DB, generationID int64) int { - t.Helper() - var n int - err := db.QueryRow( - `SELECT COUNT(*) FROM pending_embeddings WHERE generation_id = ? AND claimed_at IS NULL`, - generationID, - ).Scan(&n) - require.NoError(t, err) - return n -} - -func countClaimedByToken(t *testing.T, db *sql.DB, token string) int { - t.Helper() - var n int - err := db.QueryRow(`SELECT COUNT(*) FROM pending_embeddings WHERE claim_token = ?`, token).Scan(&n) - require.NoError(t, err) - return n -} - -func backdateClaims(t *testing.T, db *sql.DB, age time.Duration) { - t.Helper() - _, err := db.Exec( - `UPDATE pending_embeddings SET claimed_at = ? WHERE claimed_at IS NOT NULL`, - time.Now().Add(-age).Unix(), - ) - require.NoError(t, err) -} diff --git a/vector/flow.go b/vector/flow.go new file mode 100644 index 0000000..65c73c9 --- /dev/null +++ b/vector/flow.go @@ -0,0 +1,115 @@ +package vector + +import ( + "context" + "fmt" +) + +// FillOptions configures Fill. +type FillOptions struct { + // ScanBatch is the number of pending documents fetched per scan. + // Values <= 0 use 128. + ScanBatch int + // Split controls how each document's content is windowed into chunks. + Split SplitOptions + // Batch controls how chunks are batched into encode calls. + Batch BatchOptions +} + +// FillStats reports what a Fill run embedded. +type FillStats struct { + Documents int + Chunks int +} + +// Fill embeds every document that still needs the target generation: it +// scans the store for pending documents, splits and encodes each, and +// saves the resulting vectors, repeating until no documents remain. It is +// the generic scan-and-fill loop; the store decides what counts as +// pending and persists the results. +func Fill[K, G comparable](ctx context.Context, store Store[K, G], gen G, enc EncodeFunc, o FillOptions) (FillStats, error) { + scanBatch := o.ScanBatch + if scanBatch <= 0 { + scanBatch = 128 + } + + var stats FillStats + for { + if err := ctx.Err(); err != nil { + return stats, err + } + pending, err := store.PendingForGeneration(ctx, gen, scanBatch) + if err != nil { + return stats, fmt.Errorf("scan pending: %w", err) + } + if len(pending) == 0 { + return stats, nil + } + + for _, p := range pending { + chunks := Split(p.Content, o.Split) + vectors, err := EncodeBatched(ctx, enc, chunks, o.Batch) + if err != nil { + return stats, fmt.Errorf("encode document %v: %w", p.Doc, err) + } + cvs := make([]ChunkVector, len(chunks)) + for i, c := range chunks { + cvs[i] = ChunkVector{ChunkIndex: c.Index, Vector: vectors[i]} + } + if err := store.SaveVectors(ctx, gen, p.Doc, cvs); err != nil { + return stats, fmt.Errorf("save document %v: %w", p.Doc, err) + } + stats.Documents++ + stats.Chunks += len(cvs) + } + } +} + +// SearchOptions configures Search. +type SearchOptions struct { + // PerGeneration caps how many hits are fetched from each generation + // before merging. Values <= 0 use 50. + PerGeneration int + // Merge configures how per-generation results are combined. + Merge MergeOptions +} + +// Search embeds queryText once per live generation (each may use a +// different model), queries each generation, rolls the chunk hits up to +// documents, and merges the per-generation results into one ranking. +// encFor maps a generation to the encoder for that generation's model. +func Search[K, G comparable]( + ctx context.Context, + store Store[K, G], + queryText string, + encFor func(gen G) EncodeFunc, + o SearchOptions, +) ([]Hit[K], error) { + perGen := o.PerGeneration + if perGen <= 0 { + perGen = 50 + } + + gens, err := store.LiveGenerations(ctx) + if err != nil { + return nil, fmt.Errorf("live generations: %w", err) + } + + lists := make([][]Hit[K], 0, len(gens)) + for _, gen := range gens { + enc := encFor(gen) + if enc == nil { + return nil, fmt.Errorf("no encoder for generation %v", gen) + } + vectors, err := EncodeBatched(ctx, enc, []Chunk{{Index: 0, Text: queryText}}, BatchOptions{}) + if err != nil { + return nil, fmt.Errorf("embed query for generation %v: %w", gen, err) + } + hits, err := store.QueryGeneration(ctx, gen, vectors[0], perGen) + if err != nil { + return nil, fmt.Errorf("query generation %v: %w", gen, err) + } + lists = append(lists, RollupByDocument(hits)) + } + return Merge(lists, o.Merge), nil +} diff --git a/vector/flow_test.go b/vector/flow_test.go new file mode 100644 index 0000000..a7135d5 --- /dev/null +++ b/vector/flow_test.go @@ -0,0 +1,180 @@ +package vector_test + +import ( + "context" + "math" + "slices" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" +) + +// memStore is an in-memory Store[int64, int] used to exercise the flows +// without any real backend. Documents are keyed by int64; generations by +// int. QueryGeneration ranks by cosine similarity over stored vectors. +type memStore struct { + content map[int64]string + embedded map[int64]map[int]bool // doc -> gen -> done + vectors map[int]map[int64][]vector.ChunkVector // gen -> doc -> chunks + live []int // descending preference +} + +func newMemStore() *memStore { + return &memStore{ + content: map[int64]string{}, + embedded: map[int64]map[int]bool{}, + vectors: map[int]map[int64][]vector.ChunkVector{}, + } +} + +func (m *memStore) PendingForGeneration(_ context.Context, gen int, limit int) ([]vector.Pending[int64], error) { + keys := make([]int64, 0, len(m.content)) + for doc := range m.content { + if !m.embedded[doc][gen] { + keys = append(keys, doc) + } + } + slices.Sort(keys) + if limit > 0 && len(keys) > limit { + keys = keys[:limit] + } + out := make([]vector.Pending[int64], len(keys)) + for i, doc := range keys { + out[i] = vector.Pending[int64]{Doc: doc, Content: m.content[doc]} + } + return out, nil +} + +func (m *memStore) SaveVectors(_ context.Context, gen int, doc int64, vecs []vector.ChunkVector) error { + if m.vectors[gen] == nil { + m.vectors[gen] = map[int64][]vector.ChunkVector{} + } + m.vectors[gen][doc] = vecs + if m.embedded[doc] == nil { + m.embedded[doc] = map[int]bool{} + } + m.embedded[doc][gen] = true + return nil +} + +func (m *memStore) LiveGenerations(_ context.Context) ([]int, error) { + return m.live, nil +} + +func (m *memStore) QueryGeneration(_ context.Context, gen int, query vector.Vector, limit int) ([]vector.Hit[int64], error) { + var hits []vector.Hit[int64] + for doc, chunks := range m.vectors[gen] { + for _, cv := range chunks { + hits = append(hits, vector.Hit[int64]{Doc: doc, ChunkIndex: cv.ChunkIndex, Score: cosine(query, cv.Vector)}) + } + } + sort.SliceStable(hits, func(i, j int) bool { return hits[i].Score > hits[j].Score }) + if limit > 0 && len(hits) > limit { + hits = hits[:limit] + } + return hits, nil +} + +func cosine(a, b vector.Vector) float32 { + var dot, na, nb float64 + for i := range a { + dot += float64(a[i]) * float64(b[i]) + na += float64(a[i]) * float64(a[i]) + nb += float64(b[i]) * float64(b[i]) + } + if na == 0 || nb == 0 { + return 0 + } + return float32(dot / (math.Sqrt(na) * math.Sqrt(nb))) +} + +// lenEncoder embeds each text as a 1-D vector of its rune length, enough +// to confirm Fill wired chunk content through to SaveVectors. +func lenEncoder() vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i, txt := range texts { + out[i] = []float32{float32(len([]rune(txt)))} + } + return out, nil + } +} + +func TestFillEmbedsAllPendingThenStops(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + + store := newMemStore() + store.content[1] = "alpha" + store.content[2] = "beta gamma delta" + + stats, err := vector.Fill(ctx, store, 7, lenEncoder(), vector.FillOptions{ + ScanBatch: 1, // force multiple scan rounds + Split: vector.SplitOptions{MaxRunes: 4, Overlap: 0}, + }) + require.NoError(err) + + assert.Equal(2, stats.Documents) + assert.True(store.embedded[1][7] && store.embedded[2][7], "both docs stamped for gen 7") + require.Len(store.vectors[7][1], 2, "alpha -> 2 chunks of <=4 runes") + assert.InDelta(4, store.vectors[7][1][0].Vector[0], 1e-6, "first chunk carries its rune length") + + // A second run finds nothing pending and embeds zero documents. + again, err := vector.Fill(ctx, store, 7, lenEncoder(), vector.FillOptions{}) + require.NoError(err) + assert.Equal(0, again.Documents) +} + +func TestSearchRollsUpAndPrefersBuildingGeneration(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + + const active, building = 7, 9 + store := newMemStore() + store.live = []int{building, active} // descending preference + + // Doc 1 is shared; active stored it at chunk 0, building at chunk 5. + store.SaveVectors(ctx, active, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0}}}) + store.SaveVectors(ctx, active, 2, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{0, 1}}}) + store.SaveVectors(ctx, building, 1, []vector.ChunkVector{{ChunkIndex: 5, Vector: vector.Vector{1, 0}}}) + store.SaveVectors(ctx, building, 3, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0}}}) // new, building-only + + // Query vector [1,0] points at docs 1 and 3. + queryEnc := func(int) vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{1, 0} + } + return out, nil + } + } + + got, err := vector.Search(ctx, store, "q", queryEnc, vector.SearchOptions{}) + require.NoError(err) + + byDoc := map[int64]vector.Hit[int64]{} + for _, h := range got { + byDoc[h.Doc] = h + } + assert.Contains(byDoc, int64(1)) + assert.Contains(byDoc, int64(2), "active-only doc is not dropped (union coverage)") + assert.Contains(byDoc, int64(3), "building-only new doc is searchable mid-migration") + assert.Equal(5, byDoc[1].ChunkIndex, "shared doc keeps the building generation's hit") +} + +func TestSearchErrorsWhenNoEncoderForGeneration(t *testing.T) { + ctx := context.Background() + store := newMemStore() + store.live = []int{1} + store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1}}}) + + _, err := vector.Search(ctx, store, "q", func(int) vector.EncodeFunc { return nil }, vector.SearchOptions{}) + assert.ErrorContains(t, err, "no encoder") +} diff --git a/vector/generation.go b/vector/generation.go new file mode 100644 index 0000000..c75ab37 --- /dev/null +++ b/vector/generation.go @@ -0,0 +1,50 @@ +package vector + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strconv" + "strings" +) + +// Generation identifies an embedding model configuration. Two pieces of +// content embedded under generations with the same Fingerprint share a +// vector space; a different fingerprint means the caller should treat the +// vectors as a new generation and re-embed. +type Generation struct { + // Model names the embedding model, e.g. "text-embedding-3-small". + Model string + // Dimensions is the length of the vectors the model emits. + Dimensions int + // Params holds any additional knobs that change the vector space, + // such as a pooling mode or prompt template. Keys are sorted before + // fingerprinting, so map iteration order never affects the result. + Params map[string]string +} + +// Fingerprint returns a stable identifier derived from every field that +// affects the vector space. Callers persist it alongside stored vectors +// and compare it to decide whether a new generation is required. +func (g Generation) Fingerprint() string { + var b strings.Builder + b.WriteString(g.Model) + b.WriteByte('\n') + b.WriteString(strconv.Itoa(g.Dimensions)) + b.WriteByte('\n') + + keys := make([]string, 0, len(g.Params)) + for k := range g.Params { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + b.WriteString(k) + b.WriteByte('=') + b.WriteString(g.Params[k]) + b.WriteByte('\n') + } + + sum := sha256.Sum256([]byte(b.String())) + return hex.EncodeToString(sum[:8]) +} diff --git a/vector/generation_test.go b/vector/generation_test.go new file mode 100644 index 0000000..0f48a8a --- /dev/null +++ b/vector/generation_test.go @@ -0,0 +1,43 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func TestGenerationFingerprintIsStableAndOrderIndependent(t *testing.T) { + assert := assert.New(t) + a := vector.Generation{ + Model: "text-embedding-3-small", + Dimensions: 1536, + Params: map[string]string{"pooling": "mean", "prompt": "search"}, + } + b := vector.Generation{ + Model: "text-embedding-3-small", + Dimensions: 1536, + Params: map[string]string{"prompt": "search", "pooling": "mean"}, + } + + assert.Equal(a.Fingerprint(), a.Fingerprint(), "same value fingerprints identically") + assert.Equal(a.Fingerprint(), b.Fingerprint(), "map order does not change fingerprint") +} + +func TestGenerationFingerprintChangesWithSpace(t *testing.T) { + assert := assert.New(t) + base := vector.Generation{Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean"}} + + cases := map[string]vector.Generation{ + "model": {Model: "other", Dimensions: 768, Params: map[string]string{"pooling": "mean"}}, + "dimensions": {Model: "m", Dimensions: 1024, Params: map[string]string{"pooling": "mean"}}, + "param value": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "cls"}}, + "extra param": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean", "prompt": "x"}}, + } + for name, g := range cases { + t.Run(name, func(t *testing.T) { + assert.NotEqual(base.Fingerprint(), g.Fingerprint()) + }) + } +} diff --git a/vector/search.go b/vector/search.go new file mode 100644 index 0000000..41c8ff2 --- /dev/null +++ b/vector/search.go @@ -0,0 +1,158 @@ +package vector + +import "sort" + +// Hit is a single search result identifying the document it belongs to. K +// is the caller's document key type (for example int64 or a UUID); this +// package compares keys for equality but never interprets them. +type Hit[K comparable] struct { + // Doc identifies the source document. + Doc K + // ChunkIndex is the chunk within Doc that matched. + ChunkIndex int + // Score is the backend's similarity score for this chunk. Merge + // overwrites it with the merged score under the chosen strategy. + Score float32 +} + +// RollupByDocument reduces chunk-level hits to one hit per document, +// keeping the highest-scoring chunk for each, and returns them sorted by +// score descending. It is the chunk->document step a caller applies to a +// single generation's results before merging across generations. +func RollupByDocument[K comparable](hits []Hit[K]) []Hit[K] { + if len(hits) == 0 { + return nil + } + best := make(map[K]Hit[K], len(hits)) + order := make([]K, 0, len(hits)) + for _, h := range hits { + cur, ok := best[h.Doc] + if !ok { + order = append(order, h.Doc) + best[h.Doc] = h + continue + } + if h.Score > cur.Score { + best[h.Doc] = h + } + } + out := make([]Hit[K], 0, len(order)) + for _, k := range order { + out = append(out, best[k]) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Score > out[j].Score }) + return out +} + +// MergeStrategy selects how Merge orders documents drawn from different +// generations, whose raw scores are not directly comparable. +type MergeStrategy int + +const ( + // MergeNormalizedScore min-max normalizes each generation's scores to + // [0,1] before ordering. It is the default: it keeps score signal + // without letting one generation's score scale dominate. + MergeNormalizedScore MergeStrategy = iota + // MergeRawScore orders by raw score. Use it only when the generations + // share a model family and comparable score distributions. + MergeRawScore + // MergeReciprocalRank ignores absolute scores and fuses by rank. Use + // it when score distributions differ sharply between generations. + MergeReciprocalRank +) + +// MergeOptions configures Merge. +type MergeOptions struct { + // Strategy selects the ordering policy. The zero value is + // MergeNormalizedScore. + Strategy MergeStrategy + // RankConstant is the k term in reciprocal-rank fusion. Values <= 0 + // use 60. + RankConstant float64 + // Limit caps the number of returned hits. Values <= 0 return all. + Limit int +} + +// Merge unions per-generation, document-level result lists into one +// ranking. The lists are given in descending preference: when a document +// appears in more than one list, the hit from the earliest list is kept, +// which is how a caller expresses "prefer the newer generation" during a +// migration. Coverage is a union, so a document present in only one +// generation is never dropped. +// +// Each surviving hit's Score is set to the merged score under the chosen +// strategy, and the result is ordered by that score descending. +func Merge[K comparable](perGeneration [][]Hit[K], o MergeOptions) []Hit[K] { + rep := make(map[K]Hit[K]) + order := make([]K, 0) + score := make(map[K]float64) + + switch o.Strategy { + case MergeReciprocalRank: + k := o.RankConstant + if k <= 0 { + k = 60 + } + for _, list := range perGeneration { + for rank, h := range list { + if _, ok := rep[h.Doc]; !ok { + rep[h.Doc] = h + order = append(order, h.Doc) + } + score[h.Doc] += 1.0 / (k + float64(rank) + 1.0) + } + } + case MergeRawScore: + for _, list := range perGeneration { + for _, h := range list { + if _, ok := rep[h.Doc]; ok { + continue + } + rep[h.Doc] = h + order = append(order, h.Doc) + score[h.Doc] = float64(h.Score) + } + } + default: // MergeNormalizedScore + for _, list := range perGeneration { + lo, hi := scoreRange(list) + span := hi - lo + for _, h := range list { + if _, ok := rep[h.Doc]; ok { + continue + } + rep[h.Doc] = h + order = append(order, h.Doc) + if span > 0 { + score[h.Doc] = float64(h.Score-lo) / float64(span) + } else { + score[h.Doc] = 1 + } + } + } + } + + out := make([]Hit[K], 0, len(order)) + for _, doc := range order { + h := rep[doc] + h.Score = float32(score[doc]) + out = append(out, h) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Score > out[j].Score }) + if o.Limit > 0 && len(out) > o.Limit { + out = out[:o.Limit] + } + return out +} + +func scoreRange[K comparable](hits []Hit[K]) (lo, hi float32) { + for i, h := range hits { + if i == 0 || h.Score < lo { + lo = h.Score + } + if i == 0 || h.Score > hi { + hi = h.Score + } + } + return lo, hi +} diff --git a/vector/search_test.go b/vector/search_test.go new file mode 100644 index 0000000..1137818 --- /dev/null +++ b/vector/search_test.go @@ -0,0 +1,109 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func docs[K comparable](hits []vector.Hit[K]) []K { + out := make([]K, len(hits)) + for i, h := range hits { + out[i] = h.Doc + } + return out +} + +func TestRollupByDocumentKeepsBestChunkPerDoc(t *testing.T) { + assert := assert.New(t) + hits := []vector.Hit[int64]{ + {Doc: 1, ChunkIndex: 0, Score: 0.2}, + {Doc: 2, ChunkIndex: 0, Score: 0.9}, + {Doc: 1, ChunkIndex: 3, Score: 0.7}, // better chunk for doc 1 + {Doc: 2, ChunkIndex: 1, Score: 0.4}, + } + + got := vector.RollupByDocument(hits) + + assert.Equal([]int64{2, 1}, docs(got), "one hit per doc, ordered by score desc") + assert.Equal(3, got[1].ChunkIndex, "doc 1 keeps its highest-scoring chunk") + assert.InDelta(0.7, got[1].Score, 1e-6) +} + +func TestMergeUnionsAndPrefersEarlierGeneration(t *testing.T) { + assert := assert.New(t) + // String keys stand in for kata's UUIDs; building generation first. + building := []vector.Hit[string]{ + {Doc: "shared", Score: 0.50}, + {Doc: "new-only", Score: 0.40}, + } + active := []vector.Hit[string]{ + {Doc: "shared", Score: 0.99}, // higher raw score, but less preferred + {Doc: "old-only", Score: 0.80}, + } + + got := vector.Merge([][]vector.Hit[string]{building, active}, vector.MergeOptions{Strategy: vector.MergeRawScore}) + + assert.ElementsMatch([]string{"shared", "new-only", "old-only"}, docs(got), "coverage is a union") + for _, h := range got { + if h.Doc == "shared" { + assert.InDelta(0.50, h.Score, 1e-6, "shared doc keeps the preferred (building) hit, not the higher raw score") + } + } +} + +func TestMergeNormalizedScoreIsDefault(t *testing.T) { + assert := assert.New(t) + // Active generation scores live in a compressed high band; building + // generation in a low band. Raw merge would let active dominate; + // normalization puts each generation's top hit at 1.0. + active := []vector.Hit[int]{ + {Doc: 1, Score: 0.90}, + {Doc: 2, Score: 0.85}, + } + building := []vector.Hit[int]{ + {Doc: 3, Score: 0.20}, + {Doc: 4, Score: 0.10}, + } + + got := vector.Merge([][]vector.Hit[int]{building, active}, vector.MergeOptions{}) + + // Each generation's best-normalized hit should reach the top band. + top := got[0] + assert.Contains([]int{1, 3}, top.Doc, "a normalized top hit leads, not just the raw-highest") + assert.InDelta(1.0, float64(top.Score), 1e-6) +} + +func TestMergeReciprocalRankFusesAcrossGenerations(t *testing.T) { + assert := assert.New(t) + // "shared" is rank 1 in one list and rank 2 in the other, so its + // fused score should beat docs that appear in only one list. + a := []vector.Hit[int]{ + {Doc: 10, Score: 0.99}, + {Doc: 99, Score: 0.98}, + } + b := []vector.Hit[int]{ + {Doc: 99, Score: 0.50}, + {Doc: 20, Score: 0.49}, + } + + got := vector.Merge([][]vector.Hit[int]{a, b}, vector.MergeOptions{Strategy: vector.MergeReciprocalRank}) + + assert.Equal(99, got[0].Doc, "the doc found in both generations ranks first") +} + +func TestMergeRespectsLimit(t *testing.T) { + assert := assert.New(t) + list := []vector.Hit[int]{{Doc: 1, Score: 0.9}, {Doc: 2, Score: 0.8}, {Doc: 3, Score: 0.7}} + + got := vector.Merge([][]vector.Hit[int]{list}, vector.MergeOptions{Strategy: vector.MergeRawScore, Limit: 2}) + + assert.Len(got, 2) + assert.Equal([]int{1, 2}, docs(got)) +} + +func TestMergeEmpty(t *testing.T) { + assert.Empty(t, vector.Merge[int](nil, vector.MergeOptions{})) +} diff --git a/vector/sqlitevec/sqlitevec.go b/vector/sqlitevec/sqlitevec.go new file mode 100644 index 0000000..86450f1 --- /dev/null +++ b/vector/sqlitevec/sqlitevec.go @@ -0,0 +1,187 @@ +// Package sqlitevec implements vector.Store on top of SQLite with the +// sqlite-vec extension. It is a reference backend: a worked example of the +// storage contract the vector flows depend on, built against the same +// sqlite-vec binding msgvault uses. +// +// Callers must register the extension once before opening their database: +// +// sqlitevec.Register() +// db, _ := sql.Open("sqlite3", path) +// +// The caller owns the documents table; this package owns a small set of +// vector tables derived from VectorsPrefix. Each generation gets its own +// vec0 virtual table sized to that generation's dimension, so generations +// with different model dimensions coexist during a migration. +package sqlitevec + +import ( + "context" + "database/sql" + "fmt" + "regexp" + + vecext "github.com/asg017/sqlite-vec-go-bindings/cgo" + + "go.kenn.io/kit/vector" +) + +// Register loads the sqlite-vec extension into every SQLite connection +// opened afterwards in this process. It must be called before opening the +// database the store will use. +func Register() { vecext.Auto() } + +var identifierPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// State is a generation's role in the active/building lifecycle. Only +// building and active generations are searched; building sorts ahead of +// active so Merge keeps the newer generation's hit on overlap. +type State string + +const ( + StatePending State = "pending" + StateBuilding State = "building" + StateActive State = "active" + StateRetired State = "retired" +) + +// Schema names the caller's documents table and the prefix for the +// vector tables this package manages. Every field must be a bare SQL +// identifier; values are validated before being interpolated into SQL. +type Schema struct { + DocsTable string // caller's documents table, e.g. "messages" + IDColumn string // primary key column, e.g. "id" + ContentColumn string // text to embed, e.g. "body" + EmbedGenColumn string // nullable generation stamp, e.g. "embed_gen" + VectorsPrefix string // prefix for managed tables, e.g. "message_vectors" +} + +func (s Schema) validate() error { + for name, value := range map[string]string{ + "docs table": s.DocsTable, + "id column": s.IDColumn, + "content column": s.ContentColumn, + "embed gen column": s.EmbedGenColumn, + "vectors prefix": s.VectorsPrefix, + } { + if !identifierPattern.MatchString(value) { + return fmt.Errorf("invalid %s %q", name, value) + } + } + return nil +} + +// Store implements vector.Store[K, G] against SQLite + sqlite-vec. K is the +// caller's document key type and G its generation key type; both must be +// types database/sql can bind and scan (for example int64 or string). +type Store[K, G comparable] struct { + db *sql.DB + schema Schema +} + +// New returns a Store bound to db. The caller retains ownership of db. New +// creates the generations and chunks bookkeeping tables if they do not +// exist; per-generation vec0 tables are created by EnsureGeneration. +func New[K, G comparable](ctx context.Context, db *sql.DB, schema Schema) (*Store[K, G], error) { + if err := schema.validate(); err != nil { + return nil, err + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + s := &Store[K, G]{db: db, schema: schema} + if _, err := db.ExecContext(ctx, fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( + ordinal INTEGER PRIMARY KEY, + gen_key UNIQUE, + dimension INTEGER NOT NULL, + state TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS %s ( + ordinal INTEGER NOT NULL, + doc_key NOT NULL, + chunk_index INTEGER NOT NULL, + vec_rowid INTEGER NOT NULL, + PRIMARY KEY (ordinal, doc_key, chunk_index) +);`, s.generationsTable(), s.chunksTable())); err != nil { + return nil, fmt.Errorf("create bookkeeping tables: %w", err) + } + return s, nil +} + +func (s *Store[K, G]) generationsTable() string { return s.schema.VectorsPrefix + "_generations" } +func (s *Store[K, G]) chunksTable() string { return s.schema.VectorsPrefix + "_chunks" } +func (s *Store[K, G]) vecTable(ordinal int64) string { + return fmt.Sprintf("%s_v%d", s.schema.VectorsPrefix, ordinal) +} + +// EnsureGeneration registers gen with model's dimension and the given +// state, creating its vec0 table on first use. Calling it again updates +// only the state; a generation's dimension is fixed once created. +func (s *Store[K, G]) EnsureGeneration(ctx context.Context, gen G, model vector.Generation, state State) error { + if model.Dimensions <= 0 { + return fmt.Errorf("generation dimension must be positive, got %d", model.Dimensions) + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin ensure generation: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if _, err := tx.ExecContext(ctx, fmt.Sprintf(` +INSERT INTO %s (gen_key, dimension, state) VALUES (?, ?, ?) +ON CONFLICT(gen_key) DO UPDATE SET state = excluded.state`, s.generationsTable()), + gen, model.Dimensions, string(state)); err != nil { + return fmt.Errorf("upsert generation: %w", err) + } + + ordinal, dimension, err := s.lookupGenerationTx(ctx, tx, gen) + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf( + `CREATE VIRTUAL TABLE IF NOT EXISTS %s USING vec0(embedding float[%d] distance_metric=cosine)`, + s.vecTable(ordinal), dimension)); err != nil { + return fmt.Errorf("create vec0 table: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit ensure generation: %w", err) + } + return nil +} + +// SetGenerationState transitions gen to state. The caller owns the +// active/building lifecycle; this only records the decision. +func (s *Store[K, G]) SetGenerationState(ctx context.Context, gen G, state State) error { + res, err := s.db.ExecContext(ctx, + fmt.Sprintf(`UPDATE %s SET state = ? WHERE gen_key = ?`, s.generationsTable()), + string(state), gen) + if err != nil { + return fmt.Errorf("set generation state: %w", err) + } + if n, _ := res.RowsAffected(); n == 0 { + return fmt.Errorf("generation %v not found", gen) + } + return nil +} + +func (s *Store[K, G]) lookupGeneration(ctx context.Context, gen G) (ordinal int64, dimension int, err error) { + return s.scanGeneration(s.db.QueryRowContext(ctx, + fmt.Sprintf(`SELECT ordinal, dimension FROM %s WHERE gen_key = ?`, s.generationsTable()), gen), gen) +} + +func (s *Store[K, G]) lookupGenerationTx(ctx context.Context, tx *sql.Tx, gen G) (int64, int, error) { + return s.scanGeneration(tx.QueryRowContext(ctx, + fmt.Sprintf(`SELECT ordinal, dimension FROM %s WHERE gen_key = ?`, s.generationsTable()), gen), gen) +} + +func (s *Store[K, G]) scanGeneration(row *sql.Row, gen G) (int64, int, error) { + var ordinal int64 + var dimension int + if err := row.Scan(&ordinal, &dimension); err != nil { + if err == sql.ErrNoRows { + return 0, 0, fmt.Errorf("generation %v not ensured", gen) + } + return 0, 0, fmt.Errorf("lookup generation %v: %w", gen, err) + } + return ordinal, dimension, nil +} diff --git a/vector/sqlitevec/sqlitevec_test.go b/vector/sqlitevec/sqlitevec_test.go new file mode 100644 index 0000000..255c9ab --- /dev/null +++ b/vector/sqlitevec/sqlitevec_test.go @@ -0,0 +1,141 @@ +package sqlitevec_test + +import ( + "context" + "database/sql" + "path/filepath" + "strings" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" + "go.kenn.io/kit/vector/sqlitevec" +) + +// topicEncoder maps text to a one-hot 3-D vector by keyword, so queries +// match documents deterministically. +func topicEncoder() vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i, text := range texts { + switch { + case strings.Contains(text, "cat"): + out[i] = []float32{1, 0, 0} + case strings.Contains(text, "dog"): + out[i] = []float32{0, 1, 0} + default: + out[i] = []float32{0, 0, 1} + } + } + return out, nil + } +} + +func setup(t *testing.T) (*sql.DB, *sqlitevec.Store[int64, int64]) { + t.Helper() + sqlitevec.Register() + + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "vec.db")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, db.Close()) }) + + _, err = db.Exec(`CREATE TABLE messages (id INTEGER PRIMARY KEY, body TEXT, embed_gen INTEGER)`) + require.NoError(t, err) + + store, err := sqlitevec.New[int64, int64](context.Background(), db, sqlitevec.Schema{ + DocsTable: "messages", + IDColumn: "id", + ContentColumn: "body", + EmbedGenColumn: "embed_gen", + VectorsPrefix: "message_vectors", + }) + require.NoError(t, err) + return db, store +} + +func TestStoreFillThenSearch(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat sat'), (2, 'a dog ran')`) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + stats, err := vector.Fill(ctx, store, 1, topicEncoder(), vector.FillOptions{}) + require.NoError(err) + assert.Equal(2, stats.Documents) + + pending, err := store.PendingForGeneration(ctx, 1, 10) + require.NoError(err) + assert.Empty(pending, "nothing pending once every document is stamped") + + enc := func(int64) vector.EncodeFunc { return topicEncoder() } + hits, err := vector.Search(ctx, store, "a cat", enc, vector.SearchOptions{}) + require.NoError(err) + require.NotEmpty(hits) + assert.Equal(int64(1), hits[0].Doc, "the cat query ranks the cat document first") +} + +func TestStoreReembeddingReplacesVectors(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat sat')`) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + require.NoError(store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}})) + require.NoError(store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{0, 1, 0}}})) + + hits, err := store.QueryGeneration(ctx, 1, vector.Vector{0, 1, 0}, 10) + require.NoError(err) + require.Len(hits, 1, "re-embedding replaces the prior vector rather than duplicating it") + assert.InDelta(1.0, hits[0].Score, 1e-6, "stored vector now matches the new query") +} + +func TestStoreSearchUnionsLiveGenerations(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat'), (2, 'a dog')`) + require.NoError(err) + + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "v1", Dimensions: 3}, sqlitevec.StateActive)) + _, err = vector.Fill(ctx, store, 1, topicEncoder(), vector.FillOptions{}) + require.NoError(err) + + // The building generation has covered only doc 1 so far. + require.NoError(store.EnsureGeneration(ctx, 2, vector.Generation{Model: "v2", Dimensions: 3}, sqlitevec.StateBuilding)) + require.NoError(store.SaveVectors(ctx, 2, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}})) + + gens, err := store.LiveGenerations(ctx) + require.NoError(err) + assert.Equal([]int64{2, 1}, gens, "building precedes active in preference order") + + enc := func(int64) vector.EncodeFunc { return topicEncoder() } + hits, err := vector.Search(ctx, store, "a cat", enc, vector.SearchOptions{}) + require.NoError(err) + + found := map[int64]bool{} + for _, h := range hits { + found[h.Doc] = true + } + assert.True(found[1], "shared doc is searchable") + assert.True(found[2], "active-only doc is not dropped mid-migration (union coverage)") +} + +func TestNewRejectsUnsafeIdentifiers(t *testing.T) { + _, err := sqlitevec.New[int64, int64](context.Background(), nil, sqlitevec.Schema{ + DocsTable: "messages; DROP TABLE messages", + }) + require.Error(t, err) +} diff --git a/vector/sqlitevec/store.go b/vector/sqlitevec/store.go new file mode 100644 index 0000000..aad72c5 --- /dev/null +++ b/vector/sqlitevec/store.go @@ -0,0 +1,196 @@ +package sqlitevec + +import ( + "context" + "database/sql" + "fmt" + + vecext "github.com/asg017/sqlite-vec-go-bindings/cgo" + + "go.kenn.io/kit/vector" +) + +// PendingForGeneration scans the caller's documents table for rows whose +// stamp does not yet match gen, ordered by primary key for stable paging. +func (s *Store[K, G]) PendingForGeneration(ctx context.Context, gen G, limit int) ([]vector.Pending[K], error) { + query := fmt.Sprintf( + `SELECT %s, %s FROM %s WHERE %s IS NULL OR %s <> ? ORDER BY %s LIMIT ?`, + s.schema.IDColumn, s.schema.ContentColumn, s.schema.DocsTable, + s.schema.EmbedGenColumn, s.schema.EmbedGenColumn, s.schema.IDColumn) + rows, err := s.db.QueryContext(ctx, query, gen, limit) + if err != nil { + return nil, fmt.Errorf("scan pending: %w", err) + } + defer func() { _ = rows.Close() }() + + var pending []vector.Pending[K] + for rows.Next() { + var p vector.Pending[K] + if err := rows.Scan(&p.Doc, &p.Content); err != nil { + return nil, fmt.Errorf("scan pending row: %w", err) + } + pending = append(pending, p) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan pending rows: %w", err) + } + return pending, nil +} + +// SaveVectors replaces doc's chunk vectors for gen and stamps the document +// as embedded for gen, all in one transaction. +func (s *Store[K, G]) SaveVectors(ctx context.Context, gen G, doc K, vectors []vector.ChunkVector) error { + ordinal, dimension, err := s.lookupGeneration(ctx, gen) + if err != nil { + return err + } + for _, cv := range vectors { + if len(cv.Vector) != dimension { + return fmt.Errorf("chunk %d has %d dimensions, generation expects %d", cv.ChunkIndex, len(cv.Vector), dimension) + } + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin save vectors: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Drop any prior vectors for this document so re-embedding is clean. + rowids, err := s.docRowids(ctx, tx, ordinal, doc) + if err != nil { + return err + } + for _, rowid := range rowids { + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE rowid = ?`, s.vecTable(ordinal)), rowid); err != nil { + return fmt.Errorf("delete stale vector: %w", err) + } + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE ordinal = ? AND doc_key = ?`, s.chunksTable()), ordinal, doc); err != nil { + return fmt.Errorf("delete stale chunk map: %w", err) + } + + for _, cv := range vectors { + blob, err := vecext.SerializeFloat32(cv.Vector) + if err != nil { + return fmt.Errorf("serialize vector: %w", err) + } + res, err := tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (embedding) VALUES (?)`, s.vecTable(ordinal)), blob) + if err != nil { + return fmt.Errorf("insert vector: %w", err) + } + rowid, err := res.LastInsertId() + if err != nil { + return fmt.Errorf("vector rowid: %w", err) + } + if _, err := tx.ExecContext(ctx, + fmt.Sprintf(`INSERT INTO %s (ordinal, doc_key, chunk_index, vec_rowid) VALUES (?, ?, ?, ?)`, s.chunksTable()), + ordinal, doc, cv.ChunkIndex, rowid); err != nil { + return fmt.Errorf("insert chunk map: %w", err) + } + } + + if _, err := tx.ExecContext(ctx, + fmt.Sprintf(`UPDATE %s SET %s = ? WHERE %s = ?`, s.schema.DocsTable, s.schema.EmbedGenColumn, s.schema.IDColumn), + gen, doc); err != nil { + return fmt.Errorf("stamp embed generation: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit save vectors: %w", err) + } + return nil +} + +func (s *Store[K, G]) docRowids(ctx context.Context, tx txQuerier, ordinal int64, doc K) ([]int64, error) { + rows, err := tx.QueryContext(ctx, + fmt.Sprintf(`SELECT vec_rowid FROM %s WHERE ordinal = ? AND doc_key = ?`, s.chunksTable()), ordinal, doc) + if err != nil { + return nil, fmt.Errorf("read chunk map: %w", err) + } + defer func() { _ = rows.Close() }() + + var rowids []int64 + for rows.Next() { + var rowid int64 + if err := rows.Scan(&rowid); err != nil { + return nil, fmt.Errorf("scan chunk rowid: %w", err) + } + rowids = append(rowids, rowid) + } + return rowids, rows.Err() +} + +// LiveGenerations returns building and active generations, building first, +// so Merge prefers the newer generation when a document is in both. +func (s *Store[K, G]) LiveGenerations(ctx context.Context) ([]G, error) { + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` +SELECT gen_key FROM %s + WHERE state IN (?, ?) + ORDER BY CASE state WHEN ? THEN 0 ELSE 1 END, ordinal`, s.generationsTable()), + string(StateBuilding), string(StateActive), string(StateBuilding)) + if err != nil { + return nil, fmt.Errorf("list live generations: %w", err) + } + defer func() { _ = rows.Close() }() + + var gens []G + for rows.Next() { + var gen G + if err := rows.Scan(&gen); err != nil { + return nil, fmt.Errorf("scan generation key: %w", err) + } + gens = append(gens, gen) + } + return gens, rows.Err() +} + +// QueryGeneration runs a cosine KNN search within gen's vec0 table and +// maps each neighbor back to its document and chunk. Score is the cosine +// similarity (1 - cosine distance), so higher is more similar. +func (s *Store[K, G]) QueryGeneration(ctx context.Context, gen G, query vector.Vector, limit int) ([]vector.Hit[K], error) { + ordinal, dimension, err := s.lookupGeneration(ctx, gen) + if err != nil { + return nil, err + } + if len(query) != dimension { + return nil, fmt.Errorf("query has %d dimensions, generation expects %d", len(query), dimension) + } + blob, err := vecext.SerializeFloat32(query) + if err != nil { + return nil, fmt.Errorf("serialize query: %w", err) + } + + // The KNN runs against the vec0 table alone (its required form), then + // joins to the chunk map to recover document keys. + sqlText := fmt.Sprintf(` +WITH knn AS ( + SELECT rowid, distance FROM %s WHERE embedding MATCH ? ORDER BY distance LIMIT ? +) +SELECT c.doc_key, c.chunk_index, knn.distance + FROM knn JOIN %s c ON c.ordinal = ? AND c.vec_rowid = knn.rowid + ORDER BY knn.distance`, s.vecTable(ordinal), s.chunksTable()) + rows, err := s.db.QueryContext(ctx, sqlText, blob, limit, ordinal) + if err != nil { + return nil, fmt.Errorf("query generation: %w", err) + } + defer func() { _ = rows.Close() }() + + var hits []vector.Hit[K] + for rows.Next() { + var ( + doc K + chunkIndex int + distance float64 + ) + if err := rows.Scan(&doc, &chunkIndex, &distance); err != nil { + return nil, fmt.Errorf("scan hit: %w", err) + } + hits = append(hits, vector.Hit[K]{Doc: doc, ChunkIndex: chunkIndex, Score: float32(1 - distance)}) + } + return hits, rows.Err() +} + +// txQuerier is the read surface shared by *sql.DB and *sql.Tx. +type txQuerier interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} diff --git a/vector/sqlitevec/vec0_smoke_test.go b/vector/sqlitevec/vec0_smoke_test.go new file mode 100644 index 0000000..0550cee --- /dev/null +++ b/vector/sqlitevec/vec0_smoke_test.go @@ -0,0 +1,42 @@ +package sqlitevec + +import ( + "database/sql" + "testing" + + sqlitevec "github.com/asg017/sqlite-vec-go-bindings/cgo" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" +) + +// TestVec0LoadsHermetically confirms the sqlite-vec extension is compiled +// in via the cgo binding, so the backend's tests need no external setup. +func TestVec0LoadsHermetically(t *testing.T) { + require := require.New(t) + sqlitevec.Auto() + + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(err) + t.Cleanup(func() { require.NoError(db.Close()) }) + + _, err = db.Exec(`CREATE VIRTUAL TABLE v USING vec0(embedding float[3])`) + require.NoError(err) + + vec, err := sqlitevec.SerializeFloat32([]float32{1, 2, 3}) + require.NoError(err) + _, err = db.Exec(`INSERT INTO v(rowid, embedding) VALUES (1, ?)`, vec) + require.NoError(err) + + query, err := sqlitevec.SerializeFloat32([]float32{1, 2, 3}) + require.NoError(err) + var rowid int64 + var distance float64 + err = db.QueryRow( + `SELECT rowid, distance FROM v WHERE embedding MATCH ? ORDER BY distance LIMIT 1`, + query, + ).Scan(&rowid, &distance) + require.NoError(err) + + require.Equal(int64(1), rowid) + require.InDelta(0, distance, 1e-6, "identical vectors have zero distance") +} diff --git a/vector/store.go b/vector/store.go new file mode 100644 index 0000000..28facc9 --- /dev/null +++ b/vector/store.go @@ -0,0 +1,47 @@ +package vector + +import "context" + +// Pending is one document that still needs embedding for a generation, +// paired with the text to embed. +type Pending[K comparable] struct { + Doc K + Content string +} + +// ChunkVector is a single chunk's embedding, ready to persist. +type ChunkVector struct { + ChunkIndex int + Vector Vector +} + +// Store is the persistence contract the Fill and Search flows depend on. +// Implementations are a function of the caller's source system — a SQLite, +// pgvector, or DuckDB table — and own all backend SQL and query +// construction. The flows never open a database or build SQL themselves. +// +// K is the caller's document key type and G its generation id type; the +// package compares both for equality but never interprets them. +type Store[K, G comparable] interface { + // PendingForGeneration returns up to limit documents that are not yet + // embedded for gen, in a stable order. Implementations typically scan + // for "embed_gen IS NULL OR embed_gen <> gen". A document must stop + // being reported once SaveVectors has persisted it for gen, so that a + // fill loop terminates. + PendingForGeneration(ctx context.Context, gen G, limit int) ([]Pending[K], error) + + // SaveVectors persists every chunk vector for doc under gen and marks + // doc as embedded for gen (the scan-and-fill stamp). + SaveVectors(ctx context.Context, gen G, doc K, vectors []ChunkVector) error + + // LiveGenerations returns the generations a search should query, in + // descending preference. During a migration the building generation + // precedes the active one, so Merge keeps the newer generation's hit + // when a document appears in both. + LiveGenerations(ctx context.Context) ([]G, error) + + // QueryGeneration returns chunk-level hits for query within gen, + // ranked best first and capped at limit. This is where each backend's + // vector query construction lives. + QueryGeneration(ctx context.Context, gen G, query Vector, limit int) ([]Hit[K], error) +} From 792a47d3faeb6335c436d07831f9d28f0a25234e Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Thu, 25 Jun 2026 09:49:52 -0400 Subject: [PATCH 3/4] Harden generation fingerprint and reject orphan vectors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two medium findings from the roborev panel on 061a0a3. Generation.Fingerprint built its preimage by joining params as key=value lines, so a value containing the separator could hash identically to a different param set — distinct vector spaces could share a fingerprint and silently skip a needed re-embed. It now hashes the JSON encoding, which escapes values and sorts map keys, so the preimage is unambiguous. SaveVectors stamped the source document without checking the update hit a row, so if a document was deleted between scan and save (or a caller passed a missing key) the transaction committed vector rows with no backing document, which QueryGeneration would later return as orphan hits. It now checks RowsAffected and rolls back when nothing was stamped. Validation: go vet/test ./vector/... with CGO. Generated with Claude Code (Opus 4.8) Co-Authored-By: Claude Opus 4.8 --- vector/generation.go | 42 +++++++++++------------------- vector/generation_test.go | 11 ++++++++ vector/sqlitevec/sqlitevec_test.go | 16 ++++++++++++ vector/sqlitevec/store.go | 15 +++++++++-- 4 files changed, 55 insertions(+), 29 deletions(-) diff --git a/vector/generation.go b/vector/generation.go index c75ab37..5cc6508 100644 --- a/vector/generation.go +++ b/vector/generation.go @@ -3,9 +3,7 @@ package vector import ( "crypto/sha256" "encoding/hex" - "sort" - "strconv" - "strings" + "encoding/json" ) // Generation identifies an embedding model configuration. Two pieces of @@ -14,37 +12,27 @@ import ( // vectors as a new generation and re-embed. type Generation struct { // Model names the embedding model, e.g. "text-embedding-3-small". - Model string + Model string `json:"model,omitzero"` // Dimensions is the length of the vectors the model emits. - Dimensions int + Dimensions int `json:"dimensions,omitzero"` // Params holds any additional knobs that change the vector space, - // such as a pooling mode or prompt template. Keys are sorted before - // fingerprinting, so map iteration order never affects the result. - Params map[string]string + // such as a pooling mode or prompt template. + Params map[string]string `json:"params,omitzero"` } // Fingerprint returns a stable identifier derived from every field that // affects the vector space. Callers persist it alongside stored vectors // and compare it to decide whether a new generation is required. +// +// It hashes the JSON encoding rather than a hand-built string so that +// values are escaped and delimited unambiguously: a param value that +// itself contains the separator can never collide with two distinct +// params. func (g Generation) Fingerprint() string { - var b strings.Builder - b.WriteString(g.Model) - b.WriteByte('\n') - b.WriteString(strconv.Itoa(g.Dimensions)) - b.WriteByte('\n') - - keys := make([]string, 0, len(g.Params)) - for k := range g.Params { - keys = append(keys, k) - } - sort.Strings(keys) - for _, k := range keys { - b.WriteString(k) - b.WriteByte('=') - b.WriteString(g.Params[k]) - b.WriteByte('\n') - } - - sum := sha256.Sum256([]byte(b.String())) + // Generation holds only strings, an int, and a string map, all of + // which encoding/json can always marshal — and it sorts map keys, so + // the encoding is canonical. The error is structurally unreachable. + data, _ := json.Marshal(g) + sum := sha256.Sum256(data) return hex.EncodeToString(sum[:8]) } diff --git a/vector/generation_test.go b/vector/generation_test.go index 0f48a8a..189111b 100644 --- a/vector/generation_test.go +++ b/vector/generation_test.go @@ -25,6 +25,17 @@ func TestGenerationFingerprintIsStableAndOrderIndependent(t *testing.T) { assert.Equal(a.Fingerprint(), b.Fingerprint(), "map order does not change fingerprint") } +func TestGenerationFingerprintIsNotAmbiguousAcrossParams(t *testing.T) { + assert := assert.New(t) + // Two params vs a single param whose value embeds what used to be the + // key/value separator. A naive "key=value\n" join hashes both the + // same; the JSON encoding keeps them distinct. + two := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"pooling": "mean", "prompt": "x"}} + one := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"pooling": "mean\nprompt=x"}} + + assert.NotEqual(two.Fingerprint(), one.Fingerprint()) +} + func TestGenerationFingerprintChangesWithSpace(t *testing.T) { assert := assert.New(t) base := vector.Generation{Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean"}} diff --git a/vector/sqlitevec/sqlitevec_test.go b/vector/sqlitevec/sqlitevec_test.go index 255c9ab..7691329 100644 --- a/vector/sqlitevec/sqlitevec_test.go +++ b/vector/sqlitevec/sqlitevec_test.go @@ -133,6 +133,22 @@ func TestStoreSearchUnionsLiveGenerations(t *testing.T) { assert.True(found[2], "active-only doc is not dropped mid-migration (union coverage)") } +func TestStoreSaveVectorsRejectsMissingDocument(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + _, store := setup(t) + + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + err := store.SaveVectors(ctx, 1, 999, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}}) + require.Error(err, "saving vectors for a document not in the source table fails") + + hits, err := store.QueryGeneration(ctx, 1, vector.Vector{1, 0, 0}, 10) + require.NoError(err) + assert.Empty(hits, "no orphan vectors are committed when the source row is missing") +} + func TestNewRejectsUnsafeIdentifiers(t *testing.T) { _, err := sqlitevec.New[int64, int64](context.Background(), nil, sqlitevec.Schema{ DocsTable: "messages; DROP TABLE messages", diff --git a/vector/sqlitevec/store.go b/vector/sqlitevec/store.go index aad72c5..69ffaf9 100644 --- a/vector/sqlitevec/store.go +++ b/vector/sqlitevec/store.go @@ -90,11 +90,22 @@ func (s *Store[K, G]) SaveVectors(ctx context.Context, gen G, doc K, vectors []v } } - if _, err := tx.ExecContext(ctx, + res, err := tx.ExecContext(ctx, fmt.Sprintf(`UPDATE %s SET %s = ? WHERE %s = ?`, s.schema.DocsTable, s.schema.EmbedGenColumn, s.schema.IDColumn), - gen, doc); err != nil { + gen, doc) + if err != nil { return fmt.Errorf("stamp embed generation: %w", err) } + stamped, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("stamp embed generation rows: %w", err) + } + if stamped == 0 { + // The source row vanished between scan and save (or the key is + // wrong). Roll back rather than commit vectors with no document, + // which QueryGeneration would otherwise surface as orphan hits. + return fmt.Errorf("document %v not present in %s; vectors not persisted", doc, s.schema.DocsTable) + } if err := tx.Commit(); err != nil { return fmt.Errorf("commit save vectors: %w", err) } From 2a9fde1b1832c57aa7bca5a70ad95d119ed86a24 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Thu, 25 Jun 2026 14:10:11 -0400 Subject: [PATCH 4/4] Make generation fingerprint robust to future type changes The fingerprint is persisted by callers and compared to decide whether to re-embed, so its stability has to survive future edits to Generation. The prior version hand-listed the fields to hash, which has a dangerous failure mode: a field added later would be silently excluded, letting two distinct vector spaces share a fingerprint and skip a needed re-embed. Fingerprint now encodes the struct itself, so any field added later participates automatically, then re-encodes through a generic value. encoding/json sorts object keys at every level, so neither struct field order nor map insertion order affects the hash; UseNumber preserves numeric precision; and omitempty keeps an unused new field from shifting existing fingerprints. A pinned canonical-encoding test and a reflection tripwire on the field set make any change to the type or encoding fail CI rather than silently re-fingerprint every stored vector. Validation: go vet/test ./vector/... with CGO. Generated with Claude Code (Opus 4.8) Co-Authored-By: Claude Opus 4.8 --- vector/generation.go | 44 ++++++++++++++++++++++++++++----------- vector/generation_test.go | 43 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/vector/generation.go b/vector/generation.go index 5cc6508..50c896e 100644 --- a/vector/generation.go +++ b/vector/generation.go @@ -1,6 +1,7 @@ package vector import ( + "bytes" "crypto/sha256" "encoding/hex" "encoding/json" @@ -10,29 +11,48 @@ import ( // content embedded under generations with the same Fingerprint share a // vector space; a different fingerprint means the caller should treat the // vectors as a new generation and re-embed. +// +// Every field that affects the vector space must be exported and JSON +// encodable so Fingerprint accounts for it automatically. A field that +// must not affect identity has to be tagged json:"-". type Generation struct { // Model names the embedding model, e.g. "text-embedding-3-small". - Model string `json:"model,omitzero"` + Model string `json:"model,omitempty"` // Dimensions is the length of the vectors the model emits. - Dimensions int `json:"dimensions,omitzero"` + Dimensions int `json:"dimensions,omitempty"` // Params holds any additional knobs that change the vector space, // such as a pooling mode or prompt template. - Params map[string]string `json:"params,omitzero"` + Params map[string]string `json:"params,omitempty"` } // Fingerprint returns a stable identifier derived from every field that // affects the vector space. Callers persist it alongside stored vectors // and compare it to decide whether a new generation is required. // -// It hashes the JSON encoding rather than a hand-built string so that -// values are escaped and delimited unambiguously: a param value that -// itself contains the separator can never collide with two distinct -// params. +// It is built to be stable across future changes to this type: +// +// - It encodes the struct itself, so a field added later participates +// automatically rather than being silently excluded — the failure +// mode that would let two distinct vector spaces share a fingerprint. +// - It then re-encodes through a generic value, and encoding/json sorts +// object keys at every level, so neither struct field order nor map +// insertion order affects the hash. +// - Decoding with UseNumber preserves numeric tokens exactly, so no +// field loses precision through float64. +// - omitempty drops zero-valued fields, so adding an unused field never +// shifts an existing generation's fingerprint. +// +// All values are JSON encodable, so the marshal and decode errors are +// unreachable. func (g Generation) Fingerprint() string { - // Generation holds only strings, an int, and a string map, all of - // which encoding/json can always marshal — and it sorts map keys, so - // the encoding is canonical. The error is structurally unreachable. - data, _ := json.Marshal(g) - sum := sha256.Sum256(data) + raw, _ := json.Marshal(g) + + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.UseNumber() + var generic any + _ = dec.Decode(&generic) + + canonical, _ := json.Marshal(generic) + sum := sha256.Sum256(canonical) return hex.EncodeToString(sum[:8]) } diff --git a/vector/generation_test.go b/vector/generation_test.go index 189111b..75e71ac 100644 --- a/vector/generation_test.go +++ b/vector/generation_test.go @@ -1,6 +1,10 @@ package vector_test import ( + "crypto/sha256" + "encoding/hex" + "reflect" + "sort" "testing" "github.com/stretchr/testify/assert" @@ -41,8 +45,8 @@ func TestGenerationFingerprintChangesWithSpace(t *testing.T) { base := vector.Generation{Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean"}} cases := map[string]vector.Generation{ - "model": {Model: "other", Dimensions: 768, Params: map[string]string{"pooling": "mean"}}, - "dimensions": {Model: "m", Dimensions: 1024, Params: map[string]string{"pooling": "mean"}}, + "model": {Model: "other", Dimensions: 768, Params: map[string]string{"pooling": "mean"}}, + "dimensions": {Model: "m", Dimensions: 1024, Params: map[string]string{"pooling": "mean"}}, "param value": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "cls"}}, "extra param": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean", "prompt": "x"}}, } @@ -52,3 +56,38 @@ func TestGenerationFingerprintChangesWithSpace(t *testing.T) { }) } } + +// TestGenerationFingerprintPinsCanonicalEncoding locks the exact hash +// preimage. If the canonical form ever changes — sorting, omit behavior, +// number formatting, or a field added to the struct's encoding — this +// fails, forcing a conscious decision rather than a silent shift of every +// persisted fingerprint. +func TestGenerationFingerprintPinsCanonicalEncoding(t *testing.T) { + g := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"b": "2", "a": "1"}} + + // Keys sorted at every level, zero fields omitted, numbers verbatim. + const canonical = `{"dimensions":3,"model":"m","params":{"a":"1","b":"2"}}` + sum := sha256.Sum256([]byte(canonical)) + want := hex.EncodeToString(sum[:8]) + + assert.Equal(t, want, g.Fingerprint()) +} + +// TestGenerationFieldsAreTracked is a tripwire: adding, removing, or +// renaming a Generation field changes this set. When it fails, decide +// whether the new field affects the vector space. If it does, Fingerprint +// already includes it (it encodes the whole struct); if it must not, tag +// the field json:"-". Then update this expectation and the pinned +// encoding above. +func TestGenerationFieldsAreTracked(t *testing.T) { + want := []string{"Dimensions", "Model", "Params"} + + fields := reflect.VisibleFields(reflect.TypeFor[vector.Generation]()) + got := make([]string, 0, len(fields)) + for _, f := range fields { + got = append(got, f.Name) + } + sort.Strings(got) + + assert.Equal(t, want, got, "Generation fields changed: review fingerprint impact before updating this tripwire") +}