diff --git a/README.md b/README.md index a7711e3..152c9a1 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,47 @@ Semantic search (`lat search`) requires an OpenAI (`sk-...`) or Vercel AI Gatewa 3. `LAT_LLM_KEY_HELPER` env var — shell command that prints the key (10s timeout) 4. Config file — saved by `lat init`. Run `lat config` to see its location. +### Other providers + +You can point `lat search` at a local embedding server with OpenAI-compatible `/v1/embeddings` endpoint. + +Example for [OpenRouter](https://openrouter.ai) configured via environment variables: + +```bash +LAT_LLM_BASE=https://openrouter.ai/api/v1 +LAT_LLM_MODEL=qwen/qwen3-embedding-8b +LAT_LLM_DIMENSIONS=4096 # must match model output dimensions +LAT_LLM_KEY=sk-or-v1-xyz +``` + +Or via the config file (run `lat config` to see its location): + +```json +{ + "llm_base": "https://openrouter.ai/api/v1", + "llm_model": "qwen/qwen3-embedding-8b", + "llm_dimensions": 4096 +} +``` +Environment variables take precedence over config file values. + +### Local models + +For example using `llama-server` from [llama.cpp](https://github.com/ggerganov/llama.cpp), start a local server with an embedding model with: + +```bash +llama-server -hf Qwen/Qwen3-Embedding-0.6B-GGUF --embedding --gpu-layers 99 +``` + +Configure like: + +```bash +export LAT_LLM_BASE=http://localhost:8080/v1 +export LAT_LLM_DIMENSIONS=1024 +``` + +When `LAT_LLM_BASE` is set, the API key is optional. If your server doesn't require authentication, you can omit `LAT_LLM_KEY` entirely. + ## Development Requires Node.js 22+ and pnpm. diff --git a/src/cli/hook.ts b/src/cli/hook.ts index 4e77653..57772aa 100644 --- a/src/cli/hook.ts +++ b/src/cli/hook.ts @@ -5,7 +5,8 @@ import { plainStyler, type CmdContext } from '../context.js'; import { expandPrompt } from './expand.js'; import { runSearch } from './search.js'; import { getSection, formatSectionOutput } from './section.js'; -import { getLlmKey } from '../config.js'; +import { getLlmKey, readConfig } from '../config.js'; +import { detectProvider } from '../search/provider.js'; import { checkMd, checkCodeRefs, checkIndex, checkSections } from './check.js'; import { SOURCE_EXTENSIONS } from '../source-parser.js'; @@ -62,15 +63,16 @@ async function searchAndExpand( ctx: CmdContext, userPrompt: string, ): Promise { + const config = readConfig(); let key: string | undefined; try { key = getLlmKey(); } catch { return null; } - if (!key) return null; + const provider = detectProvider(key, config); - const result = await runSearch(ctx.latDir, userPrompt, key, 5); + const result = await runSearch(ctx.latDir, userPrompt, provider, key, 5); if (result.matches.length === 0) return null; const parts: string[] = [ diff --git a/src/cli/init.ts b/src/cli/init.ts index 9490cdc..34e3b71 100644 --- a/src/cli/init.ts +++ b/src/cli/init.ts @@ -1168,6 +1168,11 @@ async function setupLlmKey( styleText('yellow', ' Unrecognized key prefix.') + ' Expected sk-... (OpenAI) or vck_... (Vercel AI Gateway).', ); + console.log( + ' For a custom endpoint (e.g. local llama-server), set ' + + styleText('cyan', 'LAT_LLM_BASE') + + ' instead.', + ); console.log(' Saving anyway — you can update it later.'); } diff --git a/src/cli/search.ts b/src/cli/search.ts index 46eb1e6..4466f02 100644 --- a/src/cli/search.ts +++ b/src/cli/search.ts @@ -1,6 +1,6 @@ import type { CmdContext, CmdResult, Styler } from '../context.js'; import { openDb, ensureSchema, closeDb } from '../search/db.js'; -import { detectProvider } from '../search/provider.js'; +import type { EmbeddingProvider } from '../search/provider.js'; import { indexSections, type IndexStats } from '../search/index.js'; import { searchSections } from '../search/search.js'; import { @@ -24,14 +24,11 @@ export type IndexProgress = { async function withDb( latDir: string, - key: string, + provider: EmbeddingProvider, + key: string | undefined, progress: IndexProgress | undefined, - fn: ( - db: Awaited>, - provider: ReturnType, - ) => Promise, + fn: (db: Awaited>) => Promise, ): Promise { - const provider = detectProvider(key); const db = openDb(latDir); try { @@ -44,7 +41,7 @@ async function withDb( const stats = await indexSections(latDir, db, provider, key); progress?.afterIndex?.(stats, isEmpty); - return await fn(db, provider); + return await fn(db); } finally { await closeDb(db); } @@ -57,11 +54,12 @@ async function withDb( export async function runSearch( latDir: string, query: string, - key: string, + provider: EmbeddingProvider, + key: string | undefined, limit: number, progress?: IndexProgress, ): Promise { - return withDb(latDir, key, progress, async (db, provider) => { + return withDb(latDir, provider, key, progress, async (db) => { const results = await searchSections(db, query, provider, key, limit); if (results.length === 0) { return { query, matches: [] }; @@ -85,10 +83,11 @@ export async function runSearch( */ export async function runIndex( latDir: string, - key: string, + provider: EmbeddingProvider, + key: string | undefined, progress?: IndexProgress, ): Promise { - await withDb(latDir, key, progress, async () => {}); + await withDb(latDir, provider, key, progress, async () => {}); } export function cliProgress(reindex: boolean, s: Styler): IndexProgress { @@ -123,18 +122,25 @@ export async function searchCommand( opts: { limit: number; reindex?: boolean }, progress?: IndexProgress, ): Promise { - const { getLlmKey, getConfigPath } = await import('../config.js'); + const { getLlmKey, readConfig, getConfigPath } = await import('../config.js'); + const { detectProvider } = await import('../search/provider.js'); + + const config = readConfig(); let key: string | undefined; try { key = getLlmKey(); } catch (err) { return { output: (err as Error).message, isError: true }; } - if (!key) { + + let provider: Awaited>; + try { + provider = detectProvider(key, config); + } catch (err) { const s = ctx.styler; return { output: - s.red('No API key configured.') + + s.red((err as Error).message) + ' Provide a key via LAT_LLM_KEY, LAT_LLM_KEY_FILE, LAT_LLM_KEY_HELPER, or run ' + s.cyan('lat init') + (ctx.mode === 'cli' @@ -146,11 +152,18 @@ export async function searchCommand( } if (!query) { - await runIndex(ctx.latDir, key, progress); + await runIndex(ctx.latDir, provider, key, progress); return { output: '' }; } - const result = await runSearch(ctx.latDir, query, key, opts.limit, progress); + const result = await runSearch( + ctx.latDir, + query, + provider, + key, + opts.limit, + progress, + ); if (result.matches.length === 0) { return { output: 'No results found.' }; diff --git a/src/config.ts b/src/config.ts index ddf04e0..3abcb51 100644 --- a/src/config.ts +++ b/src/config.ts @@ -17,6 +17,9 @@ export function getConfigPath(): string { export type LatConfig = { llm_key?: string; + llm_base?: string; + llm_model?: string; + llm_dimensions?: number; }; export function readConfig(): LatConfig { diff --git a/src/search/db.ts b/src/search/db.ts index d92fb7b..c5c6357 100644 --- a/src/search/db.ts +++ b/src/search/db.ts @@ -17,6 +17,21 @@ export async function ensureSchema( db: Client, dimensions: number, ): Promise { + const embeddingType = `F32_BLOB(${dimensions})`; + + // Detect dimension change from the column type in the existing schema + const schema = await db.execute( + `SELECT sql FROM sqlite_master WHERE type='table' AND name='sections'`, + ); + if ( + schema.rows.length > 0 && + !String(schema.rows[0].sql).includes(embeddingType) + ) { + await db.execute('DROP INDEX IF EXISTS sections_vec_idx'); + await db.execute('DROP TABLE IF EXISTS sections'); + process.stderr.write(`Embedding type changed, re-indexing...\n`); + } + await db.execute( `CREATE TABLE IF NOT EXISTS sections ( id TEXT PRIMARY KEY, @@ -24,7 +39,7 @@ export async function ensureSchema( heading TEXT NOT NULL, content TEXT NOT NULL, content_hash TEXT NOT NULL, - embedding F32_BLOB(${dimensions}), + embedding ${embeddingType}, updated_at INTEGER NOT NULL )`, ); diff --git a/src/search/embeddings.ts b/src/search/embeddings.ts index daa2f03..73c6091 100644 --- a/src/search/embeddings.ts +++ b/src/search/embeddings.ts @@ -5,7 +5,7 @@ const MAX_BATCH = 2048; export async function embed( texts: string[], provider: EmbeddingProvider, - key: string, + key?: string, ): Promise { const results: number[][] = []; diff --git a/src/search/index.ts b/src/search/index.ts index 70076cf..6fa8ff1 100644 --- a/src/search/index.ts +++ b/src/search/index.ts @@ -31,7 +31,7 @@ export async function indexSections( latDir: string, db: Client, provider: EmbeddingProvider, - key: string, + key?: string, ): Promise { const projectRoot = dirname(latDir); const allSections = await loadAllSections(latDir); diff --git a/src/search/provider.ts b/src/search/provider.ts index 16ee2b6..9ed14f3 100644 --- a/src/search/provider.ts +++ b/src/search/provider.ts @@ -1,9 +1,11 @@ +import type { LatConfig } from '../config.js'; + export type EmbeddingProvider = { name: string; apiBase: string; model: string; dimensions: number; - headers: (key: string) => Record; + headers: (key?: string) => Record; }; const openai: EmbeddingProvider = { @@ -28,8 +30,32 @@ const vercel: EmbeddingProvider = { }), }; -export function detectProvider(key: string): EmbeddingProvider { - if (key.startsWith('REPLAY_LAT_LLM_KEY::')) { +function customProvider(config: LatConfig): EmbeddingProvider | null { + const base = process.env.LAT_LLM_BASE ?? config.llm_base; + if (!base) return null; + return { + name: 'custom', + apiBase: base, + model: process.env.LAT_LLM_MODEL ?? config.llm_model ?? 'default', + dimensions: + process.env.LAT_LLM_DIMENSIONS != null + ? Number(process.env.LAT_LLM_DIMENSIONS) + : (config.llm_dimensions ?? 1536), + headers: (k) => { + const h: Record = { + 'Content-Type': 'application/json', + }; + if (k) h['Authorization'] = `Bearer ${k}`; + return h; + }, + }; +} + +export function detectProvider( + key: string | undefined, + config: LatConfig = {}, +): EmbeddingProvider { + if (key?.startsWith('REPLAY_LAT_LLM_KEY::')) { const replayUrl = key.slice('REPLAY_LAT_LLM_KEY::'.length); return { name: 'replay', @@ -39,6 +65,13 @@ export function detectProvider(key: string): EmbeddingProvider { headers: () => ({ 'Content-Type': 'application/json' }), }; } + + const custom = customProvider(config); + if (custom) return custom; + + if (!key) { + throw new Error('No API key configured.'); + } if (key.startsWith('sk-ant-')) { throw new Error( "Anthropic doesn't offer an embedding model. Set LAT_LLM_KEY to an OpenAI (sk-...) or Vercel AI Gateway (vck_...) key.", @@ -47,6 +80,6 @@ export function detectProvider(key: string): EmbeddingProvider { if (key.startsWith('vck_')) return vercel; if (key.startsWith('sk-')) return openai; throw new Error( - `Unrecognized LAT_LLM_KEY prefix. Supported: OpenAI (sk-...), Vercel AI Gateway (vck_...).`, + `Unrecognized LAT_LLM_KEY prefix. Supported: OpenAI (sk-...), Vercel AI Gateway (vck_...), or set LAT_LLM_BASE for a custom endpoint.`, ); } diff --git a/src/search/search.ts b/src/search/search.ts index bb474cf..11ba878 100644 --- a/src/search/search.ts +++ b/src/search/search.ts @@ -13,7 +13,7 @@ export async function searchSections( db: Client, query: string, provider: EmbeddingProvider, - key: string, + key: string | undefined, limit = 5, ): Promise { const [queryVec] = await embed([query], provider, key); diff --git a/tests/search.test.ts b/tests/search.test.ts index bc284dd..ca8dbe3 100644 --- a/tests/search.test.ts +++ b/tests/search.test.ts @@ -34,6 +34,39 @@ describe('detectProvider', () => { it('rejects unknown key', () => { expect(() => detectProvider('xyz_abc123')).toThrow(/Unrecognized/); }); + + it('uses custom provider when LAT_LLM_BASE is set', () => { + process.env.LAT_LLM_BASE = 'http://localhost:8080/v1'; + process.env.LAT_LLM_MODEL = 'qwen3-embedding'; + process.env.LAT_LLM_DIMENSIONS = '1024'; + try { + const p = detectProvider('sk-abc123'); + expect(p).toMatchObject({ + name: 'custom', + apiBase: 'http://localhost:8080/v1', + model: 'qwen3-embedding', + dimensions: 1024, + }); + } finally { + delete process.env.LAT_LLM_BASE; + delete process.env.LAT_LLM_MODEL; + delete process.env.LAT_LLM_DIMENSIONS; + } + }); + + it('custom provider omits Authorization when key is undefined', () => { + process.env.LAT_LLM_BASE = 'http://localhost:8080/v1'; + process.env.LAT_LLM_DIMENSIONS = '1024'; + try { + const p = detectProvider(undefined); + expect(p.headers()).not.toHaveProperty('Authorization'); + expect(p.headers('sk-real')).toHaveProperty('Authorization'); + } finally { + delete process.env.LAT_LLM_BASE; + delete process.env.LAT_LLM_DIMENSIONS; + } + }); + }); // --- RAG functional tests ---