diff --git a/src/index.ts b/src/index.ts index 69e1ae8..616eb07 100644 --- a/src/index.ts +++ b/src/index.ts @@ -36,6 +36,8 @@ function graceMs(env: Env): number { const DUPLICATE_BLOCK_THRESHOLD = 0.95; const DUPLICATE_FLAG_THRESHOLD = 0.85; const CANDIDATE_SCORE_THRESHOLD = 0.45; +const TAG_BOOST_STEP = 0.15; +const TAG_BOOST_MAX = 1.5; // ─── Model constants ────────────────────────────────────────────────────────── @@ -411,7 +413,8 @@ export function cosineSim(a: ArrayLike, b: ArrayLike): number { export function rerankWithTimeDecay( matches: VectorizeMatch[], recallCounts: Map = new Map(), - importanceScores: Map = new Map() + importanceScores: Map = new Map(), + queryTags: string[] = [] ): VectorizeMatch[] { const now = Date.now(); @@ -440,7 +443,12 @@ export function rerankWithTimeDecay( const imp = importanceScores.get(parentId) ?? 0; const importanceMultiplier = imp === 0 ? 1.0 : 0.8 + (imp / 5) * 0.4; - return { ...match, score: match.score * combinedMultiplier * appendPenalty * rolledUpPenalty * importanceMultiplier }; + // Tag boost: applied outside the recency ≤1.0 cap so a tag-relevant memory can + // surface above a marginally-closer but irrelevant one. + const overlap = queryTags.length ? tags.filter(t => queryTags.includes(t)).length : 0; + const tagBoost = overlap ? Math.min(TAG_BOOST_MAX, 1 + overlap * TAG_BOOST_STEP) : 1.0; + + return { ...match, score: match.score * combinedMultiplier * appendPenalty * rolledUpPenalty * importanceMultiplier * tagBoost }; }) .sort((a, b) => b.score - a.score); } @@ -520,6 +528,43 @@ export function extractHashtags(content: string): { cleanContent: string; hashta return { cleanContent, hashtags }; } +// ─── Query tag inference ────────────────────────────────────────────────────── + +export async function inferQueryTags(query: string, env: Env): Promise { + const { hashtags } = extractHashtags(query); + if (hashtags.length) return hashtags; + + const { results: tagRows } = await env.DB.prepare( + `SELECT DISTINCT value FROM entries, json_each(entries.tags) ORDER BY value` + ).all(); + const knownTags = (tagRows as { value: string }[]).map(r => r.value); + + const lowerQuery = query.toLowerCase(); + const keywordMatches = knownTags.filter(t => + new RegExp(`(? t.trim().toLowerCase()).filter(t => t && knownSet.has(t)); + } catch { + return []; + } +} + // ─── Shared entry-listing filter builder ───────────────────────────────────── // Builds the WHERE/ORDER/LIMIT clause shared by list_recent and GET /list so // both stay in sync on which filters (tag, after, before) are supported. @@ -932,7 +977,10 @@ export async function recallEntries( embedQuery = parsed.cleanQuery; } - const values = await embed(embedQuery, env); + const [values, queryTags] = await Promise.all([ + embed(embedQuery, env), + inferQueryTags(embedQuery, env), + ]); let results: { matches: VectorizeMatch[] }; if (tag) { @@ -996,7 +1044,7 @@ export async function recallEntries( const recallCounts = new Map(rcRows.map(r => [r.id, r.recall_count ?? 0])); const importanceScores = new Map(rcRows.map(r => [r.id, r.importance_score ?? 0])); - const reranked = rerankWithTimeDecay(results.matches as VectorizeMatch[], recallCounts, importanceScores); + const reranked = rerankWithTimeDecay(results.matches as VectorizeMatch[], recallCounts, importanceScores, queryTags); const seen = new Set(); const deduped = reranked.filter((m) => { diff --git a/test/integration/recall.test.ts b/test/integration/recall.test.ts index 652b376..e42022c 100644 --- a/test/integration/recall.test.ts +++ b/test/integration/recall.test.ts @@ -278,4 +278,62 @@ describe("GET /recall", () => { const scoringCalls = prepareSpy.mock.calls.filter(([sql]) => sql.includes("recall_count, importance_score")); expect(scoringCalls).toHaveLength(2); }); + + it("hashtag or keyword in query skips the LLM during tag inference", async () => { + db.entries.push( + { id: "entry-1", content: "Work meeting notes", tags: '["work"]', source: "api", created_at: 1000, vector_ids: '["entry-1"]', recall_count: 0, importance_score: 0 }, + ); + const aiRun = vi.fn().mockImplementation(async (model: string) => { + if (model === "@cf/baai/bge-small-en-v1.5") return { data: [new Array(384).fill(0.1)] }; + return new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode('data: {"response":"work"}\n\n')); + c.enqueue(new TextEncoder().encode("data: [DONE]\n\n")); + c.close(); + }, + }); + }); + env = makeTestEnv(db, { + AI: { run: aiRun } as unknown as Ai, + VECTORIZE: makeVectorizeMock({ + query: vi.fn().mockResolvedValue({ matches: [makeMatch("entry-1", 0.9)] }), + }), + }); + + const res = await worker.fetch(req("GET", "/recall?query=work+meeting"), env, ctx); + expect(res.status).toBe(200); + // "work" is a known tag AND appears as a keyword in the query → LLM not called for inference + // (embed call uses BGE model; only LLM calls use other models) + const llmCalls = aiRun.mock.calls.filter((args: any[]) => args[0] !== "@cf/baai/bge-small-en-v1.5"); + expect(llmCalls).toHaveLength(0); + }); + + it("query with no matching keywords exercises the LLM fallback for tag inference", async () => { + db.entries.push( + { id: "entry-1", content: "Office lease renewal", tags: '["work"]', source: "api", created_at: 1000, vector_ids: '["entry-1"]', recall_count: 0, importance_score: 0 }, + ); + const aiRun = vi.fn().mockImplementation(async (model: string) => { + if (model === "@cf/baai/bge-small-en-v1.5") return { data: [new Array(384).fill(0.1)] }; + return new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode('data: {"response":"work"}\n\n')); + c.enqueue(new TextEncoder().encode("data: [DONE]\n\n")); + c.close(); + }, + }); + }); + env = makeTestEnv(db, { + AI: { run: aiRun } as unknown as Ai, + VECTORIZE: makeVectorizeMock({ + query: vi.fn().mockResolvedValue({ matches: [makeMatch("entry-1", 0.9)] }), + }), + }); + + // "quarterly planning" — no hashtags, "work" is not a whole word in this query + const res = await worker.fetch(req("GET", "/recall?query=quarterly+planning"), env, ctx); + expect(res.status).toBe(200); + // LLM called at least once (for tag inference); embedding uses BGE model (not counted) + const llmCalls = aiRun.mock.calls.filter((args: any[]) => args[0] !== "@cf/baai/bge-small-en-v1.5"); + expect(llmCalls.length).toBeGreaterThanOrEqual(1); + }); }); diff --git a/test/unit/infer-query-tags.test.ts b/test/unit/infer-query-tags.test.ts new file mode 100644 index 0000000..8144c6c --- /dev/null +++ b/test/unit/infer-query-tags.test.ts @@ -0,0 +1,105 @@ +import { describe, it, expect, vi } from "vitest"; +import { inferQueryTags } from "../../src/index"; +import { makeTestEnv, makeTestDb } from "../helpers/make-env"; + +function makeSseStream(response: string) { + return new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode(`data: {"response":${JSON.stringify(response)}}\n\n`)); + c.enqueue(new TextEncoder().encode("data: [DONE]\n\n")); + c.close(); + }, + }); +} + +describe("inferQueryTags", () => { + it("returns hashtags extracted from the query without hitting the DB", async () => { + const db = makeTestDb(); + const aiRun = vi.fn(); + const dbPrepareSpy = vi.spyOn(db, "prepare"); + const env = makeTestEnv(db, { AI: { run: aiRun } as unknown as Ai }); + const tags = await inferQueryTags("what did I decide about #work today?", env); + expect(tags).toEqual(["work"]); + // Early return — no DB or LLM call + expect(dbPrepareSpy).not.toHaveBeenCalled(); + expect(aiRun).not.toHaveBeenCalled(); + }); + + it("returns keyword-matched known tags (whole-word match, case-insensitive)", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Office lease note", tags: '["work","legal"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const env = makeTestEnv(db); + const tags = await inferQueryTags("what work and legal things did I decide?", env); + expect(tags).toHaveLength(2); + expect(tags).toEqual(expect.arrayContaining(["work", "legal"])); + }); + + it("does not call the LLM when keyword matches are found", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["work"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const aiRun = vi.fn().mockResolvedValue(makeSseStream("work")); + const env = makeTestEnv(db, { AI: { run: aiRun } as unknown as Ai }); + const tags = await inferQueryTags("work meeting notes", env); + expect(tags).toContain("work"); + expect(aiRun).not.toHaveBeenCalled(); + }); + + it("calls the LLM and intersects with known tags when cheap inference finds nothing", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["work","personal"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const aiRun = vi.fn().mockResolvedValue(makeSseStream("work, personal")); + const env = makeTestEnv(db, { AI: { run: aiRun } as unknown as Ai }); + const tags = await inferQueryTags("quarterly planning session", env); + expect(tags).toHaveLength(2); + expect(tags).toEqual(expect.arrayContaining(["work", "personal"])); + expect(aiRun).toHaveBeenCalledTimes(1); + }); + + it("filters out unknown tags returned by the LLM (intersects with known set)", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["work"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const aiRun = vi.fn().mockResolvedValue(makeSseStream("work, invented-tag, random")); + const env = makeTestEnv(db, { AI: { run: aiRun } as unknown as Ai }); + const tags = await inferQueryTags("quarterly planning session", env); + expect(tags).toEqual(["work"]); + }); + + it("returns empty array when the LLM throws — never propagates error", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["work"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const aiRun = vi.fn().mockRejectedValue(new Error("AI unavailable")); + const env = makeTestEnv(db, { AI: { run: aiRun } as unknown as Ai }); + await expect(inferQueryTags("quarterly planning session", env)).resolves.toEqual([]); + }); + + it("returns empty array when DB has no entries (no vocabulary to match against)", async () => { + const db = makeTestDb(); + const env = makeTestEnv(db); + const tags = await inferQueryTags("quarterly planning session", env); + expect(tags).toEqual([]); + }); + + it("does not partially match — 'networking' does not match tag 'net'", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["net"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const env = makeTestEnv(db); + const tags = await inferQueryTags("networking event", env); + expect(tags).not.toContain("net"); + }); + + it("does not match a hyphenated compound — 'my-claude-response-thing' does not match tag 'claude-response'", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["claude-response"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const env = makeTestEnv(db); + const tags = await inferQueryTags("my-claude-response-thing happened", env); + expect(tags).not.toContain("claude-response"); + }); + + it("matches a hyphenated tag that appears standalone in the query", async () => { + const db = makeTestDb(); + db.entries.push({ id: "e1", content: "Note", tags: '["claude-response"]', source: "api", created_at: 1000, vector_ids: "[]", recall_count: 0, importance_score: 0 }); + const env = makeTestEnv(db); + const tags = await inferQueryTags("what claude-response notes do I have", env); + expect(tags).toContain("claude-response"); + }); +}); diff --git a/test/unit/rerank.test.ts b/test/unit/rerank.test.ts index 84ce068..78d1704 100644 --- a/test/unit/rerank.test.ts +++ b/test/unit/rerank.test.ts @@ -106,4 +106,19 @@ describe("rerankWithTimeDecay", () => { const result = rerankWithTimeDecay([old, fresh], new Map(), importance); expect(result[0].id).toBe("old"); }); + + it("tag-overlapping entry outranks equal-vector-score entry without matching tag", () => { + const withTag = match("tagged", 0.9, NOW - 5 * MS_DAY, ["work"]); + const withoutTag = match("untagged", 0.9, NOW - 5 * MS_DAY, ["personal"]); + const result = rerankWithTimeDecay([withoutTag, withTag], new Map(), new Map(), ["work"]); + expect(result[0].id).toBe("tagged"); + expect(result[0].score).toBeGreaterThan(result[1].score); + }); + + it("queryTags=[] produces identical scores to no queryTags argument (backward compat)", () => { + const m = match("entry", 0.9, NOW - 5 * MS_DAY, ["work"]); + const [withEmpty] = rerankWithTimeDecay([m], new Map(), new Map(), []); + const [withDefault] = rerankWithTimeDecay([m]); + expect(withEmpty.score).toBeCloseTo(withDefault.score, 6); + }); });