From a5379dc2781427796f6790300e2b14ecb042a487 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Thu, 14 May 2026 17:17:48 +0200 Subject: [PATCH 1/6] ignore debug scripts --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 70daa8d..6acb725 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ Thumbs.db # Debug __debug_bin* +/debug/ # Python __pycache__/ From 6c5ec43f2dde73d536004fc1c2d704bf9dd5e04a Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Thu, 14 May 2026 17:42:37 +0200 Subject: [PATCH 2/6] config: global llm config block merged into repo-local config ~/.config/gortex/config.yaml gains an `llm:` block. MergeLLMInto fills zero fields of the repo-local llm config from the global one (local non-zero values win, including an explicit per-repo override of an inherited model path), with leading `~/` in model paths expanded against $HOME. The daemon and `gortex mcp` startup paths both load and merge the global config before SetupLLM; env vars still override last. --- cmd/gortex/daemon_state.go | 15 ++-- cmd/gortex/mcp.go | 8 ++- internal/config/global.go | 59 ++++++++++++++++ internal/config/global_llm_test.go | 110 +++++++++++++++++++++++++++++ 4 files changed, 182 insertions(+), 10 deletions(-) create mode 100644 internal/config/global_llm_test.go diff --git a/cmd/gortex/daemon_state.go b/cmd/gortex/daemon_state.go index 5fa0047..b6cb3c3 100644 --- a/cmd/gortex/daemon_state.go +++ b/cmd/gortex/daemon_state.go @@ -302,13 +302,14 @@ func buildDaemonState(logger *zap.Logger) (*daemonState, error) { logger.Warn("daemon: savings persistence disabled", zap.Error(err)) } - // In-process LLM service (opt-in via `.gortex.yaml` `llm.model:` or - // GORTEX_LLM_MODEL env var). Builds and attaches an in-process - // backend wired to this engine + contract registry, then registers - // the `ask` MCP tool. No-op when cfg.LLM is empty after env-merge, - // or when gortex was built without `-tags llama` (stub service + - // stub registerLLMTools). - srv.SetupLLM(cfg.LLM) + // In-process LLM service (opt-in via `.gortex.yaml` `llm.model:`, + // `~/.config/gortex/config.yaml::llm:`, or GORTEX_LLM_MODEL env + // var). Repo-local config wins per non-zero field; global fills + // the rest. Env overrides land last inside SetupLLM via MergeEnv. + // No-op when the merged config has no model, or when gortex was + // built without `-tags llama` (stub service + stub registerLLMTools). + gc, _ := config.LoadGlobal() + srv.SetupLLM(gc.MergeLLMInto(cfg.LLM)) // MultiWatcher is created in warmupDaemonState after tracked repos // have been re-indexed — NewMultiWatcher needs mi.AllMetadata() to be diff --git a/cmd/gortex/mcp.go b/cmd/gortex/mcp.go index 73943b7..f77b114 100644 --- a/cmd/gortex/mcp.go +++ b/cmd/gortex/mcp.go @@ -345,9 +345,11 @@ func runMCP(cmd *cobra.Command, args []string) error { fmt.Fprintf(os.Stderr, "[gortex] savings persistence disabled: %v\n", err) } - // In-process LLM service — same wiring as the daemon path. No-op - // when cfg.LLM is empty or gortex was built without `-tags llama`. - srv.SetupLLM(cfg.LLM) + // In-process LLM service — same wiring as the daemon path: repo + // config wins per non-zero field, global ~/.config/gortex/config.yaml + // fills the rest, env vars override last inside SetupLLM. + gc, _ := config.LoadGlobal() + srv.SetupLLM(gc.MergeLLMInto(cfg.LLM)) fmt.Fprintf(os.Stderr, "[gortex] MCP server ready (transport: %s)\n", mcpTransport) diff --git a/internal/config/global.go b/internal/config/global.go index cb68fbd..5e88e03 100644 --- a/internal/config/global.go +++ b/internal/config/global.go @@ -9,6 +9,8 @@ import ( "sync" "gopkg.in/yaml.v3" + + "github.com/zzet/gortex/internal/llm" ) var ( @@ -54,10 +56,67 @@ type GlobalConfig struct { // baseline and below per-RepoEntry / workspace lists. Exclude []string `mapstructure:"exclude" yaml:"exclude,omitempty"` + // LLM is the user-level local-LLM service config (`llm.model:` etc.). + // Merged into the repo-local Config.LLM at daemon startup via + // MergeLLMInto — local non-zero fields win, global fills the rest. + // Lets users keep model paths and tuning in one place across repos + // without duplicating an `llm:` block in every `.gortex.yaml`. + LLM llm.Config `mapstructure:"llm" yaml:"llm,omitempty"` + // configPath stores the file path used for Save(). Set by LoadGlobal or SetConfigPath. configPath string `yaml:"-"` } +// MergeLLMInto returns local with any zero fields filled from gc.LLM. +// Local non-zero values always win — including an explicit per-repo +// override of an inherited global model path. Safe to call on a nil +// receiver (returns local unchanged), so daemon startup paths don't +// need separate nil-checks for the global config. +func (gc *GlobalConfig) MergeLLMInto(local llm.Config) llm.Config { + if gc == nil { + return local + } + g := gc.LLM + if local.Model == "" { + local.Model = g.Model + } + if local.Ctx == 0 { + local.Ctx = g.Ctx + } + if local.GPULayers == 0 { + local.GPULayers = g.GPULayers + } + if local.MaxSteps == 0 { + local.MaxSteps = g.MaxSteps + } + if local.Template == "" { + local.Template = g.Template + } + local.Model = expandHome(local.Model) + return local +} + +// expandHome resolves a leading `~/` in a path against $HOME so users +// can write portable model paths in their global config. No-op when +// the path is empty, absolute without `~`, or `~` is not the first +// character. Returns the input unchanged on any os.UserHomeDir error. +func expandHome(p string) string { + if p == "" || !strings.HasPrefix(p, "~") { + return p + } + home, err := os.UserHomeDir() + if err != nil { + return p + } + if p == "~" { + return home + } + if strings.HasPrefix(p, "~/") { + return filepath.Join(home, p[2:]) + } + return p +} + // DefaultGlobalConfigPath returns the default path: ~/.config/gortex/config.yaml. // // Resolved fresh on every call so HOME changes (notably t.Setenv in tests) diff --git a/internal/config/global_llm_test.go b/internal/config/global_llm_test.go new file mode 100644 index 0000000..4798c2b --- /dev/null +++ b/internal/config/global_llm_test.go @@ -0,0 +1,110 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zzet/gortex/internal/llm" +) + +func TestLoadGlobal_LLMSectionRoundTrip(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(cfgPath, []byte(`active_project: "" +repos: [] +llm: + model: /opt/models/qwen.gguf + template: chatml + ctx: 4096 + max_steps: 12 + gpu_layers: 999 +`), 0o644)) + + gc, err := LoadGlobal(cfgPath) + require.NoError(t, err) + require.NotNil(t, gc) + assert.Equal(t, "/opt/models/qwen.gguf", gc.LLM.Model) + assert.Equal(t, "chatml", gc.LLM.Template) + assert.Equal(t, 4096, gc.LLM.Ctx) + assert.Equal(t, 12, gc.LLM.MaxSteps) + assert.Equal(t, 999, gc.LLM.GPULayers) +} + +func TestGlobalConfig_MergeLLMInto_FillsZeroFields(t *testing.T) { + gc := &GlobalConfig{LLM: llm.Config{ + Model: "/global/qwen.gguf", + Template: "chatml", + Ctx: 4096, + MaxSteps: 16, + GPULayers: 999, + }} + + got := gc.MergeLLMInto(llm.Config{}) + assert.Equal(t, "/global/qwen.gguf", got.Model) + assert.Equal(t, "chatml", got.Template) + assert.Equal(t, 4096, got.Ctx) + assert.Equal(t, 16, got.MaxSteps) + assert.Equal(t, 999, got.GPULayers) +} + +func TestGlobalConfig_MergeLLMInto_LocalWinsPerField(t *testing.T) { + gc := &GlobalConfig{LLM: llm.Config{ + Model: "/global/qwen.gguf", + Template: "chatml", + Ctx: 4096, + MaxSteps: 16, + }} + + got := gc.MergeLLMInto(llm.Config{ + Model: "/repo/override.gguf", // local wins + Ctx: 8192, // local wins + }) + assert.Equal(t, "/repo/override.gguf", got.Model) + assert.Equal(t, 8192, got.Ctx) + // Unset locals fall through to global. + assert.Equal(t, "chatml", got.Template) + assert.Equal(t, 16, got.MaxSteps) +} + +func TestGlobalConfig_MergeLLMInto_NilReceiver(t *testing.T) { + var gc *GlobalConfig // nil + local := llm.Config{Model: "/repo/x.gguf"} + got := gc.MergeLLMInto(local) + assert.Equal(t, "/repo/x.gguf", got.Model) +} + +func TestGlobalConfig_MergeLLMInto_ExpandsHomeInModelPath(t *testing.T) { + home, err := os.UserHomeDir() + require.NoError(t, err) + + gc := &GlobalConfig{LLM: llm.Config{Model: "~/models/qwen.gguf"}} + got := gc.MergeLLMInto(llm.Config{}) + assert.Equal(t, filepath.Join(home, "models/qwen.gguf"), got.Model) + + // Local override also gets expanded. + got = gc.MergeLLMInto(llm.Config{Model: "~/repo-override.gguf"}) + assert.Equal(t, filepath.Join(home, "repo-override.gguf"), got.Model) +} + +func TestExpandHome(t *testing.T) { + home, err := os.UserHomeDir() + require.NoError(t, err) + + cases := []struct { + in, want string + }{ + {"", ""}, + {"/abs/path", "/abs/path"}, + {"relative/path", "relative/path"}, + {"~", home}, + {"~/models/foo.gguf", filepath.Join(home, "models/foo.gguf")}, + {"~weird", "~weird"}, // only `~/` form is expanded + } + for _, tc := range cases { + assert.Equal(t, tc.want, expandHome(tc.in), "in=%q", tc.in) + } +} From ffd9943413f5a615bada934347d5ac29b8ff4def Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Thu, 14 May 2026 17:42:37 +0200 Subject: [PATCH 3/6] llm, mcp: LLM-assisted search ranking via search_symbols `assist` arg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit search_symbols gains an `assist` arg (auto/on/off/deep). `auto` uses a cheap NL heuristic to skip identifier lookups; active modes run grammar- constrained query expansion + name/signature rerank, and `deep` adds a body-grounded verification pass that reads candidate bodies and callers and honestly drops irrelevant matches — an empty result is preserved as the load-bearing negative signal. The svc layer adds a pre-warmed assist context with its own mutex and KV cache so short assist calls don't head-of-line block a long `ask`, plus LRU caches for expand/rerank/verify. The stub build returns errServiceUnavailable for all three new methods. --- CLAUDE.md | 13 + internal/llm/assist.go | 70 +++ internal/llm/svc/assist.go | 559 +++++++++++++++++++++++ internal/llm/svc/assist_e2e_test.go | 385 ++++++++++++++++ internal/llm/svc/assist_test.go | 276 +++++++++++ internal/llm/svc/cache.go | 121 +++++ internal/llm/svc/service.go | 37 +- internal/llm/svc/service_stub.go | 18 + internal/mcp/tools_core.go | 70 ++- internal/mcp/tools_search_assist.go | 457 ++++++++++++++++++ internal/mcp/tools_search_assist_test.go | 279 +++++++++++ 11 files changed, 2280 insertions(+), 5 deletions(-) create mode 100644 internal/llm/assist.go create mode 100644 internal/llm/svc/assist.go create mode 100644 internal/llm/svc/assist_e2e_test.go create mode 100644 internal/llm/svc/assist_test.go create mode 100644 internal/llm/svc/cache.go create mode 100644 internal/mcp/tools_search_assist.go create mode 100644 internal/mcp/tools_search_assist_test.go diff --git a/CLAUDE.md b/CLAUDE.md index 29fa438..ede4bd6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -32,6 +32,19 @@ When the daemon is built with `-tags llama` and `llm.model` is set in `.gortex.y If `ask` isn't in `tools/list`, gortex was built without `-tags llama` or `llm.model` is unset. Fall through to direct tools. +### Optional: LLM-assisted search ranking (`search_symbols` `assist:` arg) + +When the same `-tags llama` build + `llm.model` is in place, `search_symbols` accepts an `assist` argument that engages the local model in the search pipeline. The default `auto` is sub-100 ms on identifier lookups; the active modes add latency but materially improve precision on natural-language queries. + +| `assist` value | Behaviour | Cost | +|----------------|-----------|------| +| `auto` (default) | NL heuristic decides per-query. Identifier-shaped queries (`Server.handleAsk`, `parseToolCall`) skip the LLM. NL queries (≥3 tokens with a stop word, or ≥4 plain-word tokens) trigger query expansion + name+sig rerank. | None for identifier lookups; +200–500 ms for NL. | +| `on` | Forces expansion + name+sig rerank regardless of shape. Use when you know the query is fuzzy. | +200–500 ms. | +| `off` | Pure BM25 + combo/frecency. No LLM. | None. | +| `deep` | `on` plus a body-grounded verification pass — reads each top candidate's body + callers and HONESTLY drops candidates whose code isn't about the query. May return zero results when nothing genuinely matches; that's the load-bearing honest-negative signal. | +1.5–4 s. Quality is **highly model-dependent**: Qwen2.5-Coder 3B is unreliable on disambiguation cases (e.g. "hash passwords" vs functions that hash other data); Qwen2.5-Coder 7B and above produce stable, useful results. Prefer 7B+ if you want to rely on `deep`. | + +The response gains an `assist` debug block when an active mode engaged: `terms` (expansion words), `primary_count` (raw BM25 hits on the original query), `merged_count` (after expansion union), `final_count` (after filter/rerank), plus `verify_kept_ids` / `verify_dropped` for `deep`. + ### Navigation and Reading | Instead of... | You MUST use... | diff --git a/internal/llm/assist.go b/internal/llm/assist.go new file mode 100644 index 0000000..d48d390 --- /dev/null +++ b/internal/llm/assist.go @@ -0,0 +1,70 @@ +package llm + +// RerankCandidate is one entry the caller asks the LLM to consider in +// Service.RerankSymbols. ID is opaque to the model — the model only +// sees it as an identifier string to echo back in the new order — so +// callers can use whatever stable handle their graph layer provides +// (typically graph.Node.ID). +type RerankCandidate struct { + ID string + Name string + Signature string + Path string +} + +// ExpandResult is the output of Service.ExpandQuery. Terms are +// additional identifier-style search terms the caller should OR with +// the original query before BM25. Original is the trimmed input. +// Cached reports whether the result came from the in-memory LRU. +type ExpandResult struct { + Original string + Terms []string + Cached bool +} + +// RerankResult is the output of Service.RerankSymbols. Order is a +// permutation of the candidate IDs from the input — IDs the model +// dropped are appended in their original input order so the caller +// never loses candidates. Cached reports whether the result came from +// the in-memory LRU. +type RerankResult struct { + Order []string + Cached bool +} + +// VerifyCandidate is one entry the caller asks the LLM to read + +// verify against the query in Service.VerifyRelevance. The prompt +// includes the function body — the model is meant to read what the +// code actually DOES, not infer relevance from the name alone. Body +// should be pre-truncated (a single noisy candidate can blow the +// assist context). Callers carry independent contextual signal that +// distinguishes "same operation on different data" cases — e.g. a +// hashing function called only from a diagnostic-publish path is +// almost certainly not password hashing. +type VerifyCandidate struct { + ID string + Name string + Signature string + Body string + Callers []CallerInfo +} + +// CallerInfo is a compact reference to one caller of a verify +// candidate. Name + Signature is usually enough to disambiguate +// "what kind of data flows into this function" without dragging in +// the full caller body. +type CallerInfo struct { + Name string + Signature string +} + +// VerifyResult is the output of Service.VerifyRelevance. Keep is the +// subset of input IDs whose body the model judged genuinely related +// to the query, in the model's preferred order. Empty is a valid and +// load-bearing result — the model is allowed to say "nothing here +// matches" and the caller should treat that as honest negative +// evidence rather than fall back to BM25. +type VerifyResult struct { + Keep []string + Cached bool +} diff --git a/internal/llm/svc/assist.go b/internal/llm/svc/assist.go new file mode 100644 index 0000000..c4e1b6d --- /dev/null +++ b/internal/llm/svc/assist.go @@ -0,0 +1,559 @@ +//go:build llama + +package svc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/zzet/gortex/internal/llm" + "github.com/zzet/gortex/internal/llm/agent" +) + +// assistCtxSize is the KV-cache window for the short-call assist +// context. Sized for the heaviest user — verify with body + callers +// at ~3.5K tokens for 10 candidates. Expansion and rerank use a +// fraction of this; the extra KV cache is cheap (a few hundred MB). +const assistCtxSize = 4096 + +// Token caps per call. Expansion emits at most a small JSON list; +// rerank emits at most one ID per candidate. Verify emits one ID per +// surviving candidate, so its cap is comparable to rerank. +const ( + expandMaxTokens = 192 + rerankMaxTokens = 512 + verifyMaxTokens = 512 +) + +// Grammar for {"terms":[, ...]}. Strings are arbitrary JSON +// strings — callers filter the output to whatever's actually useful. +const expandGrammar = `root ::= ws "{" ws "\"terms\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws +str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +ws ::= [ \t\n]* +` + +// Grammar for {"order":[, ...]}. Same shape as expand, +// different top-level key — kept as two constants so each call site +// skips a Sprintf on the hot path. +const rerankGrammar = `root ::= ws "{" ws "\"order\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws +str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +ws ::= [ \t\n]* +` + +// Grammar for {"keep":[, ...]}. The body-grounded verifier +// MUST be allowed to emit an empty array — that's the load-bearing +// "honest negative" signal — so the array body is fully optional. +const verifyGrammar = `root ::= ws "{" ws "\"keep\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws +str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +ws ::= [ \t\n]* +` + +const expandSystem = `You expand a code-search query into a small set of CONCRETE identifier-style terms a programmer would actually grep for. ` + + `Output strict JSON: {"terms":["","",...]}. ` + + `Include 2 to 5 terms. Each term MUST be a single word with no spaces and no punctuation other than underscores. ` + + ` +RULES: +1. Prefer DOMAIN-SPECIFIC terms over generic English. ` + + `GOOD examples: bcrypt, argon2, scrypt, sha256, hmac, jwt, oauth, pbkdf2, kdf, salt. ` + + `BAD examples (NEVER emit): function, library, algorithm, code, system, data, service, value, info, content, thing, stuff, name, general, common, logic, process, handler, flow, action, helper, util, utility. ` + + ` +2. Prefer terms that are likely SYMBOL names in a codebase (camelCase / snake_case / PascalCase fragments), library or protocol names, well-known acronyms. ` + + ` +3. Do NOT echo the original query words. ` + + ` +4. If the query has no obvious domain-specific neighbours, emit FEWER terms (or an empty array) — quality over quantity.` + +const rerankSystem = `You rerank code-search results by relevance to a natural-language task. ` + + `Given a query and a list of candidate symbols (id | name | optional signature), output strict JSON: {"order":["id1","id2",...]} ` + + `with the most relevant candidates first. ` + + `Use ONLY the provided ids verbatim. Do not invent ids. You may drop ids that are clearly unrelated.` + +const verifySystem = `You filter code-search candidates by reading their BODY, SIGNATURE, and CALLERS, and keeping every one whose code is genuinely about the user's query. ` + + `Each candidate is presented as: + + | | +body: + +callers: +- | +- ... +--- + +Output strict JSON: {"keep":["id1","id2",...]} listing EVERY id whose code is meaningfully related to the query, in your preferred order (most relevant first). + +RULES (follow exactly): +1. Evaluate EACH candidate INDEPENDENTLY. Multiple candidates can be valid matches — keep them all. +2. A name that contains a query word is not enough by itself — read what the code DOES. +3. Cross-reference the CALLERS and the SIGNATURE's parameter types against the query DOMAIN. If a function hashes data but is only called from a "publishDiagnostics" or "renderLog" path with a non-password parameter type, it is NOT about hashing passwords — DROP it. +4. Be GENEROUS, not restrictive: if a candidate's body AND callers AND signature are all plausibly about the query, KEEP it. The user wants signal, not a single "best" pick. +5. Drop a candidate when its body, signature, or callers reveal the operation is on the wrong KIND of data for the query. +6. Returning {"keep":[]} is valid ONLY when NO candidate is genuinely about the query. +7. Use ONLY the provided ids verbatim. Never invent or modify an id.` + +// ensureAssist lazily allocates the short-call context the first time +// an assist method is called. Safe to invoke before locking +// assistMu — the underlying sync.Once handles concurrent first calls. +// Subsequent callers MUST still take assistMu before touching +// assistCtx, since the context itself is single-stream. +func (s *Service) ensureAssist() error { + if err := s.ensureLoaded(); err != nil { + return err + } + s.assistOnce.Do(func() { + c, err := s.model.NewContext(assistCtxSize, 0) + if err != nil { + s.assistErr = fmt.Errorf("llm: assist context: %w", err) + return + } + s.assistCtx = c + }) + return s.assistErr +} + +// ExpandQuery turns a natural-language search query into a small set +// of related identifier-style terms via one grammar-constrained +// inference pass. Result is cached by query string. Empty / blank +// input returns an empty result without touching the model. +// +// The caller is expected to OR the returned terms with the original +// query and rerank by combined BM25 score. +func (s *Service) ExpandQuery(ctx context.Context, query string) (*llm.ExpandResult, error) { + _ = ctx + query = strings.TrimSpace(query) + if query == "" { + return &llm.ExpandResult{Original: query}, nil + } + + if cached, ok := s.expandCache.Get(query); ok { + return &llm.ExpandResult{Original: query, Terms: cached, Cached: true}, nil + } + if err := s.ensureAssist(); err != nil { + return nil, err + } + + tmpl, err := agent.TemplateByName(s.cfg.Template) + if err != nil { + return nil, err + } + prompt := buildAssistPrompt(tmpl, expandSystem, "Query: "+query) + + raw, err := s.runAssist(prompt, expandGrammar, expandMaxTokens) + if err != nil { + return nil, err + } + + terms := parseStringList(raw, "terms") + terms = dedupeFilter(terms, query) + // Even an empty result is worth caching — re-issuing the prompt + // won't change a model that consistently emits nothing useful. + s.expandCache.Set(query, terms) + return &llm.ExpandResult{Original: query, Terms: terms}, nil +} + +// RerankSymbols asks the model to reorder a candidate set by +// relevance to the query. IDs the model drops are appended at the +// tail in original input order so the caller never loses a candidate. +// Empty input returns an empty order without touching the model. +// +// Cache key includes the candidate ID set so two callers passing the +// same query against different candidate pools each get their own +// cache entry; ordering of input candidates does not affect the key. +func (s *Service) RerankSymbols(ctx context.Context, query string, cands []llm.RerankCandidate) (*llm.RerankResult, error) { + _ = ctx + query = strings.TrimSpace(query) + if query == "" || len(cands) == 0 { + return &llm.RerankResult{Order: candIDs(cands)}, nil + } + + key := rerankCacheKey(query, cands) + if cached, ok := s.rerankCache.Get(key); ok { + return &llm.RerankResult{Order: cached, Cached: true}, nil + } + if err := s.ensureAssist(); err != nil { + return nil, err + } + + tmpl, err := agent.TemplateByName(s.cfg.Template) + if err != nil { + return nil, err + } + user := buildRerankUser(query, cands) + prompt := buildAssistPrompt(tmpl, rerankSystem, user) + + raw, err := s.runAssist(prompt, rerankGrammar, rerankMaxTokens) + if err != nil { + // Surface the error but keep input order intact so the caller + // can still return *something* — search-assist must never + // degrade below baseline BM25 quality. + return &llm.RerankResult{Order: candIDs(cands)}, err + } + + rawOrder := parseStringList(raw, "order") + order := filterToInputAppend(rawOrder, cands) + s.rerankCache.Set(key, order) + return &llm.RerankResult{Order: order}, nil +} + +// VerifyRelevance reads each candidate's code body and returns only +// the IDs the model judges genuinely related to the query — an empty +// list means "no candidate's code actually does what was asked", +// which is a load-bearing honest-negative signal the caller should +// preserve rather than fall back to BM25 noise. +// +// Cache key includes (query, sorted IDs, body hash) so a re-indexed +// codebase doesn't return stale verifications. Empty input short- +// circuits without touching the model. +// +// On any inference or parse failure, returns the input order +// unchanged with the error — the caller should treat that as "could +// not verify" rather than "nothing matched". +func (s *Service) VerifyRelevance(ctx context.Context, query string, cands []llm.VerifyCandidate) (*llm.VerifyResult, error) { + _ = ctx + query = strings.TrimSpace(query) + if query == "" || len(cands) == 0 { + return &llm.VerifyResult{Keep: verifyIDs(cands)}, nil + } + + key := verifyCacheKey(query, cands) + if cached, ok := s.verifyCache.Get(key); ok { + return &llm.VerifyResult{Keep: cached, Cached: true}, nil + } + if err := s.ensureAssist(); err != nil { + return nil, err + } + + tmpl, err := agent.TemplateByName(s.cfg.Template) + if err != nil { + return nil, err + } + user := buildVerifyUser(query, cands) + prompt := buildAssistPrompt(tmpl, verifySystem, user) + + raw, err := s.runAssist(prompt, verifyGrammar, verifyMaxTokens) + if err != nil { + // On failure, surface the error and keep all input candidates + // — better to over-include than to silently drop them. + return &llm.VerifyResult{Keep: verifyIDs(cands)}, err + } + + rawKeep := parseStringList(raw, "keep") + keep := filterKeepToInput(rawKeep, cands) + s.verifyCache.Set(key, keep) + return &llm.VerifyResult{Keep: keep}, nil +} + +// runAssist is the shared inference primitive for the two assist +// methods. Holds assistMu, resets KV cache, installs the grammar, +// generates with the jsonComplete early-stop predicate, and returns +// the raw model output trimmed of surrounding whitespace. +func (s *Service) runAssist(prompt, grammar string, maxTokens int) (string, error) { + s.assistMu.Lock() + defer s.assistMu.Unlock() + + if s.assistCtx == nil { + return "", errors.New("llm: assist context not initialised") + } + + s.assistCtx.Reset() + if err := s.assistCtx.SetGrammar(grammar); err != nil { + return "", fmt.Errorf("llm: install assist grammar: %w", err) + } + + var buf strings.Builder + _, err := s.assistCtx.Generate(prompt, maxTokens, func(piece string) bool { + buf.WriteString(piece) + return !assistJSONComplete(buf.String()) + }) + if err != nil { + return "", err + } + return strings.TrimSpace(buf.String()), nil +} + +// buildAssistPrompt is the single-turn equivalent of agent.initialPrompt: +// no tool list, no AssistEnd round-trip — just System + User + AssistPrime. +func buildAssistPrompt(tmpl agent.ChatTemplate, system, user string) string { + return tmpl.BOS + tmpl.System(system) + tmpl.User(user) + tmpl.AssistPrime +} + +// buildVerifyUser formats the candidate list for the body-grounded +// verify prompt. Each candidate ships with its body and a compact +// callers block — the callers carry independent contextual signal +// that lets the model distinguish "same operation, different data" +// cases the body alone can't disambiguate. Bodies and signatures +// must be pre-truncated by the caller — this is a formatter, not +// the place to enforce length limits. +func buildVerifyUser(query string, cands []llm.VerifyCandidate) string { + var b strings.Builder + b.WriteString("Query: ") + b.WriteString(query) + b.WriteString("\n\nCandidates:\n") + for _, c := range cands { + b.WriteString(c.ID) + b.WriteString(" | ") + b.WriteString(c.Name) + if sig := strings.TrimSpace(c.Signature); sig != "" { + b.WriteString(" | ") + if len(sig) > 160 { + sig = sig[:160] + "…" + } + b.WriteString(sig) + } + b.WriteString("\nbody:\n") + if body := strings.TrimSpace(c.Body); body != "" { + b.WriteString(body) + if !strings.HasSuffix(body, "\n") { + b.WriteString("\n") + } + } else { + b.WriteString("(no body — signature-only)\n") + } + if len(c.Callers) > 0 { + b.WriteString("callers:\n") + for _, cl := range c.Callers { + b.WriteString("- ") + b.WriteString(cl.Name) + if sig := strings.TrimSpace(cl.Signature); sig != "" { + b.WriteString(" | ") + if len(sig) > 120 { + sig = sig[:120] + "…" + } + b.WriteString(sig) + } + b.WriteString("\n") + } + } else { + b.WriteString("callers: (none indexed)\n") + } + b.WriteString("---\n") + } + return b.String() +} + +// buildRerankUser formats the candidate list for the rerank prompt. +// One line per candidate: "id | name | signature?". Truncates very +// long signatures so a single noisy entry can't blow the context. +func buildRerankUser(query string, cands []llm.RerankCandidate) string { + var b strings.Builder + b.WriteString("Query: ") + b.WriteString(query) + b.WriteString("\nCandidates:\n") + for _, c := range cands { + b.WriteString("- ") + b.WriteString(c.ID) + b.WriteString(" | ") + b.WriteString(c.Name) + if sig := strings.TrimSpace(c.Signature); sig != "" { + b.WriteString(" | ") + if len(sig) > 120 { + sig = sig[:120] + "…" + } + b.WriteString(sig) + } + b.WriteString("\n") + } + return b.String() +} + +// assistJSONComplete is the same shape as agent.jsonComplete: stop +// generation as soon as the top-level JSON object closes and parses. +// Replicated rather than exported from package agent to keep that +// package's surface minimal. +func assistJSONComplete(s string) bool { + s = strings.TrimSpace(s) + if !strings.HasPrefix(s, "{") || !strings.HasSuffix(s, "}") { + return false + } + var v any + return json.Unmarshal([]byte(s), &v) == nil +} + +// parseStringList extracts a top-level JSON string array under the +// given key. Returns nil on any parse failure — the caller decides +// the fallback behaviour. +func parseStringList(raw, key string) []string { + if raw == "" { + return nil + } + m := map[string]json.RawMessage{} + if err := json.Unmarshal([]byte(raw), &m); err != nil { + return nil + } + v, ok := m[key] + if !ok { + return nil + } + var out []string + if err := json.Unmarshal(v, &out); err != nil { + return nil + } + if len(out) == 0 { + return nil + } + return out +} + +// expansionStoplist is the conservative list of generic English nouns +// that the BM25 layer matches against thousands of unrelated symbols +// (e.g. `function`, `data`, `library`). These rarely carry useful +// search signal on their own and almost always inflate the candidate +// pool with noise. Members were chosen by inspecting real expansion +// outputs from Qwen2.5-Coder 3B against the gortex corpus — words +// that produced no relevant additional hits but many irrelevant ones. +// +// Borderline / domain-bearing words like `encryption`, `algorithm`, +// `security`, `key` are deliberately NOT here: they can be load-bearing +// in some codebases (a crypto library is a different story than a code +// intelligence tool). Keep this list short — over-filtering throws +// away the only signal expansion has to offer. +var expansionStoplist = map[string]bool{ + "function": true, "functions": true, "method": true, "methods": true, + "library": true, "libraries": true, + "module": true, "modules": true, "package": true, "packages": true, + "system": true, "systems": true, + "service": true, "services": true, + "code": true, "codes": true, "source": true, + "data": true, "datum": true, + "value": true, "values": true, + "object": true, "objects": true, "item": true, "items": true, + "thing": true, "things": true, + "info": true, "information": true, + "content": true, "contents": true, + "stuff": true, + "general": true, "common": true, "basic": true, "simple": true, "main": true, + "text": true, + // Generic verbs/nouns that slip through with NL queries — observed + // in the wild: "where is the rerank logic for search results" pulled + // in "logic" as an expansion term, which broadens BM25 enormously + // against any *_logic or logical_* identifier. + "logic": true, "logical": true, + "process": true, "processing": true, + "handle": true, "handler": true, "handling": true, + "flow": true, "flows": true, + "action": true, "actions": true, + "helper": true, "helpers": true, + "util": true, "utils": true, "utility": true, "utilities": true, +} + +// minExpansionTermLen rejects terms shorter than this. Sub-3 char +// fragments (`do`, `is`, `id`) generate huge BM25 hit lists and +// almost never carry useful signal. The threshold is conservative — +// short identifiers like `js`, `db`, `ui` get through. +const minExpansionTermLen = 3 + +// dedupeFilter trims, lowercases for comparison, and drops terms that +// are empty, duplicates, the original query, in expansionStoplist, or +// shorter than minExpansionTermLen. Preserves order of the surviving +// terms. The cap at maxExpansionTerms keeps the merged candidate pool +// bounded even when the model ignores the "2 to 5" prompt instruction. +func dedupeFilter(terms []string, query string) []string { + queryLower := strings.ToLower(strings.TrimSpace(query)) + seen := map[string]bool{queryLower: true} + out := make([]string, 0, len(terms)) + for _, t := range terms { + t = strings.TrimSpace(t) + if t == "" { + continue + } + k := strings.ToLower(t) + if seen[k] || expansionStoplist[k] { + continue + } + if len(t) < minExpansionTermLen { + continue + } + seen[k] = true + out = append(out, t) + if len(out) >= maxExpansionTerms { + break + } + } + return out +} + +// maxExpansionTerms caps the per-call expansion regardless of model +// output. Each extra term adds a BM25 sweep + candidate-pool growth, +// so trimming aggressively saves both latency and rerank prompt size. +const maxExpansionTerms = 5 + +// candIDs extracts just the ID slice from a candidate list, +// preserving order. Returned for fallback paths so the caller still +// gets a valid (if unhelpful) ordering. +func candIDs(cands []llm.RerankCandidate) []string { + if len(cands) == 0 { + return nil + } + out := make([]string, len(cands)) + for i, c := range cands { + out[i] = c.ID + } + return out +} + +// verifyIDs is the VerifyCandidate equivalent of candIDs — used on +// fallback paths where we want to preserve every input ID rather +// than drop them silently. +func verifyIDs(cands []llm.VerifyCandidate) []string { + if len(cands) == 0 { + return nil + } + out := make([]string, len(cands)) + for i, c := range cands { + out[i] = c.ID + } + return out +} + +// filterKeepToInput is the VerifyResult equivalent of +// filterToInputAppend but with one critical difference: dropped IDs +// are NOT appended at the tail. An empty result IS the load-bearing +// honest-negative signal, so callers must see exactly what the +// model decided to keep. +// +// Hallucinated and duplicate IDs are still filtered defensively. +func filterKeepToInput(modelKeep []string, cands []llm.VerifyCandidate) []string { + valid := make(map[string]bool, len(cands)) + for _, c := range cands { + valid[c.ID] = true + } + used := make(map[string]bool, len(cands)) + out := make([]string, 0, len(modelKeep)) + for _, id := range modelKeep { + if !valid[id] || used[id] { + continue + } + used[id] = true + out = append(out, id) + } + return out +} + +// filterToInputAppend builds the final rerank order: every model ID +// that matches an input candidate, in model-supplied order, then any +// remaining input IDs in their original order. This makes the result +// a stable permutation of the input set even when the model drops or +// hallucinates entries. +func filterToInputAppend(modelOrder []string, cands []llm.RerankCandidate) []string { + valid := make(map[string]bool, len(cands)) + for _, c := range cands { + valid[c.ID] = true + } + used := make(map[string]bool, len(cands)) + out := make([]string, 0, len(cands)) + for _, id := range modelOrder { + if !valid[id] || used[id] { + continue + } + used[id] = true + out = append(out, id) + } + for _, c := range cands { + if !used[c.ID] { + out = append(out, c.ID) + } + } + return out +} diff --git a/internal/llm/svc/assist_e2e_test.go b/internal/llm/svc/assist_e2e_test.go new file mode 100644 index 0000000..3dd7359 --- /dev/null +++ b/internal/llm/svc/assist_e2e_test.go @@ -0,0 +1,385 @@ +//go:build llama + +package svc + +import ( + "context" + "os" + "path/filepath" + "sort" + "strings" + "testing" + "time" + + "github.com/zzet/gortex/internal/llm" +) + +// TestE2E_AssistAgainstRealModel exercises ExpandQuery and +// RerankSymbols against the model configured at GORTEX_LLM_MODEL (or +// ~/models/qwen2.5-coder-3b-instruct-q4_k_m.gguf if unset). Skipped +// when the file isn't present, so the test is safe to commit and is +// only opted into when a developer has a model on disk. +// +// What it asserts: +// - ExpandQuery on an NL query returns >=1 non-empty term, none of +// which is the original query verbatim. +// - RerankSymbols echoes back a permutation of the candidate IDs. +// - The dedicated assist context survives back-to-back calls (no +// KV bleed) and the cache hits the second time. +// +// What it does NOT assert: specific term content or ordering — those +// depend on the model. +func TestE2E_AssistAgainstRealModel(t *testing.T) { + modelPath := resolveModelPath(t) + if modelPath == "" { + t.Skip("no model configured (set GORTEX_LLM_MODEL or place qwen2.5-coder-3b-instruct-q4_k_m.gguf in ~/models)") + } + + cfg := llm.Config{ + Model: modelPath, + Template: "chatml", + Ctx: 4096, + MaxSteps: 16, + }.ApplyDefaults() + + svcInst := NewService(cfg, llm.MockBackend{}) + t.Cleanup(func() { _ = svcInst.Close() }) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + t.Cleanup(cancel) + + t.Run("ExpandQuery", func(t *testing.T) { + query := "where do we hash passwords" + t0 := time.Now() + got, err := svcInst.ExpandQuery(ctx, query) + t.Logf("ExpandQuery %q → %v (%v)", query, got, time.Since(t0)) + if err != nil { + t.Fatalf("ExpandQuery failed: %v", err) + } + if got == nil { + t.Fatal("got nil result") + } + if len(got.Terms) == 0 { + t.Fatal("model returned 0 terms — prompt may need tuning") + } + for _, term := range got.Terms { + if strings.EqualFold(strings.TrimSpace(term), query) { + t.Errorf("term echoes original query: %q", term) + } + if strings.TrimSpace(term) == "" { + t.Error("empty term in output") + } + if expansionStoplist[strings.ToLower(strings.TrimSpace(term))] { + t.Errorf("stoplisted generic noun leaked through filter: %q", term) + } + if len(strings.TrimSpace(term)) < minExpansionTermLen { + t.Errorf("sub-min-length term leaked through filter: %q", term) + } + } + if len(got.Terms) > maxExpansionTerms { + t.Errorf("expansion exceeded cap: %d > %d", len(got.Terms), maxExpansionTerms) + } + + // Second call must hit the cache. + t1 := time.Now() + got2, err := svcInst.ExpandQuery(ctx, query) + if err != nil { + t.Fatalf("second ExpandQuery failed: %v", err) + } + if !got2.Cached { + t.Error("expected Cached=true on second call") + } + t.Logf("ExpandQuery cached lookup: %v", time.Since(t1)) + }) + + t.Run("RerankSymbols", func(t *testing.T) { + query := "validate user authentication token" + cands := []llm.RerankCandidate{ + {ID: "pkg/auth.parseJWT", Name: "parseJWT", Signature: "func parseJWT(s string) (*Claims, error)"}, + {ID: "pkg/auth.ValidateToken", Name: "ValidateToken", Signature: "func ValidateToken(tok string) error"}, + {ID: "pkg/auth.AuthMiddleware", Name: "AuthMiddleware", Signature: "func AuthMiddleware(h http.Handler) http.Handler"}, + {ID: "pkg/user.hashPassword", Name: "hashPassword", Signature: "func hashPassword(p string) []byte"}, + {ID: "pkg/user.NewUser", Name: "NewUser", Signature: "func NewUser(name string) *User"}, + } + t0 := time.Now() + got, err := svcInst.RerankSymbols(ctx, query, cands) + t.Logf("RerankSymbols %q → %v (%v)", query, got.Order, time.Since(t0)) + if err != nil { + t.Fatalf("RerankSymbols failed: %v", err) + } + if got == nil { + t.Fatal("got nil result") + } + // Output must be a permutation of input IDs. + assertPermutation(t, got.Order, cands) + }) + + // The verify scenarios are intentionally soft: they LOG kept ids + // rather than fail the test. Verify quality is model-dependent + // (3B fails the disambiguation cases; we want to compare against + // 7B without breaking CI on the smaller model). The structural + // invariant (Keep ⊆ input ids) is asserted via assertSubset. + t.Run("VerifyRelevance_HashPasswords", func(t *testing.T) { + query := "where do we hash passwords" + cands := verifyHashPasswordsScenario() + t0 := time.Now() + got, err := svcInst.VerifyRelevance(ctx, query, cands) + t.Logf("VerifyRelevance %q → kept=%v dropped=%d (%v)", query, got.Keep, len(cands)-len(got.Keep), time.Since(t0)) + if err != nil { + t.Fatalf("VerifyRelevance failed: %v", err) + } + assertVerifySubset(t, got.Keep, cands) + + // Soft scoring — we LOG these rather than fail. + expectKept := "synthetic.pkg.user.HashPassword" + expectDropped := "real.hashDiagnostics" + kept := containsID(got.Keep, expectKept) + dropped := !containsID(got.Keep, expectDropped) + t.Logf("VERDICT(hash-passwords): kept HashPassword=%v, dropped hashDiagnostics=%v", kept, dropped) + }) + + t.Run("VerifyRelevance_BM25", func(t *testing.T) { + query := "how does the BM25 search rank symbols" + cands := verifyBM25Scenario() + t0 := time.Now() + got, err := svcInst.VerifyRelevance(ctx, query, cands) + t.Logf("VerifyRelevance %q → kept=%v dropped=%d (%v)", query, got.Keep, len(cands)-len(got.Keep), time.Since(t0)) + if err != nil { + t.Fatalf("VerifyRelevance failed: %v", err) + } + assertVerifySubset(t, got.Keep, cands) + + expectKept := "real.NewBM25" + expectDropped := "synthetic.unrelated.parseTSConfig" + kept := containsID(got.Keep, expectKept) + dropped := !containsID(got.Keep, expectDropped) + t.Logf("VERDICT(BM25): kept NewBM25=%v, dropped parseTSConfig=%v", kept, dropped) + }) +} + +// verifyHashPasswordsScenario reproduces the live failure case: the +// real hashDiagnostics (which calls sha256.Sum256 on diagnostic JSON) +// alongside a synthetic genuine password hasher and unrelated +// distractors. A strong-enough model should keep HashPassword and +// drop hashDiagnostics on caller/signature evidence. +func verifyHashPasswordsScenario() []llm.VerifyCandidate { + return []llm.VerifyCandidate{ + { + ID: "real.hashDiagnostics", + Name: "hashDiagnostics", + Signature: "func hashDiagnostics(diags []lsp.Diagnostic) string", + Body: `func hashDiagnostics(diags []lsp.Diagnostic) string { + if len(diags) == 0 { + return "empty" + } + b, err := json.Marshal(diags) + if err != nil { + return "" + } + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +}`, + Callers: []llm.CallerInfo{ + {Name: "diagnosticsBroadcaster.publish", Signature: "func (b *diagnosticsBroadcaster) publish(uri string, diags []lsp.Diagnostic)"}, + {Name: "Server.SetLSPDiagnosticsBroadcasting", Signature: "func (s *Server) SetLSPDiagnosticsBroadcasting()"}, + }, + }, + { + ID: "synthetic.pkg.user.HashPassword", + Name: "HashPassword", + Signature: "func HashPassword(plain string) ([]byte, error)", + Body: `func HashPassword(plain string) ([]byte, error) { + return bcrypt.GenerateFromPassword([]byte(plain), bcrypt.DefaultCost) +}`, + Callers: []llm.CallerInfo{ + {Name: "UserHandler.RegisterUser", Signature: "func (h *UserHandler) RegisterUser(w http.ResponseWriter, r *http.Request)"}, + {Name: "AdminHandler.SetPassword", Signature: "func (h *AdminHandler) SetPassword(userID int, newPlain string) error"}, + }, + }, + { + ID: "synthetic.gitCommitHash", + Name: "gitCommitHash", + Signature: "func gitCommitHash(dir string) string", + Body: `func gitCommitHash(dir string) string { + out, err := exec.Command("git", "-C", dir, "rev-parse", "HEAD").Output() + if err != nil { + return "" + } + return strings.TrimSpace(string(out)) +}`, + Callers: []llm.CallerInfo{ + {Name: "indexer.snapshot", Signature: "func (i *Indexer) snapshot() Snapshot"}, + }, + }, + { + ID: "synthetic.do_setup", + Name: "do_setup", + Signature: "do_setup()", + Body: `do_setup() { + echo "provisioning hetzner instance" + hcloud server create --type cpx21 --image ubuntu-22.04 --name "$1" +}`, + Callers: nil, + }, + { + ID: "synthetic.findDoBlock", + Name: "findDoBlock", + Signature: "func (e *ElixirExtractor) findDoBlock(callNode *sitter.Node) *sitter.Node", + Body: `func (e *ElixirExtractor) findDoBlock(callNode *sitter.Node) *sitter.Node { + for i := uint32(0); i < callNode.ChildCount(); i++ { + c := callNode.Child(i) + if c.Type() == "do_block" { + return c + } + } + return nil +}`, + Callers: []llm.CallerInfo{ + {Name: "ElixirExtractor.extractDefs", Signature: "func (e *ElixirExtractor) extractDefs(root *sitter.Node, src []byte)"}, + }, + }, + } +} + +// verifyBM25Scenario tests the model's ability to keep the genuine +// BM25 ranking implementation while dropping unrelated parsers / +// fixtures. +func verifyBM25Scenario() []llm.VerifyCandidate { + return []llm.VerifyCandidate{ + { + ID: "real.NewBM25", + Name: "NewBM25", + Signature: "func NewBM25() *BM25Backend", + Body: `func NewBM25() *BM25Backend { + return &BM25Backend{ + inverted: make(map[string][]posting), + bigrams: make(map[string]map[string]struct{}), + docs: make(map[string]doc), + } +}`, + Callers: []llm.CallerInfo{ + {Name: "indexer.New", Signature: "func New(g *graph.Graph, ...) *Indexer"}, + }, + }, + { + ID: "real.BM25Backend.Search", + Name: "Search", + Signature: "func (b *BM25Backend) Search(query string, limit int) []scored", + Body: `func (b *BM25Backend) Search(query string, limit int) []scored { + terms := tokenize(query) + scores := map[string]float64{} + for _, t := range terms { + for _, p := range b.inverted[t] { + scores[p.id] += b.bm25Score(p, t) + } + } + return topK(scores, limit) +}`, + Callers: []llm.CallerInfo{ + {Name: "Engine.SearchSymbolsScoped", Signature: "func (e *Engine) SearchSymbolsScoped(q string, limit int, opts QueryOptions) []*graph.Node"}, + }, + }, + { + ID: "synthetic.unrelated.parseTSConfig", + Name: "parseTSConfig", + Signature: "func parseTSConfig(path string) (*tsConfig, error)", + Body: `func parseTSConfig(path string) (*tsConfig, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var cfg tsConfig + return &cfg, json.Unmarshal(b, &cfg) +}`, + Callers: []llm.CallerInfo{ + {Name: "TypeScriptExtractor.detect", Signature: "func (e *TypeScriptExtractor) detect(path string) bool"}, + }, + }, + { + ID: "synthetic.unrelated.do_eval", + Name: "do_eval", + Signature: "do_eval()", + Body: `do_eval() { + ssh root@"$IP" "cd /opt/gortex && ./run-bench.sh" +}`, + Callers: nil, + }, + } +} + +func assertVerifySubset(t *testing.T, kept []string, cands []llm.VerifyCandidate) { + t.Helper() + valid := make(map[string]bool, len(cands)) + for _, c := range cands { + valid[c.ID] = true + } + for _, id := range kept { + if !valid[id] { + t.Errorf("verify emitted unknown id: %q", id) + } + } +} + +func containsID(ids []string, want string) bool { + for _, id := range ids { + if id == want { + return true + } + } + return false +} + +// resolveModelPath checks env first, then the documented default +// location used by the user's daemon config. +func resolveModelPath(t *testing.T) string { + if p := os.Getenv("GORTEX_LLM_MODEL"); p != "" { + if _, err := os.Stat(p); err == nil { + return p + } + } + home, err := os.UserHomeDir() + if err != nil { + return "" + } + candidate := filepath.Join(home, "models", "qwen2.5-coder-3b-instruct-q4_k_m.gguf") + if _, err := os.Stat(candidate); err == nil { + return candidate + } + return "" +} + +func assertPermutation(t *testing.T, got []string, cands []llm.RerankCandidate) { + t.Helper() + if len(got) != len(cands) { + t.Fatalf("length mismatch: got=%d cands=%d", len(got), len(cands)) + } + wantSet := make(map[string]bool, len(cands)) + for _, c := range cands { + wantSet[c.ID] = true + } + gotSet := make(map[string]bool, len(got)) + for _, id := range got { + if !wantSet[id] { + t.Errorf("rerank emitted unknown id: %q", id) + } + if gotSet[id] { + t.Errorf("rerank emitted duplicate id: %q", id) + } + gotSet[id] = true + } + if len(gotSet) != len(wantSet) { + gotKeys := keysOf(gotSet) + wantKeys := keysOf(wantSet) + sort.Strings(gotKeys) + sort.Strings(wantKeys) + t.Fatalf("missing ids:\n got: %v\n want: %v", gotKeys, wantKeys) + } +} + +func keysOf(m map[string]bool) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} diff --git a/internal/llm/svc/assist_test.go b/internal/llm/svc/assist_test.go new file mode 100644 index 0000000..91f85e5 --- /dev/null +++ b/internal/llm/svc/assist_test.go @@ -0,0 +1,276 @@ +//go:build llama + +package svc + +import ( + "reflect" + "sync" + "testing" + + "github.com/zzet/gortex/internal/llm" +) + +func TestAssistCache_GetSetEvict(t *testing.T) { + c := newAssistCache(2) + + c.Set("a", []string{"1"}) + c.Set("b", []string{"2"}) + if got, ok := c.Get("a"); !ok || !reflect.DeepEqual(got, []string{"1"}) { + t.Fatalf("get a: ok=%v got=%v", ok, got) + } + + // Insert beyond cap; oldest ("a") must evict. + c.Set("c", []string{"3"}) + if _, ok := c.Get("a"); ok { + t.Fatalf("expected a to be evicted") + } + if _, ok := c.Get("b"); !ok { + t.Fatalf("expected b to remain") + } + if _, ok := c.Get("c"); !ok { + t.Fatalf("expected c to be present") + } +} + +func TestAssistCache_UpdateInPlace(t *testing.T) { + c := newAssistCache(2) + c.Set("a", []string{"1"}) + c.Set("b", []string{"2"}) + // Update "a" — must NOT count as a fresh insert that would + // evict "b". + c.Set("a", []string{"1", "1b"}) + if got, _ := c.Get("a"); !reflect.DeepEqual(got, []string{"1", "1b"}) { + t.Fatalf("update lost: %v", got) + } + if _, ok := c.Get("b"); !ok { + t.Fatalf("update should not evict; b missing") + } +} + +func TestAssistCache_CopyOnReadAndWrite(t *testing.T) { + c := newAssistCache(2) + val := []string{"x", "y"} + c.Set("k", val) + + // Mutate caller-side input post-Set; cache must be unaffected. + val[0] = "MUTATED" + got, _ := c.Get("k") + if got[0] != "x" { + t.Fatalf("cache mirrored caller mutation: %v", got) + } + + // Mutate returned slice; subsequent Get must still return the + // original. + got[0] = "ALSO_MUTATED" + got2, _ := c.Get("k") + if got2[0] != "x" { + t.Fatalf("cache returns aliased slice: %v", got2) + } +} + +func TestAssistCache_Concurrent(t *testing.T) { + c := newAssistCache(128) + var wg sync.WaitGroup + for i := 0; i < 32; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + key := string(rune('a' + (i % 26))) + c.Set(key, []string{key}) + c.Get(key) + }(i) + } + wg.Wait() +} + +func TestRerankCacheKey_StableAcrossOrder(t *testing.T) { + a := []llm.RerankCandidate{{ID: "x"}, {ID: "y"}, {ID: "z"}} + b := []llm.RerankCandidate{{ID: "z"}, {ID: "x"}, {ID: "y"}} + if rerankCacheKey("q", a) != rerankCacheKey("q", b) { + t.Fatalf("key must be independent of input ordering") + } +} + +func TestRerankCacheKey_DiffersOnQuery(t *testing.T) { + c := []llm.RerankCandidate{{ID: "x"}} + if rerankCacheKey("q1", c) == rerankCacheKey("q2", c) { + t.Fatalf("different queries must produce different keys") + } +} + +func TestParseStringList(t *testing.T) { + cases := []struct { + name string + raw string + key string + want []string + }{ + {"happy", `{"terms":["a","b","c"]}`, "terms", []string{"a", "b", "c"}}, + {"empty array", `{"terms":[]}`, "terms", nil}, + {"missing key", `{"order":["a"]}`, "terms", nil}, + {"malformed", `{terms:[a]}`, "terms", nil}, + {"non-array value", `{"terms":"oops"}`, "terms", nil}, + {"blank input", ``, "terms", nil}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := parseStringList(tc.raw, tc.key) + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got=%v want=%v", got, tc.want) + } + }) + } +} + +func TestDedupeFilter(t *testing.T) { + got := dedupeFilter([]string{"BCrypt", "bcrypt", "", " ", "argon2", "validate"}, "validate") + want := []string{"BCrypt", "argon2"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestDedupeFilter_DropsStoplist(t *testing.T) { + got := dedupeFilter( + []string{"function", "library", "bcrypt", "data", "argon2", "general"}, + "hash passwords") + want := []string{"bcrypt", "argon2"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestDedupeFilter_DropsShortTerms(t *testing.T) { + got := dedupeFilter([]string{"jwt", "is", "do", "id", "bcrypt", "ab"}, "auth") + want := []string{"jwt", "bcrypt"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestDedupeFilter_RespectsMaxCap(t *testing.T) { + in := []string{"a1234", "b1234", "c1234", "d1234", "e1234", "f1234", "g1234"} + got := dedupeFilter(in, "q") + if len(got) != maxExpansionTerms { + t.Fatalf("len=%d want=%d", len(got), maxExpansionTerms) + } +} + +func TestDedupeFilter_StoplistCaseInsensitive(t *testing.T) { + got := dedupeFilter([]string{"FUNCTION", "Library", "BCrypt"}, "q") + want := []string{"BCrypt"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestFilterToInputAppend(t *testing.T) { + cands := []llm.RerankCandidate{ + {ID: "a"}, {ID: "b"}, {ID: "c"}, {ID: "d"}, + } + model := []string{"c", "hallucinated", "a", "c"} // dup + hallucinated + got := filterToInputAppend(model, cands) + want := []string{"c", "a", "b", "d"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestFilterToInputAppend_ModelEmpty(t *testing.T) { + cands := []llm.RerankCandidate{{ID: "a"}, {ID: "b"}} + got := filterToInputAppend(nil, cands) + want := []string{"a", "b"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestVerifyIDs(t *testing.T) { + if got := verifyIDs(nil); got != nil { + t.Fatalf("nil input: got=%v", got) + } + in := []llm.VerifyCandidate{{ID: "a"}, {ID: "b"}} + got := verifyIDs(in) + want := []string{"a", "b"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestFilterKeepToInput(t *testing.T) { + cands := []llm.VerifyCandidate{ + {ID: "a"}, {ID: "b"}, {ID: "c"}, {ID: "d"}, + } + model := []string{"c", "hallucinated", "a", "c"} // dup + hallucinated + got := filterKeepToInput(model, cands) + want := []string{"c", "a"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestFilterKeepToInput_EmptyKeepHonoured(t *testing.T) { + // Critical contract: an empty model result must produce an empty + // keep list. Dropped IDs MUST NOT be appended back. + cands := []llm.VerifyCandidate{{ID: "a"}, {ID: "b"}} + got := filterKeepToInput(nil, cands) + if len(got) != 0 { + t.Fatalf("expected empty result, got %v", got) + } +} + +func TestFilterKeepToInput_AllHallucinated(t *testing.T) { + cands := []llm.VerifyCandidate{{ID: "a"}, {ID: "b"}} + got := filterKeepToInput([]string{"x", "y"}, cands) + if len(got) != 0 { + t.Fatalf("expected empty result, got %v", got) + } +} + +func TestVerifyCacheKey_DiffersWhenBodyChanges(t *testing.T) { + a := []llm.VerifyCandidate{{ID: "x", Body: "old code"}} + b := []llm.VerifyCandidate{{ID: "x", Body: "new code"}} + if verifyCacheKey("q", a) == verifyCacheKey("q", b) { + t.Fatalf("body change must produce different cache key") + } +} + +func TestVerifyCacheKey_StableAcrossOrder(t *testing.T) { + a := []llm.VerifyCandidate{{ID: "x", Body: "X"}, {ID: "y", Body: "Y"}} + b := []llm.VerifyCandidate{{ID: "y", Body: "Y"}, {ID: "x", Body: "X"}} + if verifyCacheKey("q", a) != verifyCacheKey("q", b) { + t.Fatalf("key must be independent of input ordering") + } +} + +func TestCandIDs(t *testing.T) { + if got := candIDs(nil); got != nil { + t.Fatalf("nil input: got=%v", got) + } + in := []llm.RerankCandidate{{ID: "z"}, {ID: "a"}} + got := candIDs(in) + want := []string{"z", "a"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got=%v want=%v", got, want) + } +} + +func TestAssistJSONComplete(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"{}", true}, + {`{"a":1}`, true}, + {" {} ", true}, + {"{", false}, + {`{"a":`, false}, + {"not json", false}, + {"", false}, + } + for _, tc := range cases { + if got := assistJSONComplete(tc.in); got != tc.want { + t.Fatalf("in=%q got=%v want=%v", tc.in, got, tc.want) + } + } +} diff --git a/internal/llm/svc/cache.go b/internal/llm/svc/cache.go new file mode 100644 index 0000000..55987b8 --- /dev/null +++ b/internal/llm/svc/cache.go @@ -0,0 +1,121 @@ +//go:build llama + +package svc + +import ( + "crypto/sha256" + "encoding/hex" + "sort" + "strings" + "sync" + + "github.com/zzet/gortex/internal/llm" +) + +// assistCache is a tiny FIFO-evicting LRU keyed on a string. Values +// are []string lists of either expanded query terms or reranked node +// IDs. Hand-rolled to avoid a dependency for what's a few-hundred-entry +// map. +// +// Concurrent-safe: every public method takes the lock. Sub-microsecond +// per op, fine for the inline call sites in ExpandQuery / RerankSymbols. +type assistCache struct { + mu sync.Mutex + max int + data map[string][]string + keys []string // insertion order; oldest first +} + +func newAssistCache(max int) *assistCache { + if max <= 0 { + max = 256 + } + return &assistCache{ + max: max, + data: make(map[string][]string, max), + keys: make([]string, 0, max), + } +} + +func (c *assistCache) Get(key string) ([]string, bool) { + c.mu.Lock() + defer c.mu.Unlock() + v, ok := c.data[key] + if !ok { + return nil, false + } + // Copy so callers can't mutate the cached slice. + out := make([]string, len(v)) + copy(out, v) + return out, true +} + +func (c *assistCache) Set(key string, val []string) { + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.data[key]; ok { + // Update in place; keep position in keys. + stored := make([]string, len(val)) + copy(stored, val) + c.data[key] = stored + return + } + if len(c.keys) >= c.max { + // Evict oldest. + oldest := c.keys[0] + c.keys = c.keys[1:] + delete(c.data, oldest) + } + stored := make([]string, len(val)) + copy(stored, val) + c.data[key] = stored + c.keys = append(c.keys, key) +} + +// Len reports current cache size. Test-only convenience. +func (c *assistCache) Len() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.data) +} + +// rerankCacheKey hashes the candidate ID set together with the query +// so cache lookups are stable across input orderings (two callers that +// pass the same candidates in a different order still hit the cache). +func rerankCacheKey(query string, cands []llm.RerankCandidate) string { + ids := make([]string, len(cands)) + for i, c := range cands { + ids[i] = c.ID + } + sort.Strings(ids) + h := sha256.New() + h.Write([]byte(query)) + h.Write([]byte{0x1f}) // unit separator + h.Write([]byte(strings.Join(ids, "\x1e"))) + return hex.EncodeToString(h.Sum(nil)[:16]) +} + +// verifyCacheKey hashes (query, sorted-ids, body-fingerprint) so a +// re-indexed codebase doesn't serve stale verifications. The body +// fingerprint is intentionally NOT just the ID set: an agent could +// re-issue the same query after editing one of the candidates' source, +// and we want to re-verify in that case. +func verifyCacheKey(query string, cands []llm.VerifyCandidate) string { + type idBody struct{ id, body string } + entries := make([]idBody, len(cands)) + for i, c := range cands { + entries[i] = idBody{id: c.ID, body: c.Body} + } + sort.Slice(entries, func(i, j int) bool { return entries[i].id < entries[j].id }) + + h := sha256.New() + h.Write([]byte(query)) + h.Write([]byte{0x1f}) + for _, e := range entries { + h.Write([]byte(e.id)) + h.Write([]byte{0x1e}) + h.Write([]byte(e.body)) + h.Write([]byte{0x1d}) // record separator + } + return hex.EncodeToString(h.Sum(nil)[:16]) +} diff --git a/internal/llm/svc/service.go b/internal/llm/svc/service.go index 018ae43..7cc3be0 100644 --- a/internal/llm/svc/service.go +++ b/internal/llm/svc/service.go @@ -32,6 +32,14 @@ import ( // // Both go through the same model and the same inference mutex — // llama.cpp is single-stream on a given device. +// +// In addition to the full-size RunAgent / Generate contexts, Service +// keeps a pre-warmed *assist context* — a smaller llama context used +// for short single-shot grammar-constrained calls (ExpandQuery, +// RerankSymbols). The assist context has its own mutex so a long +// `ask` doesn't head-of-line block hot-path NL search calls; at the +// llama.cpp level the two contexts share the model weights but each +// holds its own KV cache. type Service struct { cfg llm.Config backend llm.Backend @@ -41,6 +49,14 @@ type Service struct { loadErr error infer sync.Mutex + + assistOnce sync.Once + assistCtx *llm.Context + assistErr error + assistMu sync.Mutex + expandCache *assistCache + rerankCache *assistCache + verifyCache *assistCache } // NewService is cheap — it just stores the config and backend. The @@ -48,8 +64,11 @@ type Service struct { // Generate / RunAgent call, so daemon startup isn't slowed. func NewService(cfg llm.Config, backend llm.Backend) *Service { return &Service{ - cfg: cfg.ApplyDefaults(), - backend: backend, + cfg: cfg.ApplyDefaults(), + backend: backend, + expandCache: newAssistCache(256), + rerankCache: newAssistCache(256), + verifyCache: newAssistCache(256), } } @@ -76,9 +95,19 @@ func (s *Service) ensureLoaded() error { return s.loadErr } -// Close releases the underlying model. Safe to call multiple times. -// After Close, every operational method returns an error. +// Close releases the underlying model and any assist context. Safe +// to call multiple times. After Close, every operational method +// returns an error. func (s *Service) Close() error { + // Order matters: drop the assist context first so its KV cache + // is freed before the model itself goes away. + s.assistMu.Lock() + if s.assistCtx != nil { + s.assistCtx.Close() + s.assistCtx = nil + } + s.assistMu.Unlock() + s.infer.Lock() defer s.infer.Unlock() if s.model != nil { diff --git a/internal/llm/svc/service_stub.go b/internal/llm/svc/service_stub.go index b3143bd..52b936c 100644 --- a/internal/llm/svc/service_stub.go +++ b/internal/llm/svc/service_stub.go @@ -39,5 +39,23 @@ func (s *Service) RunAgent(_ context.Context, _ llm.RunAgentOptions) (*llm.Agent return nil, errServiceUnavailable } +// ExpandQuery is a no-op in the stub; returns errServiceUnavailable. +// Callers should check Enabled() first and skip the call entirely. +func (s *Service) ExpandQuery(_ context.Context, _ string) (*llm.ExpandResult, error) { + return nil, errServiceUnavailable +} + +// RerankSymbols is a no-op in the stub; returns errServiceUnavailable. +// Callers should check Enabled() first and skip the call entirely. +func (s *Service) RerankSymbols(_ context.Context, _ string, _ []llm.RerankCandidate) (*llm.RerankResult, error) { + return nil, errServiceUnavailable +} + +// VerifyRelevance is a no-op in the stub; returns errServiceUnavailable. +// Callers should check Enabled() first and skip the call entirely. +func (s *Service) VerifyRelevance(_ context.Context, _ string, _ []llm.VerifyCandidate) (*llm.VerifyResult, error) { + return nil, errServiceUnavailable +} + // Close is a no-op in the stub. func (s *Service) Close() error { return nil } diff --git a/internal/mcp/tools_core.go b/internal/mcp/tools_core.go index ae05644..3f6908c 100644 --- a/internal/mcp/tools_core.go +++ b/internal/mcp/tools_core.go @@ -576,6 +576,7 @@ func (s *Server) registerCoreTools() { mcp.WithString("project", mcp.Description("Filter results to repositories in a specific project")), mcp.WithString("ref", mcp.Description("Filter results to repositories with a specific reference tag")), mcp.WithString("kind", mcp.Description("Filter to one or more node kinds (comma-separated). Standard kinds: function, method, type, interface, variable, constant, field, file, package, import, contract. Coverage kinds: param, closure, enum_member, generic_param, module, table, column, config_key, flag, event, migration, fixture, todo, team, license, release.")), + mcp.WithString("assist", mcp.Description("LLM assist mode: \"auto\" (default — engages on natural-language queries, skips identifier lookups), \"on\" (force engage), \"off\" (bypass), \"deep\" (on + a body-grounded verification pass that reads candidate code and HONESTLY drops irrelevant matches — slower, may return empty results when nothing genuinely matches). Requires the daemon to be built with -tags llama and a configured model; otherwise behaves as \"off\".")), ), s.handleSearchSymbols, ) @@ -825,7 +826,34 @@ func (s *Server) handleSearchSymbols(ctx context.Context, req mcp.CallToolReques projectArg := req.GetString("project", "") scopeWS, scopeProj := s.resolveQueryScope(ctx, workspaceArg, projectArg) scope := query.QueryOptions{WorkspaceID: scopeWS, ProjectID: scopeProj} - nodes := s.engine.SearchSymbolsScoped(q, offset+limit+10, scope) + + // LLM assist gate: decides whether the expansion + rerank passes + // run for this query. The service-enabled check is layered inside + // the helpers so a stub build is a clean bypass. + assist := parseAssistMode(req) + engage := shouldEngageAssist(assist, q) && s.llmService != nil && s.llmService.Enabled() + + fetchLimit := offset + limit + 10 + if engage { + // Slightly widen the BM25 over-fetch when we're going to + // rerank: more head candidates means a more useful reorder. + fetchLimit = offset + limit + rerankCap + } + + var expandedTerms []string + if engage { + expandedTerms = expandSearchTerms(ctx, s, q) + } + + var nodes []*graph.Node + var primaryCount int + if len(expandedTerms) > 0 { + nodes, primaryCount = fetchAndMergeBM25(s, q, expandedTerms, fetchLimit, scope) + } else { + nodes = s.engine.SearchSymbolsScoped(q, fetchLimit, scope) + primaryCount = len(nodes) + } + mergedCount := len(nodes) // pre-filter; comparable to primaryCount // Apply repo/project/ref filter. allowed, filterErr := s.resolveRepoFilter(ctx, req) @@ -842,6 +870,24 @@ func (s *Server) handleSearchSymbols(ctx context.Context, req mcp.CallToolReques nodes = filterNodesByKind(nodes, kindArg) } + // LLM rerank runs AFTER kind/repo filters so the model only sees + // the candidate pool the caller will actually receive, and BEFORE + // the combo/frecency boost so per-session signals can still + // override a stale rerank. + var verifyDbg verifyDebug + var verifyRan bool + if engage { + nodes = rerankWithLLM(ctx, s, q, nodes) + // `deep` mode adds a body-grounded verification pass that + // reads candidate code and HONESTLY drops the ones whose + // body isn't actually about the query. An empty kept set is + // preserved — it's the load-bearing "nothing genuinely matches" + // signal that distinguishes deep mode from plain rerank. + if assist == assistDeep { + nodes, verifyDbg, verifyRan = verifyWithLLM(ctx, s, q, nodes) + } + } + // Rerank: fold locality + combo + frecency signals over the backend's // BM25 order. Locality ranks the session's home repo / project above // the rest of its workspace; combo + frecency are per-repo and @@ -902,6 +948,28 @@ func (s *Server) handleSearchSymbols(ctx context.Context, req mcp.CallToolReques if nextCursor != "" { resp["next_cursor"] = nextCursor } + // When LLM assist engaged, expose a small debug surface so callers + // (and the agent itself) can see what the model contributed. + // Suppressed when engage was false to keep the common-path response + // shape unchanged. + if engage { + assistDebug := map[string]any{ + "engaged": true, + "primary_count": primaryCount, // BM25 hits on original query alone, pre-filter + "merged_count": mergedCount, // BM25 hits after merging expansion terms, pre-filter + "final_count": total, // post-filter, post-rerank — matches the top-level total + } + if len(expandedTerms) > 0 { + assistDebug["terms"] = expandedTerms + } + if verifyRan { + assistDebug["verify_considered"] = verifyDbg.Considered + assistDebug["verify_kept_ids"] = verifyDbg.Kept + assistDebug["verify_kept"] = len(verifyDbg.Kept) + assistDebug["verify_dropped"] = len(verifyDbg.Considered) - len(verifyDbg.Kept) + } + resp["assist"] = assistDebug + } return s.respondJSONOrTOON(ctx, req, resp) } diff --git a/internal/mcp/tools_search_assist.go b/internal/mcp/tools_search_assist.go new file mode 100644 index 0000000..b8fd83a --- /dev/null +++ b/internal/mcp/tools_search_assist.go @@ -0,0 +1,457 @@ +package mcp + +import ( + "context" + "strings" + + mcpgo "github.com/mark3labs/mcp-go/mcp" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/llm" + "github.com/zzet/gortex/internal/query" +) + +// assistMode controls whether the LLM-driven query expansion and +// rerank passes run during handleSearchSymbols. Default is `auto` — +// the NL heuristic decides per-query. `on` and `off` are explicit +// overrides for callers that know their query character. `deep` is +// `on` plus a body-grounded verification pass that reads candidate +// code bodies and HONESTLY drops the ones whose code is not +// actually about the query (paid for in extra latency). +type assistMode int + +const ( + assistAuto assistMode = iota + assistOn + assistOff + assistDeep +) + +// parseAssistMode reads the `assist` arg. Unrecognised values fall +// back to `auto` rather than erroring so callers can't accidentally +// break search by typoing the flag. +func parseAssistMode(req mcpgo.CallToolRequest) assistMode { + switch strings.ToLower(strings.TrimSpace(req.GetString("assist", ""))) { + case "on", "yes", "true", "force": + return assistOn + case "off", "no", "false", "skip": + return assistOff + case "deep", "verify", "body": + return assistDeep + default: + return assistAuto + } +} + +// looksNaturalLanguage is the cheap pre-LLM gate. Returns true when +// the query is shaped like a natural-language description rather than +// an identifier lookup. Heuristics: +// - Fewer than 3 whitespace tokens → identifier; skip. +// - Any token containing dot / slash / scope-resolution → qualified +// identifier; skip. +// - Any token that's PascalCase or camelCase → identifier; skip. +// - At least one common English stop word among 3+ tokens → engage. +// - 4+ plain-word tokens with no identifier shape → engage. +// +// Empty / blank input never engages. +func looksNaturalLanguage(q string) bool { + q = strings.TrimSpace(q) + if q == "" { + return false + } + tokens := strings.Fields(q) + if len(tokens) < 3 { + return false + } + for _, t := range tokens { + if strings.ContainsAny(t, "./:_") { + return false + } + if hasMixedCase(t) { + return false + } + } + if hasStopWord(tokens) { + return true + } + return len(tokens) >= 4 +} + +// hasMixedCase reports whether a token contains both upper and lower +// ASCII letters — i.e. PascalCase or camelCase. Pure lowercase / +// pure uppercase plain words don't qualify. +func hasMixedCase(t string) bool { + var hasUpper, hasLower bool + for _, r := range t { + switch { + case r >= 'A' && r <= 'Z': + hasUpper = true + case r >= 'a' && r <= 'z': + hasLower = true + } + if hasUpper && hasLower { + return true + } + } + return false +} + +// assistStopWords is a tight list of English function words that +// rarely appear in identifier-style queries. Matching any of them in +// a 3+ token query strongly signals natural language. Kept short +// deliberately — false positives here cost LLM latency on every call. +var assistStopWords = map[string]struct{}{ + "where": {}, "how": {}, "what": {}, "why": {}, "which": {}, "when": {}, + "the": {}, "is": {}, "a": {}, "an": {}, + "in": {}, "of": {}, "to": {}, "for": {}, "with": {}, "by": {}, "from": {}, + "do": {}, "does": {}, "are": {}, "we": {}, "us": {}, "our": {}, + "and": {}, "or": {}, "not": {}, +} + +func hasStopWord(tokens []string) bool { + for _, t := range tokens { + if _, ok := assistStopWords[strings.ToLower(t)]; ok { + return true + } + } + return false +} + +// shouldEngageAssist combines the caller's explicit mode with the +// auto-gate heuristic. `on` and `deep` always engage, `off` never +// engages, and `auto` defers to looksNaturalLanguage. The service- +// side enabled check is layered on top — callers wrap this with +// `s.llmService != nil && s.llmService.Enabled()` so that a stub +// build short-circuits regardless of mode. +func shouldEngageAssist(mode assistMode, query string) bool { + switch mode { + case assistOff: + return false + case assistOn, assistDeep: + return true + default: + return looksNaturalLanguage(query) + } +} + +// expandSearchTerms calls the LLM expansion path and returns the +// extra terms. Returns nil (no expansion) on any failure so the +// search path stays at parity with today's behaviour when the model +// hiccups or isn't loaded yet. +func expandSearchTerms(ctx context.Context, s *Server, query string) []string { + if s.llmService == nil || !s.llmService.Enabled() { + return nil + } + res, err := s.llmService.ExpandQuery(ctx, query) + if err != nil || res == nil { + return nil + } + return res.Terms +} + +// fetchAndMergeBM25 runs BM25 once per term (original + expansions), +// then folds the results into a single deduplicated slice. The +// original query's hits win position; expansion hits append in their +// own BM25 order with duplicates skipped. +// +// fetchLimit is the per-term over-fetch budget. Bounded by the caller +// so a wide expansion can't blow up the candidate pool. +// +// primaryCount is the size of the original-query BM25 result before +// merging; useful for diagnostic / debug surfaces that want to show +// how many candidates expansion contributed. +func fetchAndMergeBM25(s *Server, original string, expanded []string, fetchLimit int, scope query.QueryOptions) (merged []*graph.Node, primaryCount int) { + primary := s.engine.SearchSymbolsScoped(original, fetchLimit, scope) + primaryCount = len(primary) + if len(expanded) == 0 { + return primary, primaryCount + } + seen := make(map[string]bool, len(primary)) + merged = make([]*graph.Node, 0, len(primary)) + for _, n := range primary { + if seen[n.ID] { + continue + } + seen[n.ID] = true + merged = append(merged, n) + } + for _, term := range expanded { + term = strings.TrimSpace(term) + if term == "" { + continue + } + extra := s.engine.SearchSymbolsScoped(term, fetchLimit, scope) + for _, n := range extra { + if seen[n.ID] { + continue + } + seen[n.ID] = true + merged = append(merged, n) + } + } + return merged, primaryCount +} + +// rerankCap bounds how many candidates the rerank pass sees. The +// model has limited working memory; past ~25 items its judgement +// degrades and the prompt blows the assist context. Trailing +// candidates beyond rerankCap stay in BM25 order and are appended +// after the reranked head. +const rerankCap = 20 + +// prioritizeCallables stably re-orders nodes so functions and methods +// come first, preserving BM25 order within each bucket. Everything +// non-callable (fields, params, variables, constants, types, files, +// imports, …) sinks to the tail in its original order. The intent +// is to make sure the rerank head — which is what the LLM sees and +// reorders — is populated with the symbols that actually *do* things, +// not their structural siblings that just happen to share tokens. +func prioritizeCallables(nodes []*graph.Node) []*graph.Node { + callable := make([]*graph.Node, 0, len(nodes)) + others := make([]*graph.Node, 0, len(nodes)) + for _, n := range nodes { + if n.Kind == graph.KindFunction || n.Kind == graph.KindMethod { + callable = append(callable, n) + } else { + others = append(others, n) + } + } + return append(callable, others...) +} + +// verifyCap bounds how many candidates the body-grounded verifier +// sees. Each candidate ships with its function body (truncated), so +// the input is much heavier than the name+sig rerank — keep it +// smaller to stay inside the assist context. +const verifyCap = 10 + +// verifyBodyMaxLines and verifyBodyMaxChars cap the per-candidate +// body fed to the model. We want enough to see what the code DOES +// (a function header + a few lines of logic) without including +// every helper call. Empirically 8 non-blank lines is plenty for +// the verify decision. +const ( + verifyBodyMaxLines = 8 + verifyBodyMaxChars = 600 +) + +// verifyCallersPerCand caps the number of callers sent per candidate. +// More callers = more disambiguation signal, but also more tokens. +// Three is empirically enough to anchor the data-domain of most +// functions without blowing the assist context for a 10-candidate batch. +const verifyCallersPerCand = 3 + +// topCallersForVerify returns up to verifyCallersPerCand callers of n, +// each with name + truncated signature. The query depth is 1 (direct +// callers only) and the brief detail level keeps memory pressure low. +// Returns nil for non-callable kinds or when GetCallers yields nothing. +func topCallersForVerify(s *Server, n *graph.Node) []llm.CallerInfo { + if n.Kind != graph.KindFunction && n.Kind != graph.KindMethod { + return nil + } + sg := s.engine.GetCallers(n.ID, query.QueryOptions{ + Depth: 1, + Limit: verifyCallersPerCand + 4, // over-fetch a little: self + non-callers get filtered + Detail: "brief", + }) + if sg == nil || len(sg.Nodes) == 0 { + return nil + } + out := make([]llm.CallerInfo, 0, verifyCallersPerCand) + for _, cn := range sg.Nodes { + if cn == nil || cn.ID == n.ID { + continue + } + if cn.Kind != graph.KindFunction && cn.Kind != graph.KindMethod { + continue + } + sig, _ := cn.Meta["signature"].(string) + out = append(out, llm.CallerInfo{ + Name: cn.Name, + Signature: sig, + }) + if len(out) >= verifyCallersPerCand { + break + } + } + return out +} + +// extractBodyForVerify reads a node's source body, returns the first +// verifyBodyMaxLines non-blank lines truncated to verifyBodyMaxChars. +// Returns "" when no source can be read or when the node isn't a +// function/method — non-function symbols pass through to the verifier +// with signature-only context, which the prompt handles explicitly. +func extractBodyForVerify(s *Server, n *graph.Node) string { + if n.Kind != graph.KindFunction && n.Kind != graph.KindMethod { + return "" + } + if n.StartLine <= 0 || n.EndLine <= 0 { + return "" + } + abs, err := s.resolveNodePath(n) + if err != nil { + return "" + } + source, _, _, err := readLines(abs, n.StartLine, n.EndLine, 0) + if err != nil { + return "" + } + return truncateBody(source, verifyBodyMaxLines, verifyBodyMaxChars) +} + +// truncateBody keeps the first maxLines non-blank lines, then +// caps the result at maxChars. Blank lines between code count +// against neither budget — they're skipped. Returns the truncated +// text with a trailing "…" marker when either cap fires. +func truncateBody(src string, maxLines, maxChars int) string { + if src == "" { + return "" + } + lines := strings.Split(src, "\n") + var b strings.Builder + kept := 0 + for _, ln := range lines { + if strings.TrimSpace(ln) == "" { + b.WriteString("\n") + continue + } + b.WriteString(ln) + b.WriteString("\n") + kept++ + if kept >= maxLines { + b.WriteString("…\n") + break + } + } + out := b.String() + if len(out) > maxChars { + out = out[:maxChars] + "…\n" + } + return out +} + +// verifyDebug captures what the verify pass saw and decided, so the +// debug surface can return it for diagnostic inspection. Lightweight +// — only ID lists, no bodies. +type verifyDebug struct { + Considered []string // IDs sent to the verifier (top-verifyCap of head) + Kept []string // IDs the model chose to keep, in keep order +} + +// verifyWithLLM runs the body-grounded verification pass on the head +// of `nodes`. Returns the model's kept-and-ordered subset followed +// by anything past verifyCap (unverified tail). On failure or empty +// service the input is returned unchanged. +// +// An empty `keep` is HONORED: when the model says "nothing here +// matches", we return only the unverified tail. The caller is meant +// to treat that as a legitimate negative result rather than fall back +// to the noisy pre-verify candidates. +func verifyWithLLM(ctx context.Context, s *Server, query string, nodes []*graph.Node) (result []*graph.Node, dbg verifyDebug, ok bool) { + if s.llmService == nil || !s.llmService.Enabled() || len(nodes) == 0 { + return nodes, dbg, false + } + head := nodes + tail := []*graph.Node(nil) + if len(nodes) > verifyCap { + head = nodes[:verifyCap] + tail = nodes[verifyCap:] + } + + cands := make([]llm.VerifyCandidate, len(head)) + idx := make(map[string]*graph.Node, len(head)) + dbg.Considered = make([]string, len(head)) + for i, n := range head { + sig, _ := n.Meta["signature"].(string) + cands[i] = llm.VerifyCandidate{ + ID: n.ID, + Name: n.Name, + Signature: sig, + Body: extractBodyForVerify(s, n), + Callers: topCallersForVerify(s, n), + } + idx[n.ID] = n + dbg.Considered[i] = n.ID + } + + res, err := s.llmService.VerifyRelevance(ctx, query, cands) + if err != nil || res == nil { + return nodes, dbg, false + } + + keptNodes := make([]*graph.Node, 0, len(res.Keep)) + usedIDs := make(map[string]bool, len(res.Keep)) + for _, id := range res.Keep { + if n, ok := idx[id]; ok && !usedIDs[id] { + usedIDs[id] = true + keptNodes = append(keptNodes, n) + dbg.Kept = append(dbg.Kept, id) + } + } + out := append(keptNodes, tail...) + return out, dbg, true +} + +// rerankWithLLM packs the head of `nodes` into RerankCandidates, +// calls the service, and rebuilds the slice in the model's order. +// Trailing candidates beyond rerankCap are kept verbatim after the +// reranked head. On any failure, returns the input unchanged. +// +// Before partitioning into head/tail, nodes are re-sorted so callable +// kinds (function / method) come before everything else — preserving +// BM25 order within each bucket. Without this, a high-scoring param +// or field node (e.g. `BM25Backend.Search#param:limit`) can pre-empt +// the enclosing method (`BM25Backend.Search`) inside the rerank +// window, leaving the model unable to surface the real callable. +func rerankWithLLM(ctx context.Context, s *Server, query string, nodes []*graph.Node) []*graph.Node { + if s.llmService == nil || !s.llmService.Enabled() || len(nodes) < 2 { + return nodes + } + nodes = prioritizeCallables(nodes) + head := nodes + tail := []*graph.Node(nil) + if len(nodes) > rerankCap { + head = nodes[:rerankCap] + tail = nodes[rerankCap:] + } + + cands := make([]llm.RerankCandidate, len(head)) + idx := make(map[string]*graph.Node, len(head)) + for i, n := range head { + sig, _ := n.Meta["signature"].(string) + cands[i] = llm.RerankCandidate{ + ID: n.ID, + Name: n.Name, + Signature: sig, + Path: n.FilePath, + } + idx[n.ID] = n + } + + res, err := s.llmService.RerankSymbols(ctx, query, cands) + if err != nil || res == nil || len(res.Order) == 0 { + return nodes + } + + reordered := make([]*graph.Node, 0, len(nodes)) + used := make(map[string]bool, len(head)) + for _, id := range res.Order { + n, ok := idx[id] + if !ok || used[id] { + continue + } + used[id] = true + reordered = append(reordered, n) + } + // Defensive: the service guarantees a permutation, but if any + // head node is missing for any reason, append it after the + // reranked head in its original position. + for _, n := range head { + if !used[n.ID] { + reordered = append(reordered, n) + } + } + reordered = append(reordered, tail...) + return reordered +} diff --git a/internal/mcp/tools_search_assist_test.go b/internal/mcp/tools_search_assist_test.go new file mode 100644 index 0000000..cad9830 --- /dev/null +++ b/internal/mcp/tools_search_assist_test.go @@ -0,0 +1,279 @@ +package mcp + +import ( + "encoding/json" + "strings" + "testing" + + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zzet/gortex/internal/graph" + "github.com/zzet/gortex/internal/query" +) + +func TestParseAssistMode(t *testing.T) { + cases := []struct { + in string + want assistMode + }{ + {"", assistAuto}, + {"auto", assistAuto}, + {"AUTO", assistAuto}, + {" auto ", assistAuto}, + {"on", assistOn}, + {"ON", assistOn}, + {"yes", assistOn}, + {"true", assistOn}, + {"force", assistOn}, + {"off", assistOff}, + {"OFF", assistOff}, + {"no", assistOff}, + {"false", assistOff}, + {"skip", assistOff}, + {"deep", assistDeep}, + {"DEEP", assistDeep}, + {"verify", assistDeep}, + {"body", assistDeep}, + {"garbage", assistAuto}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]any{"assist": tc.in} + got := parseAssistMode(req) + assert.Equal(t, tc.want, got, "input=%q", tc.in) + }) + } +} + +func TestLooksNaturalLanguage(t *testing.T) { + cases := []struct { + name string + q string + want bool + }{ + {"empty", "", false}, + {"blanks", " ", false}, + {"single token", "handler", false}, + {"two tokens", "handle user", false}, + + {"qualified identifier", "pkg/foo bar baz", false}, + {"camelCase token", "handleSomething for fun", false}, + {"PascalCase token", "MyHandler tests pass", false}, + {"dotted identifier", "foo.Bar baz qux", false}, + {"snake_case identifier", "do_thing in cluster", false}, + {"scoped identifier", "ns::Type does stuff", false}, + + {"NL with stop word", "where do we hash passwords", true}, + {"NL plain 4 tokens", "validate token auth flow", true}, + {"NL plain 3 tokens no stop word", "validate token auth", false}, + {"NL with the", "the user login flow", true}, + + {"mixed identifier short-circuits stop word", "the handleAsk token", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := looksNaturalLanguage(tc.q) + assert.Equal(t, tc.want, got, "q=%q", tc.q) + }) + } +} + +func TestShouldEngageAssist(t *testing.T) { + // `on` always engages, regardless of shape. + assert.True(t, shouldEngageAssist(assistOn, "Foo")) + assert.True(t, shouldEngageAssist(assistOn, "")) + + // `off` never engages. + assert.False(t, shouldEngageAssist(assistOff, "where do we hash")) + assert.False(t, shouldEngageAssist(assistOff, "")) + + // `auto` defers to the heuristic. + assert.False(t, shouldEngageAssist(assistAuto, "handleAsk")) + assert.True(t, shouldEngageAssist(assistAuto, "where do we hash")) + + // `deep` always engages — its whole purpose is opt-in verification + // for cases the caller knows are NL queries. + assert.True(t, shouldEngageAssist(assistDeep, "Foo")) + assert.True(t, shouldEngageAssist(assistDeep, "where do we hash")) +} + +func TestTruncateBody(t *testing.T) { + cases := []struct { + name string + src string + maxLines int + maxChars int + want string + }{ + {"empty", "", 8, 600, ""}, + { + "under both caps", + "a()\nb()\nc()", + 8, 600, + "a()\nb()\nc()\n", + }, + { + "blank lines skipped from line count", + "a()\n\nb()\n\nc()\n", + 3, 600, + "a()\n\nb()\n\nc()\n…\n", + }, + { + "line cap fires", + "l1\nl2\nl3\nl4\nl5", + 3, 600, + "l1\nl2\nl3\n…\n", + }, + { + "char cap fires after line cap", + strings.Repeat("X", 700), + 8, 100, + strings.Repeat("X", 100) + "…\n", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := truncateBody(tc.src, tc.maxLines, tc.maxChars) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestHasMixedCase(t *testing.T) { + assert.False(t, hasMixedCase("lower")) + assert.False(t, hasMixedCase("UPPER")) + assert.False(t, hasMixedCase("")) + assert.False(t, hasMixedCase("123")) + assert.True(t, hasMixedCase("camelCase")) + assert.True(t, hasMixedCase("PascalCase")) +} + +func TestHasStopWord(t *testing.T) { + assert.True(t, hasStopWord([]string{"hello", "where", "world"})) + assert.True(t, hasStopWord([]string{"WHERE", "is", "x"})) + assert.False(t, hasStopWord([]string{"validate", "token", "auth"})) + assert.False(t, hasStopWord(nil)) +} + +// TestFetchAndMergeBM25_DedupesAcrossTerms verifies that when the +// same node matches multiple terms, it appears only once and keeps +// its primary-term position. +func TestFetchAndMergeBM25_DedupesAcrossTerms(t *testing.T) { + srv, _ := setupTestServer(t) + scope := query.QueryOptions{} + + // Primary term that hits "helper". + primary := srv.engine.SearchSymbolsScoped("helper", 20, scope) + require.NotEmpty(t, primary) + + // Merging with the same term as an "expansion" must produce the + // same list, not duplicates. + merged, primaryCount := fetchAndMergeBM25(srv, "helper", []string{"helper"}, 20, scope) + assert.Equal(t, len(primary), primaryCount) + assert.Equal(t, idsOf(primary), idsOf(merged)) +} + +// TestFetchAndMergeBM25_AppendsNewMatches verifies that expansion +// terms bring in additional candidates the primary term missed. +func TestFetchAndMergeBM25_AppendsNewMatches(t *testing.T) { + srv, _ := setupTestServer(t) + scope := query.QueryOptions{} + + primary := srv.engine.SearchSymbolsScoped("helper", 20, scope) + merged, primaryCount := fetchAndMergeBM25(srv, "helper", []string{"main"}, 20, scope) + assert.Equal(t, len(primary), primaryCount) + + primaryIDs := idsOf(primary) + mergedIDs := idsOf(merged) + + // Every primary ID appears in the merged set, in primary order + // at the head. + require.GreaterOrEqual(t, len(mergedIDs), len(primaryIDs)) + for i, id := range primaryIDs { + assert.Equal(t, id, mergedIDs[i], "primary order broken at index %d", i) + } + // The merge brought in at least one "main"-matched node. + assert.Greater(t, len(mergedIDs), len(primaryIDs)) +} + +// TestSearchSymbols_AssistArgPassThrough verifies the new assist arg +// parses and doesn't break the no-LLM path. Without a service the +// gate always reads as "no engage" regardless of mode, so results +// match the no-assist baseline exactly. +func TestSearchSymbols_AssistArgPassThrough(t *testing.T) { + srv, _ := setupTestServer(t) + + for _, mode := range []string{"", "auto", "on", "off"} { + t.Run("assist="+mode, func(t *testing.T) { + args := map[string]any{"query": "helper"} + if mode != "" { + args["assist"] = mode + } + result := callTool(t, srv, "search_symbols", args) + require.False(t, result.IsError, "search failed for mode=%q", mode) + text := result.Content[0].(mcplib.TextContent).Text + var resp map[string]any + require.NoError(t, json.Unmarshal([]byte(text), &resp)) + results := resp["results"].([]any) + require.NotEmpty(t, results, "no results for mode=%q", mode) + }) + } +} + +func TestPrioritizeCallables(t *testing.T) { + // Mixed input: BM25-ranked, with callable kinds interleaved among + // param/field/type nodes. Expected output: callables in their + // original order, then everything else in its original order. + nodes := []*graph.Node{ + {ID: "p1", Kind: graph.KindParam}, + {ID: "f1", Kind: graph.KindFunction}, + {ID: "fld1", Kind: graph.KindField}, + {ID: "m1", Kind: graph.KindMethod}, + {ID: "t1", Kind: graph.KindType}, + {ID: "f2", Kind: graph.KindFunction}, + } + got := prioritizeCallables(nodes) + want := []string{"f1", "m1", "f2", "p1", "fld1", "t1"} + gotIDs := idsOf(got) + if len(gotIDs) != len(want) { + t.Fatalf("length mismatch: got=%v want=%v", gotIDs, want) + } + for i := range want { + if gotIDs[i] != want[i] { + t.Errorf("position %d: got=%q want=%q", i, gotIDs[i], want[i]) + } + } +} + +func TestPrioritizeCallables_AllCallable(t *testing.T) { + nodes := []*graph.Node{ + {ID: "a", Kind: graph.KindFunction}, + {ID: "b", Kind: graph.KindMethod}, + } + got := prioritizeCallables(nodes) + if got[0].ID != "a" || got[1].ID != "b" { + t.Fatalf("order changed when no reordering needed: %v", idsOf(got)) + } +} + +func TestPrioritizeCallables_NoCallable(t *testing.T) { + nodes := []*graph.Node{ + {ID: "a", Kind: graph.KindParam}, + {ID: "b", Kind: graph.KindField}, + } + got := prioritizeCallables(nodes) + if got[0].ID != "a" || got[1].ID != "b" { + t.Fatalf("order changed when no callables present: %v", idsOf(got)) + } +} + +func idsOf(nodes []*graph.Node) []string { + out := make([]string, len(nodes)) + for i, n := range nodes { + out[i] = n.ID + } + return out +} From 2fd7dc8382f0474abb5c7e57beb5421c6bdffbea Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Thu, 14 May 2026 18:47:46 +0200 Subject: [PATCH 4/6] llm, mcp, config: pluggable multi-provider LLM backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce an llm.Provider interface so the `ask` agent and the search_symbols assist passes run on any of four backends, selected by the `llm.provider` config key: - local — in-process llama.cpp, the only `-tags llama` package; ships a non-llama stub so the provider factory always compiles - anthropic / openai / ollama — pure-Go HTTP clients, available in every build, via forced tool-use / json_schema / `format` schema respectively Provider is a single-method surface (Complete) over a provider-neutral []Message conversation; the agent tool-loop and the three assist passes are all built on it. Tool calls travel as plain text (emulated protocol) so one Message shape works across every provider. llm.Config gains a `provider` key plus per-provider sub-blocks, replacing the flat llama-only shape; MergedWith / MergeLLMInto merge per sub-block. Assist prompts are tiered by capability (small / frontier) and keyed off the active provider's name rather than carried per-provider. Knock-on: agent and svc are now pure Go — the build-tag split is contained entirely in provider/local. The `ask` tool and assist modes work with any provider, not just llama builds; service_stub.go and tools_llm_stub.go are deleted and registerLLMTools is unconditional. SetupLLM logs provider-construction errors instead of silently disabling. --- CLAUDE.md | 23 +- cmd/gortex/daemon_state.go | 15 +- cmd/gortex/mcp.go | 7 +- internal/config/config.go | 9 +- internal/config/global.go | 36 +-- internal/config/global_llm_test.go | 95 ++++-- internal/llm/agent/agent.go | 216 ++++---------- internal/llm/agent/tools.go | 2 - internal/llm/cmd/agentdemo/main.go | 46 ++- internal/llm/cmd/bench/main.go | 36 +-- internal/llm/config.go | 280 ++++++++++++++--- internal/llm/config_test.go | 119 ++++++++ internal/llm/prompts.go | 178 +++++++++++ internal/llm/prompts_test.go | 86 ++++++ internal/llm/provider.go | 110 +++++++ internal/llm/provider/anthropic/anthropic.go | 247 +++++++++++++++ .../llm/provider/anthropic/anthropic_test.go | 169 +++++++++++ internal/llm/provider/local/local.go | 213 +++++++++++++ internal/llm/provider/local/stub.go | 24 ++ internal/llm/provider/local/template.go | 163 ++++++++++ internal/llm/provider/ollama/ollama.go | 173 +++++++++++ internal/llm/provider/ollama/ollama_test.go | 131 ++++++++ internal/llm/provider/openai/openai.go | 216 ++++++++++++++ internal/llm/provider/openai/openai_test.go | 147 +++++++++ internal/llm/provider/provider.go | 41 +++ internal/llm/provider/provider_test.go | 62 ++++ internal/llm/svc/assist.go | 281 +++++------------- internal/llm/svc/assist_e2e_test.go | 9 +- internal/llm/svc/assist_test.go | 22 -- internal/llm/svc/cache.go | 2 - internal/llm/svc/service.go | 245 +++++++-------- internal/llm/svc/service_stub.go | 61 ---- internal/mcp/server.go | 42 ++- internal/mcp/tools_core.go | 2 +- internal/mcp/tools_llm.go | 5 +- internal/mcp/tools_llm_stub.go | 11 - 36 files changed, 2731 insertions(+), 793 deletions(-) create mode 100644 internal/llm/config_test.go create mode 100644 internal/llm/prompts.go create mode 100644 internal/llm/prompts_test.go create mode 100644 internal/llm/provider.go create mode 100644 internal/llm/provider/anthropic/anthropic.go create mode 100644 internal/llm/provider/anthropic/anthropic_test.go create mode 100644 internal/llm/provider/local/local.go create mode 100644 internal/llm/provider/local/stub.go create mode 100644 internal/llm/provider/local/template.go create mode 100644 internal/llm/provider/ollama/ollama.go create mode 100644 internal/llm/provider/ollama/ollama_test.go create mode 100644 internal/llm/provider/openai/openai.go create mode 100644 internal/llm/provider/openai/openai_test.go create mode 100644 internal/llm/provider/provider.go create mode 100644 internal/llm/provider/provider_test.go delete mode 100644 internal/llm/svc/service_stub.go delete mode 100644 internal/mcp/tools_llm_stub.go diff --git a/CLAUDE.md b/CLAUDE.md index ede4bd6..83250e7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,9 +20,22 @@ go test -race ./... # all test packages must pass Gortex is running as an MCP server. You MUST use graph queries instead of file reads whenever possible. This saves thousands of tokens per task. -### Optional: delegate research to a local agent +### Optional: LLM features and provider selection -When the daemon is built with `-tags llama` and `llm.model` is set in `.gortex.yaml` (or via the `GORTEX_LLM_MODEL` env var), the `ask` MCP tool is registered. It runs a grammar-constrained agent locally that uses gortex tools to research one question and returns a synthesized answer — useful when you'd otherwise issue many `search_symbols` / `get_callers` / `contracts` calls. +The `ask` tool and the `search_symbols` `assist` modes are backed by an LLM provider, selected by the `llm.provider` config key (in `.gortex.yaml` or `~/.config/gortex/config.yaml`): + +| `llm.provider` | Backend | Requires | +|----------------|---------|----------| +| `local` (default) | in-process llama.cpp | a `-tags llama` build + `llm.local.model` (a `.gguf` path) | +| `anthropic` | Anthropic Messages API | `llm.anthropic.model` + `ANTHROPIC_API_KEY` | +| `openai` | OpenAI Chat Completions | `llm.openai.model` + `OPENAI_API_KEY` | +| `ollama` | Ollama daemon | `llm.ollama.model` (+ `llm.ollama.host`, default `localhost:11434`) | + +The HTTP providers are pure Go — available without `-tags llama`. `GORTEX_LLM_PROVIDER` / `GORTEX_LLM_MODEL` env vars override the file config. If the active provider can't be constructed (missing model / API key, or `local` without `-tags llama`), the daemon logs a warning and the LLM features stay absent. + +### Optional: delegate research to the `ask` agent + +When a provider is configured, the `ask` MCP tool is registered. It runs a structured tool-calling agent that uses gortex tools to research one question and returns a synthesized answer — useful when you'd otherwise issue many `search_symbols` / `get_callers` / `contracts` calls. | When you'd otherwise... | Consider... | |---------------------------------------|------------------------------------------| @@ -30,18 +43,18 @@ When the daemon is built with `-tags llama` and `llm.model` is set in `.gortex.y | Trace a request across repos (consumer → contract → handler → downstream) | `ask` with `chain: true` | | Look up a single known fact | Skip `ask` — direct tools are faster | -If `ask` isn't in `tools/list`, gortex was built without `-tags llama` or `llm.model` is unset. Fall through to direct tools. +If `ask` isn't in `tools/list`, no LLM provider is configured (or it failed to construct). Fall through to direct tools. ### Optional: LLM-assisted search ranking (`search_symbols` `assist:` arg) -When the same `-tags llama` build + `llm.model` is in place, `search_symbols` accepts an `assist` argument that engages the local model in the search pipeline. The default `auto` is sub-100 ms on identifier lookups; the active modes add latency but materially improve precision on natural-language queries. +When a provider is configured, `search_symbols` accepts an `assist` argument that engages the model in the search pipeline. The default `auto` is sub-100 ms on identifier lookups; the active modes add latency but materially improve precision on natural-language queries. | `assist` value | Behaviour | Cost | |----------------|-----------|------| | `auto` (default) | NL heuristic decides per-query. Identifier-shaped queries (`Server.handleAsk`, `parseToolCall`) skip the LLM. NL queries (≥3 tokens with a stop word, or ≥4 plain-word tokens) trigger query expansion + name+sig rerank. | None for identifier lookups; +200–500 ms for NL. | | `on` | Forces expansion + name+sig rerank regardless of shape. Use when you know the query is fuzzy. | +200–500 ms. | | `off` | Pure BM25 + combo/frecency. No LLM. | None. | -| `deep` | `on` plus a body-grounded verification pass — reads each top candidate's body + callers and HONESTLY drops candidates whose code isn't about the query. May return zero results when nothing genuinely matches; that's the load-bearing honest-negative signal. | +1.5–4 s. Quality is **highly model-dependent**: Qwen2.5-Coder 3B is unreliable on disambiguation cases (e.g. "hash passwords" vs functions that hash other data); Qwen2.5-Coder 7B and above produce stable, useful results. Prefer 7B+ if you want to rely on `deep`. | +| `deep` | `on` plus a body-grounded verification pass — reads each top candidate's body + callers and HONESTLY drops candidates whose code isn't about the query. May return zero results when nothing genuinely matches; that's the load-bearing honest-negative signal. | +1.5–4 s. Quality is **highly model-dependent**: small local models (Qwen2.5-Coder 3B) are unreliable on disambiguation cases (e.g. "hash passwords" vs functions that hash other data); a 7B-class local model or any hosted provider produces stable, useful results. The assist prompts are tiered automatically — terser for hosted frontier models, rule-heavy for small local ones. | The response gains an `assist` debug block when an active mode engaged: `terms` (expansion words), `primary_count` (raw BM25 hits on the original query), `merged_count` (after expansion union), `final_count` (after filter/rerank), plus `verify_kept_ids` / `verify_dropped` for `deep`. diff --git a/cmd/gortex/daemon_state.go b/cmd/gortex/daemon_state.go index b6cb3c3..4d62186 100644 --- a/cmd/gortex/daemon_state.go +++ b/cmd/gortex/daemon_state.go @@ -302,12 +302,15 @@ func buildDaemonState(logger *zap.Logger) (*daemonState, error) { logger.Warn("daemon: savings persistence disabled", zap.Error(err)) } - // In-process LLM service (opt-in via `.gortex.yaml` `llm.model:`, - // `~/.config/gortex/config.yaml::llm:`, or GORTEX_LLM_MODEL env - // var). Repo-local config wins per non-zero field; global fills - // the rest. Env overrides land last inside SetupLLM via MergeEnv. - // No-op when the merged config has no model, or when gortex was - // built without `-tags llama` (stub service + stub registerLLMTools). + // LLM service (opt-in via the `.gortex.yaml` `llm:` block, + // `~/.config/gortex/config.yaml::llm:`, or GORTEX_LLM_* env vars). + // Repo-local config wins per non-zero field; the global config + // fills the rest; env overrides land last inside SetupLLM via + // MergeEnv. The active provider is chosen by `llm.provider` + // (local / anthropic / openai / ollama). No-op when the active + // provider has no model configured; a provider that fails to + // construct (e.g. "local" without `-tags llama`, or a missing API + // key) is logged and the service stays disabled. gc, _ := config.LoadGlobal() srv.SetupLLM(gc.MergeLLMInto(cfg.LLM)) diff --git a/cmd/gortex/mcp.go b/cmd/gortex/mcp.go index f77b114..996fea1 100644 --- a/cmd/gortex/mcp.go +++ b/cmd/gortex/mcp.go @@ -345,9 +345,10 @@ func runMCP(cmd *cobra.Command, args []string) error { fmt.Fprintf(os.Stderr, "[gortex] savings persistence disabled: %v\n", err) } - // In-process LLM service — same wiring as the daemon path: repo - // config wins per non-zero field, global ~/.config/gortex/config.yaml - // fills the rest, env vars override last inside SetupLLM. + // LLM service — same wiring as the daemon path: repo config wins + // per non-zero field, global ~/.config/gortex/config.yaml fills the + // rest, env vars override last inside SetupLLM. The active provider + // is chosen by `llm.provider` (local / anthropic / openai / ollama). gc, _ := config.LoadGlobal() srv.SetupLLM(gc.MergeLLMInto(cfg.LLM)) diff --git a/internal/config/config.go b/internal/config/config.go index 2223624..0d91240 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -230,10 +230,11 @@ type Config struct { Guards GuardsConfig `mapstructure:"guards" yaml:"guards,omitempty"` Multi MultiRepoConfig `mapstructure:"multi" yaml:"multi,omitempty"` Semantic SemanticConfig `mapstructure:"semantic" yaml:"semantic,omitempty"` - // LLM configures the in-process local-LLM service that backs the - // `ask` MCP tool (and future wiki / doc generators). Empty by - // default — daemon skips LLM wiring entirely when llm.model is - // unset. Env vars GORTEX_LLM_* override file values; see + // LLM configures the LLM service that backs the `ask` MCP tool and + // the search-assist passes. Empty by default — daemon skips LLM + // wiring entirely when the active provider has no model configured. + // The `llm.provider` key selects the backend (local / anthropic / + // openai / ollama); env vars GORTEX_LLM_* override file values; see // internal/llm/config.go::Config.MergeEnv. LLM llm.Config `mapstructure:"llm" yaml:"llm,omitempty"` } diff --git a/internal/config/global.go b/internal/config/global.go index 5e88e03..b55c521 100644 --- a/internal/config/global.go +++ b/internal/config/global.go @@ -67,32 +67,20 @@ type GlobalConfig struct { configPath string `yaml:"-"` } -// MergeLLMInto returns local with any zero fields filled from gc.LLM. -// Local non-zero values always win — including an explicit per-repo -// override of an inherited global model path. Safe to call on a nil -// receiver (returns local unchanged), so daemon startup paths don't -// need separate nil-checks for the global config. +// MergeLLMInto layers a repo-local llm.Config over the global user +// config: each zero-valued field of local is filled from gc.LLM, +// per provider sub-block. Local non-zero values always win — including +// an explicit per-repo override of an inherited global model path. +// Safe to call on a nil receiver (returns local unchanged), so daemon +// startup paths don't need separate nil-checks for the global config. +// +// The local provider's model path additionally gets `~/` expanded +// against $HOME so users can write portable paths in either config. func (gc *GlobalConfig) MergeLLMInto(local llm.Config) llm.Config { - if gc == nil { - return local - } - g := gc.LLM - if local.Model == "" { - local.Model = g.Model - } - if local.Ctx == 0 { - local.Ctx = g.Ctx - } - if local.GPULayers == 0 { - local.GPULayers = g.GPULayers - } - if local.MaxSteps == 0 { - local.MaxSteps = g.MaxSteps - } - if local.Template == "" { - local.Template = g.Template + if gc != nil { + local = local.MergedWith(gc.LLM) } - local.Model = expandHome(local.Model) + local.Local.Model = expandHome(local.Local.Model) return local } diff --git a/internal/config/global_llm_test.go b/internal/config/global_llm_test.go index 4798c2b..10c352b 100644 --- a/internal/config/global_llm_test.go +++ b/internal/config/global_llm_test.go @@ -17,77 +17,110 @@ func TestLoadGlobal_LLMSectionRoundTrip(t *testing.T) { require.NoError(t, os.WriteFile(cfgPath, []byte(`active_project: "" repos: [] llm: - model: /opt/models/qwen.gguf - template: chatml - ctx: 4096 + provider: local max_steps: 12 - gpu_layers: 999 + local: + model: /opt/models/qwen.gguf + template: chatml + ctx: 4096 + gpu_layers: 999 + anthropic: + model: claude-sonnet-4-6 `), 0o644)) gc, err := LoadGlobal(cfgPath) require.NoError(t, err) require.NotNil(t, gc) - assert.Equal(t, "/opt/models/qwen.gguf", gc.LLM.Model) - assert.Equal(t, "chatml", gc.LLM.Template) - assert.Equal(t, 4096, gc.LLM.Ctx) + assert.Equal(t, "local", gc.LLM.Provider) assert.Equal(t, 12, gc.LLM.MaxSteps) - assert.Equal(t, 999, gc.LLM.GPULayers) + assert.Equal(t, "/opt/models/qwen.gguf", gc.LLM.Local.Model) + assert.Equal(t, "chatml", gc.LLM.Local.Template) + assert.Equal(t, 4096, gc.LLM.Local.Ctx) + assert.Equal(t, 999, gc.LLM.Local.GPULayers) + assert.Equal(t, "claude-sonnet-4-6", gc.LLM.Anthropic.Model) } func TestGlobalConfig_MergeLLMInto_FillsZeroFields(t *testing.T) { gc := &GlobalConfig{LLM: llm.Config{ - Model: "/global/qwen.gguf", - Template: "chatml", - Ctx: 4096, - MaxSteps: 16, - GPULayers: 999, + Provider: "local", + MaxSteps: 16, + Local: llm.LocalConfig{ + Model: "/global/qwen.gguf", + Template: "chatml", + Ctx: 4096, + GPULayers: 999, + }, }} got := gc.MergeLLMInto(llm.Config{}) - assert.Equal(t, "/global/qwen.gguf", got.Model) - assert.Equal(t, "chatml", got.Template) - assert.Equal(t, 4096, got.Ctx) + assert.Equal(t, "local", got.Provider) assert.Equal(t, 16, got.MaxSteps) - assert.Equal(t, 999, got.GPULayers) + assert.Equal(t, "/global/qwen.gguf", got.Local.Model) + assert.Equal(t, "chatml", got.Local.Template) + assert.Equal(t, 4096, got.Local.Ctx) + assert.Equal(t, 999, got.Local.GPULayers) } func TestGlobalConfig_MergeLLMInto_LocalWinsPerField(t *testing.T) { gc := &GlobalConfig{LLM: llm.Config{ - Model: "/global/qwen.gguf", - Template: "chatml", - Ctx: 4096, + Provider: "local", MaxSteps: 16, + Local: llm.LocalConfig{ + Model: "/global/qwen.gguf", + Template: "chatml", + Ctx: 4096, + }, }} got := gc.MergeLLMInto(llm.Config{ - Model: "/repo/override.gguf", // local wins - Ctx: 8192, // local wins + Local: llm.LocalConfig{ + Model: "/repo/override.gguf", // local wins + Ctx: 8192, // local wins + }, }) - assert.Equal(t, "/repo/override.gguf", got.Model) - assert.Equal(t, 8192, got.Ctx) + assert.Equal(t, "/repo/override.gguf", got.Local.Model) + assert.Equal(t, 8192, got.Local.Ctx) // Unset locals fall through to global. - assert.Equal(t, "chatml", got.Template) + assert.Equal(t, "chatml", got.Local.Template) assert.Equal(t, 16, got.MaxSteps) + assert.Equal(t, "local", got.Provider) +} + +func TestGlobalConfig_MergeLLMInto_PerProviderSubBlocks(t *testing.T) { + gc := &GlobalConfig{LLM: llm.Config{ + Anthropic: llm.RemoteConfig{Model: "claude-sonnet-4-6", APIKeyEnv: "ANTHROPIC_API_KEY"}, + Ollama: llm.OllamaConfig{Host: "http://localhost:11434"}, + }} + + // Repo selects a different provider and overrides only one field. + got := gc.MergeLLMInto(llm.Config{ + Provider: "anthropic", + Anthropic: llm.RemoteConfig{Model: "claude-opus-4-7"}, + }) + assert.Equal(t, "anthropic", got.Provider) + assert.Equal(t, "claude-opus-4-7", got.Anthropic.Model) // local wins + assert.Equal(t, "ANTHROPIC_API_KEY", got.Anthropic.APIKeyEnv) // global fills + assert.Equal(t, "http://localhost:11434", got.Ollama.Host) // unrelated block still merges } func TestGlobalConfig_MergeLLMInto_NilReceiver(t *testing.T) { var gc *GlobalConfig // nil - local := llm.Config{Model: "/repo/x.gguf"} + local := llm.Config{Local: llm.LocalConfig{Model: "/repo/x.gguf"}} got := gc.MergeLLMInto(local) - assert.Equal(t, "/repo/x.gguf", got.Model) + assert.Equal(t, "/repo/x.gguf", got.Local.Model) } func TestGlobalConfig_MergeLLMInto_ExpandsHomeInModelPath(t *testing.T) { home, err := os.UserHomeDir() require.NoError(t, err) - gc := &GlobalConfig{LLM: llm.Config{Model: "~/models/qwen.gguf"}} + gc := &GlobalConfig{LLM: llm.Config{Local: llm.LocalConfig{Model: "~/models/qwen.gguf"}}} got := gc.MergeLLMInto(llm.Config{}) - assert.Equal(t, filepath.Join(home, "models/qwen.gguf"), got.Model) + assert.Equal(t, filepath.Join(home, "models/qwen.gguf"), got.Local.Model) // Local override also gets expanded. - got = gc.MergeLLMInto(llm.Config{Model: "~/repo-override.gguf"}) - assert.Equal(t, filepath.Join(home, "repo-override.gguf"), got.Model) + got = gc.MergeLLMInto(llm.Config{Local: llm.LocalConfig{Model: "~/repo-override.gguf"}}) + assert.Equal(t, filepath.Join(home, "repo-override.gguf"), got.Local.Model) } func TestExpandHome(t *testing.T) { diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 277ab4e..5a46ac0 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -1,20 +1,25 @@ -//go:build llama - -// Package agent runs a grammar-constrained tool-calling loop on top of -// the internal/llm wrapper. The model can only emit JSON of the -// shape {"tool": "", "args": {...}}; -// each tool call is executed and its result fed back as a new turn. +// Package agent runs a provider-agnostic tool-calling loop. On each +// turn the model emits one JSON object {"tool":"","args":{...}}; +// the loop executes that call and feeds the result back as a new turn. // The loop terminates when the model calls the final_answer tool. +// +// The structured-output constraint — the model may only emit a valid +// tool-call object — is enforced by the llm.Provider via +// CompletionRequest.Shape == ShapeToolCall: a GBNF grammar for the +// local llama.cpp provider, json-schema / forced-tool for the HTTP +// providers. The loop carries the conversation as a provider-neutral +// []llm.Message, so the same agent drives every provider. package agent import ( + "context" "encoding/json" "errors" "fmt" "sort" "strings" - llm "github.com/zzet/gortex/internal/llm" + "github.com/zzet/gortex/internal/llm" ) // ToolFunc executes a single tool call. args is the parsed JSON object @@ -37,78 +42,28 @@ type Step struct { Args map[string]any // parsed args (call/final) } -// ChatTemplate describes how to wrap conversation turns for a given -// model family. The four wrappers each take raw content and return -// the fully-marked-up turn; AssistPrime is the marker we append to -// the running conversation right before each generate call so the -// model starts emitting an assistant turn. -type ChatTemplate struct { - Name string - BOS string - System func(content string) string - User func(content string) string - AssistEnd string // marker appended after a captured assistant emission - Tool func(content string) string - AssistPrime string -} - -// TemplateChatML covers Qwen2.5 family and Nous Hermes-3 (which -// re-trains Llama-3 onto ChatML). -var TemplateChatML = ChatTemplate{ - Name: "chatml", - System: func(c string) string { return "<|im_start|>system\n" + c + "<|im_end|>\n" }, - User: func(c string) string { return "<|im_start|>user\n" + c + "<|im_end|>\n" }, - AssistEnd: "<|im_end|>\n", - Tool: func(c string) string { return "<|im_start|>tool\n" + c + "<|im_end|>\n" }, - AssistPrime: "<|im_start|>assistant\n", -} - -// TemplateLlama3 covers Meta's Llama-3.x stock instruct format. Used by -// models that keep Llama-3's native template (NOT Hermes-3, which -// switches to ChatML). -var TemplateLlama3 = ChatTemplate{ - Name: "llama3", - BOS: "<|begin_of_text|>", - System: func(c string) string { - return "<|start_header_id|>system<|end_header_id|>\n\n" + c + "<|eot_id|>" - }, - User: func(c string) string { - return "<|start_header_id|>user<|end_header_id|>\n\n" + c + "<|eot_id|>" - }, - AssistEnd: "<|eot_id|>", - Tool: func(c string) string { - return "<|start_header_id|>ipython<|end_header_id|>\n\n" + c + "<|eot_id|>" - }, - AssistPrime: "<|start_header_id|>assistant<|end_header_id|>\n\n", -} +// FinalAnswerTool is the name of the synthetic terminator tool. It is +// registered automatically by New; callers must not pre-register it. +const FinalAnswerTool = "final_answer" -// TemplateByName returns a known chat template by short name. -func TemplateByName(name string) (ChatTemplate, error) { - switch name { - case "", "chatml", "qwen", "hermes": - return TemplateChatML, nil - case "llama3", "llama": - return TemplateLlama3, nil - } - return ChatTemplate{}, fmt.Errorf("unknown chat template %q", name) -} +// stepMaxTokens caps a single tool-call emission. A tool call is a +// small JSON object, so this is generous. +const stepMaxTokens = 512 type Agent struct { - ctx *llm.Context - tmpl ChatTemplate - tools map[string]Tool - names []string // sorted, for stable grammar - grammar string - maxTok int + provider llm.Provider + tools map[string]Tool + names []string // sorted, stable iteration order + specs []llm.ToolSpec // sorted by name; handed to the provider } -// FinalAnswerTool is the name of the synthetic terminator tool. It is -// registered automatically by New; callers should not pre-register it. -const FinalAnswerTool = "final_answer" - -func New(ctx *llm.Context, tools []Tool, tmpl ChatTemplate) (*Agent, error) { - if tmpl.AssistPrime == "" { - tmpl = TemplateChatML +// New builds an Agent over a provider and a tool set. The synthetic +// final_answer tool is appended automatically. Returns an error for a +// nil provider or a malformed tool (empty name, reserved name, nil +// Run). +func New(provider llm.Provider, tools []Tool) (*Agent, error) { + if provider == nil { + return nil, errors.New("agent: nil provider") } reg := make(map[string]Tool, len(tools)+1) for _, t := range tools { @@ -134,44 +89,36 @@ func New(ctx *llm.Context, tools []Tool, tmpl ChatTemplate) (*Agent, error) { names = append(names, n) } sort.Strings(names) - - a := &Agent{ - ctx: ctx, - tmpl: tmpl, - tools: reg, - names: names, - maxTok: 512, - } - a.grammar = buildGrammar(names) - if err := ctx.SetGrammar(a.grammar); err != nil { - return nil, fmt.Errorf("agent: install grammar: %w", err) + specs := make([]llm.ToolSpec, len(names)) + for i, n := range names { + specs[i] = llm.ToolSpec{Name: n, Description: reg[n].Description} } - return a, nil -} -// Grammar returns the GBNF the agent installed. Exposed for debugging. -func (a *Agent) Grammar() string { return a.grammar } + return &Agent{provider: provider, tools: reg, names: names, specs: specs}, nil +} // Run executes the tool-calling loop until the model invokes // final_answer or maxSteps is reached. The transcript captures every // call/result/final step in order. -func (a *Agent) Run(systemExtras, userQuestion string, maxSteps int) (answer string, transcript []Step, err error) { - conv := a.initialPrompt(systemExtras, userQuestion) +func (a *Agent) Run(ctx context.Context, systemExtras, userQuestion string, maxSteps int) (answer string, transcript []Step, err error) { + conv := []llm.Message{ + {Role: llm.RoleSystem, Content: a.systemPrompt(systemExtras)}, + {Role: llm.RoleUser, Content: userQuestion}, + } seen := map[string]struct{}{} - for step := 0; step < maxSteps; step++ { - a.ctx.Reset() - - var buf strings.Builder - _, gerr := a.ctx.Generate(conv, a.maxTok, func(piece string) bool { - buf.WriteString(piece) - return !jsonComplete(buf.String()) + for step := range maxSteps { + resp, gerr := a.provider.Complete(ctx, llm.CompletionRequest{ + Messages: conv, + MaxTokens: stepMaxTokens, + Shape: llm.ShapeToolCall, + Tools: a.specs, }) if gerr != nil { return "", transcript, fmt.Errorf("step %d generate: %w", step, gerr) } - raw := strings.TrimSpace(buf.String()) + raw := strings.TrimSpace(resp.Text) call, perr := parseToolCall(raw) if perr != nil { return "", transcript, fmt.Errorf("step %d parse %q: %w", step, raw, perr) @@ -187,24 +134,25 @@ func (a *Agent) Run(systemExtras, userQuestion string, maxSteps int) (answer str tool, ok := a.tools[call.Tool] if !ok { - // Grammar shouldn't allow this, but defend anyway. - return "", transcript, fmt.Errorf("step %d unknown tool %q (grammar bug?)", step, call.Tool) + // Structured output shouldn't allow this, but defend anyway. + return "", transcript, fmt.Errorf("step %d unknown tool %q (provider bug?)", step, call.Tool) } transcript = append(transcript, Step{ Kind: "call", Raw: raw, Tool: call.Tool, Args: call.Args, }) // Loop detection: if we've already executed this exact - // (tool, args) pair in this run, refuse to execute it again - // and feed back a synthetic loop_detected observation so the - // model is forced to change strategy. + // (tool, args) pair in this run, refuse to execute it again and + // feed back a synthetic loop_detected observation so the model + // is forced to change strategy. key := callKey(call.Tool, call.Args) if _, dup := seen[key]; dup { loopResult := `{"error":"loop_detected","message":"You already called this exact tool with these exact args; the result did not help. Try DIFFERENT args, a DIFFERENT tool, or call final_answer to give your best summary of what you found."}` transcript = append(transcript, Step{Kind: "result", Raw: loopResult}) - conv += raw + a.tmpl.AssistEnd + - a.tmpl.Tool(loopResult) + - a.tmpl.AssistPrime + conv = append(conv, + llm.Message{Role: llm.RoleAssistant, Content: raw}, + llm.Message{Role: llm.RoleTool, Content: loopResult, ToolName: call.Tool}, + ) continue } seen[key] = struct{}{} @@ -215,22 +163,25 @@ func (a *Agent) Run(systemExtras, userQuestion string, maxSteps int) (answer str } transcript = append(transcript, Step{Kind: "result", Raw: result}) - conv += raw + a.tmpl.AssistEnd + - a.tmpl.Tool(result) + - a.tmpl.AssistPrime + conv = append(conv, + llm.Message{Role: llm.RoleAssistant, Content: raw}, + llm.Message{Role: llm.RoleTool, Content: result, ToolName: call.Tool}, + ) } return "", transcript, fmt.Errorf("agent: exceeded %d steps without final_answer", maxSteps) } -func (a *Agent) initialPrompt(extras, question string) string { +// systemPrompt assembles the agent's system message: the tool-call +// protocol, the tool catalogue, and any caller-supplied extras +// (the simple / chain mode rule sets). +func (a *Agent) systemPrompt(extras string) string { var sys strings.Builder sys.WriteString("You are a tool-using agent. ") sys.WriteString("On each turn, emit ONE JSON object: ") sys.WriteString(`{"tool": "", "args": {...}}.`) sys.WriteString(" Available tools:\n") for _, n := range a.names { - t := a.tools[n] - fmt.Fprintf(&sys, "- %s: %s\n", n, t.Description) + fmt.Fprintf(&sys, "- %s: %s\n", n, a.tools[n].Description) } sys.WriteString("\nAfter receiving a tool result, call the next tool. ") sys.WriteString("When you have enough information, call ") @@ -240,11 +191,7 @@ func (a *Agent) initialPrompt(extras, question string) string { sys.WriteString("\n\n") sys.WriteString(extras) } - - return a.tmpl.BOS + - a.tmpl.System(sys.String()) + - a.tmpl.User(question) + - a.tmpl.AssistPrime + return sys.String() } type toolCall struct { @@ -272,36 +219,3 @@ func callKey(tool string, args map[string]any) string { b, _ := json.Marshal(args) return tool + ":" + string(b) } - -// jsonComplete reports whether s is a complete top-level JSON object. -// Used to early-stop generation as soon as the grammar-driven model -// closes the brace, instead of waiting on EOS. -func jsonComplete(s string) bool { - s = strings.TrimSpace(s) - if !strings.HasPrefix(s, "{") || !strings.HasSuffix(s, "}") { - return false - } - var v any - return json.Unmarshal([]byte(s), &v) == nil -} - -// buildGrammar returns a GBNF that accepts {"tool": "", -// "args": {}} with whitespace tolerance. -func buildGrammar(names []string) string { - alt := make([]string, len(names)) - for i, n := range names { - alt[i] = `"\"" "` + n + `" "\""` - } - toolname := strings.Join(alt, " | ") - - return `root ::= ws "{" ws "\"tool\"" ws ":" ws toolname ws "," ws "\"args\"" ws ":" ws object ws "}" ws -toolname ::= ` + toolname + ` -object ::= "{" ws ( pair ( ws "," ws pair )* )? ws "}" -pair ::= string ws ":" ws value -array ::= "[" ws ( value ( ws "," ws value )* )? ws "]" -value ::= string | number | object | array | "true" | "false" | "null" -string ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" -number ::= "-"? ( "0" | [1-9] [0-9]* ) ( "." [0-9]+ )? ( [eE] [-+]? [0-9]+ )? -ws ::= [ \t\n]* -` -} diff --git a/internal/llm/agent/tools.go b/internal/llm/agent/tools.go index 91a244a..884c8a0 100644 --- a/internal/llm/agent/tools.go +++ b/internal/llm/agent/tools.go @@ -1,5 +1,3 @@ -//go:build llama - package agent import ( diff --git a/internal/llm/cmd/agentdemo/main.go b/internal/llm/cmd/agentdemo/main.go index 0bc8190..a3bd77d 100644 --- a/internal/llm/cmd/agentdemo/main.go +++ b/internal/llm/cmd/agentdemo/main.go @@ -1,8 +1,8 @@ //go:build llama -// agentdemo: drive the grammar-constrained tool-calling agent against -// either canned mock data or the real gortex daemon. Same model, same -// agent loop, the only variable is the backend. +// agentdemo: drive the structured tool-calling agent against either +// canned mock data or the real gortex daemon. Same agent loop, the +// only variable is the backend. // // go build -tags llama -o /tmp/agentdemo ./internal/llm/cmd/agentdemo // /tmp/agentdemo -model ~/models/qwen2.5-3b-instruct-q4_k_m.gguf \ @@ -13,6 +13,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -20,6 +21,7 @@ import ( "github.com/zzet/gortex/internal/llm" "github.com/zzet/gortex/internal/llm/agent" + "github.com/zzet/gortex/internal/llm/provider" ) const promptP1 = `RULES (follow these exactly): @@ -88,7 +90,6 @@ func main() { ref := flag.String("ref", "", "restrict queries to this ref tag") promptName := flag.String("prompt", "p2", "prompt variant: p0 | p1 | p2 | chain") chainMode := flag.Bool("chain", false, "register contract + dependency tools for cross-system tracing") - showGrammar := flag.Bool("show-grammar", false, "print the generated GBNF and exit") flag.Parse() systemExtras, err := promptByName(*promptName) @@ -96,12 +97,6 @@ func main() { fmt.Fprintln(os.Stderr, err) os.Exit(2) } - - tmpl, err := agent.TemplateByName(*tmplName) - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(2) - } if *modelPath == "" { fmt.Fprintln(os.Stderr, "error: -model is required") os.Exit(2) @@ -132,33 +127,30 @@ func main() { tools = agent.GortexTools(backend, scope) } - m, err := llm.LoadModel(*modelPath, *gpu) + cfg := llm.Config{ + Provider: "local", + Local: llm.LocalConfig{ + Model: *modelPath, + Ctx: *nCtx, + GPULayers: *gpu, + Template: *tmplName, + }, + }.ApplyDefaults() + prov, err := provider.New(cfg) if err != nil { - fmt.Fprintf(os.Stderr, "load: %v\n", err) + fmt.Fprintf(os.Stderr, "provider: %v\n", err) os.Exit(1) } - defer m.Close() + defer prov.Close() - ctx, err := m.NewContext(*nCtx, 0) - if err != nil { - fmt.Fprintf(os.Stderr, "context: %v\n", err) - os.Exit(1) - } - defer ctx.Close() - - ag, err := agent.New(ctx, tools, tmpl) + ag, err := agent.New(prov, tools) if err != nil { fmt.Fprintf(os.Stderr, "agent: %v\n", err) os.Exit(1) } - if *showGrammar { - fmt.Println(ag.Grammar()) - return - } - t0 := time.Now() - answer, transcript, runErr := ag.Run(systemExtras, *question, *maxSteps) + answer, transcript, runErr := ag.Run(context.Background(), systemExtras, *question, *maxSteps) fmt.Println("=== TRANSCRIPT ===") for i, st := range transcript { diff --git a/internal/llm/cmd/bench/main.go b/internal/llm/cmd/bench/main.go index ee180fd..ed8f3b3 100644 --- a/internal/llm/cmd/bench/main.go +++ b/internal/llm/cmd/bench/main.go @@ -9,6 +9,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -18,6 +19,7 @@ import ( "github.com/zzet/gortex/internal/llm" "github.com/zzet/gortex/internal/llm/agent" + "github.com/zzet/gortex/internal/llm/provider" ) type modelSpec struct { @@ -158,29 +160,25 @@ func promptByName(name string) (string, error) { func runOne(spec modelSpec, qs []question, ctxSize int, systemExtras string, backend llm.Backend, runScope llm.Scope, chain bool) []result { results := make([]result, 0, len(qs)) - m, err := llm.LoadModel(spec.Path, 999) + cfg := llm.Config{ + Provider: "local", + Local: llm.LocalConfig{ + Model: spec.Path, + Ctx: ctxSize, + GPULayers: 999, + Template: spec.Template, + }, + }.ApplyDefaults() + prov, err := provider.New(cfg) if err != nil { for _, q := range qs { - results = append(results, result{Question: q.Name, Err: "load: " + err.Error()}) - } - return results - } - defer m.Close() - - tmpl, err := agent.TemplateByName(spec.Template) - if err != nil { - for _, q := range qs { - results = append(results, result{Question: q.Name, Err: "template: " + err.Error()}) + results = append(results, result{Question: q.Name, Err: "provider: " + err.Error()}) } return results } + defer prov.Close() for _, q := range qs { - ctx, err := m.NewContext(ctxSize, 0) - if err != nil { - results = append(results, result{Question: q.Name, Err: "context: " + err.Error()}) - continue - } // Rebuild tools per question so the scope can differ across // the cross-repo set. Chain mode adds contracts + // get_dependencies for cross-system tracing. @@ -190,10 +188,9 @@ func runOne(spec modelSpec, qs []question, ctxSize int, systemExtras string, bac } else { tools = agent.GortexTools(backend, q.effectiveScope(runScope)) } - ag, err := agent.New(ctx, tools, tmpl) + ag, err := agent.New(prov, tools) if err != nil { results = append(results, result{Question: q.Name, Err: "agent: " + err.Error()}) - ctx.Close() continue } @@ -202,7 +199,7 @@ func runOne(spec modelSpec, qs []question, ctxSize int, systemExtras string, bac maxSteps = 20 } t0 := time.Now() - ans, transcript, runErr := ag.Run(systemExtras, q.Text, maxSteps) + ans, transcript, runErr := ag.Run(context.Background(), systemExtras, q.Text, maxSteps) elapsed := time.Since(t0) r := result{Question: q.Name, Steps: stepCount(transcript), Answer: ans, Elapsed: elapsed} @@ -210,7 +207,6 @@ func runOne(spec modelSpec, qs []question, ctxSize int, systemExtras string, bac r.Err = runErr.Error() } results = append(results, r) - ctx.Close() } return results } diff --git a/internal/llm/config.go b/internal/llm/config.go index 9398973..aac853b 100644 --- a/internal/llm/config.go +++ b/internal/llm/config.go @@ -1,12 +1,13 @@ -// Package llm — config loader for the in-process LLM service. +// Package llm — config loader for the LLM service. // -// This file is pure Go (no build tag) so both llama and non-llama -// builds can compile it. The actual service construction lives in -// service.go (llama) and service_stub.go (!llama). +// This file is pure Go (no build tag) so every build can compile it. +// The actual provider construction lives under internal/llm/provider/ +// — the `local` provider is the only one that needs `-tags llama`. // // Resolution order: file values are populated by the gortex config -// loader; MergeEnv overlays any GORTEX_LLM_* env var that's set -// (env wins). Empty fields fall back to defaults applied here. +// loader; MergeEnv overlays any GORTEX_LLM_* env var that's set (env +// wins); ApplyDefaults fills any remaining zero fields. A repo-local +// Config can additionally be layered over a global one via MergedWith. package llm import ( @@ -15,79 +16,264 @@ import ( "strings" ) -// Config is the YAML-friendly LLM block. Lives alongside the rest of -// the gortex config; promoted from .gortex.yaml's `llm:` section. +// Config is the YAML-friendly `llm:` block. The active backend is +// chosen by Provider; each provider reads its own sub-block, so a +// single config file can carry settings for several providers and +// switch between them by changing one key. type Config struct { - // Path to a .gguf model file. Required — empty disables the - // service entirely (no tool registered, no startup cost). - Model string `mapstructure:"model" yaml:"model,omitempty"` + // Provider selects the inference backend: "local" (llama.cpp, + // in-process, requires a `-tags llama` build), "anthropic", + // "openai", or "ollama". Empty defaults to "local". + Provider string `mapstructure:"provider" yaml:"provider,omitempty"` - // Context size in tokens. Defaults to 4096. - Ctx int `mapstructure:"ctx" yaml:"ctx,omitempty"` + // MaxSteps caps the agent tool-loop. Provider-agnostic. Defaults + // to 16. + MaxSteps int `mapstructure:"max_steps" yaml:"max_steps,omitempty"` + + // Local configures the in-process llama.cpp provider. + Local LocalConfig `mapstructure:"local" yaml:"local,omitempty"` + // Anthropic configures the hosted Anthropic Messages API provider. + Anthropic RemoteConfig `mapstructure:"anthropic" yaml:"anthropic,omitempty"` + // OpenAI configures the hosted OpenAI Chat Completions provider. + OpenAI RemoteConfig `mapstructure:"openai" yaml:"openai,omitempty"` + // Ollama configures a local/remote Ollama daemon provider. + Ollama OllamaConfig `mapstructure:"ollama" yaml:"ollama,omitempty"` +} - // Number of layers to offload to GPU (Metal/CUDA). 999 = all. - // 0 = CPU-only. Defaults to 999. +// LocalConfig is the `llm.local:` sub-block — settings for the +// in-process llama.cpp provider. +type LocalConfig struct { + // Model is the path to a .gguf model file. Required for the local + // provider — empty disables it. + Model string `mapstructure:"model" yaml:"model,omitempty"` + // Ctx is the context window in tokens. Defaults to 4096. + Ctx int `mapstructure:"ctx" yaml:"ctx,omitempty"` + // GPULayers is the number of layers to offload to GPU (Metal / + // CUDA). 999 = all, 0 = CPU-only. Defaults to 999. GPULayers int `mapstructure:"gpu_layers" yaml:"gpu_layers,omitempty"` + // Template is the chat-template family: "chatml" (Qwen2.5, + // Hermes-3) or "llama3" (Llama-3.x native). Defaults to "chatml". + Template string `mapstructure:"template" yaml:"template,omitempty"` +} - // Maximum agent loop steps before giving up. Defaults to 16. - MaxSteps int `mapstructure:"max_steps" yaml:"max_steps,omitempty"` +// RemoteConfig is the sub-block shared by the HTTP API providers +// (Anthropic, OpenAI). +type RemoteConfig struct { + // Model is the API model identifier (e.g. "claude-sonnet-4-6", + // "gpt-4o"). Defaulted per provider by ApplyDefaults. + Model string `mapstructure:"model" yaml:"model,omitempty"` + // APIKeyEnv names the environment variable holding the API key. + // Defaulted per provider by ApplyDefaults. The key itself is never + // stored in the config file. + APIKeyEnv string `mapstructure:"api_key_env" yaml:"api_key_env,omitempty"` + // BaseURL overrides the API endpoint (proxies, gateways, Azure). + // Defaulted per provider by ApplyDefaults. + BaseURL string `mapstructure:"base_url" yaml:"base_url,omitempty"` +} - // Chat template family: "chatml" (Qwen2.5, Hermes-3) or - // "llama3" (Llama-3.x native). Defaults to "chatml". - Template string `mapstructure:"template" yaml:"template,omitempty"` +// OllamaConfig is the `llm.ollama:` sub-block. +type OllamaConfig struct { + // Model is the Ollama model tag (e.g. "qwen2.5-coder:7b"). + // Required for the Ollama provider — empty disables it. + Model string `mapstructure:"model" yaml:"model,omitempty"` + // Host is the Ollama daemon base URL. Defaults to + // "http://localhost:11434". + Host string `mapstructure:"host" yaml:"host,omitempty"` +} + +// Default endpoints / key env vars, applied by ApplyDefaults. +const ( + defaultAnthropicModel = "claude-sonnet-4-6" + defaultAnthropicBaseURL = "https://api.anthropic.com" + defaultAnthropicKeyEnv = "ANTHROPIC_API_KEY" + + defaultOpenAIModel = "gpt-4o" + defaultOpenAIBaseURL = "https://api.openai.com" + defaultOpenAIKeyEnv = "OPENAI_API_KEY" + + defaultOllamaHost = "http://localhost:11434" +) + +// ProviderName returns the effective provider, applying the "local" +// default for an empty value. +func (c Config) ProviderName() string { + if strings.TrimSpace(c.Provider) == "" { + return "local" + } + return strings.ToLower(strings.TrimSpace(c.Provider)) } -// IsEnabled reports whether the config carries enough to start a -// service. Empty Model = disabled. -func (c Config) IsEnabled() bool { return strings.TrimSpace(c.Model) != "" } +// IsEnabled reports whether the config carries enough to start the +// active provider. A provider is enabled once its required fields are +// set: the local and Ollama providers need a model; the hosted +// providers need a model (defaulted) — the API key is validated at +// provider-construction time, not here. +func (c Config) IsEnabled() bool { + switch c.ProviderName() { + case "local": + return strings.TrimSpace(c.Local.Model) != "" + case "anthropic": + return strings.TrimSpace(c.Anthropic.Model) != "" + case "openai": + return strings.TrimSpace(c.OpenAI.Model) != "" + case "ollama": + return strings.TrimSpace(c.Ollama.Model) != "" + default: + return false + } +} // MergeEnv overlays any GORTEX_LLM_* env var on top of the file -// values. Env wins. After merging, ApplyDefaults fills in any -// remaining zero values. +// values, then applies defaults. Env wins over file. GORTEX_LLM_MODEL +// targets the *active* provider's model so the common "swap the +// model" case needs only one variable. func (c Config) MergeEnv() Config { - if v := os.Getenv("GORTEX_LLM_MODEL"); v != "" { - c.Model = v + if v := os.Getenv("GORTEX_LLM_PROVIDER"); v != "" { + c.Provider = v } - if v := os.Getenv("GORTEX_LLM_CTX"); v != "" { + if v := os.Getenv("GORTEX_LLM_MAX_STEPS"); v != "" { if n, err := strconv.Atoi(v); err == nil { - c.Ctx = n + c.MaxSteps = n } } - if v := os.Getenv("GORTEX_LLM_GPU_LAYERS"); v != "" { + if v := os.Getenv("GORTEX_LLM_MODEL"); v != "" { + switch c.ProviderName() { + case "anthropic": + c.Anthropic.Model = v + case "openai": + c.OpenAI.Model = v + case "ollama": + c.Ollama.Model = v + default: + c.Local.Model = v + } + } + if v := os.Getenv("GORTEX_LLM_CTX"); v != "" { if n, err := strconv.Atoi(v); err == nil { - c.GPULayers = n + c.Local.Ctx = n } } - if v := os.Getenv("GORTEX_LLM_MAX_STEPS"); v != "" { + if v := os.Getenv("GORTEX_LLM_GPU_LAYERS"); v != "" { if n, err := strconv.Atoi(v); err == nil { - c.MaxSteps = n + c.Local.GPULayers = n } } if v := os.Getenv("GORTEX_LLM_TEMPLATE"); v != "" { - c.Template = v + c.Local.Template = v } return c.ApplyDefaults() } // ApplyDefaults fills zero-valued fields with the canonical defaults. -// Called by MergeEnv; safe to call standalone. +// Called by MergeEnv; safe to call standalone and idempotent. func (c Config) ApplyDefaults() Config { - if c.Ctx == 0 { - c.Ctx = 4096 - } - if c.GPULayers == 0 { - // 0 stays 0 only if the user explicitly set it; we can't - // distinguish "set to 0" from "not set" at the struct level. - // Convention: 0 means CPU-only is what the user wants iff - // they put it in the YAML or env. For default-zero we pick - // 999 (offload all layers). - c.GPULayers = 999 + if strings.TrimSpace(c.Provider) == "" { + c.Provider = "local" } if c.MaxSteps == 0 { c.MaxSteps = 16 } - if c.Template == "" { - c.Template = "chatml" + + // local + if c.Local.Ctx == 0 { + c.Local.Ctx = 4096 + } + if c.Local.GPULayers == 0 { + // 0 is indistinguishable from "unset" at the struct level; the + // default offloads all layers. A user wanting CPU-only sets a + // negative value? No — convention: explicit 0 in YAML still + // reads as 0 here, so we can't honour CPU-only via this field + // cleanly. 999 is the safe, fast default. + c.Local.GPULayers = 999 + } + if c.Local.Template == "" { + c.Local.Template = "chatml" + } + + // anthropic + if c.Anthropic.Model == "" { + c.Anthropic.Model = defaultAnthropicModel + } + if c.Anthropic.APIKeyEnv == "" { + c.Anthropic.APIKeyEnv = defaultAnthropicKeyEnv + } + if c.Anthropic.BaseURL == "" { + c.Anthropic.BaseURL = defaultAnthropicBaseURL + } + + // openai + if c.OpenAI.Model == "" { + c.OpenAI.Model = defaultOpenAIModel } + if c.OpenAI.APIKeyEnv == "" { + c.OpenAI.APIKeyEnv = defaultOpenAIKeyEnv + } + if c.OpenAI.BaseURL == "" { + c.OpenAI.BaseURL = defaultOpenAIBaseURL + } + + // ollama + if c.Ollama.Host == "" { + c.Ollama.Host = defaultOllamaHost + } + return c } + +// MergedWith returns c with each zero-valued field filled from fb. +// Non-zero fields of c always win — including an explicit per-repo +// override of an inherited global value. Used to layer a repo-local +// Config (c) over a global user Config (fb). Call before ApplyDefaults +// so genuine zero values still merge. +func (c Config) MergedWith(fb Config) Config { + if c.Provider == "" { + c.Provider = fb.Provider + } + if c.MaxSteps == 0 { + c.MaxSteps = fb.MaxSteps + } + c.Local = c.Local.mergedWith(fb.Local) + c.Anthropic = c.Anthropic.mergedWith(fb.Anthropic) + c.OpenAI = c.OpenAI.mergedWith(fb.OpenAI) + c.Ollama = c.Ollama.mergedWith(fb.Ollama) + return c +} + +func (l LocalConfig) mergedWith(fb LocalConfig) LocalConfig { + if l.Model == "" { + l.Model = fb.Model + } + if l.Ctx == 0 { + l.Ctx = fb.Ctx + } + if l.GPULayers == 0 { + l.GPULayers = fb.GPULayers + } + if l.Template == "" { + l.Template = fb.Template + } + return l +} + +func (r RemoteConfig) mergedWith(fb RemoteConfig) RemoteConfig { + if r.Model == "" { + r.Model = fb.Model + } + if r.APIKeyEnv == "" { + r.APIKeyEnv = fb.APIKeyEnv + } + if r.BaseURL == "" { + r.BaseURL = fb.BaseURL + } + return r +} + +func (o OllamaConfig) mergedWith(fb OllamaConfig) OllamaConfig { + if o.Model == "" { + o.Model = fb.Model + } + if o.Host == "" { + o.Host = fb.Host + } + return o +} diff --git a/internal/llm/config_test.go b/internal/llm/config_test.go new file mode 100644 index 0000000..bc3f86a --- /dev/null +++ b/internal/llm/config_test.go @@ -0,0 +1,119 @@ +package llm + +import "testing" + +func TestConfig_ProviderName_DefaultsToLocal(t *testing.T) { + if got := (Config{}).ProviderName(); got != "local" { + t.Fatalf("empty provider: got %q want local", got) + } + if got := (Config{Provider: " Anthropic "}).ProviderName(); got != "anthropic" { + t.Fatalf("normalisation: got %q want anthropic", got) + } +} + +func TestConfig_IsEnabled(t *testing.T) { + cases := []struct { + name string + cfg Config + want bool + }{ + {"empty", Config{}, false}, + {"local with model", Config{Provider: "local", Local: LocalConfig{Model: "/m.gguf"}}, true}, + {"local no model", Config{Provider: "local"}, false}, + {"anthropic with model", Config{Provider: "anthropic", Anthropic: RemoteConfig{Model: "claude"}}, true}, + {"anthropic no model", Config{Provider: "anthropic"}, false}, + {"openai with model", Config{Provider: "openai", OpenAI: RemoteConfig{Model: "gpt"}}, true}, + {"ollama with model", Config{Provider: "ollama", Ollama: OllamaConfig{Model: "qwen"}}, true}, + {"ollama no model", Config{Provider: "ollama"}, false}, + {"unknown provider", Config{Provider: "bogus", Local: LocalConfig{Model: "/m.gguf"}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.cfg.IsEnabled(); got != tc.want { + t.Fatalf("IsEnabled=%v want %v", got, tc.want) + } + }) + } +} + +func TestConfig_ApplyDefaults(t *testing.T) { + c := Config{}.ApplyDefaults() + if c.Provider != "local" { + t.Errorf("provider=%q want local", c.Provider) + } + if c.MaxSteps != 16 { + t.Errorf("max_steps=%d want 16", c.MaxSteps) + } + if c.Local.Ctx != 4096 || c.Local.GPULayers != 999 || c.Local.Template != "chatml" { + t.Errorf("local defaults wrong: %+v", c.Local) + } + if c.Anthropic.Model != defaultAnthropicModel || c.Anthropic.APIKeyEnv != defaultAnthropicKeyEnv || c.Anthropic.BaseURL != defaultAnthropicBaseURL { + t.Errorf("anthropic defaults wrong: %+v", c.Anthropic) + } + if c.OpenAI.Model != defaultOpenAIModel || c.OpenAI.APIKeyEnv != defaultOpenAIKeyEnv || c.OpenAI.BaseURL != defaultOpenAIBaseURL { + t.Errorf("openai defaults wrong: %+v", c.OpenAI) + } + if c.Ollama.Host != defaultOllamaHost { + t.Errorf("ollama host=%q want %q", c.Ollama.Host, defaultOllamaHost) + } +} + +func TestConfig_ApplyDefaults_Idempotent(t *testing.T) { + once := Config{Provider: "anthropic", Anthropic: RemoteConfig{Model: "m"}}.ApplyDefaults() + twice := once.ApplyDefaults() + if once != twice { + t.Fatalf("ApplyDefaults not idempotent:\n once=%+v\n twice=%+v", once, twice) + } +} + +func TestConfig_MergeEnv(t *testing.T) { + t.Setenv("GORTEX_LLM_PROVIDER", "anthropic") + t.Setenv("GORTEX_LLM_MODEL", "claude-opus-4-7") + t.Setenv("GORTEX_LLM_MAX_STEPS", "8") + c := Config{}.MergeEnv() + if c.Provider != "anthropic" { + t.Errorf("provider=%q want anthropic", c.Provider) + } + if c.Anthropic.Model != "claude-opus-4-7" { + t.Errorf("anthropic model=%q — GORTEX_LLM_MODEL should target the active provider", c.Anthropic.Model) + } + if c.MaxSteps != 8 { + t.Errorf("max_steps=%d want 8", c.MaxSteps) + } +} + +func TestConfig_MergeEnv_ModelTargetsLocalByDefault(t *testing.T) { + t.Setenv("GORTEX_LLM_PROVIDER", "") + t.Setenv("GORTEX_LLM_MODEL", "/local/m.gguf") + c := Config{}.MergeEnv() + if c.Local.Model != "/local/m.gguf" { + t.Errorf("local model=%q want /local/m.gguf", c.Local.Model) + } +} + +func TestConfig_MergedWith(t *testing.T) { + global := Config{ + Provider: "local", + MaxSteps: 16, + Local: LocalConfig{Model: "/g.gguf", Template: "chatml", Ctx: 4096}, + Anthropic: RemoteConfig{APIKeyEnv: "ANTHROPIC_API_KEY"}, + } + local := Config{Local: LocalConfig{Model: "/repo.gguf"}} // overrides only the model + + got := local.MergedWith(global) + if got.Provider != "local" { + t.Errorf("provider=%q — global should fill", got.Provider) + } + if got.Local.Model != "/repo.gguf" { + t.Errorf("local model=%q — repo should win", got.Local.Model) + } + if got.Local.Template != "chatml" || got.Local.Ctx != 4096 { + t.Errorf("local sub-fields not filled from global: %+v", got.Local) + } + if got.Anthropic.APIKeyEnv != "ANTHROPIC_API_KEY" { + t.Errorf("anthropic block not merged: %+v", got.Anthropic) + } + if got.MaxSteps != 16 { + t.Errorf("max_steps=%d — global should fill", got.MaxSteps) + } +} diff --git a/internal/llm/prompts.go b/internal/llm/prompts.go new file mode 100644 index 0000000..edd209c --- /dev/null +++ b/internal/llm/prompts.go @@ -0,0 +1,178 @@ +// Package llm — prompt tiers and structured-output schemas. +// +// The search-assist passes (expand / rerank / verify) need different +// prompting depending on the model behind the active provider. A +// small local GGUF model needs verbose, rule-heavy, example-laden +// instructions; a hosted frontier model reasons well with light +// direction and is measurably *hurt* by over-constraining prompts. +// +// Rather than carry a prompt set per provider, prompts are keyed by a +// capability *tier* (PromptProfile): two sets to maintain, not four. +// The svc layer maps a provider to its tier via ProfileForProvider. +package llm + +// PromptProfile selects a prompt tier. +type PromptProfile int + +const ( + // ProfileSmall is the verbose, rule-heavy tier — tuned for small + // local GGUF models and small Ollama coder models, which need + // explicit rules and blocklists to behave. + ProfileSmall PromptProfile = iota + // ProfileFrontier is the terse tier — tuned for hosted frontier + // models (Anthropic, OpenAI) that reason well with light direction + // and lose quality when over-constrained. + ProfileFrontier +) + +// ProfileForProvider maps a provider's Name() to its prompt tier. +// "local" and "ollama" run small models → ProfileSmall; the hosted +// providers run frontier models → ProfileFrontier. Unknown names fall +// back to ProfileSmall (the safe, more-explicit tier). +func ProfileForProvider(name string) PromptProfile { + switch name { + case "anthropic", "openai": + return ProfileFrontier + default: + return ProfileSmall + } +} + +// --- Expand ----------------------------------------------------------------- + +const expandSystemSmall = `You expand a code-search query into a small set of CONCRETE identifier-style terms a programmer would actually grep for. ` + + `Output strict JSON: {"terms":["","",...]}. ` + + `Include 2 to 5 terms. Each term MUST be a single word with no spaces and no punctuation other than underscores. ` + + ` +RULES: +1. Prefer DOMAIN-SPECIFIC terms over generic English. ` + + `GOOD examples: bcrypt, argon2, scrypt, sha256, hmac, jwt, oauth, pbkdf2, kdf, salt. ` + + `BAD examples (NEVER emit): function, library, algorithm, code, system, data, service, value, info, content, thing, stuff, name, general, common, logic, process, handler, flow, action, helper, util, utility. ` + + ` +2. Prefer terms that are likely SYMBOL names in a codebase (camelCase / snake_case / PascalCase fragments), library or protocol names, well-known acronyms. ` + + ` +3. Do NOT echo the original query words. ` + + ` +4. If the query has no obvious domain-specific neighbours, emit FEWER terms (or an empty array) — quality over quantity.` + +const expandSystemFrontier = `Expand a code-search query into 2-5 concrete identifier-style terms a programmer would grep for: library and protocol names, well-known acronyms, camelCase / snake_case symbol fragments. ` + + `Skip generic English nouns (function, data, handler, ...). Do not echo the query words. Fewer strong terms beat many weak ones — an empty list is fine when there are no good neighbours. ` + + `Output JSON: {"terms":["",...]}.` + +// ExpandSystemPrompt returns the system prompt for the query-expansion +// pass at the given tier. +func ExpandSystemPrompt(p PromptProfile) string { + if p == ProfileFrontier { + return expandSystemFrontier + } + return expandSystemSmall +} + +// --- Rerank ----------------------------------------------------------------- + +const rerankSystemSmall = `You rerank code-search results by relevance to a natural-language task. ` + + `Given a query and a list of candidate symbols (id | name | optional signature), output strict JSON: {"order":["id1","id2",...]} ` + + `with the most relevant candidates first. ` + + `Use ONLY the provided ids verbatim. Do not invent ids. You may drop ids that are clearly unrelated.` + +const rerankSystemFrontier = `Reorder the candidate symbols by relevance to the query, most relevant first. ` + + `Use the provided ids verbatim; drop ones that are clearly unrelated. Output JSON: {"order":["",...]}.` + +// RerankSystemPrompt returns the system prompt for the rerank pass at +// the given tier. +func RerankSystemPrompt(p PromptProfile) string { + if p == ProfileFrontier { + return rerankSystemFrontier + } + return rerankSystemSmall +} + +// --- Verify ----------------------------------------------------------------- + +const verifySystemSmall = `You filter code-search candidates by reading their BODY, SIGNATURE, and CALLERS, and keeping every one whose code is genuinely about the user's query. ` + + `Each candidate is presented as: + + | | +body: + +callers: +- | +- ... +--- + +Output strict JSON: {"keep":["id1","id2",...]} listing EVERY id whose code is meaningfully related to the query, in your preferred order (most relevant first). + +RULES (follow exactly): +1. Evaluate EACH candidate INDEPENDENTLY. Multiple candidates can be valid matches — keep them all. +2. A name that contains a query word is not enough by itself — read what the code DOES. +3. Cross-reference the CALLERS and the SIGNATURE's parameter types against the query DOMAIN. If a function hashes data but is only called from a "publishDiagnostics" or "renderLog" path with a non-password parameter type, it is NOT about hashing passwords — DROP it. +4. Be GENEROUS, not restrictive: if a candidate's body AND callers AND signature are all plausibly about the query, KEEP it. The user wants signal, not a single "best" pick. +5. Drop a candidate when its body, signature, or callers reveal the operation is on the wrong KIND of data for the query. +6. Returning {"keep":[]} is valid ONLY when NO candidate is genuinely about the query. +7. Use ONLY the provided ids verbatim. Never invent or modify an id.` + +const verifySystemFrontier = `Filter code-search candidates: keep every one whose body, signature, and callers show it genuinely concerns the query; drop the rest. ` + + `Judge by what the code DOES and the data domain its callers imply — not by name overlap with the query. Keep all genuine matches, not just the single best. ` + + `An empty result {"keep":[]} is valid and correct when nothing genuinely matches. Use the provided ids verbatim. Output JSON: {"keep":["",...]}.` + +// VerifySystemPrompt returns the system prompt for the body-grounded +// verification pass at the given tier. +func VerifySystemPrompt(p PromptProfile) string { + if p == ProfileFrontier { + return verifySystemFrontier + } + return verifySystemSmall +} + +// --- Structured-output schemas ---------------------------------------------- + +// JSONSchemaFor returns a provider-agnostic JSON Schema (as a +// marshalable map) describing the response shape for a JSONShape. The +// HTTP providers feed it to their native structured-output mechanism +// (Anthropic forced-tool input_schema, OpenAI json_schema, Ollama +// format); the local provider ignores it and uses a GBNF grammar +// instead. Returns nil for ShapeFreeform. +// +// tools is consulted only for ShapeToolCall, where the "tool" field is +// constrained to an enum of the tool names. +func JSONSchemaFor(shape JSONShape, tools []ToolSpec) map[string]any { + stringArray := map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + } + switch shape { + case ShapeExpandTerms: + return listSchema("terms", stringArray) + case ShapeRerankOrder: + return listSchema("order", stringArray) + case ShapeVerifyKeep: + return listSchema("keep", stringArray) + case ShapeToolCall: + names := make([]any, len(tools)) + for i, t := range tools { + names[i] = t.Name + } + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "tool": map[string]any{"type": "string", "enum": names}, + "args": map[string]any{"type": "object"}, + }, + "required": []any{"tool", "args"}, + "additionalProperties": false, + } + default: + return nil + } +} + +// listSchema builds the schema for a single-key object whose value is +// a JSON array — the shape shared by expand / rerank / verify. +func listSchema(key string, arraySchema map[string]any) map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{key: arraySchema}, + "required": []any{key}, + "additionalProperties": false, + } +} diff --git a/internal/llm/prompts_test.go b/internal/llm/prompts_test.go new file mode 100644 index 0000000..a52b7f5 --- /dev/null +++ b/internal/llm/prompts_test.go @@ -0,0 +1,86 @@ +package llm + +import "testing" + +func TestProfileForProvider(t *testing.T) { + cases := map[string]PromptProfile{ + "anthropic": ProfileFrontier, + "openai": ProfileFrontier, + "local": ProfileSmall, + "ollama": ProfileSmall, + "": ProfileSmall, + "unknown": ProfileSmall, + } + for name, want := range cases { + if got := ProfileForProvider(name); got != want { + t.Errorf("ProfileForProvider(%q)=%v want %v", name, got, want) + } + } +} + +func TestSystemPrompts_DifferByTier(t *testing.T) { + if ExpandSystemPrompt(ProfileSmall) == ExpandSystemPrompt(ProfileFrontier) { + t.Error("expand prompt identical across tiers") + } + if RerankSystemPrompt(ProfileSmall) == RerankSystemPrompt(ProfileFrontier) { + t.Error("rerank prompt identical across tiers") + } + if VerifySystemPrompt(ProfileSmall) == VerifySystemPrompt(ProfileFrontier) { + t.Error("verify prompt identical across tiers") + } + for _, s := range []string{ + ExpandSystemPrompt(ProfileSmall), ExpandSystemPrompt(ProfileFrontier), + RerankSystemPrompt(ProfileSmall), RerankSystemPrompt(ProfileFrontier), + VerifySystemPrompt(ProfileSmall), VerifySystemPrompt(ProfileFrontier), + } { + if s == "" { + t.Error("empty system prompt") + } + } +} + +func TestJSONSchemaFor_Freeform(t *testing.T) { + if JSONSchemaFor(ShapeFreeform, nil) != nil { + t.Error("freeform shape must have no schema") + } +} + +func TestJSONSchemaFor_ListShapes(t *testing.T) { + cases := map[JSONShape]string{ + ShapeExpandTerms: "terms", + ShapeRerankOrder: "order", + ShapeVerifyKeep: "keep", + } + for shape, key := range cases { + s := JSONSchemaFor(shape, nil) + if s == nil { + t.Fatalf("shape %d: nil schema", shape) + } + if s["type"] != "object" { + t.Errorf("shape %d: type=%v want object", shape, s["type"]) + } + props, ok := s["properties"].(map[string]any) + if !ok { + t.Fatalf("shape %d: properties not a map: %v", shape, s["properties"]) + } + if _, ok := props[key]; !ok { + t.Errorf("shape %d: missing property %q (have %v)", shape, key, props) + } + } +} + +func TestJSONSchemaFor_ToolCallEnumeratesNames(t *testing.T) { + s := JSONSchemaFor(ShapeToolCall, []ToolSpec{{Name: "search_symbols"}, {Name: "get_callers"}}) + props := s["properties"].(map[string]any) + tool := props["tool"].(map[string]any) + enum, ok := tool["enum"].([]any) + if !ok || len(enum) != 2 { + t.Fatalf("tool enum=%v", tool["enum"]) + } + if enum[0] != "search_symbols" || enum[1] != "get_callers" { + t.Errorf("tool enum order=%v", enum) + } + if _, ok := props["args"]; !ok { + t.Error("tool-call schema missing args property") + } +} diff --git a/internal/llm/provider.go b/internal/llm/provider.go new file mode 100644 index 0000000..e78b549 --- /dev/null +++ b/internal/llm/provider.go @@ -0,0 +1,110 @@ +// Package llm — provider abstraction. +// +// Provider isolates every LLM operation (the agent tool-loop and the +// three search-assist passes) from where inference actually runs. Four +// implementations live under internal/llm/provider/: a llama.cpp +// `local` provider (CGO, `-tags llama`) and three pure-Go HTTP +// providers (`anthropic`, `openai`, `ollama`). They are swapped via +// the `llm.provider` config key — see Config. +// +// The whole surface is a single method, Complete: one structured +// single-turn call. The agent loop is just repeated Complete calls +// with a growing Messages slice; the assist passes are one Complete +// call each. Keeping the interface to one method is what lets the +// HTTP providers stay small and the build-tag split stay contained to +// the `local` package. +package llm + +import "context" + +// Role identifies who produced a Message in a provider conversation. +type Role string + +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleTool Role = "tool" +) + +// Message is one turn in a provider conversation. Content is the plain +// text payload — for a RoleAssistant message that represents a tool +// call it holds the raw tool-call JSON; for a RoleTool message it +// holds the tool's observation and ToolName names the tool that +// produced it. +// +// The conversation is provider-neutral: the `local` provider flattens +// it through a llama chat template, the HTTP providers map it onto +// their native messages array. Tool calls are carried as plain text +// (the "emulated" protocol the local model already uses), not via any +// provider's native tool-use wire format — that keeps a single +// Message shape working across all four providers. +type Message struct { + Role Role + Content string + ToolName string // set on RoleTool messages: which tool produced Content +} + +// ToolSpec describes one callable tool to a provider. When a request +// carries Shape == ShapeToolCall the provider constrains the emitted +// "tool" field to exactly these names. +type ToolSpec struct { + Name string + Description string +} + +// JSONShape names the structured-output schema a provider must enforce +// on a completion. ShapeFreeform applies no constraint; every other +// value corresponds to a concrete JSON object shape (see +// JSONSchemaFor) that the provider guarantees the response conforms +// to — via a GBNF grammar (local) or a json-schema / forced-tool +// mechanism (HTTP providers). +type JSONShape int + +const ( + ShapeFreeform JSONShape = iota // no structured constraint + ShapeExpandTerms // {"terms":[...]} + ShapeRerankOrder // {"order":[...]} + ShapeVerifyKeep // {"keep":[...]} + ShapeToolCall // {"tool":,"args":{...}} +) + +// CompletionRequest is one single-turn request to a Provider. The +// provider flattens Messages into its native wire format, applies the +// structured-output mechanism implied by Shape, and returns the raw +// model text. +type CompletionRequest struct { + // Messages is the conversation so far, oldest first. The last + // message is whatever the model should respond to. + Messages []Message + // MaxTokens caps generation length. 0 lets the provider pick a + // sensible default. + MaxTokens int + // Shape is the structured-output contract for the response. + Shape JSONShape + // Tools is consulted only when Shape == ShapeToolCall: the provider + // constrains the emitted "tool" field to these names. + Tools []ToolSpec +} + +// CompletionResponse is a Provider's single-turn output. Text is the +// raw model text — JSON conforming to the requested Shape when Shape +// is not ShapeFreeform. +type CompletionResponse struct { + Text string +} + +// Provider is a single-turn LLM completion backend. The agent loop and +// the search-assist passes are both built on repeated Complete calls. +type Provider interface { + // Name returns the provider's short identifier — one of "local", + // "anthropic", "openai", "ollama". Used to pick the prompt tier + // (see PromptProfile) and for diagnostics. + Name() string + // Complete runs one single-turn completion, honouring req.Shape + // with whatever structured-output mechanism the provider has. + Complete(ctx context.Context, req CompletionRequest) (CompletionResponse, error) + // Close releases any held resources (model weights, idle HTTP + // connections). Safe to call multiple times. + Close() error +} diff --git a/internal/llm/provider/anthropic/anthropic.go b/internal/llm/provider/anthropic/anthropic.go new file mode 100644 index 0000000..60d93f9 --- /dev/null +++ b/internal/llm/provider/anthropic/anthropic.go @@ -0,0 +1,247 @@ +// Package anthropic is the hosted Anthropic Messages API llm.Provider. +// +// It is pure Go — available in every build, no `-tags llama` needed. +// Structured output (the expand / rerank / verify shapes and the agent +// tool-call shape) is obtained by declaring a single forced tool whose +// input_schema is the requested JSONShape: the model's tool_use block +// carries the structured JSON, which is marshaled back to text. The +// agent tool-loop itself uses the *emulated* protocol — tool calls and +// results travel as plain text turns — so a single llm.Message shape +// works across all four providers. +package anthropic + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/zzet/gortex/internal/llm" +) + +// anthropicVersion is the API version header value the Messages API +// requires. +const anthropicVersion = "2023-06-01" + +// respondToolName is the synthetic tool used to force structured +// output. The model is given exactly this one tool with tool_choice +// pinned to it; its input is our JSON payload. +const respondToolName = "respond" + +// Provider implements llm.Provider against api.anthropic.com. +type Provider struct { + model string + apiKey string + baseURL string + client *http.Client +} + +var _ llm.Provider = (*Provider)(nil) + +// New constructs the Anthropic provider. The API key is read from the +// env var named by cfg.APIKeyEnv (default ANTHROPIC_API_KEY) — an +// unset key is a hard error so misconfiguration surfaces at startup, +// not on the first query. +func New(cfg llm.RemoteConfig) (llm.Provider, error) { + keyEnv := strings.TrimSpace(cfg.APIKeyEnv) + if keyEnv == "" { + keyEnv = "ANTHROPIC_API_KEY" + } + key := strings.TrimSpace(os.Getenv(keyEnv)) + if key == "" { + return nil, fmt.Errorf("anthropic: API key env %q is not set", keyEnv) + } + if strings.TrimSpace(cfg.Model) == "" { + return nil, errors.New("anthropic: llm.anthropic.model is empty") + } + base := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if base == "" { + base = "https://api.anthropic.com" + } + return &Provider{ + model: cfg.Model, + apiKey: key, + baseURL: base, + client: &http.Client{Timeout: 120 * time.Second}, + }, nil +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return "anthropic" } + +// Close releases idle HTTP connections. +func (p *Provider) Close() error { + p.client.CloseIdleConnections() + return nil +} + +// wire types for the Messages API request/response. +type apiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type apiTool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"input_schema"` +} + +type apiRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []apiMessage `json:"messages"` + Tools []apiTool `json:"tools,omitempty"` + ToolChoice map[string]any `json:"tool_choice,omitempty"` +} + +type apiContentBlock struct { + Type string `json:"type"` + Text string `json:"text"` + Name string `json:"name"` + Input json.RawMessage `json:"input"` +} + +type apiResponse struct { + Content []apiContentBlock `json:"content"` + Error *struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + +// Complete implements llm.Provider. +func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm.CompletionResponse, error) { + system, msgs := splitMessages(req.Messages) + maxTokens := req.MaxTokens + if maxTokens <= 0 { + maxTokens = 1024 + } + + body := apiRequest{ + Model: p.model, + MaxTokens: maxTokens, + System: system, + Messages: msgs, + } + structured := req.Shape != llm.ShapeFreeform + if structured { + body.Tools = []apiTool{{ + Name: respondToolName, + Description: "Return your response as the structured arguments of this tool.", + InputSchema: llm.JSONSchemaFor(req.Shape, req.Tools), + }} + body.ToolChoice = map[string]any{"type": "tool", "name": respondToolName} + } + + raw, err := json.Marshal(body) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("anthropic: marshal request: %w", err) + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/messages", bytes.NewReader(raw)) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("anthropic: build request: %w", err) + } + httpReq.Header.Set("content-type", "application/json") + httpReq.Header.Set("x-api-key", p.apiKey) + httpReq.Header.Set("anthropic-version", anthropicVersion) + + resp, err := p.client.Do(httpReq) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("anthropic: request failed: %w", err) + } + defer resp.Body.Close() + payload, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("anthropic: read response: %w", err) + } + + var parsed apiResponse + if err := json.Unmarshal(payload, &parsed); err != nil { + return llm.CompletionResponse{}, fmt.Errorf("anthropic: decode response (status %d): %w", resp.StatusCode, err) + } + if resp.StatusCode != http.StatusOK { + if parsed.Error != nil { + return llm.CompletionResponse{}, fmt.Errorf("anthropic: API error (status %d): %s: %s", resp.StatusCode, parsed.Error.Type, parsed.Error.Message) + } + return llm.CompletionResponse{}, fmt.Errorf("anthropic: API error (status %d): %s", resp.StatusCode, snippet(payload)) + } + + text, err := extractText(parsed.Content, structured) + if err != nil { + return llm.CompletionResponse{}, err + } + return llm.CompletionResponse{Text: text}, nil +} + +// splitMessages pulls every RoleSystem message into the top-level +// `system` string (Anthropic carries system separately from the +// messages array) and maps the rest onto user/assistant turns. Tool +// observations are rendered as user turns — the emulated tool-call +// protocol — which keeps the user/assistant alternation the API +// requires intact. +func splitMessages(in []llm.Message) (system string, msgs []apiMessage) { + var sys []string + for _, m := range in { + switch m.Role { + case llm.RoleSystem: + if s := strings.TrimSpace(m.Content); s != "" { + sys = append(sys, s) + } + case llm.RoleAssistant: + msgs = append(msgs, apiMessage{Role: "assistant", Content: m.Content}) + case llm.RoleTool: + msgs = append(msgs, apiMessage{Role: "user", Content: renderToolResult(m)}) + default: // RoleUser and anything unexpected + msgs = append(msgs, apiMessage{Role: "user", Content: m.Content}) + } + } + return strings.Join(sys, "\n\n"), msgs +} + +// renderToolResult formats a RoleTool message as a plain-text user +// turn for the emulated tool-call protocol. +func renderToolResult(m llm.Message) string { + if m.ToolName != "" { + return "Tool result (" + m.ToolName + "):\n" + m.Content + } + return "Tool result:\n" + m.Content +} + +// extractText pulls the response text out of the content blocks. For a +// structured request it returns the forced tool's input JSON; for a +// freeform request it concatenates the text blocks. +func extractText(blocks []apiContentBlock, structured bool) (string, error) { + if structured { + for _, b := range blocks { + if b.Type == "tool_use" && b.Name == respondToolName && len(b.Input) > 0 { + return strings.TrimSpace(string(b.Input)), nil + } + } + return "", errors.New("anthropic: response carried no forced-tool output") + } + var b strings.Builder + for _, blk := range blocks { + if blk.Type == "text" { + b.WriteString(blk.Text) + } + } + return strings.TrimSpace(b.String()), nil +} + +// snippet truncates a response body for inclusion in an error. +func snippet(b []byte) string { + const max = 300 + s := strings.TrimSpace(string(b)) + if len(s) > max { + return s[:max] + "…" + } + return s +} diff --git a/internal/llm/provider/anthropic/anthropic_test.go b/internal/llm/provider/anthropic/anthropic_test.go new file mode 100644 index 0000000..cbe29a0 --- /dev/null +++ b/internal/llm/provider/anthropic/anthropic_test.go @@ -0,0 +1,169 @@ +package anthropic + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/zzet/gortex/internal/llm" +) + +func TestNew_MissingKey(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "") + if _, err := New(llm.RemoteConfig{Model: "claude-x"}); err == nil { + t.Fatal("expected error when API key env is unset") + } +} + +func TestNew_MissingModel(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "k") + if _, err := New(llm.RemoteConfig{}); err == nil { + t.Fatal("expected error when model is unset") + } +} + +func TestComplete_StructuredUsesForcedTool(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + t.Errorf("path=%q want /v1/messages", r.URL.Path) + } + if r.Header.Get("x-api-key") != "test-key" { + t.Errorf("x-api-key=%q", r.Header.Get("x-api-key")) + } + if r.Header.Get("anthropic-version") == "" { + t.Error("missing anthropic-version header") + } + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.Header().Set("content-type", "application/json") + _, _ = io.WriteString(w, `{"content":[{"type":"tool_use","name":"respond","input":{"terms":["bcrypt","argon2"]}}]}`) + })) + defer srv.Close() + + t.Setenv("ANTHROPIC_API_KEY", "test-key") + p, err := New(llm.RemoteConfig{Model: "claude-x", BaseURL: srv.URL}) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + resp, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{ + {Role: llm.RoleSystem, Content: "you expand queries"}, + {Role: llm.RoleUser, Content: "Query: hashing"}, + }, + Shape: llm.ShapeExpandTerms, + MaxTokens: 100, + }) + if err != nil { + t.Fatal(err) + } + if resp.Text != `{"terms":["bcrypt","argon2"]}` { + t.Errorf("text=%q", resp.Text) + } + if gotBody["system"] != "you expand queries" { + t.Errorf("system=%v — system message should be hoisted to the top-level field", gotBody["system"]) + } + if gotBody["tool_choice"] == nil { + t.Error("structured request must force a tool_choice") + } + if tools, _ := gotBody["tools"].([]any); len(tools) != 1 { + t.Errorf("tools=%v want exactly the respond tool", gotBody["tools"]) + } + if msgs, _ := gotBody["messages"].([]any); len(msgs) != 1 { + t.Errorf("messages=%v — system should be extracted, leaving just the user turn", gotBody["messages"]) + } +} + +func TestComplete_FreeformNoTools(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"content":[{"type":"text","text":"hello world"}]}`) + })) + defer srv.Close() + + t.Setenv("ANTHROPIC_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) + defer p.Close() + + resp, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + Shape: llm.ShapeFreeform, + }) + if err != nil { + t.Fatal(err) + } + if resp.Text != "hello world" { + t.Errorf("text=%q want 'hello world'", resp.Text) + } + if _, ok := gotBody["tools"]; ok { + t.Error("freeform request must not send tools") + } +} + +func TestComplete_ToolResultBecomesUserTurn(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"content":[{"type":"text","text":"ok"}]}`) + })) + defer srv.Close() + + t.Setenv("ANTHROPIC_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) + defer p.Close() + + _, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{ + {Role: llm.RoleUser, Content: "q"}, + {Role: llm.RoleAssistant, Content: `{"tool":"x","args":{}}`}, + {Role: llm.RoleTool, Content: `{"result":1}`, ToolName: "x"}, + }, + Shape: llm.ShapeFreeform, + }) + if err != nil { + t.Fatal(err) + } + msgs := gotBody["messages"].([]any) + if len(msgs) != 3 { + t.Fatalf("messages=%d want 3", len(msgs)) + } + last := msgs[2].(map[string]any) + if last["role"] != "user" { + t.Errorf("tool result role=%v want user (emulated protocol)", last["role"]) + } +} + +func TestComplete_APIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":{"type":"invalid_request_error","message":"bad model"}}`) + })) + defer srv.Close() + + t.Setenv("ANTHROPIC_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) + defer p.Close() + + _, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected an error for a non-200 response") + } +} + +func TestName(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m"}) + if p.Name() != "anthropic" { + t.Errorf("Name()=%q", p.Name()) + } +} diff --git a/internal/llm/provider/local/local.go b/internal/llm/provider/local/local.go new file mode 100644 index 0000000..e0cd875 --- /dev/null +++ b/internal/llm/provider/local/local.go @@ -0,0 +1,213 @@ +//go:build llama + +// Package local is the in-process llama.cpp llm.Provider. It wraps the +// CGO model/context from package llm and is the only provider that +// needs a `-tags llama` build; the non-llama build (stub.go) compiles +// a New that reports the provider as unavailable. +package local + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "strings" + "sync" + + "github.com/zzet/gortex/internal/llm" +) + +// assistCtxSize is the KV-cache window for the short-call assist +// context (expand / rerank / verify). Sized for the heaviest user — +// verify with body + callers at ~3.5K tokens for 10 candidates. +const assistCtxSize = 4096 + +// defaultMaxTokens caps a Complete call whose request leaves +// MaxTokens unset. +const defaultMaxTokens = 512 + +// Provider is the local llama.cpp implementation of llm.Provider. +// +// It keeps two llama contexts behind separate mutexes: a small +// assistCtx for the structured search-assist shapes, and a full-size +// mainCtx for the agent tool-loop (ShapeToolCall) and freeform +// generation. Splitting them means a long agent run can't head-of-line +// block a hot-path assist call — at the llama.cpp level both share the +// model weights, each holds its own KV cache. +// +// Every Complete call is self-contained: it resets the context's KV +// cache and prefills the entire conversation passed in the request, so +// no cross-call state lives in a context. That makes per-call locking +// (rather than per-agent-run locking) correct even under concurrent +// callers. +type Provider struct { + cfg llm.LocalConfig + tmpl chatTemplate + + loadOnce sync.Once + loadErr error + model *llm.Model + + assistMu sync.Mutex + assistCtx *llm.Context + + mainMu sync.Mutex + mainCtx *llm.Context +} + +// compile-time assertion that *Provider satisfies the interface. +var _ llm.Provider = (*Provider)(nil) + +// New constructs the local provider from its config sub-block. The +// model is NOT loaded here — that happens lazily on the first Complete +// call so daemon startup isn't slowed. New only validates that a model +// path is set and the file exists, and that the chat template is +// known, so misconfiguration surfaces immediately. +// +// Returns the llm.Provider interface (not the concrete *Provider) so +// the signature matches the non-llama stub and the provider factory +// can treat both builds uniformly. +func New(cfg llm.LocalConfig) (llm.Provider, error) { + path := strings.TrimSpace(cfg.Model) + if path == "" { + return nil, errors.New("local: llm.local.model is empty") + } + if _, err := os.Stat(path); err != nil { + return nil, fmt.Errorf("local: model file: %w", err) + } + tmpl, err := templateByName(cfg.Template) + if err != nil { + return nil, err + } + if cfg.Ctx <= 0 { + cfg.Ctx = 4096 + } + return &Provider{cfg: cfg, tmpl: tmpl}, nil +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return "local" } + +// ensureLoaded mmaps the model and allocates both contexts on first +// use. Idempotent; the stored loadErr is returned on every subsequent +// call once a load has failed. +func (p *Provider) ensureLoaded() error { + p.loadOnce.Do(func() { + m, err := llm.LoadModel(p.cfg.Model, p.cfg.GPULayers) + if err != nil { + p.loadErr = fmt.Errorf("local: load model: %w", err) + return + } + assistCtx, err := m.NewContext(assistCtxSize, 0) + if err != nil { + m.Close() + p.loadErr = fmt.Errorf("local: assist context: %w", err) + return + } + mainCtx, err := m.NewContext(p.cfg.Ctx, 0) + if err != nil { + assistCtx.Close() + m.Close() + p.loadErr = fmt.Errorf("local: main context: %w", err) + return + } + p.model = m + p.assistCtx = assistCtx + p.mainCtx = mainCtx + }) + return p.loadErr +} + +// Complete implements llm.Provider. It flattens the conversation +// through the chat template, installs the GBNF grammar implied by +// req.Shape, and runs greedy decoding with a JSON-complete early-stop. +func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm.CompletionResponse, error) { + if err := ctx.Err(); err != nil { + return llm.CompletionResponse{}, err + } + if err := p.ensureLoaded(); err != nil { + return llm.CompletionResponse{}, err + } + + maxTokens := req.MaxTokens + if maxTokens <= 0 { + maxTokens = defaultMaxTokens + } + prompt := p.tmpl.flatten(req.Messages) + grammar := grammarForShape(req.Shape, req.Tools) + structured := req.Shape != llm.ShapeFreeform + + llmCtx, mu := p.contextFor(req.Shape) + mu.Lock() + defer mu.Unlock() + + llmCtx.Reset() + if err := llmCtx.SetGrammar(grammar); err != nil { + return llm.CompletionResponse{}, fmt.Errorf("local: install grammar: %w", err) + } + + var buf strings.Builder + _, err := llmCtx.Generate(prompt, maxTokens, func(piece string) bool { + buf.WriteString(piece) + // For a structured shape the grammar guarantees the output is + // JSON; stop as soon as the top-level object closes and parses + // instead of waiting on EOS. Freeform runs to EOS / maxTokens. + if structured { + return !jsonComplete(buf.String()) + } + return true + }) + if err != nil { + return llm.CompletionResponse{}, err + } + return llm.CompletionResponse{Text: strings.TrimSpace(buf.String())}, nil +} + +// contextFor routes a shape to its context + mutex. The structured +// search-assist shapes use the small assist context; the agent loop +// and freeform generation use the full-size main context. +func (p *Provider) contextFor(shape llm.JSONShape) (*llm.Context, *sync.Mutex) { + switch shape { + case llm.ShapeExpandTerms, llm.ShapeRerankOrder, llm.ShapeVerifyKeep: + return p.assistCtx, &p.assistMu + default: + return p.mainCtx, &p.mainMu + } +} + +// Close releases the contexts and the model. Safe to call multiple +// times and before any Complete (when nothing was ever loaded). +func (p *Provider) Close() error { + p.assistMu.Lock() + if p.assistCtx != nil { + p.assistCtx.Close() + p.assistCtx = nil + } + p.assistMu.Unlock() + + p.mainMu.Lock() + if p.mainCtx != nil { + p.mainCtx.Close() + p.mainCtx = nil + } + p.mainMu.Unlock() + + if p.model != nil { + p.model.Close() + p.model = nil + } + return nil +} + +// jsonComplete reports whether s is a complete, parseable top-level +// JSON object — the early-stop predicate for grammar-constrained +// generation. +func jsonComplete(s string) bool { + s = strings.TrimSpace(s) + if !strings.HasPrefix(s, "{") || !strings.HasSuffix(s, "}") { + return false + } + var v any + return json.Unmarshal([]byte(s), &v) == nil +} diff --git a/internal/llm/provider/local/stub.go b/internal/llm/provider/local/stub.go new file mode 100644 index 0000000..48f3628 --- /dev/null +++ b/internal/llm/provider/local/stub.go @@ -0,0 +1,24 @@ +//go:build !llama + +// Package local — stub variant for builds without `-tags llama`. +// +// The real local provider is a CGO wrapper around llama.cpp; without +// the build tag it cannot be compiled. New therefore returns a clear +// error so the provider factory can fall through (or report the +// misconfiguration) instead of failing to compile. The HTTP providers +// are unaffected — they are pure Go and available in every build. +package local + +import ( + "errors" + + "github.com/zzet/gortex/internal/llm" +) + +// ErrUnavailable is returned by New in a non-llama build. +var ErrUnavailable = errors.New("local: provider unavailable — gortex was built without `-tags llama`") + +// New reports the local provider as unavailable in this build. +func New(_ llm.LocalConfig) (llm.Provider, error) { + return nil, ErrUnavailable +} diff --git a/internal/llm/provider/local/template.go b/internal/llm/provider/local/template.go new file mode 100644 index 0000000..cc9fcf0 --- /dev/null +++ b/internal/llm/provider/local/template.go @@ -0,0 +1,163 @@ +//go:build llama + +// Package local — chat templates and GBNF grammars. +// +// These are the parts of the old internal/llm/agent and +// internal/llm/svc layers that are specific to a local llama.cpp +// model: how to flatten a provider-neutral []llm.Message into a single +// prompt string with the right turn markers, and how to constrain +// token sampling to a JSON shape via GBNF. The HTTP providers need +// neither — they speak structured messages and json-schema natively. +package local + +import ( + "fmt" + "strings" + + "github.com/zzet/gortex/internal/llm" +) + +// chatTemplate describes how to wrap conversation turns for a given +// model family. The wrappers each take raw content and return the +// fully-marked-up turn; assistPrime is the marker appended right +// before a generate call so the model starts an assistant turn. +type chatTemplate struct { + name string + bos string + system func(content string) string + user func(content string) string + tool func(content string) string + assistEnd string // marker appended after a captured assistant emission + assistPrime string +} + +// templateChatML covers the Qwen2.5 family and Nous Hermes-3 (which +// re-trains Llama-3 onto ChatML). +var templateChatML = chatTemplate{ + name: "chatml", + system: func(c string) string { return "<|im_start|>system\n" + c + "<|im_end|>\n" }, + user: func(c string) string { return "<|im_start|>user\n" + c + "<|im_end|>\n" }, + tool: func(c string) string { return "<|im_start|>tool\n" + c + "<|im_end|>\n" }, + assistEnd: "<|im_end|>\n", + assistPrime: "<|im_start|>assistant\n", +} + +// templateLlama3 covers Meta's Llama-3.x stock instruct format. Used +// by models that keep Llama-3's native template (NOT Hermes-3, which +// switches to ChatML). +var templateLlama3 = chatTemplate{ + name: "llama3", + bos: "<|begin_of_text|>", + system: func(c string) string { + return "<|start_header_id|>system<|end_header_id|>\n\n" + c + "<|eot_id|>" + }, + user: func(c string) string { + return "<|start_header_id|>user<|end_header_id|>\n\n" + c + "<|eot_id|>" + }, + tool: func(c string) string { + return "<|start_header_id|>ipython<|end_header_id|>\n\n" + c + "<|eot_id|>" + }, + assistEnd: "<|eot_id|>", + assistPrime: "<|start_header_id|>assistant<|end_header_id|>\n\n", +} + +// templateByName returns a known chat template by short name. Empty +// falls back to ChatML. +func templateByName(name string) (chatTemplate, error) { + switch name { + case "", "chatml", "qwen", "hermes": + return templateChatML, nil + case "llama3", "llama": + return templateLlama3, nil + } + return chatTemplate{}, fmt.Errorf("local: unknown chat template %q", name) +} + +// flatten renders a provider-neutral conversation into a single prompt +// string and primes an assistant turn. RoleAssistant messages are +// wrapped as a complete assistant emission (prime + content + end); +// RoleTool messages use the tool wrapper. The trailing assistPrime is +// what makes the model start generating an assistant turn. +func (t chatTemplate) flatten(msgs []llm.Message) string { + var b strings.Builder + b.WriteString(t.bos) + for _, m := range msgs { + switch m.Role { + case llm.RoleSystem: + b.WriteString(t.system(m.Content)) + case llm.RoleUser: + b.WriteString(t.user(m.Content)) + case llm.RoleAssistant: + b.WriteString(t.assistPrime) + b.WriteString(m.Content) + b.WriteString(t.assistEnd) + case llm.RoleTool: + b.WriteString(t.tool(m.Content)) + } + } + b.WriteString(t.assistPrime) + return b.String() +} + +// --- GBNF grammars ---------------------------------------------------------- + +// The expand / rerank / verify grammars each accept a single-key JSON +// object whose value is a string array. The array body is fully +// optional so the model is always allowed to emit an empty list — +// for verify that empty list is the load-bearing "honest negative". +const ( + expandGrammar = `root ::= ws "{" ws "\"terms\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws +str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +ws ::= [ \t\n]* +` + rerankGrammar = `root ::= ws "{" ws "\"order\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws +str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +ws ::= [ \t\n]* +` + verifyGrammar = `root ::= ws "{" ws "\"keep\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws +str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +ws ::= [ \t\n]* +` +) + +// buildToolCallGrammar returns a GBNF that accepts +// {"tool":"","args":{}} with +// whitespace tolerance. names must be non-empty. +func buildToolCallGrammar(names []string) string { + alt := make([]string, len(names)) + for i, n := range names { + alt[i] = `"\"" "` + n + `" "\""` + } + toolname := strings.Join(alt, " | ") + return `root ::= ws "{" ws "\"tool\"" ws ":" ws toolname ws "," ws "\"args\"" ws ":" ws object ws "}" ws +toolname ::= ` + toolname + ` +object ::= "{" ws ( pair ( ws "," ws pair )* )? ws "}" +pair ::= string ws ":" ws value +array ::= "[" ws ( value ( ws "," ws value )* )? ws "]" +value ::= string | number | object | array | "true" | "false" | "null" +string ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" +number ::= "-"? ( "0" | [1-9] [0-9]* ) ( "." [0-9]+ )? ( [eE] [-+]? [0-9]+ )? +ws ::= [ \t\n]* +` +} + +// grammarForShape returns the GBNF for a structured shape. ShapeFreeform +// returns "" (no constraint). ShapeToolCall requires the tool name set. +func grammarForShape(shape llm.JSONShape, tools []llm.ToolSpec) string { + switch shape { + case llm.ShapeExpandTerms: + return expandGrammar + case llm.ShapeRerankOrder: + return rerankGrammar + case llm.ShapeVerifyKeep: + return verifyGrammar + case llm.ShapeToolCall: + names := make([]string, len(tools)) + for i, t := range tools { + names[i] = t.Name + } + return buildToolCallGrammar(names) + default: + return "" + } +} diff --git a/internal/llm/provider/ollama/ollama.go b/internal/llm/provider/ollama/ollama.go new file mode 100644 index 0000000..aa004c9 --- /dev/null +++ b/internal/llm/provider/ollama/ollama.go @@ -0,0 +1,173 @@ +// Package ollama is the Ollama daemon llm.Provider. +// +// It is pure Go — available in every build. Ollama runs models +// locally (or on a remote host) and exposes an OpenAI-ish /api/chat +// endpoint. Structured output uses Ollama's `format` field, which +// accepts a JSON schema directly. The agent tool-loop uses the +// emulated protocol — tool calls and results travel as plain text +// turns. +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/zzet/gortex/internal/llm" +) + +// Provider implements llm.Provider against an Ollama daemon. +type Provider struct { + model string + host string + client *http.Client +} + +var _ llm.Provider = (*Provider)(nil) + +// New constructs the Ollama provider. Unlike the hosted providers +// there is no API key; New only requires a model tag and a reachable +// host (default http://localhost:11434). Reachability is not probed +// here — that surfaces on the first Complete call. +func New(cfg llm.OllamaConfig) (llm.Provider, error) { + if strings.TrimSpace(cfg.Model) == "" { + return nil, errors.New("ollama: llm.ollama.model is empty") + } + host := strings.TrimRight(strings.TrimSpace(cfg.Host), "/") + if host == "" { + host = "http://localhost:11434" + } + return &Provider{ + model: cfg.Model, + host: host, + client: &http.Client{Timeout: 120 * time.Second}, + }, nil +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return "ollama" } + +// Close releases idle HTTP connections. +func (p *Provider) Close() error { + p.client.CloseIdleConnections() + return nil +} + +// wire types for the /api/chat endpoint. +type apiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type apiRequest struct { + Model string `json:"model"` + Messages []apiMessage `json:"messages"` + Stream bool `json:"stream"` + Format json.RawMessage `json:"format,omitempty"` + Options map[string]any `json:"options,omitempty"` +} + +type apiResponse struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + Error string `json:"error"` +} + +// Complete implements llm.Provider. +func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm.CompletionResponse, error) { + body := apiRequest{ + Model: p.model, + Messages: mapMessages(req.Messages), + Stream: false, + } + if schema := llm.JSONSchemaFor(req.Shape, req.Tools); schema != nil { + // Ollama's `format` accepts a JSON schema verbatim. + encoded, err := json.Marshal(schema) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("ollama: marshal schema: %w", err) + } + body.Format = encoded + } + if req.MaxTokens > 0 { + body.Options = map[string]any{"num_predict": req.MaxTokens} + } + + raw, err := json.Marshal(body) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("ollama: marshal request: %w", err) + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.host+"/api/chat", bytes.NewReader(raw)) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("ollama: build request: %w", err) + } + httpReq.Header.Set("content-type", "application/json") + + resp, err := p.client.Do(httpReq) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("ollama: request failed: %w", err) + } + defer resp.Body.Close() + payload, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("ollama: read response: %w", err) + } + + var parsed apiResponse + if err := json.Unmarshal(payload, &parsed); err != nil { + return llm.CompletionResponse{}, fmt.Errorf("ollama: decode response (status %d): %w", resp.StatusCode, err) + } + if resp.StatusCode != http.StatusOK { + if parsed.Error != "" { + return llm.CompletionResponse{}, fmt.Errorf("ollama: API error (status %d): %s", resp.StatusCode, parsed.Error) + } + return llm.CompletionResponse{}, fmt.Errorf("ollama: API error (status %d): %s", resp.StatusCode, snippet(payload)) + } + if parsed.Error != "" { + return llm.CompletionResponse{}, fmt.Errorf("ollama: %s", parsed.Error) + } + return llm.CompletionResponse{Text: strings.TrimSpace(parsed.Message.Content)}, nil +} + +// mapMessages flattens the provider-neutral conversation onto Ollama +// chat roles. Tool observations become user turns (emulated tool-call +// protocol). +func mapMessages(in []llm.Message) []apiMessage { + out := make([]apiMessage, 0, len(in)) + for _, m := range in { + switch m.Role { + case llm.RoleSystem: + out = append(out, apiMessage{Role: "system", Content: m.Content}) + case llm.RoleAssistant: + out = append(out, apiMessage{Role: "assistant", Content: m.Content}) + case llm.RoleTool: + out = append(out, apiMessage{Role: "user", Content: renderToolResult(m)}) + default: + out = append(out, apiMessage{Role: "user", Content: m.Content}) + } + } + return out +} + +func renderToolResult(m llm.Message) string { + if m.ToolName != "" { + return "Tool result (" + m.ToolName + "):\n" + m.Content + } + return "Tool result:\n" + m.Content +} + +// snippet truncates a response body for inclusion in an error. +func snippet(b []byte) string { + const max = 300 + s := strings.TrimSpace(string(b)) + if len(s) > max { + return s[:max] + "…" + } + return s +} diff --git a/internal/llm/provider/ollama/ollama_test.go b/internal/llm/provider/ollama/ollama_test.go new file mode 100644 index 0000000..e7a6b62 --- /dev/null +++ b/internal/llm/provider/ollama/ollama_test.go @@ -0,0 +1,131 @@ +package ollama + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/zzet/gortex/internal/llm" +) + +func TestNew_MissingModel(t *testing.T) { + if _, err := New(llm.OllamaConfig{}); err == nil { + t.Fatal("expected error when model is unset") + } +} + +func TestNew_DefaultsHost(t *testing.T) { + p, err := New(llm.OllamaConfig{Model: "qwen"}) + if err != nil { + t.Fatal(err) + } + defer p.Close() + if p.Name() != "ollama" { + t.Errorf("Name()=%q", p.Name()) + } +} + +func TestComplete_StructuredSendsFormatSchema(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/chat" { + t.Errorf("path=%q want /api/chat", r.URL.Path) + } + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"{\"keep\":[\"a\"]}"}}`) + })) + defer srv.Close() + + p, err := New(llm.OllamaConfig{Model: "qwen2.5-coder:7b", Host: srv.URL}) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + resp, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "verify"}}, + Shape: llm.ShapeVerifyKeep, + MaxTokens: 128, + }) + if err != nil { + t.Fatal(err) + } + if resp.Text != `{"keep":["a"]}` { + t.Errorf("text=%q", resp.Text) + } + if gotBody["format"] == nil { + t.Error("structured request must send a `format` schema") + } + if gotBody["stream"] != false { + t.Errorf("stream=%v want false", gotBody["stream"]) + } + opts, _ := gotBody["options"].(map[string]any) + if opts == nil || opts["num_predict"] == nil { + t.Errorf("options.num_predict missing: %v", gotBody["options"]) + } +} + +func TestComplete_FreeformNoFormat(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"message":{"role":"assistant","content":"hi there"}}`) + })) + defer srv.Close() + + p, _ := New(llm.OllamaConfig{Model: "m", Host: srv.URL}) + defer p.Close() + + resp, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + Shape: llm.ShapeFreeform, + }) + if err != nil { + t.Fatal(err) + } + if resp.Text != "hi there" { + t.Errorf("text=%q", resp.Text) + } + if _, ok := gotBody["format"]; ok { + t.Error("freeform request must not send a `format` field") + } +} + +func TestComplete_APIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = io.WriteString(w, `{"error":"model not found"}`) + })) + defer srv.Close() + + p, _ := New(llm.OllamaConfig{Model: "m", Host: srv.URL}) + defer p.Close() + + if _, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + }); err == nil { + t.Fatal("expected an error for a non-200 response") + } +} + +func TestComplete_InlineErrorField(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 200 OK but an error payload — Ollama does this for some failures. + _, _ = io.WriteString(w, `{"error":"something went wrong"}`) + })) + defer srv.Close() + + p, _ := New(llm.OllamaConfig{Model: "m", Host: srv.URL}) + defer p.Close() + + if _, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + }); err == nil { + t.Fatal("expected an error when the response carries an inline error field") + } +} diff --git a/internal/llm/provider/openai/openai.go b/internal/llm/provider/openai/openai.go new file mode 100644 index 0000000..2f55887 --- /dev/null +++ b/internal/llm/provider/openai/openai.go @@ -0,0 +1,216 @@ +// Package openai is the hosted OpenAI Chat Completions llm.Provider. +// +// It is pure Go — available in every build. Structured output uses the +// Chat Completions `response_format` field: a strict json_schema for +// the fixed list shapes (expand / rerank / verify), and a non-strict +// json_schema for the agent tool-call shape whose `args` object is +// intentionally open-ended. The agent tool-loop uses the emulated +// protocol — tool calls and results travel as plain text turns. +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/zzet/gortex/internal/llm" +) + +// Provider implements llm.Provider against api.openai.com. +type Provider struct { + model string + apiKey string + baseURL string + client *http.Client +} + +var _ llm.Provider = (*Provider)(nil) + +// New constructs the OpenAI provider. The API key is read from the env +// var named by cfg.APIKeyEnv (default OPENAI_API_KEY); an unset key is +// a hard error. +func New(cfg llm.RemoteConfig) (llm.Provider, error) { + keyEnv := strings.TrimSpace(cfg.APIKeyEnv) + if keyEnv == "" { + keyEnv = "OPENAI_API_KEY" + } + key := strings.TrimSpace(os.Getenv(keyEnv)) + if key == "" { + return nil, fmt.Errorf("openai: API key env %q is not set", keyEnv) + } + if strings.TrimSpace(cfg.Model) == "" { + return nil, errors.New("openai: llm.openai.model is empty") + } + base := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if base == "" { + base = "https://api.openai.com" + } + return &Provider{ + model: cfg.Model, + apiKey: key, + baseURL: base, + client: &http.Client{Timeout: 120 * time.Second}, + }, nil +} + +// Name implements llm.Provider. +func (p *Provider) Name() string { return "openai" } + +// Close releases idle HTTP connections. +func (p *Provider) Close() error { + p.client.CloseIdleConnections() + return nil +} + +// wire types for the Chat Completions API. +type apiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type apiRequest struct { + Model string `json:"model"` + Messages []apiMessage `json:"messages"` + MaxTokens int `json:"max_completion_tokens,omitempty"` + ResponseFormat map[string]any `json:"response_format,omitempty"` +} + +type apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Error *struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` +} + +// Complete implements llm.Provider. +func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm.CompletionResponse, error) { + body := apiRequest{ + Model: p.model, + Messages: mapMessages(req.Messages), + MaxTokens: req.MaxTokens, + } + if rf := responseFormat(req.Shape, req.Tools); rf != nil { + body.ResponseFormat = rf + } + + raw, err := json.Marshal(body) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("openai: marshal request: %w", err) + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(raw)) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("openai: build request: %w", err) + } + httpReq.Header.Set("content-type", "application/json") + httpReq.Header.Set("authorization", "Bearer "+p.apiKey) + + resp, err := p.client.Do(httpReq) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("openai: request failed: %w", err) + } + defer resp.Body.Close() + payload, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return llm.CompletionResponse{}, fmt.Errorf("openai: read response: %w", err) + } + + var parsed apiResponse + if err := json.Unmarshal(payload, &parsed); err != nil { + return llm.CompletionResponse{}, fmt.Errorf("openai: decode response (status %d): %w", resp.StatusCode, err) + } + if resp.StatusCode != http.StatusOK { + if parsed.Error != nil { + return llm.CompletionResponse{}, fmt.Errorf("openai: API error (status %d): %s: %s", resp.StatusCode, parsed.Error.Type, parsed.Error.Message) + } + return llm.CompletionResponse{}, fmt.Errorf("openai: API error (status %d): %s", resp.StatusCode, snippet(payload)) + } + if len(parsed.Choices) == 0 { + return llm.CompletionResponse{}, errors.New("openai: response carried no choices") + } + return llm.CompletionResponse{Text: strings.TrimSpace(parsed.Choices[0].Message.Content)}, nil +} + +// mapMessages flattens the provider-neutral conversation onto OpenAI +// chat roles. Tool observations become user turns (emulated tool-call +// protocol). +func mapMessages(in []llm.Message) []apiMessage { + out := make([]apiMessage, 0, len(in)) + for _, m := range in { + switch m.Role { + case llm.RoleSystem: + out = append(out, apiMessage{Role: "system", Content: m.Content}) + case llm.RoleAssistant: + out = append(out, apiMessage{Role: "assistant", Content: m.Content}) + case llm.RoleTool: + out = append(out, apiMessage{Role: "user", Content: renderToolResult(m)}) + default: + out = append(out, apiMessage{Role: "user", Content: m.Content}) + } + } + return out +} + +func renderToolResult(m llm.Message) string { + if m.ToolName != "" { + return "Tool result (" + m.ToolName + "):\n" + m.Content + } + return "Tool result:\n" + m.Content +} + +// responseFormat builds the `response_format` payload. The fixed list +// shapes get a strict json_schema (they are fully strict-compliant); +// the tool-call shape gets a non-strict json_schema because its `args` +// object is deliberately unconstrained, which strict mode forbids. +// ShapeFreeform returns nil — no constraint. +func responseFormat(shape llm.JSONShape, tools []llm.ToolSpec) map[string]any { + schema := llm.JSONSchemaFor(shape, tools) + if schema == nil { + return nil + } + strict := shape != llm.ShapeToolCall + return map[string]any{ + "type": "json_schema", + "json_schema": map[string]any{ + "name": schemaName(shape), + "schema": schema, + "strict": strict, + }, + } +} + +func schemaName(shape llm.JSONShape) string { + switch shape { + case llm.ShapeExpandTerms: + return "expand_terms" + case llm.ShapeRerankOrder: + return "rerank_order" + case llm.ShapeVerifyKeep: + return "verify_keep" + case llm.ShapeToolCall: + return "tool_call" + default: + return "response" + } +} + +// snippet truncates a response body for inclusion in an error. +func snippet(b []byte) string { + const max = 300 + s := strings.TrimSpace(string(b)) + if len(s) > max { + return s[:max] + "…" + } + return s +} diff --git a/internal/llm/provider/openai/openai_test.go b/internal/llm/provider/openai/openai_test.go new file mode 100644 index 0000000..fcf36b9 --- /dev/null +++ b/internal/llm/provider/openai/openai_test.go @@ -0,0 +1,147 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/zzet/gortex/internal/llm" +) + +func TestNew_MissingKey(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "") + if _, err := New(llm.RemoteConfig{Model: "gpt-x"}); err == nil { + t.Fatal("expected error when API key env is unset") + } +} + +func TestComplete_StructuredUsesJSONSchema(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("path=%q", r.URL.Path) + } + if r.Header.Get("authorization") != "Bearer test-key" { + t.Errorf("authorization=%q", r.Header.Get("authorization")) + } + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"{\"terms\":[\"jwt\"]}"}}]}`) + })) + defer srv.Close() + + t.Setenv("OPENAI_API_KEY", "test-key") + p, err := New(llm.RemoteConfig{Model: "gpt-x", BaseURL: srv.URL}) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + resp, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "Query: auth"}}, + Shape: llm.ShapeExpandTerms, + MaxTokens: 64, + }) + if err != nil { + t.Fatal(err) + } + if resp.Text != `{"terms":["jwt"]}` { + t.Errorf("text=%q", resp.Text) + } + rf, ok := gotBody["response_format"].(map[string]any) + if !ok { + t.Fatalf("response_format missing/invalid: %v", gotBody["response_format"]) + } + if rf["type"] != "json_schema" { + t.Errorf("response_format.type=%v want json_schema", rf["type"]) + } + js := rf["json_schema"].(map[string]any) + if js["strict"] != true { + t.Errorf("list shapes should request strict json_schema, got strict=%v", js["strict"]) + } +} + +func TestComplete_ToolCallShapeIsNonStrict(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"{\"tool\":\"x\",\"args\":{}}"}}]}`) + })) + defer srv.Close() + + t.Setenv("OPENAI_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) + defer p.Close() + + _, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "go"}}, + Shape: llm.ShapeToolCall, + Tools: []llm.ToolSpec{{Name: "x"}}, + }) + if err != nil { + t.Fatal(err) + } + rf := gotBody["response_format"].(map[string]any) + js := rf["json_schema"].(map[string]any) + if js["strict"] != false { + t.Errorf("tool-call shape must be non-strict (args is open-ended), got strict=%v", js["strict"]) + } +} + +func TestComplete_FreeformNoResponseFormat(t *testing.T) { + var gotBody map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"plain text"}}]}`) + })) + defer srv.Close() + + t.Setenv("OPENAI_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) + defer p.Close() + + resp, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + Shape: llm.ShapeFreeform, + }) + if err != nil { + t.Fatal(err) + } + if resp.Text != "plain text" { + t.Errorf("text=%q", resp.Text) + } + if _, ok := gotBody["response_format"]; ok { + t.Error("freeform request must not send response_format") + } +} + +func TestComplete_APIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = io.WriteString(w, `{"error":{"type":"invalid_api_key","message":"bad key"}}`) + })) + defer srv.Close() + + t.Setenv("OPENAI_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) + defer p.Close() + + if _, err := p.Complete(context.Background(), llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, + }); err == nil { + t.Fatal("expected an error for a non-200 response") + } +} + +func TestName(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "k") + p, _ := New(llm.RemoteConfig{Model: "m"}) + if p.Name() != "openai" { + t.Errorf("Name()=%q", p.Name()) + } +} diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go new file mode 100644 index 0000000..e9b0a24 --- /dev/null +++ b/internal/llm/provider/provider.go @@ -0,0 +1,41 @@ +// Package provider is the llm.Provider factory. It is the one place +// that imports every concrete provider implementation; everything else +// (the svc layer, the agent loop, the cmd demos) depends only on the +// llm.Provider interface and calls New here. +// +// The build-tag split is contained entirely within the `local` +// subpackage: provider.New compiles in every build, and selecting +// "local" without a `-tags llama` binary surfaces as a runtime error +// from local.New, not a compile failure. +package provider + +import ( + "fmt" + + "github.com/zzet/gortex/internal/llm" + "github.com/zzet/gortex/internal/llm/provider/anthropic" + "github.com/zzet/gortex/internal/llm/provider/local" + "github.com/zzet/gortex/internal/llm/provider/ollama" + "github.com/zzet/gortex/internal/llm/provider/openai" +) + +// New builds the llm.Provider selected by cfg.Provider. cfg should +// already have defaults applied (see llm.Config.ApplyDefaults) — the +// HTTP providers rely on the defaulted model / endpoint / key-env +// values. Returns an error when the provider is unknown or +// misconfigured (missing model, unset API key) or, for "local", when +// the binary was built without `-tags llama`. +func New(cfg llm.Config) (llm.Provider, error) { + switch cfg.ProviderName() { + case "local": + return local.New(cfg.Local) + case "anthropic": + return anthropic.New(cfg.Anthropic) + case "openai": + return openai.New(cfg.OpenAI) + case "ollama": + return ollama.New(cfg.Ollama) + default: + return nil, fmt.Errorf("llm: unknown provider %q (want local|anthropic|openai|ollama)", cfg.ProviderName()) + } +} diff --git a/internal/llm/provider/provider_test.go b/internal/llm/provider/provider_test.go new file mode 100644 index 0000000..4055f1a --- /dev/null +++ b/internal/llm/provider/provider_test.go @@ -0,0 +1,62 @@ +package provider + +import ( + "testing" + + "github.com/zzet/gortex/internal/llm" +) + +func TestNew_UnknownProvider(t *testing.T) { + if _, err := New(llm.Config{Provider: "bogus"}.ApplyDefaults()); err == nil { + t.Fatal("expected error for an unknown provider") + } +} + +func TestNew_AnthropicMissingKey(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "") + if _, err := New(llm.Config{Provider: "anthropic"}.ApplyDefaults()); err == nil { + t.Fatal("expected error when ANTHROPIC_API_KEY is unset") + } +} + +func TestNew_AnthropicOK(t *testing.T) { + t.Setenv("ANTHROPIC_API_KEY", "test-key") + p, err := New(llm.Config{Provider: "anthropic"}.ApplyDefaults()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer p.Close() + if p.Name() != "anthropic" { + t.Errorf("Name()=%q want anthropic", p.Name()) + } +} + +func TestNew_OpenAIOK(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "test-key") + p, err := New(llm.Config{Provider: "openai"}.ApplyDefaults()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer p.Close() + if p.Name() != "openai" { + t.Errorf("Name()=%q want openai", p.Name()) + } +} + +func TestNew_OllamaMissingModel(t *testing.T) { + if _, err := New(llm.Config{Provider: "ollama"}.ApplyDefaults()); err == nil { + t.Fatal("expected error when llm.ollama.model is unset") + } +} + +func TestNew_OllamaOK(t *testing.T) { + cfg := llm.Config{Provider: "ollama", Ollama: llm.OllamaConfig{Model: "qwen2.5-coder:7b"}}.ApplyDefaults() + p, err := New(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer p.Close() + if p.Name() != "ollama" { + t.Errorf("Name()=%q want ollama", p.Name()) + } +} diff --git a/internal/llm/svc/assist.go b/internal/llm/svc/assist.go index c4e1b6d..01d5d0c 100644 --- a/internal/llm/svc/assist.go +++ b/internal/llm/svc/assist.go @@ -1,151 +1,56 @@ -//go:build llama - package svc import ( "context" "encoding/json" - "errors" - "fmt" "strings" "github.com/zzet/gortex/internal/llm" - "github.com/zzet/gortex/internal/llm/agent" ) -// assistCtxSize is the KV-cache window for the short-call assist -// context. Sized for the heaviest user — verify with body + callers -// at ~3.5K tokens for 10 candidates. Expansion and rerank use a -// fraction of this; the extra KV cache is cheap (a few hundred MB). -const assistCtxSize = 4096 - -// Token caps per call. Expansion emits at most a small JSON list; -// rerank emits at most one ID per candidate. Verify emits one ID per -// surviving candidate, so its cap is comparable to rerank. +// Token caps per assist call. Expansion emits at most a small JSON +// list; rerank emits one id per candidate; verify emits one id per +// surviving candidate. Handed to the provider as CompletionRequest +// .MaxTokens — the provider applies its own structured-output +// early-stop on top. const ( expandMaxTokens = 192 rerankMaxTokens = 512 verifyMaxTokens = 512 ) -// Grammar for {"terms":[, ...]}. Strings are arbitrary JSON -// strings — callers filter the output to whatever's actually useful. -const expandGrammar = `root ::= ws "{" ws "\"terms\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws -str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" -ws ::= [ \t\n]* -` - -// Grammar for {"order":[, ...]}. Same shape as expand, -// different top-level key — kept as two constants so each call site -// skips a Sprintf on the hot path. -const rerankGrammar = `root ::= ws "{" ws "\"order\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws -str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" -ws ::= [ \t\n]* -` - -// Grammar for {"keep":[, ...]}. The body-grounded verifier -// MUST be allowed to emit an empty array — that's the load-bearing -// "honest negative" signal — so the array body is fully optional. -const verifyGrammar = `root ::= ws "{" ws "\"keep\"" ws ":" ws "[" ws ( str ( ws "," ws str )* )? ws "]" ws "}" ws -str ::= "\"" ( [^"\\] | "\\" ( ["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] ) )* "\"" -ws ::= [ \t\n]* -` - -const expandSystem = `You expand a code-search query into a small set of CONCRETE identifier-style terms a programmer would actually grep for. ` + - `Output strict JSON: {"terms":["","",...]}. ` + - `Include 2 to 5 terms. Each term MUST be a single word with no spaces and no punctuation other than underscores. ` + - ` -RULES: -1. Prefer DOMAIN-SPECIFIC terms over generic English. ` + - `GOOD examples: bcrypt, argon2, scrypt, sha256, hmac, jwt, oauth, pbkdf2, kdf, salt. ` + - `BAD examples (NEVER emit): function, library, algorithm, code, system, data, service, value, info, content, thing, stuff, name, general, common, logic, process, handler, flow, action, helper, util, utility. ` + - ` -2. Prefer terms that are likely SYMBOL names in a codebase (camelCase / snake_case / PascalCase fragments), library or protocol names, well-known acronyms. ` + - ` -3. Do NOT echo the original query words. ` + - ` -4. If the query has no obvious domain-specific neighbours, emit FEWER terms (or an empty array) — quality over quantity.` - -const rerankSystem = `You rerank code-search results by relevance to a natural-language task. ` + - `Given a query and a list of candidate symbols (id | name | optional signature), output strict JSON: {"order":["id1","id2",...]} ` + - `with the most relevant candidates first. ` + - `Use ONLY the provided ids verbatim. Do not invent ids. You may drop ids that are clearly unrelated.` - -const verifySystem = `You filter code-search candidates by reading their BODY, SIGNATURE, and CALLERS, and keeping every one whose code is genuinely about the user's query. ` + - `Each candidate is presented as: - - | | -body: - -callers: -- | -- ... ---- - -Output strict JSON: {"keep":["id1","id2",...]} listing EVERY id whose code is meaningfully related to the query, in your preferred order (most relevant first). - -RULES (follow exactly): -1. Evaluate EACH candidate INDEPENDENTLY. Multiple candidates can be valid matches — keep them all. -2. A name that contains a query word is not enough by itself — read what the code DOES. -3. Cross-reference the CALLERS and the SIGNATURE's parameter types against the query DOMAIN. If a function hashes data but is only called from a "publishDiagnostics" or "renderLog" path with a non-password parameter type, it is NOT about hashing passwords — DROP it. -4. Be GENEROUS, not restrictive: if a candidate's body AND callers AND signature are all plausibly about the query, KEEP it. The user wants signal, not a single "best" pick. -5. Drop a candidate when its body, signature, or callers reveal the operation is on the wrong KIND of data for the query. -6. Returning {"keep":[]} is valid ONLY when NO candidate is genuinely about the query. -7. Use ONLY the provided ids verbatim. Never invent or modify an id.` - -// ensureAssist lazily allocates the short-call context the first time -// an assist method is called. Safe to invoke before locking -// assistMu — the underlying sync.Once handles concurrent first calls. -// Subsequent callers MUST still take assistMu before touching -// assistCtx, since the context itself is single-stream. -func (s *Service) ensureAssist() error { - if err := s.ensureLoaded(); err != nil { - return err - } - s.assistOnce.Do(func() { - c, err := s.model.NewContext(assistCtxSize, 0) - if err != nil { - s.assistErr = fmt.Errorf("llm: assist context: %w", err) - return - } - s.assistCtx = c - }) - return s.assistErr -} - // ExpandQuery turns a natural-language search query into a small set -// of related identifier-style terms via one grammar-constrained -// inference pass. Result is cached by query string. Empty / blank -// input returns an empty result without touching the model. +// of related identifier-style terms via one structured completion. +// Result is cached by query string. Empty / blank input returns an +// empty result without touching the provider. // // The caller is expected to OR the returned terms with the original // query and rerank by combined BM25 score. func (s *Service) ExpandQuery(ctx context.Context, query string) (*llm.ExpandResult, error) { - _ = ctx query = strings.TrimSpace(query) if query == "" { return &llm.ExpandResult{Original: query}, nil } - if cached, ok := s.expandCache.Get(query); ok { return &llm.ExpandResult{Original: query, Terms: cached, Cached: true}, nil } - if err := s.ensureAssist(); err != nil { - return nil, err - } - - tmpl, err := agent.TemplateByName(s.cfg.Template) - if err != nil { - return nil, err + if s.provider == nil { + return nil, errServiceUnavailable } - prompt := buildAssistPrompt(tmpl, expandSystem, "Query: "+query) - raw, err := s.runAssist(prompt, expandGrammar, expandMaxTokens) + resp, err := s.provider.Complete(ctx, llm.CompletionRequest{ + Messages: []llm.Message{ + {Role: llm.RoleSystem, Content: llm.ExpandSystemPrompt(s.profile)}, + {Role: llm.RoleUser, Content: "Query: " + query}, + }, + MaxTokens: expandMaxTokens, + Shape: llm.ShapeExpandTerms, + }) if err != nil { return nil, err } - terms := parseStringList(raw, "terms") + terms := parseStringList(resp.Text, "terms") terms = dedupeFilter(terms, query) // Even an empty result is worth caching — re-issuing the prompt // won't change a model that consistently emits nothing useful. @@ -153,16 +58,15 @@ func (s *Service) ExpandQuery(ctx context.Context, query string) (*llm.ExpandRes return &llm.ExpandResult{Original: query, Terms: terms}, nil } -// RerankSymbols asks the model to reorder a candidate set by +// RerankSymbols asks the provider to reorder a candidate set by // relevance to the query. IDs the model drops are appended at the // tail in original input order so the caller never loses a candidate. -// Empty input returns an empty order without touching the model. +// Empty input returns an empty order without touching the provider. // -// Cache key includes the candidate ID set so two callers passing the -// same query against different candidate pools each get their own +// The cache key includes the candidate ID set so two callers passing +// the same query against different candidate pools each get their own // cache entry; ordering of input candidates does not affect the key. func (s *Service) RerankSymbols(ctx context.Context, query string, cands []llm.RerankCandidate) (*llm.RerankResult, error) { - _ = ctx query = strings.TrimSpace(query) if query == "" || len(cands) == 0 { return &llm.RerankResult{Order: candIDs(cands)}, nil @@ -172,18 +76,18 @@ func (s *Service) RerankSymbols(ctx context.Context, query string, cands []llm.R if cached, ok := s.rerankCache.Get(key); ok { return &llm.RerankResult{Order: cached, Cached: true}, nil } - if err := s.ensureAssist(); err != nil { - return nil, err + if s.provider == nil { + return nil, errServiceUnavailable } - tmpl, err := agent.TemplateByName(s.cfg.Template) - if err != nil { - return nil, err - } - user := buildRerankUser(query, cands) - prompt := buildAssistPrompt(tmpl, rerankSystem, user) - - raw, err := s.runAssist(prompt, rerankGrammar, rerankMaxTokens) + resp, err := s.provider.Complete(ctx, llm.CompletionRequest{ + Messages: []llm.Message{ + {Role: llm.RoleSystem, Content: llm.RerankSystemPrompt(s.profile)}, + {Role: llm.RoleUser, Content: buildRerankUser(query, cands)}, + }, + MaxTokens: rerankMaxTokens, + Shape: llm.ShapeRerankOrder, + }) if err != nil { // Surface the error but keep input order intact so the caller // can still return *something* — search-assist must never @@ -191,7 +95,7 @@ func (s *Service) RerankSymbols(ctx context.Context, query string, cands []llm.R return &llm.RerankResult{Order: candIDs(cands)}, err } - rawOrder := parseStringList(raw, "order") + rawOrder := parseStringList(resp.Text, "order") order := filterToInputAppend(rawOrder, cands) s.rerankCache.Set(key, order) return &llm.RerankResult{Order: order}, nil @@ -203,15 +107,14 @@ func (s *Service) RerankSymbols(ctx context.Context, query string, cands []llm.R // which is a load-bearing honest-negative signal the caller should // preserve rather than fall back to BM25 noise. // -// Cache key includes (query, sorted IDs, body hash) so a re-indexed -// codebase doesn't return stale verifications. Empty input short- -// circuits without touching the model. +// The cache key includes (query, sorted IDs, body hash) so a +// re-indexed codebase doesn't return stale verifications. Empty input +// short-circuits without touching the provider. // // On any inference or parse failure, returns the input order // unchanged with the error — the caller should treat that as "could // not verify" rather than "nothing matched". func (s *Service) VerifyRelevance(ctx context.Context, query string, cands []llm.VerifyCandidate) (*llm.VerifyResult, error) { - _ = ctx query = strings.TrimSpace(query) if query == "" || len(cands) == 0 { return &llm.VerifyResult{Keep: verifyIDs(cands)}, nil @@ -221,71 +124,37 @@ func (s *Service) VerifyRelevance(ctx context.Context, query string, cands []llm if cached, ok := s.verifyCache.Get(key); ok { return &llm.VerifyResult{Keep: cached, Cached: true}, nil } - if err := s.ensureAssist(); err != nil { - return nil, err - } - - tmpl, err := agent.TemplateByName(s.cfg.Template) - if err != nil { - return nil, err + if s.provider == nil { + return nil, errServiceUnavailable } - user := buildVerifyUser(query, cands) - prompt := buildAssistPrompt(tmpl, verifySystem, user) - raw, err := s.runAssist(prompt, verifyGrammar, verifyMaxTokens) + resp, err := s.provider.Complete(ctx, llm.CompletionRequest{ + Messages: []llm.Message{ + {Role: llm.RoleSystem, Content: llm.VerifySystemPrompt(s.profile)}, + {Role: llm.RoleUser, Content: buildVerifyUser(query, cands)}, + }, + MaxTokens: verifyMaxTokens, + Shape: llm.ShapeVerifyKeep, + }) if err != nil { // On failure, surface the error and keep all input candidates // — better to over-include than to silently drop them. return &llm.VerifyResult{Keep: verifyIDs(cands)}, err } - rawKeep := parseStringList(raw, "keep") + rawKeep := parseStringList(resp.Text, "keep") keep := filterKeepToInput(rawKeep, cands) s.verifyCache.Set(key, keep) return &llm.VerifyResult{Keep: keep}, nil } -// runAssist is the shared inference primitive for the two assist -// methods. Holds assistMu, resets KV cache, installs the grammar, -// generates with the jsonComplete early-stop predicate, and returns -// the raw model output trimmed of surrounding whitespace. -func (s *Service) runAssist(prompt, grammar string, maxTokens int) (string, error) { - s.assistMu.Lock() - defer s.assistMu.Unlock() - - if s.assistCtx == nil { - return "", errors.New("llm: assist context not initialised") - } - - s.assistCtx.Reset() - if err := s.assistCtx.SetGrammar(grammar); err != nil { - return "", fmt.Errorf("llm: install assist grammar: %w", err) - } - - var buf strings.Builder - _, err := s.assistCtx.Generate(prompt, maxTokens, func(piece string) bool { - buf.WriteString(piece) - return !assistJSONComplete(buf.String()) - }) - if err != nil { - return "", err - } - return strings.TrimSpace(buf.String()), nil -} - -// buildAssistPrompt is the single-turn equivalent of agent.initialPrompt: -// no tool list, no AssistEnd round-trip — just System + User + AssistPrime. -func buildAssistPrompt(tmpl agent.ChatTemplate, system, user string) string { - return tmpl.BOS + tmpl.System(system) + tmpl.User(user) + tmpl.AssistPrime -} - // buildVerifyUser formats the candidate list for the body-grounded // verify prompt. Each candidate ships with its body and a compact // callers block — the callers carry independent contextual signal // that lets the model distinguish "same operation, different data" // cases the body alone can't disambiguate. Bodies and signatures -// must be pre-truncated by the caller — this is a formatter, not -// the place to enforce length limits. +// must be pre-truncated by the caller — this is a formatter, not the +// place to enforce length limits. func buildVerifyUser(query string, cands []llm.VerifyCandidate) string { var b strings.Builder b.WriteString("Query: ") @@ -358,19 +227,6 @@ func buildRerankUser(query string, cands []llm.RerankCandidate) string { return b.String() } -// assistJSONComplete is the same shape as agent.jsonComplete: stop -// generation as soon as the top-level JSON object closes and parses. -// Replicated rather than exported from package agent to keep that -// package's surface minimal. -func assistJSONComplete(s string) bool { - s = strings.TrimSpace(s) - if !strings.HasPrefix(s, "{") || !strings.HasSuffix(s, "}") { - return false - } - var v any - return json.Unmarshal([]byte(s), &v) == nil -} - // parseStringList extracts a top-level JSON string array under the // given key. Returns nil on any parse failure — the caller decides // the fallback behaviour. @@ -406,8 +262,7 @@ func parseStringList(raw, key string) []string { // // Borderline / domain-bearing words like `encryption`, `algorithm`, // `security`, `key` are deliberately NOT here: they can be load-bearing -// in some codebases (a crypto library is a different story than a code -// intelligence tool). Keep this list short — over-filtering throws +// in some codebases. Keep this list short — over-filtering throws // away the only signal expansion has to offer. var expansionStoplist = map[string]bool{ "function": true, "functions": true, "method": true, "methods": true, @@ -419,23 +274,23 @@ var expansionStoplist = map[string]bool{ "data": true, "datum": true, "value": true, "values": true, "object": true, "objects": true, "item": true, "items": true, - "thing": true, "things": true, - "info": true, "information": true, + "thing": true, "things": true, + "info": true, "information": true, "content": true, "contents": true, "stuff": true, "general": true, "common": true, "basic": true, "simple": true, "main": true, - "text": true, + "text": true, // Generic verbs/nouns that slip through with NL queries — observed // in the wild: "where is the rerank logic for search results" pulled // in "logic" as an expansion term, which broadens BM25 enormously // against any *_logic or logical_* identifier. "logic": true, "logical": true, "process": true, "processing": true, - "handle": true, "handler": true, "handling": true, - "flow": true, "flows": true, - "action": true, "actions": true, - "helper": true, "helpers": true, - "util": true, "utils": true, "utility": true, "utilities": true, + "handle": true, "handler": true, "handling": true, + "flow": true, "flows": true, + "action": true, "actions": true, + "helper": true, "helpers": true, + "util": true, "utils": true, "utility": true, "utilities": true, } // minExpansionTermLen rejects terms shorter than this. Sub-3 char @@ -444,6 +299,11 @@ var expansionStoplist = map[string]bool{ // short identifiers like `js`, `db`, `ui` get through. const minExpansionTermLen = 3 +// maxExpansionTerms caps the per-call expansion regardless of model +// output. Each extra term adds a BM25 sweep + candidate-pool growth, +// so trimming aggressively saves both latency and rerank prompt size. +const maxExpansionTerms = 5 + // dedupeFilter trims, lowercases for comparison, and drops terms that // are empty, duplicates, the original query, in expansionStoplist, or // shorter than minExpansionTermLen. Preserves order of the surviving @@ -474,11 +334,6 @@ func dedupeFilter(terms []string, query string) []string { return out } -// maxExpansionTerms caps the per-call expansion regardless of model -// output. Each extra term adds a BM25 sweep + candidate-pool growth, -// so trimming aggressively saves both latency and rerank prompt size. -const maxExpansionTerms = 5 - // candIDs extracts just the ID slice from a candidate list, // preserving order. Returned for fallback paths so the caller still // gets a valid (if unhelpful) ordering. @@ -494,8 +349,8 @@ func candIDs(cands []llm.RerankCandidate) []string { } // verifyIDs is the VerifyCandidate equivalent of candIDs — used on -// fallback paths where we want to preserve every input ID rather -// than drop them silently. +// fallback paths where we want to preserve every input ID rather than +// drop them silently. func verifyIDs(cands []llm.VerifyCandidate) []string { if len(cands) == 0 { return nil @@ -510,8 +365,8 @@ func verifyIDs(cands []llm.VerifyCandidate) []string { // filterKeepToInput is the VerifyResult equivalent of // filterToInputAppend but with one critical difference: dropped IDs // are NOT appended at the tail. An empty result IS the load-bearing -// honest-negative signal, so callers must see exactly what the -// model decided to keep. +// honest-negative signal, so callers must see exactly what the model +// decided to keep. // // Hallucinated and duplicate IDs are still filtered defensively. func filterKeepToInput(modelKeep []string, cands []llm.VerifyCandidate) []string { diff --git a/internal/llm/svc/assist_e2e_test.go b/internal/llm/svc/assist_e2e_test.go index 3dd7359..db6f3d1 100644 --- a/internal/llm/svc/assist_e2e_test.go +++ b/internal/llm/svc/assist_e2e_test.go @@ -36,10 +36,13 @@ func TestE2E_AssistAgainstRealModel(t *testing.T) { } cfg := llm.Config{ - Model: modelPath, - Template: "chatml", - Ctx: 4096, + Provider: "local", MaxSteps: 16, + Local: llm.LocalConfig{ + Model: modelPath, + Template: "chatml", + Ctx: 4096, + }, }.ApplyDefaults() svcInst := NewService(cfg, llm.MockBackend{}) diff --git a/internal/llm/svc/assist_test.go b/internal/llm/svc/assist_test.go index 91f85e5..b54278b 100644 --- a/internal/llm/svc/assist_test.go +++ b/internal/llm/svc/assist_test.go @@ -1,5 +1,3 @@ -//go:build llama - package svc import ( @@ -254,23 +252,3 @@ func TestCandIDs(t *testing.T) { t.Fatalf("got=%v want=%v", got, want) } } - -func TestAssistJSONComplete(t *testing.T) { - cases := []struct { - in string - want bool - }{ - {"{}", true}, - {`{"a":1}`, true}, - {" {} ", true}, - {"{", false}, - {`{"a":`, false}, - {"not json", false}, - {"", false}, - } - for _, tc := range cases { - if got := assistJSONComplete(tc.in); got != tc.want { - t.Fatalf("in=%q got=%v want=%v", tc.in, got, tc.want) - } - } -} diff --git a/internal/llm/svc/cache.go b/internal/llm/svc/cache.go index 55987b8..dc4647a 100644 --- a/internal/llm/svc/cache.go +++ b/internal/llm/svc/cache.go @@ -1,5 +1,3 @@ -//go:build llama - package svc import ( diff --git a/internal/llm/svc/service.go b/internal/llm/svc/service.go index 7cc3be0..97a0c8f 100644 --- a/internal/llm/svc/service.go +++ b/internal/llm/svc/service.go @@ -1,172 +1,154 @@ -//go:build llama - -// Package svc is the runner layer that ties the LLM model (package -// llm) to the agent loop (package llm/agent). It lives in its own -// package to break the import cycle that would otherwise exist -// between `llm` (defines Context, Backend) and `llm/agent` (depends -// on those types). +// Package svc is the runner layer that ties an llm.Provider to the +// agent tool-loop (package llm/agent) and the search-assist passes +// (assist.go). It lives in its own package to break the import cycle +// that would otherwise exist between `llm` and `llm/agent`. +// +// svc is pure Go: the `-tags llama` build-tag split is contained +// entirely within the provider packages. The daemon links the same +// Service whether or not the tag is set — without it only the `local` +// provider is unavailable; the HTTP providers still work, and a +// disabled service degrades cleanly (Enabled() reports false). package svc import ( "context" "errors" - "fmt" "strings" - "sync" "time" "github.com/zzet/gortex/internal/llm" "github.com/zzet/gortex/internal/llm/agent" + "github.com/zzet/gortex/internal/llm/provider" ) -// Service is the reusable in-process LLM access point. Wraps a -// lazily-loaded llama.cpp model plus a Backend (typically an -// InProcessBackend pointing at the daemon's *query.Engine). Two -// consumption shapes: -// -// - Generate: one-shot prompt → text. Used by future wiki / doc -// generation features that don't need a tool-calling loop. -// - RunAgent: grammar-constrained agent loop that uses Backend's -// tools to navigate the graph and produce a synthesized answer. -// Used by the MCP `ask` tool handler. +// errServiceUnavailable is returned by operational methods when no +// provider could be constructed (disabled config, build without +// `-tags llama` for the local provider, missing API key, ...). +var errServiceUnavailable = errors.New("llm: service unavailable — no provider configured") + +// Service is the reusable LLM access point. It wraps a constructed +// llm.Provider plus a Backend (typically an InProcessBackend pointing +// at the daemon's *query.Engine). Three consumption shapes: // -// Both go through the same model and the same inference mutex — -// llama.cpp is single-stream on a given device. +// - Generate: one-shot prompt → text. Freeform completion. +// - RunAgent: the grammar/schema-constrained tool-calling loop that +// uses the Backend's tools to navigate the graph. Backs the MCP +// `ask` tool. +// - ExpandQuery / RerankSymbols / VerifyRelevance: the search-assist +// passes — short structured completions backing the `search_symbols` +// `assist` argument (see assist.go). // -// In addition to the full-size RunAgent / Generate contexts, Service -// keeps a pre-warmed *assist context* — a smaller llama context used -// for short single-shot grammar-constrained calls (ExpandQuery, -// RerankSymbols). The assist context has its own mutex so a long -// `ask` doesn't head-of-line block hot-path NL search calls; at the -// llama.cpp level the two contexts share the model weights but each -// holds its own KV cache. +// The active provider is chosen by llm.Config.Provider. The prompt +// tier (profile) is derived from the provider's Name() so the assist +// passes prompt small local models and hosted frontier models +// differently — see llm.ProfileForProvider. type Service struct { - cfg llm.Config - backend llm.Backend - - loadOnce sync.Once - model *llm.Model - loadErr error + cfg llm.Config + backend llm.Backend + provider llm.Provider + providerErr error + profile llm.PromptProfile - infer sync.Mutex - - assistOnce sync.Once - assistCtx *llm.Context - assistErr error - assistMu sync.Mutex expandCache *assistCache rerankCache *assistCache verifyCache *assistCache } -// NewService is cheap — it just stores the config and backend. The -// model is mmap'd and Metal kernels compiled lazily on the first -// Generate / RunAgent call, so daemon startup isn't slowed. +// NewService constructs the service and its provider. Provider +// construction is cheap for every backend — the local provider only +// validates its config here and defers the model mmap to the first +// call. A disabled or misconfigured config yields a Service whose +// Enabled() reports false; the construction error is retained and +// surfaced via ProviderErr. func NewService(cfg llm.Config, backend llm.Backend) *Service { - return &Service{ - cfg: cfg.ApplyDefaults(), + cfg = cfg.ApplyDefaults() + s := &Service{ + cfg: cfg, backend: backend, expandCache: newAssistCache(256), rerankCache: newAssistCache(256), verifyCache: newAssistCache(256), } + if cfg.IsEnabled() && backend != nil { + p, err := provider.New(cfg) + if err != nil { + s.providerErr = err + } else { + s.provider = p + s.profile = llm.ProfileForProvider(p.Name()) + } + } + return s } -// Enabled reports whether the service has a valid configuration -// (non-empty model path) and a backend. Callers should check this -// before registering features that depend on the service. +// Enabled reports whether the service can do real work — a provider +// was constructed and a backend is wired. Callers gate feature / +// tool registration on this. func (s *Service) Enabled() bool { - return s != nil && s.cfg.IsEnabled() && s.backend != nil + return s != nil && s.provider != nil && s.backend != nil } -func (s *Service) ensureLoaded() error { - s.loadOnce.Do(func() { - if !s.cfg.IsEnabled() { - s.loadErr = errors.New("llm: model path is empty") - return - } - m, err := llm.LoadModel(s.cfg.Model, s.cfg.GPULayers) - if err != nil { - s.loadErr = fmt.Errorf("llm: load model: %w", err) - return - } - s.model = m - }) - return s.loadErr +// ProviderErr returns the error from provider construction, if any. +// Enabled() is false whenever this is non-nil; the daemon entrypoint +// surfaces it as a startup warning so a misconfigured `llm:` block +// (unset API key, model file missing) isn't silently ignored. +func (s *Service) ProviderErr() error { + if s == nil { + return nil + } + return s.providerErr } -// Close releases the underlying model and any assist context. Safe -// to call multiple times. After Close, every operational method -// returns an error. -func (s *Service) Close() error { - // Order matters: drop the assist context first so its KV cache - // is freed before the model itself goes away. - s.assistMu.Lock() - if s.assistCtx != nil { - s.assistCtx.Close() - s.assistCtx = nil +// ProviderName returns the active provider's name, or "" when no +// provider was constructed. +func (s *Service) ProviderName() string { + if s == nil || s.provider == nil { + return "" } - s.assistMu.Unlock() + return s.provider.Name() +} - s.infer.Lock() - defer s.infer.Unlock() - if s.model != nil { - s.model.Close() - s.model = nil +// Close releases the provider's resources (model weights, idle HTTP +// connections). Safe to call multiple times and on a disabled service. +func (s *Service) Close() error { + if s == nil || s.provider == nil { + return nil } - return nil + return s.provider.Close() } -// Generate runs one-shot inference: prompt in, generated text out. -// No agent loop, no tools — just the model. Intended for future -// summarization / wiki generation use cases where the caller assembles -// the prompt with relevant code context itself. -// -// maxTokens caps the generation length; 0 falls back to a sensible -// default (1024). The model's chat template is NOT applied — pass a -// fully-formatted prompt. +// Generate runs one-shot freeform inference: prompt in, generated text +// out. No agent loop, no tools. maxTokens caps generation length; 0 +// falls back to a sensible default. func (s *Service) Generate(ctx context.Context, prompt string, maxTokens int) (string, error) { - _ = ctx // greedy inference is uninterruptible in the current wrapper - if err := s.ensureLoaded(); err != nil { - return "", err + if s.provider == nil { + return "", errServiceUnavailable } if maxTokens <= 0 { maxTokens = 1024 } - - s.infer.Lock() - defer s.infer.Unlock() - - llmCtx, err := s.model.NewContext(s.cfg.Ctx, 0) - if err != nil { - return "", fmt.Errorf("llm: new context: %w", err) - } - defer llmCtx.Close() - - var out strings.Builder - _, err = llmCtx.Generate(prompt, maxTokens, func(piece string) bool { - out.WriteString(piece) - return true + resp, err := s.provider.Complete(ctx, llm.CompletionRequest{ + Messages: []llm.Message{{Role: llm.RoleUser, Content: prompt}}, + MaxTokens: maxTokens, + Shape: llm.ShapeFreeform, }) if err != nil { - return out.String(), err + return "", err } - return out.String(), nil + return resp.Text, nil } -// RunAgent runs the grammar-constrained tool-calling agent loop. The -// agent issues tool calls against the configured Backend (typically an -// InProcessBackend wired to gortex's *query.Engine) and synthesizes a -// final answer via the model's final_answer tool. +// RunAgent runs the structured tool-calling agent loop. The agent +// issues tool calls against the configured Backend and synthesizes a +// final answer via the final_answer tool. // -// Returned AgentAnswer always has at least Answer/Error populated — -// non-nil even on error paths. +// The returned AgentAnswer always has at least Answer/Error populated +// — non-nil even on error paths. func (s *Service) RunAgent(ctx context.Context, opts llm.RunAgentOptions) (*llm.AgentAnswer, error) { - _ = ctx answer := &llm.AgentAnswer{Scope: opts.Scope, ChainMode: opts.Chain} - if err := s.ensureLoaded(); err != nil { - answer.Error = err.Error() - return answer, err + if s.provider == nil { + answer.Error = errServiceUnavailable.Error() + return answer, errServiceUnavailable } if strings.TrimSpace(opts.Question) == "" { err := errors.New("llm: question is empty") @@ -174,12 +156,6 @@ func (s *Service) RunAgent(ctx context.Context, opts llm.RunAgentOptions) (*llm. return answer, err } - tmpl, err := agent.TemplateByName(s.cfg.Template) - if err != nil { - answer.Error = err.Error() - return answer, err - } - systemExtras := opts.SystemExtras if systemExtras == "" { if opts.Chain { @@ -196,24 +172,14 @@ func (s *Service) RunAgent(ctx context.Context, opts llm.RunAgentOptions) (*llm. tools = agent.GortexTools(s.backend, opts.Scope) } - s.infer.Lock() - defer s.infer.Unlock() - - llmCtx, err := s.model.NewContext(s.cfg.Ctx, 0) - if err != nil { - answer.Error = err.Error() - return answer, err - } - defer llmCtx.Close() - - ag, err := agent.New(llmCtx, tools, tmpl) + ag, err := agent.New(s.provider, tools) if err != nil { answer.Error = err.Error() return answer, err } t0 := time.Now() - answerText, transcript, runErr := ag.Run(systemExtras, opts.Question, s.cfg.MaxSteps) + answerText, transcript, runErr := ag.Run(ctx, systemExtras, opts.Question, s.cfg.MaxSteps) answer.ElapsedMs = time.Since(t0).Milliseconds() answer.Answer = answerText @@ -239,8 +205,8 @@ func (s *Service) RunAgent(ctx context.Context, opts llm.RunAgentOptions) (*llm. return answer, runErr } -// promptSimple — P2-equivalent rules from the bench experiments. Tight -// system prompt for single-hop / cross-repo lookups. +// promptSimple — tight system-prompt extras for single-hop / +// cross-repo lookups. const promptSimple = `RULES (follow these exactly): - If the user gives you only a bare name (not a path-qualified id like "pkg/x.Foo"), you MUST first call search_symbols to resolve it to an id before calling get_callers. - For search_symbols, pass ONLY the bare symbol name as "query" — no prepositions, no package qualifiers, no extra words. @@ -249,9 +215,8 @@ const promptSimple = `RULES (follow these exactly): - Never call the same tool with the same args twice in a row. - When you have enough information, call final_answer summarising what you found.` -// promptChain — chain-mode rules with the explicit "no get_callers" -// direction warning that we proved closes Coder-7B's directional -// confusion in the bench. +// promptChain — chain-mode extras with the explicit "no get_callers" +// direction warning. const promptChain = `RULES (follow these exactly): - You are tracing a cross-system call chain. Output one tool call per turn. - DIRECTION MATTERS. Only these tools are correct in chain mode: diff --git a/internal/llm/svc/service_stub.go b/internal/llm/svc/service_stub.go deleted file mode 100644 index 52b936c..0000000 --- a/internal/llm/svc/service_stub.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !llama - -// Package svc — stub variant for builds without `-tags llama`. Every -// operational method returns errServiceUnavailable so callers don't -// need conditional imports; they just check Service.Enabled(). -package svc - -import ( - "context" - "errors" - - "github.com/zzet/gortex/internal/llm" -) - -var errServiceUnavailable = errors.New("llm: built without -tags llama; LLM service unavailable") - -// Service is the pure-Go stub. Same exported surface as the llama -// build's Service so non-llama compilation succeeds without -// conditional imports at every call site. Every operational method -// returns errServiceUnavailable; lifecycle methods are no-ops. -type Service struct{} - -// NewService returns a disabled stub service. The cfg and backend -// arguments are accepted for API compatibility but ignored. -func NewService(_ llm.Config, _ llm.Backend) *Service { return &Service{} } - -// Enabled reports whether the service can do real work. Always false -// in the stub build — callers should use this to gate tool -// registration / docs generation features. -func (s *Service) Enabled() bool { return false } - -// Generate is a no-op in the stub; returns errServiceUnavailable. -func (s *Service) Generate(_ context.Context, _ string, _ int) (string, error) { - return "", errServiceUnavailable -} - -// RunAgent is a no-op in the stub; returns errServiceUnavailable. -func (s *Service) RunAgent(_ context.Context, _ llm.RunAgentOptions) (*llm.AgentAnswer, error) { - return nil, errServiceUnavailable -} - -// ExpandQuery is a no-op in the stub; returns errServiceUnavailable. -// Callers should check Enabled() first and skip the call entirely. -func (s *Service) ExpandQuery(_ context.Context, _ string) (*llm.ExpandResult, error) { - return nil, errServiceUnavailable -} - -// RerankSymbols is a no-op in the stub; returns errServiceUnavailable. -// Callers should check Enabled() first and skip the call entirely. -func (s *Service) RerankSymbols(_ context.Context, _ string, _ []llm.RerankCandidate) (*llm.RerankResult, error) { - return nil, errServiceUnavailable -} - -// VerifyRelevance is a no-op in the stub; returns errServiceUnavailable. -// Callers should check Enabled() first and skip the call entirely. -func (s *Service) VerifyRelevance(_ context.Context, _ string, _ []llm.VerifyCandidate) (*llm.VerifyResult, error) { - return nil, errServiceUnavailable -} - -// Close is a no-op in the stub. -func (s *Service) Close() error { return nil } diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 7de0767..cbb0d50 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -123,12 +123,12 @@ type Server struct { // SetLSPDiagnosticsBroadcasting; nil until then. diagBroadcaster *diagnosticsBroadcaster - // llmService is the optional in-process LLM agent service backing - // the `ask` MCP tool (plus future internal callers like wiki/doc - // generation). nil until SetLLMService is called by the daemon - // entrypoint with a real service instance — in which case the - // `ask` tool is registered. In builds without `-tags llama`, the - // service is a stub that returns errServiceUnavailable. + // llmService is the optional LLM service backing the `ask` MCP tool + // and the `search_symbols` assist modes. nil until SetLLMService is + // called by the daemon entrypoint. The service wraps whichever + // provider `llm.provider` selects (local llama.cpp / Anthropic / + // OpenAI / Ollama); when the provider can't be constructed it + // reports Enabled() == false and the dependent tools stay absent. llmService *svc.Service // resourcesNotifier overrides the live mcpServer when pushing @@ -565,13 +565,14 @@ func NewServer(engine *query.Engine, g *graph.Graph, idx *indexer.Indexer, watch return s } -// SetLLMService attaches the in-process LLM service to the server -// and registers the `ask` MCP tool. Call after NewServer; without -// this, the `ask` MCP tool is not registered (clean degradation for -// builds / deployments without an LLM). +// SetLLMService attaches the LLM service to the server and registers +// the `ask` MCP tool. Call after NewServer; without this, the `ask` +// MCP tool is not registered (clean degradation for deployments +// without an LLM). // -// Safe to call with a stub service (built without `-tags llama`) — -// tool registration is a no-op in that case. +// Safe to call with a disabled service (no provider configured, or +// provider construction failed) — registerLLMTools gates on +// Service.Enabled() and skips registration in that case. // // Lifecycle: the server does NOT take ownership of the service; the // daemon entrypoint that constructed the service is responsible for @@ -587,16 +588,25 @@ func (s *Server) SetLLMService(service *svc.Service) { // it. A zero or disabled cfg is a no-op — safe to call // unconditionally. // -// Available in both `-tags llama` and pure-Go builds: the stub -// Service is what gets attached without the tag, and the stub -// registerLLMTools then skips registration. +// The provider is chosen by cfg.Provider. Selecting "local" in a +// binary built without `-tags llama` — or any provider with a missing +// model / API key — leaves the service disabled; the construction +// error is logged as a warning rather than failing daemon startup, so +// a misconfigured `llm:` block degrades cleanly (the `ask` tool and +// `search_symbols` assist modes are simply absent). func (s *Server) SetupLLM(cfg llm.Config) { cfg = cfg.MergeEnv() if !cfg.IsEnabled() { return } backend := svc.NewInProcessBackend(s.engine, s.effectiveContractRegistry) - s.SetLLMService(svc.NewService(cfg, backend)) + service := svc.NewService(cfg, backend) + s.SetLLMService(service) + if err := service.ProviderErr(); err != nil { + s.logger.Warn("LLM provider unavailable — `ask` tool and search assist disabled", + zap.String("provider", cfg.ProviderName()), + zap.Error(err)) + } } // InitFeedback initializes the feedback manager for cross-session feedback persistence. diff --git a/internal/mcp/tools_core.go b/internal/mcp/tools_core.go index 3f6908c..6f41338 100644 --- a/internal/mcp/tools_core.go +++ b/internal/mcp/tools_core.go @@ -576,7 +576,7 @@ func (s *Server) registerCoreTools() { mcp.WithString("project", mcp.Description("Filter results to repositories in a specific project")), mcp.WithString("ref", mcp.Description("Filter results to repositories with a specific reference tag")), mcp.WithString("kind", mcp.Description("Filter to one or more node kinds (comma-separated). Standard kinds: function, method, type, interface, variable, constant, field, file, package, import, contract. Coverage kinds: param, closure, enum_member, generic_param, module, table, column, config_key, flag, event, migration, fixture, todo, team, license, release.")), - mcp.WithString("assist", mcp.Description("LLM assist mode: \"auto\" (default — engages on natural-language queries, skips identifier lookups), \"on\" (force engage), \"off\" (bypass), \"deep\" (on + a body-grounded verification pass that reads candidate code and HONESTLY drops irrelevant matches — slower, may return empty results when nothing genuinely matches). Requires the daemon to be built with -tags llama and a configured model; otherwise behaves as \"off\".")), + mcp.WithString("assist", mcp.Description("LLM assist mode: \"auto\" (default — engages on natural-language queries, skips identifier lookups), \"on\" (force engage), \"off\" (bypass), \"deep\" (on + a body-grounded verification pass that reads candidate code and HONESTLY drops irrelevant matches — slower, may return empty results when nothing genuinely matches). Requires an LLM provider configured via `llm.provider` (local / anthropic / openai / ollama); behaves as \"off\" when none is available.")), ), s.handleSearchSymbols, ) diff --git a/internal/mcp/tools_llm.go b/internal/mcp/tools_llm.go index 7141142..944a323 100644 --- a/internal/mcp/tools_llm.go +++ b/internal/mcp/tools_llm.go @@ -1,5 +1,3 @@ -//go:build llama - package mcp import ( @@ -23,7 +21,7 @@ func (s *Server) registerLLMTools() { } s.mcpServer.AddTool( mcp.NewTool("ask", - mcp.WithDescription("Ask a local research agent (small GGUF model running in-process via llama.cpp) to navigate the gortex graph and return a synthesized answer. Use this instead of issuing many search_symbols / get_callers / contracts calls yourself when the question is open-ended or requires multi-hop reasoning across repos — the agent does that work locally and returns a filtered answer. Set chain=true for cross-system call-chain tracing (consumer → contract → provider → downstream)."), + mcp.WithDescription("Ask a research agent to navigate the gortex graph and return a synthesized answer. The agent runs on whichever LLM provider is configured (`llm.provider`): an in-process llama.cpp model, or a hosted Anthropic / OpenAI / Ollama backend. Use this instead of issuing many search_symbols / get_callers / contracts calls yourself when the question is open-ended or requires multi-hop reasoning across repos — the agent does that work and returns a filtered answer. Set chain=true for cross-system call-chain tracing (consumer → contract → provider → downstream)."), mcp.WithString("question", mcp.Required(), mcp.Description("Natural-language question about the indexed codebase. Examples: \"who calls NewServer in the mcp package?\", \"trace the path from web's /v1/stats consumer to the gortex handler\".")), mcp.WithString("repo", mcp.Description("Optional repo-prefix scope (e.g. \"gortex-cloud\"). Restricts the agent's tool calls to one repo. Leave empty for cross-repo questions.")), mcp.WithString("project", mcp.Description("Optional project scope.")), @@ -66,4 +64,3 @@ func (s *Server) handleAsk(ctx context.Context, req mcp.CallToolRequest) (*mcp.C } return mcp.NewToolResultText(string(out)), nil } - diff --git a/internal/mcp/tools_llm_stub.go b/internal/mcp/tools_llm_stub.go deleted file mode 100644 index 24a814b..0000000 --- a/internal/mcp/tools_llm_stub.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !llama - -package mcp - -// registerLLMTools is the no-op stub used when gortex is built -// without `-tags llama`. The real implementation in tools_llm.go -// registers the `ask` MCP tool and wires it to the LLM service. -// -// Method exists on *Server in both build variants so NewServer can -// call s.registerLLMTools() unconditionally. -func (s *Server) registerLLMTools() {} From f1de3baa2167e37c74b34f902fb80a1a22d527a3 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Thu, 14 May 2026 18:56:33 +0200 Subject: [PATCH 5/6] Actualise README --- README.md | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/README.md b/README.md index 3929664..6cb08c9 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ For Homebrew, package managers (`.deb` / `.rpm` / `.apk`), direct binary downloa - **Infrastructure graph layer** — first-class `KindResource` (Kubernetes Deployments, Services, Ingresses, ConfigMaps, Secrets, CronJobs), `KindKustomization` (overlay tree), and `KindImage` (Dockerfile FROM targets and K8s `container.image`) with `depends_on` / `configures` / `mounts` / `exposes` / `uses_env` edges. Cross-references with code-side `os.Getenv` calls automatically. Surfaced via `analyze` `kind: "k8s_resources" / "kustomize" / "images"` - **CPG-lite dataflow** — `value_flow` (intra-procedural assignment / return / range), `arg_of` (caller arg → callee param), and `returns_to` (callee → assignment LHS) edges built at index time. `flow_between` returns ranked dataflow paths between two symbol IDs; `taint_paths` does pattern-driven source→sink sweeps for security audits - **3 MCP prompts** — `pre_commit`, `orientation`, `safe_to_change` for guided workflows +- **LLM features (optional)** — opt-in `ask` research agent + LLM-assisted `search_symbols` ranking, behind a pluggable provider (`local` llama.cpp / Anthropic / OpenAI / Ollama). Off by default; the HTTP providers need no native dependencies. See [LLM Features](#llm-features-optional) - **Two-tier config** — global config (`~/.config/gortex/config.yaml`) for projects and repo lists, per-repo `.gortex.yaml` for guards, excludes, and local overrides - **Guard rules** — project-specific constraints (co-change, boundary) enforced via `check_guards` - **Watch mode** — surgical graph updates on file change across all tracked repos, live sync with agents @@ -565,6 +566,60 @@ go build -tags embeddings_onnx ./cmd/gortex/ # needs: brew install onnxruntime go build -tags embeddings_gomlx ./cmd/gortex/ # auto-downloads XLA plugin ``` +## LLM Features (optional) + +Gortex can delegate code-intelligence work to an LLM. Two features, both **off by default** and gated on configuring a provider: + +- **`ask` MCP tool** — a research agent that drives Gortex's own tools (search, callers, contracts, dependencies) to answer an open-ended question and returns a synthesized answer, instead of the calling agent issuing many tool calls itself. `chain: true` traces cross-system call chains. +- **`search_symbols` `assist` arg** — LLM-assisted ranking on `search_symbols`: `auto` (engage on natural-language queries only), `on`, `off`, `deep` (adds a body-grounded verification pass that reads candidate code + callers and honestly drops irrelevant matches). + +### Providers + +The backend is chosen by the `llm.provider` key. The three HTTP providers are pure Go — available in any build; only `local` needs a `-tags llama` build (it embeds llama.cpp). + +| `llm.provider` | Backend | Needs | +|----------------|---------|-------| +| `local` | in-process llama.cpp | a `-tags llama` build + a `.gguf` model file | +| `anthropic` | Anthropic Messages API | `ANTHROPIC_API_KEY` | +| `openai` | OpenAI Chat Completions | `OPENAI_API_KEY` | +| `ollama` | Ollama daemon | a running Ollama + a pulled model | + +### Configuration + +The `llm:` block goes in `~/.config/gortex/config.yaml` or a per-repo `.gortex.yaml` (repo-local wins per field, global fills the rest). Configure only the provider you use: + +```yaml +# ~/.config/gortex/config.yaml (or per-repo .gortex.yaml) +llm: + provider: local # local | anthropic | openai | ollama + max_steps: 16 # agent tool-loop cap (provider-agnostic) + + local: # provider: local — requires a `-tags llama` build + model: ~/models/qwen2.5-coder-7b-instruct-q4_k_m.gguf + ctx: 4096 # context window in tokens + gpu_layers: 999 # layers to offload to GPU (0 = CPU-only) + template: chatml # chatml | llama3 + + anthropic: # provider: anthropic + model: claude-sonnet-4-6 + api_key_env: ANTHROPIC_API_KEY # env var holding the key (this is the default) + # base_url: https://api.anthropic.com + + openai: # provider: openai + model: gpt-4o + api_key_env: OPENAI_API_KEY + + ollama: # provider: ollama + model: qwen2.5-coder:7b + host: http://localhost:11434 +``` + +Env overrides: `GORTEX_LLM_PROVIDER`, `GORTEX_LLM_MODEL` (targets the active provider's model), `GORTEX_LLM_MAX_STEPS`. API keys are read from the env var named by `api_key_env` — never stored in the config file. + +If the active provider can't be constructed (missing model or API key, or `local` without a `-tags llama` build), the daemon logs a warning and the LLM features stay absent — the rest of Gortex is unaffected. If the `ask` tool isn't in `tools/list`, no provider is configured. + +The `assist` prompts are tiered automatically — terser for hosted frontier models, rule-heavy for small local ones. `deep` mode in particular benefits from a 7B-class or hosted model; small local models are unreliable on its disambiguation cases. + ## Token Savings Gortex tracks how many tokens it saves compared to naive file reads — per-call, per-session, and cumulative across restarts: From 8b752faa9fae4dccc6e8066610fd36988a0b9aa9 Mon Sep 17 00:00:00 2001 From: Andrey Kumanyaev Date: Thu, 14 May 2026 19:01:24 +0200 Subject: [PATCH 6/6] Fix linter issues --- internal/llm/provider/anthropic/anthropic.go | 2 +- internal/llm/provider/anthropic/anthropic_test.go | 8 ++++---- internal/llm/provider/ollama/ollama.go | 2 +- internal/llm/provider/ollama/ollama_test.go | 10 +++++----- internal/llm/provider/openai/openai.go | 2 +- internal/llm/provider/openai/openai_test.go | 8 ++++---- internal/llm/provider/provider_test.go | 6 +++--- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/internal/llm/provider/anthropic/anthropic.go b/internal/llm/provider/anthropic/anthropic.go index 60d93f9..330398c 100644 --- a/internal/llm/provider/anthropic/anthropic.go +++ b/internal/llm/provider/anthropic/anthropic.go @@ -157,7 +157,7 @@ func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm if err != nil { return llm.CompletionResponse{}, fmt.Errorf("anthropic: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() payload, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) if err != nil { return llm.CompletionResponse{}, fmt.Errorf("anthropic: read response: %w", err) diff --git a/internal/llm/provider/anthropic/anthropic_test.go b/internal/llm/provider/anthropic/anthropic_test.go index cbe29a0..cfd4bab 100644 --- a/internal/llm/provider/anthropic/anthropic_test.go +++ b/internal/llm/provider/anthropic/anthropic_test.go @@ -49,7 +49,7 @@ func TestComplete_StructuredUsesForcedTool(t *testing.T) { if err != nil { t.Fatal(err) } - defer p.Close() + defer func() { _ = p.Close() }() resp, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{ @@ -90,7 +90,7 @@ func TestComplete_FreeformNoTools(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "k") p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() resp, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, @@ -118,7 +118,7 @@ func TestComplete_ToolResultBecomesUserTurn(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "k") p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() _, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{ @@ -150,7 +150,7 @@ func TestComplete_APIError(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "k") p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() _, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, diff --git a/internal/llm/provider/ollama/ollama.go b/internal/llm/provider/ollama/ollama.go index aa004c9..29de005 100644 --- a/internal/llm/provider/ollama/ollama.go +++ b/internal/llm/provider/ollama/ollama.go @@ -113,7 +113,7 @@ func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm if err != nil { return llm.CompletionResponse{}, fmt.Errorf("ollama: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() payload, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) if err != nil { return llm.CompletionResponse{}, fmt.Errorf("ollama: read response: %w", err) diff --git a/internal/llm/provider/ollama/ollama_test.go b/internal/llm/provider/ollama/ollama_test.go index e7a6b62..4ecb071 100644 --- a/internal/llm/provider/ollama/ollama_test.go +++ b/internal/llm/provider/ollama/ollama_test.go @@ -22,7 +22,7 @@ func TestNew_DefaultsHost(t *testing.T) { if err != nil { t.Fatal(err) } - defer p.Close() + defer func() { _ = p.Close() }() if p.Name() != "ollama" { t.Errorf("Name()=%q", p.Name()) } @@ -44,7 +44,7 @@ func TestComplete_StructuredSendsFormatSchema(t *testing.T) { if err != nil { t.Fatal(err) } - defer p.Close() + defer func() { _ = p.Close() }() resp, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "verify"}}, @@ -79,7 +79,7 @@ func TestComplete_FreeformNoFormat(t *testing.T) { defer srv.Close() p, _ := New(llm.OllamaConfig{Model: "m", Host: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() resp, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, @@ -104,7 +104,7 @@ func TestComplete_APIError(t *testing.T) { defer srv.Close() p, _ := New(llm.OllamaConfig{Model: "m", Host: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() if _, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, @@ -121,7 +121,7 @@ func TestComplete_InlineErrorField(t *testing.T) { defer srv.Close() p, _ := New(llm.OllamaConfig{Model: "m", Host: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() if _, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, diff --git a/internal/llm/provider/openai/openai.go b/internal/llm/provider/openai/openai.go index 2f55887..94a26ab 100644 --- a/internal/llm/provider/openai/openai.go +++ b/internal/llm/provider/openai/openai.go @@ -120,7 +120,7 @@ func (p *Provider) Complete(ctx context.Context, req llm.CompletionRequest) (llm if err != nil { return llm.CompletionResponse{}, fmt.Errorf("openai: request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() payload, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) if err != nil { return llm.CompletionResponse{}, fmt.Errorf("openai: read response: %w", err) diff --git a/internal/llm/provider/openai/openai_test.go b/internal/llm/provider/openai/openai_test.go index fcf36b9..3128311 100644 --- a/internal/llm/provider/openai/openai_test.go +++ b/internal/llm/provider/openai/openai_test.go @@ -38,7 +38,7 @@ func TestComplete_StructuredUsesJSONSchema(t *testing.T) { if err != nil { t.Fatal(err) } - defer p.Close() + defer func() { _ = p.Close() }() resp, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "Query: auth"}}, @@ -75,7 +75,7 @@ func TestComplete_ToolCallShapeIsNonStrict(t *testing.T) { t.Setenv("OPENAI_API_KEY", "k") p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() _, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "go"}}, @@ -103,7 +103,7 @@ func TestComplete_FreeformNoResponseFormat(t *testing.T) { t.Setenv("OPENAI_API_KEY", "k") p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() resp, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, @@ -129,7 +129,7 @@ func TestComplete_APIError(t *testing.T) { t.Setenv("OPENAI_API_KEY", "k") p, _ := New(llm.RemoteConfig{Model: "m", BaseURL: srv.URL}) - defer p.Close() + defer func() { _ = p.Close() }() if _, err := p.Complete(context.Background(), llm.CompletionRequest{ Messages: []llm.Message{{Role: llm.RoleUser, Content: "hi"}}, diff --git a/internal/llm/provider/provider_test.go b/internal/llm/provider/provider_test.go index 4055f1a..24e5194 100644 --- a/internal/llm/provider/provider_test.go +++ b/internal/llm/provider/provider_test.go @@ -25,7 +25,7 @@ func TestNew_AnthropicOK(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - defer p.Close() + defer func() { _ = p.Close() }() if p.Name() != "anthropic" { t.Errorf("Name()=%q want anthropic", p.Name()) } @@ -37,7 +37,7 @@ func TestNew_OpenAIOK(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - defer p.Close() + defer func() { _ = p.Close() }() if p.Name() != "openai" { t.Errorf("Name()=%q want openai", p.Name()) } @@ -55,7 +55,7 @@ func TestNew_OllamaOK(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - defer p.Close() + defer func() { _ = p.Close() }() if p.Name() != "ollama" { t.Errorf("Name()=%q want ollama", p.Name()) }