diff --git a/src/search/embeddings.ts b/src/search/embeddings.ts index daa2f03..4ca2139 100644 --- a/src/search/embeddings.ts +++ b/src/search/embeddings.ts @@ -17,6 +17,7 @@ export async function embed( body: JSON.stringify({ model: provider.model, input: batch, + ...(provider.dimensions ? { dimensions: provider.dimensions } : {}), }), }); diff --git a/src/search/provider.ts b/src/search/provider.ts index 16ee2b6..4ff382e 100644 --- a/src/search/provider.ts +++ b/src/search/provider.ts @@ -28,7 +28,44 @@ const vercel: EmbeddingProvider = { }), }; +const gemini: EmbeddingProvider = { + name: 'gemini', + apiBase: 'https://generativelanguage.googleapis.com/v1beta/openai', + model: 'gemini-embedding-001', + dimensions: 1536, + headers: (key) => ({ + Authorization: `Bearer ${key}`, + 'Content-Type': 'application/json', + }), +}; + +/** + * Build a custom provider from LAT_LLM_ENDPOINT + optional LAT_LLM_MODEL. + * The endpoint must be OpenAI-compatible (POST /embeddings). + */ +export function customProvider( + endpoint: string, + model?: string, +): EmbeddingProvider { + return { + name: 'custom', + apiBase: endpoint.replace(/\/+$/, ''), + model: model || 'text-embedding-3-small', + dimensions: 1536, + headers: (key) => ({ + Authorization: `Bearer ${key}`, + 'Content-Type': 'application/json', + }), + }; +} + export function detectProvider(key: string): EmbeddingProvider { + // Custom endpoint takes highest priority + const endpoint = process.env.LAT_LLM_ENDPOINT; + if (endpoint) { + return customProvider(endpoint, process.env.LAT_LLM_MODEL); + } + if (key.startsWith('REPLAY_LAT_LLM_KEY::')) { const replayUrl = key.slice('REPLAY_LAT_LLM_KEY::'.length); return { @@ -44,9 +81,11 @@ export function detectProvider(key: string): EmbeddingProvider { "Anthropic doesn't offer an embedding model. Set LAT_LLM_KEY to an OpenAI (sk-...) or Vercel AI Gateway (vck_...) key.", ); } + if (key.startsWith('AIza')) return gemini; if (key.startsWith('vck_')) return vercel; if (key.startsWith('sk-')) return openai; throw new Error( - `Unrecognized LAT_LLM_KEY prefix. Supported: OpenAI (sk-...), Vercel AI Gateway (vck_...).`, + `Unrecognized LAT_LLM_KEY prefix. Supported: OpenAI (sk-...), Vercel AI Gateway (vck_...), Gemini (AIza...). ` + + `Or set LAT_LLM_ENDPOINT for any OpenAI-compatible server.`, ); } diff --git a/tests/search.test.ts b/tests/search.test.ts index bc284dd..0b99b2c 100644 --- a/tests/search.test.ts +++ b/tests/search.test.ts @@ -4,6 +4,7 @@ import { join } from 'node:path'; import { tmpdir } from 'node:os'; import { detectProvider, + customProvider, type EmbeddingProvider, } from '../src/search/provider.js'; import { openDb, ensureSchema, closeDb } from '../src/search/db.js'; @@ -34,6 +35,55 @@ describe('detectProvider', () => { it('rejects unknown key', () => { expect(() => detectProvider('xyz_abc123')).toThrow(/Unrecognized/); }); + + it('detects Gemini key', () => { + const p = detectProvider('AIzaSyExampleKey123'); + expect(p.name).toBe('gemini'); + expect(p.apiBase).toContain('generativelanguage.googleapis.com'); + }); + + it('uses LAT_LLM_ENDPOINT when set', () => { + const prev = process.env.LAT_LLM_ENDPOINT; + const prevModel = process.env.LAT_LLM_MODEL; + try { + process.env.LAT_LLM_ENDPOINT = 'http://localhost:11434/v1'; + process.env.LAT_LLM_MODEL = 'nomic-embed-text'; + const p = detectProvider('sk-abc123'); // key prefix ignored when endpoint set + expect(p.name).toBe('custom'); + expect(p.apiBase).toBe('http://localhost:11434/v1'); + expect(p.model).toBe('nomic-embed-text'); + } finally { + if (prev === undefined) delete process.env.LAT_LLM_ENDPOINT; + else process.env.LAT_LLM_ENDPOINT = prev; + if (prevModel === undefined) delete process.env.LAT_LLM_MODEL; + else process.env.LAT_LLM_MODEL = prevModel; + } + }); + + it('LAT_LLM_ENDPOINT strips trailing slashes', () => { + const prev = process.env.LAT_LLM_ENDPOINT; + try { + process.env.LAT_LLM_ENDPOINT = 'http://localhost:8080/v1/'; + const p = detectProvider('sk-abc123'); + expect(p.apiBase).toBe('http://localhost:8080/v1'); + } finally { + if (prev === undefined) delete process.env.LAT_LLM_ENDPOINT; + else process.env.LAT_LLM_ENDPOINT = prev; + } + }); +}); + +describe('customProvider', () => { + it('builds provider with custom model', () => { + const p = customProvider('http://localhost:11434/v1', 'nomic-embed-text'); + expect(p.name).toBe('custom'); + expect(p.model).toBe('nomic-embed-text'); + }); + + it('defaults model when not specified', () => { + const p = customProvider('http://localhost:11434/v1'); + expect(p.model).toBe('text-embedding-3-small'); + }); }); // --- RAG functional tests ---