From a9571d6aaeafb6a1feb0c3e35be9389e5f5e7c75 Mon Sep 17 00:00:00 2001 From: Vadim Comanescu Date: Fri, 27 Feb 2026 16:21:57 +0100 Subject: [PATCH 1/4] llm/preflight: harden codex app-server provider wiring (cherry picked from commit c0110c630859508ce3390cf976cd11e8a06fbb6b) --- internal/agent/profile.go | 23 + internal/agent/profile_registry.go | 8 +- internal/agent/profile_test.go | 44 + .../engine/api_client_from_runtime.go | 13 +- .../engine/api_client_from_runtime_test.go | 71 ++ .../attractor/engine/provider_preflight.go | 49 + .../engine/provider_preflight_test.go | 174 +++ internal/llm/providers/anthropic/adapter.go | 16 +- .../llm/providers/anthropic/adapter_test.go | 39 + .../llm/providers/codexappserver/adapter.go | 537 +++++++++ .../codexappserver/adapter_helpers_test.go | 280 +++++ .../providers/codexappserver/adapter_test.go | 334 +++++ .../providers/codexappserver/env_config.go | 103 ++ .../codexappserver/env_config_test.go | 109 ++ .../codexappserver/protocol_types.go | 39 + .../codexappserver/request_translator.go | 713 +++++++++++ .../request_translator_controls_test.go | 109 ++ .../codexappserver/request_translator_test.go | 237 ++++ .../codexappserver/response_translator.go | 507 ++++++++ .../response_translator_test.go | 146 +++ .../codexappserver/stream_translator.go | 495 ++++++++ .../codexappserver/stream_translator_test.go | 150 +++ .../codexappserver/translator_utils.go | 78 ++ .../llm/providers/codexappserver/transport.go | 1070 +++++++++++++++++ .../codexappserver/transport_helpers_test.go | 579 +++++++++ .../codexappserver/transport_timeout_test.go | 82 ++ internal/llm/providers/codexappserver/util.go | 124 ++ .../llm/providers/codexappserver/util_test.go | 121 ++ internal/llm/providers/google/adapter.go | 18 +- internal/llm/providers/google/adapter_test.go | 34 + internal/llm/providers/openai/adapter.go | 6 +- internal/llm/providers/openai/adapter_test.go | 34 + .../llm/providers/openaicompat/adapter.go | 33 +- .../providers/openaicompat/adapter_test.go | 40 +- internal/llm/rate_limit.go | 164 +++ internal/llm/rate_limit_test.go | 98 ++ internal/llm/stream.go | 10 +- internal/llm/types.go | 3 + internal/llmclient/env.go | 1 + internal/llmclient/env_test.go | 22 + internal/providerspec/builtin.go | 10 + internal/providerspec/spec.go | 1 + internal/providerspec/spec_test.go | 27 +- 43 files changed, 6716 insertions(+), 35 deletions(-) create mode 100644 internal/llm/providers/codexappserver/adapter.go create mode 100644 internal/llm/providers/codexappserver/adapter_helpers_test.go create mode 100644 internal/llm/providers/codexappserver/adapter_test.go create mode 100644 internal/llm/providers/codexappserver/env_config.go create mode 100644 internal/llm/providers/codexappserver/env_config_test.go create mode 100644 internal/llm/providers/codexappserver/protocol_types.go create mode 100644 internal/llm/providers/codexappserver/request_translator.go create mode 100644 internal/llm/providers/codexappserver/request_translator_controls_test.go create mode 100644 internal/llm/providers/codexappserver/request_translator_test.go create mode 100644 internal/llm/providers/codexappserver/response_translator.go create mode 100644 internal/llm/providers/codexappserver/response_translator_test.go create mode 100644 internal/llm/providers/codexappserver/stream_translator.go create mode 100644 internal/llm/providers/codexappserver/stream_translator_test.go create mode 100644 internal/llm/providers/codexappserver/translator_utils.go create mode 100644 internal/llm/providers/codexappserver/transport.go create mode 100644 internal/llm/providers/codexappserver/transport_helpers_test.go create mode 100644 internal/llm/providers/codexappserver/transport_timeout_test.go create mode 100644 internal/llm/providers/codexappserver/util.go create mode 100644 internal/llm/providers/codexappserver/util_test.go create mode 100644 internal/llm/rate_limit.go create mode 100644 internal/llm/rate_limit_test.go diff --git a/internal/agent/profile.go b/internal/agent/profile.go index 0ccb2b1c..1724b39f 100644 --- a/internal/agent/profile.go +++ b/internal/agent/profile.go @@ -139,6 +139,29 @@ func NewOpenAIProfile(model string) ProviderProfile { } } +func NewCodexAppServerProfile(model string) ProviderProfile { + return &baseProfile{ + id: "codex-app-server", + model: strings.TrimSpace(model), + parallel: true, + contextWindow: 1_047_576, + basePrompt: openAIProfileBasePrompt, + docFiles: []string{"AGENTS.md", ".codex/instructions.md"}, + toolDefs: []llm.ToolDefinition{ + defReadFile(), + defApplyPatch(), + defWriteFile(), + defShell(), + defGrep(), + defGlob(), + defSpawnAgent(), + defSendInput(), + defWait(), + defCloseAgent(), + }, + } +} + func NewAnthropicProfile(model string) ProviderProfile { return &baseProfile{ id: "anthropic", diff --git a/internal/agent/profile_registry.go b/internal/agent/profile_registry.go index fbf92d8c..aa265c82 100644 --- a/internal/agent/profile_registry.go +++ b/internal/agent/profile_registry.go @@ -9,9 +9,11 @@ import ( var ( profileFactoriesMu sync.RWMutex profileFactories = map[string]func(string) ProviderProfile{ - "openai": NewOpenAIProfile, - "anthropic": NewAnthropicProfile, - "google": NewGeminiProfile, + "openai": NewOpenAIProfile, + "anthropic": NewAnthropicProfile, + "google": NewGeminiProfile, + "codex-app-server": NewCodexAppServerProfile, + "codex": NewCodexAppServerProfile, } ) diff --git a/internal/agent/profile_test.go b/internal/agent/profile_test.go index 0eed555a..c33ab9c8 100644 --- a/internal/agent/profile_test.go +++ b/internal/agent/profile_test.go @@ -41,6 +41,19 @@ func TestProviderProfiles_ToolsetsAndDocSelection(t *testing.T) { assertHasTool(t, gemini, "read_many_files") assertHasTool(t, gemini, "list_dir") assertMissingTool(t, gemini, "apply_patch") + + codex := NewCodexAppServerProfile("gpt-5-codex") + if codex.ID() != "codex-app-server" { + t.Fatalf("codex id: %q", codex.ID()) + } + if !codex.SupportsParallelToolCalls() { + t.Fatalf("codex profile should support parallel tool calls") + } + if codex.ContextWindowSize() != 1_047_576 { + t.Fatalf("codex context window: got %d want %d", codex.ContextWindowSize(), 1_047_576) + } + assertHasTool(t, codex, "apply_patch") + assertMissingTool(t, codex, "edit_file") } func TestProviderProfiles_ToolLists_MatchSpec(t *testing.T) { @@ -92,6 +105,21 @@ func TestProviderProfiles_ToolLists_MatchSpec(t *testing.T) { "close_agent", }) }) + t.Run("codex-app-server", func(t *testing.T) { + p := NewCodexAppServerProfile("gpt-5-codex") + assertToolListExact(t, p, []string{ + "read_file", + "apply_patch", + "write_file", + "shell", + "grep", + "glob", + "spawn_agent", + "send_input", + "wait", + "close_agent", + }) + }) } func TestProviderProfiles_BuildSystemPrompt_IncludesProviderSpecificBaseInstructions(t *testing.T) { @@ -181,4 +209,20 @@ func TestNewProfileForFamily_DefaultFamiliesAndRegistration(t *testing.T) { if _, err := NewProfileForFamily("missing-family", "m3"); err == nil { t.Fatalf("expected unsupported family error") } + + codex, err := NewProfileForFamily("codex-app-server", "gpt-5-codex") + if err != nil { + t.Fatalf("NewProfileForFamily(codex-app-server): %v", err) + } + if codex.ID() != "codex-app-server" { + t.Fatalf("codex profile id=%q want codex-app-server", codex.ID()) + } + + codexAlias, err := NewProfileForFamily("codex", "gpt-5-codex") + if err != nil { + t.Fatalf("NewProfileForFamily(codex): %v", err) + } + if codexAlias.ID() != "codex-app-server" { + t.Fatalf("codex alias profile id=%q want codex-app-server", codexAlias.ID()) + } } diff --git a/internal/attractor/engine/api_client_from_runtime.go b/internal/attractor/engine/api_client_from_runtime.go index 6d3fadc0..854d9412 100644 --- a/internal/attractor/engine/api_client_from_runtime.go +++ b/internal/attractor/engine/api_client_from_runtime.go @@ -8,6 +8,7 @@ import ( "github.com/danshapiro/kilroy/internal/llm" "github.com/danshapiro/kilroy/internal/llm/providers/anthropic" + "github.com/danshapiro/kilroy/internal/llm/providers/codexappserver" "github.com/danshapiro/kilroy/internal/llm/providers/google" "github.com/danshapiro/kilroy/internal/llm/providers/openai" "github.com/danshapiro/kilroy/internal/llm/providers/openaicompat" @@ -21,7 +22,15 @@ func newAPIClientFromProviderRuntimes(runtimes map[string]ProviderRuntime) (*llm if rt.Backend != BackendAPI { continue } - apiKey := strings.TrimSpace(os.Getenv(rt.API.DefaultAPIKeyEnv)) + apiKeyEnv := strings.TrimSpace(rt.API.DefaultAPIKeyEnv) + if rt.API.Protocol == providerspec.ProtocolCodexAppServer && apiKeyEnv == "" { + c.Register(codexappserver.NewAdapter(codexappserver.AdapterOptions{Provider: key})) + continue + } + if apiKeyEnv == "" { + continue + } + apiKey := strings.TrimSpace(os.Getenv(apiKeyEnv)) if apiKey == "" { continue } @@ -41,6 +50,8 @@ func newAPIClientFromProviderRuntimes(runtimes map[string]ProviderRuntime) (*llm OptionsKey: rt.API.ProviderOptionsKey, ExtraHeaders: rt.APIHeaders(), })) + case providerspec.ProtocolCodexAppServer: + c.Register(codexappserver.NewAdapter(codexappserver.AdapterOptions{Provider: key})) default: return nil, fmt.Errorf("unsupported api protocol %q for provider %s", rt.API.Protocol, key) } diff --git a/internal/attractor/engine/api_client_from_runtime_test.go b/internal/attractor/engine/api_client_from_runtime_test.go index 48e36838..3b5f1071 100644 --- a/internal/attractor/engine/api_client_from_runtime_test.go +++ b/internal/attractor/engine/api_client_from_runtime_test.go @@ -114,6 +114,77 @@ func TestNewAPIClientFromProviderRuntimes_RegistersMinimaxViaOpenAICompat(t *tes } } +func TestNewAPIClientFromProviderRuntimes_RegistersCodexAppServerProtocol(t *testing.T) { + runtimes := map[string]ProviderRuntime{ + "codex-app-server": { + Key: "codex-app-server", + Backend: BackendAPI, + API: providerspec.APISpec{ + Protocol: providerspec.ProtocolCodexAppServer, + DefaultAPIKeyEnv: "", + }, + }, + } + c, err := newAPIClientFromProviderRuntimes(runtimes) + if err != nil { + t.Fatalf("newAPIClientFromProviderRuntimes: %v", err) + } + if len(c.ProviderNames()) != 1 || c.ProviderNames()[0] != "codex-app-server" { + t.Fatalf("expected codex-app-server adapter, got %v", c.ProviderNames()) + } +} + +func TestNewAPIClientFromProviderRuntimes_CodexAppServerHonorsExplicitAPIKeyEnv(t *testing.T) { + runtimes := map[string]ProviderRuntime{ + "codex-app-server": { + Key: "codex-app-server", + Backend: BackendAPI, + API: providerspec.APISpec{ + Protocol: providerspec.ProtocolCodexAppServer, + DefaultAPIKeyEnv: "CODEX_APP_SERVER_TOKEN", + }, + }, + } + + t.Setenv("CODEX_APP_SERVER_TOKEN", "") + c, err := newAPIClientFromProviderRuntimes(runtimes) + if err != nil { + t.Fatalf("newAPIClientFromProviderRuntimes: %v", err) + } + if len(c.ProviderNames()) != 0 { + t.Fatalf("expected no adapters when explicit codex api key env is unset, got %v", c.ProviderNames()) + } + + t.Setenv("CODEX_APP_SERVER_TOKEN", "present") + c, err = newAPIClientFromProviderRuntimes(runtimes) + if err != nil { + t.Fatalf("newAPIClientFromProviderRuntimes: %v", err) + } + if len(c.ProviderNames()) != 1 || c.ProviderNames()[0] != "codex-app-server" { + t.Fatalf("expected codex-app-server adapter when explicit env is set, got %v", c.ProviderNames()) + } +} + +func TestNewAPIClientFromProviderRuntimes_CodexAppServerPreservesCustomProviderKey(t *testing.T) { + runtimes := map[string]ProviderRuntime{ + "my-codex-provider": { + Key: "my-codex-provider", + Backend: BackendAPI, + API: providerspec.APISpec{ + Protocol: providerspec.ProtocolCodexAppServer, + DefaultAPIKeyEnv: "", + }, + }, + } + c, err := newAPIClientFromProviderRuntimes(runtimes) + if err != nil { + t.Fatalf("newAPIClientFromProviderRuntimes: %v", err) + } + if len(c.ProviderNames()) != 1 || c.ProviderNames()[0] != "my-codex-provider" { + t.Fatalf("expected custom codex provider key to be preserved, got %v", c.ProviderNames()) + } +} + func TestResolveBuiltInBaseURLOverride_MinimaxUsesEnvOverride(t *testing.T) { t.Setenv("MINIMAX_BASE_URL", "http://127.0.0.1:8888") got := resolveBuiltInBaseURLOverride("minimax", "https://api.minimax.io") diff --git a/internal/attractor/engine/provider_preflight.go b/internal/attractor/engine/provider_preflight.go index 54b0c622..7c94d613 100644 --- a/internal/attractor/engine/provider_preflight.go +++ b/internal/attractor/engine/provider_preflight.go @@ -31,6 +31,9 @@ const ( defaultPreflightAPIPromptProbeRetries = 2 defaultPreflightAPIPromptProbeBaseDelay = 500 * time.Millisecond defaultPreflightAPIPromptProbeMaxDelay = 5 * time.Second + + codexAppServerCommandEnv = "CODEX_APP_SERVER_COMMAND" + codexAppServerDefaultCommand = "codex" ) type providerPreflightReport struct { @@ -179,7 +182,53 @@ func runProviderAPIPreflight(ctx context.Context, g *model.Graph, runtimes map[s }) return fmt.Errorf("preflight: provider %s missing runtime definition", provider) } + if rt.API.Protocol == providerspec.ProtocolCodexAppServer { + command := strings.TrimSpace(os.Getenv(codexAppServerCommandEnv)) + source := "default" + if command == "" { + command = codexAppServerDefaultCommand + } else { + source = "env" + } + resolvedPath, lookErr := exec.LookPath(command) + if lookErr != nil { + report.addCheck(providerPreflightCheck{ + Name: "provider_api_presence", + Provider: provider, + Status: preflightStatusFail, + Message: fmt.Sprintf("codex app server command %q is not available: %v", command, lookErr), + Details: map[string]any{ + "command": command, + "source": source, + }, + }) + return fmt.Errorf("preflight: provider %s codex app server command %q is not available: %w", provider, command, lookErr) + } + report.addCheck(providerPreflightCheck{ + Name: "provider_api_presence", + Provider: provider, + Status: preflightStatusPass, + Message: fmt.Sprintf("codex app server command %q is available", command), + Details: map[string]any{ + "command": command, + "resolved_path": resolvedPath, + "source": source, + }, + }) + } keyEnv := strings.TrimSpace(rt.API.DefaultAPIKeyEnv) + if rt.API.Protocol == providerspec.ProtocolCodexAppServer && keyEnv == "" { + report.addCheck(providerPreflightCheck{ + Name: "provider_api_credentials", + Provider: provider, + Status: preflightStatusPass, + Message: "api key env is not required for codex app server", + Details: map[string]any{ + "protocol": string(rt.API.Protocol), + }, + }) + continue + } if keyEnv == "" { report.addCheck(providerPreflightCheck{ Name: "provider_api_credentials", diff --git a/internal/attractor/engine/provider_preflight_test.go b/internal/attractor/engine/provider_preflight_test.go index eba16cd9..12d7943d 100644 --- a/internal/attractor/engine/provider_preflight_test.go +++ b/internal/attractor/engine/provider_preflight_test.go @@ -637,6 +637,164 @@ func TestPreflightReport_IncludesCLIProfileAndSource(t *testing.T) { } } +func TestRunWithConfig_PreflightCodexAppServer_DoesNotRequireEnvGate(t *testing.T) { + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") + prepareCodexAppServerCommandForPreflight(t) + + repo := initTestRepo(t) + catalog := writeCatalogForPreflight(t, `{ + "data": [ + {"id": "codex-app-server/gpt-5.3-codex"} + ] +}`) + + cfg := testPreflightConfigForProviders(repo, catalog, map[string]BackendKind{ + "codex-app-server": BackendAPI, + }) + cfg.LLM.CLIProfile = "real" + dot := singleProviderDot("codex-app-server", "gpt-5.3-codex") + + logsRoot := t.TempDir() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, err := RunWithConfig(ctx, dot, cfg, RunOptions{RunID: "preflight-codex-app-server-no-env-gate", LogsRoot: logsRoot}) + if err == nil { + t.Fatalf("expected downstream cxdb error, got nil") + } + if strings.Contains(err.Error(), "preflight:") { + t.Fatalf("unexpected preflight failure: %v", err) + } + + report := mustReadPreflightReport(t, logsRoot) + foundAPIPresencePass := false + foundCredentialsPass := false + for _, check := range report.Checks { + if check.Name == "provider_api_presence" && check.Provider == "codex-app-server" { + foundAPIPresencePass = true + if check.Status != "pass" { + t.Fatalf("provider_api_presence.status=%q want pass", check.Status) + } + } + if check.Name != "provider_api_credentials" || check.Provider != "codex-app-server" { + continue + } + foundCredentialsPass = true + if check.Status != "pass" { + t.Fatalf("provider_api_credentials.status=%q want pass", check.Status) + } + if !strings.Contains(check.Message, "not required") { + t.Fatalf("provider_api_credentials.message=%q want mention of non-required api key", check.Message) + } + } + if !foundAPIPresencePass { + t.Fatalf("expected provider_api_presence pass check for codex-app-server") + } + if !foundCredentialsPass { + t.Fatalf("expected provider_api_credentials pass check for codex-app-server") + } +} + +func TestRunWithConfig_PreflightCodexAppServer_ExplicitAPIKeyEnvIsEnforced(t *testing.T) { + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") + t.Setenv("CODEX_APP_SERVER_TOKEN", "") + prepareCodexAppServerCommandForPreflight(t) + + repo := initTestRepo(t) + catalog := writeCatalogForPreflight(t, `{ + "data": [ + {"id": "codex-app-server/gpt-5.3-codex"} + ] +}`) + + cfg := testPreflightConfigForProviders(repo, catalog, map[string]BackendKind{ + "codex-app-server": BackendAPI, + }) + cfg.LLM.CLIProfile = "real" + cfg.LLM.Providers["codex-app-server"] = ProviderConfig{ + Backend: BackendAPI, + API: ProviderAPIConfig{ + APIKeyEnv: "CODEX_APP_SERVER_TOKEN", + }, + } + dot := singleProviderDot("codex-app-server", "gpt-5.3-codex") + + logsRoot := t.TempDir() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, err := RunWithConfig(ctx, dot, cfg, RunOptions{RunID: "preflight-codex-app-server-explicit-env", LogsRoot: logsRoot}) + if err == nil { + t.Fatalf("expected preflight failure, got nil") + } + if !strings.Contains(err.Error(), "preflight: provider codex-app-server missing api key env CODEX_APP_SERVER_TOKEN") { + t.Fatalf("unexpected error: %v", err) + } + + report := mustReadPreflightReport(t, logsRoot) + foundCredentialsFail := false + for _, check := range report.Checks { + if check.Name != "provider_api_credentials" || check.Provider != "codex-app-server" { + continue + } + foundCredentialsFail = true + if check.Status != "fail" { + t.Fatalf("provider_api_credentials.status=%q want fail", check.Status) + } + if !strings.Contains(check.Message, "CODEX_APP_SERVER_TOKEN") { + t.Fatalf("provider_api_credentials.message=%q want mention of CODEX_APP_SERVER_TOKEN", check.Message) + } + } + if !foundCredentialsFail { + t.Fatalf("expected provider_api_credentials fail check for codex-app-server") + } +} + +func TestRunWithConfig_PreflightCodexAppServer_FailsWhenCommandUnavailable(t *testing.T) { + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") + t.Setenv("CODEX_APP_SERVER_COMMAND", "codex-missing-preflight") + + repo := initTestRepo(t) + catalog := writeCatalogForPreflight(t, `{ + "data": [ + {"id": "codex-app-server/gpt-5.3-codex"} + ] +}`) + + cfg := testPreflightConfigForProviders(repo, catalog, map[string]BackendKind{ + "codex-app-server": BackendAPI, + }) + cfg.LLM.CLIProfile = "real" + dot := singleProviderDot("codex-app-server", "gpt-5.3-codex") + + logsRoot := t.TempDir() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, err := RunWithConfig(ctx, dot, cfg, RunOptions{RunID: "preflight-codex-app-server-missing-command", LogsRoot: logsRoot}) + if err == nil { + t.Fatalf("expected preflight failure, got nil") + } + if !strings.Contains(err.Error(), `preflight: provider codex-app-server codex app server command "codex-missing-preflight" is not available`) { + t.Fatalf("unexpected error: %v", err) + } + + report := mustReadPreflightReport(t, logsRoot) + foundPresenceFail := false + for _, check := range report.Checks { + if check.Name != "provider_api_presence" || check.Provider != "codex-app-server" { + continue + } + foundPresenceFail = true + if check.Status != "fail" { + t.Fatalf("provider_api_presence.status=%q want fail", check.Status) + } + if !strings.Contains(check.Message, "codex-missing-preflight") { + t.Fatalf("provider_api_presence.message=%q want mention of codex-missing-preflight", check.Message) + } + } + if !foundPresenceFail { + t.Fatalf("expected provider_api_presence fail check for codex-app-server") + } +} + func TestRunWithConfig_PreflightPromptProbe_UsesOnlyAPIProvidersInGraph(t *testing.T) { repo := initTestRepo(t) catalog := writeCatalogForPreflight(t, `{ @@ -1943,3 +2101,19 @@ func TestUsedAPIProviders_ExcludesUncredentialedFailoverTarget(t *testing.T) { t.Fatalf("want [anthropic google] (both credentialed), got %v", got) } } + +func prepareCodexAppServerCommandForPreflight(t *testing.T) { + t.Helper() + binDir := t.TempDir() + commandName := "codex-preflight" + commandPath := filepath.Join(binDir, commandName) + script := `#!/usr/bin/env bash +set -euo pipefail +exit 0 +` + if err := os.WriteFile(commandPath, []byte(script), 0o755); err != nil { + t.Fatalf("write codex preflight helper binary: %v", err) + } + t.Setenv("PATH", binDir+":"+os.Getenv("PATH")) + t.Setenv("CODEX_APP_SERVER_COMMAND", commandName) +} diff --git a/internal/llm/providers/anthropic/adapter.go b/internal/llm/providers/anthropic/adapter.go index fc6784d3..f3519373 100644 --- a/internal/llm/providers/anthropic/adapter.go +++ b/internal/llm/providers/anthropic/adapter.go @@ -208,7 +208,9 @@ func (a *Adapter) Complete(ctx context.Context, req llm.Request) (llm.Response, return llm.Response{}, llm.ErrorFromHTTPStatus(a.Name(), resp.StatusCode, msg, raw, ra) } - return fromAnthropicResponse(a.Name(), raw, req.Model), nil + out := fromAnthropicResponse(a.Name(), raw, req.Model) + out.RateLimit = llm.ParseRateLimitInfo(resp.Header, time.Now()) + return out, nil } func (a *Adapter) completeViaStream(ctx context.Context, req llm.Request) (llm.Response, error) { @@ -388,6 +390,7 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro s := llm.NewChanStream(cancel) s.Send(llm.StreamEvent{Type: llm.StreamEventStreamStart}) + rateLimit := llm.ParseRateLimitInfo(resp.Header, time.Now()) go func() { defer func() { @@ -684,11 +687,12 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro msg := llm.Message{Role: llm.RoleAssistant, Content: parts} r := llm.Response{ - Provider: a.Name(), - Model: req.Model, - Message: msg, - Finish: finish, - Usage: usage, + Provider: a.Name(), + Model: req.Model, + Message: msg, + Finish: finish, + Usage: usage, + RateLimit: rateLimit, } if len(r.ToolCalls()) > 0 { r.Finish = llm.FinishReason{Reason: "tool_calls", Raw: "tool_use"} diff --git a/internal/llm/providers/anthropic/adapter_test.go b/internal/llm/providers/anthropic/adapter_test.go index 6eb47ec9..910bc460 100644 --- a/internal/llm/providers/anthropic/adapter_test.go +++ b/internal/llm/providers/anthropic/adapter_test.go @@ -18,6 +18,28 @@ import ( "github.com/danshapiro/kilroy/internal/llm" ) +func assertRateLimitInfo(t *testing.T, rl *llm.RateLimitInfo) { + t.Helper() + if rl == nil { + t.Fatalf("expected rate limit info, got nil") + } + if rl.RequestsRemaining == nil || *rl.RequestsRemaining != 9 { + t.Fatalf("requests_remaining: %#v", rl.RequestsRemaining) + } + if rl.RequestsLimit == nil || *rl.RequestsLimit != 10 { + t.Fatalf("requests_limit: %#v", rl.RequestsLimit) + } + if rl.TokensRemaining == nil || *rl.TokensRemaining != 90 { + t.Fatalf("tokens_remaining: %#v", rl.TokensRemaining) + } + if rl.TokensLimit == nil || *rl.TokensLimit != 100 { + t.Fatalf("tokens_limit: %#v", rl.TokensLimit) + } + if rl.ResetAt != "2025-01-01T00:00:10Z" { + t.Fatalf("reset_at: %q", rl.ResetAt) + } +} + func TestAdapter_Complete_MapsToMessagesAPI_AndSetsBetaHeaders(t *testing.T) { var gotBody map[string]any gotBeta := "" @@ -33,6 +55,11 @@ func TestAdapter_Complete_MapsToMessagesAPI_AndSetsBetaHeaders(t *testing.T) { _ = json.Unmarshal(b, &gotBody) w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") _, _ = w.Write([]byte(`{ "id": "msg_1", "model": "claude-test", @@ -68,6 +95,7 @@ func TestAdapter_Complete_MapsToMessagesAPI_AndSetsBetaHeaders(t *testing.T) { if strings.TrimSpace(resp.Text()) != "Hello" { t.Fatalf("resp text: %q", resp.Text()) } + assertRateLimitInfo(t, resp.RateLimit) if gotBeta != "prompt-caching-2024-07-31" { t.Fatalf("anthropic-beta header: %q", gotBeta) } @@ -176,6 +204,11 @@ func TestAdapter_Stream_NormalizesDotsTodashesinModelID(t *testing.T) { gotModel, _ = body["model"].(string) w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") f, _ := w.(http.Flusher) write := func(event, data string) { _, _ = io.WriteString(w, "event: "+event+"\ndata: "+data+"\n\n") @@ -251,6 +284,11 @@ func TestAdapter_Stream_YieldsTextDeltasAndFinish(t *testing.T) { _ = json.Unmarshal(b, &gotBody) w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") f, _ := w.(http.Flusher) write := func(event string, data string) { _, _ = io.WriteString(w, "event: "+event+"\n") @@ -297,6 +335,7 @@ func TestAdapter_Stream_YieldsTextDeltasAndFinish(t *testing.T) { if finish == nil || strings.TrimSpace(finish.Text()) != "Hello" { t.Fatalf("finish response: %+v", finish) } + assertRateLimitInfo(t, finish.RateLimit) if gotBody == nil { t.Fatalf("server did not capture request body") } diff --git a/internal/llm/providers/codexappserver/adapter.go b/internal/llm/providers/codexappserver/adapter.go new file mode 100644 index 00000000..b45edc8d --- /dev/null +++ b/internal/llm/providers/codexappserver/adapter.go @@ -0,0 +1,537 @@ +package codexappserver + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + "sync" + "time" + + "github.com/danshapiro/kilroy/internal/llm" +) + +type codexTransport interface { + Initialize(ctx context.Context) error + Close() error + Complete(ctx context.Context, payload map[string]any) (map[string]any, error) + Stream(ctx context.Context, payload map[string]any) (*NotificationStream, error) + ListModels(ctx context.Context, params map[string]any) (modelListResponse, error) +} + +type AdapterOptions struct { + Provider string + Transport codexTransport + TransportOptions TransportOptions + TranslateRequest func(request llm.Request, streaming bool) (translateRequestResult, error) + TranslateResponse func(body map[string]any) (llm.Response, error) + TranslateStream func(events <-chan map[string]any) <-chan llm.StreamEvent +} + +type Adapter struct { + provider string + transportProvided codexTransport + transportOptions TransportOptions + + translateRequestFn func(request llm.Request, streaming bool) (translateRequestResult, error) + translateResponseFn func(body map[string]any) (llm.Response, error) + translateStreamFn func(events <-chan map[string]any) <-chan llm.StreamEvent + + transportMu sync.Mutex + transport codexTransport + + modelListMu sync.Mutex + modelList *modelListResponse +} + +func init() { + llm.RegisterEnvAdapterFactory(func() (llm.ProviderAdapter, bool, error) { + opts, ok := transportOptionsFromEnv() + if !ok { + return nil, false, nil + } + return NewAdapter(AdapterOptions{TransportOptions: opts}), true, nil + }) +} + +func NewAdapter(options AdapterOptions) *Adapter { + provider := strings.TrimSpace(options.Provider) + if provider == "" { + provider = providerName + } + return &Adapter{ + provider: provider, + transportProvided: options.Transport, + transportOptions: options.TransportOptions, + translateRequestFn: func(request llm.Request, streaming bool) (translateRequestResult, error) { + if options.TranslateRequest != nil { + return options.TranslateRequest(request, streaming) + } + return translateRequest(request, streaming) + }, + translateResponseFn: func(body map[string]any) (llm.Response, error) { + if options.TranslateResponse != nil { + return options.TranslateResponse(body) + } + return translateResponse(body) + }, + translateStreamFn: func(events <-chan map[string]any) <-chan llm.StreamEvent { + if options.TranslateStream != nil { + return options.TranslateStream(events) + } + return translateStream(events) + }, + } +} + +func NewFromEnv() (*Adapter, error) { + opts, _ := transportOptionsFromEnv() + return NewAdapter(AdapterOptions{TransportOptions: opts}), nil +} + +func (a *Adapter) Name() string { return a.provider } + +func (a *Adapter) Complete(ctx context.Context, req llm.Request) (llm.Response, error) { + resolved, err := resolveFileImages(req) + if err != nil { + return llm.Response{}, mapCodexError(err, a.provider, "complete") + } + + translated, err := a.translateRequestFn(resolved, false) + if err != nil { + return llm.Response{}, mapCodexError(err, a.provider, "complete") + } + + transport, err := a.getTransport() + if err != nil { + return llm.Response{}, mapCodexError(err, a.provider, "complete") + } + + result, err := transport.Complete(ctx, translated.Payload) + if err != nil { + return llm.Response{}, mapCodexError(err, a.provider, "complete") + } + if embedded := extractTurnError(result); embedded != nil { + return llm.Response{}, mapCodexError(embedded, a.provider, "complete") + } + + response, err := a.translateResponseFn(result) + if err != nil { + return llm.Response{}, mapCodexError(err, a.provider, "complete") + } + response.Provider = a.provider + if len(translated.Warnings) > 0 { + response.Warnings = append(response.Warnings, translated.Warnings...) + } + return response, nil +} + +func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, error) { + resolved, err := resolveFileImages(req) + if err != nil { + return nil, mapCodexError(err, a.provider, "stream") + } + translated, err := a.translateRequestFn(resolved, true) + if err != nil { + return nil, mapCodexError(err, a.provider, "stream") + } + + transport, err := a.getTransport() + if err != nil { + return nil, mapCodexError(err, a.provider, "stream") + } + + sctx, cancel := context.WithCancel(ctx) + stream, err := transport.Stream(sctx, translated.Payload) + if err != nil { + cancel() + return nil, mapCodexError(err, a.provider, "stream") + } + + out := llm.NewChanStream(cancel) + go func() { + defer cancel() + defer stream.Close() + defer out.CloseSend() + + translatedEvents := a.translateStreamFn(stream.Notifications) + warningsAttached := false + for event := range translatedEvents { + if !warningsAttached && event.Type == llm.StreamEventStreamStart { + warningsAttached = true + if len(translated.Warnings) > 0 { + event.Warnings = append(event.Warnings, translated.Warnings...) + } + out.Send(event) + continue + } + out.Send(event) + } + if !warningsAttached && len(translated.Warnings) > 0 { + out.Send(llm.StreamEvent{Type: llm.StreamEventStreamStart, Warnings: translated.Warnings}) + } + if stream.Err != nil { + if streamErr, ok := <-stream.Err; ok && streamErr != nil { + out.Send(llm.StreamEvent{Type: llm.StreamEventError, Err: mapCodexError(streamErr, a.provider, "stream")}) + } + } + }() + + return out, nil +} + +func (a *Adapter) Initialize(ctx context.Context) error { + transport, err := a.getTransport() + if err != nil { + return mapCodexError(err, a.provider, "complete") + } + if err := transport.Initialize(ctx); err != nil { + return mapCodexError(err, a.provider, "complete") + } + return nil +} + +func (a *Adapter) ListModels(ctx context.Context, params map[string]any) (modelListResponse, error) { + a.modelListMu.Lock() + if a.modelList != nil { + cached := *a.modelList + a.modelListMu.Unlock() + return cached, nil + } + a.modelListMu.Unlock() + + transport, err := a.getTransport() + if err != nil { + return modelListResponse{}, mapCodexError(err, a.provider, "complete") + } + resp, err := transport.ListModels(ctx, params) + if err != nil { + return modelListResponse{}, mapCodexError(err, a.provider, "complete") + } + + a.modelListMu.Lock() + cp := resp + a.modelList = &cp + a.modelListMu.Unlock() + return resp, nil +} + +func (a *Adapter) GetDefaultModel(ctx context.Context) (*modelEntry, error) { + resp, err := a.ListModels(ctx, nil) + if err != nil { + return nil, err + } + for idx := range resp.Data { + if resp.Data[idx].IsDefault { + entry := resp.Data[idx] + return &entry, nil + } + } + return nil, nil +} + +func (a *Adapter) Close() error { + a.transportMu.Lock() + transport := a.transport + a.transportMu.Unlock() + if transport == nil { + return nil + } + if err := transport.Close(); err != nil { + return mapCodexError(err, a.provider, "complete") + } + return nil +} + +func (a *Adapter) getTransport() (codexTransport, error) { + if a.transportProvided != nil { + return a.transportProvided, nil + } + a.transportMu.Lock() + defer a.transportMu.Unlock() + if a.transport != nil { + return a.transport, nil + } + opts := a.transportOptions + if envOpts, ok := transportOptionsFromEnv(); ok { + if strings.TrimSpace(opts.Command) == "" { + opts.Command = envOpts.Command + } + if len(opts.Args) == 0 { + opts.Args = append([]string{}, envOpts.Args...) + } + } + a.transport = NewTransport(opts) + return a.transport, nil +} + +func resolveFileImages(req llm.Request) (llm.Request, error) { + resolved := req + resolved.Messages = make([]llm.Message, len(req.Messages)) + for mi, message := range req.Messages { + copyMessage := message + copyMessage.Content = make([]llm.ContentPart, len(message.Content)) + for pi, part := range message.Content { + copyPart := part + if part.Kind == llm.ContentImage && part.Image != nil && len(part.Image.Data) == 0 { + url := strings.TrimSpace(part.Image.URL) + if isResolvableImagePath(url) { + path := resolveImagePath(url) + bytes, err := os.ReadFile(path) + if err != nil { + return llm.Request{}, err + } + mediaType := strings.TrimSpace(part.Image.MediaType) + if mediaType == "" { + mediaType = llm.InferMimeTypeFromPath(path) + } + if mediaType == "" { + mediaType = "image/png" + } + copyPart.Image = &llm.ImageData{ + Data: bytes, + MediaType: mediaType, + Detail: part.Image.Detail, + } + } + } + copyMessage.Content[pi] = copyPart + } + resolved.Messages[mi] = copyMessage + } + return resolved, nil +} + +func isResolvableImagePath(url string) bool { + if strings.TrimSpace(url) == "" { + return false + } + if strings.HasPrefix(strings.TrimSpace(url), "file://") { + return true + } + return llm.IsLocalPath(url) +} + +func resolveImagePath(url string) string { + path := strings.TrimSpace(url) + if strings.HasPrefix(path, "file://") { + path = strings.TrimPrefix(path, "file://") + } + return llm.ExpandTilde(path) +} + +type normalizedErrorInfo struct { + Message string + Status int + HasStatus bool + Code string + RetryAfter *time.Duration + Raw any +} + +func extractTurnError(value map[string]any) any { + if value == nil { + return nil + } + if turnError, ok := value["turnError"]; ok && turnError != nil { + return turnError + } + if rootErr, ok := value["error"]; ok && rootErr != nil { + return rootErr + } + turn := asMap(value["turn"]) + if turn == nil { + return nil + } + if turnErr, ok := turn["error"]; ok && turnErr != nil { + return turnErr + } + status := strings.ToLower(strings.TrimSpace(asString(turn["status"]))) + if status == "failed" || status == "error" { + return turn + } + return nil +} + +func mapCodexError(raw any, provider string, contextKind string) error { + if raw == nil { + return nil + } + if rawMap := asMap(raw); rawMap != nil { + if embedded := extractTurnError(rawMap); embedded != nil { + raw = embedded + } + } + if err, ok := raw.(error); ok { + if mapped := llm.WrapContextError(provider, err); mapped != err { + return mapped + } + var llmErr llm.Error + if errorsAs(err, &llmErr) { + return err + } + } + + info := normalizeErrorInfo(raw) + code := normalizeCode(info.Code) + + if isTransportFailure(code, info.Message) { + if contextKind == "stream" { + return llm.NewStreamError(provider, info.Message) + } + return llm.NewNetworkError(provider, info.Message) + } + + if info.HasStatus { + return llm.ErrorFromHTTPStatus(provider, info.Status, info.Message, info.Raw, info.RetryAfter) + } + + if class := classifyByCode(code); class != "" { + switch class { + case "invalid_request": + return llm.ErrorFromHTTPStatus(provider, 400, info.Message, info.Raw, nil) + case "auth": + return llm.ErrorFromHTTPStatus(provider, 401, info.Message, info.Raw, nil) + case "rate_limit": + return llm.ErrorFromHTTPStatus(provider, 429, info.Message, info.Raw, info.RetryAfter) + case "server": + return llm.ErrorFromHTTPStatus(provider, 500, info.Message, info.Raw, nil) + } + } + + msg := strings.ToLower(info.Message) + switch { + case strings.Contains(msg, "context length"), strings.Contains(msg, "too many tokens"): + return llm.ErrorFromHTTPStatus(provider, 413, info.Message, info.Raw, nil) + case strings.Contains(msg, "content filter"), strings.Contains(msg, "safety"): + return llm.ErrorFromHTTPStatus(provider, 400, info.Message, info.Raw, nil) + case strings.Contains(msg, "quota"), strings.Contains(msg, "billing"): + return llm.ErrorFromHTTPStatus(provider, 429, info.Message, info.Raw, info.RetryAfter) + case strings.Contains(msg, "not found"), strings.Contains(msg, "does not exist"): + return llm.ErrorFromHTTPStatus(provider, 404, info.Message, info.Raw, nil) + case strings.Contains(msg, "unauthorized"), strings.Contains(msg, "invalid key"): + return llm.ErrorFromHTTPStatus(provider, 401, info.Message, info.Raw, nil) + case strings.Contains(msg, "model") && (strings.Contains(msg, "not supported") || strings.Contains(msg, "unsupported") || strings.Contains(msg, "unknown model")): + return llm.ErrorFromHTTPStatus(provider, 400, info.Message, info.Raw, nil) + } + + if contextKind == "stream" { + return llm.NewStreamError(provider, info.Message) + } + return llm.NewNetworkError(provider, info.Message) +} + +func normalizeErrorInfo(raw any) normalizedErrorInfo { + info := normalizedErrorInfo{ + Message: "codex-app-server request failed", + Raw: raw, + } + root := asMap(raw) + nested := asMap(root["error"]) + source := root + if nested != nil { + source = nested + } + if source == nil { + source = map[string]any{} + } + + if err, ok := raw.(error); ok { + info.Message = err.Error() + } + if message := firstNonEmpty(asString(source["message"]), asString(root["message"])); message != "" { + info.Message = unwrapJSONMessage(message) + } + + if statusVal, ok := source["status"]; ok { + info.Status = asInt(statusVal, 0) + info.HasStatus = true + } else if statusVal, ok := root["status"]; ok { + info.Status = asInt(statusVal, 0) + info.HasStatus = true + } + + info.Code = firstNonEmpty( + asString(source["code"]), + asString(source["type"]), + asString(root["code"]), + asString(root["type"]), + ) + + retry := source["retryAfter"] + if retry == nil { + retry = source["retry_after"] + } + if retry == nil { + retry = root["retryAfter"] + } + if retry == nil { + retry = root["retry_after"] + } + if retry != nil { + seconds := asInt(retry, -1) + if seconds >= 0 { + d := time.Duration(seconds) * time.Second + info.RetryAfter = &d + } + } + return info +} + +func unwrapJSONMessage(message string) string { + trimmed := strings.TrimSpace(message) + if !strings.HasPrefix(trimmed, "{") { + return message + } + dec := json.NewDecoder(strings.NewReader(trimmed)) + dec.UseNumber() + var payload map[string]any + if err := dec.Decode(&payload); err != nil { + return message + } + if detail := firstNonEmpty(asString(payload["detail"]), asString(payload["message"])); detail != "" { + return detail + } + return message +} + +func isTransportFailure(code, message string) bool { + if code != "" { + switch code { + case "ECONNREFUSED", "ECONNRESET", "EPIPE", "ENOTFOUND", "EAI_AGAIN", "ETIMEDOUT", "EHOSTUNREACH", "ENETUNREACH", "ECONNABORTED": + return true + } + } + lower := strings.ToLower(message) + return strings.Contains(lower, "broken pipe") || + strings.Contains(lower, "econnrefused") || + strings.Contains(lower, "econnreset") || + strings.Contains(lower, "epipe") || + strings.Contains(lower, "spawn") +} + +func classifyByCode(code string) string { + if code == "" { + return "" + } + switch { + case strings.Contains(code, "INVALID_REQUEST"), strings.Contains(code, "BAD_REQUEST"), strings.Contains(code, "UNSUPPORTED"), strings.Contains(code, "INVALID_ARGUMENT"), strings.Contains(code, "INVALID_INPUT"): + return "invalid_request" + case strings.Contains(code, "UNAUTHENTICATED"), strings.Contains(code, "INVALID_API_KEY"), strings.Contains(code, "AUTHENTICATION"): + return "auth" + case strings.Contains(code, "RATE_LIMIT"), strings.Contains(code, "TOO_MANY_REQUESTS"), strings.Contains(code, "RESOURCE_EXHAUSTED"): + return "rate_limit" + case strings.Contains(code, "INTERNAL"), strings.Contains(code, "SERVER_ERROR"), strings.Contains(code, "UNAVAILABLE"): + return "server" + default: + return "" + } +} + +func errorsAs(err error, target any) bool { + if err == nil { + return false + } + return errors.As(err, target) +} diff --git a/internal/llm/providers/codexappserver/adapter_helpers_test.go b/internal/llm/providers/codexappserver/adapter_helpers_test.go new file mode 100644 index 00000000..52019b8a --- /dev/null +++ b/internal/llm/providers/codexappserver/adapter_helpers_test.go @@ -0,0 +1,280 @@ +package codexappserver + +import ( + "context" + "errors" + "os" + "strings" + "testing" + "time" + + "github.com/danshapiro/kilroy/internal/llm" +) + +func TestAdapterHelpers_PathAndTurnExtraction(t *testing.T) { + if isResolvableImagePath("") { + t.Fatalf("empty image path should not be resolvable") + } + if !isResolvableImagePath("file:///tmp/image.png") { + t.Fatalf("file:// image path should be resolvable") + } + if !strings.HasSuffix(resolveImagePath("file:///tmp/image.png"), "/tmp/image.png") { + t.Fatalf("resolved file path mismatch: %q", resolveImagePath("file:///tmp/image.png")) + } + + if got := extractTurnError(nil); got != nil { + t.Fatalf("nil map should not produce turn error, got %#v", got) + } + turnErr := map[string]any{"message": "bad turn"} + if got := asMap(extractTurnError(map[string]any{"turnError": turnErr})); got["message"] != "bad turn" { + t.Fatalf("turnError precedence mismatch: %#v", got) + } + rootErr := map[string]any{"message": "bad root"} + if got := asMap(extractTurnError(map[string]any{"error": rootErr})); got["message"] != "bad root" { + t.Fatalf("root error extraction mismatch: %#v", got) + } + turn := map[string]any{"id": "turn_1", "status": "failed"} + if got := asMap(extractTurnError(map[string]any{"turn": turn})); got["id"] != "turn_1" { + t.Fatalf("failed turn should be treated as error payload: %#v", got) + } +} + +func TestAdapterHelpers_NormalizeAndClassifyErrorInfo(t *testing.T) { + info := normalizeErrorInfo(map[string]any{ + "error": map[string]any{ + "message": `{"detail":"wrapped detail"}`, + "status": 429, + "code": "RATE_LIMIT", + "retryAfter": 7, + }, + }) + if info.Message != "wrapped detail" { + t.Fatalf("message: got %q want %q", info.Message, "wrapped detail") + } + if !info.HasStatus || info.Status != 429 { + t.Fatalf("status: got has=%v status=%d", info.HasStatus, info.Status) + } + if info.Code != "RATE_LIMIT" { + t.Fatalf("code: got %q", info.Code) + } + if info.RetryAfter == nil || *info.RetryAfter != 7*time.Second { + t.Fatalf("retryAfter: got %#v", info.RetryAfter) + } + + if got := unwrapJSONMessage("plain text"); got != "plain text" { + t.Fatalf("unwrap plain text: got %q", got) + } + if got := unwrapJSONMessage(`{"message":"msg fallback"}`); got != "msg fallback" { + t.Fatalf("unwrap json message fallback: got %q", got) + } + if !isTransportFailure("ECONNREFUSED", "ignored") { + t.Fatalf("expected transport failure by code") + } + if !isTransportFailure("", "broken pipe from child process") { + t.Fatalf("expected transport failure by message") + } + if isTransportFailure("INVALID_REQUEST", "input is malformed") { + t.Fatalf("did not expect invalid request to be transport failure") + } + + if got := classifyByCode("INVALID_REQUEST"); got != "invalid_request" { + t.Fatalf("classify invalid_request: got %q", got) + } + if got := classifyByCode("UNAUTHENTICATED"); got != "auth" { + t.Fatalf("classify auth: got %q", got) + } + if got := classifyByCode("RESOURCE_EXHAUSTED"); got != "rate_limit" { + t.Fatalf("classify rate_limit: got %q", got) + } + if got := classifyByCode("SERVER_ERROR"); got != "server" { + t.Fatalf("classify server: got %q", got) + } + if got := classifyByCode("SOMETHING_ELSE"); got != "" { + t.Fatalf("unexpected classification: %q", got) + } + + var target *llm.AuthenticationError + if errorsAs(nil, &target) { + t.Fatalf("errorsAs should return false for nil error") + } + authErr := llm.ErrorFromHTTPStatus("codex-app-server", 401, "bad key", nil, nil) + if !errorsAs(authErr, &target) { + t.Fatalf("expected errorsAs to match authentication error") + } +} + +func TestAdapterHelpers_MapCodexError_Branches(t *testing.T) { + if err := mapCodexError(nil, providerName, "complete"); err != nil { + t.Fatalf("nil error should stay nil, got %v", err) + } + + timeoutErr := mapCodexError(context.DeadlineExceeded, providerName, "complete") + var requestTimeout *llm.RequestTimeoutError + if !errors.As(timeoutErr, &requestTimeout) { + t.Fatalf("expected RequestTimeoutError from wrapped context deadline, got %T (%v)", timeoutErr, timeoutErr) + } + + transportErr := mapCodexError(map[string]any{ + "error": map[string]any{ + "code": "EPIPE", + "message": "broken pipe", + }, + }, providerName, "stream") + var streamErr *llm.StreamError + if !errors.As(transportErr, &streamErr) { + t.Fatalf("expected StreamError for stream transport failures, got %T (%v)", transportErr, transportErr) + } + + statusErr := mapCodexError(map[string]any{ + "error": map[string]any{ + "status": 404, + "message": "not found", + }, + }, providerName, "complete") + var notFound *llm.NotFoundError + if !errors.As(statusErr, ¬Found) { + t.Fatalf("expected NotFoundError from explicit status, got %T (%v)", statusErr, statusErr) + } + + classifiedErr := mapCodexError(map[string]any{ + "error": map[string]any{ + "code": "RATE_LIMIT", + "message": "too many requests", + }, + }, providerName, "complete") + var rateLimit *llm.RateLimitError + if !errors.As(classifiedErr, &rateLimit) { + t.Fatalf("expected RateLimitError from classified code, got %T (%v)", classifiedErr, classifiedErr) + } + + messageHintErr := mapCodexError(map[string]any{ + "message": "model not supported for this endpoint", + }, providerName, "complete") + var invalidRequest *llm.InvalidRequestError + if !errors.As(messageHintErr, &invalidRequest) { + t.Fatalf("expected InvalidRequestError from message hint, got %T (%v)", messageHintErr, messageHintErr) + } + + fallbackErr := mapCodexError(map[string]any{ + "message": "completely unknown failure", + }, providerName, "complete") + var networkErr *llm.NetworkError + if !errors.As(fallbackErr, &networkErr) { + t.Fatalf("expected NetworkError fallback for unknown complete errors, got %T (%v)", fallbackErr, fallbackErr) + } +} + +func TestAdapterHelpers_BasicLifecycleAndModelSelection(t *testing.T) { + t.Setenv(envCommand, "codex-test") + adapterFromEnv, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv: %v", err) + } + if adapterFromEnv == nil { + t.Fatalf("expected adapter from env") + } + if adapterFromEnv.Name() != providerName { + t.Fatalf("Name: got %q want %q", adapterFromEnv.Name(), providerName) + } + if adapterFromEnv.transportOptions.Command != "codex-test" { + t.Fatalf("transport command from env: got %q", adapterFromEnv.transportOptions.Command) + } + + initCalls := 0 + closeCalls := 0 + adapter := NewAdapter(AdapterOptions{ + Transport: &fakeTransport{ + initializeFn: func(ctx context.Context) error { + initCalls++ + return nil + }, + closeFn: func() error { + closeCalls++ + return nil + }, + listFn: func(ctx context.Context, params map[string]any) (modelListResponse, error) { + return modelListResponse{ + Data: []modelEntry{ + {ID: "model_a", Model: "codex-mini"}, + {ID: "model_b", Model: "codex-pro", IsDefault: true}, + }, + }, nil + }, + }, + }) + + if err := adapter.Initialize(context.Background()); err != nil { + t.Fatalf("Initialize: %v", err) + } + if initCalls != 1 { + t.Fatalf("initialize calls: got %d want 1", initCalls) + } + + def, err := adapter.GetDefaultModel(context.Background()) + if err != nil { + t.Fatalf("GetDefaultModel: %v", err) + } + if def == nil || def.ID != "model_b" { + t.Fatalf("default model mismatch: %#v", def) + } + + adapter.transport = &fakeTransport{ + closeFn: func() error { + closeCalls++ + return nil + }, + } + if err := adapter.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + if closeCalls != 1 { + t.Fatalf("close calls: got %d want 1", closeCalls) + } + + adapterNoTransport := NewAdapter(AdapterOptions{}) + if err := adapterNoTransport.Close(); err != nil { + t.Fatalf("Close without transport should succeed, got %v", err) + } +} + +func TestAdapterHelpers_ProviderOverride(t *testing.T) { + adapter := NewAdapter(AdapterOptions{Provider: "custom-codex-provider"}) + if got := adapter.Name(); got != "custom-codex-provider" { + t.Fatalf("Name: got %q want %q", got, "custom-codex-provider") + } +} + +func TestAdapterHelpers_GetTransport_CachesAndRespectsProvidedTransport(t *testing.T) { + provided := &fakeTransport{} + adapterWithProvided := NewAdapter(AdapterOptions{Transport: provided}) + got, err := adapterWithProvided.getTransport() + if err != nil { + t.Fatalf("getTransport provided: %v", err) + } + if got != provided { + t.Fatalf("expected provided transport to be returned") + } + + t.Setenv(envCommand, "") + t.Setenv(envArgs, "") + t.Setenv(envCommandArgs, "") + _ = os.Unsetenv(envCommand) + + adapter := NewAdapter(AdapterOptions{ + TransportOptions: TransportOptions{ + Command: "codex-custom", + Args: []string{"app-server", "--listen", "stdio://"}, + }, + }) + first, err := adapter.getTransport() + if err != nil { + t.Fatalf("getTransport first: %v", err) + } + second, err := adapter.getTransport() + if err != nil { + t.Fatalf("getTransport second: %v", err) + } + if first != second { + t.Fatalf("expected getTransport to cache created transport instance") + } +} diff --git a/internal/llm/providers/codexappserver/adapter_test.go b/internal/llm/providers/codexappserver/adapter_test.go new file mode 100644 index 00000000..d54bdb34 --- /dev/null +++ b/internal/llm/providers/codexappserver/adapter_test.go @@ -0,0 +1,334 @@ +package codexappserver + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + "testing" + "time" + + "github.com/danshapiro/kilroy/internal/llm" +) + +type fakeTransport struct { + initializeFn func(ctx context.Context) error + closeFn func() error + completeFn func(ctx context.Context, payload map[string]any) (map[string]any, error) + streamFn func(ctx context.Context, payload map[string]any) (*NotificationStream, error) + listFn func(ctx context.Context, params map[string]any) (modelListResponse, error) +} + +func (f *fakeTransport) Initialize(ctx context.Context) error { + if f.initializeFn != nil { + return f.initializeFn(ctx) + } + return nil +} + +func (f *fakeTransport) Close() error { + if f.closeFn != nil { + return f.closeFn() + } + return nil +} + +func (f *fakeTransport) Complete(ctx context.Context, payload map[string]any) (map[string]any, error) { + if f.completeFn != nil { + return f.completeFn(ctx, payload) + } + return map[string]any{}, nil +} + +func (f *fakeTransport) Stream(ctx context.Context, payload map[string]any) (*NotificationStream, error) { + if f.streamFn != nil { + return f.streamFn(ctx, payload) + } + events := make(chan map[string]any) + errs := make(chan error) + close(events) + close(errs) + return &NotificationStream{Notifications: events, Err: errs, closeFn: func() {}}, nil +} + +func (f *fakeTransport) ListModels(ctx context.Context, params map[string]any) (modelListResponse, error) { + if f.listFn != nil { + return f.listFn(ctx, params) + } + return modelListResponse{Data: []modelEntry{}, NextCursor: nil}, nil +} + +func TestAdapterComplete_UsesTransportAndMergesWarnings(t *testing.T) { + var seenPayload map[string]any + transport := &fakeTransport{ + completeFn: func(ctx context.Context, payload map[string]any) (map[string]any, error) { + seenPayload = payload + return map[string]any{"turn": map[string]any{"id": "turn_1", "status": "completed", "items": []any{}}}, nil + }, + } + adapter := NewAdapter(AdapterOptions{ + Transport: transport, + TranslateRequest: func(request llm.Request, streaming bool) (translateRequestResult, error) { + return translateRequestResult{ + Payload: map[string]any{"input": []any{}, "threadId": defaultThreadID}, + Warnings: []llm.Warning{{Message: "Dropped unsupported audio", Code: "unsupported_part"}}, + }, nil + }, + TranslateResponse: func(body map[string]any) (llm.Response, error) { + return llm.Response{ + ID: "resp_1", + Model: "codex-mini", + Provider: providerName, + Message: llm.Assistant("done"), + Finish: llm.FinishReason{Reason: llm.FinishReasonStop}, + Usage: llm.Usage{InputTokens: 1, OutputTokens: 2, TotalTokens: 3}, + Warnings: []llm.Warning{{Message: "Deprecated field"}}, + }, nil + }, + }) + + resp, err := adapter.Complete(context.Background(), llm.Request{Model: "codex-mini", Messages: []llm.Message{llm.User("hello")}}) + if err != nil { + t.Fatalf("Complete: %v", err) + } + if seenPayload == nil { + t.Fatalf("transport payload not captured") + } + if len(resp.Warnings) != 2 { + t.Fatalf("warnings len: got %d want 2", len(resp.Warnings)) + } + if resp.Warnings[0].Message != "Deprecated field" || resp.Warnings[1].Code != "unsupported_part" { + t.Fatalf("warnings mismatch: %+v", resp.Warnings) + } +} + +func TestAdapterComplete_MapsTurnErrors(t *testing.T) { + transport := &fakeTransport{ + completeFn: func(ctx context.Context, payload map[string]any) (map[string]any, error) { + return map[string]any{ + "turn": map[string]any{ + "id": "turn_bad", + "status": "failed", + "error": map[string]any{ + "status": 429, + "code": "RATE_LIMITED", + "message": "too many requests", + }, + }, + }, nil + }, + } + adapter := NewAdapter(AdapterOptions{ + Transport: transport, + TranslateRequest: func(request llm.Request, streaming bool) (translateRequestResult, error) { + return translateRequestResult{Payload: map[string]any{"input": []any{}, "threadId": defaultThreadID}}, nil + }, + TranslateResponse: func(body map[string]any) (llm.Response, error) { + return llm.Response{}, nil + }, + }) + + _, err := adapter.Complete(context.Background(), llm.Request{Model: "codex-mini", Messages: []llm.Message{llm.User("hello")}}) + if err == nil { + t.Fatalf("expected complete error") + } + var rateLimit *llm.RateLimitError + if !errors.As(err, &rateLimit) { + t.Fatalf("expected RateLimitError, got %T (%v)", err, err) + } +} + +func TestAdapterStream_AttachesWarningsToStreamStart(t *testing.T) { + transport := &fakeTransport{ + streamFn: func(ctx context.Context, payload map[string]any) (*NotificationStream, error) { + events := make(chan map[string]any, 2) + errs := make(chan error, 1) + events <- map[string]any{"method": "turn/started", "params": map[string]any{"turn": map[string]any{"id": "turn_1", "status": "inProgress", "items": []any{}}}} + events <- map[string]any{"method": "turn/completed", "params": map[string]any{"turn": map[string]any{"id": "turn_1", "status": "completed", "items": []any{}}}} + close(events) + close(errs) + return &NotificationStream{Notifications: events, Err: errs, closeFn: func() {}}, nil + }, + } + adapter := NewAdapter(AdapterOptions{ + Transport: transport, + TranslateRequest: func(request llm.Request, streaming bool) (translateRequestResult, error) { + return translateRequestResult{ + Payload: map[string]any{"input": []any{}, "threadId": defaultThreadID}, + Warnings: []llm.Warning{{Message: "Tool output truncated", Code: "truncated"}}, + }, nil + }, + }) + + stream, err := adapter.Stream(context.Background(), llm.Request{Model: "codex-mini", Messages: []llm.Message{llm.User("hello")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer stream.Close() + + var start *llm.StreamEvent + for event := range stream.Events() { + if event.Type == llm.StreamEventStreamStart { + copyEvent := event + start = ©Event + } + } + if start == nil { + t.Fatalf("expected stream start event") + } + if len(start.Warnings) != 1 || start.Warnings[0].Code != "truncated" { + t.Fatalf("stream start warnings mismatch: %+v", start.Warnings) + } +} + +func TestAdapter_ListModelsCachesFirstResponse(t *testing.T) { + calls := 0 + transport := &fakeTransport{ + listFn: func(ctx context.Context, params map[string]any) (modelListResponse, error) { + calls++ + return modelListResponse{Data: []modelEntry{{ID: "1", Model: "codex-mini", IsDefault: true}}}, nil + }, + } + adapter := NewAdapter(AdapterOptions{Transport: transport}) + + first, err := adapter.ListModels(context.Background(), map[string]any{"limit": 1}) + if err != nil { + t.Fatalf("ListModels first: %v", err) + } + second, err := adapter.ListModels(context.Background(), map[string]any{"limit": 99}) + if err != nil { + t.Fatalf("ListModels second: %v", err) + } + if calls != 1 { + t.Fatalf("list call count: got %d want 1", calls) + } + if len(first.Data) != 1 || len(second.Data) != 1 { + t.Fatalf("cached model list mismatch: first=%+v second=%+v", first, second) + } +} + +func TestResolveFileImages_LoadsLocalPathData(t *testing.T) { + dir := t.TempDir() + path := dir + "/image.png" + if err := os.WriteFile(path, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + resolved, err := resolveFileImages(llm.Request{ + Model: "codex-mini", + Messages: []llm.Message{{ + Role: llm.RoleUser, + Content: []llm.ContentPart{{ + Kind: llm.ContentImage, + Image: &llm.ImageData{URL: path}, + }}, + }}, + }) + if err != nil { + t.Fatalf("resolveFileImages: %v", err) + } + part := resolved.Messages[0].Content[0] + if part.Image == nil || len(part.Image.Data) == 0 || part.Image.URL != "" { + t.Fatalf("resolved image mismatch: %+v", part.Image) + } + if part.Image.MediaType != "image/png" { + t.Fatalf("media type mismatch: got %q want %q", part.Image.MediaType, "image/png") + } +} + +func TestMapCodexError_UnsupportedModelMessageBecomesInvalidRequest(t *testing.T) { + err := mapCodexError(map[string]any{ + "turn": map[string]any{ + "id": "turn_unsupported", + "status": "failed", + "error": map[string]any{ + "message": "{\"detail\":\"The 'nonexistent-model-xyz' model is not supported when using Codex with a ChatGPT account.\"}", + }, + }, + }, providerName, "complete") + + var invalid *llm.InvalidRequestError + if !errors.As(err, &invalid) { + t.Fatalf("expected InvalidRequestError, got %T (%v)", err, err) + } +} + +func TestAdapterStream_PropagatesTransportErrors(t *testing.T) { + transport := &fakeTransport{ + streamFn: func(ctx context.Context, payload map[string]any) (*NotificationStream, error) { + events := make(chan map[string]any) + errs := make(chan error, 1) + close(events) + errs <- context.DeadlineExceeded + close(errs) + return &NotificationStream{Notifications: events, Err: errs, closeFn: func() {}}, nil + }, + } + adapter := NewAdapter(AdapterOptions{ + Transport: transport, + TranslateRequest: func(request llm.Request, streaming bool) (translateRequestResult, error) { + return translateRequestResult{Payload: map[string]any{"input": []any{}, "threadId": defaultThreadID}}, nil + }, + TranslateStream: func(events <-chan map[string]any) <-chan llm.StreamEvent { + out := make(chan llm.StreamEvent) + close(out) + return out + }, + }) + + stream, err := adapter.Stream(context.Background(), llm.Request{Model: "codex-mini", Messages: []llm.Message{llm.User("hello")}}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + defer stream.Close() + + var gotErr error + timeout := time.After(2 * time.Second) +loop: + for { + select { + case event, ok := <-stream.Events(): + if !ok { + break loop + } + if event.Type == llm.StreamEventError { + gotErr = event.Err + } + case <-timeout: + t.Fatalf("timed out waiting for stream events") + } + } + if gotErr == nil { + t.Fatalf("expected stream error event") + } + var timeoutErr *llm.RequestTimeoutError + if !errors.As(gotErr, &timeoutErr) { + t.Fatalf("expected RequestTimeoutError, got %T (%v)", gotErr, gotErr) + } +} + +var _ codexTransport = (*fakeTransport)(nil) + +func TestNormalizeErrorInfo_UnwrapsJSONMessage(t *testing.T) { + info := normalizeErrorInfo(map[string]any{"message": `{"detail":"wrapped detail"}`}) + if info.Message != "wrapped detail" { + t.Fatalf("message: got %q want %q", info.Message, "wrapped detail") + } +} + +func TestParseToolCall_NormalizesArguments(t *testing.T) { + tool := parseToolCall(`{"id":"call_1","name":"search","arguments":{"q":"foo"}}`) + if tool == nil { + t.Fatalf("expected tool call") + } + if tool.ID != "call_1" || tool.Name != "search" { + t.Fatalf("tool metadata mismatch: %+v", tool) + } + if strings.TrimSpace(string(tool.Arguments)) != `{"q":"foo"}` { + t.Fatalf("tool args mismatch: %q", string(tool.Arguments)) + } + if !json.Valid(tool.Arguments) { + t.Fatalf("tool args should be valid json: %q", string(tool.Arguments)) + } +} diff --git a/internal/llm/providers/codexappserver/env_config.go b/internal/llm/providers/codexappserver/env_config.go new file mode 100644 index 00000000..28481aaf --- /dev/null +++ b/internal/llm/providers/codexappserver/env_config.go @@ -0,0 +1,103 @@ +package codexappserver + +import ( + "encoding/json" + "os" + "os/exec" + "regexp" + "strings" +) + +const ( + envCommand = "CODEX_APP_SERVER_COMMAND" + envArgs = "CODEX_APP_SERVER_ARGS" + envCommandArgs = "CODEX_APP_SERVER_COMMAND_ARGS" + envAutoDiscover = "CODEX_APP_SERVER_AUTO_DISCOVER" +) + +var ( + getenv = os.Getenv + lookPath = exec.LookPath + shellArgSplitRE = regexp.MustCompile(`(?:[^\s"']+|"[^"]*"|'[^']*')+`) +) + +func parseArgs(raw string) []string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil + } + if strings.HasPrefix(trimmed, "[") { + var parsed []string + if err := json.Unmarshal([]byte(trimmed), &parsed); err == nil { + out := make([]string, 0, len(parsed)) + for _, arg := range parsed { + if strings.TrimSpace(arg) != "" { + out = append(out, arg) + } + } + if len(out) > 0 { + return out + } + } + } + parts := shellArgSplitRE.FindAllString(trimmed, -1) + if len(parts) == 0 { + return nil + } + out := make([]string, 0, len(parts)) + for _, part := range parts { + if len(part) >= 2 { + if (strings.HasPrefix(part, "\"") && strings.HasSuffix(part, "\"")) || + (strings.HasPrefix(part, "'") && strings.HasSuffix(part, "'")) { + part = part[1 : len(part)-1] + } + } + if strings.TrimSpace(part) != "" { + out = append(out, part) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func transportOptionsFromEnv() (TransportOptions, bool) { + opts := TransportOptions{} + hasExplicitOverride := false + if cmd := strings.TrimSpace(getenv(envCommand)); cmd != "" { + opts.Command = cmd + hasExplicitOverride = true + } + argsRaw := getenv(envArgs) + if strings.TrimSpace(argsRaw) == "" { + argsRaw = getenv(envCommandArgs) + } + if args := parseArgs(argsRaw); len(args) > 0 { + opts.Args = args + hasExplicitOverride = true + } + if hasExplicitOverride { + return opts, true + } + + // If no explicit overrides are provided, only enable env registration when + // explicit auto-discovery is enabled and the default codex command is + // available on PATH. + if !isTruthyEnvValue(getenv(envAutoDiscover)) { + return TransportOptions{}, false + } + if _, err := lookPath(defaultCommand); err == nil { + return opts, true + } + return TransportOptions{}, false +} + +func isTruthyEnvValue(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "y", "on": + return true + default: + return false + } +} diff --git a/internal/llm/providers/codexappserver/env_config_test.go b/internal/llm/providers/codexappserver/env_config_test.go new file mode 100644 index 00000000..dfba059d --- /dev/null +++ b/internal/llm/providers/codexappserver/env_config_test.go @@ -0,0 +1,109 @@ +package codexappserver + +import ( + "errors" + "testing" +) + +func TestParseArgs_JSONAndShellFormats(t *testing.T) { + if got := parseArgs(`["app-server","--listen","stdio://"]`); len(got) != 3 || got[0] != "app-server" { + t.Fatalf("json parse args mismatch: %#v", got) + } + if got := parseArgs(`app-server --listen "stdio://"`); len(got) != 3 || got[2] != "stdio://" { + t.Fatalf("shell parse args mismatch: %#v", got) + } +} + +func TestTransportOptionsFromEnv(t *testing.T) { + origGetenv := getenv + origLookPath := lookPath + t.Cleanup(func() { + getenv = origGetenv + lookPath = origLookPath + }) + values := map[string]string{ + envCommand: "codex-bin", + envArgs: `app-server --listen stdio://`, + } + getenv = func(key string) string { return values[key] } + + opts, ok := transportOptionsFromEnv() + if !ok { + t.Fatalf("expected enabled transport options") + } + if opts.Command != "codex-bin" { + t.Fatalf("command: got %q", opts.Command) + } + if len(opts.Args) != 3 { + t.Fatalf("args: %#v", opts.Args) + } +} + +func TestTransportOptionsFromEnv_DisabledWithoutExplicitOverridesOrOptIn(t *testing.T) { + origGetenv := getenv + origLookPath := lookPath + t.Cleanup(func() { + getenv = origGetenv + lookPath = origLookPath + }) + getenv = func(string) string { return "" } + lookPath = func(string) (string, error) { return "/usr/bin/codex", nil } + + opts, ok := transportOptionsFromEnv() + if ok { + t.Fatalf("expected transport options to remain disabled without explicit opt-in") + } + if opts.Command != "" { + t.Fatalf("command: got %q want empty", opts.Command) + } + if len(opts.Args) != 0 { + t.Fatalf("args: %#v", opts.Args) + } +} + +func TestTransportOptionsFromEnv_EnabledWhenAutoDiscoverOptInAndCodexPresent(t *testing.T) { + origGetenv := getenv + origLookPath := lookPath + t.Cleanup(func() { + getenv = origGetenv + lookPath = origLookPath + }) + values := map[string]string{ + envAutoDiscover: "1", + } + getenv = func(key string) string { return values[key] } + lookPath = func(string) (string, error) { return "/usr/bin/codex", nil } + + opts, ok := transportOptionsFromEnv() + if !ok { + t.Fatalf("expected transport options enabled with explicit auto-discover opt-in") + } + if opts.Command != "" { + t.Fatalf("command: got %q want empty", opts.Command) + } + if len(opts.Args) != 0 { + t.Fatalf("args: %#v", opts.Args) + } +} + +func TestTransportOptionsFromEnv_DisabledWhenAutoDiscoverOptInButCodexMissing(t *testing.T) { + origGetenv := getenv + origLookPath := lookPath + t.Cleanup(func() { + getenv = origGetenv + lookPath = origLookPath + }) + values := map[string]string{ + envAutoDiscover: "true", + } + getenv = func(key string) string { return values[key] } + lookPath = func(string) (string, error) { return "", errors.New("not found") } + + opts, ok := transportOptionsFromEnv() + if ok { + t.Fatalf("expected disabled transport options when codex is unavailable") + } + if opts.Command != "" || len(opts.Args) != 0 { + t.Fatalf("expected empty opts when disabled, got: %+v", opts) + } +} diff --git a/internal/llm/providers/codexappserver/protocol_types.go b/internal/llm/providers/codexappserver/protocol_types.go new file mode 100644 index 00000000..50f4d2d2 --- /dev/null +++ b/internal/llm/providers/codexappserver/protocol_types.go @@ -0,0 +1,39 @@ +package codexappserver + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +type jsonRPCMessage struct { + JSONRPC string `json:"jsonrpc,omitempty"` + ID any `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params any `json:"params,omitempty"` + Result any `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type modelReasoningEffort struct { + Effort string `json:"effort"` + Description string `json:"description"` +} + +type modelEntry struct { + ID string `json:"id"` + Model string `json:"model"` + DisplayName string `json:"displayName"` + Hidden bool `json:"hidden"` + IsDefault bool `json:"isDefault"` + DefaultReasoningEffort string `json:"defaultReasoningEffort"` + ReasoningEffort []modelReasoningEffort `json:"reasoningEffort"` + InputModalities []string `json:"inputModalities"` + SupportsPersonality bool `json:"supportsPersonality"` + Upgrade any `json:"upgrade,omitempty"` +} + +type modelListResponse struct { + Data []modelEntry `json:"data"` + NextCursor any `json:"nextCursor"` +} diff --git a/internal/llm/providers/codexappserver/request_translator.go b/internal/llm/providers/codexappserver/request_translator.go new file mode 100644 index 00000000..819143e4 --- /dev/null +++ b/internal/llm/providers/codexappserver/request_translator.go @@ -0,0 +1,713 @@ +package codexappserver + +import ( + "bytes" + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/danshapiro/kilroy/internal/llm" +) + +const ( + transcriptBeginMarker = "[[[UNIFIED_TRANSCRIPT_V1_BEGIN]]]" + transcriptEndMarker = "[[[UNIFIED_TRANSCRIPT_V1_END]]]" + transcriptPayloadBeginMarker = "[[[UNIFIED_TRANSCRIPT_PAYLOAD_BEGIN]]]" + transcriptPayloadEndMarker = "[[[UNIFIED_TRANSCRIPT_PAYLOAD_END]]]" + toolCallBeginMarker = "[[TOOL_CALL]]" + toolCallEndMarker = "[[/TOOL_CALL]]" + defaultThreadID = "thread_stateless" + transcriptVersion = "unified.codex-app-server.request.v1" + defaultReasoningEffort = "high" +) + +var ( + jsonObjectOutputSchema = map[string]any{ + "type": "object", + "properties": map[string]any{}, + "additionalProperties": true, + } + supportedReasoningEfforts = map[string]struct{}{ + "none": {}, + "minimal": {}, + "low": {}, + "medium": {}, + "high": {}, + "xhigh": {}, + } + turnOptionKeyMap = map[string]string{ + "cwd": "cwd", + "approvalPolicy": "approvalPolicy", + "approval_policy": "approvalPolicy", + "sandboxPolicy": "sandboxPolicy", + "sandbox_policy": "sandboxPolicy", + "model": "model", + "effort": "effort", + "summary": "summary", + "personality": "personality", + "collaborationMode": "collaborationMode", + "collaboration_mode": "collaborationMode", + "outputSchema": "outputSchema", + "output_schema": "outputSchema", + } + controlOptionKeyMap = map[string]string{ + "temperature": "temperature", + "topP": "topP", + "top_p": "topP", + "maxTokens": "maxTokens", + "max_tokens": "maxTokens", + "stopSequences": "stopSequences", + "stop_sequences": "stopSequences", + "metadata": "metadata", + "reasoningEffort": "reasoningEffort", + "reasoning_effort": "reasoningEffort", + } + uriSchemeRE = regexp.MustCompile(`^[a-zA-Z][a-zA-Z\d+\-.]*:`) +) + +type resolvedToolChoice struct { + Mode string `json:"mode"` + ToolName string `json:"toolName,omitempty"` +} + +type transcriptControls struct { + Model string `json:"model"` + ToolChoice resolvedToolChoice `json:"toolChoice"` + ResponseFormat map[string]any `json:"responseFormat"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + MaxTokens *int `json:"maxTokens,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ReasoningEff string `json:"reasoningEffort,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type transcriptPayload struct { + Version string `json:"version"` + ToolCallProtocol map[string]string `json:"toolCallProtocol"` + Controls transcriptControls `json:"controls"` + Tools []map[string]any `json:"tools"` + History []map[string]any `json:"history"` +} + +type translateRequestResult struct { + Payload map[string]any + Warnings []llm.Warning +} + +func translateRequest(request llm.Request, _ bool) (translateRequestResult, error) { + warnings := make([]llm.Warning, 0, 4) + toolChoice := normalizeToolChoice(request) + if err := validateToolChoice(toolChoice, request); err != nil { + return translateRequestResult{}, err + } + + var reasoningInput string + if request.ReasoningEffort != nil { + reasoningInput = *request.ReasoningEffort + } + reasoningEffort := normalizeReasoningEffort(reasoningInput, &warnings, "request.reasoningEffort") + if reasoningEffort == "" { + reasoningEffort = defaultReasoningEffort + } + + controls := transcriptControls{ + Model: request.Model, + ToolChoice: toolChoice, + ResponseFormat: responseFormatForTranscript(request.ResponseFormat), + Temperature: request.Temperature, + TopP: request.TopP, + MaxTokens: request.MaxTokens, + StopSequences: append([]string{}, request.StopSequences...), + ReasoningEff: reasoningEffort, + Metadata: metadataForTranscript(request.Metadata), + } + + history, imageInputs := translateMessages(request.Messages, &warnings) + + params := map[string]any{ + "threadId": defaultThreadID, + "model": request.Model, + "effort": controls.ReasoningEff, + } + if outputSchema := resolveOutputSchema(request.ResponseFormat); outputSchema != nil { + params["outputSchema"] = outputSchema + } + + applyProviderOptions(request, params, &controls, &warnings) + if model := strings.TrimSpace(asString(params["model"])); model != "" { + controls.Model = model + } + paramEffort := normalizeReasoningEffort(asString(params["effort"]), &warnings, "codex_app_server.effort") + if paramEffort == "" { + paramEffort = controls.ReasoningEff + } + if paramEffort == "" { + paramEffort = defaultReasoningEffort + } + params["effort"] = paramEffort + controls.ReasoningEff = paramEffort + + payload := transcriptPayload{ + Version: transcriptVersion, + ToolCallProtocol: map[string]string{ + "beginMarker": toolCallBeginMarker, + "endMarker": toolCallEndMarker, + }, + Controls: controls, + Tools: buildToolsSection(request), + History: history, + } + + transcript, err := buildTranscript(payload, toolChoice) + if err != nil { + return translateRequestResult{}, err + } + + input := make([]any, 0, 1+len(imageInputs)) + input = append(input, map[string]any{ + "type": "text", + "text": transcript, + "text_elements": []any{}, + }) + for _, in := range imageInputs { + input = append(input, in) + } + params["input"] = input + + return translateRequestResult{Payload: params, Warnings: warnings}, nil +} + +func normalizeToolChoice(request llm.Request) resolvedToolChoice { + if request.ToolChoice != nil { + mode := strings.TrimSpace(strings.ToLower(request.ToolChoice.Mode)) + if mode == "named" { + return resolvedToolChoice{Mode: "named", ToolName: strings.TrimSpace(request.ToolChoice.Name)} + } + if mode == "" { + mode = "auto" + } + return resolvedToolChoice{Mode: mode} + } + if len(request.Tools) > 0 { + return resolvedToolChoice{Mode: "auto"} + } + return resolvedToolChoice{Mode: "none"} +} + +func validateToolChoice(choice resolvedToolChoice, request llm.Request) error { + toolNames := make(map[string]struct{}, len(request.Tools)) + for _, tool := range request.Tools { + toolNames[strings.TrimSpace(tool.Name)] = struct{}{} + } + switch choice.Mode { + case "required": + if len(toolNames) == 0 { + return fmt.Errorf("toolChoice.mode=\"required\" requires at least one tool definition") + } + case "named": + if len(toolNames) == 0 { + return fmt.Errorf("toolChoice.mode=\"named\" requires tools, but no tools were provided") + } + if strings.TrimSpace(choice.ToolName) == "" { + return fmt.Errorf("toolChoice.mode=\"named\" requires a non-empty toolName") + } + if _, ok := toolNames[choice.ToolName]; !ok { + return fmt.Errorf("toolChoice.mode=\"named\" references unknown tool %q", choice.ToolName) + } + } + return nil +} + +func responseFormatForTranscript(format *llm.ResponseFormat) map[string]any { + if format == nil { + return map[string]any{"type": "text"} + } + out := map[string]any{"type": format.Type} + if format.JSONSchema != nil { + out["jsonSchema"] = format.JSONSchema + } + if format.Strict { + out["strict"] = true + } + return out +} + +func metadataForTranscript(metadata map[string]string) map[string]interface{} { + if len(metadata) == 0 { + return nil + } + out := make(map[string]interface{}, len(metadata)) + for key, value := range metadata { + out[key] = value + } + return out +} + +func resolveOutputSchema(responseFormat *llm.ResponseFormat) map[string]any { + if responseFormat == nil || strings.EqualFold(responseFormat.Type, "text") || strings.TrimSpace(responseFormat.Type) == "" { + return nil + } + if strings.EqualFold(responseFormat.Type, "json") { + return deepCopyMap(jsonObjectOutputSchema) + } + if responseFormat.JSONSchema == nil { + return nil + } + return deepCopyMap(responseFormat.JSONSchema) +} + +func buildToolsSection(request llm.Request) []map[string]any { + if len(request.Tools) == 0 { + return nil + } + out := make([]map[string]any, 0, len(request.Tools)) + for _, tool := range request.Tools { + out = append(out, map[string]any{ + "name": tool.Name, + "description": tool.Description, + "parameters": tool.Parameters, + }) + } + return out +} + +func translateMessages(messages []llm.Message, warnings *[]llm.Warning) ([]map[string]any, []map[string]any) { + history := make([]map[string]any, 0, len(messages)) + imageInputs := make([]map[string]any, 0, 2) + imageIndex := 0 + nextImageID := func() string { + imageIndex++ + return fmt.Sprintf("img_%04d", imageIndex) + } + + for messageIndex, message := range messages { + parts := make([]map[string]any, 0, len(message.Content)) + for partIndex, part := range message.Content { + parts = append(parts, translatePart(part, messageIndex, partIndex, warnings, &imageInputs, nextImageID)) + } + history = append(history, map[string]any{ + "index": messageIndex, + "role": string(message.Role), + "name": message.Name, + "toolCallId": message.ToolCallID, + "parts": parts, + }) + } + + return history, imageInputs +} + +func translatePart( + part llm.ContentPart, + messageIndex int, + partIndex int, + warnings *[]llm.Warning, + imageInputs *[]map[string]any, + nextImageID func() string, +) map[string]any { + switch part.Kind { + case llm.ContentText: + return map[string]any{ + "index": partIndex, + "kind": "text", + "text": part.Text, + } + case llm.ContentImage: + imageID := nextImageID() + if part.Image != nil { + if len(part.Image.Data) > 0 { + mediaType := strings.TrimSpace(part.Image.MediaType) + if mediaType == "" { + mediaType = "image/png" + } + *imageInputs = append(*imageInputs, map[string]any{ + "type": "image", + "url": llm.DataURI(mediaType, part.Image.Data), + }) + return map[string]any{ + "index": partIndex, + "kind": "image", + "assetId": imageID, + "inputType": "image", + "source": "inline_data", + "mediaType": mediaType, + "detail": part.Image.Detail, + } + } + if url := strings.TrimSpace(part.Image.URL); url != "" { + if isLikelyLocalPath(url) { + *imageInputs = append(*imageInputs, map[string]any{"type": "localImage", "path": url}) + return map[string]any{ + "index": partIndex, + "kind": "image", + "assetId": imageID, + "inputType": "localImage", + "source": "local_path", + "path": url, + "detail": part.Image.Detail, + } + } + *imageInputs = append(*imageInputs, map[string]any{"type": "image", "url": url}) + return map[string]any{ + "index": partIndex, + "kind": "image", + "assetId": imageID, + "inputType": "image", + "source": "remote_url", + "url": url, + "detail": part.Image.Detail, + } + } + } + *warnings = append(*warnings, llm.Warning{ + Code: "unsupported_part", + Message: "Image content parts without data or url cannot be attached and were translated to fallback text", + }) + return map[string]any{ + "index": partIndex, + "kind": "image", + "assetId": imageID, + "fallback": "missing_image_data_or_url", + } + case llm.ContentAudio: + *warnings = append(*warnings, warningForFallback("Audio")) + byteLength := 0 + url := "" + mediaType := "" + if part.Audio != nil { + byteLength = len(part.Audio.Data) + url = part.Audio.URL + mediaType = part.Audio.MediaType + } + return map[string]any{ + "index": partIndex, + "kind": "audio", + "fallback": map[string]any{ + "url": url, + "mediaType": mediaType, + "byteLength": byteLength, + }, + } + case llm.ContentDocument: + *warnings = append(*warnings, warningForFallback("Document")) + byteLength := 0 + url := "" + mediaType := "" + filename := "" + if part.Document != nil { + byteLength = len(part.Document.Data) + url = part.Document.URL + mediaType = part.Document.MediaType + filename = part.Document.FileName + } + return map[string]any{ + "index": partIndex, + "kind": "document", + "fallback": map[string]any{ + "url": url, + "mediaType": mediaType, + "fileName": filename, + "byteLength": byteLength, + }, + } + case llm.ContentToolCall: + if part.ToolCall == nil { + break + } + value, raw := normalizeToolArguments(part.ToolCall.Arguments) + protocolPayload := map[string]any{ + "id": part.ToolCall.ID, + "name": part.ToolCall.Name, + "arguments": value, + } + protocolJSON, _ := json.Marshal(protocolPayload) + return map[string]any{ + "index": partIndex, + "kind": "tool_call", + "id": part.ToolCall.ID, + "name": part.ToolCall.Name, + "arguments": value, + "rawArguments": raw, + "protocolBlock": strings.Join([]string{ + toolCallBeginMarker, + string(protocolJSON), + toolCallEndMarker, + }, "\n"), + } + case llm.ContentToolResult: + if part.ToolResult == nil { + break + } + item := map[string]any{ + "index": partIndex, + "kind": "tool_result", + "toolCallId": part.ToolResult.ToolCallID, + "content": part.ToolResult.Content, + "isError": part.ToolResult.IsError, + } + if len(part.ToolResult.ImageData) > 0 { + mediaType := strings.TrimSpace(part.ToolResult.ImageMediaType) + if mediaType == "" { + mediaType = "image/png" + } + item["imageDataUri"] = llm.DataURI(mediaType, part.ToolResult.ImageData) + item["imageMediaType"] = mediaType + } + return item + case llm.ContentThinking: + if part.Thinking == nil { + break + } + return map[string]any{ + "index": partIndex, + "kind": "thinking", + "text": part.Thinking.Text, + "signature": part.Thinking.Signature, + "redacted": false, + } + case llm.ContentRedThinking: + if part.Thinking == nil { + break + } + return map[string]any{ + "index": partIndex, + "kind": "redacted_thinking", + "text": part.Thinking.Text, + "signature": part.Thinking.Signature, + "redacted": true, + } + default: + if kind := strings.TrimSpace(string(part.Kind)); kind != "" { + *warnings = append(*warnings, warningForFallback(fmt.Sprintf("Custom (%s)", kind))) + fallback := map[string]any{ + "index": partIndex, + "kind": kind, + "fallbackKind": "custom", + } + if part.Data != nil { + fallback["data"] = part.Data + } + return fallback + } + *warnings = append(*warnings, llm.Warning{ + Code: "unsupported_part", + Message: fmt.Sprintf("Unknown content part kind at message index %d was translated to fallback text", messageIndex), + }) + return map[string]any{ + "index": partIndex, + "kind": "unknown", + "fallback": true, + } + } + + *warnings = append(*warnings, llm.Warning{ + Code: "unsupported_part", + Message: fmt.Sprintf("Content part kind %q at message index %d was empty and translated to fallback text", part.Kind, messageIndex), + }) + return map[string]any{ + "index": partIndex, + "kind": string(part.Kind), + "fallback": true, + } +} + +func normalizeToolArguments(arguments json.RawMessage) (any, string) { + trimmed := strings.TrimSpace(string(arguments)) + if trimmed == "" { + return map[string]any{}, "{}" + } + dec := json.NewDecoder(bytes.NewReader([]byte(trimmed))) + dec.UseNumber() + var parsed any + if err := dec.Decode(&parsed); err != nil { + return trimmed, trimmed + } + return parsed, trimmed +} + +func warningForFallback(kind string) llm.Warning { + return llm.Warning{ + Code: "unsupported_part", + Message: fmt.Sprintf("%s content parts are not natively supported by codex-app-server and were translated to deterministic transcript fallback text", kind), + } +} + +func warningForUnsupportedProviderOption(key string) llm.Warning { + return llm.Warning{ + Code: "unsupported_option", + Message: fmt.Sprintf("Provider option codex_app_server.%s is not supported and was ignored", key), + } +} + +func normalizeReasoningEffort(value string, warnings *[]llm.Warning, source string) string { + normalized := strings.ToLower(strings.TrimSpace(value)) + if normalized == "" { + return "" + } + if _, ok := supportedReasoningEfforts[normalized]; ok { + return normalized + } + if warnings != nil { + *warnings = append(*warnings, llm.Warning{ + Code: "unsupported_option", + Message: fmt.Sprintf( + "%s value %q is unsupported and was ignored (expected none, minimal, low, medium, high, or xhigh)", + source, + value, + ), + }) + } + return "" +} + +func applyProviderOptions( + request llm.Request, + params map[string]any, + controls *transcriptControls, + warnings *[]llm.Warning, +) { + options := codexProviderOptions(request.ProviderOptions) + if len(options) == 0 { + return + } + + for key, value := range options { + if turnKey, ok := turnOptionKeyMap[key]; ok { + params[turnKey] = value + continue + } + if controlKey, ok := controlOptionKeyMap[key]; ok { + applyControlOverride(controlKey, value, controls, warnings, key) + continue + } + *warnings = append(*warnings, warningForUnsupportedProviderOption(key)) + } +} + +func codexProviderOptions(options map[string]any) map[string]any { + if len(options) == 0 { + return nil + } + for _, key := range []string{"codex_app_server", "codex-app-server", "codexappserver"} { + if raw, ok := options[key]; ok { + if m := asMap(raw); m != nil { + return m + } + } + } + return nil +} + +func applyControlOverride( + key string, + value any, + controls *transcriptControls, + warnings *[]llm.Warning, + rawKey string, +) { + switch key { + case "temperature": + if f, ok := value.(float64); ok { + controls.Temperature = &f + return + } + case "topP": + if f, ok := value.(float64); ok { + controls.TopP = &f + return + } + case "maxTokens": + if n, ok := value.(float64); ok { + i := int(n) + controls.MaxTokens = &i + return + } + if n, ok := value.(int); ok { + controls.MaxTokens = &n + return + } + case "stopSequences": + if arr, ok := value.([]any); ok { + out := make([]string, 0, len(arr)) + for _, item := range arr { + s := asString(item) + if s == "" { + *outWarning(warnings) = append(*outWarning(warnings), warningForUnsupportedProviderOption(rawKey)) + return + } + out = append(out, s) + } + controls.StopSequences = out + return + } + if arr, ok := value.([]string); ok { + controls.StopSequences = append([]string{}, arr...) + return + } + case "metadata": + if rec := asMap(value); rec != nil { + if controls.Metadata == nil { + controls.Metadata = map[string]interface{}{} + } + for mk, mv := range rec { + controls.Metadata[mk] = mv + } + return + } + case "reasoningEffort": + normalized := normalizeReasoningEffort(asString(value), warnings, fmt.Sprintf("codex_app_server.%s", rawKey)) + if normalized != "" { + controls.ReasoningEff = normalized + } + return + } + *warnings = append(*warnings, warningForUnsupportedProviderOption(rawKey)) +} + +func outWarning(w *[]llm.Warning) *[]llm.Warning { return w } + +func buildTranscript(payload transcriptPayload, choice resolvedToolChoice) (string, error) { + toolChoiceLine := "Tool choice policy: " + choice.Mode + if choice.Mode == "named" { + toolChoiceLine = fmt.Sprintf("Tool choice policy: named (%s)", choice.ToolName) + } + payloadJSON, err := json.Marshal(payload) + if err != nil { + return "", err + } + lines := []string{ + transcriptBeginMarker, + "Stateless transcript payload for unified-llm codex-app-server translation.", + "Treat the payload as the authoritative full conversation history.", + "When emitting tool calls, use deterministic protocol blocks exactly:", + toolCallBeginMarker, + `{"id":"call_","name":"","arguments":{}}`, + toolCallEndMarker, + "Do not wrap tool-call protocol blocks in markdown fences.", + toolChoiceLine, + transcriptPayloadBeginMarker, + string(payloadJSON), + transcriptPayloadEndMarker, + transcriptEndMarker, + } + return strings.Join(lines, "\n"), nil +} + +func isLikelyLocalPath(url string) bool { + url = strings.TrimSpace(url) + if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { + return false + } + if strings.HasPrefix(url, "data:") { + return false + } + if strings.HasPrefix(url, "file://") { + return true + } + if uriSchemeRE.MatchString(url) { + return false + } + return true +} diff --git a/internal/llm/providers/codexappserver/request_translator_controls_test.go b/internal/llm/providers/codexappserver/request_translator_controls_test.go new file mode 100644 index 00000000..4c7348c0 --- /dev/null +++ b/internal/llm/providers/codexappserver/request_translator_controls_test.go @@ -0,0 +1,109 @@ +package codexappserver + +import ( + "testing" + + "github.com/danshapiro/kilroy/internal/llm" +) + +func TestRequestTranslator_ApplyControlOverride_SupportedValues(t *testing.T) { + controls := &transcriptControls{} + warnings := []llm.Warning{} + + applyControlOverride("temperature", 0.7, controls, &warnings, "temperature") + applyControlOverride("topP", 0.9, controls, &warnings, "topP") + applyControlOverride("maxTokens", 42.0, controls, &warnings, "maxTokens") + applyControlOverride("maxTokens", 64, controls, &warnings, "maxTokens") + applyControlOverride("stopSequences", []any{"END", "STOP"}, controls, &warnings, "stopSequences") + applyControlOverride("metadata", map[string]any{"team": "qa"}, controls, &warnings, "metadata") + applyControlOverride("reasoningEffort", "medium", controls, &warnings, "reasoningEffort") + + if controls.Temperature == nil || *controls.Temperature != 0.7 { + t.Fatalf("temperature: %#v", controls.Temperature) + } + if controls.TopP == nil || *controls.TopP != 0.9 { + t.Fatalf("topP: %#v", controls.TopP) + } + if controls.MaxTokens == nil || *controls.MaxTokens != 64 { + t.Fatalf("maxTokens: %#v", controls.MaxTokens) + } + if len(controls.StopSequences) != 2 || controls.StopSequences[0] != "END" || controls.StopSequences[1] != "STOP" { + t.Fatalf("stopSequences: %#v", controls.StopSequences) + } + if controls.Metadata["team"] != "qa" { + t.Fatalf("metadata: %#v", controls.Metadata) + } + if controls.ReasoningEff != "medium" { + t.Fatalf("reasoningEffort: %q", controls.ReasoningEff) + } + if len(warnings) != 0 { + t.Fatalf("did not expect warnings for supported values: %+v", warnings) + } +} + +func TestRequestTranslator_ApplyControlOverride_InvalidValuesEmitWarnings(t *testing.T) { + controls := &transcriptControls{} + warnings := []llm.Warning{} + + applyControlOverride("stopSequences", []any{"END", 2}, controls, &warnings, "stopSequences") + applyControlOverride("reasoningEffort", "invalid-effort", controls, &warnings, "reasoningEffort") + applyControlOverride("temperature", "not-a-number", controls, &warnings, "temperature") + applyControlOverride("unknownKey", true, controls, &warnings, "unknownKey") + + if len(controls.StopSequences) != 0 { + t.Fatalf("stopSequences should remain unset on invalid input: %#v", controls.StopSequences) + } + if controls.ReasoningEff != "" { + t.Fatalf("reasoningEffort should remain empty for invalid value: %q", controls.ReasoningEff) + } + if controls.Temperature != nil { + t.Fatalf("temperature should remain nil on invalid input: %#v", controls.Temperature) + } + if len(warnings) < 4 { + t.Fatalf("expected warnings for invalid values, got %d (%+v)", len(warnings), warnings) + } +} + +func TestRequestTranslator_ApplyProviderOptions_MapsKnownKeysAndWarnsUnknown(t *testing.T) { + params := map[string]any{} + controls := &transcriptControls{} + warnings := []llm.Warning{} + + applyProviderOptions(llm.Request{ + ProviderOptions: map[string]any{ + "codex_app_server": map[string]any{ + "cwd": "/tmp/project", + "approval_policy": "never", + "temperature": 0.2, + "reasoning_effort": "high", + "unsupportedX": true, + }, + }, + }, params, controls, &warnings) + + if params["cwd"] != "/tmp/project" { + t.Fatalf("cwd mapping: %#v", params["cwd"]) + } + if params["approvalPolicy"] != "never" { + t.Fatalf("approvalPolicy mapping: %#v", params["approvalPolicy"]) + } + if controls.Temperature == nil || *controls.Temperature != 0.2 { + t.Fatalf("temperature override: %#v", controls.Temperature) + } + if controls.ReasoningEff != "high" { + t.Fatalf("reasoningEffort override: %q", controls.ReasoningEff) + } + if len(warnings) == 0 { + t.Fatalf("expected warning for unsupported provider option") + } +} + +func TestRequestTranslator_WarningHelpers(t *testing.T) { + w := warningForUnsupportedProviderOption("x_opt") + if w.Code != "unsupported_option" { + t.Fatalf("warning code: %q", w.Code) + } + if out := outWarning(&[]llm.Warning{}); out == nil { + t.Fatalf("outWarning should return same pointer") + } +} diff --git a/internal/llm/providers/codexappserver/request_translator_test.go b/internal/llm/providers/codexappserver/request_translator_test.go new file mode 100644 index 00000000..c8956a8e --- /dev/null +++ b/internal/llm/providers/codexappserver/request_translator_test.go @@ -0,0 +1,237 @@ +package codexappserver + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/danshapiro/kilroy/internal/llm" +) + +func mustTranscriptPayload(t *testing.T, params map[string]any) map[string]any { + t.Helper() + input := asSlice(params["input"]) + if len(input) == 0 { + t.Fatalf("missing input") + } + textItem := asMap(input[0]) + if asString(textItem["type"]) != "text" { + t.Fatalf("first input item is not text: %#v", textItem) + } + transcript := asString(textItem["text"]) + if !strings.Contains(transcript, transcriptPayloadBeginMarker) || !strings.Contains(transcript, transcriptPayloadEndMarker) { + t.Fatalf("missing transcript payload markers") + } + start := strings.Index(transcript, transcriptPayloadBeginMarker+"\n") + if start < 0 { + t.Fatalf("missing payload start marker") + } + start += len(transcriptPayloadBeginMarker) + 1 + end := strings.Index(transcript[start:], "\n"+transcriptPayloadEndMarker) + if end < 0 { + t.Fatalf("missing payload end marker") + } + payloadJSON := transcript[start : start+end] + var payload map[string]any + if err := json.Unmarshal([]byte(payloadJSON), &payload); err != nil { + t.Fatalf("payload json unmarshal: %v", err) + } + return payload +} + +func TestTranslateRequest_FullSurface(t *testing.T) { + temperature := 0.3 + topP := 0.8 + maxTokens := 300 + reasoning := "high" + + request := llm.Request{ + Model: "gpt-5.2-codex", + Messages: []llm.Message{ + llm.System("System guardrails"), + llm.Developer("Developer instruction"), + { + Role: llm.RoleUser, + Content: []llm.ContentPart{ + {Kind: llm.ContentText, Text: "What is in this image?"}, + {Kind: llm.ContentImage, Image: &llm.ImageData{URL: "https://example.com/cat.png", Detail: "high"}}, + }, + }, + { + Role: llm.RoleAssistant, + Content: []llm.ContentPart{ + {Kind: llm.ContentText, Text: "Let me inspect it."}, + {Kind: llm.ContentToolCall, ToolCall: &llm.ToolCallData{ID: "call_weather", Name: "get_weather", Arguments: json.RawMessage(`{"city":"SF"}`)}}, + }, + }, + llm.ToolResultNamed("call_weather", "get_weather", map[string]any{"temperature": "72F"}, false), + }, + Tools: []llm.ToolDefinition{{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{"city": map[string]any{"type": "string"}}, + "required": []any{"city"}, + }, + }}, + ToolChoice: &llm.ToolChoice{Mode: "named", Name: "get_weather"}, + ResponseFormat: &llm.ResponseFormat{ + Type: "json_schema", + JSONSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"answer": map[string]any{"type": "string"}}, + "required": []any{"answer"}, + }, + }, + Temperature: &temperature, + TopP: &topP, + MaxTokens: &maxTokens, + StopSequences: []string{""}, + ReasoningEffort: &reasoning, + Metadata: map[string]string{ + "traceId": "trace-123", + "tenant": "acme", + }, + ProviderOptions: map[string]any{ + "codex_app_server": map[string]any{ + "cwd": "/tmp/project", + "summary": "concise", + "personality": "pragmatic", + }, + }, + } + + translated, err := translateRequest(request, false) + if err != nil { + t.Fatalf("translateRequest: %v", err) + } + if len(translated.Warnings) != 0 { + t.Fatalf("unexpected warnings: %+v", translated.Warnings) + } + params := translated.Payload + if got := asString(params["threadId"]); got != defaultThreadID { + t.Fatalf("threadId: got %q want %q", got, defaultThreadID) + } + if got := asString(params["model"]); got != "gpt-5.2-codex" { + t.Fatalf("model: got %q", got) + } + if got := asString(params["cwd"]); got != "/tmp/project" { + t.Fatalf("cwd: got %q", got) + } + if got := asString(params["summary"]); got != "concise" { + t.Fatalf("summary: got %q", got) + } + if got := asString(params["personality"]); got != "pragmatic" { + t.Fatalf("personality: got %q", got) + } + if params["outputSchema"] == nil { + t.Fatalf("expected outputSchema to be set") + } + + input := asSlice(params["input"]) + if len(input) != 2 { + t.Fatalf("input len: got %d want 2", len(input)) + } + imageInput := asMap(input[1]) + if asString(imageInput["type"]) != "image" || asString(imageInput["url"]) != "https://example.com/cat.png" { + t.Fatalf("image input mismatch: %#v", imageInput) + } + + payload := mustTranscriptPayload(t, params) + if got := asString(payload["version"]); got != transcriptVersion { + t.Fatalf("payload version: got %q want %q", got, transcriptVersion) + } + controls := asMap(payload["controls"]) + if got := asString(controls["model"]); got != "gpt-5.2-codex" { + t.Fatalf("controls.model: got %q", got) + } + if got := asString(asMap(controls["toolChoice"])["mode"]); got != "named" { + t.Fatalf("tool choice mode: got %q", got) + } + if got := asString(asMap(controls["toolChoice"])["toolName"]); got != "get_weather" { + t.Fatalf("tool choice name: got %q", got) + } +} + +func TestTranslateRequest_FallbackWarnings(t *testing.T) { + request := llm.Request{ + Model: "codex-mini", + Messages: []llm.Message{{ + Role: llm.RoleUser, + Content: []llm.ContentPart{ + {Kind: llm.ContentAudio, Audio: &llm.AudioData{URL: "https://example.com/a.wav", MediaType: "audio/wav"}}, + {Kind: llm.ContentDocument, Document: &llm.DocumentData{URL: "https://example.com/r.pdf", MediaType: "application/pdf", FileName: "r.pdf"}}, + {Kind: llm.ContentKind("custom_note"), Data: map[string]any{"topic": "ops", "priority": "high"}}, + }, + }}, + } + + translated, err := translateRequest(request, false) + if err != nil { + t.Fatalf("translateRequest: %v", err) + } + if len(translated.Warnings) != 3 { + t.Fatalf("warning len: got %d want 3", len(translated.Warnings)) + } + for _, w := range translated.Warnings { + if w.Code != "unsupported_part" { + t.Fatalf("warning code: got %q want unsupported_part", w.Code) + } + } + if translated.Warnings[2].Message != "Custom (custom_note) content parts are not natively supported by codex-app-server and were translated to deterministic transcript fallback text" { + t.Fatalf("custom warning mismatch: %q", translated.Warnings[2].Message) + } + + payload := mustTranscriptPayload(t, translated.Payload) + history := asSlice(payload["history"]) + if len(history) != 1 { + t.Fatalf("history len: got %d want 1", len(history)) + } + parts := asSlice(asMap(history[0])["parts"]) + if len(parts) != 3 { + t.Fatalf("parts len: got %d want 3", len(parts)) + } + customPart := asMap(parts[2]) + if got := asString(customPart["fallbackKind"]); got != "custom" { + t.Fatalf("fallbackKind: got %q want custom", got) + } + customData := asMap(customPart["data"]) + if asString(customData["topic"]) != "ops" || asString(customData["priority"]) != "high" { + t.Fatalf("custom data mismatch: %#v", customData) + } +} + +func TestTranslateRequest_ValidatesToolChoice(t *testing.T) { + req := llm.Request{ + Model: "codex-mini", + Messages: []llm.Message{llm.User("Need tools")}, + ToolChoice: &llm.ToolChoice{Mode: "required"}, + } + if _, err := translateRequest(req, false); err == nil { + t.Fatalf("expected error for required tool choice without tools") + } + + req = llm.Request{ + Model: "codex-mini", + Messages: []llm.Message{llm.User("Need weather")}, + Tools: []llm.ToolDefinition{{ + Name: "lookup_weather", + Parameters: map[string]any{"type": "object"}, + }}, + ToolChoice: &llm.ToolChoice{Mode: "named", Name: "missing_tool"}, + } + if _, err := translateRequest(req, false); err == nil { + t.Fatalf("expected error for named tool choice without matching tool") + } +} + +func TestTranslateRequest_DefaultReasoningEffort(t *testing.T) { + translated, err := translateRequest(llm.Request{Model: "codex-mini", Messages: []llm.Message{llm.User("Hello")}}, false) + if err != nil { + t.Fatalf("translateRequest: %v", err) + } + if got := asString(translated.Payload["effort"]); got != "high" { + t.Fatalf("effort: got %q want high", got) + } +} diff --git a/internal/llm/providers/codexappserver/response_translator.go b/internal/llm/providers/codexappserver/response_translator.go new file mode 100644 index 00000000..e5d1f434 --- /dev/null +++ b/internal/llm/providers/codexappserver/response_translator.go @@ -0,0 +1,507 @@ +package codexappserver + +import ( + "bytes" + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" + "time" + + "github.com/danshapiro/kilroy/internal/llm" +) + +type normalizedNotification struct { + Method string + Params map[string]any +} + +type itemDeltas struct { + agentByID map[string]string + reasoningSummaryByID map[string]map[int]string + reasoningContentByID map[string]map[int]string +} + +var toolProtocolRE = regexp.MustCompile(`(?is)\[\[TOOL_CALL\]\]([\s\S]*?)\[\[/TOOL_CALL\]\]`) + +func translateResponse(body map[string]any) (llm.Response, error) { + notifications := extractNotifications(body) + turn := asMap(body["turn"]) + items := collectItems(turn, notifications) + + content := make([]llm.ContentPart, 0, 8) + for _, item := range items { + switch asString(item["type"]) { + case "reasoning": + content = append(content, translateReasoning(item)...) + case "agentMessage": + content = append(content, translateAgentMessage(item)...) + } + } + if len(content) == 0 { + if fallback := asString(body["text"]); fallback != "" { + content = append(content, llm.ContentPart{Kind: llm.ContentText, Text: fallback}) + } + } + + rawStatus := firstNonEmpty(asString(turn["status"]), asString(body["status"])) + hasToolCalls := false + for _, part := range content { + if part.Kind == llm.ContentToolCall { + hasToolCalls = true + break + } + } + + usage := translateUsage(body, notifications) + warnings := extractWarnings(body, notifications) + response := llm.Response{ + ID: firstNonEmpty(asString(turn["id"]), asString(body["id"])), + Model: extractModel(body, notifications), + Provider: "codex-app-server", + Message: llm.Message{ + Role: llm.RoleAssistant, + Content: content, + }, + Finish: llm.FinishReason{ + Reason: mapFinishReason(rawStatus, hasToolCalls), + Raw: rawStatus, + }, + Usage: usage, + Raw: body, + Warnings: warnings, + } + return response, nil +} + +func collectItems(turn map[string]any, notifications []normalizedNotification) []map[string]any { + deltas := collectDeltas(notifications) + orderedIDs := make([]string, 0, 16) + byID := make(map[string]map[string]any) + + upsert := func(item map[string]any) { + id := strings.TrimSpace(asString(item["id"])) + if id == "" { + return + } + if _, exists := byID[id]; !exists { + orderedIDs = append(orderedIDs, id) + } + byID[id] = item + } + + for _, notification := range notifications { + if notification.Method != "item/completed" { + continue + } + item := asMap(notification.Params["item"]) + if item != nil { + upsert(item) + } + } + for _, itemRaw := range asSlice(turn["items"]) { + item := asMap(itemRaw) + if item != nil { + upsert(item) + } + } + for itemID, text := range deltas.agentByID { + if _, exists := byID[itemID]; exists { + continue + } + upsert(map[string]any{"id": itemID, "type": "agentMessage", "text": text}) + } + for itemID, summaryMap := range deltas.reasoningSummaryByID { + if _, exists := byID[itemID]; exists { + continue + } + upsert(map[string]any{ + "id": itemID, + "type": "reasoning", + "summary": mapByIndex(summaryMap), + "content": mapByIndex(deltas.reasoningContentByID[itemID]), + }) + } + for itemID, contentMap := range deltas.reasoningContentByID { + if _, exists := byID[itemID]; exists { + continue + } + upsert(map[string]any{ + "id": itemID, + "type": "reasoning", + "summary": mapByIndex(deltas.reasoningSummaryByID[itemID]), + "content": mapByIndex(contentMap), + }) + } + + out := make([]map[string]any, 0, len(orderedIDs)) + for _, id := range orderedIDs { + if item := byID[id]; item != nil { + out = append(out, item) + } + } + return out +} + +func collectDeltas(notifications []normalizedNotification) itemDeltas { + agentByID := map[string]string{} + reasoningSummaryByID := map[string]map[int]string{} + reasoningContentByID := map[string]map[int]string{} + + appendByIndex := func(target map[string]map[int]string, itemID string, idx int, delta string) { + if _, ok := target[itemID]; !ok { + target[itemID] = map[int]string{} + } + target[itemID][idx] = target[itemID][idx] + delta + } + + for _, notification := range notifications { + switch notification.Method { + case "item/agentMessage/delta": + itemID := asString(notification.Params["itemId"]) + if itemID == "" { + continue + } + agentByID[itemID] = agentByID[itemID] + asString(notification.Params["delta"]) + case "item/reasoning/summaryTextDelta": + itemID := asString(notification.Params["itemId"]) + if itemID == "" { + continue + } + appendByIndex(reasoningSummaryByID, itemID, asInt(notification.Params["summaryIndex"], 0), asString(notification.Params["delta"])) + case "item/reasoning/textDelta": + itemID := asString(notification.Params["itemId"]) + if itemID == "" { + continue + } + appendByIndex(reasoningContentByID, itemID, asInt(notification.Params["contentIndex"], 0), asString(notification.Params["delta"])) + } + } + + return itemDeltas{ + agentByID: agentByID, + reasoningSummaryByID: reasoningSummaryByID, + reasoningContentByID: reasoningContentByID, + } +} + +func mapByIndex(in map[int]string) []any { + if len(in) == 0 { + return nil + } + keys := make([]int, 0, len(in)) + for idx := range in { + keys = append(keys, idx) + } + sort.Ints(keys) + out := make([]any, 0, len(keys)) + for _, idx := range keys { + out = append(out, in[idx]) + } + return out +} + +func translateReasoning(item map[string]any) []llm.ContentPart { + parts := make([]llm.ContentPart, 0, 4) + for _, source := range []any{item["summary"], item["content"]} { + for _, chunk := range asSlice(source) { + text := asString(chunk) + if strings.TrimSpace(text) == "" { + continue + } + parts = append(parts, splitReasoningChunk(text)...) + } + } + return parts +} + +func splitReasoningChunk(text string) []llm.ContentPart { + segments := splitReasoningSegments(text) + out := make([]llm.ContentPart, 0, len(segments)) + for _, segment := range segments { + trimmed := strings.TrimSpace(segment.Text) + if trimmed == "" { + continue + } + thinking := &llm.ThinkingData{Text: trimmed, Redacted: segment.Redacted} + if segment.Redacted { + out = append(out, llm.ContentPart{Kind: llm.ContentRedThinking, Thinking: thinking}) + continue + } + out = append(out, llm.ContentPart{Kind: llm.ContentThinking, Thinking: thinking}) + } + return out +} + +func translateAgentMessage(item map[string]any) []llm.ContentPart { + text := asString(item["text"]) + if text == "" { + return nil + } + matches := toolProtocolRE.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return []llm.ContentPart{{Kind: llm.ContentText, Text: text}} + } + + parts := make([]llm.ContentPart, 0, len(matches)*2+1) + cursor := 0 + for _, m := range matches { + if len(m) < 4 { + continue + } + start := m[0] + end := m[1] + payloadStart := m[2] + payloadEnd := m[3] + if start > cursor { + prefix := text[cursor:start] + if prefix != "" { + parts = append(parts, llm.ContentPart{Kind: llm.ContentText, Text: prefix}) + } + } + payload := strings.TrimSpace(text[payloadStart:payloadEnd]) + if toolCall := parseToolCall(payload); toolCall != nil { + parts = append(parts, llm.ContentPart{Kind: llm.ContentToolCall, ToolCall: toolCall}) + } else { + block := text[start:end] + if block != "" { + parts = append(parts, llm.ContentPart{Kind: llm.ContentText, Text: block}) + } + } + cursor = end + } + if cursor < len(text) { + suffix := text[cursor:] + if suffix != "" { + parts = append(parts, llm.ContentPart{Kind: llm.ContentText, Text: suffix}) + } + } + return parts +} + +func parseToolCall(payload string) *llm.ToolCallData { + if strings.TrimSpace(payload) == "" { + return nil + } + m, ok := parseJSONRecord(payload) + if !ok { + return nil + } + name := strings.TrimSpace(asString(m["name"])) + if name == "" { + return nil + } + id := strings.TrimSpace(asString(m["id"])) + if id == "" { + id = fmt.Sprintf("call_%d", time.Now().UnixNano()) + } + typ := strings.TrimSpace(asString(m["type"])) + + argsRaw := m["arguments"] + arguments, rawStr := normalizeParsedArguments(argsRaw) + _ = rawStr + + toolCall := &llm.ToolCallData{ + ID: id, + Name: name, + Arguments: arguments, + } + if typ != "" { + toolCall.Type = typ + } + return toolCall +} + +func normalizeParsedArguments(value any) (json.RawMessage, string) { + if s, ok := value.(string); ok { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return json.RawMessage("{}"), "{}" + } + if json.Valid([]byte(trimmed)) { + return json.RawMessage(trimmed), trimmed + } + encoded, _ := json.Marshal(trimmed) + return json.RawMessage(encoded), trimmed + } + if value == nil { + return json.RawMessage("{}"), "{}" + } + b, err := json.Marshal(value) + if err != nil || len(b) == 0 { + return json.RawMessage("{}"), "{}" + } + return json.RawMessage(b), string(b) +} + +func extractModel(body map[string]any, notifications []normalizedNotification) string { + if model := firstNonEmpty(asString(body["model"]), asString(body["modelId"]), asString(body["model_name"])); model != "" { + return model + } + for idx := len(notifications) - 1; idx >= 0; idx-- { + notification := notifications[idx] + if notification.Method != "model/rerouted" { + continue + } + if model := asString(notification.Params["toModel"]); model != "" { + return model + } + } + return "" +} + +func translateUsage(body map[string]any, notifications []normalizedNotification) llm.Usage { + var usageSource map[string]any + var rawUsage map[string]any + + for idx := len(notifications) - 1; idx >= 0; idx-- { + notification := notifications[idx] + if notification.Method != "thread/tokenUsage/updated" { + continue + } + tokenUsage := asMap(notification.Params["tokenUsage"]) + if tokenUsage == nil { + continue + } + rawUsage = tokenUsage + usageSource = asMap(tokenUsage["last"]) + if usageSource == nil { + usageSource = tokenUsage + } + break + } + if usageSource == nil { + tokenUsage := asMap(body["tokenUsage"]) + if tokenUsage != nil { + rawUsage = tokenUsage + usageSource = asMap(tokenUsage["last"]) + if usageSource == nil { + usageSource = tokenUsage + } + } + } + if usageSource == nil { + usage := asMap(body["usage"]) + if usage != nil { + rawUsage = usage + usageSource = usage + } + } + + usage := llm.Usage{ + InputTokens: asInt(usageSource["inputTokens"], asInt(usageSource["input_tokens"], 0)), + OutputTokens: asInt(usageSource["outputTokens"], asInt(usageSource["output_tokens"], 0)), + TotalTokens: asInt(usageSource["totalTokens"], asInt(usageSource["total_tokens"], 0)), + } + reasoningTokens := asInt(usageSource["reasoningOutputTokens"], asInt(usageSource["reasoning_tokens"], -1)) + cacheReadTokens := asInt(usageSource["cachedInputTokens"], asInt(usageSource["cache_read_input_tokens"], -1)) + cacheWriteTokens := asInt(usageSource["cacheWriteTokens"], asInt(usageSource["cache_write_input_tokens"], -1)) + if usage.TotalTokens <= 0 { + usage.TotalTokens = usage.InputTokens + usage.OutputTokens + } + if reasoningTokens >= 0 { + usage.ReasoningTokens = intPtr(reasoningTokens) + } + if cacheReadTokens >= 0 { + usage.CacheReadTokens = intPtr(cacheReadTokens) + } + if cacheWriteTokens >= 0 { + usage.CacheWriteTokens = intPtr(cacheWriteTokens) + } + if rawUsage != nil { + usage.Raw = rawUsage + } + return usage +} + +func extractWarnings(body map[string]any, notifications []normalizedNotification) []llm.Warning { + warnings := make([]llm.Warning, 0, 4) + for _, warningValue := range asSlice(body["warnings"]) { + warning := asMap(warningValue) + if warning == nil { + continue + } + message := strings.TrimSpace(asString(warning["message"])) + if message == "" { + continue + } + code := strings.TrimSpace(asString(warning["code"])) + warnings = append(warnings, llm.Warning{Message: message, Code: code}) + } + for _, notification := range notifications { + if notification.Method != "deprecationNotice" && notification.Method != "configWarning" { + continue + } + message := firstNonEmpty( + asString(notification.Params["message"]), + asString(notification.Params["notice"]), + asString(notification.Params["warning"]), + ) + if message == "" { + continue + } + warnings = append(warnings, llm.Warning{Message: message, Code: notification.Method}) + } + return warnings +} + +func extractNotifications(body map[string]any) []normalizedNotification { + notifications := make([]normalizedNotification, 0, 16) + sources := make([]any, 0) + sources = append(sources, asSlice(body["notifications"])...) + sources = append(sources, asSlice(body["events"])...) + sources = append(sources, asSlice(body["rawNotifications"])...) + + for _, raw := range sources { + entry := asMap(raw) + if entry == nil { + continue + } + method := firstNonEmpty(asString(entry["method"]), asString(entry["event"]), asString(entry["type"])) + if method == "" { + continue + } + params := asMap(entry["params"]) + if params == nil { + if dataString, ok := entry["data"].(string); ok { + if parsed, ok := parseJSONRecord(dataString); ok { + params = parsed + } + } else { + params = asMap(entry["data"]) + } + } + if params == nil { + params = map[string]any{} + } + notifications = append(notifications, normalizedNotification{Method: method, Params: params}) + } + + return notifications +} + +func parseJSONRecord(in string) (map[string]any, bool) { + dec := json.NewDecoder(strings.NewReader(strings.TrimSpace(in))) + dec.UseNumber() + var parsed map[string]any + if err := dec.Decode(&parsed); err != nil { + return nil, false + } + if parsed == nil { + return nil, false + } + return parsed, true +} + +func intPtr(v int) *int { return &v } + +func parseJSONAny(in string) any { + dec := json.NewDecoder(bytes.NewReader([]byte(in))) + dec.UseNumber() + var parsed any + if err := dec.Decode(&parsed); err != nil { + return nil + } + return parsed +} diff --git a/internal/llm/providers/codexappserver/response_translator_test.go b/internal/llm/providers/codexappserver/response_translator_test.go new file mode 100644 index 00000000..e1aa015b --- /dev/null +++ b/internal/llm/providers/codexappserver/response_translator_test.go @@ -0,0 +1,146 @@ +package codexappserver + +import ( + "strings" + "testing" + + "github.com/danshapiro/kilroy/internal/llm" +) + +func testNotification(method string, params map[string]any) map[string]any { + return map[string]any{"method": method, "params": params} +} + +func TestTranslateResponse_ToolProtocolReasoningAndUsage(t *testing.T) { + body := map[string]any{ + "id": "resp_codex_1", + "model": "codex-mini", + "turn": map[string]any{ + "id": "turn_1", + "status": "completed", + }, + "notifications": []any{ + testNotification("item/completed", map[string]any{ + "item": map[string]any{ + "id": "reasoning_1", + "type": "reasoning", + "summary": []any{"Plan steps"}, + "content": []any{"Visible [[REDACTED_REASONING]]secret[[/REDACTED_REASONING]] done"}, + }, + }), + testNotification("item/completed", map[string]any{ + "item": map[string]any{ + "id": "agent_1", + "type": "agentMessage", + "text": "Before [[TOOL_CALL]]{\"id\":\"call_1\",\"name\":\"search\",\"arguments\":{\"q\":\"foo\"}}[[/TOOL_CALL]] After", + }, + }), + testNotification("thread/tokenUsage/updated", map[string]any{ + "tokenUsage": map[string]any{ + "last": map[string]any{ + "inputTokens": 11, + "outputTokens": 7, + "totalTokens": 18, + "reasoningOutputTokens": 3, + "cachedInputTokens": 2, + }, + }, + }), + }, + } + + response, err := translateResponse(body) + if err != nil { + t.Fatalf("translateResponse: %v", err) + } + if response.ID != "turn_1" { + t.Fatalf("response id: got %q want turn_1", response.ID) + } + if response.Model != "codex-mini" { + t.Fatalf("response model: got %q", response.Model) + } + if response.Provider != providerName { + t.Fatalf("response provider: got %q", response.Provider) + } + if response.Finish.Reason != llm.FinishReasonToolCalls { + t.Fatalf("finish reason: got %q want %q", response.Finish.Reason, llm.FinishReasonToolCalls) + } + if response.Usage.InputTokens != 11 || response.Usage.OutputTokens != 7 || response.Usage.TotalTokens != 18 { + t.Fatalf("usage mismatch: %+v", response.Usage) + } + if response.Usage.ReasoningTokens == nil || *response.Usage.ReasoningTokens != 3 { + t.Fatalf("reasoning tokens mismatch: %+v", response.Usage) + } + if response.Usage.CacheReadTokens == nil || *response.Usage.CacheReadTokens != 2 { + t.Fatalf("cache read tokens mismatch: %+v", response.Usage) + } + + if len(response.Message.Content) < 4 { + t.Fatalf("expected content parts, got %+v", response.Message.Content) + } + foundToolCall := false + for _, part := range response.Message.Content { + if part.Kind == llm.ContentToolCall && part.ToolCall != nil { + foundToolCall = true + if part.ToolCall.ID != "call_1" || part.ToolCall.Name != "search" { + t.Fatalf("tool call mismatch: %+v", part.ToolCall) + } + if strings.TrimSpace(string(part.ToolCall.Arguments)) != `{"q":"foo"}` { + t.Fatalf("tool arguments mismatch: %q", string(part.ToolCall.Arguments)) + } + } + } + if !foundToolCall { + t.Fatalf("expected tool call part in response content") + } +} + +func TestTranslateResponse_FinishReasonMapping(t *testing.T) { + interrupted, err := translateResponse(map[string]any{ + "model": "codex-mini", + "turn": map[string]any{"id": "turn_2", "status": "interrupted"}, + }) + if err != nil { + t.Fatalf("translateResponse interrupted: %v", err) + } + if interrupted.Finish.Reason != llm.FinishReasonLength { + t.Fatalf("interrupted finish reason: got %q", interrupted.Finish.Reason) + } + + failed, err := translateResponse(map[string]any{ + "model": "codex-mini", + "turn": map[string]any{"id": "turn_3", "status": "failed"}, + }) + if err != nil { + t.Fatalf("translateResponse failed: %v", err) + } + if failed.Finish.Reason != llm.FinishReasonError { + t.Fatalf("failed finish reason: got %q", failed.Finish.Reason) + } +} + +func TestTranslateResponse_ReconstructsFromDeltas(t *testing.T) { + body := map[string]any{ + "model": "codex-mini", + "turn": map[string]any{"id": "turn_4", "status": "completed"}, + "notifications": []any{ + testNotification("item/agentMessage/delta", map[string]any{"itemId": "agent_delta", "delta": "Hello "}), + testNotification("item/agentMessage/delta", map[string]any{"itemId": "agent_delta", "delta": "world"}), + testNotification("item/reasoning/summaryTextDelta", map[string]any{"itemId": "reason_delta", "summaryIndex": 0, "delta": "Need to inspect state"}), + }, + } + + response, err := translateResponse(body) + if err != nil { + t.Fatalf("translateResponse: %v", err) + } + if len(response.Message.Content) != 2 { + t.Fatalf("content len: got %d want 2 (%+v)", len(response.Message.Content), response.Message.Content) + } + if response.Message.Content[0].Kind != llm.ContentText || response.Message.Content[0].Text != "Hello world" { + t.Fatalf("agent text reconstruction mismatch: %+v", response.Message.Content[0]) + } + if response.Message.Content[1].Kind != llm.ContentThinking || response.Message.Content[1].Thinking == nil || response.Message.Content[1].Thinking.Text != "Need to inspect state" { + t.Fatalf("reasoning reconstruction mismatch: %+v", response.Message.Content[1]) + } +} diff --git a/internal/llm/providers/codexappserver/stream_translator.go b/internal/llm/providers/codexappserver/stream_translator.go new file mode 100644 index 00000000..2a1357cd --- /dev/null +++ b/internal/llm/providers/codexappserver/stream_translator.go @@ -0,0 +1,495 @@ +package codexappserver + +import ( + "strconv" + "strings" + + "github.com/danshapiro/kilroy/internal/llm" +) + +const ( + toolProtocolStartToken = "[[TOOL_CALL]]" + toolProtocolStartTokenLower = "[[tool_call]]" + toolProtocolEndToken = "[[/TOOL_CALL]]" + toolProtocolEndTokenLower = "[[/tool_call]]" +) + +const maxStartReserve = len(toolProtocolStartToken) - 1 + +type parsedSegment struct { + Kind string + Text string + ToolCall *llm.ToolCallData +} + +type toolProtocolStreamParser struct { + buffer string + insideBlock bool + opening string +} + +func (p *toolProtocolStreamParser) feed(delta string) []parsedSegment { + p.buffer += delta + return p.drain(false) +} + +func (p *toolProtocolStreamParser) flush() []parsedSegment { + return p.drain(true) +} + +func (p *toolProtocolStreamParser) drain(finalize bool) []parsedSegment { + segments := make([]parsedSegment, 0, 4) + + for { + if p.insideBlock { + endIdx := strings.Index(strings.ToLower(p.buffer), toolProtocolEndTokenLower) + if endIdx < 0 { + if !finalize { + break + } + if p.buffer != "" || p.opening != "" { + segments = append(segments, parsedSegment{Kind: "text", Text: p.opening + p.buffer}) + } + p.buffer = "" + p.insideBlock = false + p.opening = "" + continue + } + + payload := p.buffer[:endIdx] + closing := p.buffer[endIdx : endIdx+len(toolProtocolEndToken)] + p.buffer = p.buffer[endIdx+len(toolProtocolEndToken):] + p.insideBlock = false + + if toolCall := parseToolCall(payload); toolCall != nil { + segments = append(segments, parsedSegment{Kind: "tool_call", ToolCall: toolCall}) + } else { + segments = append(segments, parsedSegment{Kind: "text", Text: p.opening + payload + closing}) + } + p.opening = "" + continue + } + + lower := strings.ToLower(p.buffer) + startIdx := strings.Index(lower, toolProtocolStartTokenLower) + if startIdx < 0 { + if p.buffer == "" { + break + } + if finalize { + segments = append(segments, parsedSegment{Kind: "text", Text: p.buffer}) + p.buffer = "" + break + } + if len(p.buffer) <= maxStartReserve { + break + } + safeText := p.buffer[:len(p.buffer)-maxStartReserve] + p.buffer = p.buffer[len(p.buffer)-maxStartReserve:] + if safeText != "" { + segments = append(segments, parsedSegment{Kind: "text", Text: safeText}) + } + break + } + + if startIdx > 0 { + segments = append(segments, parsedSegment{Kind: "text", Text: p.buffer[:startIdx]}) + } + + p.opening = p.buffer[startIdx : startIdx+len(toolProtocolStartToken)] + p.buffer = p.buffer[startIdx+len(toolProtocolStartToken):] + p.insideBlock = true + } + + return segments +} + +type textStreamState struct { + TextStarted bool + Parser *toolProtocolStreamParser +} + +func translateStream(events <-chan map[string]any) <-chan llm.StreamEvent { + out := make(chan llm.StreamEvent, 64) + go func() { + defer close(out) + + streamStarted := false + streamID := "" + model := "" + emittedToolCalls := false + var latestUsage *llm.Usage + + textStates := make(map[string]*textStreamState) + reasoningByItem := make(map[string]map[string]struct{}) + activeReasoningIDs := make(map[string]struct{}) + + closeReasoningForItem := func(itemID string) []llm.StreamEvent { + itemSet := reasoningByItem[itemID] + if len(itemSet) == 0 { + return nil + } + outEvents := make([]llm.StreamEvent, 0, len(itemSet)) + for reasoningID := range itemSet { + if _, ok := activeReasoningIDs[reasoningID]; !ok { + continue + } + delete(activeReasoningIDs, reasoningID) + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventReasoningEnd, + ReasoningID: reasoningID, + }) + } + delete(reasoningByItem, itemID) + return outEvents + } + + closeAllReasoning := func() []llm.StreamEvent { + if len(activeReasoningIDs) == 0 { + return nil + } + keys := make([]string, 0, len(activeReasoningIDs)) + for reasoningID := range activeReasoningIDs { + keys = append(keys, reasoningID) + } + outEvents := make([]llm.StreamEvent, 0, len(keys)) + for _, reasoningID := range keys { + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventReasoningEnd, + ReasoningID: reasoningID, + }) + delete(activeReasoningIDs, reasoningID) + } + reasoningByItem = map[string]map[string]struct{}{} + return outEvents + } + + ensureReasoningStarted := func(itemID, reasoningID string) []llm.StreamEvent { + if _, ok := reasoningByItem[itemID]; !ok { + reasoningByItem[itemID] = map[string]struct{}{} + } + reasoningByItem[itemID][reasoningID] = struct{}{} + if _, ok := activeReasoningIDs[reasoningID]; ok { + return nil + } + activeReasoningIDs[reasoningID] = struct{}{} + return []llm.StreamEvent{{ + Type: llm.StreamEventReasoningStart, + ReasoningID: reasoningID, + }} + } + + emitAgentSegments := func(itemID string, segments []parsedSegment) []llm.StreamEvent { + state := textStates[itemID] + if state == nil { + return nil + } + outEvents := make([]llm.StreamEvent, 0, len(segments)*3) + for _, segment := range segments { + switch segment.Kind { + case "text": + if segment.Text == "" { + continue + } + if !state.TextStarted { + state.TextStarted = true + outEvents = append(outEvents, llm.StreamEvent{Type: llm.StreamEventTextStart, TextID: itemID}) + } + outEvents = append(outEvents, llm.StreamEvent{Type: llm.StreamEventTextDelta, TextID: itemID, Delta: segment.Text}) + case "tool_call": + if segment.ToolCall == nil { + continue + } + emittedToolCalls = true + if state.TextStarted { + state.TextStarted = false + outEvents = append(outEvents, llm.StreamEvent{Type: llm.StreamEventTextEnd, TextID: itemID}) + } + call := *segment.ToolCall + outEvents = append(outEvents, + llm.StreamEvent{Type: llm.StreamEventToolCallStart, ToolCall: &llm.ToolCallData{ID: call.ID, Name: call.Name, Type: firstNonEmpty(call.Type, "function")}}, + llm.StreamEvent{Type: llm.StreamEventToolCallDelta, ToolCall: &llm.ToolCallData{ID: call.ID, Name: call.Name, Type: firstNonEmpty(call.Type, "function"), Arguments: call.Arguments}}, + llm.StreamEvent{Type: llm.StreamEventToolCallEnd, ToolCall: &llm.ToolCallData{ID: call.ID, Name: call.Name, Type: firstNonEmpty(call.Type, "function"), Arguments: call.Arguments}}, + ) + } + } + return outEvents + } + + flushAgentState := func(itemID string) []llm.StreamEvent { + state := textStates[itemID] + if state == nil { + return nil + } + outEvents := emitAgentSegments(itemID, state.Parser.flush()) + if state.TextStarted { + state.TextStarted = false + outEvents = append(outEvents, llm.StreamEvent{Type: llm.StreamEventTextEnd, TextID: itemID}) + } + delete(textStates, itemID) + return outEvents + } + + for rawEvent := range events { + notification, ok := normalizeNotification(rawEvent) + if !ok { + continue + } + outEvents := make([]llm.StreamEvent, 0, 6) + + switch notification.Method { + case "turn/started": + turn := asMap(notification.Params["turn"]) + if turnID := firstNonEmpty(asString(turn["id"]), asString(notification.Params["turnId"])); turnID != "" { + streamID = turnID + } + emittedToolCalls = false + if reroutedModel := asString(notification.Params["model"]); reroutedModel != "" { + model = reroutedModel + } + if !streamStarted { + streamStarted = true + outEvents = append(outEvents, llm.StreamEvent{Type: llm.StreamEventStreamStart, ID: streamID, Model: model}) + } + + case "item/agentMessage/delta": + itemID := asString(notification.Params["itemId"]) + if itemID == "" { + break + } + if _, ok := textStates[itemID]; !ok { + textStates[itemID] = &textStreamState{Parser: &toolProtocolStreamParser{}} + } + delta := asString(notification.Params["delta"]) + outEvents = append(outEvents, emitAgentSegments(itemID, textStates[itemID].Parser.feed(delta))...) + + case "item/reasoning/summaryPartAdded": + itemID := asString(notification.Params["itemId"]) + summaryIndex := asInt(notification.Params["summaryIndex"], -1) + if itemID == "" || summaryIndex < 0 { + break + } + nextReasoningID := fmtReasoningID(itemID, "summary", summaryIndex) + if existing := reasoningByItem[itemID]; len(existing) > 0 { + for reasoningID := range existing { + if !strings.HasPrefix(reasoningID, itemID+":summary:") || reasoningID == nextReasoningID { + continue + } + if _, ok := activeReasoningIDs[reasoningID]; ok { + delete(activeReasoningIDs, reasoningID) + outEvents = append(outEvents, llm.StreamEvent{Type: llm.StreamEventReasoningEnd, ReasoningID: reasoningID}) + } + } + } + outEvents = append(outEvents, ensureReasoningStarted(itemID, nextReasoningID)...) + + case "item/reasoning/summaryTextDelta": + itemID := asString(notification.Params["itemId"]) + if itemID == "" { + break + } + reasoningID := fmtReasoningID(itemID, "summary", asInt(notification.Params["summaryIndex"], 0)) + outEvents = append(outEvents, ensureReasoningStarted(itemID, reasoningID)...) + for _, segment := range splitReasoningSegments(asString(notification.Params["delta"])) { + if segment.Text == "" { + continue + } + var redacted *bool + if segment.Redacted { + v := true + redacted = &v + } + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventReasoningDelta, + ReasoningDelta: segment.Text, + ReasoningID: reasoningID, + Redacted: redacted, + }) + } + + case "item/reasoning/textDelta": + itemID := asString(notification.Params["itemId"]) + if itemID == "" { + break + } + reasoningID := fmtReasoningID(itemID, "content", asInt(notification.Params["contentIndex"], 0)) + outEvents = append(outEvents, ensureReasoningStarted(itemID, reasoningID)...) + for _, segment := range splitReasoningSegments(asString(notification.Params["delta"])) { + if segment.Text == "" { + continue + } + var redacted *bool + if segment.Redacted { + v := true + redacted = &v + } + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventReasoningDelta, + ReasoningDelta: segment.Text, + ReasoningID: reasoningID, + Redacted: redacted, + }) + } + + case "item/completed": + item := asMap(notification.Params["item"]) + if item == nil { + break + } + itemID := asString(item["id"]) + itemType := asString(item["type"]) + if itemID == "" { + break + } + if itemType == "agentMessage" { + outEvents = append(outEvents, flushAgentState(itemID)...) + break + } + if itemType == "reasoning" { + outEvents = append(outEvents, closeReasoningForItem(itemID)...) + } + + case "thread/tokenUsage/updated": + latestUsage = usageFromTokenUsage(asMap(notification.Params["tokenUsage"])) + + case "error": + errorData := asMap(notification.Params["error"]) + message := firstNonEmpty(asString(errorData["message"]), "Unknown stream error") + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventError, + Err: llm.NewStreamError("codex-app-server", message), + Raw: notification.Params, + }) + + case "turn/completed": + turn := asMap(notification.Params["turn"]) + if turnID := asString(turn["id"]); turnID != "" { + streamID = turnID + } + for itemID := range textStates { + outEvents = append(outEvents, flushAgentState(itemID)...) + } + outEvents = append(outEvents, closeAllReasoning()...) + status := asString(turn["status"]) + if status == "failed" { + turnError := asMap(turn["error"]) + message := firstNonEmpty(asString(turnError["message"]), "Turn failed") + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventError, + Err: llm.NewStreamError("codex-app-server", message), + Raw: notification.Params, + }) + } + if turnUsage := usageFromTokenUsage(asMap(turn["tokenUsage"])); turnUsage != nil { + latestUsage = turnUsage + } else if turnUsage := usageFromTokenUsage(asMap(turn["token_usage"])); turnUsage != nil { + latestUsage = turnUsage + } + finish := llm.FinishReason{Reason: mapFinishReason(status, emittedToolCalls), Raw: status} + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventFinish, + FinishReason: &finish, + Usage: latestUsage, + Raw: notification.Params, + }) + + default: + outEvents = append(outEvents, llm.StreamEvent{ + Type: llm.StreamEventProviderEvent, + EventType: notification.Method, + Raw: notification.Params, + }) + } + + if !streamStarted && notification.Method != "turn/started" { + hasTranslated := false + for _, event := range outEvents { + if event.Type != llm.StreamEventProviderEvent { + hasTranslated = true + break + } + } + if hasTranslated { + streamStarted = true + start := llm.StreamEvent{Type: llm.StreamEventStreamStart, ID: streamID, Model: model} + outEvents = append([]llm.StreamEvent{start}, outEvents...) + } + } + + for _, event := range outEvents { + out <- event + } + } + }() + return out +} + +func usageFromTokenUsage(tokenUsage map[string]any) *llm.Usage { + if tokenUsage == nil { + return nil + } + last := asMap(tokenUsage["last"]) + if last == nil { + last = tokenUsage + } + usage := llm.Usage{ + InputTokens: asInt(last["inputTokens"], asInt(last["input_tokens"], 0)), + OutputTokens: asInt(last["outputTokens"], asInt(last["output_tokens"], 0)), + TotalTokens: asInt(last["totalTokens"], asInt(last["total_tokens"], 0)), + Raw: tokenUsage, + } + if usage.TotalTokens <= 0 { + usage.TotalTokens = usage.InputTokens + usage.OutputTokens + } + reasoningTokens := asInt(last["reasoningOutputTokens"], asInt(last["reasoning_tokens"], -1)) + if reasoningTokens >= 0 { + usage.ReasoningTokens = intPtr(reasoningTokens) + } + cacheReadTokens := asInt(last["cachedInputTokens"], asInt(last["cache_read_input_tokens"], -1)) + if cacheReadTokens >= 0 { + usage.CacheReadTokens = intPtr(cacheReadTokens) + } + cacheWriteTokens := asInt(last["cacheWriteTokens"], asInt(last["cache_write_input_tokens"], -1)) + if cacheWriteTokens >= 0 { + usage.CacheWriteTokens = intPtr(cacheWriteTokens) + } + return &usage +} + +func normalizeNotification(rawEvent map[string]any) (normalizedNotification, bool) { + if method := strings.TrimSpace(asString(rawEvent["method"])); method != "" { + params := asMap(rawEvent["params"]) + if params == nil { + params = map[string]any{} + } + return normalizedNotification{Method: method, Params: params}, true + } + + if event := strings.TrimSpace(asString(rawEvent["event"])); event != "" { + params := map[string]any{} + switch data := rawEvent["data"].(type) { + case string: + if parsed, ok := parseJSONRecord(data); ok { + params = parsed + } + default: + if rec := asMap(data); rec != nil { + params = rec + } + } + return normalizedNotification{Method: event, Params: params}, true + } + + typ := strings.TrimSpace(asString(rawEvent["type"])) + if strings.Contains(typ, "/") { + params := deepCopyMap(rawEvent) + delete(params, "type") + return normalizedNotification{Method: typ, Params: params}, true + } + + return normalizedNotification{}, false +} + +func fmtReasoningID(itemID, segment string, idx int) string { + return itemID + ":" + segment + ":" + strconv.Itoa(idx) +} diff --git a/internal/llm/providers/codexappserver/stream_translator_test.go b/internal/llm/providers/codexappserver/stream_translator_test.go new file mode 100644 index 00000000..11be1152 --- /dev/null +++ b/internal/llm/providers/codexappserver/stream_translator_test.go @@ -0,0 +1,150 @@ +package codexappserver + +import ( + "testing" + + "github.com/danshapiro/kilroy/internal/llm" +) + +func streamNotification(method string, params map[string]any) map[string]any { + return map[string]any{"method": method, "params": params} +} + +func collectStreamEvents(events []map[string]any) []llm.StreamEvent { + in := make(chan map[string]any, len(events)) + for _, event := range events { + in <- event + } + close(in) + + out := make([]llm.StreamEvent, 0, 16) + for event := range translateStream(in) { + out = append(out, event) + } + return out +} + +func TestTranslateStream_TextAndFinishUsage(t *testing.T) { + events := collectStreamEvents([]map[string]any{ + streamNotification("turn/started", map[string]any{"turn": map[string]any{"id": "turn_1", "status": "inProgress", "items": []any{}}}), + streamNotification("item/agentMessage/delta", map[string]any{"itemId": "agent_1", "delta": "Hello"}), + streamNotification("item/agentMessage/delta", map[string]any{"itemId": "agent_1", "delta": " world"}), + streamNotification("item/completed", map[string]any{"item": map[string]any{"id": "agent_1", "type": "agentMessage", "text": "Hello world"}}), + streamNotification("thread/tokenUsage/updated", map[string]any{"tokenUsage": map[string]any{"last": map[string]any{"inputTokens": 12, "outputTokens": 8, "totalTokens": 20, "reasoningOutputTokens": 2, "cachedInputTokens": 3}}}), + streamNotification("turn/completed", map[string]any{"turn": map[string]any{"id": "turn_1", "status": "completed", "items": []any{}}}), + }) + + if len(events) == 0 { + t.Fatalf("expected events") + } + if events[0].Type != llm.StreamEventStreamStart { + t.Fatalf("first event type: got %q want %q", events[0].Type, llm.StreamEventStreamStart) + } + + text := "" + var finish *llm.StreamEvent + for idx := range events { + event := events[idx] + if event.Type == llm.StreamEventTextDelta { + text += event.Delta + } + if event.Type == llm.StreamEventFinish { + finish = &events[idx] + } + } + if text != "Hello world" { + t.Fatalf("text delta mismatch: got %q want %q", text, "Hello world") + } + if finish == nil { + t.Fatalf("expected finish event") + } + if finish.FinishReason == nil || finish.FinishReason.Reason != llm.FinishReasonStop { + t.Fatalf("finish reason mismatch: %+v", finish.FinishReason) + } + if finish.Usage == nil || finish.Usage.TotalTokens != 20 { + t.Fatalf("finish usage mismatch: %+v", finish.Usage) + } +} + +func TestTranslateStream_ParsesToolCallProtocol(t *testing.T) { + events := collectStreamEvents([]map[string]any{ + streamNotification("turn/started", map[string]any{"turn": map[string]any{"id": "turn_2", "status": "inProgress", "items": []any{}}}), + streamNotification("item/agentMessage/delta", map[string]any{"itemId": "agent_2", "delta": "Lead [[TOOL_CALL]]{\"id\":\"call_abc\",\"name\":\"lookup\",\"arguments\":{\"x\":1}}[[/TOOL_CALL]] tail"}), + streamNotification("item/completed", map[string]any{"item": map[string]any{"id": "agent_2", "type": "agentMessage", "text": ""}}), + streamNotification("turn/completed", map[string]any{"turn": map[string]any{"id": "turn_2", "status": "completed", "items": []any{}}}), + }) + + seenStart := false + seenDelta := false + seenEnd := false + finish := llm.FinishReason{} + for _, event := range events { + switch event.Type { + case llm.StreamEventToolCallStart: + seenStart = true + if event.ToolCall == nil || event.ToolCall.ID != "call_abc" || event.ToolCall.Name != "lookup" { + t.Fatalf("tool call start mismatch: %+v", event.ToolCall) + } + case llm.StreamEventToolCallDelta: + seenDelta = true + if event.ToolCall == nil || string(event.ToolCall.Arguments) != `{"x":1}` { + t.Fatalf("tool call delta mismatch: %+v", event.ToolCall) + } + case llm.StreamEventToolCallEnd: + seenEnd = true + case llm.StreamEventFinish: + if event.FinishReason != nil { + finish = *event.FinishReason + } + } + } + if !seenStart || !seenDelta || !seenEnd { + t.Fatalf("tool call events missing: start=%t delta=%t end=%t", seenStart, seenDelta, seenEnd) + } + if finish.Reason != llm.FinishReasonToolCalls { + t.Fatalf("finish reason mismatch: got %q want %q", finish.Reason, llm.FinishReasonToolCalls) + } +} + +func TestTranslateStream_FailedTurnEmitsErrorAndFinish(t *testing.T) { + events := collectStreamEvents([]map[string]any{ + streamNotification("turn/started", map[string]any{"turn": map[string]any{"id": "turn_3", "status": "inProgress", "items": []any{}}}), + streamNotification("error", map[string]any{"error": map[string]any{"message": "upstream overloaded"}}), + streamNotification("turn/completed", map[string]any{"turn": map[string]any{"id": "turn_3", "status": "failed", "error": map[string]any{"message": "turn failed hard"}, "items": []any{}}}), + }) + + errorCount := 0 + finishCount := 0 + for _, event := range events { + if event.Type == llm.StreamEventError { + errorCount++ + } + if event.Type == llm.StreamEventFinish { + finishCount++ + if event.FinishReason == nil || event.FinishReason.Reason != llm.FinishReasonError { + t.Fatalf("finish reason mismatch: %+v", event.FinishReason) + } + } + } + if errorCount != 2 { + t.Fatalf("error count: got %d want 2", errorCount) + } + if finishCount != 1 { + t.Fatalf("finish count: got %d want 1", finishCount) + } +} + +func TestTranslateStream_ProviderEventPassthrough(t *testing.T) { + events := collectStreamEvents([]map[string]any{ + streamNotification("model/rerouted", map[string]any{"fromModel": "codex-mini", "toModel": "codex-pro"}), + }) + if len(events) != 1 { + t.Fatalf("event count: got %d want 1", len(events)) + } + if events[0].Type != llm.StreamEventProviderEvent { + t.Fatalf("event type: got %q want %q", events[0].Type, llm.StreamEventProviderEvent) + } + if events[0].EventType != "model/rerouted" { + t.Fatalf("event_type: got %q want %q", events[0].EventType, "model/rerouted") + } +} diff --git a/internal/llm/providers/codexappserver/translator_utils.go b/internal/llm/providers/codexappserver/translator_utils.go new file mode 100644 index 00000000..b2023e2d --- /dev/null +++ b/internal/llm/providers/codexappserver/translator_utils.go @@ -0,0 +1,78 @@ +package codexappserver + +import "regexp" + +type reasoningSegment struct { + Text string + Redacted bool +} + +var redactedReasoningRE = regexp.MustCompile(`(?is)([\s\S]*?)|\[\[REDACTED_REASONING\]\]([\s\S]*?)\[\[/REDACTED_REASONING\]\]`) + +func splitReasoningSegments(text string) []reasoningSegment { + if text == "" { + return nil + } + segments := make([]reasoningSegment, 0, 4) + cursor := 0 + matches := redactedReasoningRE.FindAllStringSubmatchIndex(text, -1) + for _, m := range matches { + if len(m) < 6 { + continue + } + start := m[0] + end := m[1] + if start > cursor { + visible := text[cursor:start] + if visible != "" { + segments = append(segments, reasoningSegment{Text: visible}) + } + } + redacted := "" + if m[2] >= 0 && m[3] >= 0 { + redacted = text[m[2]:m[3]] + } else if m[4] >= 0 && m[5] >= 0 { + redacted = text[m[4]:m[5]] + } + if redacted != "" { + segments = append(segments, reasoningSegment{Text: redacted, Redacted: true}) + } + cursor = end + } + if cursor < len(text) { + tail := text[cursor:] + if tail != "" { + prefixRe := regexp.MustCompile(`(?is)^(?:\[REDACTED\]\s*|REDACTED:\s*)([\s\S]+)$`) + if sm := prefixRe.FindStringSubmatch(tail); len(sm) == 2 && sm[1] != "" { + segments = append(segments, reasoningSegment{Text: sm[1], Redacted: true}) + } else { + segments = append(segments, reasoningSegment{Text: tail}) + } + } + } + return segments +} + +func mapFinishReason(rawStatus string, hasToolCalls bool) string { + if hasToolCalls { + return llmFinishReasonToolCalls + } + switch rawStatus { + case "completed": + return llmFinishReasonStop + case "interrupted": + return llmFinishReasonLength + case "failed": + return llmFinishReasonError + default: + return llmFinishReasonOther + } +} + +const ( + llmFinishReasonStop = "stop" + llmFinishReasonLength = "length" + llmFinishReasonToolCalls = "tool_calls" + llmFinishReasonError = "error" + llmFinishReasonOther = "other" +) diff --git a/internal/llm/providers/codexappserver/transport.go b/internal/llm/providers/codexappserver/transport.go new file mode 100644 index 00000000..8731e888 --- /dev/null +++ b/internal/llm/providers/codexappserver/transport.go @@ -0,0 +1,1070 @@ +package codexappserver + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/danshapiro/kilroy/internal/llm" +) + +const ( + providerName = "codex-app-server" + defaultCommand = "codex" + defaultConnectTimeout = 15 * time.Second + // No provider-imposed request cap by default; execution deadlines should come + // from caller context (for example stage/runtime policy timeouts). + defaultRequestTimeout = 0 + defaultShutdownTimeout = 5 * time.Second + defaultInterruptTimeout = 2 * time.Second + defaultStderrTailLimit = 16 * 1024 + maxJSONRPCLineSize = 16 * 1024 * 1024 +) + +var defaultCommandArgs = []string{"app-server", "--listen", "stdio://"} + +type TransportOptions struct { + Command string + Args []string + CWD string + Env map[string]string + InitializeParams map[string]any + ConnectTimeout time.Duration + RequestTimeout time.Duration + ShutdownTimeout time.Duration + StderrTailLimit int +} + +type NotificationStream struct { + Notifications <-chan map[string]any + Err <-chan error + closeFn func() +} + +func (s *NotificationStream) Close() { + if s == nil || s.closeFn == nil { + return + } + s.closeFn() +} + +type pendingRequest struct { + method string + respCh chan pendingResult +} + +type pendingResult struct { + result any + err error +} + +type stdioTransport struct { + opts TransportOptions + + mu sync.Mutex + writeMu sync.Mutex + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + procDone chan struct{} + + closed bool + shuttingDown bool + initialized bool + initWait chan struct{} + initErr error + + nextID int64 + + pending map[string]*pendingRequest + listeners map[int]func(map[string]any) + nextLID int + + stderrTail string +} + +func NewTransport(opts TransportOptions) *stdioTransport { + opts.Command = strings.TrimSpace(opts.Command) + if opts.Command == "" { + opts.Command = defaultCommand + } + if len(opts.Args) == 0 { + opts.Args = append([]string{}, defaultCommandArgs...) + } + if opts.ConnectTimeout <= 0 { + opts.ConnectTimeout = defaultConnectTimeout + } + if opts.RequestTimeout <= 0 { + opts.RequestTimeout = defaultRequestTimeout + } + if opts.ShutdownTimeout <= 0 { + opts.ShutdownTimeout = defaultShutdownTimeout + } + if opts.StderrTailLimit <= 0 { + opts.StderrTailLimit = defaultStderrTailLimit + } + if opts.InitializeParams == nil { + opts.InitializeParams = map[string]any{ + "clientInfo": map[string]any{ + "name": "unified_llm", + "title": "Unified LLM", + "version": "0.1.0", + }, + } + } + return &stdioTransport{ + opts: opts, + pending: map[string]*pendingRequest{}, + listeners: map[int]func(map[string]any){}, + } +} + +func (t *stdioTransport) Initialize(ctx context.Context) error { + return t.ensureInitialized(ctx) +} + +func (t *stdioTransport) Complete(ctx context.Context, payload map[string]any) (map[string]any, error) { + return t.runTurn(ctx, payload) +} + +func (t *stdioTransport) Stream(ctx context.Context, payload map[string]any) (*NotificationStream, error) { + events := make(chan map[string]any, 128) + errs := make(chan error, 1) + sctx, cancel := context.WithCancel(ctx) + + go func() { + defer close(events) + defer close(errs) + + if err := t.ensureInitialized(sctx); err != nil { + errs <- err + return + } + + turnTemplate, err := parseTurnStartPayload(payload) + if err != nil { + errs <- err + return + } + + requestCtx, requestCancel := contextWithRequestTimeout(sctx, t.opts.RequestTimeout) + defer requestCancel() + + threadResp, err := t.startThread(requestCtx, toThreadStartParams(turnTemplate)) + if err != nil { + errs <- err + return + } + thread := asMap(threadResp["thread"]) + threadID := asString(thread["id"]) + if threadID == "" { + errs <- llm.ErrorFromHTTPStatus(providerName, 400, "thread/start response missing thread.id", threadResp, nil) + return + } + + turnParams := deepCopyMap(turnTemplate) + turnParams["threadId"] = threadID + + var ( + stateMu sync.Mutex + turnID string + completed = make(chan struct{}, 1) + ) + + sendNotification := func(notification map[string]any) { + select { + case events <- deepCopyMap(notification): + case <-requestCtx.Done(): + } + } + + unsubscribe := t.subscribe(func(notification map[string]any) { + stateMu.Lock() + currentTurnID := turnID + stateMu.Unlock() + if !notificationBelongsToTurn(notification, threadID, currentTurnID) { + return + } + sendNotification(notification) + notificationTurnID := extractTurnID(notification) + if notificationTurnID != "" { + stateMu.Lock() + if turnID == "" { + turnID = notificationTurnID + } + currentTurnID = turnID + stateMu.Unlock() + } + if asString(notification["method"]) == "turn/completed" { + if currentTurnID == "" || notificationTurnID == "" || notificationTurnID == currentTurnID { + select { + case completed <- struct{}{}: + default: + } + } + } + }) + defer unsubscribe() + + turnResp, err := t.startTurn(requestCtx, turnParams) + if err != nil { + errs <- err + return + } + turn := asMap(turnResp["turn"]) + if tid := asString(turn["id"]); tid != "" { + stateMu.Lock() + turnID = tid + stateMu.Unlock() + } + if isTerminalTurnStatus(asString(turn["status"])) { + sendNotification(map[string]any{ + "method": "turn/completed", + "params": map[string]any{ + "threadId": threadID, + "turn": turn, + }, + }) + select { + case completed <- struct{}{}: + default: + } + } + + select { + case <-completed: + return + case <-requestCtx.Done(): + stateMu.Lock() + currentTurnID := turnID + stateMu.Unlock() + if currentTurnID != "" { + go t.interruptTurnBestEffort(threadID, currentTurnID) + } + errs <- llm.WrapContextError(providerName, requestCtx.Err()) + return + } + }() + + return &NotificationStream{Notifications: events, Err: errs, closeFn: cancel}, nil +} + +func (t *stdioTransport) ListModels(ctx context.Context, params map[string]any) (modelListResponse, error) { + if err := t.ensureInitialized(ctx); err != nil { + return modelListResponse{}, err + } + if params == nil { + params = map[string]any{} + } + result, err := t.sendRequest(ctx, "model/list", params, t.opts.RequestTimeout) + if err != nil { + return modelListResponse{}, err + } + b, err := json.Marshal(result) + if err != nil { + return modelListResponse{}, err + } + dec := json.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + var out modelListResponse + if err := dec.Decode(&out); err != nil { + return modelListResponse{}, err + } + if out.Data == nil { + out.Data = []modelEntry{} + } + return out, nil +} + +func (t *stdioTransport) Close() error { + t.mu.Lock() + if t.closed { + t.mu.Unlock() + return nil + } + t.closed = true + t.mu.Unlock() + + t.rejectAllPending(llm.NewNetworkError(providerName, "Codex transport closed")) + return t.shutdownProcess() +} + +func (t *stdioTransport) runTurn(ctx context.Context, payload map[string]any) (map[string]any, error) { + if err := t.ensureInitialized(ctx); err != nil { + return nil, err + } + + turnTemplate, err := parseTurnStartPayload(payload) + if err != nil { + return nil, err + } + + requestCtx, requestCancel := contextWithRequestTimeout(ctx, t.opts.RequestTimeout) + defer requestCancel() + + threadResp, err := t.startThread(requestCtx, toThreadStartParams(turnTemplate)) + if err != nil { + return nil, err + } + thread := asMap(threadResp["thread"]) + threadID := asString(thread["id"]) + if threadID == "" { + return nil, llm.ErrorFromHTTPStatus(providerName, 400, "thread/start response missing thread.id", threadResp, nil) + } + + turnParams := deepCopyMap(turnTemplate) + turnParams["threadId"] = threadID + + var ( + stateMu sync.Mutex + notifications []map[string]any + turnID string + completed = make(chan struct{}, 1) + ) + + unsubscribe := t.subscribe(func(notification map[string]any) { + stateMu.Lock() + currentTurnID := turnID + stateMu.Unlock() + if !notificationBelongsToTurn(notification, threadID, currentTurnID) { + return + } + stateMu.Lock() + notifications = append(notifications, deepCopyMap(notification)) + notificationTurnID := extractTurnID(notification) + if turnID == "" && notificationTurnID != "" { + turnID = notificationTurnID + } + currentTurnID = turnID + stateMu.Unlock() + if asString(notification["method"]) == "turn/completed" { + if currentTurnID == "" || notificationTurnID == "" || notificationTurnID == currentTurnID { + select { + case completed <- struct{}{}: + default: + } + } + } + }) + defer unsubscribe() + + turnResp, err := t.startTurn(requestCtx, turnParams) + if err != nil { + return nil, err + } + turn := asMap(turnResp["turn"]) + if tid := asString(turn["id"]); tid != "" { + stateMu.Lock() + turnID = tid + stateMu.Unlock() + } + if isTerminalTurnStatus(asString(turn["status"])) { + select { + case completed <- struct{}{}: + default: + } + } + + select { + case <-completed: + case <-requestCtx.Done(): + stateMu.Lock() + currentTurnID := turnID + stateMu.Unlock() + if currentTurnID != "" { + go t.interruptTurnBestEffort(threadID, currentTurnID) + } + return nil, llm.WrapContextError(providerName, requestCtx.Err()) + } + + stateMu.Lock() + capturedNotifications := append([]map[string]any{}, notifications...) + capturedTurnID := turnID + stateMu.Unlock() + + completedTurn := findCompletedTurn(capturedNotifications, capturedTurnID) + if completedTurn == nil { + completedTurn = turn + } + result := map[string]any{ + "thread": thread, + "turn": completedTurn, + "threadId": threadID, + "turnId": firstNonEmpty(capturedTurnID, asString(completedTurn["id"])), + "notifications": capturedNotifications, + "threadResponse": threadResp, + "turnResponse": turnResp, + } + return result, nil +} + +func (t *stdioTransport) startThread(ctx context.Context, params map[string]any) (map[string]any, error) { + return t.sendRequest(ctx, "thread/start", params, t.opts.RequestTimeout) +} + +func (t *stdioTransport) startTurn(ctx context.Context, params map[string]any) (map[string]any, error) { + return t.sendRequest(ctx, "turn/start", params, t.opts.RequestTimeout) +} + +func (t *stdioTransport) interruptTurn(ctx context.Context, params map[string]any) error { + _, err := t.sendRequest(ctx, "turn/interrupt", params, t.opts.RequestTimeout) + return err +} + +func (t *stdioTransport) interruptTurnBestEffort(threadID, turnID string) { + if strings.TrimSpace(threadID) == "" || strings.TrimSpace(turnID) == "" { + return + } + timeout := t.interruptTimeout() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + _ = t.interruptTurn(ctx, map[string]any{"threadId": threadID, "turnId": turnID}) +} + +func (t *stdioTransport) interruptTimeout() time.Duration { + timeout := t.opts.RequestTimeout + if timeout <= 0 { + timeout = defaultInterruptTimeout + } + if t.opts.ShutdownTimeout > 0 && t.opts.ShutdownTimeout < timeout { + timeout = t.opts.ShutdownTimeout + } + return timeout +} + +func (t *stdioTransport) ensureInitialized(ctx context.Context) error { + t.mu.Lock() + if t.closed { + t.mu.Unlock() + return llm.NewNetworkError(providerName, "Codex transport is closed") + } + if t.initialized { + t.mu.Unlock() + return nil + } + if t.initWait != nil { + wait := t.initWait + t.mu.Unlock() + select { + case <-wait: + t.mu.Lock() + err := t.initErr + t.mu.Unlock() + return err + case <-ctx.Done(): + return llm.WrapContextError(providerName, ctx.Err()) + } + } + + wait := make(chan struct{}) + t.initWait = wait + t.mu.Unlock() + + err := t.startAndInitialize(ctx) + + t.mu.Lock() + if err == nil { + t.initialized = true + } + t.initErr = err + close(wait) + t.initWait = nil + t.mu.Unlock() + return err +} + +func (t *stdioTransport) startAndInitialize(ctx context.Context) error { + if err := t.spawnProcess(); err != nil { + return err + } + connCtx, cancel := contextWithRequestTimeout(ctx, t.opts.ConnectTimeout) + defer cancel() + if _, err := t.sendRequest(connCtx, "initialize", t.opts.InitializeParams, t.opts.ConnectTimeout); err != nil { + _ = t.shutdownProcess() + return err + } + if err := t.sendNotification(connCtx, "initialized", nil); err != nil { + _ = t.shutdownProcess() + return err + } + return nil +} + +func (t *stdioTransport) spawnProcess() error { + t.mu.Lock() + if t.cmd != nil && processAlive(t.cmd) { + t.mu.Unlock() + return nil + } + if t.closed { + t.mu.Unlock() + return llm.NewNetworkError(providerName, "Codex transport is closed") + } + t.mu.Unlock() + + cmd := exec.Command(t.opts.Command, t.opts.Args...) + if strings.TrimSpace(t.opts.CWD) != "" { + cmd.Dir = t.opts.CWD + } + if len(t.opts.Env) > 0 { + env := os.Environ() + for key, value := range t.opts.Env { + env = append(env, key+"="+value) + } + cmd.Env = env + } + + stdin, err := cmd.StdinPipe() + if err != nil { + return llm.NewNetworkError(providerName, fmt.Sprintf("failed to open stdin pipe: %v", err)) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + _ = stdin.Close() + return llm.NewNetworkError(providerName, fmt.Sprintf("failed to open stdout pipe: %v", err)) + } + stderr, err := cmd.StderrPipe() + if err != nil { + _ = stdin.Close() + _ = stdout.Close() + return llm.NewNetworkError(providerName, fmt.Sprintf("failed to open stderr pipe: %v", err)) + } + + if err := cmd.Start(); err != nil { + _ = stdin.Close() + _ = stdout.Close() + _ = stderr.Close() + return llm.NewNetworkError(providerName, fmt.Sprintf("failed to spawn codex app-server: %v", err)) + } + + procDone := make(chan struct{}) + t.mu.Lock() + t.cmd = cmd + t.stdin = stdin + t.stdout = stdout + t.stderr = stderr + t.procDone = procDone + t.stderrTail = "" + t.initialized = false + t.mu.Unlock() + + go t.readStdout(cmd, stdout) + go t.readStderr(stderr) + go t.waitForExit(cmd, procDone) + + return nil +} + +func (t *stdioTransport) readStdout(cmd *exec.Cmd, stdout io.Reader) { + scanner := bufio.NewScanner(stdout) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, maxJSONRPCLineSize) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + dec := json.NewDecoder(strings.NewReader(line)) + dec.UseNumber() + var message map[string]any + if err := dec.Decode(&message); err != nil { + continue + } + t.handleIncomingMessage(message) + } + if err := scanner.Err(); err != nil { + t.handleUnexpectedProcessTermination(llm.NewNetworkError(providerName, fmt.Sprintf("Codex stdout read error: %v", err))) + } + _ = cmd +} + +func (t *stdioTransport) readStderr(stderr io.Reader) { + buf := make([]byte, 4096) + for { + n, err := stderr.Read(buf) + if n > 0 { + t.appendStderrTail(string(buf[:n])) + } + if err != nil { + return + } + } +} + +func (t *stdioTransport) appendStderrTail(chunk string) { + if chunk == "" { + return + } + t.mu.Lock() + t.stderrTail += chunk + if len(t.stderrTail) > t.opts.StderrTailLimit { + t.stderrTail = t.stderrTail[len(t.stderrTail)-t.opts.StderrTailLimit:] + } + t.mu.Unlock() +} + +func (t *stdioTransport) waitForExit(cmd *exec.Cmd, done chan struct{}) { + err := cmd.Wait() + t.mu.Lock() + shuttingDown := t.shuttingDown + closed := t.closed + stderrTail := strings.TrimSpace(t.stderrTail) + if t.cmd == cmd { + t.cmd = nil + t.stdin = nil + t.stdout = nil + t.stderr = nil + t.procDone = nil + t.initialized = false + } + t.shuttingDown = false + t.mu.Unlock() + close(done) + + if shuttingDown || closed { + return + } + message := "Codex app-server exited unexpectedly" + if err != nil { + message = fmt.Sprintf("Codex app-server exited unexpectedly: %v", err) + } + if stderrTail != "" { + message = message + ". stderr: " + stderrTail + } + t.handleUnexpectedProcessTermination(llm.NewNetworkError(providerName, message)) +} + +func (t *stdioTransport) handleUnexpectedProcessTermination(err error) { + t.rejectAllPending(err) +} + +func (t *stdioTransport) handleIncomingMessage(message map[string]any) { + id, hasID := message["id"] + _, hasResult := message["result"] + errorObj := asMap(message["error"]) + + if hasID && hasResult { + t.resolvePendingRequest(id, pendingResult{result: message["result"]}) + return + } + if hasID && errorObj != nil { + t.resolvePendingRequest(id, pendingResult{err: t.toRPCError(asString(message["method"]), errorObj)}) + return + } + + method := strings.TrimSpace(asString(message["method"])) + if method == "" { + return + } + if hasID { + go t.handleServerRequest(id, method, message["params"]) + return + } + notification := map[string]any{"method": method} + if params := asMap(message["params"]); params != nil { + notification["params"] = params + } + t.emitNotification(notification) +} + +func (t *stdioTransport) emitNotification(notification map[string]any) { + t.mu.Lock() + listeners := make([]func(map[string]any), 0, len(t.listeners)) + for _, listener := range t.listeners { + listeners = append(listeners, listener) + } + t.mu.Unlock() + for _, listener := range listeners { + func(l func(map[string]any)) { + defer func() { _ = recover() }() + l(notification) + }(listener) + } +} + +func (t *stdioTransport) subscribe(listener func(map[string]any)) func() { + t.mu.Lock() + id := t.nextLID + t.nextLID++ + t.listeners[id] = listener + t.mu.Unlock() + return func() { + t.mu.Lock() + delete(t.listeners, id) + t.mu.Unlock() + } +} + +func (t *stdioTransport) resolvePendingRequest(id any, result pendingResult) { + key := rpcIDKey(id) + t.mu.Lock() + pending := t.pending[key] + if pending != nil { + delete(t.pending, key) + } + t.mu.Unlock() + if pending == nil { + return + } + select { + case pending.respCh <- result: + default: + } +} + +func (t *stdioTransport) rejectAllPending(err error) { + t.mu.Lock() + pending := t.pending + t.pending = map[string]*pendingRequest{} + t.mu.Unlock() + for _, req := range pending { + if req == nil { + continue + } + select { + case req.respCh <- pendingResult{err: err}: + default: + } + } +} + +func (t *stdioTransport) sendRequest(ctx context.Context, method string, params any, timeout time.Duration) (map[string]any, error) { + requestCtx, cancel := contextWithRequestTimeout(ctx, timeout) + defer cancel() + if err := requestCtx.Err(); err != nil { + return nil, llm.WrapContextError(providerName, err) + } + + id := atomic.AddInt64(&t.nextID, 1) + idKey := rpcIDKey(id) + respCh := make(chan pendingResult, 1) + + t.mu.Lock() + if t.closed { + t.mu.Unlock() + return nil, llm.NewNetworkError(providerName, "Codex transport is closed") + } + t.pending[idKey] = &pendingRequest{method: method, respCh: respCh} + t.mu.Unlock() + + request := map[string]any{"id": id, "method": method} + if params != nil { + request["params"] = params + } + if err := t.writeJSONLine(request); err != nil { + t.mu.Lock() + delete(t.pending, idKey) + t.mu.Unlock() + return nil, err + } + + select { + case result := <-respCh: + if result.err != nil { + return nil, result.err + } + if m := asMap(result.result); m != nil { + return m, nil + } + if result.result == nil { + return map[string]any{}, nil + } + b, err := json.Marshal(result.result) + if err != nil { + return nil, llm.NewNetworkError(providerName, fmt.Sprintf("invalid RPC result for %s: %v", method, err)) + } + return decodeJSONToMap(b), nil + case <-requestCtx.Done(): + t.mu.Lock() + delete(t.pending, idKey) + t.mu.Unlock() + return nil, llm.WrapContextError(providerName, requestCtx.Err()) + } +} + +func (t *stdioTransport) sendNotification(ctx context.Context, method string, params any) error { + if err := ctx.Err(); err != nil { + return llm.WrapContextError(providerName, err) + } + message := map[string]any{"method": method} + if params != nil { + message["params"] = params + } + return t.writeJSONLine(message) +} + +func (t *stdioTransport) writeJSONLine(message map[string]any) error { + b, err := json.Marshal(message) + if err != nil { + return llm.NewNetworkError(providerName, fmt.Sprintf("failed to marshal RPC message: %v", err)) + } + line := append(b, '\n') + + t.mu.Lock() + stdin := t.stdin + cmd := t.cmd + t.mu.Unlock() + if stdin == nil || cmd == nil || !processAlive(cmd) { + return llm.NewNetworkError(providerName, "Codex app-server stdin is not writable") + } + + t.writeMu.Lock() + defer t.writeMu.Unlock() + if _, err := stdin.Write(line); err != nil { + return llm.NewNetworkError(providerName, fmt.Sprintf("failed to write to codex app-server: %v", err)) + } + return nil +} + +func (t *stdioTransport) toRPCError(method string, errObj map[string]any) error { + code := asInt(errObj["code"], 0) + message := firstNonEmpty(asString(errObj["message"]), "RPC error") + wrapped := fmt.Sprintf("Codex RPC %s failed (%d): %s", method, code, message) + switch code { + case -32700, -32600, -32601, -32602: + return llm.ErrorFromHTTPStatus(providerName, 400, wrapped, errObj["data"], nil) + default: + return llm.ErrorFromHTTPStatus(providerName, 500, wrapped, errObj["data"], nil) + } +} + +func (t *stdioTransport) handleServerRequest(id any, method string, params any) { + sendSuccess := func(result any) { + _ = t.writeJSONLine(map[string]any{"id": id, "result": result}) + } + sendError := func(code int, message string, data any) { + errObj := map[string]any{"code": code, "message": message} + if data != nil { + errObj["data"] = data + } + _ = t.writeJSONLine(map[string]any{"id": id, "error": errObj}) + } + + switch method { + case "item/tool/call": + sendSuccess(map[string]any{"contentItems": []any{}, "success": false}) + case "item/tool/requestUserInput": + sendSuccess(buildDefaultUserInputResponse(params)) + case "item/commandExecution/requestApproval": + sendSuccess(map[string]any{"decision": "decline"}) + case "item/fileChange/requestApproval": + sendSuccess(map[string]any{"decision": "decline"}) + case "applyPatchApproval": + sendSuccess(map[string]any{"decision": "denied"}) + case "execCommandApproval": + sendSuccess(map[string]any{"decision": "denied"}) + case "account/chatgptAuthTokens/refresh": + sendError(-32001, "External ChatGPT auth token refresh is not configured", nil) + default: + sendError(-32601, "Method not found: "+method, nil) + } +} + +func buildDefaultUserInputResponse(params any) map[string]any { + answers := map[string]any{} + p := asMap(params) + if p == nil { + return map[string]any{"answers": answers} + } + for _, questionRaw := range asSlice(p["questions"]) { + question := asMap(questionRaw) + if question == nil { + continue + } + id := asString(question["id"]) + if id == "" { + continue + } + answers[id] = map[string]any{"answers": []any{}} + } + return map[string]any{"answers": answers} +} + +func (t *stdioTransport) shutdownProcess() error { + t.mu.Lock() + cmd := t.cmd + stdin := t.stdin + done := t.procDone + t.shuttingDown = true + t.mu.Unlock() + + if cmd == nil { + return nil + } + if stdin != nil { + _ = stdin.Close() + } + if cmd.Process != nil { + _ = cmd.Process.Signal(os.Interrupt) + } + + if done != nil { + select { + case <-done: + return nil + case <-time.After(t.opts.ShutdownTimeout): + } + } + + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + if done != nil { + select { + case <-done: + case <-time.After(time.Second): + } + } + return nil +} + +func parseTurnStartPayload(payload map[string]any) (map[string]any, error) { + if payload == nil { + return nil, llm.ErrorFromHTTPStatus(providerName, 400, "codex-app-server turn payload must be an object", nil, nil) + } + input := asSlice(payload["input"]) + if input == nil { + return nil, llm.ErrorFromHTTPStatus(providerName, 400, "codex-app-server turn payload is missing input array", payload, nil) + } + out := deepCopyMap(payload) + if strings.TrimSpace(asString(out["threadId"])) == "" { + out["threadId"] = defaultThreadID + } + out["input"] = input + return out, nil +} + +func toThreadStartParams(turn map[string]any) map[string]any { + thread := map[string]any{} + for _, key := range []string{"model", "cwd", "approvalPolicy", "personality"} { + if v, ok := turn[key]; ok && v != nil { + thread[key] = v + } + } + if sandbox := asString(turn["sandbox"]); sandbox == "read-only" || sandbox == "workspace-write" || sandbox == "danger-full-access" { + thread["sandbox"] = sandbox + } + return thread +} + +func isTerminalTurnStatus(status string) bool { + switch strings.TrimSpace(status) { + case "completed", "failed", "interrupted": + return true + default: + return false + } +} + +func notificationBelongsToTurn(notification map[string]any, threadID, turnID string) bool { + notificationThreadID := extractThreadID(notification) + if notificationThreadID != "" && notificationThreadID != threadID { + return false + } + if turnID == "" { + return true + } + notificationTurnID := extractTurnID(notification) + if notificationTurnID != "" && notificationTurnID != turnID { + return false + } + return true +} + +func extractThreadID(notification map[string]any) string { + params := asMap(notification["params"]) + if params == nil { + return "" + } + if threadID := asString(params["threadId"]); threadID != "" { + return threadID + } + if threadID := asString(params["thread_id"]); threadID != "" { + return threadID + } + return "" +} + +func extractTurnID(notification map[string]any) string { + params := asMap(notification["params"]) + if params == nil { + return "" + } + if turnID := asString(params["turnId"]); turnID != "" { + return turnID + } + if turnID := asString(params["turn_id"]); turnID != "" { + return turnID + } + turn := asMap(params["turn"]) + if turn != nil { + if turnID := asString(turn["id"]); turnID != "" { + return turnID + } + } + return "" +} + +func findCompletedTurn(notifications []map[string]any, turnID string) map[string]any { + for idx := len(notifications) - 1; idx >= 0; idx-- { + notification := notifications[idx] + if asString(notification["method"]) != "turn/completed" { + continue + } + notificationTurnID := extractTurnID(notification) + if turnID != "" && notificationTurnID != "" && notificationTurnID != turnID { + continue + } + turn := asMap(asMap(notification["params"])["turn"]) + if turn == nil { + continue + } + if asString(turn["id"]) == "" { + continue + } + if asSlice(turn["items"]) == nil { + continue + } + return turn + } + return nil +} + +func contextWithRequestTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout <= 0 { + return context.WithCancel(ctx) + } + if deadline, ok := ctx.Deadline(); ok { + if time.Until(deadline) <= timeout { + return context.WithCancel(ctx) + } + } + return context.WithTimeout(ctx, timeout) +} + +func processAlive(cmd *exec.Cmd) bool { + if cmd == nil { + return false + } + if cmd.ProcessState != nil && cmd.ProcessState.Exited() { + return false + } + if cmd.Process == nil { + return false + } + return true +} + +func rpcIDKey(id any) string { + return strings.TrimSpace(fmt.Sprintf("%v", id)) +} diff --git a/internal/llm/providers/codexappserver/transport_helpers_test.go b/internal/llm/providers/codexappserver/transport_helpers_test.go new file mode 100644 index 00000000..330c8bac --- /dev/null +++ b/internal/llm/providers/codexappserver/transport_helpers_test.go @@ -0,0 +1,579 @@ +package codexappserver + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "strings" + "sync" + "testing" + "time" + + "github.com/danshapiro/kilroy/internal/llm" +) + +type recordingWriteCloser struct { + mu sync.Mutex + buf bytes.Buffer + err error + close bool +} + +func (w *recordingWriteCloser) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + if w.err != nil { + return 0, w.err + } + return w.buf.Write(p) +} + +func (w *recordingWriteCloser) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + w.close = true + return nil +} + +func (w *recordingWriteCloser) lines() []string { + w.mu.Lock() + defer w.mu.Unlock() + raw := strings.TrimSpace(w.buf.String()) + if raw == "" { + return nil + } + return strings.Split(raw, "\n") +} + +func aliveCmd(t *testing.T) *exec.Cmd { + t.Helper() + proc, err := os.FindProcess(os.Getpid()) + if err != nil { + t.Fatalf("FindProcess: %v", err) + } + return &exec.Cmd{Process: proc} +} + +func TestTransport_ParseTurnStartPayload_ValidatesAndDefaults(t *testing.T) { + if _, err := parseTurnStartPayload(nil); err == nil { + t.Fatalf("expected error for nil payload") + } + + if _, err := parseTurnStartPayload(map[string]any{"threadId": "thread_1"}); err == nil { + t.Fatalf("expected error for missing input array") + } + + in := map[string]any{ + "input": []any{ + map[string]any{"type": "message", "role": "user"}, + }, + } + out, err := parseTurnStartPayload(in) + if err != nil { + t.Fatalf("parseTurnStartPayload: %v", err) + } + if got := asString(out["threadId"]); got != defaultThreadID { + t.Fatalf("threadId default: got %q want %q", got, defaultThreadID) + } + if asSlice(out["input"]) == nil { + t.Fatalf("input array was not preserved: %#v", out["input"]) + } + if _, ok := in["threadId"]; ok { + t.Fatalf("expected input map to remain unmodified; got %#v", in) + } +} + +func TestTransport_ToThreadStartParams_FiltersFields(t *testing.T) { + turn := map[string]any{ + "model": "codex-mini", + "cwd": "/tmp/repo", + "approvalPolicy": "never", + "personality": "strict", + "sandbox": "danger-full-access", + "ignored": true, + } + got := toThreadStartParams(turn) + if got["model"] != "codex-mini" || got["cwd"] != "/tmp/repo" || got["approvalPolicy"] != "never" || got["personality"] != "strict" { + t.Fatalf("unexpected mapped thread params: %#v", got) + } + if got["sandbox"] != "danger-full-access" { + t.Fatalf("sandbox: got %#v", got["sandbox"]) + } + if _, ok := got["ignored"]; ok { + t.Fatalf("did not expect ignored key in thread params: %#v", got) + } + + turn["sandbox"] = "unsupported" + got = toThreadStartParams(turn) + if _, ok := got["sandbox"]; ok { + t.Fatalf("unexpected sandbox for unsupported mode: %#v", got["sandbox"]) + } +} + +func TestTransport_TurnStatusAndNotificationMatching(t *testing.T) { + if !isTerminalTurnStatus("completed") || !isTerminalTurnStatus("failed") || !isTerminalTurnStatus("interrupted") { + t.Fatalf("expected terminal statuses to match") + } + if isTerminalTurnStatus("running") { + t.Fatalf("did not expect running to be terminal") + } + + n1 := map[string]any{"params": map[string]any{"threadId": "thread_1", "turnId": "turn_1"}} + if got := extractThreadID(n1); got != "thread_1" { + t.Fatalf("extractThreadID: got %q", got) + } + if got := extractTurnID(n1); got != "turn_1" { + t.Fatalf("extractTurnID: got %q", got) + } + if !notificationBelongsToTurn(n1, "thread_1", "turn_1") { + t.Fatalf("expected notification to belong to matching thread/turn") + } + if notificationBelongsToTurn(n1, "thread_2", "turn_1") { + t.Fatalf("expected thread mismatch to reject notification") + } + if notificationBelongsToTurn(n1, "thread_1", "turn_2") { + t.Fatalf("expected turn mismatch to reject notification") + } + + n2 := map[string]any{"params": map[string]any{"thread_id": "thread_1", "turn": map[string]any{"id": "turn_1"}}} + if got := extractThreadID(n2); got != "thread_1" { + t.Fatalf("extractThreadID snake_case: got %q", got) + } + if got := extractTurnID(n2); got != "turn_1" { + t.Fatalf("extractTurnID nested turn.id: got %q", got) + } + if !notificationBelongsToTurn(n2, "thread_1", "") { + t.Fatalf("expected turn-less matching to pass when thread matches") + } +} + +func TestTransport_FindCompletedTurn_ReturnsLatestMatching(t *testing.T) { + notifications := []map[string]any{ + {"method": "turn/progress", "params": map[string]any{"threadId": "thread_1"}}, + {"method": "turn/completed", "params": map[string]any{"turnId": "turn_old", "turn": map[string]any{"id": "turn_old", "items": []any{"a"}}}}, + {"method": "turn/completed", "params": map[string]any{"turnId": "turn_new", "turn": map[string]any{"id": "turn_new"}}}, // missing items + {"method": "turn/completed", "params": map[string]any{"turnId": "turn_new", "turn": map[string]any{"id": "turn_new", "items": []any{"x"}}}}, + } + + got := findCompletedTurn(notifications, "turn_new") + if got == nil { + t.Fatalf("expected completed turn") + } + if asString(got["id"]) != "turn_new" { + t.Fatalf("completed turn id: got %#v", got["id"]) + } + if len(asSlice(got["items"])) != 1 { + t.Fatalf("completed turn items: %#v", got["items"]) + } + + if miss := findCompletedTurn(notifications, "turn_missing"); miss != nil { + t.Fatalf("expected nil for missing turn id; got %#v", miss) + } +} + +func TestTransport_ProcessAliveAndRPCIDKey(t *testing.T) { + if processAlive(nil) { + t.Fatalf("nil command should not be alive") + } + + cmd := aliveCmd(t) + if !processAlive(cmd) { + t.Fatalf("expected command with process to be alive") + } + + finished := exec.Command(os.Args[0], "-test.run=TestTransport_HelperProcess") + finished.Env = append(os.Environ(), + "GO_WANT_TRANSPORT_HELPER=1", + "GO_TRANSPORT_HELPER_MODE=exit", + ) + if err := finished.Run(); err != nil { + t.Fatalf("run finished command: %v", err) + } + if processAlive(finished) { + t.Fatalf("expected exited command to be not alive") + } + + if got := rpcIDKey(" x "); got != "x" { + t.Fatalf("rpcIDKey string trim: got %q", got) + } + if got := rpcIDKey(12); got != "12" { + t.Fatalf("rpcIDKey int conversion: got %q", got) + } +} + +func TestTransport_ResolveAndRejectPendingRequests(t *testing.T) { + tp := &stdioTransport{pending: map[string]*pendingRequest{}} + + respCh := make(chan pendingResult, 1) + tp.pending["1"] = &pendingRequest{method: "turn/start", respCh: respCh} + tp.resolvePendingRequest(1, pendingResult{result: map[string]any{"ok": true}}) + + select { + case got := <-respCh: + if asMap(got.result)["ok"] != true { + t.Fatalf("unexpected pending result: %#v", got.result) + } + default: + t.Fatalf("expected resolved pending result") + } + if _, ok := tp.pending["1"]; ok { + t.Fatalf("expected pending request to be removed after resolution") + } + + errCh := make(chan pendingResult, 1) + tp.pending["2"] = &pendingRequest{method: "turn/start", respCh: errCh} + wantErr := errors.New("transport failed") + tp.rejectAllPending(wantErr) + + select { + case got := <-errCh: + if !errors.Is(got.err, wantErr) { + t.Fatalf("rejectAllPending err: got %v want %v", got.err, wantErr) + } + default: + t.Fatalf("expected rejected pending error") + } + if len(tp.pending) != 0 { + t.Fatalf("expected pending map to be cleared; got %#v", tp.pending) + } +} + +func TestTransport_EmitNotificationAndSubscribeLifecycle(t *testing.T) { + tp := &stdioTransport{listeners: map[int]func(map[string]any){}} + + got := make(chan string, 2) + unsubscribe := tp.subscribe(func(notification map[string]any) { + got <- asString(notification["method"]) + }) + _ = tp.subscribe(func(map[string]any) { + panic("listener panic should be recovered") + }) + + tp.emitNotification(map[string]any{"method": "turn/progress"}) + select { + case method := <-got: + if method != "turn/progress" { + t.Fatalf("unexpected method: %q", method) + } + default: + t.Fatalf("expected listener notification") + } + + unsubscribe() + tp.emitNotification(map[string]any{"method": "turn/completed"}) + select { + case method := <-got: + t.Fatalf("did not expect method after unsubscribe: %q", method) + default: + } +} + +func TestTransport_ToRPCError_MapsJSONRPCCodes(t *testing.T) { + tp := &stdioTransport{} + badReq := tp.toRPCError("turn/start", map[string]any{ + "code": -32601, + "message": "method not found", + "data": map[string]any{"hint": "check schema"}, + }) + var llmErr llm.Error + if !errors.As(badReq, &llmErr) { + t.Fatalf("expected llm.Error, got %T", badReq) + } + if llmErr.StatusCode() != 400 { + t.Fatalf("status code: got %d want 400", llmErr.StatusCode()) + } + + serverErr := tp.toRPCError("turn/start", map[string]any{ + "code": -32000, + "message": "internal", + }) + if !errors.As(serverErr, &llmErr) { + t.Fatalf("expected llm.Error, got %T", serverErr) + } + if llmErr.StatusCode() != 500 { + t.Fatalf("status code: got %d want 500", llmErr.StatusCode()) + } +} + +func TestTransport_WriteJSONLine_ValidationAndSuccessPaths(t *testing.T) { + tp := &stdioTransport{} + if err := tp.writeJSONLine(map[string]any{"bad": func() {}}); err == nil { + t.Fatalf("expected marshal error") + } + + if err := tp.writeJSONLine(map[string]any{"method": "x"}); err == nil { + t.Fatalf("expected non-writable stdin error without process") + } + + writer := &recordingWriteCloser{} + tp = &stdioTransport{cmd: aliveCmd(t), stdin: writer} + if err := tp.writeJSONLine(map[string]any{"method": "turn/start"}); err != nil { + t.Fatalf("writeJSONLine success: %v", err) + } + lines := writer.lines() + if len(lines) != 1 || !strings.Contains(lines[0], `"method":"turn/start"`) { + t.Fatalf("unexpected written line: %#v", lines) + } + + failingWriter := &recordingWriteCloser{err: errors.New("boom")} + tp = &stdioTransport{cmd: aliveCmd(t), stdin: failingWriter} + if err := tp.writeJSONLine(map[string]any{"method": "turn/start"}); err == nil { + t.Fatalf("expected write failure") + } +} + +func TestTransport_SendRequest_CoversClosedWriteSuccessAndTimeout(t *testing.T) { + tp := &stdioTransport{ + closed: true, + pending: map[string]*pendingRequest{}, + } + if _, err := tp.sendRequest(context.Background(), "turn/start", nil, 50*time.Millisecond); err == nil { + t.Fatalf("expected closed transport error") + } + + tp = &stdioTransport{ + pending: map[string]*pendingRequest{}, + } + if _, err := tp.sendRequest(context.Background(), "turn/start", nil, 50*time.Millisecond); err == nil { + t.Fatalf("expected write error when process is unavailable") + } + if len(tp.pending) != 0 { + t.Fatalf("expected pending to be cleaned after write failure; got %#v", tp.pending) + } + + writer := &recordingWriteCloser{} + tp = &stdioTransport{ + pending: map[string]*pendingRequest{}, + cmd: aliveCmd(t), + stdin: writer, + } + resultCh := make(chan struct { + value map[string]any + err error + }, 1) + go func() { + value, err := tp.sendRequest(context.Background(), "turn/start", map[string]any{"input": []any{}}, 250*time.Millisecond) + resultCh <- struct { + value map[string]any + err error + }{value: value, err: err} + }() + + deadline := time.Now().Add(150 * time.Millisecond) + for { + tp.mu.Lock() + _, ok := tp.pending["1"] + tp.mu.Unlock() + if ok { + break + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for pending request registration") + } + time.Sleep(2 * time.Millisecond) + } + tp.resolvePendingRequest(1, pendingResult{result: map[string]any{"ok": true}}) + got := <-resultCh + if got.err != nil { + t.Fatalf("sendRequest success path returned error: %v", got.err) + } + if asMap(got.value)["ok"] != true { + t.Fatalf("sendRequest success payload: %#v", got.value) + } + + timeoutTransport := &stdioTransport{ + pending: map[string]*pendingRequest{}, + cmd: aliveCmd(t), + stdin: &recordingWriteCloser{}, + } + _, err := timeoutTransport.sendRequest(context.Background(), "turn/start", nil, 25*time.Millisecond) + if err == nil { + t.Fatalf("expected timeout error") + } + tpErr := &llm.RequestTimeoutError{} + if !errors.As(err, &tpErr) { + t.Fatalf("expected RequestTimeoutError, got %T (%v)", err, err) + } + timeoutTransport.mu.Lock() + pendingLen := len(timeoutTransport.pending) + timeoutTransport.mu.Unlock() + if pendingLen != 0 { + t.Fatalf("expected timeout path to clear pending map, got len=%d", pendingLen) + } +} + +func TestTransport_HandleIncomingMessage_ResolvesPendingAndForwardsNotifications(t *testing.T) { + tp := &stdioTransport{ + pending: map[string]*pendingRequest{}, + listeners: map[int]func(map[string]any){}, + } + + okCh := make(chan pendingResult, 1) + tp.pending["1"] = &pendingRequest{method: "turn/start", respCh: okCh} + tp.handleIncomingMessage(map[string]any{"id": 1, "result": map[string]any{"threadId": "thread_1"}}) + select { + case got := <-okCh: + if asMap(got.result)["threadId"] != "thread_1" { + t.Fatalf("unexpected result payload: %#v", got.result) + } + default: + t.Fatalf("expected pending result resolution") + } + + errCh := make(chan pendingResult, 1) + tp.pending["2"] = &pendingRequest{method: "turn/start", respCh: errCh} + tp.handleIncomingMessage(map[string]any{ + "id": 2, + "method": "turn/start", + "error": map[string]any{"code": -32601, "message": "method not found"}, + }) + select { + case got := <-errCh: + if got.err == nil { + t.Fatalf("expected rpc error result") + } + default: + t.Fatalf("expected pending error resolution") + } + + notifications := make(chan string, 1) + unsubscribe := tp.subscribe(func(notification map[string]any) { + notifications <- asString(notification["method"]) + }) + tp.handleIncomingMessage(map[string]any{ + "method": "turn/progress", + "params": map[string]any{"threadId": "thread_1"}, + }) + select { + case method := <-notifications: + if method != "turn/progress" { + t.Fatalf("unexpected notification method: %q", method) + } + default: + t.Fatalf("expected notification fan-out") + } + unsubscribe() +} + +func TestTransport_HandleServerRequest_SupportsKnownMethods(t *testing.T) { + writer := &recordingWriteCloser{} + tp := &stdioTransport{ + cmd: aliveCmd(t), + stdin: writer, + } + + cases := []struct { + id int + method string + params any + wantField string + wantValue string + wantErr int + }{ + {id: 1, method: "item/tool/call", wantField: "success", wantValue: "false"}, + {id: 2, method: "item/tool/requestUserInput", params: map[string]any{"questions": []any{map[string]any{"id": "q1"}}}, wantField: "answers", wantValue: "map"}, + {id: 3, method: "item/commandExecution/requestApproval", wantField: "decision", wantValue: "decline"}, + {id: 4, method: "item/fileChange/requestApproval", wantField: "decision", wantValue: "decline"}, + {id: 5, method: "applyPatchApproval", wantField: "decision", wantValue: "denied"}, + {id: 6, method: "execCommandApproval", wantField: "decision", wantValue: "denied"}, + {id: 7, method: "account/chatgptAuthTokens/refresh", wantErr: -32001}, + {id: 8, method: "unknown/method", wantErr: -32601}, + } + + for _, tc := range cases { + tp.handleServerRequest(tc.id, tc.method, tc.params) + } + + lines := writer.lines() + if len(lines) != len(cases) { + t.Fatalf("written lines: got %d want %d (%#v)", len(lines), len(cases), lines) + } + for i, tc := range cases { + var msg map[string]any + if err := json.Unmarshal([]byte(lines[i]), &msg); err != nil { + t.Fatalf("unmarshal response line %d: %v", i, err) + } + if got := asInt(msg["id"], 0); got != tc.id { + t.Fatalf("line %d id: got %d want %d", i, got, tc.id) + } + if tc.wantErr != 0 { + errObj := asMap(msg["error"]) + if got := asInt(errObj["code"], 0); got != tc.wantErr { + t.Fatalf("line %d error code: got %d want %d", i, got, tc.wantErr) + } + continue + } + result := asMap(msg["result"]) + switch tc.wantValue { + case "false": + if got := fmt.Sprint(result[tc.wantField]); got != "false" { + t.Fatalf("line %d result[%q]: got %q want false", i, tc.wantField, got) + } + case "map": + if asMap(result[tc.wantField]) == nil { + t.Fatalf("line %d expected map field %q in result: %#v", i, tc.wantField, result) + } + default: + if got := asString(result[tc.wantField]); got != tc.wantValue { + t.Fatalf("line %d result[%q]: got %q want %q", i, tc.wantField, got, tc.wantValue) + } + } + } +} + +func TestTransport_ShutdownProcess_WithRunningProcess(t *testing.T) { + cmd := exec.Command(os.Args[0], "-test.run=TestTransport_HelperProcess") + cmd.Env = append(os.Environ(), + "GO_WANT_TRANSPORT_HELPER=1", + "GO_TRANSPORT_HELPER_MODE=stdin", + ) + stdin, err := cmd.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe: %v", err) + } + if err := cmd.Start(); err != nil { + t.Fatalf("Start: %v", err) + } + done := make(chan struct{}) + go func() { + _ = cmd.Wait() + close(done) + }() + + tp := &stdioTransport{ + cmd: cmd, + stdin: stdin, + procDone: done, + opts: TransportOptions{ + ShutdownTimeout: 250 * time.Millisecond, + }, + } + if err := tp.shutdownProcess(); err != nil { + t.Fatalf("shutdownProcess: %v", err) + } + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatalf("helper process did not exit after shutdown") + } +} + +func TestTransport_HelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_TRANSPORT_HELPER") != "1" { + return + } + switch os.Getenv("GO_TRANSPORT_HELPER_MODE") { + case "stdin": + _, _ = io.Copy(io.Discard, os.Stdin) + case "exit": + return + default: + return + } +} diff --git a/internal/llm/providers/codexappserver/transport_timeout_test.go b/internal/llm/providers/codexappserver/transport_timeout_test.go new file mode 100644 index 00000000..e51afdec --- /dev/null +++ b/internal/llm/providers/codexappserver/transport_timeout_test.go @@ -0,0 +1,82 @@ +package codexappserver + +import ( + "context" + "testing" + "time" +) + +func TestNewTransport_DefaultRequestTimeoutIsDisabled(t *testing.T) { + transport := NewTransport(TransportOptions{}) + if transport.opts.RequestTimeout != 0 { + t.Fatalf("request timeout: got %v want 0 (disabled)", transport.opts.RequestTimeout) + } +} + +func TestContextWithRequestTimeout_DisabledDoesNotInjectDeadline(t *testing.T) { + ctx := context.Background() + + derivedCtx, cancel := contextWithRequestTimeout(ctx, 0) + defer cancel() + + if _, ok := derivedCtx.Deadline(); ok { + t.Fatalf("expected no deadline when timeout is disabled") + } +} + +func TestContextWithRequestTimeout_PositiveTimeoutInjectsDeadline(t *testing.T) { + ctx := context.Background() + + derivedCtx, cancel := contextWithRequestTimeout(ctx, 500*time.Millisecond) + defer cancel() + + deadline, ok := derivedCtx.Deadline() + if !ok { + t.Fatalf("expected derived context deadline") + } + remaining := time.Until(deadline) + if remaining <= 0 || remaining > time.Second { + t.Fatalf("derived deadline remaining=%v, expected within (0, 1s]", remaining) + } +} + +func TestContextWithRequestTimeout_ParentDeadlineTakesPrecedence(t *testing.T) { + parentCtx, cancelParent := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancelParent() + + derivedCtx, cancelDerived := contextWithRequestTimeout(parentCtx, 5*time.Second) + defer cancelDerived() + + deadline, ok := derivedCtx.Deadline() + if !ok { + t.Fatalf("expected derived context deadline from parent") + } + remaining := time.Until(deadline) + if remaining <= 0 || remaining > time.Second { + t.Fatalf("derived deadline remaining=%v, expected parent-sized deadline", remaining) + } +} + +func TestInterruptTimeout_DefaultsWhenRequestTimeoutDisabled(t *testing.T) { + transport := NewTransport(TransportOptions{}) + if got := transport.interruptTimeout(); got != defaultInterruptTimeout { + t.Fatalf("interrupt timeout: got %v want %v", got, defaultInterruptTimeout) + } +} + +func TestInterruptTimeout_UsesRequestTimeoutWhenSet(t *testing.T) { + transport := NewTransport(TransportOptions{RequestTimeout: 3 * time.Second}) + if got := transport.interruptTimeout(); got != 3*time.Second { + t.Fatalf("interrupt timeout: got %v want %v", got, 3*time.Second) + } +} + +func TestInterruptTimeout_ClampsToShutdownTimeout(t *testing.T) { + transport := NewTransport(TransportOptions{ + RequestTimeout: 7 * time.Second, + ShutdownTimeout: 1 * time.Second, + }) + if got := transport.interruptTimeout(); got != 1*time.Second { + t.Fatalf("interrupt timeout: got %v want %v", got, 1*time.Second) + } +} diff --git a/internal/llm/providers/codexappserver/util.go b/internal/llm/providers/codexappserver/util.go new file mode 100644 index 00000000..9ff0859a --- /dev/null +++ b/internal/llm/providers/codexappserver/util.go @@ -0,0 +1,124 @@ +package codexappserver + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" +) + +func asMap(v any) map[string]any { + m, _ := v.(map[string]any) + return m +} + +func asSlice(v any) []any { + a, _ := v.([]any) + return a +} + +func asString(v any) string { + switch x := v.(type) { + case string: + return x + case json.Number: + return x.String() + default: + return "" + } +} + +func asInt(v any, def int) int { + switch x := v.(type) { + case int: + return x + case int8: + return int(x) + case int16: + return int(x) + case int32: + return int(x) + case int64: + return int(x) + case float64: + return int(x) + case float32: + return int(x) + case json.Number: + i, err := x.Int64() + if err == nil { + return int(i) + } + case string: + n := strings.TrimSpace(x) + if n == "" { + return def + } + var i int + if _, err := fmt.Sscanf(n, "%d", &i); err == nil { + return i + } + } + return def +} + +func asBool(v any, def bool) bool { + b, ok := v.(bool) + if ok { + return b + } + return def +} + +func deepCopyMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + b, err := json.Marshal(in) + if err != nil { + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out + } + return decodeJSONToMap(b) +} + +func decodeJSONToMap(b []byte) map[string]any { + dec := json.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + var out map[string]any + if err := dec.Decode(&out); err != nil { + return map[string]any{} + } + if out == nil { + return map[string]any{} + } + return out +} + +func normalizeCode(value string) string { + value = strings.TrimSpace(strings.ToUpper(value)) + if value == "" { + return "" + } + var b strings.Builder + for _, r := range value { + if (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + b.WriteRune(r) + continue + } + b.WriteByte('_') + } + return b.String() +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if s := strings.TrimSpace(v); s != "" { + return s + } + } + return "" +} diff --git a/internal/llm/providers/codexappserver/util_test.go b/internal/llm/providers/codexappserver/util_test.go new file mode 100644 index 00000000..adfabad2 --- /dev/null +++ b/internal/llm/providers/codexappserver/util_test.go @@ -0,0 +1,121 @@ +package codexappserver + +import ( + "encoding/json" + "testing" +) + +func TestUtil_AsHelpersAndNormalization(t *testing.T) { + m := map[string]any{"x": 1} + if got := asMap(m); got["x"] != 1 { + t.Fatalf("asMap mismatch: %#v", got) + } + if got := asMap("not-a-map"); got != nil { + t.Fatalf("expected nil for invalid map cast, got %#v", got) + } + + s := []any{"a", 1} + if got := asSlice(s); len(got) != 2 { + t.Fatalf("asSlice mismatch: %#v", got) + } + if got := asSlice("not-a-slice"); got != nil { + t.Fatalf("expected nil for invalid slice cast, got %#v", got) + } + + if got := asString("abc"); got != "abc" { + t.Fatalf("asString string mismatch: %q", got) + } + if got := asString(json.Number("123")); got != "123" { + t.Fatalf("asString json.Number mismatch: %q", got) + } + if got := asString(99); got != "" { + t.Fatalf("expected empty string for unsupported type, got %q", got) + } + + if got := normalizeCode(" invalid-request "); got != "INVALID_REQUEST" { + t.Fatalf("normalizeCode mismatch: %q", got) + } + if got := firstNonEmpty("", " ", "x", "y"); got != "x" { + t.Fatalf("firstNonEmpty mismatch: %q", got) + } +} + +func TestUtil_AsIntAndBool(t *testing.T) { + if got := asInt(int8(2), 0); got != 2 { + t.Fatalf("asInt int8 mismatch: %d", got) + } + if got := asInt(int16(3), 0); got != 3 { + t.Fatalf("asInt int16 mismatch: %d", got) + } + if got := asInt(int32(4), 0); got != 4 { + t.Fatalf("asInt int32 mismatch: %d", got) + } + if got := asInt(int64(5), 0); got != 5 { + t.Fatalf("asInt int64 mismatch: %d", got) + } + if got := asInt(float32(6.9), 0); got != 6 { + t.Fatalf("asInt float32 mismatch: %d", got) + } + if got := asInt(float64(7.9), 0); got != 7 { + t.Fatalf("asInt float64 mismatch: %d", got) + } + if got := asInt(json.Number("8"), 0); got != 8 { + t.Fatalf("asInt json.Number mismatch: %d", got) + } + if got := asInt(" 9 ", 0); got != 9 { + t.Fatalf("asInt string numeric mismatch: %d", got) + } + if got := asInt(" ", 42); got != 42 { + t.Fatalf("asInt empty string should use default: %d", got) + } + if got := asInt("not-numeric", 42); got != 42 { + t.Fatalf("asInt invalid string should use default: %d", got) + } + + if got := asBool(true, false); !got { + t.Fatalf("asBool true mismatch") + } + if got := asBool("bad", true); !got { + t.Fatalf("asBool fallback mismatch") + } +} + +func TestUtil_DeepCopyAndDecodeJSONToMap(t *testing.T) { + orig := map[string]any{"nested": map[string]any{"x": 1}} + cp := deepCopyMap(orig) + if cp == nil { + t.Fatalf("deepCopyMap returned nil") + } + nested := asMap(cp["nested"]) + nested["x"] = 9 + if asMap(orig["nested"])["x"] != 1 { + t.Fatalf("deepCopyMap should not alias nested map") + } + + cp = deepCopyMap(nil) + if cp != nil { + t.Fatalf("deepCopyMap(nil) should return nil, got %#v", cp) + } + + withUnmarshalable := map[string]any{ + "f": func() {}, + "x": "ok", + } + cp = deepCopyMap(withUnmarshalable) + if cp["x"] != "ok" { + t.Fatalf("deepCopyMap fallback copy mismatch: %#v", cp) + } + if _, ok := cp["f"]; !ok { + t.Fatalf("deepCopyMap fallback should preserve unmarshalable key") + } + + if got := decodeJSONToMap([]byte(`{"x":1}`)); asString(got["x"]) != "1" { + t.Fatalf("decodeJSONToMap valid json mismatch: %#v", got) + } + if got := decodeJSONToMap([]byte(`not-json`)); len(got) != 0 { + t.Fatalf("decodeJSONToMap invalid json should return empty map, got %#v", got) + } + if got := decodeJSONToMap([]byte(`null`)); len(got) != 0 { + t.Fatalf("decodeJSONToMap null should return empty map, got %#v", got) + } +} diff --git a/internal/llm/providers/google/adapter.go b/internal/llm/providers/google/adapter.go index e8c9334d..e377ce4f 100644 --- a/internal/llm/providers/google/adapter.go +++ b/internal/llm/providers/google/adapter.go @@ -203,7 +203,9 @@ func (a *Adapter) Complete(ctx context.Context, req llm.Request) (llm.Response, return llm.Response{}, llm.ErrorFromHTTPStatus(a.Name(), resp.StatusCode, msg, raw, ra) } - return fromGeminiResponse(a.Name(), raw, req.Model), nil + out := fromGeminiResponse(a.Name(), raw, req.Model) + out.RateLimit = llm.ParseRateLimitInfo(resp.Header, time.Now()) + return out, nil } func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, error) { @@ -338,6 +340,7 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro cancel() return nil, llm.ErrorFromHTTPStatus(a.Name(), resp.StatusCode, msg, raw, ra) } + rateLimit := llm.ParseRateLimitInfo(resp.Header, time.Now()) s := llm.NewChanStream(cancel) s.Send(llm.StreamEvent{Type: llm.StreamEventStreamStart}) @@ -435,12 +438,13 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro flushTextPart() msg := llm.Message{Role: llm.RoleAssistant, Content: contentParts} r := llm.Response{ - Provider: a.Name(), - Model: req.Model, - Message: msg, - Finish: finish, - Usage: usage, - Raw: raw, + Provider: a.Name(), + Model: req.Model, + Message: msg, + Finish: finish, + Usage: usage, + RateLimit: rateLimit, + Raw: raw, } if r.Finish.Reason == "" { if len(r.ToolCalls()) > 0 { diff --git a/internal/llm/providers/google/adapter_test.go b/internal/llm/providers/google/adapter_test.go index e800fecf..3490a109 100644 --- a/internal/llm/providers/google/adapter_test.go +++ b/internal/llm/providers/google/adapter_test.go @@ -17,6 +17,28 @@ import ( "github.com/danshapiro/kilroy/internal/llm" ) +func assertRateLimitInfo(t *testing.T, rl *llm.RateLimitInfo) { + t.Helper() + if rl == nil { + t.Fatalf("expected rate limit info, got nil") + } + if rl.RequestsRemaining == nil || *rl.RequestsRemaining != 9 { + t.Fatalf("requests_remaining: %#v", rl.RequestsRemaining) + } + if rl.RequestsLimit == nil || *rl.RequestsLimit != 10 { + t.Fatalf("requests_limit: %#v", rl.RequestsLimit) + } + if rl.TokensRemaining == nil || *rl.TokensRemaining != 90 { + t.Fatalf("tokens_remaining: %#v", rl.TokensRemaining) + } + if rl.TokensLimit == nil || *rl.TokensLimit != 100 { + t.Fatalf("tokens_limit: %#v", rl.TokensLimit) + } + if rl.ResetAt != "2025-01-01T00:00:10Z" { + t.Fatalf("reset_at: %q", rl.ResetAt) + } +} + func TestAdapter_Complete_MapsToGeminiGenerateContent(t *testing.T) { var gotBody map[string]any gotKey := "" @@ -34,6 +56,11 @@ func TestAdapter_Complete_MapsToGeminiGenerateContent(t *testing.T) { _ = json.Unmarshal(b, &gotBody) w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") _, _ = w.Write([]byte(`{ "candidates": [{"content": {"parts": [{"text":"Hello"}]}, "finishReason":"STOP"}], "usageMetadata": {"promptTokenCount": 1, "candidatesTokenCount": 2, "totalTokenCount": 3} @@ -64,6 +91,7 @@ func TestAdapter_Complete_MapsToGeminiGenerateContent(t *testing.T) { if strings.TrimSpace(resp.Text()) != "Hello" { t.Fatalf("resp text: %q", resp.Text()) } + assertRateLimitInfo(t, resp.RateLimit) if gotKey != "k" { t.Fatalf("key param: %q", gotKey) } @@ -529,6 +557,11 @@ func TestAdapter_Stream_YieldsTextDeltasAndFinish(t *testing.T) { _ = json.Unmarshal(b, &gotBody) w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") f, _ := w.(http.Flusher) write := func(data string) { _, _ = io.WriteString(w, "data: "+data+"\n\n") @@ -570,6 +603,7 @@ func TestAdapter_Stream_YieldsTextDeltasAndFinish(t *testing.T) { if finish == nil || strings.TrimSpace(finish.Text()) != "Hello" { t.Fatalf("finish response: %+v", finish) } + assertRateLimitInfo(t, finish.RateLimit) if gotKey != "k" { t.Fatalf("key param: %q", gotKey) } diff --git a/internal/llm/providers/openai/adapter.go b/internal/llm/providers/openai/adapter.go index e5426298..dc3cd914 100644 --- a/internal/llm/providers/openai/adapter.go +++ b/internal/llm/providers/openai/adapter.go @@ -157,7 +157,9 @@ func (a *Adapter) Complete(ctx context.Context, req llm.Request) (llm.Response, return llm.Response{}, llm.ErrorFromHTTPStatus(a.Name(), resp.StatusCode, msg, raw, ra) } - return fromResponses(a.Name(), raw, req.Model), nil + out := fromResponses(a.Name(), raw, req.Model) + out.RateLimit = llm.ParseRateLimitInfo(resp.Header, time.Now()) + return out, nil } func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, error) { @@ -255,6 +257,7 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro s := llm.NewChanStream(cancel) // STREAM_START s.Send(llm.StreamEvent{Type: llm.StreamEventStreamStart}) + rateLimit := llm.ParseRateLimitInfo(resp.Header, time.Now()) go func() { defer func() { @@ -399,6 +402,7 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro rawResp = payload } r := fromResponses(a.Name(), rawResp, req.Model) + r.RateLimit = rateLimit // Ensure text segment is closed. if textStarted { s.Send(llm.StreamEvent{Type: llm.StreamEventTextEnd, TextID: textID}) diff --git a/internal/llm/providers/openai/adapter_test.go b/internal/llm/providers/openai/adapter_test.go index 2af1ef16..bb3daf3b 100644 --- a/internal/llm/providers/openai/adapter_test.go +++ b/internal/llm/providers/openai/adapter_test.go @@ -16,6 +16,28 @@ import ( "github.com/danshapiro/kilroy/internal/llm" ) +func assertRateLimitInfo(t *testing.T, rl *llm.RateLimitInfo) { + t.Helper() + if rl == nil { + t.Fatalf("expected rate limit info, got nil") + } + if rl.RequestsRemaining == nil || *rl.RequestsRemaining != 9 { + t.Fatalf("requests_remaining: %#v", rl.RequestsRemaining) + } + if rl.RequestsLimit == nil || *rl.RequestsLimit != 10 { + t.Fatalf("requests_limit: %#v", rl.RequestsLimit) + } + if rl.TokensRemaining == nil || *rl.TokensRemaining != 90 { + t.Fatalf("tokens_remaining: %#v", rl.TokensRemaining) + } + if rl.TokensLimit == nil || *rl.TokensLimit != 100 { + t.Fatalf("tokens_limit: %#v", rl.TokensLimit) + } + if rl.ResetAt != "2025-01-01T00:00:10Z" { + t.Fatalf("reset_at: %q", rl.ResetAt) + } +} + func TestAdapter_Complete_MapsToResponsesAPI(t *testing.T) { var gotBody map[string]any @@ -29,6 +51,11 @@ func TestAdapter_Complete_MapsToResponsesAPI(t *testing.T) { _ = json.Unmarshal(b, &gotBody) w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") _, _ = w.Write([]byte(`{ "id": "resp_1", "model": "gpt-5.2", @@ -67,6 +94,7 @@ func TestAdapter_Complete_MapsToResponsesAPI(t *testing.T) { if strings.TrimSpace(resp.Text()) != "Hello" { t.Fatalf("resp text: %q", resp.Text()) } + assertRateLimitInfo(t, resp.RateLimit) // Assert request mapping. if gotBody == nil { @@ -333,6 +361,11 @@ func TestAdapter_Stream_YieldsTextDeltasAndFinish(t *testing.T) { _ = json.Unmarshal(b, &gotBody) w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") f, _ := w.(http.Flusher) write := func(event string, data string) { @@ -377,6 +410,7 @@ func TestAdapter_Stream_YieldsTextDeltasAndFinish(t *testing.T) { if finish == nil || strings.TrimSpace(finish.Text()) != "Hello" { t.Fatalf("finish response: %+v", finish) } + assertRateLimitInfo(t, finish.RateLimit) if gotBody == nil { t.Fatalf("server did not capture request body") diff --git a/internal/llm/providers/openaicompat/adapter.go b/internal/llm/providers/openaicompat/adapter.go index 6f6c0d61..db445085 100644 --- a/internal/llm/providers/openaicompat/adapter.go +++ b/internal/llm/providers/openaicompat/adapter.go @@ -118,6 +118,7 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro _, perr := parseChatCompletionsResponse(a.cfg.Provider, req.Model, resp) return nil, perr } + rateLimit := llm.ParseRateLimitInfo(resp.Header, time.Now()) s := llm.NewChanStream(cancelAll) go func() { @@ -127,9 +128,10 @@ func (a *Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, erro s.Send(llm.StreamEvent{Type: llm.StreamEventStreamStart}) state := &chatStreamState{ - Provider: a.cfg.Provider, - Model: req.Model, - TextID: "assistant_text", + Provider: a.cfg.Provider, + Model: req.Model, + TextID: "assistant_text", + RateLimit: rateLimit, } err := llm.ParseSSE(sctx, resp.Body, func(ev llm.SSEEvent) error { @@ -235,7 +237,12 @@ func parseChatCompletionsResponse(provider, model string, resp *http.Response) ( if err := dec.Decode(&raw); err != nil { return llm.Response{}, llm.WrapContextError(provider, err) } - return fromChatCompletions(provider, model, raw) + out, err := fromChatCompletions(provider, model, raw) + if err != nil { + return llm.Response{}, err + } + out.RateLimit = llm.ParseRateLimitInfo(resp.Header, time.Now()) + return out, nil } func toChatCompletionsMessages(msgs []llm.Message) []map[string]any { @@ -441,9 +448,10 @@ func normalizeFinishReason(in string) string { } type chatStreamState struct { - Provider string - Model string - TextID string + Provider string + Model string + TextID string + RateLimit *llm.RateLimitInfo Text strings.Builder TextOpen bool @@ -490,11 +498,12 @@ func (st *chatStreamState) FinalResponse() llm.Response { finish = llm.FinishReason{Reason: "stop", Raw: "stop"} } return llm.Response{ - Provider: st.Provider, - Model: st.Model, - Message: msg, - Finish: finish, - Usage: st.Usage, + Provider: st.Provider, + Model: st.Model, + Message: msg, + Finish: finish, + Usage: st.Usage, + RateLimit: st.RateLimit, } } diff --git a/internal/llm/providers/openaicompat/adapter_test.go b/internal/llm/providers/openaicompat/adapter_test.go index 7c75997b..7df3f6be 100644 --- a/internal/llm/providers/openaicompat/adapter_test.go +++ b/internal/llm/providers/openaicompat/adapter_test.go @@ -12,11 +12,38 @@ import ( "github.com/danshapiro/kilroy/internal/llm" ) +func assertRateLimitInfo(t *testing.T, rl *llm.RateLimitInfo) { + t.Helper() + if rl == nil { + t.Fatalf("expected rate limit info, got nil") + } + if rl.RequestsRemaining == nil || *rl.RequestsRemaining != 9 { + t.Fatalf("requests_remaining: %#v", rl.RequestsRemaining) + } + if rl.RequestsLimit == nil || *rl.RequestsLimit != 10 { + t.Fatalf("requests_limit: %#v", rl.RequestsLimit) + } + if rl.TokensRemaining == nil || *rl.TokensRemaining != 90 { + t.Fatalf("tokens_remaining: %#v", rl.TokensRemaining) + } + if rl.TokensLimit == nil || *rl.TokensLimit != 100 { + t.Fatalf("tokens_limit: %#v", rl.TokensLimit) + } + if rl.ResetAt != "2025-01-01T00:00:10Z" { + t.Fatalf("reset_at: %q", rl.ResetAt) + } +} + func TestAdapter_Complete_ChatCompletionsMapsToolCalls(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/chat/completions" { t.Fatalf("path: %s", r.URL.Path) } + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") _, _ = w.Write([]byte(`{"id":"c1","model":"m","choices":[{"finish_reason":"tool_calls","message":{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"file_path\":\"README.md\"}"}}]}}],"usage":{"prompt_tokens":10,"completion_tokens":3,"total_tokens":13}}`)) })) defer srv.Close() @@ -39,11 +66,17 @@ func TestAdapter_Complete_ChatCompletionsMapsToolCalls(t *testing.T) { if len(resp.ToolCalls()) != 1 { t.Fatalf("tool call mapping failed") } + assertRateLimitInfo(t, resp.RateLimit) } func TestAdapter_Stream_EmitsFinishEvent(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("x-ratelimit-remaining-requests", "9") + w.Header().Set("x-ratelimit-limit-requests", "10") + w.Header().Set("x-ratelimit-remaining-tokens", "90") + w.Header().Set("x-ratelimit-limit-tokens", "100") + w.Header().Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") _, _ = w.Write([]byte("data: {\"id\":\"c2\",\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":null}]}\n\n")) _, _ = w.Write([]byte("data: {\"id\":\"c2\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1,\"total_tokens\":2}}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) @@ -61,16 +94,17 @@ func TestAdapter_Stream_EmitsFinishEvent(t *testing.T) { } defer stream.Close() - sawFinish := false + var finish *llm.Response for ev := range stream.Events() { if ev.Type == llm.StreamEventFinish { - sawFinish = true + finish = ev.Response break } } - if !sawFinish { + if finish == nil { t.Fatalf("expected finish event") } + assertRateLimitInfo(t, finish.RateLimit) } func TestAdapter_Stream_MapsToolCallDeltasToEventsAndFinalResponse(t *testing.T) { diff --git a/internal/llm/rate_limit.go b/internal/llm/rate_limit.go new file mode 100644 index 00000000..247a8ba3 --- /dev/null +++ b/internal/llm/rate_limit.go @@ -0,0 +1,164 @@ +package llm + +import ( + "net/http" + "regexp" + "strconv" + "strings" + "time" +) + +var firstIntRe = regexp.MustCompile(`[-+]?\d+`) + +// ParseRateLimitInfo extracts informational rate limit metadata from response headers. +// The result is best-effort and intended for observability, not proactive throttling. +func ParseRateLimitInfo(headers http.Header, now time.Time) *RateLimitInfo { + if headers == nil { + return nil + } + + reqRemaining := parseHeaderInt(headers, + "x-ratelimit-remaining-requests", + "x-ratelimit-remaining-request", + ) + reqLimit := parseHeaderInt(headers, + "x-ratelimit-limit-requests", + "x-ratelimit-limit-request", + ) + tokRemaining := parseHeaderInt(headers, + "x-ratelimit-remaining-tokens", + "x-ratelimit-remaining-token", + ) + tokLimit := parseHeaderInt(headers, + "x-ratelimit-limit-tokens", + "x-ratelimit-limit-token", + ) + + // Fallback headers that do not distinguish requests vs tokens. + if reqRemaining == nil && tokRemaining == nil { + reqRemaining = parseHeaderInt(headers, + "x-ratelimit-remaining", + "ratelimit-remaining", + ) + } + if reqLimit == nil && tokLimit == nil { + reqLimit = parseHeaderInt(headers, + "x-ratelimit-limit", + "ratelimit-limit", + ) + } + + resetRaw := firstHeaderValue(headers, + "x-ratelimit-reset-requests", + "x-ratelimit-reset-request", + "x-ratelimit-reset-tokens", + "x-ratelimit-reset-token", + "x-ratelimit-reset", + "ratelimit-reset", + ) + resetAt := parseRateLimitReset(resetRaw, now) + + if reqRemaining == nil && reqLimit == nil && tokRemaining == nil && tokLimit == nil && resetAt == "" { + return nil + } + return &RateLimitInfo{ + RequestsRemaining: reqRemaining, + RequestsLimit: reqLimit, + TokensRemaining: tokRemaining, + TokensLimit: tokLimit, + ResetAt: resetAt, + } +} + +func firstHeaderValue(headers http.Header, keys ...string) string { + for _, key := range keys { + v := strings.TrimSpace(headers.Get(key)) + if v != "" { + return v + } + } + return "" +} + +func parseHeaderInt(headers http.Header, keys ...string) *int { + for _, key := range keys { + if n, ok := parseIntLikeHeaderValue(headers.Get(key)); ok { + return &n + } + } + return nil +} + +func parseIntLikeHeaderValue(v string) (int, bool) { + v = strings.TrimSpace(v) + if v == "" { + return 0, false + } + if i, err := strconv.Atoi(v); err == nil { + return i, true + } + token := v + if idx := strings.IndexAny(token, ",;"); idx >= 0 { + token = strings.TrimSpace(token[:idx]) + if i, err := strconv.Atoi(token); err == nil { + return i, true + } + } + if f, err := strconv.ParseFloat(token, 64); err == nil { + return int(f), true + } + if m := firstIntRe.FindString(v); m != "" { + if i, err := strconv.Atoi(m); err == nil { + return i, true + } + } + return 0, false +} + +func parseRateLimitReset(v string, now time.Time) string { + v = strings.TrimSpace(v) + if v == "" { + return "" + } + if t, err := http.ParseTime(v); err == nil { + return t.UTC().Format(time.RFC3339) + } + if d, err := time.ParseDuration(v); err == nil { + if d < 0 { + d = 0 + } + return now.Add(d).UTC().Format(time.RFC3339) + } + if f, ok := parseFloatLikeHeaderValue(v); ok { + switch { + case f >= 1e12: + // Unix epoch in milliseconds. + return time.UnixMilli(int64(f)).UTC().Format(time.RFC3339) + case f >= 1e9: + // Unix epoch in seconds. + return time.Unix(int64(f), 0).UTC().Format(time.RFC3339) + case f >= 0: + // Relative seconds. + return now.Add(time.Duration(f * float64(time.Second))).UTC().Format(time.RFC3339) + } + } + return "" +} + +func parseFloatLikeHeaderValue(v string) (float64, bool) { + v = strings.TrimSpace(v) + if v == "" { + return 0, false + } + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f, true + } + token := v + if idx := strings.IndexAny(token, ",;"); idx >= 0 { + token = strings.TrimSpace(token[:idx]) + if f, err := strconv.ParseFloat(token, 64); err == nil { + return f, true + } + } + return 0, false +} diff --git a/internal/llm/rate_limit_test.go b/internal/llm/rate_limit_test.go new file mode 100644 index 00000000..8c709351 --- /dev/null +++ b/internal/llm/rate_limit_test.go @@ -0,0 +1,98 @@ +package llm + +import ( + "net/http" + "testing" + "time" +) + +func TestParseRateLimitInfo_ProviderHeaders(t *testing.T) { + h := http.Header{} + h.Set("x-ratelimit-remaining-requests", "9") + h.Set("x-ratelimit-limit-requests", "10") + h.Set("x-ratelimit-remaining-tokens", "90") + h.Set("x-ratelimit-limit-tokens", "100") + h.Set("x-ratelimit-reset-requests", "Wed, 01 Jan 2025 00:00:10 GMT") + + got := ParseRateLimitInfo(h, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)) + if got == nil { + t.Fatalf("expected rate limit info") + } + if got.RequestsRemaining == nil || *got.RequestsRemaining != 9 { + t.Fatalf("requests_remaining: %#v", got.RequestsRemaining) + } + if got.RequestsLimit == nil || *got.RequestsLimit != 10 { + t.Fatalf("requests_limit: %#v", got.RequestsLimit) + } + if got.TokensRemaining == nil || *got.TokensRemaining != 90 { + t.Fatalf("tokens_remaining: %#v", got.TokensRemaining) + } + if got.TokensLimit == nil || *got.TokensLimit != 100 { + t.Fatalf("tokens_limit: %#v", got.TokensLimit) + } + if got.ResetAt != "2025-01-01T00:00:10Z" { + t.Fatalf("reset_at: %q", got.ResetAt) + } +} + +func TestParseRateLimitInfo_FallbackHeaders(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + h := http.Header{} + h.Set("x-ratelimit-remaining", "5") + h.Set("x-ratelimit-limit", "8") + h.Set("ratelimit-reset", "5") + + got := ParseRateLimitInfo(h, now) + if got == nil { + t.Fatalf("expected rate limit info") + } + if got.RequestsRemaining == nil || *got.RequestsRemaining != 5 { + t.Fatalf("requests_remaining: %#v", got.RequestsRemaining) + } + if got.RequestsLimit == nil || *got.RequestsLimit != 8 { + t.Fatalf("requests_limit: %#v", got.RequestsLimit) + } + if got.TokensRemaining != nil { + t.Fatalf("tokens_remaining: %#v", got.TokensRemaining) + } + if got.TokensLimit != nil { + t.Fatalf("tokens_limit: %#v", got.TokensLimit) + } + if got.ResetAt != "2025-01-01T00:00:05Z" { + t.Fatalf("reset_at: %q", got.ResetAt) + } +} + +func TestParseRateLimitInfo_ResetFormats(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + cases := []struct { + name string + value string + expect string + }{ + {name: "duration", value: "2s", expect: "2025-01-01T00:00:02Z"}, + {name: "epoch_seconds", value: "1735689610", expect: "2025-01-01T00:00:10Z"}, + {name: "epoch_millis", value: "1735689610000", expect: "2025-01-01T00:00:10Z"}, + {name: "relative_seconds_float", value: "1.5", expect: "2025-01-01T00:00:01Z"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h := http.Header{} + h.Set("x-ratelimit-reset", tc.value) + got := ParseRateLimitInfo(h, now) + if got == nil { + t.Fatalf("expected rate limit info") + } + if got.ResetAt != tc.expect { + t.Fatalf("reset_at: got %q want %q", got.ResetAt, tc.expect) + } + }) + } +} + +func TestParseRateLimitInfo_Empty(t *testing.T) { + got := ParseRateLimitInfo(http.Header{}, time.Now()) + if got != nil { + t.Fatalf("expected nil, got %#v", got) + } +} diff --git a/internal/llm/stream.go b/internal/llm/stream.go index e8f41d89..42584d15 100644 --- a/internal/llm/stream.go +++ b/internal/llm/stream.go @@ -29,12 +29,19 @@ const ( type StreamEvent struct { Type StreamEventType `json:"type"` + // Stream start metadata + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Warnings []Warning `json:"warnings,omitempty"` + // Text events Delta string `json:"delta,omitempty"` TextID string `json:"text_id,omitempty"` // Reasoning events ReasoningDelta string `json:"reasoning_delta,omitempty"` + ReasoningID string `json:"reasoning_id,omitempty"` + Redacted *bool `json:"redacted,omitempty"` // Tool call events ToolCall *ToolCallData `json:"tool_call,omitempty"` @@ -48,5 +55,6 @@ type StreamEvent struct { Err error `json:"-"` // Passthrough - Raw map[string]any `json:"raw,omitempty"` + EventType string `json:"event_type,omitempty"` + Raw map[string]any `json:"raw,omitempty"` } diff --git a/internal/llm/types.go b/internal/llm/types.go index 634d8227..cd23be52 100644 --- a/internal/llm/types.go +++ b/internal/llm/types.go @@ -94,6 +94,9 @@ type ContentPart struct { ToolCall *ToolCallData `json:"tool_call,omitempty"` ToolResult *ToolResultData `json:"tool_result,omitempty"` Thinking *ThinkingData `json:"thinking,omitempty"` + + // Data carries provider-specific payload for custom content kinds. + Data any `json:"data,omitempty"` } type ImageData struct { diff --git a/internal/llmclient/env.go b/internal/llmclient/env.go index 3df1b0bc..8b7142f1 100644 --- a/internal/llmclient/env.go +++ b/internal/llmclient/env.go @@ -3,6 +3,7 @@ package llmclient import ( "github.com/danshapiro/kilroy/internal/llm" _ "github.com/danshapiro/kilroy/internal/llm/providers/anthropic" + _ "github.com/danshapiro/kilroy/internal/llm/providers/codexappserver" _ "github.com/danshapiro/kilroy/internal/llm/providers/google" _ "github.com/danshapiro/kilroy/internal/llm/providers/openai" ) diff --git a/internal/llmclient/env_test.go b/internal/llmclient/env_test.go index 65ae3f06..27aabc93 100644 --- a/internal/llmclient/env_test.go +++ b/internal/llmclient/env_test.go @@ -7,8 +7,30 @@ func TestNewFromEnv_ErrorsWhenNoProvidersConfigured(t *testing.T) { t.Setenv("ANTHROPIC_API_KEY", "") t.Setenv("GEMINI_API_KEY", "") t.Setenv("GOOGLE_API_KEY", "") + t.Setenv("CODEX_APP_SERVER_COMMAND", "") + t.Setenv("CODEX_APP_SERVER_ARGS", "") + t.Setenv("CODEX_APP_SERVER_COMMAND_ARGS", "") + t.Setenv("CODEX_APP_SERVER_AUTO_DISCOVER", "") _, err := NewFromEnv() if err == nil { t.Fatalf("expected error, got nil") } } +func TestNewFromEnv_RegistersCodexAppServerWhenCommandOverrideIsSet(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "") + t.Setenv("ANTHROPIC_API_KEY", "") + t.Setenv("GEMINI_API_KEY", "") + t.Setenv("GOOGLE_API_KEY", "") + t.Setenv("CODEX_APP_SERVER_COMMAND", "codex") + t.Setenv("CODEX_APP_SERVER_ARGS", "") + t.Setenv("CODEX_APP_SERVER_COMMAND_ARGS", "") + t.Setenv("CODEX_APP_SERVER_AUTO_DISCOVER", "") + c, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv: %v", err) + } + names := c.ProviderNames() + if len(names) != 1 || names[0] != "codex-app-server" { + t.Fatalf("provider names: got %v want [codex-app-server]", names) + } +} diff --git a/internal/providerspec/builtin.go b/internal/providerspec/builtin.go index 789a7f1d..450a51fd 100644 --- a/internal/providerspec/builtin.go +++ b/internal/providerspec/builtin.go @@ -19,6 +19,16 @@ var builtinSpecs = map[string]Spec{ CapabilityAll: []string{"--json"}, }, }, + "codex-app-server": { + Key: "codex-app-server", + Aliases: []string{"codex_app_server"}, + API: &APISpec{ + Protocol: ProtocolCodexAppServer, + DefaultAPIKeyEnv: "", + ProviderOptionsKey: "codex_app_server", + ProfileFamily: "codex-app-server", + }, + }, "anthropic": { Key: "anthropic", API: &APISpec{ diff --git a/internal/providerspec/spec.go b/internal/providerspec/spec.go index 63d53161..dd388401 100644 --- a/internal/providerspec/spec.go +++ b/internal/providerspec/spec.go @@ -12,6 +12,7 @@ const ( ProtocolOpenAIChatCompletions APIProtocol = "openai_chat_completions" ProtocolAnthropicMessages APIProtocol = "anthropic_messages" ProtocolGoogleGenerateContent APIProtocol = "google_generate_content" + ProtocolCodexAppServer APIProtocol = "codex_app_server" ) type APISpec struct { diff --git a/internal/providerspec/spec_test.go b/internal/providerspec/spec_test.go index 2788bad8..1affea88 100644 --- a/internal/providerspec/spec_test.go +++ b/internal/providerspec/spec_test.go @@ -4,7 +4,7 @@ import "testing" func TestBuiltinSpecsIncludeCoreAndNewProviders(t *testing.T) { s := Builtins() - for _, key := range []string{"openai", "anthropic", "google", "kimi", "zai", "cerebras", "minimax", "inception"} { + for _, key := range []string{"openai", "codex-app-server", "anthropic", "google", "kimi", "zai", "cerebras", "minimax", "inception"} { if _, ok := s[key]; !ok { t.Fatalf("missing builtin provider %q", key) } @@ -33,6 +33,9 @@ func TestCanonicalProviderKey_Aliases(t *testing.T) { if got := CanonicalProviderKey("minimax-ai"); got != "minimax" { t.Fatalf("minimax-ai alias: got %q want %q", got, "minimax") } + if got := CanonicalProviderKey("codex_app_server"); got != "codex-app-server" { + t.Fatalf("codex_app_server alias: got %q want %q", got, "codex-app-server") + } if got := CanonicalProviderKey("inceptionlabs"); got != "inception" { t.Fatalf("inceptionlabs alias: got %q want %q", got, "inception") } @@ -44,6 +47,28 @@ func TestCanonicalProviderKey_Aliases(t *testing.T) { } } +func TestBuiltinCodexAppServerDefaults(t *testing.T) { + spec, ok := Builtin("codex-app-server") + if !ok { + t.Fatalf("expected codex-app-server builtin") + } + if spec.API == nil { + t.Fatalf("expected codex-app-server api spec") + } + if got := spec.API.Protocol; got != ProtocolCodexAppServer { + t.Fatalf("codex-app-server protocol: got %q want %q", got, ProtocolCodexAppServer) + } + if got := spec.API.DefaultAPIKeyEnv; got != "" { + t.Fatalf("codex-app-server api_key_env: got %q want empty", got) + } + if got := spec.API.ProviderOptionsKey; got != "codex_app_server" { + t.Fatalf("codex-app-server provider_options_key: got %q want %q", got, "codex_app_server") + } + if got := spec.API.ProfileFamily; got != "codex-app-server" { + t.Fatalf("codex-app-server profile_family: got %q want %q", got, "codex-app-server") + } +} + func TestBuiltinCerebrasDefaultsToOpenAICompatAPI(t *testing.T) { spec, ok := Builtin("cerebras") if !ok { From d662f9bdf58d98b478199cbeb8755db08538261c Mon Sep 17 00:00:00 2001 From: Vadim Comanescu Date: Fri, 27 Feb 2026 18:30:55 +0100 Subject: [PATCH 2/4] engine: allow spark model on codex app server api (cherry picked from commit 1a57a2ca640d65457fa663e9548d41c9be22ab53) --- internal/attractor/engine/cli_only_models.go | 4 +- .../attractor/engine/cli_only_models_test.go | 25 +++++++-- internal/attractor/engine/codergen_router.go | 4 +- .../engine/codergen_router_cli_only_test.go | 10 +++- .../attractor/engine/provider_preflight.go | 5 +- .../engine/provider_preflight_test.go | 52 +++++++++++++++---- 6 files changed, 77 insertions(+), 23 deletions(-) diff --git a/internal/attractor/engine/cli_only_models.go b/internal/attractor/engine/cli_only_models.go index 02b1a6c3..27848d01 100644 --- a/internal/attractor/engine/cli_only_models.go +++ b/internal/attractor/engine/cli_only_models.go @@ -4,9 +4,7 @@ import "strings" // cliOnlyModelIDs lists models that MUST route through CLI backend regardless // of provider backend configuration. These models have no API endpoint. -var cliOnlyModelIDs = map[string]bool{ - "gpt-5.3-codex-spark": true, -} +var cliOnlyModelIDs = map[string]bool{} // isCLIOnlyModel returns true if the given model ID (with or without provider // prefix) must be routed exclusively through the CLI backend. diff --git a/internal/attractor/engine/cli_only_models_test.go b/internal/attractor/engine/cli_only_models_test.go index 0b9adff0..972f473b 100644 --- a/internal/attractor/engine/cli_only_models_test.go +++ b/internal/attractor/engine/cli_only_models_test.go @@ -7,10 +7,10 @@ func TestIsCLIOnlyModel(t *testing.T) { model string want bool }{ - {"gpt-5.3-codex-spark", true}, - {"GPT-5.3-CODEX-SPARK", true}, // case-insensitive - {"openai/gpt-5.3-codex-spark", true}, // with provider prefix - {"gpt-5.3-codex", false}, // regular codex + {"gpt-5.3-codex-spark", false}, + {"GPT-5.3-CODEX-SPARK", false}, // case-insensitive + {"openai/gpt-5.3-codex-spark", false}, // with provider prefix + {"gpt-5.3-codex", false}, // regular codex {"gpt-5.2-codex", false}, {"claude-opus-4-6", false}, {"", false}, @@ -21,3 +21,20 @@ func TestIsCLIOnlyModel(t *testing.T) { } } } + +func TestIsCLIOnlyModel_UsesConfiguredRegistry(t *testing.T) { + orig := cliOnlyModelIDs + cliOnlyModelIDs = map[string]bool{ + "test-cli-only-model": true, + } + t.Cleanup(func() { + cliOnlyModelIDs = orig + }) + + if got := isCLIOnlyModel("test-cli-only-model"); !got { + t.Fatalf("isCLIOnlyModel(test-cli-only-model) = %v, want true", got) + } + if got := isCLIOnlyModel("openai/test-cli-only-model"); !got { + t.Fatalf("isCLIOnlyModel(openai/test-cli-only-model) = %v, want true", got) + } +} diff --git a/internal/attractor/engine/codergen_router.go b/internal/attractor/engine/codergen_router.go index 2b2228d4..7d3b2e14 100644 --- a/internal/attractor/engine/codergen_router.go +++ b/internal/attractor/engine/codergen_router.go @@ -94,8 +94,8 @@ func (r *CodergenRouter) Run(ctx context.Context, exec *Execution, node *model.N return "", nil, fmt.Errorf("no backend configured for provider %s", prov) } - // CLI-only model override: models like gpt-5.3-codex-spark have no API - // endpoint. Force CLI backend regardless of provider configuration. + // CLI-only model override: force CLI backend when a model is marked + // CLI-only in the registry. if isCLIOnlyModel(modelID) && backend != BackendCLI { warnEngine(exec, fmt.Sprintf("cli-only model override: node=%s model=%s backend=%s->cli", node.ID, modelID, backend)) backend = BackendCLI diff --git a/internal/attractor/engine/codergen_router_cli_only_test.go b/internal/attractor/engine/codergen_router_cli_only_test.go index cf351e2d..e855e4ea 100644 --- a/internal/attractor/engine/codergen_router_cli_only_test.go +++ b/internal/attractor/engine/codergen_router_cli_only_test.go @@ -10,6 +10,14 @@ import ( ) func TestCLIOnlyModelOverride_SwitchesBackendAndWarns(t *testing.T) { + orig := cliOnlyModelIDs + cliOnlyModelIDs = map[string]bool{ + "test-cli-only-model": true, + } + t.Cleanup(func() { + cliOnlyModelIDs = orig + }) + // Set up router with openai configured as API backend. runtimes := map[string]ProviderRuntime{ "openai": {Key: "openai", Backend: BackendAPI}, @@ -24,7 +32,7 @@ func TestCLIOnlyModelOverride_SwitchesBackendAndWarns(t *testing.T) { // Create a node using the CLI-only model. node := model.NewNode("spark-test") node.Attrs["llm_provider"] = "openai" - node.Attrs["llm_model"] = "gpt-5.3-codex-spark" + node.Attrs["llm_model"] = "test-cli-only-model" node.Attrs["shape"] = "box" // Create an execution with temp dirs to isolate artifacts and an Engine diff --git a/internal/attractor/engine/provider_preflight.go b/internal/attractor/engine/provider_preflight.go index 7c94d613..1be7fc92 100644 --- a/internal/attractor/engine/provider_preflight.go +++ b/internal/attractor/engine/provider_preflight.go @@ -93,9 +93,8 @@ func runProviderCLIPreflight(ctx context.Context, g *model.Graph, runtimes map[s _ = writePreflightReport(opts.LogsRoot, report) }() - // Validate CLI-only models: fail early if a CLI-only model (e.g., - // gpt-5.3-codex-spark) is used but its provider is not configured with - // backend=cli. + // Validate CLI-only models: fail early if a configured CLI-only model is + // used but its provider is not configured with backend=cli. if err := validateCLIOnlyModels(g, runtimes, opts.ForceModels, report); err != nil { return report, err } diff --git a/internal/attractor/engine/provider_preflight_test.go b/internal/attractor/engine/provider_preflight_test.go index 12d7943d..981a2f3f 100644 --- a/internal/attractor/engine/provider_preflight_test.go +++ b/internal/attractor/engine/provider_preflight_test.go @@ -1922,18 +1922,26 @@ exit 1 } func TestProviderPreflight_CLIOnlyModelWithAPIBackend_Fails(t *testing.T) { + orig := cliOnlyModelIDs + cliOnlyModelIDs = map[string]bool{ + "test-cli-only-model": true, + } + t.Cleanup(func() { + cliOnlyModelIDs = orig + }) + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") repo := initTestRepo(t) catalog := writeCatalogForPreflight(t, `{ "data": [ - {"id": "openai/gpt-5.3-codex-spark"} + {"id": "openai/test-cli-only-model"} ] }`) // openai configured as API backend — should fail for CLI-only model. cfg := testPreflightConfigForProviders(repo, catalog, map[string]BackendKind{ "openai": BackendAPI, }) - dot := singleProviderDot("openai", "gpt-5.3-codex-spark") + dot := singleProviderDot("openai", "test-cli-only-model") logsRoot := t.TempDir() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -1963,18 +1971,26 @@ func TestProviderPreflight_CLIOnlyModelWithAPIBackend_Fails(t *testing.T) { } func TestProviderPreflight_CLIOnlyModelWithCLIBackend_Passes(t *testing.T) { + orig := cliOnlyModelIDs + cliOnlyModelIDs = map[string]bool{ + "test-cli-only-model": true, + } + t.Cleanup(func() { + cliOnlyModelIDs = orig + }) + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") repo := initTestRepo(t) catalog := writeCatalogForPreflight(t, `{ "data": [ - {"id": "openai/gpt-5.3-codex-spark"} + {"id": "openai/test-cli-only-model"} ] }`) // openai configured as CLI backend — should pass the CLI-only check. cfg := testPreflightConfigForProviders(repo, catalog, map[string]BackendKind{ "openai": BackendCLI, }) - dot := singleProviderDot("openai", "gpt-5.3-codex-spark") + dot := singleProviderDot("openai", "test-cli-only-model") logsRoot := t.TempDir() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -1999,11 +2015,19 @@ func TestProviderPreflight_CLIOnlyModelWithCLIBackend_Passes(t *testing.T) { } func TestProviderPreflight_CLIOnlyModel_ForceModelOverridesToRegular_NoFail(t *testing.T) { + orig := cliOnlyModelIDs + cliOnlyModelIDs = map[string]bool{ + "test-cli-only-model": true, + } + t.Cleanup(func() { + cliOnlyModelIDs = orig + }) + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") repo := initTestRepo(t) catalog := writeCatalogForPreflight(t, `{ "data": [ - {"id": "openai/gpt-5.3-codex-spark"}, + {"id": "openai/test-cli-only-model"}, {"id": "openai/gpt-5.2-codex"} ] }`) @@ -2012,7 +2036,7 @@ func TestProviderPreflight_CLIOnlyModel_ForceModelOverridesToRegular_NoFail(t *t cfg := testPreflightConfigForProviders(repo, catalog, map[string]BackendKind{ "openai": BackendAPI, }) - dot := singleProviderDot("openai", "gpt-5.3-codex-spark") + dot := singleProviderDot("openai", "test-cli-only-model") logsRoot := t.TempDir() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -2023,20 +2047,28 @@ func TestProviderPreflight_CLIOnlyModel_ForceModelOverridesToRegular_NoFail(t *t AllowTestShim: true, ForceModels: map[string]string{"openai": "gpt-5.2-codex"}, }) - // Should NOT fail with CLI-only error — force-model replaces Spark with - // a regular model. + // Should NOT fail with CLI-only error — force-model replaces the + // CLI-only model with a regular model. if err != nil && strings.Contains(err.Error(), "CLI-only") { t.Fatalf("force-model to regular model should bypass CLI-only check, got: %v", err) } } func TestProviderPreflight_ForceModelInjectsCLIOnly_WithAPIBackend_Fails(t *testing.T) { + orig := cliOnlyModelIDs + cliOnlyModelIDs = map[string]bool{ + "test-cli-only-model": true, + } + t.Cleanup(func() { + cliOnlyModelIDs = orig + }) + t.Setenv("KILROY_PREFLIGHT_PROMPT_PROBES", "off") repo := initTestRepo(t) catalog := writeCatalogForPreflight(t, `{ "data": [ {"id": "openai/gpt-5.2-codex"}, - {"id": "openai/gpt-5.3-codex-spark"} + {"id": "openai/test-cli-only-model"} ] }`) // openai configured as API backend, graph uses a regular model, but @@ -2053,7 +2085,7 @@ func TestProviderPreflight_ForceModelInjectsCLIOnly_WithAPIBackend_Fails(t *test RunID: "force-cli-only-api-fail", LogsRoot: logsRoot, AllowTestShim: true, - ForceModels: map[string]string{"openai": "gpt-5.3-codex-spark"}, + ForceModels: map[string]string{"openai": "test-cli-only-model"}, }) if err == nil { t.Fatal("expected preflight error when force-model injects CLI-only model with API backend, got nil") From 7fda33cada815b104b8b990d587bfcd8fe4bc665 Mon Sep 17 00:00:00 2001 From: Vadim Comanescu Date: Sat, 28 Feb 2026 15:38:51 +0100 Subject: [PATCH 3/4] engine: harden codex app-server autonomy wiring (cherry picked from commit 23f3ce1e2329597b204b530bc23b1f817b5266c2) --- .../engine/codergen_failover_test.go | 69 +++++++++++++++++++ internal/attractor/engine/codergen_router.go | 47 +++++++++++-- .../codexappserver/request_translator.go | 2 + .../request_translator_controls_test.go | 16 +++++ 4 files changed, 130 insertions(+), 4 deletions(-) diff --git a/internal/attractor/engine/codergen_failover_test.go b/internal/attractor/engine/codergen_failover_test.go index 5262547e..ecd9b312 100644 --- a/internal/attractor/engine/codergen_failover_test.go +++ b/internal/attractor/engine/codergen_failover_test.go @@ -392,3 +392,72 @@ func TestShouldFailoverLLMError_GetwdBootstrapErrorDoesNotFailover(t *testing.T) t.Fatalf("getwd bootstrap errors should not trigger failover") } } + +func TestAgentLoopProviderOptions_CodexAppServer_UsesFullAutonomousPermissions(t *testing.T) { + got := agentLoopProviderOptions("codex_app_server", "/tmp/worktree") + if len(got) != 1 { + t.Fatalf("provider options length=%d want 1", len(got)) + } + raw, ok := got["codex_app_server"] + if !ok { + t.Fatalf("missing codex_app_server provider options: %#v", got) + } + opts, ok := raw.(map[string]any) + if !ok { + t.Fatalf("codex_app_server options type=%T want map[string]any", raw) + } + if gotCwd := fmt.Sprint(opts["cwd"]); gotCwd != "/tmp/worktree" { + t.Fatalf("cwd=%q want %q", gotCwd, "/tmp/worktree") + } + if gotApproval := fmt.Sprint(opts["approvalPolicy"]); gotApproval != "never" { + t.Fatalf("approvalPolicy=%q want %q", gotApproval, "never") + } + if gotSandbox := fmt.Sprint(opts["sandbox"]); gotSandbox != "danger-full-access" { + t.Fatalf("sandbox=%q want %q", gotSandbox, "danger-full-access") + } + rawSandboxPolicy, ok := opts["sandboxPolicy"] + if !ok { + t.Fatalf("missing sandboxPolicy in codex options: %#v", opts) + } + sandboxPolicy, ok := rawSandboxPolicy.(map[string]any) + if !ok { + t.Fatalf("sandboxPolicy type=%T want map[string]any", rawSandboxPolicy) + } + if gotType := fmt.Sprint(sandboxPolicy["type"]); gotType != "dangerFullAccess" { + t.Fatalf("sandboxPolicy.type=%q want %q", gotType, "dangerFullAccess") + } +} + +func TestAgentLoopProviderOptions_Cerebras_PreservesReasoningHistory(t *testing.T) { + got := agentLoopProviderOptions("cerebras", "") + raw, ok := got["cerebras"] + if !ok { + t.Fatalf("missing cerebras provider options: %#v", got) + } + opts, ok := raw.(map[string]any) + if !ok { + t.Fatalf("cerebras options type=%T want map[string]any", raw) + } + clearThinking, ok := opts["clear_thinking"].(bool) + if !ok { + t.Fatalf("clear_thinking type=%T want bool", opts["clear_thinking"]) + } + if clearThinking { + t.Fatalf("clear_thinking=%v want false", clearThinking) + } +} + +func TestAgentLoopProviderOptions_CodexAppServer_OmitsCwdWhenWorktreeEmpty(t *testing.T) { + got := agentLoopProviderOptions("codex-app-server", "") + raw, ok := got["codex_app_server"] + if !ok { + t.Fatalf("missing codex_app_server provider options: %#v", got) + } + opts, ok := raw.(map[string]any) + if !ok { + t.Fatalf("codex_app_server options type=%T want map[string]any", raw) + } + if _, exists := opts["cwd"]; exists { + t.Fatalf("expected cwd to be omitted when worktreeDir is empty: %#v", opts["cwd"]) + } +} diff --git a/internal/attractor/engine/codergen_router.go b/internal/attractor/engine/codergen_router.go index 7d3b2e14..3cd9a947 100644 --- a/internal/attractor/engine/codergen_router.go +++ b/internal/attractor/engine/codergen_router.go @@ -238,15 +238,24 @@ func (r *CodergenRouter) runAPI(ctx context.Context, execCtx *Execution, node *m if reasoning != "" { sessCfg.ReasoningEffort = reasoning } - if maxTokensPtr != nil { - sessCfg.MaxTokens = maxTokensPtr + if providerOptions := agentLoopProviderOptions(prov, execCtx.WorktreeDir); len(providerOptions) > 0 { + sessCfg.ProviderOptions = providerOptions } // Cerebras GLM 4.7: preserve reasoning across agent-loop turns. // clear_thinking defaults to true on the API, which strips prior // reasoning context — counterproductive for multi-step agentic work. if normalizeProviderKey(prov) == "cerebras" { - sessCfg.ProviderOptions = map[string]any{ - "cerebras": map[string]any{"clear_thinking": false}, + if sessCfg.ProviderOptions == nil { + sessCfg.ProviderOptions = map[string]any{} + } + if existing, ok := sessCfg.ProviderOptions["cerebras"]; ok { + if cerebrasOpts, ok := existing.(map[string]any); ok { + cerebrasOpts["clear_thinking"] = false + } else { + sessCfg.ProviderOptions["cerebras"] = map[string]any{"clear_thinking": false} + } + } else { + sessCfg.ProviderOptions["cerebras"] = map[string]any{"clear_thinking": false} } } if v := parseInt(node.Attr("max_agent_turns", ""), 0); v > 0 { @@ -383,6 +392,36 @@ func (r *CodergenRouter) runAPI(ctx context.Context, execCtx *Execution, node *m } } +func agentLoopProviderOptions(provider string, worktreeDir string) map[string]any { + switch normalizeProviderKey(provider) { + case "cerebras": + // Cerebras GLM 4.7: preserve reasoning across agent-loop turns. + // clear_thinking defaults to true on the API, which strips prior + // reasoning context, this is counterproductive for multi-step agentic work. + return map[string]any{ + "cerebras": map[string]any{"clear_thinking": false}, + } + case "codex-app-server": + opts := map[string]any{ + "approvalPolicy": "never", + // Keep both thread-level and turn-level sandbox knobs set. + // App-server surfaces use different fields/casing across thread/start and turn/start. + "sandbox": "danger-full-access", + "sandboxPolicy": map[string]any{ + "type": "dangerFullAccess", + }, + } + if wt := strings.TrimSpace(worktreeDir); wt != "" { + opts["cwd"] = wt + } + return map[string]any{ + "codex_app_server": opts, + } + default: + return nil + } +} + type providerModel struct { Provider string Model string diff --git a/internal/llm/providers/codexappserver/request_translator.go b/internal/llm/providers/codexappserver/request_translator.go index 819143e4..afba28c7 100644 --- a/internal/llm/providers/codexappserver/request_translator.go +++ b/internal/llm/providers/codexappserver/request_translator.go @@ -40,6 +40,8 @@ var ( "cwd": "cwd", "approvalPolicy": "approvalPolicy", "approval_policy": "approvalPolicy", + "sandbox": "sandbox", + "sandbox_mode": "sandbox", "sandboxPolicy": "sandboxPolicy", "sandbox_policy": "sandboxPolicy", "model": "model", diff --git a/internal/llm/providers/codexappserver/request_translator_controls_test.go b/internal/llm/providers/codexappserver/request_translator_controls_test.go index 4c7348c0..f44c17de 100644 --- a/internal/llm/providers/codexappserver/request_translator_controls_test.go +++ b/internal/llm/providers/codexappserver/request_translator_controls_test.go @@ -74,6 +74,8 @@ func TestRequestTranslator_ApplyProviderOptions_MapsKnownKeysAndWarnsUnknown(t * "codex_app_server": map[string]any{ "cwd": "/tmp/project", "approval_policy": "never", + "sandbox_mode": "danger-full-access", + "sandboxPolicy": map[string]any{"type": "dangerFullAccess"}, "temperature": 0.2, "reasoning_effort": "high", "unsupportedX": true, @@ -87,6 +89,20 @@ func TestRequestTranslator_ApplyProviderOptions_MapsKnownKeysAndWarnsUnknown(t * if params["approvalPolicy"] != "never" { t.Fatalf("approvalPolicy mapping: %#v", params["approvalPolicy"]) } + if params["sandbox"] != "danger-full-access" { + t.Fatalf("sandbox mapping: %#v", params["sandbox"]) + } + rawSandboxPolicy, ok := params["sandboxPolicy"] + if !ok { + t.Fatalf("missing sandboxPolicy mapping: %#v", params) + } + sandboxPolicy, ok := rawSandboxPolicy.(map[string]any) + if !ok { + t.Fatalf("sandboxPolicy type=%T want map[string]any", rawSandboxPolicy) + } + if sandboxPolicy["type"] != "dangerFullAccess" { + t.Fatalf("sandboxPolicy.type=%#v want %q", sandboxPolicy["type"], "dangerFullAccess") + } if controls.Temperature == nil || *controls.Temperature != 0.2 { t.Fatalf("temperature override: %#v", controls.Temperature) } From 6be757a9721cbd81850c7c80db90f14fa188d1fe Mon Sep 17 00:00:00 2001 From: Vadim Comanescu Date: Wed, 4 Mar 2026 19:51:16 +0100 Subject: [PATCH 4/4] fix(codex-app-server): handle process-exit waits and status parsing --- .../llm/providers/codexappserver/adapter.go | 66 +++++++- .../providers/codexappserver/adapter_test.go | 24 +++ .../llm/providers/codexappserver/transport.go | 146 +++++++++++++++--- .../codexappserver/transport_helpers_test.go | 51 ++++++ 4 files changed, 262 insertions(+), 25 deletions(-) diff --git a/internal/llm/providers/codexappserver/adapter.go b/internal/llm/providers/codexappserver/adapter.go index b45edc8d..d5968272 100644 --- a/internal/llm/providers/codexappserver/adapter.go +++ b/internal/llm/providers/codexappserver/adapter.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "os" + "strconv" "strings" "sync" "time" @@ -445,11 +446,15 @@ func normalizeErrorInfo(raw any) normalizedErrorInfo { } if statusVal, ok := source["status"]; ok { - info.Status = asInt(statusVal, 0) - info.HasStatus = true + if status, hasStatus := parseHTTPStatus(statusVal); hasStatus { + info.Status = status + info.HasStatus = true + } } else if statusVal, ok := root["status"]; ok { - info.Status = asInt(statusVal, 0) - info.HasStatus = true + if status, hasStatus := parseHTTPStatus(statusVal); hasStatus { + info.Status = status + info.HasStatus = true + } } info.Code = firstNonEmpty( @@ -496,6 +501,59 @@ func unwrapJSONMessage(message string) string { return message } +func parseHTTPStatus(raw any) (int, bool) { + switch value := raw.(type) { + case int: + return normalizeHTTPStatus(value) + case int8: + return normalizeHTTPStatus(int(value)) + case int16: + return normalizeHTTPStatus(int(value)) + case int32: + return normalizeHTTPStatus(int(value)) + case int64: + return normalizeHTTPStatus(int(value)) + case uint: + return normalizeHTTPStatus(int(value)) + case uint8: + return normalizeHTTPStatus(int(value)) + case uint16: + return normalizeHTTPStatus(int(value)) + case uint32: + return normalizeHTTPStatus(int(value)) + case uint64: + return normalizeHTTPStatus(int(value)) + case float32: + return normalizeHTTPStatus(int(value)) + case float64: + return normalizeHTTPStatus(int(value)) + case json.Number: + if parsed, err := value.Int64(); err == nil { + return normalizeHTTPStatus(int(parsed)) + } + return 0, false + case string: + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return 0, false + } + parsed, err := strconv.Atoi(trimmed) + if err != nil { + return 0, false + } + return normalizeHTTPStatus(parsed) + default: + return 0, false + } +} + +func normalizeHTTPStatus(status int) (int, bool) { + if status < 100 || status > 599 { + return 0, false + } + return status, true +} + func isTransportFailure(code, message string) bool { if code != "" { switch code { diff --git a/internal/llm/providers/codexappserver/adapter_test.go b/internal/llm/providers/codexappserver/adapter_test.go index d54bdb34..3c544431 100644 --- a/internal/llm/providers/codexappserver/adapter_test.go +++ b/internal/llm/providers/codexappserver/adapter_test.go @@ -317,6 +317,30 @@ func TestNormalizeErrorInfo_UnwrapsJSONMessage(t *testing.T) { } } +func TestNormalizeErrorInfo_IgnoresSymbolicStatus(t *testing.T) { + info := normalizeErrorInfo(map[string]any{ + "error": map[string]any{ + "status": "RESOURCE_EXHAUSTED", + "message": "rate limited", + }, + }) + if info.HasStatus { + t.Fatalf("expected symbolic status to be ignored, got status=%d", info.Status) + } +} + +func TestNormalizeErrorInfo_ParsesNumericStatusString(t *testing.T) { + info := normalizeErrorInfo(map[string]any{ + "error": map[string]any{ + "status": "429", + "message": "rate limited", + }, + }) + if !info.HasStatus || info.Status != 429 { + t.Fatalf("expected HTTP status 429, got hasStatus=%v status=%d", info.HasStatus, info.Status) + } +} + func TestParseToolCall_NormalizesArguments(t *testing.T) { tool := parseToolCall(`{"id":"call_1","name":"search","arguments":{"q":"foo"}}`) if tool == nil { diff --git a/internal/llm/providers/codexappserver/transport.go b/internal/llm/providers/codexappserver/transport.go index 8731e888..161dbf29 100644 --- a/internal/llm/providers/codexappserver/transport.go +++ b/internal/llm/providers/codexappserver/transport.go @@ -67,6 +67,55 @@ type pendingResult struct { err error } +type processLifecycle struct { + done chan struct{} + once sync.Once + mu sync.Mutex + err error +} + +type turnWaitOutcome int + +const ( + turnWaitCompleted turnWaitOutcome = iota + turnWaitContextDone + turnWaitProcessTerminated +) + +func newProcessLifecycle() *processLifecycle { + return &processLifecycle{done: make(chan struct{})} +} + +func (l *processLifecycle) finish(err error) { + if l == nil { + return + } + l.mu.Lock() + if l.err == nil { + l.err = err + } + l.mu.Unlock() + l.once.Do(func() { + close(l.done) + }) +} + +func (l *processLifecycle) doneCh() <-chan struct{} { + if l == nil { + return nil + } + return l.done +} + +func (l *processLifecycle) processError() error { + if l == nil { + return nil + } + l.mu.Lock() + defer l.mu.Unlock() + return l.err +} + type stdioTransport struct { opts TransportOptions @@ -77,6 +126,7 @@ type stdioTransport struct { stdout io.ReadCloser stderr io.ReadCloser procDone chan struct{} + life *processLifecycle closed bool shuttingDown bool @@ -150,6 +200,11 @@ func (t *stdioTransport) Stream(ctx context.Context, payload map[string]any) (*N errs <- err return } + life := t.currentProcessLifecycle() + if life == nil { + errs <- llm.NewNetworkError(providerName, "Codex app-server process is unavailable") + return + } turnTemplate, err := parseTurnStartPayload(payload) if err != nil { @@ -241,19 +296,20 @@ func (t *stdioTransport) Stream(ctx context.Context, payload map[string]any) (*N } } - select { - case <-completed: + outcome, waitErr := t.waitForTurnCompletion(requestCtx, completed, life) + if waitErr == nil { return - case <-requestCtx.Done(): + } + if outcome == turnWaitContextDone { stateMu.Lock() currentTurnID := turnID stateMu.Unlock() if currentTurnID != "" { go t.interruptTurnBestEffort(threadID, currentTurnID) } - errs <- llm.WrapContextError(providerName, requestCtx.Err()) - return } + errs <- waitErr + return }() return &NotificationStream{Notifications: events, Err: errs, closeFn: cancel}, nil @@ -303,6 +359,10 @@ func (t *stdioTransport) runTurn(ctx context.Context, payload map[string]any) (m if err := t.ensureInitialized(ctx); err != nil { return nil, err } + life := t.currentProcessLifecycle() + if life == nil { + return nil, llm.NewNetworkError(providerName, "Codex app-server process is unavailable") + } turnTemplate, err := parseTurnStartPayload(payload) if err != nil { @@ -375,16 +435,17 @@ func (t *stdioTransport) runTurn(ctx context.Context, payload map[string]any) (m } } - select { - case <-completed: - case <-requestCtx.Done(): - stateMu.Lock() - currentTurnID := turnID - stateMu.Unlock() - if currentTurnID != "" { - go t.interruptTurnBestEffort(threadID, currentTurnID) + outcome, waitErr := t.waitForTurnCompletion(requestCtx, completed, life) + if waitErr != nil { + if outcome == turnWaitContextDone { + stateMu.Lock() + currentTurnID := turnID + stateMu.Unlock() + if currentTurnID != "" { + go t.interruptTurnBestEffort(threadID, currentTurnID) + } } - return nil, llm.WrapContextError(providerName, requestCtx.Err()) + return nil, waitErr } stateMu.Lock() @@ -442,6 +503,35 @@ func (t *stdioTransport) interruptTimeout() time.Duration { return timeout } +func (t *stdioTransport) currentProcessLifecycle() *processLifecycle { + t.mu.Lock() + defer t.mu.Unlock() + return t.life +} + +func (t *stdioTransport) processTerminationError(life *processLifecycle) error { + if life != nil { + if err := life.processError(); err != nil { + return err + } + } + return llm.NewNetworkError(providerName, "Codex app-server process exited") +} + +func (t *stdioTransport) waitForTurnCompletion(ctx context.Context, completed <-chan struct{}, life *processLifecycle) (turnWaitOutcome, error) { + if life == nil { + return turnWaitProcessTerminated, llm.NewNetworkError(providerName, "Codex app-server process is unavailable") + } + select { + case <-completed: + return turnWaitCompleted, nil + case <-ctx.Done(): + return turnWaitContextDone, llm.WrapContextError(providerName, ctx.Err()) + case <-life.doneCh(): + return turnWaitProcessTerminated, t.processTerminationError(life) + } +} + func (t *stdioTransport) ensureInitialized(ctx context.Context) error { t.mu.Lock() if t.closed { @@ -548,24 +638,26 @@ func (t *stdioTransport) spawnProcess() error { } procDone := make(chan struct{}) + life := newProcessLifecycle() t.mu.Lock() t.cmd = cmd t.stdin = stdin t.stdout = stdout t.stderr = stderr t.procDone = procDone + t.life = life t.stderrTail = "" t.initialized = false t.mu.Unlock() - go t.readStdout(cmd, stdout) + go t.readStdout(cmd, stdout, life) go t.readStderr(stderr) - go t.waitForExit(cmd, procDone) + go t.waitForExit(cmd, procDone, life) return nil } -func (t *stdioTransport) readStdout(cmd *exec.Cmd, stdout io.Reader) { +func (t *stdioTransport) readStdout(cmd *exec.Cmd, stdout io.Reader, life *processLifecycle) { scanner := bufio.NewScanner(stdout) buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, maxJSONRPCLineSize) @@ -583,7 +675,7 @@ func (t *stdioTransport) readStdout(cmd *exec.Cmd, stdout io.Reader) { t.handleIncomingMessage(message) } if err := scanner.Err(); err != nil { - t.handleUnexpectedProcessTermination(llm.NewNetworkError(providerName, fmt.Sprintf("Codex stdout read error: %v", err))) + t.handleUnexpectedProcessTermination(life, llm.NewNetworkError(providerName, fmt.Sprintf("Codex stdout read error: %v", err))) } _ = cmd } @@ -613,7 +705,7 @@ func (t *stdioTransport) appendStderrTail(chunk string) { t.mu.Unlock() } -func (t *stdioTransport) waitForExit(cmd *exec.Cmd, done chan struct{}) { +func (t *stdioTransport) waitForExit(cmd *exec.Cmd, done chan struct{}, life *processLifecycle) { err := cmd.Wait() t.mu.Lock() shuttingDown := t.shuttingDown @@ -625,12 +717,23 @@ func (t *stdioTransport) waitForExit(cmd *exec.Cmd, done chan struct{}) { t.stdout = nil t.stderr = nil t.procDone = nil + t.life = nil t.initialized = false } t.shuttingDown = false t.mu.Unlock() close(done) + exitMessage := "Codex app-server process exited" + if err != nil { + exitMessage = fmt.Sprintf("Codex app-server process exited: %v", err) + } + if stderrTail != "" { + exitMessage = exitMessage + ". stderr: " + stderrTail + } + exitErr := llm.NewNetworkError(providerName, exitMessage) + life.finish(exitErr) + if shuttingDown || closed { return } @@ -641,10 +744,11 @@ func (t *stdioTransport) waitForExit(cmd *exec.Cmd, done chan struct{}) { if stderrTail != "" { message = message + ". stderr: " + stderrTail } - t.handleUnexpectedProcessTermination(llm.NewNetworkError(providerName, message)) + t.handleUnexpectedProcessTermination(life, llm.NewNetworkError(providerName, message)) } -func (t *stdioTransport) handleUnexpectedProcessTermination(err error) { +func (t *stdioTransport) handleUnexpectedProcessTermination(life *processLifecycle, err error) { + life.finish(err) t.rejectAllPending(err) } diff --git a/internal/llm/providers/codexappserver/transport_helpers_test.go b/internal/llm/providers/codexappserver/transport_helpers_test.go index 330c8bac..594d8d2d 100644 --- a/internal/llm/providers/codexappserver/transport_helpers_test.go +++ b/internal/llm/providers/codexappserver/transport_helpers_test.go @@ -577,3 +577,54 @@ func TestTransport_HelperProcess(t *testing.T) { return } } + +func TestProcessLifecycle_FinishIsIdempotent(t *testing.T) { + life := newProcessLifecycle() + firstErr := errors.New("first exit") + secondErr := errors.New("second exit") + + life.finish(firstErr) + life.finish(secondErr) + + select { + case <-life.doneCh(): + default: + t.Fatalf("expected lifecycle done channel to close") + } + if !errors.Is(life.processError(), firstErr) { + t.Fatalf("expected first process error to win, got %v", life.processError()) + } +} + +func TestTransport_WaitForTurnCompletion_UnblocksOnProcessExit(t *testing.T) { + tp := &stdioTransport{} + life := newProcessLifecycle() + completed := make(chan struct{}) + resultCh := make(chan struct { + outcome turnWaitOutcome + err error + }, 1) + + go func() { + outcome, err := tp.waitForTurnCompletion(context.Background(), completed, life) + resultCh <- struct { + outcome turnWaitOutcome + err error + }{outcome: outcome, err: err} + }() + + processErr := llm.NewNetworkError(providerName, "Codex app-server exited unexpectedly") + life.finish(processErr) + + select { + case result := <-resultCh: + if result.outcome != turnWaitProcessTerminated { + t.Fatalf("wait outcome: got %v want %v", result.outcome, turnWaitProcessTerminated) + } + if !errors.Is(result.err, processErr) { + t.Fatalf("wait error: got %v want %v", result.err, processErr) + } + case <-time.After(500 * time.Millisecond): + t.Fatalf("timed out waiting for process-exit unblock") + } +}