diff --git a/go.mod b/go.mod index e9497ee..2cfdf19 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module go.kenn.io/kit go 1.26.3 require ( + github.com/asg017/sqlite-vec-go-bindings v0.1.6 github.com/gofrs/flock v0.13.0 + github.com/mattn/go-sqlite3 v1.14.44 github.com/posthog/posthog-go v1.12.6 github.com/stretchr/testify v1.11.1 go.opentelemetry.io/otel v1.43.0 diff --git a/go.sum b/go.sum index f82d09d..3d118e8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= +github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -38,6 +40,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= +github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posthog/posthog-go v1.12.6 h1:N+FrKWY6DOuDhV2OMgvtKAKDYGTdtS9/nuvr0BTyBp0= diff --git a/vector/AGENTS.md b/vector/AGENTS.md new file mode 100644 index 0000000..6c86733 --- /dev/null +++ b/vector/AGENTS.md @@ -0,0 +1,41 @@ +# vector package invariants + +`go.kenn.io/kit/vector` owns the backend-neutral parts of an embedding +pipeline. Preserve these invariants when changing it. + +## The storage boundary is the point of this package + +- The core `vector` package must not import `database/sql`, a driver, or + any backend client, and must not construct backend SQL. The `Fill` and + `Search` flows reach storage only through the `Store[K, G]` interface. +- Persistence is a function of the caller's source system. Backends live + in their own subpackages (e.g. `vector/sqlitevec`) so a caller wiring + one backend never pulls another backend's driver. New backends + (pgvector, duckdb) go in sibling subpackages, not into the core. +- Backends own query construction. The differences between sqlite-vec + `vec0 MATCH`, pgvector `<=>`, and duckdb `array_distance` belong behind + `QueryGeneration`, never in the core flows. + +## Keys and generations are opaque + +- Document identity is the caller's type `K` and generation identity its + type `G`. msgvault uses `int64`; kata uses UUIDs. Compare them for + equality only; never assume a type, a single id namespace, or an + ordering. Backends additionally require `K`/`G` to be types + `database/sql` can bind and scan. + +## Merge semantics + +- `Merge` takes per-generation lists in descending preference and keeps + the earliest list's hit on overlap (prefer the newer generation during + a migration). Coverage is a union — never drop a document that only one + generation covers, and never emit duplicates. +- Cross-generation scores are not comparable. Default to + `MergeNormalizedScore`; raw-score merging is opt-in. + +## Generations during migration + +- The mid-migration union exists because new documents land only in the + building generation while the active generation still serves the bulk. + `Search` must keep querying every generation `LiveGenerations` returns, + in the order it returns them. diff --git a/vector/chunk.go b/vector/chunk.go new file mode 100644 index 0000000..78239dd --- /dev/null +++ b/vector/chunk.go @@ -0,0 +1,48 @@ +package vector + +// Chunk is a window of text encoded as a single vector. Index is the +// chunk's position within the source content, starting at zero. +type Chunk struct { + Index int + Text string +} + +// SplitOptions controls how Split windows content into chunks. +type SplitOptions struct { + // MaxRunes bounds the number of runes in each chunk. Values <= 0 + // disable splitting and return the content as a single chunk. + MaxRunes int + // Overlap is the number of runes shared between consecutive chunks. + // It is clamped to the range [0, MaxRunes-1]. + Overlap int +} + +// Split windows content into overlapping chunks of at most MaxRunes runes. +// It splits on runes rather than bytes so multi-byte characters are never +// torn apart. Empty content yields no chunks. +// +// Split measures size in runes, not model tokens. Callers that budget by +// tokens should convert their token budget to an approximate rune count. +func Split(content string, o SplitOptions) []Chunk { + if content == "" { + return nil + } + runes := []rune(content) + if o.MaxRunes <= 0 || len(runes) <= o.MaxRunes { + return []Chunk{{Index: 0, Text: content}} + } + + overlap := min(max(o.Overlap, 0), o.MaxRunes-1) + stride := o.MaxRunes - overlap + + var chunks []Chunk + for start, idx := 0, 0; start < len(runes); start, idx = start+stride, idx+1 { + end := start + o.MaxRunes + if end >= len(runes) { + chunks = append(chunks, Chunk{Index: idx, Text: string(runes[start:])}) + break + } + chunks = append(chunks, Chunk{Index: idx, Text: string(runes[start:end])}) + } + return chunks +} diff --git a/vector/chunk_test.go b/vector/chunk_test.go new file mode 100644 index 0000000..e2af4f9 --- /dev/null +++ b/vector/chunk_test.go @@ -0,0 +1,84 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func TestSplit(t *testing.T) { + tests := []struct { + name string + content string + opts vector.SplitOptions + want []vector.Chunk + }{ + { + name: "empty yields no chunks", + content: "", + opts: vector.SplitOptions{MaxRunes: 4}, + want: nil, + }, + { + name: "non-positive max returns single chunk", + content: "hello world", + opts: vector.SplitOptions{MaxRunes: 0}, + want: []vector.Chunk{{Index: 0, Text: "hello world"}}, + }, + { + name: "content shorter than max is one chunk", + content: "abcd", + opts: vector.SplitOptions{MaxRunes: 8}, + want: []vector.Chunk{{Index: 0, Text: "abcd"}}, + }, + { + name: "windows without overlap", + content: "abcdefghij", + opts: vector.SplitOptions{MaxRunes: 5}, + want: []vector.Chunk{ + {Index: 0, Text: "abcde"}, + {Index: 1, Text: "fghij"}, + }, + }, + { + name: "windows with overlap", + content: "abcdefghij", + opts: vector.SplitOptions{MaxRunes: 4, Overlap: 1}, + want: []vector.Chunk{ + {Index: 0, Text: "abcd"}, + {Index: 1, Text: "defg"}, + {Index: 2, Text: "ghij"}, + }, + }, + { + name: "overlap at or above max clamps to max-1", + content: "abcdef", + opts: vector.SplitOptions{MaxRunes: 3, Overlap: 9}, + want: []vector.Chunk{ + {Index: 0, Text: "abc"}, + {Index: 1, Text: "bcd"}, + {Index: 2, Text: "cde"}, + {Index: 3, Text: "def"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, vector.Split(tt.content, tt.opts)) + }) + } +} + +func TestSplitDoesNotTearMultiByteRunes(t *testing.T) { + assert := assert.New(t) + // Each emoji is multiple bytes but one rune. + chunks := vector.Split("😀😁😂🤣", vector.SplitOptions{MaxRunes: 2}) + + assert.Equal([]vector.Chunk{ + {Index: 0, Text: "😀😁"}, + {Index: 1, Text: "😂🤣"}, + }, chunks) +} diff --git a/vector/doc.go b/vector/doc.go new file mode 100644 index 0000000..6d96ae3 --- /dev/null +++ b/vector/doc.go @@ -0,0 +1,24 @@ +// Package vector provides backend-neutral building blocks for embedding +// content and searching the resulting vectors. +// +// It is organized in three layers: +// +// - Transforms and value types: Split windows content into chunks, +// Generation identifies an embedding model, EncodeBatched batches +// encode calls, and RollupByDocument and Merge reduce and combine +// search results across generations. These are pure functions. +// +// - The Store contract: Store[K, G] is the persistence interface the +// flows depend on. Implementations are a function of the caller's +// source system and own all backend SQL and query construction; see +// the sqlitevec subpackage for a worked example. +// +// - Flows: Fill runs the scan-and-fill embedding loop and Search runs +// the cross-generation query-and-merge, both over a Store. +// +// Nothing in this package opens a database, holds an index, or constructs +// backend SQL — the flows delegate every storage operation to the Store. +// Document identity is the caller's own key type K, and generation +// identity its type G; the package compares both for equality but never +// interprets them. +package vector diff --git a/vector/encode.go b/vector/encode.go new file mode 100644 index 0000000..59458bf --- /dev/null +++ b/vector/encode.go @@ -0,0 +1,111 @@ +package vector + +import ( + "context" + "fmt" + "sync" +) + +// Vector is a single embedding. +type Vector []float32 + +// EncodeFunc turns a batch of texts into one vector each, in the same +// order. Implementations own the model or API client and any retry or +// backoff policy, since retryability is provider-specific. +type EncodeFunc func(ctx context.Context, texts []string) ([][]float32, error) + +// BatchOptions controls how EncodeBatched groups and parallelizes calls. +type BatchOptions struct { + // BatchSize is the maximum number of chunks passed to EncodeFunc in a + // single call. Values <= 0 send every chunk in one call. + BatchSize int + // Concurrency bounds how many EncodeFunc calls run at once. Values + // <= 0 mean one call at a time. + Concurrency int +} + +// EncodeBatched splits chunks into batches, invokes enc with bounded +// concurrency, and returns one Vector per input chunk in input order. It +// stops launching work at the first error or when ctx is cancelled, and +// reports the first error encountered. +func EncodeBatched(ctx context.Context, enc EncodeFunc, chunks []Chunk, o BatchOptions) ([]Vector, error) { + if enc == nil { + return nil, fmt.Errorf("encode func is nil") + } + if len(chunks) == 0 { + return nil, nil + } + + batchSize := o.BatchSize + if batchSize <= 0 { + batchSize = len(chunks) + } + concurrency := o.Concurrency + if concurrency <= 0 { + concurrency = 1 + } + + out := make([]Vector, len(chunks)) + sem := make(chan struct{}, concurrency) + var ( + wg sync.WaitGroup + mu sync.Mutex + firstErr error + ) + failed := func() bool { + mu.Lock() + defer mu.Unlock() + return firstErr != nil + } + setErr := func(err error) { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + + for start := 0; start < len(chunks); start += batchSize { + if ctx.Err() != nil { + setErr(ctx.Err()) + break + } + if failed() { + break + } + + end := min(start+batchSize, len(chunks)) + texts := make([]string, end-start) + for i, c := range chunks[start:end] { + texts[i] = c.Text + } + + sem <- struct{}{} + wg.Add(1) + go func(start int, texts []string) { + defer wg.Done() + defer func() { <-sem }() + + vecs, err := enc(ctx, texts) + if err != nil { + setErr(fmt.Errorf("encode batch at %d: %w", start, err)) + return + } + if len(vecs) != len(texts) { + setErr(fmt.Errorf("encode batch at %d: got %d vectors for %d texts", start, len(vecs), len(texts))) + return + } + // Each batch owns a disjoint index range, so writes to out + // never overlap across goroutines. + for i, v := range vecs { + out[start+i] = Vector(v) + } + }(start, texts) + } + + wg.Wait() + if firstErr != nil { + return nil, firstErr + } + return out, nil +} diff --git a/vector/encode_test.go b/vector/encode_test.go new file mode 100644 index 0000000..f783493 --- /dev/null +++ b/vector/encode_test.go @@ -0,0 +1,124 @@ +package vector_test + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" +) + +func chunks(texts ...string) []vector.Chunk { + out := make([]vector.Chunk, len(texts)) + for i, txt := range texts { + out[i] = vector.Chunk{Index: i, Text: txt} + } + return out +} + +// echoEncoder returns one vector per text whose single component encodes +// the text length, so results can be matched back to their input order. +func echoEncoder(record func(batch []string)) vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + if record != nil { + record(texts) + } + out := make([][]float32, len(texts)) + for i, txt := range texts { + out[i] = []float32{float32(len(txt))} + } + return out, nil + } +} + +func TestEncodeBatchedPreservesOrderAcrossBatches(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + var mu sync.Mutex + var sizes []int + enc := echoEncoder(func(batch []string) { + mu.Lock() + sizes = append(sizes, len(batch)) + mu.Unlock() + }) + + in := chunks("a", "bb", "ccc", "dddd", "eeeee") + out, err := vector.EncodeBatched(context.Background(), enc, in, vector.BatchOptions{BatchSize: 2, Concurrency: 3}) + require.NoError(err) + require.Len(out, len(in)) + for i, c := range in { + assert.Equal(float32(len(c.Text)), out[i][0], "vector %d matches its input", i) + } + + mu.Lock() + defer mu.Unlock() + assert.ElementsMatch([]int{2, 2, 1}, sizes, "batches are sized by BatchSize") +} + +func TestEncodeBatchedRespectsConcurrencyBound(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + var inFlight, maxInFlight atomic.Int64 + enc := func(_ context.Context, texts []string) ([][]float32, error) { + cur := inFlight.Add(1) + for { + prev := maxInFlight.Load() + if cur <= prev || maxInFlight.CompareAndSwap(prev, cur) { + break + } + } + defer inFlight.Add(-1) + out := make([][]float32, len(texts)) + return out, nil + } + + in := chunks("a", "b", "c", "d", "e", "f", "g", "h") + _, err := vector.EncodeBatched(context.Background(), enc, in, vector.BatchOptions{BatchSize: 1, Concurrency: 2}) + require.NoError(err) + assert.LessOrEqual(maxInFlight.Load(), int64(2), "never exceeds the concurrency bound") +} + +func TestEncodeBatchedSurfacesEncodeError(t *testing.T) { + assert := assert.New(t) + sentinel := errors.New("boom") + enc := func(_ context.Context, _ []string) ([][]float32, error) { return nil, sentinel } + + _, err := vector.EncodeBatched(context.Background(), enc, chunks("a", "b"), vector.BatchOptions{BatchSize: 1}) + assert.ErrorIs(err, sentinel) +} + +func TestEncodeBatchedRejectsCountMismatch(t *testing.T) { + assert := assert.New(t) + enc := func(_ context.Context, _ []string) ([][]float32, error) { + return [][]float32{{1}}, nil // one vector for two texts + } + + _, err := vector.EncodeBatched(context.Background(), enc, chunks("a", "b"), vector.BatchOptions{}) + assert.ErrorContains(err, "vectors for") +} + +func TestEncodeBatchedNilEncoder(t *testing.T) { + _, err := vector.EncodeBatched(context.Background(), nil, chunks("a"), vector.BatchOptions{}) + assert.Error(t, err) +} + +func TestEncodeBatchedEmptyInput(t *testing.T) { + assert := assert.New(t) + out, err := vector.EncodeBatched(context.Background(), echoEncoder(nil), nil, vector.BatchOptions{}) + assert.NoError(err) + assert.Empty(out) +} + +func TestEncodeBatchedStopsOnCancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := vector.EncodeBatched(ctx, echoEncoder(nil), chunks("a", "b"), vector.BatchOptions{BatchSize: 1}) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/vector/flow.go b/vector/flow.go new file mode 100644 index 0000000..65c73c9 --- /dev/null +++ b/vector/flow.go @@ -0,0 +1,115 @@ +package vector + +import ( + "context" + "fmt" +) + +// FillOptions configures Fill. +type FillOptions struct { + // ScanBatch is the number of pending documents fetched per scan. + // Values <= 0 use 128. + ScanBatch int + // Split controls how each document's content is windowed into chunks. + Split SplitOptions + // Batch controls how chunks are batched into encode calls. + Batch BatchOptions +} + +// FillStats reports what a Fill run embedded. +type FillStats struct { + Documents int + Chunks int +} + +// Fill embeds every document that still needs the target generation: it +// scans the store for pending documents, splits and encodes each, and +// saves the resulting vectors, repeating until no documents remain. It is +// the generic scan-and-fill loop; the store decides what counts as +// pending and persists the results. +func Fill[K, G comparable](ctx context.Context, store Store[K, G], gen G, enc EncodeFunc, o FillOptions) (FillStats, error) { + scanBatch := o.ScanBatch + if scanBatch <= 0 { + scanBatch = 128 + } + + var stats FillStats + for { + if err := ctx.Err(); err != nil { + return stats, err + } + pending, err := store.PendingForGeneration(ctx, gen, scanBatch) + if err != nil { + return stats, fmt.Errorf("scan pending: %w", err) + } + if len(pending) == 0 { + return stats, nil + } + + for _, p := range pending { + chunks := Split(p.Content, o.Split) + vectors, err := EncodeBatched(ctx, enc, chunks, o.Batch) + if err != nil { + return stats, fmt.Errorf("encode document %v: %w", p.Doc, err) + } + cvs := make([]ChunkVector, len(chunks)) + for i, c := range chunks { + cvs[i] = ChunkVector{ChunkIndex: c.Index, Vector: vectors[i]} + } + if err := store.SaveVectors(ctx, gen, p.Doc, cvs); err != nil { + return stats, fmt.Errorf("save document %v: %w", p.Doc, err) + } + stats.Documents++ + stats.Chunks += len(cvs) + } + } +} + +// SearchOptions configures Search. +type SearchOptions struct { + // PerGeneration caps how many hits are fetched from each generation + // before merging. Values <= 0 use 50. + PerGeneration int + // Merge configures how per-generation results are combined. + Merge MergeOptions +} + +// Search embeds queryText once per live generation (each may use a +// different model), queries each generation, rolls the chunk hits up to +// documents, and merges the per-generation results into one ranking. +// encFor maps a generation to the encoder for that generation's model. +func Search[K, G comparable]( + ctx context.Context, + store Store[K, G], + queryText string, + encFor func(gen G) EncodeFunc, + o SearchOptions, +) ([]Hit[K], error) { + perGen := o.PerGeneration + if perGen <= 0 { + perGen = 50 + } + + gens, err := store.LiveGenerations(ctx) + if err != nil { + return nil, fmt.Errorf("live generations: %w", err) + } + + lists := make([][]Hit[K], 0, len(gens)) + for _, gen := range gens { + enc := encFor(gen) + if enc == nil { + return nil, fmt.Errorf("no encoder for generation %v", gen) + } + vectors, err := EncodeBatched(ctx, enc, []Chunk{{Index: 0, Text: queryText}}, BatchOptions{}) + if err != nil { + return nil, fmt.Errorf("embed query for generation %v: %w", gen, err) + } + hits, err := store.QueryGeneration(ctx, gen, vectors[0], perGen) + if err != nil { + return nil, fmt.Errorf("query generation %v: %w", gen, err) + } + lists = append(lists, RollupByDocument(hits)) + } + return Merge(lists, o.Merge), nil +} diff --git a/vector/flow_test.go b/vector/flow_test.go new file mode 100644 index 0000000..a7135d5 --- /dev/null +++ b/vector/flow_test.go @@ -0,0 +1,180 @@ +package vector_test + +import ( + "context" + "math" + "slices" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" +) + +// memStore is an in-memory Store[int64, int] used to exercise the flows +// without any real backend. Documents are keyed by int64; generations by +// int. QueryGeneration ranks by cosine similarity over stored vectors. +type memStore struct { + content map[int64]string + embedded map[int64]map[int]bool // doc -> gen -> done + vectors map[int]map[int64][]vector.ChunkVector // gen -> doc -> chunks + live []int // descending preference +} + +func newMemStore() *memStore { + return &memStore{ + content: map[int64]string{}, + embedded: map[int64]map[int]bool{}, + vectors: map[int]map[int64][]vector.ChunkVector{}, + } +} + +func (m *memStore) PendingForGeneration(_ context.Context, gen int, limit int) ([]vector.Pending[int64], error) { + keys := make([]int64, 0, len(m.content)) + for doc := range m.content { + if !m.embedded[doc][gen] { + keys = append(keys, doc) + } + } + slices.Sort(keys) + if limit > 0 && len(keys) > limit { + keys = keys[:limit] + } + out := make([]vector.Pending[int64], len(keys)) + for i, doc := range keys { + out[i] = vector.Pending[int64]{Doc: doc, Content: m.content[doc]} + } + return out, nil +} + +func (m *memStore) SaveVectors(_ context.Context, gen int, doc int64, vecs []vector.ChunkVector) error { + if m.vectors[gen] == nil { + m.vectors[gen] = map[int64][]vector.ChunkVector{} + } + m.vectors[gen][doc] = vecs + if m.embedded[doc] == nil { + m.embedded[doc] = map[int]bool{} + } + m.embedded[doc][gen] = true + return nil +} + +func (m *memStore) LiveGenerations(_ context.Context) ([]int, error) { + return m.live, nil +} + +func (m *memStore) QueryGeneration(_ context.Context, gen int, query vector.Vector, limit int) ([]vector.Hit[int64], error) { + var hits []vector.Hit[int64] + for doc, chunks := range m.vectors[gen] { + for _, cv := range chunks { + hits = append(hits, vector.Hit[int64]{Doc: doc, ChunkIndex: cv.ChunkIndex, Score: cosine(query, cv.Vector)}) + } + } + sort.SliceStable(hits, func(i, j int) bool { return hits[i].Score > hits[j].Score }) + if limit > 0 && len(hits) > limit { + hits = hits[:limit] + } + return hits, nil +} + +func cosine(a, b vector.Vector) float32 { + var dot, na, nb float64 + for i := range a { + dot += float64(a[i]) * float64(b[i]) + na += float64(a[i]) * float64(a[i]) + nb += float64(b[i]) * float64(b[i]) + } + if na == 0 || nb == 0 { + return 0 + } + return float32(dot / (math.Sqrt(na) * math.Sqrt(nb))) +} + +// lenEncoder embeds each text as a 1-D vector of its rune length, enough +// to confirm Fill wired chunk content through to SaveVectors. +func lenEncoder() vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i, txt := range texts { + out[i] = []float32{float32(len([]rune(txt)))} + } + return out, nil + } +} + +func TestFillEmbedsAllPendingThenStops(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + + store := newMemStore() + store.content[1] = "alpha" + store.content[2] = "beta gamma delta" + + stats, err := vector.Fill(ctx, store, 7, lenEncoder(), vector.FillOptions{ + ScanBatch: 1, // force multiple scan rounds + Split: vector.SplitOptions{MaxRunes: 4, Overlap: 0}, + }) + require.NoError(err) + + assert.Equal(2, stats.Documents) + assert.True(store.embedded[1][7] && store.embedded[2][7], "both docs stamped for gen 7") + require.Len(store.vectors[7][1], 2, "alpha -> 2 chunks of <=4 runes") + assert.InDelta(4, store.vectors[7][1][0].Vector[0], 1e-6, "first chunk carries its rune length") + + // A second run finds nothing pending and embeds zero documents. + again, err := vector.Fill(ctx, store, 7, lenEncoder(), vector.FillOptions{}) + require.NoError(err) + assert.Equal(0, again.Documents) +} + +func TestSearchRollsUpAndPrefersBuildingGeneration(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + + const active, building = 7, 9 + store := newMemStore() + store.live = []int{building, active} // descending preference + + // Doc 1 is shared; active stored it at chunk 0, building at chunk 5. + store.SaveVectors(ctx, active, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0}}}) + store.SaveVectors(ctx, active, 2, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{0, 1}}}) + store.SaveVectors(ctx, building, 1, []vector.ChunkVector{{ChunkIndex: 5, Vector: vector.Vector{1, 0}}}) + store.SaveVectors(ctx, building, 3, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0}}}) // new, building-only + + // Query vector [1,0] points at docs 1 and 3. + queryEnc := func(int) vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{1, 0} + } + return out, nil + } + } + + got, err := vector.Search(ctx, store, "q", queryEnc, vector.SearchOptions{}) + require.NoError(err) + + byDoc := map[int64]vector.Hit[int64]{} + for _, h := range got { + byDoc[h.Doc] = h + } + assert.Contains(byDoc, int64(1)) + assert.Contains(byDoc, int64(2), "active-only doc is not dropped (union coverage)") + assert.Contains(byDoc, int64(3), "building-only new doc is searchable mid-migration") + assert.Equal(5, byDoc[1].ChunkIndex, "shared doc keeps the building generation's hit") +} + +func TestSearchErrorsWhenNoEncoderForGeneration(t *testing.T) { + ctx := context.Background() + store := newMemStore() + store.live = []int{1} + store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1}}}) + + _, err := vector.Search(ctx, store, "q", func(int) vector.EncodeFunc { return nil }, vector.SearchOptions{}) + assert.ErrorContains(t, err, "no encoder") +} diff --git a/vector/generation.go b/vector/generation.go new file mode 100644 index 0000000..50c896e --- /dev/null +++ b/vector/generation.go @@ -0,0 +1,58 @@ +package vector + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" +) + +// Generation identifies an embedding model configuration. Two pieces of +// content embedded under generations with the same Fingerprint share a +// vector space; a different fingerprint means the caller should treat the +// vectors as a new generation and re-embed. +// +// Every field that affects the vector space must be exported and JSON +// encodable so Fingerprint accounts for it automatically. A field that +// must not affect identity has to be tagged json:"-". +type Generation struct { + // Model names the embedding model, e.g. "text-embedding-3-small". + Model string `json:"model,omitempty"` + // Dimensions is the length of the vectors the model emits. + Dimensions int `json:"dimensions,omitempty"` + // Params holds any additional knobs that change the vector space, + // such as a pooling mode or prompt template. + Params map[string]string `json:"params,omitempty"` +} + +// Fingerprint returns a stable identifier derived from every field that +// affects the vector space. Callers persist it alongside stored vectors +// and compare it to decide whether a new generation is required. +// +// It is built to be stable across future changes to this type: +// +// - It encodes the struct itself, so a field added later participates +// automatically rather than being silently excluded — the failure +// mode that would let two distinct vector spaces share a fingerprint. +// - It then re-encodes through a generic value, and encoding/json sorts +// object keys at every level, so neither struct field order nor map +// insertion order affects the hash. +// - Decoding with UseNumber preserves numeric tokens exactly, so no +// field loses precision through float64. +// - omitempty drops zero-valued fields, so adding an unused field never +// shifts an existing generation's fingerprint. +// +// All values are JSON encodable, so the marshal and decode errors are +// unreachable. +func (g Generation) Fingerprint() string { + raw, _ := json.Marshal(g) + + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.UseNumber() + var generic any + _ = dec.Decode(&generic) + + canonical, _ := json.Marshal(generic) + sum := sha256.Sum256(canonical) + return hex.EncodeToString(sum[:8]) +} diff --git a/vector/generation_test.go b/vector/generation_test.go new file mode 100644 index 0000000..75e71ac --- /dev/null +++ b/vector/generation_test.go @@ -0,0 +1,93 @@ +package vector_test + +import ( + "crypto/sha256" + "encoding/hex" + "reflect" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func TestGenerationFingerprintIsStableAndOrderIndependent(t *testing.T) { + assert := assert.New(t) + a := vector.Generation{ + Model: "text-embedding-3-small", + Dimensions: 1536, + Params: map[string]string{"pooling": "mean", "prompt": "search"}, + } + b := vector.Generation{ + Model: "text-embedding-3-small", + Dimensions: 1536, + Params: map[string]string{"prompt": "search", "pooling": "mean"}, + } + + assert.Equal(a.Fingerprint(), a.Fingerprint(), "same value fingerprints identically") + assert.Equal(a.Fingerprint(), b.Fingerprint(), "map order does not change fingerprint") +} + +func TestGenerationFingerprintIsNotAmbiguousAcrossParams(t *testing.T) { + assert := assert.New(t) + // Two params vs a single param whose value embeds what used to be the + // key/value separator. A naive "key=value\n" join hashes both the + // same; the JSON encoding keeps them distinct. + two := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"pooling": "mean", "prompt": "x"}} + one := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"pooling": "mean\nprompt=x"}} + + assert.NotEqual(two.Fingerprint(), one.Fingerprint()) +} + +func TestGenerationFingerprintChangesWithSpace(t *testing.T) { + assert := assert.New(t) + base := vector.Generation{Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean"}} + + cases := map[string]vector.Generation{ + "model": {Model: "other", Dimensions: 768, Params: map[string]string{"pooling": "mean"}}, + "dimensions": {Model: "m", Dimensions: 1024, Params: map[string]string{"pooling": "mean"}}, + "param value": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "cls"}}, + "extra param": {Model: "m", Dimensions: 768, Params: map[string]string{"pooling": "mean", "prompt": "x"}}, + } + for name, g := range cases { + t.Run(name, func(t *testing.T) { + assert.NotEqual(base.Fingerprint(), g.Fingerprint()) + }) + } +} + +// TestGenerationFingerprintPinsCanonicalEncoding locks the exact hash +// preimage. If the canonical form ever changes — sorting, omit behavior, +// number formatting, or a field added to the struct's encoding — this +// fails, forcing a conscious decision rather than a silent shift of every +// persisted fingerprint. +func TestGenerationFingerprintPinsCanonicalEncoding(t *testing.T) { + g := vector.Generation{Model: "m", Dimensions: 3, Params: map[string]string{"b": "2", "a": "1"}} + + // Keys sorted at every level, zero fields omitted, numbers verbatim. + const canonical = `{"dimensions":3,"model":"m","params":{"a":"1","b":"2"}}` + sum := sha256.Sum256([]byte(canonical)) + want := hex.EncodeToString(sum[:8]) + + assert.Equal(t, want, g.Fingerprint()) +} + +// TestGenerationFieldsAreTracked is a tripwire: adding, removing, or +// renaming a Generation field changes this set. When it fails, decide +// whether the new field affects the vector space. If it does, Fingerprint +// already includes it (it encodes the whole struct); if it must not, tag +// the field json:"-". Then update this expectation and the pinned +// encoding above. +func TestGenerationFieldsAreTracked(t *testing.T) { + want := []string{"Dimensions", "Model", "Params"} + + fields := reflect.VisibleFields(reflect.TypeFor[vector.Generation]()) + got := make([]string, 0, len(fields)) + for _, f := range fields { + got = append(got, f.Name) + } + sort.Strings(got) + + assert.Equal(t, want, got, "Generation fields changed: review fingerprint impact before updating this tripwire") +} diff --git a/vector/search.go b/vector/search.go new file mode 100644 index 0000000..41c8ff2 --- /dev/null +++ b/vector/search.go @@ -0,0 +1,158 @@ +package vector + +import "sort" + +// Hit is a single search result identifying the document it belongs to. K +// is the caller's document key type (for example int64 or a UUID); this +// package compares keys for equality but never interprets them. +type Hit[K comparable] struct { + // Doc identifies the source document. + Doc K + // ChunkIndex is the chunk within Doc that matched. + ChunkIndex int + // Score is the backend's similarity score for this chunk. Merge + // overwrites it with the merged score under the chosen strategy. + Score float32 +} + +// RollupByDocument reduces chunk-level hits to one hit per document, +// keeping the highest-scoring chunk for each, and returns them sorted by +// score descending. It is the chunk->document step a caller applies to a +// single generation's results before merging across generations. +func RollupByDocument[K comparable](hits []Hit[K]) []Hit[K] { + if len(hits) == 0 { + return nil + } + best := make(map[K]Hit[K], len(hits)) + order := make([]K, 0, len(hits)) + for _, h := range hits { + cur, ok := best[h.Doc] + if !ok { + order = append(order, h.Doc) + best[h.Doc] = h + continue + } + if h.Score > cur.Score { + best[h.Doc] = h + } + } + out := make([]Hit[K], 0, len(order)) + for _, k := range order { + out = append(out, best[k]) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Score > out[j].Score }) + return out +} + +// MergeStrategy selects how Merge orders documents drawn from different +// generations, whose raw scores are not directly comparable. +type MergeStrategy int + +const ( + // MergeNormalizedScore min-max normalizes each generation's scores to + // [0,1] before ordering. It is the default: it keeps score signal + // without letting one generation's score scale dominate. + MergeNormalizedScore MergeStrategy = iota + // MergeRawScore orders by raw score. Use it only when the generations + // share a model family and comparable score distributions. + MergeRawScore + // MergeReciprocalRank ignores absolute scores and fuses by rank. Use + // it when score distributions differ sharply between generations. + MergeReciprocalRank +) + +// MergeOptions configures Merge. +type MergeOptions struct { + // Strategy selects the ordering policy. The zero value is + // MergeNormalizedScore. + Strategy MergeStrategy + // RankConstant is the k term in reciprocal-rank fusion. Values <= 0 + // use 60. + RankConstant float64 + // Limit caps the number of returned hits. Values <= 0 return all. + Limit int +} + +// Merge unions per-generation, document-level result lists into one +// ranking. The lists are given in descending preference: when a document +// appears in more than one list, the hit from the earliest list is kept, +// which is how a caller expresses "prefer the newer generation" during a +// migration. Coverage is a union, so a document present in only one +// generation is never dropped. +// +// Each surviving hit's Score is set to the merged score under the chosen +// strategy, and the result is ordered by that score descending. +func Merge[K comparable](perGeneration [][]Hit[K], o MergeOptions) []Hit[K] { + rep := make(map[K]Hit[K]) + order := make([]K, 0) + score := make(map[K]float64) + + switch o.Strategy { + case MergeReciprocalRank: + k := o.RankConstant + if k <= 0 { + k = 60 + } + for _, list := range perGeneration { + for rank, h := range list { + if _, ok := rep[h.Doc]; !ok { + rep[h.Doc] = h + order = append(order, h.Doc) + } + score[h.Doc] += 1.0 / (k + float64(rank) + 1.0) + } + } + case MergeRawScore: + for _, list := range perGeneration { + for _, h := range list { + if _, ok := rep[h.Doc]; ok { + continue + } + rep[h.Doc] = h + order = append(order, h.Doc) + score[h.Doc] = float64(h.Score) + } + } + default: // MergeNormalizedScore + for _, list := range perGeneration { + lo, hi := scoreRange(list) + span := hi - lo + for _, h := range list { + if _, ok := rep[h.Doc]; ok { + continue + } + rep[h.Doc] = h + order = append(order, h.Doc) + if span > 0 { + score[h.Doc] = float64(h.Score-lo) / float64(span) + } else { + score[h.Doc] = 1 + } + } + } + } + + out := make([]Hit[K], 0, len(order)) + for _, doc := range order { + h := rep[doc] + h.Score = float32(score[doc]) + out = append(out, h) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Score > out[j].Score }) + if o.Limit > 0 && len(out) > o.Limit { + out = out[:o.Limit] + } + return out +} + +func scoreRange[K comparable](hits []Hit[K]) (lo, hi float32) { + for i, h := range hits { + if i == 0 || h.Score < lo { + lo = h.Score + } + if i == 0 || h.Score > hi { + hi = h.Score + } + } + return lo, hi +} diff --git a/vector/search_test.go b/vector/search_test.go new file mode 100644 index 0000000..1137818 --- /dev/null +++ b/vector/search_test.go @@ -0,0 +1,109 @@ +package vector_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.kenn.io/kit/vector" +) + +func docs[K comparable](hits []vector.Hit[K]) []K { + out := make([]K, len(hits)) + for i, h := range hits { + out[i] = h.Doc + } + return out +} + +func TestRollupByDocumentKeepsBestChunkPerDoc(t *testing.T) { + assert := assert.New(t) + hits := []vector.Hit[int64]{ + {Doc: 1, ChunkIndex: 0, Score: 0.2}, + {Doc: 2, ChunkIndex: 0, Score: 0.9}, + {Doc: 1, ChunkIndex: 3, Score: 0.7}, // better chunk for doc 1 + {Doc: 2, ChunkIndex: 1, Score: 0.4}, + } + + got := vector.RollupByDocument(hits) + + assert.Equal([]int64{2, 1}, docs(got), "one hit per doc, ordered by score desc") + assert.Equal(3, got[1].ChunkIndex, "doc 1 keeps its highest-scoring chunk") + assert.InDelta(0.7, got[1].Score, 1e-6) +} + +func TestMergeUnionsAndPrefersEarlierGeneration(t *testing.T) { + assert := assert.New(t) + // String keys stand in for kata's UUIDs; building generation first. + building := []vector.Hit[string]{ + {Doc: "shared", Score: 0.50}, + {Doc: "new-only", Score: 0.40}, + } + active := []vector.Hit[string]{ + {Doc: "shared", Score: 0.99}, // higher raw score, but less preferred + {Doc: "old-only", Score: 0.80}, + } + + got := vector.Merge([][]vector.Hit[string]{building, active}, vector.MergeOptions{Strategy: vector.MergeRawScore}) + + assert.ElementsMatch([]string{"shared", "new-only", "old-only"}, docs(got), "coverage is a union") + for _, h := range got { + if h.Doc == "shared" { + assert.InDelta(0.50, h.Score, 1e-6, "shared doc keeps the preferred (building) hit, not the higher raw score") + } + } +} + +func TestMergeNormalizedScoreIsDefault(t *testing.T) { + assert := assert.New(t) + // Active generation scores live in a compressed high band; building + // generation in a low band. Raw merge would let active dominate; + // normalization puts each generation's top hit at 1.0. + active := []vector.Hit[int]{ + {Doc: 1, Score: 0.90}, + {Doc: 2, Score: 0.85}, + } + building := []vector.Hit[int]{ + {Doc: 3, Score: 0.20}, + {Doc: 4, Score: 0.10}, + } + + got := vector.Merge([][]vector.Hit[int]{building, active}, vector.MergeOptions{}) + + // Each generation's best-normalized hit should reach the top band. + top := got[0] + assert.Contains([]int{1, 3}, top.Doc, "a normalized top hit leads, not just the raw-highest") + assert.InDelta(1.0, float64(top.Score), 1e-6) +} + +func TestMergeReciprocalRankFusesAcrossGenerations(t *testing.T) { + assert := assert.New(t) + // "shared" is rank 1 in one list and rank 2 in the other, so its + // fused score should beat docs that appear in only one list. + a := []vector.Hit[int]{ + {Doc: 10, Score: 0.99}, + {Doc: 99, Score: 0.98}, + } + b := []vector.Hit[int]{ + {Doc: 99, Score: 0.50}, + {Doc: 20, Score: 0.49}, + } + + got := vector.Merge([][]vector.Hit[int]{a, b}, vector.MergeOptions{Strategy: vector.MergeReciprocalRank}) + + assert.Equal(99, got[0].Doc, "the doc found in both generations ranks first") +} + +func TestMergeRespectsLimit(t *testing.T) { + assert := assert.New(t) + list := []vector.Hit[int]{{Doc: 1, Score: 0.9}, {Doc: 2, Score: 0.8}, {Doc: 3, Score: 0.7}} + + got := vector.Merge([][]vector.Hit[int]{list}, vector.MergeOptions{Strategy: vector.MergeRawScore, Limit: 2}) + + assert.Len(got, 2) + assert.Equal([]int{1, 2}, docs(got)) +} + +func TestMergeEmpty(t *testing.T) { + assert.Empty(t, vector.Merge[int](nil, vector.MergeOptions{})) +} diff --git a/vector/sqlitevec/sqlitevec.go b/vector/sqlitevec/sqlitevec.go new file mode 100644 index 0000000..86450f1 --- /dev/null +++ b/vector/sqlitevec/sqlitevec.go @@ -0,0 +1,187 @@ +// Package sqlitevec implements vector.Store on top of SQLite with the +// sqlite-vec extension. It is a reference backend: a worked example of the +// storage contract the vector flows depend on, built against the same +// sqlite-vec binding msgvault uses. +// +// Callers must register the extension once before opening their database: +// +// sqlitevec.Register() +// db, _ := sql.Open("sqlite3", path) +// +// The caller owns the documents table; this package owns a small set of +// vector tables derived from VectorsPrefix. Each generation gets its own +// vec0 virtual table sized to that generation's dimension, so generations +// with different model dimensions coexist during a migration. +package sqlitevec + +import ( + "context" + "database/sql" + "fmt" + "regexp" + + vecext "github.com/asg017/sqlite-vec-go-bindings/cgo" + + "go.kenn.io/kit/vector" +) + +// Register loads the sqlite-vec extension into every SQLite connection +// opened afterwards in this process. It must be called before opening the +// database the store will use. +func Register() { vecext.Auto() } + +var identifierPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// State is a generation's role in the active/building lifecycle. Only +// building and active generations are searched; building sorts ahead of +// active so Merge keeps the newer generation's hit on overlap. +type State string + +const ( + StatePending State = "pending" + StateBuilding State = "building" + StateActive State = "active" + StateRetired State = "retired" +) + +// Schema names the caller's documents table and the prefix for the +// vector tables this package manages. Every field must be a bare SQL +// identifier; values are validated before being interpolated into SQL. +type Schema struct { + DocsTable string // caller's documents table, e.g. "messages" + IDColumn string // primary key column, e.g. "id" + ContentColumn string // text to embed, e.g. "body" + EmbedGenColumn string // nullable generation stamp, e.g. "embed_gen" + VectorsPrefix string // prefix for managed tables, e.g. "message_vectors" +} + +func (s Schema) validate() error { + for name, value := range map[string]string{ + "docs table": s.DocsTable, + "id column": s.IDColumn, + "content column": s.ContentColumn, + "embed gen column": s.EmbedGenColumn, + "vectors prefix": s.VectorsPrefix, + } { + if !identifierPattern.MatchString(value) { + return fmt.Errorf("invalid %s %q", name, value) + } + } + return nil +} + +// Store implements vector.Store[K, G] against SQLite + sqlite-vec. K is the +// caller's document key type and G its generation key type; both must be +// types database/sql can bind and scan (for example int64 or string). +type Store[K, G comparable] struct { + db *sql.DB + schema Schema +} + +// New returns a Store bound to db. The caller retains ownership of db. New +// creates the generations and chunks bookkeeping tables if they do not +// exist; per-generation vec0 tables are created by EnsureGeneration. +func New[K, G comparable](ctx context.Context, db *sql.DB, schema Schema) (*Store[K, G], error) { + if err := schema.validate(); err != nil { + return nil, err + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + s := &Store[K, G]{db: db, schema: schema} + if _, err := db.ExecContext(ctx, fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( + ordinal INTEGER PRIMARY KEY, + gen_key UNIQUE, + dimension INTEGER NOT NULL, + state TEXT NOT NULL +); +CREATE TABLE IF NOT EXISTS %s ( + ordinal INTEGER NOT NULL, + doc_key NOT NULL, + chunk_index INTEGER NOT NULL, + vec_rowid INTEGER NOT NULL, + PRIMARY KEY (ordinal, doc_key, chunk_index) +);`, s.generationsTable(), s.chunksTable())); err != nil { + return nil, fmt.Errorf("create bookkeeping tables: %w", err) + } + return s, nil +} + +func (s *Store[K, G]) generationsTable() string { return s.schema.VectorsPrefix + "_generations" } +func (s *Store[K, G]) chunksTable() string { return s.schema.VectorsPrefix + "_chunks" } +func (s *Store[K, G]) vecTable(ordinal int64) string { + return fmt.Sprintf("%s_v%d", s.schema.VectorsPrefix, ordinal) +} + +// EnsureGeneration registers gen with model's dimension and the given +// state, creating its vec0 table on first use. Calling it again updates +// only the state; a generation's dimension is fixed once created. +func (s *Store[K, G]) EnsureGeneration(ctx context.Context, gen G, model vector.Generation, state State) error { + if model.Dimensions <= 0 { + return fmt.Errorf("generation dimension must be positive, got %d", model.Dimensions) + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin ensure generation: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if _, err := tx.ExecContext(ctx, fmt.Sprintf(` +INSERT INTO %s (gen_key, dimension, state) VALUES (?, ?, ?) +ON CONFLICT(gen_key) DO UPDATE SET state = excluded.state`, s.generationsTable()), + gen, model.Dimensions, string(state)); err != nil { + return fmt.Errorf("upsert generation: %w", err) + } + + ordinal, dimension, err := s.lookupGenerationTx(ctx, tx, gen) + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf( + `CREATE VIRTUAL TABLE IF NOT EXISTS %s USING vec0(embedding float[%d] distance_metric=cosine)`, + s.vecTable(ordinal), dimension)); err != nil { + return fmt.Errorf("create vec0 table: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit ensure generation: %w", err) + } + return nil +} + +// SetGenerationState transitions gen to state. The caller owns the +// active/building lifecycle; this only records the decision. +func (s *Store[K, G]) SetGenerationState(ctx context.Context, gen G, state State) error { + res, err := s.db.ExecContext(ctx, + fmt.Sprintf(`UPDATE %s SET state = ? WHERE gen_key = ?`, s.generationsTable()), + string(state), gen) + if err != nil { + return fmt.Errorf("set generation state: %w", err) + } + if n, _ := res.RowsAffected(); n == 0 { + return fmt.Errorf("generation %v not found", gen) + } + return nil +} + +func (s *Store[K, G]) lookupGeneration(ctx context.Context, gen G) (ordinal int64, dimension int, err error) { + return s.scanGeneration(s.db.QueryRowContext(ctx, + fmt.Sprintf(`SELECT ordinal, dimension FROM %s WHERE gen_key = ?`, s.generationsTable()), gen), gen) +} + +func (s *Store[K, G]) lookupGenerationTx(ctx context.Context, tx *sql.Tx, gen G) (int64, int, error) { + return s.scanGeneration(tx.QueryRowContext(ctx, + fmt.Sprintf(`SELECT ordinal, dimension FROM %s WHERE gen_key = ?`, s.generationsTable()), gen), gen) +} + +func (s *Store[K, G]) scanGeneration(row *sql.Row, gen G) (int64, int, error) { + var ordinal int64 + var dimension int + if err := row.Scan(&ordinal, &dimension); err != nil { + if err == sql.ErrNoRows { + return 0, 0, fmt.Errorf("generation %v not ensured", gen) + } + return 0, 0, fmt.Errorf("lookup generation %v: %w", gen, err) + } + return ordinal, dimension, nil +} diff --git a/vector/sqlitevec/sqlitevec_test.go b/vector/sqlitevec/sqlitevec_test.go new file mode 100644 index 0000000..7691329 --- /dev/null +++ b/vector/sqlitevec/sqlitevec_test.go @@ -0,0 +1,157 @@ +package sqlitevec_test + +import ( + "context" + "database/sql" + "path/filepath" + "strings" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.kenn.io/kit/vector" + "go.kenn.io/kit/vector/sqlitevec" +) + +// topicEncoder maps text to a one-hot 3-D vector by keyword, so queries +// match documents deterministically. +func topicEncoder() vector.EncodeFunc { + return func(_ context.Context, texts []string) ([][]float32, error) { + out := make([][]float32, len(texts)) + for i, text := range texts { + switch { + case strings.Contains(text, "cat"): + out[i] = []float32{1, 0, 0} + case strings.Contains(text, "dog"): + out[i] = []float32{0, 1, 0} + default: + out[i] = []float32{0, 0, 1} + } + } + return out, nil + } +} + +func setup(t *testing.T) (*sql.DB, *sqlitevec.Store[int64, int64]) { + t.Helper() + sqlitevec.Register() + + db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "vec.db")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, db.Close()) }) + + _, err = db.Exec(`CREATE TABLE messages (id INTEGER PRIMARY KEY, body TEXT, embed_gen INTEGER)`) + require.NoError(t, err) + + store, err := sqlitevec.New[int64, int64](context.Background(), db, sqlitevec.Schema{ + DocsTable: "messages", + IDColumn: "id", + ContentColumn: "body", + EmbedGenColumn: "embed_gen", + VectorsPrefix: "message_vectors", + }) + require.NoError(t, err) + return db, store +} + +func TestStoreFillThenSearch(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat sat'), (2, 'a dog ran')`) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + stats, err := vector.Fill(ctx, store, 1, topicEncoder(), vector.FillOptions{}) + require.NoError(err) + assert.Equal(2, stats.Documents) + + pending, err := store.PendingForGeneration(ctx, 1, 10) + require.NoError(err) + assert.Empty(pending, "nothing pending once every document is stamped") + + enc := func(int64) vector.EncodeFunc { return topicEncoder() } + hits, err := vector.Search(ctx, store, "a cat", enc, vector.SearchOptions{}) + require.NoError(err) + require.NotEmpty(hits) + assert.Equal(int64(1), hits[0].Doc, "the cat query ranks the cat document first") +} + +func TestStoreReembeddingReplacesVectors(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat sat')`) + require.NoError(err) + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + require.NoError(store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}})) + require.NoError(store.SaveVectors(ctx, 1, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{0, 1, 0}}})) + + hits, err := store.QueryGeneration(ctx, 1, vector.Vector{0, 1, 0}, 10) + require.NoError(err) + require.Len(hits, 1, "re-embedding replaces the prior vector rather than duplicating it") + assert.InDelta(1.0, hits[0].Score, 1e-6, "stored vector now matches the new query") +} + +func TestStoreSearchUnionsLiveGenerations(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + db, store := setup(t) + + _, err := db.ExecContext(ctx, `INSERT INTO messages (id, body) VALUES (1, 'a cat'), (2, 'a dog')`) + require.NoError(err) + + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "v1", Dimensions: 3}, sqlitevec.StateActive)) + _, err = vector.Fill(ctx, store, 1, topicEncoder(), vector.FillOptions{}) + require.NoError(err) + + // The building generation has covered only doc 1 so far. + require.NoError(store.EnsureGeneration(ctx, 2, vector.Generation{Model: "v2", Dimensions: 3}, sqlitevec.StateBuilding)) + require.NoError(store.SaveVectors(ctx, 2, 1, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}})) + + gens, err := store.LiveGenerations(ctx) + require.NoError(err) + assert.Equal([]int64{2, 1}, gens, "building precedes active in preference order") + + enc := func(int64) vector.EncodeFunc { return topicEncoder() } + hits, err := vector.Search(ctx, store, "a cat", enc, vector.SearchOptions{}) + require.NoError(err) + + found := map[int64]bool{} + for _, h := range hits { + found[h.Doc] = true + } + assert.True(found[1], "shared doc is searchable") + assert.True(found[2], "active-only doc is not dropped mid-migration (union coverage)") +} + +func TestStoreSaveVectorsRejectsMissingDocument(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + ctx := context.Background() + _, store := setup(t) + + require.NoError(store.EnsureGeneration(ctx, 1, vector.Generation{Model: "m", Dimensions: 3}, sqlitevec.StateActive)) + + err := store.SaveVectors(ctx, 1, 999, []vector.ChunkVector{{ChunkIndex: 0, Vector: vector.Vector{1, 0, 0}}}) + require.Error(err, "saving vectors for a document not in the source table fails") + + hits, err := store.QueryGeneration(ctx, 1, vector.Vector{1, 0, 0}, 10) + require.NoError(err) + assert.Empty(hits, "no orphan vectors are committed when the source row is missing") +} + +func TestNewRejectsUnsafeIdentifiers(t *testing.T) { + _, err := sqlitevec.New[int64, int64](context.Background(), nil, sqlitevec.Schema{ + DocsTable: "messages; DROP TABLE messages", + }) + require.Error(t, err) +} diff --git a/vector/sqlitevec/store.go b/vector/sqlitevec/store.go new file mode 100644 index 0000000..69ffaf9 --- /dev/null +++ b/vector/sqlitevec/store.go @@ -0,0 +1,207 @@ +package sqlitevec + +import ( + "context" + "database/sql" + "fmt" + + vecext "github.com/asg017/sqlite-vec-go-bindings/cgo" + + "go.kenn.io/kit/vector" +) + +// PendingForGeneration scans the caller's documents table for rows whose +// stamp does not yet match gen, ordered by primary key for stable paging. +func (s *Store[K, G]) PendingForGeneration(ctx context.Context, gen G, limit int) ([]vector.Pending[K], error) { + query := fmt.Sprintf( + `SELECT %s, %s FROM %s WHERE %s IS NULL OR %s <> ? ORDER BY %s LIMIT ?`, + s.schema.IDColumn, s.schema.ContentColumn, s.schema.DocsTable, + s.schema.EmbedGenColumn, s.schema.EmbedGenColumn, s.schema.IDColumn) + rows, err := s.db.QueryContext(ctx, query, gen, limit) + if err != nil { + return nil, fmt.Errorf("scan pending: %w", err) + } + defer func() { _ = rows.Close() }() + + var pending []vector.Pending[K] + for rows.Next() { + var p vector.Pending[K] + if err := rows.Scan(&p.Doc, &p.Content); err != nil { + return nil, fmt.Errorf("scan pending row: %w", err) + } + pending = append(pending, p) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan pending rows: %w", err) + } + return pending, nil +} + +// SaveVectors replaces doc's chunk vectors for gen and stamps the document +// as embedded for gen, all in one transaction. +func (s *Store[K, G]) SaveVectors(ctx context.Context, gen G, doc K, vectors []vector.ChunkVector) error { + ordinal, dimension, err := s.lookupGeneration(ctx, gen) + if err != nil { + return err + } + for _, cv := range vectors { + if len(cv.Vector) != dimension { + return fmt.Errorf("chunk %d has %d dimensions, generation expects %d", cv.ChunkIndex, len(cv.Vector), dimension) + } + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin save vectors: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // Drop any prior vectors for this document so re-embedding is clean. + rowids, err := s.docRowids(ctx, tx, ordinal, doc) + if err != nil { + return err + } + for _, rowid := range rowids { + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE rowid = ?`, s.vecTable(ordinal)), rowid); err != nil { + return fmt.Errorf("delete stale vector: %w", err) + } + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`DELETE FROM %s WHERE ordinal = ? AND doc_key = ?`, s.chunksTable()), ordinal, doc); err != nil { + return fmt.Errorf("delete stale chunk map: %w", err) + } + + for _, cv := range vectors { + blob, err := vecext.SerializeFloat32(cv.Vector) + if err != nil { + return fmt.Errorf("serialize vector: %w", err) + } + res, err := tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO %s (embedding) VALUES (?)`, s.vecTable(ordinal)), blob) + if err != nil { + return fmt.Errorf("insert vector: %w", err) + } + rowid, err := res.LastInsertId() + if err != nil { + return fmt.Errorf("vector rowid: %w", err) + } + if _, err := tx.ExecContext(ctx, + fmt.Sprintf(`INSERT INTO %s (ordinal, doc_key, chunk_index, vec_rowid) VALUES (?, ?, ?, ?)`, s.chunksTable()), + ordinal, doc, cv.ChunkIndex, rowid); err != nil { + return fmt.Errorf("insert chunk map: %w", err) + } + } + + res, err := tx.ExecContext(ctx, + fmt.Sprintf(`UPDATE %s SET %s = ? WHERE %s = ?`, s.schema.DocsTable, s.schema.EmbedGenColumn, s.schema.IDColumn), + gen, doc) + if err != nil { + return fmt.Errorf("stamp embed generation: %w", err) + } + stamped, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("stamp embed generation rows: %w", err) + } + if stamped == 0 { + // The source row vanished between scan and save (or the key is + // wrong). Roll back rather than commit vectors with no document, + // which QueryGeneration would otherwise surface as orphan hits. + return fmt.Errorf("document %v not present in %s; vectors not persisted", doc, s.schema.DocsTable) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit save vectors: %w", err) + } + return nil +} + +func (s *Store[K, G]) docRowids(ctx context.Context, tx txQuerier, ordinal int64, doc K) ([]int64, error) { + rows, err := tx.QueryContext(ctx, + fmt.Sprintf(`SELECT vec_rowid FROM %s WHERE ordinal = ? AND doc_key = ?`, s.chunksTable()), ordinal, doc) + if err != nil { + return nil, fmt.Errorf("read chunk map: %w", err) + } + defer func() { _ = rows.Close() }() + + var rowids []int64 + for rows.Next() { + var rowid int64 + if err := rows.Scan(&rowid); err != nil { + return nil, fmt.Errorf("scan chunk rowid: %w", err) + } + rowids = append(rowids, rowid) + } + return rowids, rows.Err() +} + +// LiveGenerations returns building and active generations, building first, +// so Merge prefers the newer generation when a document is in both. +func (s *Store[K, G]) LiveGenerations(ctx context.Context) ([]G, error) { + rows, err := s.db.QueryContext(ctx, fmt.Sprintf(` +SELECT gen_key FROM %s + WHERE state IN (?, ?) + ORDER BY CASE state WHEN ? THEN 0 ELSE 1 END, ordinal`, s.generationsTable()), + string(StateBuilding), string(StateActive), string(StateBuilding)) + if err != nil { + return nil, fmt.Errorf("list live generations: %w", err) + } + defer func() { _ = rows.Close() }() + + var gens []G + for rows.Next() { + var gen G + if err := rows.Scan(&gen); err != nil { + return nil, fmt.Errorf("scan generation key: %w", err) + } + gens = append(gens, gen) + } + return gens, rows.Err() +} + +// QueryGeneration runs a cosine KNN search within gen's vec0 table and +// maps each neighbor back to its document and chunk. Score is the cosine +// similarity (1 - cosine distance), so higher is more similar. +func (s *Store[K, G]) QueryGeneration(ctx context.Context, gen G, query vector.Vector, limit int) ([]vector.Hit[K], error) { + ordinal, dimension, err := s.lookupGeneration(ctx, gen) + if err != nil { + return nil, err + } + if len(query) != dimension { + return nil, fmt.Errorf("query has %d dimensions, generation expects %d", len(query), dimension) + } + blob, err := vecext.SerializeFloat32(query) + if err != nil { + return nil, fmt.Errorf("serialize query: %w", err) + } + + // The KNN runs against the vec0 table alone (its required form), then + // joins to the chunk map to recover document keys. + sqlText := fmt.Sprintf(` +WITH knn AS ( + SELECT rowid, distance FROM %s WHERE embedding MATCH ? ORDER BY distance LIMIT ? +) +SELECT c.doc_key, c.chunk_index, knn.distance + FROM knn JOIN %s c ON c.ordinal = ? AND c.vec_rowid = knn.rowid + ORDER BY knn.distance`, s.vecTable(ordinal), s.chunksTable()) + rows, err := s.db.QueryContext(ctx, sqlText, blob, limit, ordinal) + if err != nil { + return nil, fmt.Errorf("query generation: %w", err) + } + defer func() { _ = rows.Close() }() + + var hits []vector.Hit[K] + for rows.Next() { + var ( + doc K + chunkIndex int + distance float64 + ) + if err := rows.Scan(&doc, &chunkIndex, &distance); err != nil { + return nil, fmt.Errorf("scan hit: %w", err) + } + hits = append(hits, vector.Hit[K]{Doc: doc, ChunkIndex: chunkIndex, Score: float32(1 - distance)}) + } + return hits, rows.Err() +} + +// txQuerier is the read surface shared by *sql.DB and *sql.Tx. +type txQuerier interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} diff --git a/vector/sqlitevec/vec0_smoke_test.go b/vector/sqlitevec/vec0_smoke_test.go new file mode 100644 index 0000000..0550cee --- /dev/null +++ b/vector/sqlitevec/vec0_smoke_test.go @@ -0,0 +1,42 @@ +package sqlitevec + +import ( + "database/sql" + "testing" + + sqlitevec "github.com/asg017/sqlite-vec-go-bindings/cgo" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" +) + +// TestVec0LoadsHermetically confirms the sqlite-vec extension is compiled +// in via the cgo binding, so the backend's tests need no external setup. +func TestVec0LoadsHermetically(t *testing.T) { + require := require.New(t) + sqlitevec.Auto() + + db, err := sql.Open("sqlite3", ":memory:") + require.NoError(err) + t.Cleanup(func() { require.NoError(db.Close()) }) + + _, err = db.Exec(`CREATE VIRTUAL TABLE v USING vec0(embedding float[3])`) + require.NoError(err) + + vec, err := sqlitevec.SerializeFloat32([]float32{1, 2, 3}) + require.NoError(err) + _, err = db.Exec(`INSERT INTO v(rowid, embedding) VALUES (1, ?)`, vec) + require.NoError(err) + + query, err := sqlitevec.SerializeFloat32([]float32{1, 2, 3}) + require.NoError(err) + var rowid int64 + var distance float64 + err = db.QueryRow( + `SELECT rowid, distance FROM v WHERE embedding MATCH ? ORDER BY distance LIMIT 1`, + query, + ).Scan(&rowid, &distance) + require.NoError(err) + + require.Equal(int64(1), rowid) + require.InDelta(0, distance, 1e-6, "identical vectors have zero distance") +} diff --git a/vector/store.go b/vector/store.go new file mode 100644 index 0000000..28facc9 --- /dev/null +++ b/vector/store.go @@ -0,0 +1,47 @@ +package vector + +import "context" + +// Pending is one document that still needs embedding for a generation, +// paired with the text to embed. +type Pending[K comparable] struct { + Doc K + Content string +} + +// ChunkVector is a single chunk's embedding, ready to persist. +type ChunkVector struct { + ChunkIndex int + Vector Vector +} + +// Store is the persistence contract the Fill and Search flows depend on. +// Implementations are a function of the caller's source system — a SQLite, +// pgvector, or DuckDB table — and own all backend SQL and query +// construction. The flows never open a database or build SQL themselves. +// +// K is the caller's document key type and G its generation id type; the +// package compares both for equality but never interprets them. +type Store[K, G comparable] interface { + // PendingForGeneration returns up to limit documents that are not yet + // embedded for gen, in a stable order. Implementations typically scan + // for "embed_gen IS NULL OR embed_gen <> gen". A document must stop + // being reported once SaveVectors has persisted it for gen, so that a + // fill loop terminates. + PendingForGeneration(ctx context.Context, gen G, limit int) ([]Pending[K], error) + + // SaveVectors persists every chunk vector for doc under gen and marks + // doc as embedded for gen (the scan-and-fill stamp). + SaveVectors(ctx context.Context, gen G, doc K, vectors []ChunkVector) error + + // LiveGenerations returns the generations a search should query, in + // descending preference. During a migration the building generation + // precedes the active one, so Merge keeps the newer generation's hit + // when a document appears in both. + LiveGenerations(ctx context.Context) ([]G, error) + + // QueryGeneration returns chunk-level hits for query within gen, + // ranked best first and capped at limit. This is where each backend's + // vector query construction lives. + QueryGeneration(ctx context.Context, gen G, query Vector, limit int) ([]Hit[K], error) +}