diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 00243f8..d2a7ba1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,7 +29,7 @@ jobs: - name: Generate release notes run: | - awk '/^## \[Unreleased\]/{skip=1; next} /^## \[/{skip=0; c++; if(c>1)exit; if(c==1)next} !skip && c>0' CHANGELOG.md | tail -c +2 > /tmp/release-notes.txt + awk '/^## \[/{c++; if(c>1)exit; if(c==1){next}} c>0' CHANGELOG.md | tail -c +2 > /tmp/release-notes.txt - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 diff --git a/.golangci.yml b/.golangci.yml index 3d9df46..acd16d3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -138,9 +138,31 @@ issues: - path: main\.go linters: - goconst + - path: cmd/delete\.go + linters: + - nlreturn + - path: cmd/harness\.go + linters: + - goconst + - nlreturn - path: cmd/version\.go linters: - goconst + - path: cmd/list\.go + linters: + - nlreturn + - path: cmd/util\.go + linters: + - nlreturn + - path: internal/errors/errors\.go + linters: + - nlreturn + - path: internal/ui/prompt\.go + linters: + - nlreturn + - path: internal/providers/registry\.go + linters: + - nlreturn - text: "exported (function|method|type|const)" linters: - stylecheck @@ -163,6 +185,10 @@ issues: linters: - funlen # Context-aware functions with multiple cancellation checks + - text: "migrateConfigFile" + linters: + - funlen + - cyclop - text: "LoadConfig" linters: - funlen diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 44e2cf9..112d73e 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -48,3 +48,9 @@ changelog: # Disable auto-generated changelog since CHANGELOG.md is maintained manually disable: true +release: + footer: | + + --- + + Released by [GoReleaser](https://github.com/goreleaser/goreleaser). diff --git a/AGENTS.md b/AGENTS.md index 1f7daa6..8e32aab 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,65 +1,43 @@ # Kairo Development -## Architecture +## WHAT -Go CLI that wraps Claude/Qwen Code providers with X25519 encryption via age. +Go CLI wrapper for Claude/Qwen Code API providers with X25519 encryption. -**Stack:** Go 1.26+, Cobra, filippo.io/age, YAML, Go testing +**Tech Stack:** Go 1.26+, Cobra, age (filippo.io/age), YAML, Go testing -```text -main.go # Bootstrap -cmd/ # Cobra commands (root, setup, execution, update, version) -internal/ - config/ # YAML loading, caching, migration - crypto/ # X25519 key gen, age encrypt/decrypt - errors/ # Typed error hierarchy with context - providers/ # Provider registry and resolution - secrets/ # Encrypted secret storage - ui/ # ANSI-aware terminal output (build-tagged per OS) - validate/ # API key, URL, model, cross-provider validation - version/ # Build-time metadata - wrapper/ # Shell wrapper script generation -tests/integration/ # End-to-end workflow tests -docs/ # Architecture, guides, reference -``` - -See `internal/README.md` for package contracts and data flow. -See `cmd/README.md` for command structure and CLIContext details. +**Key Directories:** -## Conventions +- `cmd/` - CLI commands (Cobra) - see `cmd/root.go:1` +- `internal/` - Business logic (config, crypto, providers, ui, errors, validate) +- `docs/` - Architecture, guides, reference documentation -- Internal packages (`internal/*`) have zero CLI dependencies. Keep them pure Go. -- Injectable function variables for testability in `cmd/` — external calls (exec, HTTP, prompts) are assigned to package-level `var` funcs, overridden in tests. See `cmd/update.go` for the pattern. -- Propagate `context.Context` through all I/O-bound call chains. Check cancellation between sequential operations (file writes, network calls). -- Error wrapping uses the typed `internal/errors` package — `kairoerrors.WrapError(kind, msg, err)` with `.WithContext(key, val)` for diagnostics. -- Build tags for platform-specific code (e.g., `//go:build !windows` in `internal/ui/`). -- Coverage threshold: 70% enforced in CI. +**Entry Points:** -## Constraints +- `main.go:1` - Application bootstrap +- `cmd/root.go:1` - Root command and CLI setup +- `internal/README.md` - Package contracts, data flow -- Keep `migrateConfigFile` and similar context-aware functions decomposed below cyclop=10. The linter catches this, but the pattern is: extract substeps into named helpers (`statOldConfig`, `readAndValidateConfig`, `finalizeMigration`). -- `nlreturn` + `whitespace` linters coexist — do not place blank lines at the start of blocks (whitespace rejects) but do place blank lines before returns in the main function flow (nlreturn requires). Short error-guard returns inside `if` blocks are exempt from nlreturn's blank-line rule. -- Functions with `CheckContext` calls between sequential I/O steps naturally grow complex. Prefer extracting substeps into helpers early. -- Update checks hit the GitHub Releases API (`api.github.com`) — unauthenticated, 60 req/hr/IP limit. Do not add additional unauthenticated GitHub API calls. - -## Commands +## HOW ```bash -just build # Binary to dist/ -just test # All tests with -race -just test-coverage # Coverage report -just lint # gofmt, go vet, golangci-lint -just pre-release # Format, lint, pre-commit hooks, test -just release # Create release (GITHUB_TOKEN required) +just build # Binary to dist/ +just test # All tests with race detector +just test-coverage # Coverage report +just lint # gofmt, go vet, golangci-lint +just pre-release # Format, lint, pre-commit hooks, test +just release # Create release (requires GITHUB_TOKEN) ``` -CI runs: `golangci-lint run ./...`, `go test -race -coverprofile=coverage.out ./...`, `go mod tidy` check. +## Docs + +Read these if relevant to your task: -Pre-commit hooks (`.pre-commit-config.yaml`): golangci-lint, go test, go mod tidy check. +- `docs/architecture/README.md` - System design +- `docs/guides/development-guide.md` - Adding commands, CI workflows -## Patterns +## Notes -- **Adding a new provider:** See `docs/guides/development-guide.md` for the full workflow. -- **Architecture decisions:** See `docs/architecture/adr/` for ADRs (X25519 choice, Cobra selection, age library). -- **Wrapper scripts:** See `docs/architecture/wrapper-scripts.md` for harness execution details. -- **CI workflows:** See `.github/workflows/ci.yml` (lint, test, cross-platform build), `release.yml`, `vulnerability-scan.yml`. +- Internal packages have no CLI dependencies - keep them pure +- Use `just --list` to see all available commands +- Run `just pre-release` before committing diff --git a/CHANGELOG.md b/CHANGELOG.md index bacba55..69ce769 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,24 +5,6 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] - -### Added -- `--no-update-check` flag on `kairo version` to skip GitHub API calls - -### Changed -- CLI commands now propagate the Cobra command context to child operations and TUI interactions instead of using `context.Background()`, enabling proper cancellation on Ctrl+C -- Config loading in `delete` and `list` commands now uses the config cache for consistency -- Root command decomposed into `runRoot`, `loadRootConfig`, and `dispatchExecution` for maintainability -- Removed `DecryptSecrets` (string return) in favor of `DecryptSecretsBytes` with `ClearMemory` to allow secure memory clearing of decrypted key material - -### Fixed -- Temp auth directory no longer leaks on error paths — cleanup now runs before `os.Exit` -- Hyphens in custom provider names no longer produce invalid environment variable names (e.g., `MY-PROVIDER_API_KEY` → `MY_PROVIDER_API_KEY`) -- `mustParseCIDR` no longer panics on invalid CIDR constants — uses `log.Fatalf` with a clear message -- HTTP client in update command now has TLS handshake and response header timeouts for better protection against slow/misbehaving servers -- ANSI color codes are now stripped on Windows terminals that don't support them - ## [2.3.7] - 2026-04-25 ### Added diff --git a/cmd/coverage_config_test.go b/cmd/coverage_config_test.go deleted file mode 100644 index c197db2..0000000 --- a/cmd/coverage_config_test.go +++ /dev/null @@ -1,399 +0,0 @@ -package cmd - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/dkmnx/kairo/internal/config" - "github.com/dkmnx/kairo/internal/constants" - "github.com/dkmnx/kairo/internal/crypto" - "github.com/dkmnx/kairo/internal/providers" - "github.com/spf13/cobra" - "github.com/yarlson/tap" -) - -// --- Tests for LoadSecrets --- - -func TestLoadSecrets_NoSecretsFile(t *testing.T) { - tmpDir := withTempConfigDir(t) - result, err := LoadSecrets(context.Background(), tmpDir) - if err != nil { - t.Fatalf("LoadSecrets() error = %v", err) - } - if len(result.Secrets) != 0 { - t.Errorf("Expected empty secrets, got %d entries", len(result.Secrets)) - } - if result.SecretsPath == "" { - t.Error("SecretsPath should be set") - } -} - -func TestLoadSecrets_WithSecrets(t *testing.T) { - tmpDir := withTempConfigDir(t) - ctx := context.Background() - if err := crypto.EnsureKeyExists(ctx, tmpDir); err != nil { - t.Fatal(err) - } - secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) - keyPath := filepath.Join(tmpDir, constants.KeyFileName) - if err := crypto.EncryptSecrets(ctx, secretsPath, keyPath, "ZAI_API_KEY=test-key\n"); err != nil { - t.Fatal(err) - } - - result, err := LoadSecrets(ctx, tmpDir) - if err != nil { - t.Fatalf("LoadSecrets() error = %v", err) - } - if result.Secrets["ZAI_API_KEY"] != "test-key" { - t.Errorf("ZAI_API_KEY = %q, want %q", result.Secrets["ZAI_API_KEY"], "test-key") - } -} - -// --- Tests for SaveSecrets --- - -func TestSaveSecrets(t *testing.T) { - tmpDir := t.TempDir() - ctx := context.Background() - if err := crypto.EnsureKeyExists(ctx, tmpDir); err != nil { - t.Fatal(err) - } - secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) - keyPath := filepath.Join(tmpDir, constants.KeyFileName) - - err := SaveSecrets(ctx, secretsPath, keyPath, map[string]string{"TEST_API_KEY": "test-value"}) - if err != nil { - t.Fatalf("SaveSecrets() error = %v", err) - } - - result, err := LoadSecrets(ctx, tmpDir) - if err != nil { - t.Fatalf("LoadSecrets() error = %v", err) - } - if result.Secrets["TEST_API_KEY"] != "test-value" { - t.Errorf("TEST_API_KEY = %q, want %q", result.Secrets["TEST_API_KEY"], "test-value") - } -} - -// --- Tests for EnsureConfigDir --- - -func TestEnsureConfigDir(t *testing.T) { - origDir := getConfigDir() - defer setConfigDir(origDir) - - tmpDir := t.TempDir() - configDir := filepath.Join(tmpDir, "kairo-test-config") - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(configDir) - - err := EnsureConfigDir(cliCtx, configDir) - if err != nil { - t.Fatalf("EnsureConfigDir() error = %v", err) - } - if _, err := os.Stat(configDir); os.IsNotExist(err) { - t.Error("config directory should exist") - } -} - -// --- Tests for GetProviderDefinition --- - -func TestGetProviderDefinition(t *testing.T) { - tests := []struct { - name string - provider string - wantName string - wantBaseURL string - wantModel string - }{ - {"zai builtin", "zai", "Z.AI", "https://api.z.ai/api/anthropic", "glm-5.1"}, - {"minimax builtin", "minimax", "MiniMax", "https://api.minimax.io/anthropic", "MiniMax-M2.7"}, - {"unknown uses input", "myprovider", "myprovider", "", ""}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - def := GetProviderDefinition(tt.provider) - if def.Name != tt.wantName { - t.Errorf("Name = %q, want %q", def.Name, tt.wantName) - } - if tt.wantBaseURL != "" && def.BaseURL != tt.wantBaseURL { - t.Errorf("BaseURL = %q, want %q", def.BaseURL, tt.wantBaseURL) - } - if tt.wantModel != "" && def.Model != tt.wantModel { - t.Errorf("Model = %q, want %q", def.Model, tt.wantModel) - } - }) - } -} - -// --- Tests for validateConfiguredModel --- - -func TestValidateConfiguredModel(t *testing.T) { - tests := []struct { - name string - cfg modelValidationConfig - wantErr bool - }{ - {"empty model builtin ok", modelValidationConfig{Model: "", ProviderName: "zai", DisplayName: "Z.AI"}, false}, - {"valid model builtin", modelValidationConfig{Model: "glm-5.1", ProviderName: "zai", DisplayName: "Z.AI"}, false}, - {"empty custom requires model", modelValidationConfig{Model: " ", ProviderName: "custom-provider", DisplayName: "custom-provider"}, true}, - {"valid model custom", modelValidationConfig{Model: "my-model", ProviderName: "custom-provider", DisplayName: "custom-provider"}, false}, - {"model too long for built-in", modelValidationConfig{Model: strings.Repeat("a", 101), ProviderName: "zai", DisplayName: "Z.AI"}, true}, - {"model with invalid chars for built-in", modelValidationConfig{Model: "model@invalid", ProviderName: "zai", DisplayName: "Z.AI"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateConfiguredModel(tt.cfg) - if (err != nil) != tt.wantErr { - t.Errorf("validateConfiguredModel() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -// --- Tests for BuildProviderConfig --- - -func TestBuildProviderConfig_NewProvider(t *testing.T) { - def := providers.ProviderDefinition{Name: "Z.AI", BaseURL: "https://api.z.ai/api/anthropic", Model: "glm-5.1"} - cfg := BuildProviderConfig(ProviderBuildConfig{ - Definition: def, BaseURL: "https://custom.api.com", Model: "custom-model", Exists: false, Existing: nil, - }) - if cfg.Name != "Z.AI" { - t.Errorf("Name = %q, want %q", cfg.Name, "Z.AI") - } - if cfg.BaseURL != "https://custom.api.com" { - t.Errorf("BaseURL = %q, want %q", cfg.BaseURL, "https://custom.api.com") - } - if cfg.Model != "custom-model" { - t.Errorf("Model = %q, want %q", cfg.Model, "custom-model") - } -} - -func TestBuildProviderConfig_EditExisting(t *testing.T) { - existing := &config.Provider{ - Name: "Z.AI", BaseURL: "https://old.api.com", Model: "old-model", EnvVars: []string{"EXTRA_VAR=extra"}, - } - def := providers.ProviderDefinition{Name: "Z.AI", BaseURL: "https://api.z.ai/api/anthropic", Model: "glm-5.1"} - cfg := BuildProviderConfig(ProviderBuildConfig{ - Definition: def, BaseURL: "https://new.api.com", Model: "new-model", Exists: true, Existing: existing, - }) - if cfg.BaseURL != "https://new.api.com" { - t.Errorf("BaseURL = %q, want %q", cfg.BaseURL, "https://new.api.com") - } - if cfg.Model != "new-model" { - t.Errorf("Model = %q, want %q", cfg.Model, "new-model") - } - if len(cfg.EnvVars) != 1 || cfg.EnvVars[0] != "EXTRA_VAR=extra" { - t.Errorf("EnvVars = %v, want [EXTRA_VAR=extra]", cfg.EnvVars) - } -} - -// --- Tests for runResetSecrets --- - -func TestRunResetSecrets_Cancelled(t *testing.T) { - origConfirm := confirmUIFn - confirmUIFn = func(prompt string) (bool, error) { return false, nil } - defer func() { confirmUIFn = origConfirm }() - - tmpDir := withTempConfigDir(t) - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - secretsResult := SecretsResult{ - Secrets: make(map[string]string), - SecretsPath: filepath.Join(tmpDir, constants.SecretsFileName), - KeyPath: filepath.Join(tmpDir, constants.KeyFileName), - } - - err := runResetSecrets(cliCtx, tmpDir, secretsResult) - if err == nil { - t.Error("runResetSecrets() should return error when cancelled") - } - if err.Error() != "operation cancelled by user" { - t.Errorf("error = %q, want 'operation cancelled by user'", err.Error()) - } -} - -func TestRunResetSecrets_Confirmed(t *testing.T) { - origConfirm := confirmUIFn - confirmUIFn = func(prompt string) (bool, error) { return true, nil } - defer func() { confirmUIFn = origConfirm }() - - tmpDir := withTempConfigDir(t) - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { - t.Fatal(err) - } - - secretsResult := SecretsResult{ - Secrets: make(map[string]string), - SecretsPath: filepath.Join(tmpDir, constants.SecretsFileName), - KeyPath: filepath.Join(tmpDir, constants.KeyFileName), - } - - err := runResetSecrets(cliCtx, tmpDir, secretsResult) - if err != nil { - t.Fatalf("runResetSecrets() error = %v", err) - } - keyPath := filepath.Join(tmpDir, constants.KeyFileName) - if _, err := os.Stat(keyPath); os.IsNotExist(err) { - t.Error("new key should exist after reset") - } -} - -// --- Tests for requireConfigDirWritable --- - -func TestRequireConfigDirWritable_CreatesDir(t *testing.T) { - origDir := getConfigDir() - defer setConfigDir(origDir) - - tmpDir := t.TempDir() - testDir := filepath.Join(tmpDir, "kairo-config") - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(testDir) - cmd := newCommandWithContext(cliCtx) - - result := requireConfigDirWritable(cmd) - if result == "" { - t.Error("requireConfigDirWritable() should return path when dir can be created") - } - if _, err := os.Stat(testDir); os.IsNotExist(err) { - t.Error("directory should be created") - } -} - -func TestRequireConfigDirWritable_NoConfigDir(t *testing.T) { - // When configDir is empty and GetConfigDir has no override, - // it falls back to the platform default. So it will find a dir. - // Test that calling requireConfigDirWritable with a nil context returns empty. - cmd := &cobra.Command{} - // No CLIContext set, GetCLIContext falls back to defaultCLIContext - result := requireConfigDirWritable(cmd) - // With no override, GetConfigDir returns the platform default, which is writable, - // so result will be a non-empty path. We just verify it doesn't panic. - _ = result -} - -// --- Tests for loadConfigOrExit --- - -func TestLoadConfigOrExit_NoConfig(t *testing.T) { - tmpDir := withTempConfigDir(t) - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - cmd := newCommandWithContext(cliCtx) - result := loadConfigOrExit(cmd) - if result != nil { - t.Error("loadConfigOrExit() should return nil when no config exists") - } -} - -func TestLoadConfigOrExit_WithConfig(t *testing.T) { - tmpDir := withTempConfigDir(t) - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - cfg := &config.Config{ - DefaultProvider: "zai", - Providers: map[string]config.Provider{ - "zai": {Name: "Z.AI", BaseURL: "https://api.z.ai/api/anthropic", Model: "glm-5.1"}, - }, - DefaultModels: make(map[string]string), - } - mustCreateConfig(t, tmpDir, cfg) - - cmd := newCommandWithContext(cliCtx) - result := loadConfigOrExit(cmd) - if result == nil { - t.Fatal("loadConfigOrExit() should return config when it exists") - } - if result.DefaultProvider != "zai" { - t.Errorf("DefaultProvider = %q, want %q", result.DefaultProvider, "zai") - } -} - -// --- Tests for configureProvider integration --- - -func TestConfigureProvider_InvalidName(t *testing.T) { - _, err := configureProvider(context.Background(), ProviderSetup{ - ProviderName: "123invalid", - Cfg: &config.Config{Providers: make(map[string]config.Provider)}, - }) - if err == nil { - t.Error("configureProvider() should return error for invalid provider name") - } -} - -func TestConfigureProvider_EmptyAPIKey(t *testing.T) { - withMockedTAP(t) - tapPasswordFn = func(ctx context.Context, opts tap.PasswordOptions) string { return "" } - - tmpDir := t.TempDir() - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { - t.Fatal(err) - } - - _, err := configureProvider(context.Background(), ProviderSetup{ - CLIContext: NewCLIContext(), ConfigDir: tmpDir, - Cfg: &config.Config{Providers: make(map[string]config.Provider), DefaultModels: make(map[string]string)}, - ProviderName: "zai", Secrets: make(map[string]string), - SecretsPath: filepath.Join(tmpDir, constants.SecretsFileName), - KeyPath: filepath.Join(tmpDir, constants.KeyFileName), IsEdit: false, - }) - if err == nil { - t.Error("configureProvider() should return error for empty API key") - } -} - -func TestConfigureProvider_Success(t *testing.T) { - withMockedTAP(t) - apiKey := "sk-test-api-key-that-is-long-enough-1234567890" - tapPasswordFn = func(ctx context.Context, opts tap.PasswordOptions) string { return apiKey } - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return "" } - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return true } - tapIntroFn = func(title string, opts ...tap.MessageOptions) {} - tapMessageFn = func(message string, opts ...tap.MessageOptions) {} - tapOutroFn = func(message string, opts ...tap.MessageOptions) {} - - tmpDir := t.TempDir() - if err := crypto.EnsureKeyExists(context.Background(), tmpDir); err != nil { - t.Fatal(err) - } - - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - result, err := configureProvider(context.Background(), ProviderSetup{ - CLIContext: cliCtx, ConfigDir: tmpDir, - Cfg: &config.Config{Providers: make(map[string]config.Provider), DefaultModels: make(map[string]string)}, - ProviderName: "zai", Secrets: make(map[string]string), - SecretsPath: filepath.Join(tmpDir, constants.SecretsFileName), - KeyPath: filepath.Join(tmpDir, constants.KeyFileName), IsEdit: false, - }) - if err != nil { - t.Fatalf("configureProvider() error = %v", err) - } - if result != "zai" { - t.Errorf("configureProvider() = %q, want %q", result, "zai") - } - - loadedCfg := mustLoadConfig(t, tmpDir) - zaiProvider, ok := loadedCfg.Providers["zai"] - if !ok { - t.Fatal("zai provider not found in config") - } - if zaiProvider.Name != "Z.AI" { - t.Errorf("Provider Name = %q, want %q", zaiProvider.Name, "Z.AI") - } - - secretsResult, err := LoadSecrets(context.Background(), tmpDir) - if err != nil { - t.Fatalf("LoadSecrets() error = %v", err) - } - if secretsResult.Secrets["ZAI_API_KEY"] != apiKey { - t.Errorf("ZAI_API_KEY = %q, want %q", secretsResult.Secrets["ZAI_API_KEY"], apiKey) - } -} diff --git a/cmd/coverage_env_test.go b/cmd/coverage_env_test.go deleted file mode 100644 index b77c408..0000000 --- a/cmd/coverage_env_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package cmd - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/dkmnx/kairo/internal/constants" - "github.com/dkmnx/kairo/internal/crypto" -) - -// --- Tests for RequiresAPIKey --- - -func TestRequiresAPIKey(t *testing.T) { - tests := []struct { - name string - provider string - want bool - }{ - {"zai requires key", "zai", true}, - {"minimax requires key", "minimax", true}, - {"deepseek requires key", "deepseek", true}, - {"kimi requires key", "kimi", true}, - {"custom requires key", "custom", true}, - {"unknown requires key", "unknown-provider", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if result := RequiresAPIKey(tt.provider); result != tt.want { - t.Errorf("RequiresAPIKey(%q) = %v, want %v", tt.provider, result, tt.want) - } - }) - } -} - -// --- Tests for BuildProviderEnv --- - -func TestBuildProviderEnv_BasicEnv(t *testing.T) { - tmpDir := withTempConfigDir(t) - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - provider := EnvProvider{BaseURL: "https://api.test.com", Model: "test-model", EnvVars: []string{"TEST_VAR=value123"}} - result, err := BuildProviderEnv(cliCtx, tmpDir, provider, "test") - if err != nil { - t.Fatalf("BuildProviderEnv() error = %v", err) - } - - envMap := make(map[string]string) - for _, env := range result.ProviderEnv { - parts := strings.SplitN(env, "=", 2) - if len(parts) == 2 { - envMap[parts[0]] = parts[1] - } - } - if envMap["ANTHROPIC_BASE_URL"] != "https://api.test.com" { - t.Errorf("ANTHROPIC_BASE_URL = %q, want %q", envMap["ANTHROPIC_BASE_URL"], "https://api.test.com") - } - if envMap["ANTHROPIC_MODEL"] != "test-model" { - t.Errorf("ANTHROPIC_MODEL = %q, want %q", envMap["ANTHROPIC_MODEL"], "test-model") - } - if envMap["TEST_VAR"] != "value123" { - t.Errorf("TEST_VAR = %q, want %q", envMap["TEST_VAR"], "value123") - } -} - -func TestBuildProviderEnv_SecretsLoadFallback(t *testing.T) { - tmpDir := withTempConfigDir(t) - cliCtx := NewCLIContext() - cliCtx.SetConfigDir(tmpDir) - - // When no secrets file exists, LoadSecrets returns empty map, no error. - // So BuildProviderEnv succeeds with empty secrets. - provider := EnvProvider{BaseURL: "https://api.test.com", Model: "test-model"} - result, err := BuildProviderEnv(cliCtx, tmpDir, provider, "zai") - if err != nil { - t.Fatalf("BuildProviderEnv() should not error when secrets file doesn't exist, got: %v", err) - } - if len(result.Secrets) != 0 { - t.Errorf("Expected empty secrets, got %d entries", len(result.Secrets)) - } -} - -// --- Tests for BuildBuiltInEnvVars --- - -func TestBuildBuiltInEnvVars(t *testing.T) { - envVars := BuildBuiltInEnvVars(EnvProvider{BaseURL: "https://api.example.com", Model: "gpt-4"}) - expectedKeys := []string{ - "ANTHROPIC_BASE_URL", "ANTHROPIC_MODEL", "ANTHROPIC_DEFAULT_HAIKU_MODEL", - "ANTHROPIC_DEFAULT_SONNET_MODEL", "ANTHROPIC_DEFAULT_OPUS_MODEL", - "ANTHROPIC_SMALL_FAST_MODEL", "NODE_OPTIONS", - } - envMap := make(map[string]string) - for _, env := range envVars { - parts := strings.SplitN(env, "=", 2) - if len(parts) == 2 { - envMap[parts[0]] = parts[1] - } - } - for _, key := range expectedKeys { - if _, ok := envMap[key]; !ok { - t.Errorf("Expected env var %q not found", key) - } - } - if envMap["NODE_OPTIONS"] != "--no-deprecation" { - t.Errorf("NODE_OPTIONS = %q, want %q", envMap["NODE_OPTIONS"], "--no-deprecation") - } -} - -// --- Tests for APIKeyEnvVarName --- - -func TestAPIKeyEnvVarName(t *testing.T) { - tests := []struct { - provider string - want string - }{ - {"zai", "ZAI_API_KEY"}, - {"minimax", "MINIMAX_API_KEY"}, - {"deepseek", "DEEPSEEK_API_KEY"}, - } - for _, tt := range tests { - t.Run(tt.provider, func(t *testing.T) { - if result := APIKeyEnvVarName(tt.provider); result != tt.want { - t.Errorf("APIKeyEnvVarName(%q) = %q, want %q", tt.provider, result, tt.want) - } - }) - } -} - -// --- Tests for BuildSecretsEnvVars --- - -func TestBuildSecretsEnvVars(t *testing.T) { - secrets := map[string]string{"ZAI_API_KEY": "sk-abc123", "EXTRA_VAR": "extra-value"} - envVars := BuildSecretsEnvVars(secrets) - if len(envVars) != 2 { - t.Errorf("returned %d vars, want 2", len(envVars)) - } - foundKeys := make(map[string]bool) - for _, env := range envVars { - parts := strings.SplitN(env, "=", 2) - if len(parts) == 2 { - foundKeys[parts[0]] = true - } - } - if !foundKeys["ZAI_API_KEY"] || !foundKeys["EXTRA_VAR"] { - t.Error("Missing expected keys in env vars") - } -} - -// --- Tests for ResetSecretsFiles --- - -func TestResetSecretsFiles(t *testing.T) { - tmpDir := t.TempDir() - ctx := context.Background() - keyPath := filepath.Join(tmpDir, constants.KeyFileName) - secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) - - if err := crypto.GenerateKey(ctx, keyPath); err != nil { - t.Fatal(err) - } - if err := crypto.EncryptSecrets(ctx, secretsPath, keyPath, "TEST_API_KEY=test-secret\n"); err != nil { - t.Fatal(err) - } - - err := ResetSecretsFiles(ctx, tmpDir, secretsPath, keyPath) - if err != nil { - t.Fatalf("ResetSecretsFiles() error = %v", err) - } - if _, err := os.Stat(keyPath); os.IsNotExist(err) { - t.Fatal("new key file should exist after reset") - } - if _, err := os.Stat(secretsPath); !os.IsNotExist(err) { - t.Error("old secrets file should be removed after reset") - } -} - -func TestResetSecretsFiles_NoExistingFiles(t *testing.T) { - tmpDir := t.TempDir() - ctx := context.Background() - keyPath := filepath.Join(tmpDir, constants.KeyFileName) - secretsPath := filepath.Join(tmpDir, constants.SecretsFileName) - - err := ResetSecretsFiles(ctx, tmpDir, secretsPath, keyPath) - if err != nil { - t.Fatalf("ResetSecretsFiles() error = %v", err) - } - if _, err := os.Stat(keyPath); os.IsNotExist(err) { - t.Fatal("new key file should exist after reset") - } -} diff --git a/cmd/coverage_test.go b/cmd/coverage_test.go index 855fbe8..8ea9df3 100644 --- a/cmd/coverage_test.go +++ b/cmd/coverage_test.go @@ -1,8 +1,8 @@ package cmd import ( - "context" "os" + "strings" "testing" "github.com/dkmnx/kairo/internal/config" @@ -14,7 +14,29 @@ func TestHandleSecretsError(t *testing.T) { handleSecretsError(testErr) } -// TestBuildProviderListOptions is now in setup_prompts_test.go with table-driven tests. +func TestBuildProviderListOptions(t *testing.T) { + providerList := []string{"anthropic", "zai", "minimax"} + options := buildProviderListOptions(providerList) + + if len(options) != 3 { + t.Errorf("expected 3 options, got %d", len(options)) + } + + expectedProviders := map[string]bool{ + "anthropic": true, + "zai": true, + "minimax": true, + } + + for _, opt := range options { + if !expectedProviders[opt.Value] { + t.Errorf("unexpected provider: %s", opt.Value) + } + if opt.Label != opt.Value { + t.Errorf("label should match value for %s", opt.Value) + } + } +} func TestBuildProviderConfig(t *testing.T) { t.Run("new provider", func(t *testing.T) { @@ -64,8 +86,65 @@ func TestBuildProviderConfig(t *testing.T) { }) } -// TestBuildSecretsEnvVars, TestBuildBuiltInEnvVars, and TestAPIKeyEnvVarName -// are now in coverage_env_test.go with improved coverage. +func TestBuildSecretsEnvVars(t *testing.T) { + secrets := map[string]string{ + "ANTHROPIC_API_KEY": "test-key-123", + "ZAI_API_KEY": "zai-key-456", + } + + envVars := BuildSecretsEnvVars(secrets) + + if len(envVars) != 2 { + t.Errorf("expected 2 env vars, got %d", len(envVars)) + } + + expectedVars := map[string]bool{ + "ANTHROPIC_API_KEY=test-key-123": true, + "ZAI_API_KEY=zai-key-456": true, + } + + for _, envVar := range envVars { + if !expectedVars[envVar] { + t.Errorf("unexpected env var: %s", envVar) + } + } +} + +func TestBuildBuiltInEnvVars(t *testing.T) { + provider := EnvProvider{ + BaseURL: "https://api.test.com", + Model: "test-model", + } + + envVars := BuildBuiltInEnvVars(provider) + + expectedKeys := []string{ + "ANTHROPIC_BASE_URL", + "ANTHROPIC_MODEL", + "ANTHROPIC_DEFAULT_HAIKU_MODEL", + "ANTHROPIC_DEFAULT_SONNET_MODEL", + "ANTHROPIC_DEFAULT_OPUS_MODEL", + "ANTHROPIC_SMALL_FAST_MODEL", + } + + envMap := make(map[string]string) + for _, env := range envVars { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + for _, key := range expectedKeys { + if _, exists := envMap[key]; !exists { + t.Errorf("BuildBuiltInEnvVars() missing expected key %s", key) + } + } + + if envMap["ANTHROPIC_BASE_URL"] != provider.BaseURL { + t.Errorf("ANTHROPIC_BASE_URL = %s, want %s", envMap["ANTHROPIC_BASE_URL"], provider.BaseURL) + } +} func TestSplitArgs(t *testing.T) { tests := []struct { @@ -113,10 +192,30 @@ func TestSplitArgs(t *testing.T) { } } -// TestAPIKeyEnvVarName is now in coverage_env_test.go with additional test cases. +func TestAPIKeyEnvVarName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"anthropic", "anthropic", "ANTHROPIC_API_KEY"}, + {"zai", "zai", "ZAI_API_KEY"}, + {"minimax", "minimax", "MINIMAX_API_KEY"}, + {"UPPERCASE", "UPPERCASE", "UPPERCASE_API_KEY"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := APIKeyEnvVarName(tt.input) + if got != tt.expected { + t.Errorf("APIKeyEnvVarName(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} func TestResolveProviderName(t *testing.T) { - name, err := ResolveProviderName(context.Background(), "anthropic") + name, err := ResolveProviderName("anthropic") if err != nil { t.Errorf("unexpected error: %v", err) } @@ -125,4 +224,14 @@ func TestResolveProviderName(t *testing.T) { } } -// TestGetProviderDefinition is now in coverage_config_test.go with table-driven tests. +func TestGetProviderDefinition(t *testing.T) { + def := GetProviderDefinition("anthropic") + if def.Name == "" { + t.Error("expected non-empty provider definition") + } + + def = GetProviderDefinition("custom-provider") + if def.Name != "custom-provider" { + t.Errorf("expected 'custom-provider', got %q", def.Name) + } +} diff --git a/cmd/delete.go b/cmd/delete.go index 255b017..41581c7 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -28,15 +28,13 @@ var deleteCmd = &cobra.Command{ return } - cfg, err := cliCtx.GetConfigCache().Get(cliCtx.GetRootCtx(), dir) + cfg, err := config.LoadConfig(cliCtx.GetRootCtx(), dir) if err != nil { if os.IsNotExist(err) { ui.PrintWarn("No providers configured") - return } handleConfigError(cmd, err) - return } @@ -45,7 +43,6 @@ var deleteCmd = &cobra.Command{ if len(cfg.Providers) == 0 { ui.PrintWarn("No providers configured") ui.PrintInfo("Run 'kairo setup' to get started") - return } @@ -61,18 +58,17 @@ var deleteCmd = &cobra.Command{ fmt.Println() - tapIntroFn("Delete Provider", tap.MessageOptions{ + tap.Intro("Delete Provider", tap.MessageOptions{ Hint: "Remove a configured provider from Kairo", }) - selected := tapSelectFn(cmd.Context(), tap.SelectOptions[string]{ + selected := tap.Select(context.Background(), tap.SelectOptions[string]{ Message: "Select provider to delete", Options: options, }) target = selected if target == "" { ui.PrintInfo("Operation cancelled") - return } } else { @@ -83,16 +79,14 @@ var deleteCmd = &cobra.Command{ if !ok { ui.PrintError(fmt.Sprintf("Provider '%s' not configured", target)) ui.PrintInfo("Run 'kairo list' to see configured providers") - return } - confirmed := tapConfirmFn(cmd.Context(), tap.ConfirmOptions{ + confirmed := tap.Confirm(context.Background(), tap.ConfirmOptions{ Message: fmt.Sprintf("Are you sure you want to delete '%s'?", target), }) if !confirmed { ui.PrintInfo("Operation cancelled") - return } @@ -104,7 +98,6 @@ var deleteCmd = &cobra.Command{ if err := config.SaveConfig(cliCtx.GetRootCtx(), dir, cfg); err != nil { ui.PrintError(fmt.Sprintf("Saving config: %v", err)) - return } @@ -116,11 +109,10 @@ var deleteCmd = &cobra.Command{ if err := deleteProviderSecrets(cliCtx.GetRootCtx(), secretsPath, keyPath, target); err != nil { ui.PrintError(fmt.Sprintf("Failed to clean up secrets for '%s': %v", target, err)) ui.PrintInfo("Provider removed from config but its secrets could not be deleted — manual cleanup may be required") - return } - tapOutroFn(fmt.Sprintf("Provider '%s' deleted successfully", target)) + tap.Outro(fmt.Sprintf("Provider '%s' deleted successfully", target)) }, } @@ -152,7 +144,6 @@ func deleteProviderSecrets(ctx context.Context, secretsPath, keyPath, providerNa if removeErr := os.Remove(secretsPath); removeErr != nil { return fmt.Errorf("could not remove empty secrets file: %w", removeErr) } - return nil } diff --git a/cmd/delete_test.go b/cmd/delete_test.go index 6bdf935..cd636ba 100644 --- a/cmd/delete_test.go +++ b/cmd/delete_test.go @@ -221,19 +221,18 @@ func TestDeleteProviderSecretsPreservesMalformedLines(t *testing.T) { t.Fatalf("deleteProviderSecrets() error = %v", err) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets() error = %v", err) } - defer crypto.ClearMemory(decrypted) - if !strings.Contains(string(decrypted), "VALID_KEY=valid_value") { + if !strings.Contains(decrypted, "VALID_KEY=valid_value") { t.Error("decrypted content should still contain VALID_KEY=valid_value") } - if !strings.Contains(string(decrypted), "malformed_without_equals") { + if !strings.Contains(decrypted, "malformed_without_equals") { t.Error("decrypted content should still contain malformed_without_equals") } - if strings.Contains(string(decrypted), "PROVIDER_TO_DELETE_API_KEY=secret") { + if strings.Contains(decrypted, "PROVIDER_TO_DELETE_API_KEY=secret") { t.Error("decrypted content should NOT contain PROVIDER_TO_DELETE_API_KEY") } } diff --git a/cmd/execution_env.go b/cmd/execution_env.go index dc8909d..49db9ee 100644 --- a/cmd/execution_env.go +++ b/cmd/execution_env.go @@ -36,9 +36,7 @@ func BuildSecretsEnvVars(secrets map[string]string) []string { } func APIKeyEnvVarName(providerName string) string { - sanitized := strings.ReplaceAll(providerName, "-", "_") - - return fmt.Sprintf("%s_API_KEY", strings.ToUpper(sanitized)) + return fmt.Sprintf("%s_API_KEY", strings.ToUpper(providerName)) } func RequiresAPIKey(providerName string) bool { diff --git a/cmd/execution_harness.go b/cmd/execution_harness.go index 3e43673..31cd593 100644 --- a/cmd/execution_harness.go +++ b/cmd/execution_harness.go @@ -26,7 +26,7 @@ type HarnessRun struct { EnvVarName string } -func runHarnessWithWrapper(parentCtx context.Context, params HarnessRun) error { +func runHarnessWithWrapper(params HarnessRun) error { harnessPath, err := lookPath(params.HarnessBinary) if err != nil { return fmt.Errorf("'%s' command not found in PATH", params.HarnessBinary) @@ -47,7 +47,7 @@ func runHarnessWithWrapper(parentCtx context.Context, params HarnessRun) error { ui.ClearScreen() ui.PrintBanner(kairoversion.Version, params.Provider.Model, params.Provider.Name) - ctx, cancel := context.WithCancel(parentCtx) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() setupSignalHandler(cancel) @@ -108,18 +108,16 @@ func executeWithAuth(cfg ExecutionConfig) { ) run.EnvVarName = "ANTHROPIC_API_KEY" - if err := runHarnessWithWrapper(cfg.Cmd.Context(), run); err != nil { + if err := runHarnessWithWrapper(run); err != nil { cfg.Cmd.Printf("Error running Qwen: %v\n", err) - cleanup() exitProcess(1) } return } - if err := runHarnessWithWrapper(cfg.Cmd.Context(), run); err != nil { + if err := runHarnessWithWrapper(run); err != nil { cfg.Cmd.Printf("Error running Claude: %v\n", err) - cleanup() exitProcess(1) } } @@ -148,7 +146,7 @@ func executeWithoutAuth(cfg ExecutionConfig) { ui.ClearScreen() ui.PrintBanner(kairoversion.Version, cfg.Provider.Model, cfg.Provider.Name) - ctx, cancel := context.WithCancel(cfg.Cmd.Context()) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() setupSignalHandler(cancel) diff --git a/cmd/execution_test.go b/cmd/execution_test.go index 746fce1..b81e386 100644 --- a/cmd/execution_test.go +++ b/cmd/execution_test.go @@ -38,7 +38,7 @@ func TestRunHarnessWithWrapper_HarnessNotFound(t *testing.T) { }, } - err := runHarnessWithWrapper(context.Background(), run) + err := runHarnessWithWrapper(run) if err == nil { t.Fatal("runHarnessWithWrapper() should return error when harness not found") } @@ -70,7 +70,7 @@ func TestRunHarnessWithWrapper_WrapperGenerationFails(t *testing.T) { }, } - err := runHarnessWithWrapper(context.Background(), run) + err := runHarnessWithWrapper(run) if err == nil { t.Fatal("runHarnessWithWrapper() should return error when wrapper generation fails") } @@ -123,7 +123,7 @@ func TestRunHarnessWithWrapper_Success(t *testing.T) { }, } - err := runHarnessWithWrapper(context.Background(), run) + err := runHarnessWithWrapper(run) if err != nil { t.Fatalf("runHarnessWithWrapper() should succeed, got error: %v", err) } @@ -202,7 +202,6 @@ func TestExecuteWithAuth_TokenFileWriteFails(t *testing.T) { } cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -267,7 +266,6 @@ func TestExecuteWithAuth_QwenHarness(t *testing.T) { exitProcess = func(int) {} cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -331,7 +329,6 @@ func TestExecuteWithAuth_ClaudeHarness(t *testing.T) { exitProcess = func(int) {} cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -402,7 +399,6 @@ func TestExecuteWithAuth_YoloModeClaude(t *testing.T) { exitProcess = func(int) {} cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -481,7 +477,6 @@ func TestExecuteWithAuth_YoloModeQwen(t *testing.T) { exitProcess = func(int) {} cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -549,7 +544,7 @@ func TestApiKeyEnvVarName(t *testing.T) { {"lowercase provider", "anthropic", "ANTHROPIC_API_KEY"}, {"uppercase provider", "ANTHROPIC", "ANTHROPIC_API_KEY"}, {"mixed case provider", "MiniMax", "MINIMAX_API_KEY"}, - {"provider with hyphen", "my-provider", "MY_PROVIDER_API_KEY"}, + {"provider with hyphen", "my-provider", "MY-PROVIDER_API_KEY"}, {"provider with underscore", "my_provider", "MY_PROVIDER_API_KEY"}, } @@ -565,7 +560,6 @@ func TestApiKeyEnvVarName(t *testing.T) { func TestExecuteWithoutAuth_QwenNoAPIKey(t *testing.T) { cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -594,7 +588,6 @@ func TestExecuteWithoutAuth_HarnessNotFound(t *testing.T) { } cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -644,7 +637,6 @@ func TestExecuteWithoutAuth_ExecutionFails(t *testing.T) { } cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -692,7 +684,6 @@ func TestExecuteWithoutAuth_YoloModeClaude(t *testing.T) { exitProcess = func(int) {} cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) @@ -789,7 +780,6 @@ func TestBuildProviderEnvironment_WithProviderEnvVars(t *testing.T) { func TestExecuteWithoutAuth_QwenNoAuth(t *testing.T) { cmd := &cobra.Command{} - cmd.SetContext(context.Background()) var output bytes.Buffer cmd.SetOut(&output) diff --git a/cmd/harness.go b/cmd/harness.go index 7ca775c..5feee51 100644 --- a/cmd/harness.go +++ b/cmd/harness.go @@ -32,7 +32,6 @@ var harnessGetCmd = &cobra.Command{ if cfg.DefaultHarness == "" { ui.PrintInfo("No default harness configured (using claude)") - return } @@ -51,7 +50,6 @@ var harnessSetCmd = &cobra.Command{ if !isValidHarness(harnessName) { ui.PrintError(fmt.Sprintf("Invalid harness: '%s'", args[0])) ui.PrintInfo("Valid harnesses: claude, qwen") - return } @@ -63,7 +61,6 @@ var harnessSetCmd = &cobra.Command{ cfg, err := GetCLIContext(cmd).GetConfigCache().Get(GetCLIContext(cmd).GetRootCtx(), dir) if err != nil && !errors.Is(err, kairoerrors.ErrConfigNotFound) { handleConfigError(cmd, err) - return } if err != nil { @@ -76,7 +73,6 @@ var harnessSetCmd = &cobra.Command{ cfg.DefaultHarness = harnessName if err := config.SaveConfig(GetCLIContext(cmd).GetRootCtx(), dir, cfg); err != nil { ui.PrintError(fmt.Sprintf("Error saving config: %v", err)) - return } @@ -108,10 +104,8 @@ func getHarness(flagHarness, configHarness string) string { } if !isValidHarness(harness) { ui.PrintWarn(fmt.Sprintf("Unknown harness '%s', using 'claude'", harness)) - return harnessClaude } - return harness } diff --git a/cmd/integration_test.go b/cmd/integration_test.go index bf06f5e..33109c7 100644 --- a/cmd/integration_test.go +++ b/cmd/integration_test.go @@ -74,13 +74,12 @@ func TestFullProviderConfigurationWorkflow(t *testing.T) { t.Errorf("loaded %d providers, want %d", len(loadedCfg.Providers), len(providersToTest)) } - decryptedContent, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decryptedContent, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decryptedContent) - parsedSecrets := secretspkg.Parse(string(decryptedContent)) + parsedSecrets := secretspkg.Parse(decryptedContent) for _, p := range providersToTest { for k := range p.envVars { if _, ok := parsedSecrets[k]; !ok { @@ -115,13 +114,12 @@ func TestCustomProviderConfigPersistence(t *testing.T) { } // Test config persistence without rotation - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decrypted) - if !strings.Contains(string(decrypted), customKey) { + if !strings.Contains(decrypted, customKey) { t.Errorf("decrypted secrets should contain %q, got: %q", customKey, decrypted) } } @@ -201,17 +199,16 @@ func TestE2ECompleteWorkflow(t *testing.T) { t.Error("zai provider should exist in config") } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decrypted) - if !strings.Contains(string(decrypted), "ZAI_API_KEY") { + if !strings.Contains(decrypted, "ZAI_API_KEY") { t.Error("secrets should contain ZAI_API_KEY") } - if !strings.Contains(string(decrypted), "sk-zai-test-key-12345") { + if !strings.Contains(decrypted, "sk-zai-test-key-12345") { t.Error("secrets should contain the API key") } diff --git a/cmd/list.go b/cmd/list.go index 9b90686..274171e 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -20,27 +20,23 @@ var listCmd = &cobra.Command{ dir := requireConfigDir(cmd) if dir == "" { ui.PrintInfo("Run 'kairo setup' to configure providers") - return } - cfg, err := cliCtx.GetConfigCache().Get(cliCtx.GetRootCtx(), dir) + cfg, err := config.LoadConfig(cliCtx.GetRootCtx(), dir) if err != nil { if os.IsNotExist(err) { ui.PrintWarn("No providers configured") ui.PrintInfo("Run 'kairo setup' to get started") - return } handleConfigError(cmd, err) - return } if len(cfg.Providers) == 0 { ui.PrintWarn("No providers configured") ui.PrintInfo("Run 'kairo setup' to get started") - return } @@ -93,9 +89,7 @@ func sortProviderNames(providers map[string]config.Provider, defaultProvider str if names[j] == defaultProvider { return false } - return names[i] < names[j] }) - return names } diff --git a/cmd/root.go b/cmd/root.go index 7087cdb..0c43dea 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -52,101 +52,92 @@ encrypted secrets management using age encryption. Version: %s (commit: %s, date: %s)`, kairoversion.Version, kairoversion.Commit, kairoversion.Date), Args: cobra.MinimumNArgs(0), - Run: runRoot, -} - -func runRoot(cmd *cobra.Command, args []string) { - cliCtx := GetCLIContext(cmd) - configDir := cliCtx.GetConfigDir() - if configDir == "" { - cmd.Println("Error: config directory not found") - if err := cmd.Help(); err != nil { - cmd.Println(err) + Run: func(cmd *cobra.Command, args []string) { + cliCtx := GetCLIContext(cmd) + configDir := cliCtx.GetConfigDir() + if configDir == "" { + cmd.Println("Error: config directory not found") + if err := cmd.Help(); err != nil { + cmd.Println(err) + } + + return } - return - } - - cfg, err := loadRootConfig(cmd, cliCtx, configDir) - if err != nil { - return - } - - if len(cfg.Providers) == 0 { - cmd.Println("No providers configured. Run 'kairo setup' to get started.") + cfg, err := cliCtx.GetConfigCache().Get(cliCtx.GetRootCtx(), configDir) + if err != nil { + if os.IsNotExist(err) { + cmd.Println("No providers configured. Run 'kairo setup' to get started.") - return - } - - _, harnessArgs, providerName := resolveProviderAndArgs(cmd, cfg, args) - if providerName == "" { - return - } - - provider, ok := cfg.Providers[providerName] - if !ok { - cmd.Printf("Error: provider '%s' not configured\n", providerName) - cmd.Println("Run 'kairo list' to see configured providers") - - return - } + return + } + handleConfigError(cmd, err) - dispatchExecution(cliCtx, cmd, configDir, provider, providerName, harnessArgs, cfg.DefaultHarness) -} + return + } -func loadRootConfig(cmd *cobra.Command, cliCtx *CLIContext, configDir string) (*config.Config, error) { - cfg, err := cliCtx.GetConfigCache().Get(cliCtx.GetRootCtx(), configDir) - if err != nil { - if os.IsNotExist(err) { + if len(cfg.Providers) == 0 { cmd.Println("No providers configured. Run 'kairo setup' to get started.") - return nil, err + return } - handleConfigError(cmd, err) - return nil, err - } + _, harnessArgs, providerName := resolveProviderAndArgs(cmd, cfg, args) + if providerName == "" { + return + } - return cfg, nil -} + provider, ok := cfg.Providers[providerName] + if !ok { + cmd.Printf("Error: provider '%s' not configured\n", providerName) + cmd.Println("Run 'kairo list' to see configured providers") -func dispatchExecution( - cliCtx *CLIContext, cmd *cobra.Command, configDir string, - provider config.Provider, providerName string, harnessArgs []string, defaultHarness string, -) { - harnessToUse := getHarness(harnessFlag, defaultHarness) - harnessBinary := getHarnessBinary(harnessToUse) - - envResult, err := BuildProviderEnv(cliCtx, configDir, EnvProvider{ - BaseURL: provider.BaseURL, - Model: provider.Model, - EnvVars: provider.EnvVars, - }, providerName) - if err != nil { - handleSecretsError(err) - - return - } + return + } - execCfg := ExecutionConfig{ - Cmd: cmd, - ProviderEnv: envResult.ProviderEnv, - HarnessToUse: harnessToUse, - HarnessBinary: harnessBinary, - Provider: provider, - HarnessArgs: harnessArgs, - Yolo: yoloFlag, - } + harnessToUse := getHarness(harnessFlag, cfg.DefaultHarness) + harnessBinary := getHarnessBinary(harnessToUse) - apiKeyKey := APIKeyEnvVarName(providerName) - if apiKey, hasKey := envResult.Secrets[apiKeyKey]; hasKey { - execCfg.APIKey = apiKey - executeWithAuth(execCfg) + envResult, err := BuildProviderEnv(cliCtx, configDir, EnvProvider{ + BaseURL: provider.BaseURL, + Model: provider.Model, + EnvVars: provider.EnvVars, + }, providerName) + if err != nil { + handleSecretsError(err) - return - } + return + } + + providerEnv := envResult.ProviderEnv + secrets := envResult.Secrets + + apiKeyKey := APIKeyEnvVarName(providerName) + if apiKey, hasKey := secrets[apiKeyKey]; hasKey { + executeWithAuth(ExecutionConfig{ + Cmd: cmd, + ProviderEnv: providerEnv, + HarnessToUse: harnessToUse, + HarnessBinary: harnessBinary, + Provider: provider, + HarnessArgs: harnessArgs, + APIKey: apiKey, + Yolo: yoloFlag, + }) + + return + } - executeWithoutAuth(execCfg) + executeWithoutAuth(ExecutionConfig{ + Cmd: cmd, + ProviderEnv: providerEnv, + HarnessToUse: harnessToUse, + HarnessBinary: harnessBinary, + Provider: provider, + HarnessArgs: harnessArgs, + Yolo: yoloFlag, + }) + }, } func Execute() error { diff --git a/cmd/setup.go b/cmd/setup.go index a748dc4..ed4ff17 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "errors" "fmt" @@ -13,11 +12,8 @@ import ( var setupResetSecrets bool -// Injectable ui.Confirm wrapper for testability. -var confirmUIFn = ui.Confirm - -func configureProvider(ctx context.Context, params ProviderSetup) (string, error) { - validatedName, err := ResolveProviderName(ctx, params.ProviderName) +func configureProvider(params ProviderSetup) (string, error) { + validatedName, err := ResolveProviderName(params.ProviderName) if err != nil { return "", err } @@ -36,17 +32,17 @@ func configureProvider(ctx context.Context, params ProviderSetup) (string, error displayProviderHeader(promptCfg) - apiKey := promptForAPIKey(ctx, promptCfg) + apiKey := promptForAPIKey(promptCfg) if err := validate.ValidateAPIKey(apiKey, definition.Name); err != nil { return "", err } - baseURL := promptForBaseURL(ctx, promptCfg) + baseURL := promptForBaseURL(promptCfg) if err := validate.ValidateURL(baseURL, definition.Name); err != nil { return "", err } - model := promptForModel(ctx, promptCfg) + model := promptForModel(promptCfg) if err := validateConfiguredModel(modelValidationConfig{ Model: model, ProviderName: validatedName, @@ -80,7 +76,7 @@ func configureProvider(ctx context.Context, params ProviderSetup) (string, error return "", err } - tapOutroFn(fmt.Sprintf("%s configured successfully", provider.Name), tap.MessageOptions{ + tap.Outro(fmt.Sprintf("%s configured successfully", provider.Name), tap.MessageOptions{ Hint: fmt.Sprintf("Run 'kairo %s' to use this provider", validatedName), }) @@ -92,7 +88,7 @@ func runResetSecrets(cliCtx *CLIContext, configDir string, secretsResult Secrets ui.PrintInfo("You will need to re-enter all API keys.") ui.PrintInfo("") - confirmed, err := confirmUIFn("Continue") + confirmed, err := ui.Confirm("Continue") if err != nil || !confirmed { return errors.New("operation cancelled by user") } @@ -157,7 +153,7 @@ var setupCmd = &cobra.Command{ ui.PrintWarn(w) } - providerName := promptForProvider(cmd.Context(), cfg) + providerName := promptForProvider(cfg) if providerName == "" { ui.PrintInfo("Setup cancelled") @@ -165,7 +161,7 @@ var setupCmd = &cobra.Command{ } _, exists := cfg.Providers[providerName] - if _, err := configureProvider(cmd.Context(), ProviderSetup{ + if _, err := configureProvider(ProviderSetup{ CLIContext: cliCtx, ConfigDir: configDir, Cfg: cfg, diff --git a/cmd/setup_config.go b/cmd/setup_config.go index 52236c8..558c661 100644 --- a/cmd/setup_config.go +++ b/cmd/setup_config.go @@ -41,10 +41,7 @@ func LoadConfig(cliCtx *CLIContext, configDir string) (*config.Config, error) { } type AddProviderParams struct { - CLIContext interface { - InvalidateCache(dir string) - GetRootCtx() context.Context - } + CLIContext interface{ InvalidateCache(dir string) } ConfigDir string Cfg *config.Config ProviderName string @@ -57,7 +54,7 @@ func AddAndSaveProvider(params AddProviderParams) error { if params.SetAsDefault && params.Cfg.DefaultProvider == "" { params.Cfg.DefaultProvider = params.ProviderName } - if err := config.SaveConfig(params.CLIContext.GetRootCtx(), params.ConfigDir, params.Cfg); err != nil { + if err := config.SaveConfig(context.Background(), params.ConfigDir, params.Cfg); err != nil { return kairoerrors.WrapError(kairoerrors.ConfigError, "saving config", err) } diff --git a/cmd/setup_prompts.go b/cmd/setup_prompts.go index 993c6cd..94d2737 100644 --- a/cmd/setup_prompts.go +++ b/cmd/setup_prompts.go @@ -12,58 +12,6 @@ import ( const setupNewProvider = "Setup new provider" -// Injectable tap function variables for testability. -// These follow the same pattern as lookPath, execCommandContext, etc. -var ( - tapSelectFn = defaultTapSelect - tapTextFn = tapText - tapPasswordFn = tapPassword - tapConfirmFn = tapConfirm - tapIntroFn = tapIntroFunc - tapOutroFn = tapOutroFunc - tapMessageFn = tapMessageFunc -) - -func defaultTapSelect(ctx context.Context, opts tap.SelectOptions[string]) string { - return tap.Select(ctx, opts) -} - -func tapText(ctx context.Context, opts tap.TextOptions) string { - return tap.Text(ctx, opts) -} - -func tapPassword(ctx context.Context, opts tap.PasswordOptions) string { - return tap.Password(ctx, opts) -} - -func tapConfirm(ctx context.Context, opts tap.ConfirmOptions) bool { - return tap.Confirm(ctx, opts) -} - -func tapIntroFunc(title string, opts ...tap.MessageOptions) { - if len(opts) > 0 { - tap.Intro(title, opts[0]) - } else { - tap.Intro(title) - } -} - -func tapOutroFunc(message string, opts ...tap.MessageOptions) { - if len(opts) > 0 { - tap.Outro(message, opts[0]) - } else { - tap.Outro(message) - } -} - -func tapMessageFunc(message string, opts ...tap.MessageOptions) { - if len(opts) > 0 { - tap.Message(message, opts[0]) - } else { - tap.Message(message) - } -} - func buildProviderListOptions(providerList []string) []tap.SelectOption[string] { options := make([]tap.SelectOption[string], len(providerList)) for i, name := range providerList { @@ -73,7 +21,9 @@ func buildProviderListOptions(providerList []string) []tap.SelectOption[string] return options } -func promptForProvider(ctx context.Context, cfg *config.Config) string { +func promptForProvider(cfg *config.Config) string { + ctx := context.Background() + if len(cfg.Providers) == 0 { return promptForNewProvider(ctx) } @@ -85,7 +35,7 @@ func promptForNewProvider(ctx context.Context) string { allProviders := append(providers.GetProviderList(), "custom") options := buildProviderListOptions(allProviders) - return tapSelectFn(ctx, tap.SelectOptions[string]{ + return tap.Select(ctx, tap.SelectOptions[string]{ Message: "Select provider to configure", Options: options, }) @@ -101,11 +51,11 @@ func promptForExistingOrNewProvider(ctx context.Context, cfg *config.Config) str fmt.Println() - tapIntroFn("Setup Provider", tap.MessageOptions{ + tap.Intro("Setup Provider", tap.MessageOptions{ Hint: "Configure new provider or edit existing from Kairo", }) - selected := tapSelectFn(ctx, tap.SelectOptions[string]{ + selected := tap.Select(ctx, tap.SelectOptions[string]{ Message: "Select provider to edit or setup new", Options: options, }) @@ -128,24 +78,26 @@ type providerPromptConfig struct { func displayProviderHeader(cfg providerPromptConfig) { if cfg.IsEdit && cfg.Exists { - tapMessageFn(fmt.Sprintf("Editing %s", cfg.Provider.Name), tap.MessageOptions{ + tap.Message(fmt.Sprintf("Editing %s", cfg.Provider.Name), tap.MessageOptions{ Hint: "Press Enter to keep current values", }) } } -func promptForAPIKey(ctx context.Context, cfg providerPromptConfig) string { +func promptForAPIKey(cfg providerPromptConfig) string { + ctx := context.Background() + if !cfg.IsEdit || !cfg.Exists { - return tapPasswordFn(ctx, tap.PasswordOptions{Message: "API Key"}) + return tap.Password(ctx, tap.PasswordOptions{Message: "API Key"}) } existingKey := cfg.Secrets[APIKeyEnvVarName(cfg.ProviderName)] if existingKey == "" { - return tapPasswordFn(ctx, tap.PasswordOptions{Message: "API Key"}) + return tap.Password(ctx, tap.PasswordOptions{Message: "API Key"}) } - if tapConfirmFn(ctx, tap.ConfirmOptions{Message: "Modify API key?"}) { - return tapPasswordFn(ctx, tap.PasswordOptions{Message: "New API Key"}) + if tap.Confirm(ctx, tap.ConfirmOptions{Message: "Modify API key?"}) { + return tap.Password(ctx, tap.PasswordOptions{Message: "New API Key"}) } return existingKey @@ -159,12 +111,14 @@ type promptFieldConfig struct { Exists bool } -func promptForField(ctx context.Context, cfg promptFieldConfig) string { +func promptForField(cfg promptFieldConfig) string { + ctx := context.Background() + if cfg.IsEdit && cfg.Exists { return promptForFieldEdit(ctx, cfg) } - result := strings.TrimSpace(tapTextFn(ctx, tap.TextOptions{ + result := strings.TrimSpace(tap.Text(ctx, tap.TextOptions{ Message: cfg.Label, DefaultValue: cfg.DefaultValue, Placeholder: cfg.DefaultValue, @@ -184,10 +138,10 @@ func promptForFieldEdit(ctx context.Context, cfg promptFieldConfig) string { } if effectiveDefault != "" { - if tapConfirmFn(ctx, tap.ConfirmOptions{ + if tap.Confirm(ctx, tap.ConfirmOptions{ Message: fmt.Sprintf("Modify %s? (current: %s)", cfg.Label, effectiveDefault), }) { - return strings.TrimSpace(tapTextFn(ctx, tap.TextOptions{ + return strings.TrimSpace(tap.Text(ctx, tap.TextOptions{ Message: fmt.Sprintf("New %s", cfg.Label), DefaultValue: effectiveDefault, Placeholder: effectiveDefault, @@ -197,14 +151,14 @@ func promptForFieldEdit(ctx context.Context, cfg promptFieldConfig) string { return effectiveDefault } - return strings.TrimSpace(tapTextFn(ctx, tap.TextOptions{ + return strings.TrimSpace(tap.Text(ctx, tap.TextOptions{ Message: cfg.Label, Placeholder: cfg.DefaultValue, })) } -func promptForBaseURL(ctx context.Context, cfg providerPromptConfig) string { - return promptForField(ctx, promptFieldConfig{ +func promptForBaseURL(cfg providerPromptConfig) string { + return promptForField(promptFieldConfig{ Label: "Base URL", CurrentValue: cfg.Provider.BaseURL, DefaultValue: cfg.Definition.BaseURL, @@ -213,8 +167,8 @@ func promptForBaseURL(ctx context.Context, cfg providerPromptConfig) string { }) } -func promptForModel(ctx context.Context, cfg providerPromptConfig) string { - return promptForField(ctx, promptFieldConfig{ +func promptForModel(cfg providerPromptConfig) string { + return promptForField(promptFieldConfig{ Label: "Model", CurrentValue: cfg.Provider.Model, DefaultValue: cfg.Definition.Model, diff --git a/cmd/setup_prompts_test.go b/cmd/setup_prompts_test.go index 255090e..63bca95 100644 --- a/cmd/setup_prompts_test.go +++ b/cmd/setup_prompts_test.go @@ -1,412 +1,83 @@ package cmd import ( - "context" - "strings" "testing" "github.com/dkmnx/kairo/internal/config" - "github.com/dkmnx/kairo/internal/providers" - "github.com/yarlson/tap" ) -// --- Test helper restore utilities --- - -type tapFuncs struct { - selectFn func(ctx context.Context, opts tap.SelectOptions[string]) string - textFn func(ctx context.Context, opts tap.TextOptions) string - passwordFn func(ctx context.Context, opts tap.PasswordOptions) string - confirmFn func(ctx context.Context, opts tap.ConfirmOptions) bool - introFn func(title string, opts ...tap.MessageOptions) - outroFn func(message string, opts ...tap.MessageOptions) - messageFn func(message string, opts ...tap.MessageOptions) -} - -// withMockedTAP saves all tap function variables, replaces them with mocks, -// and defers restoration. Returns the original state for inspection. -func withMockedTAP(t *testing.T) *tapFuncs { - t.Helper() - origins := &tapFuncs{ - selectFn: tapSelectFn, - textFn: tapTextFn, - passwordFn: tapPasswordFn, - confirmFn: tapConfirmFn, - introFn: tapIntroFn, - outroFn: tapOutroFn, - messageFn: tapMessageFn, - } +func TestPromptFieldConfig(t *testing.T) { + t.Run("struct fields", func(t *testing.T) { + cfg := promptFieldConfig{ + Label: "Test Label", + CurrentValue: "current", + DefaultValue: "default", + IsEdit: true, + Exists: true, + } - t.Cleanup(func() { - tapSelectFn = origins.selectFn - tapTextFn = origins.textFn - tapPasswordFn = origins.passwordFn - tapConfirmFn = origins.confirmFn - tapIntroFn = origins.introFn - tapOutroFn = origins.outroFn - tapMessageFn = origins.messageFn + if cfg.Label != "Test Label" { + t.Errorf("Label = %q, want 'Test Label'", cfg.Label) + } + if cfg.CurrentValue != "current" { + t.Errorf("CurrentValue = %q, want 'current'", cfg.CurrentValue) + } + if cfg.DefaultValue != "default" { + t.Errorf("DefaultValue = %q, want 'default'", cfg.DefaultValue) + } + if !cfg.IsEdit { + t.Error("IsEdit should be true") + } + if !cfg.Exists { + t.Error("Exists should be true") + } }) - - return origins } -// --- Tests for buildProviderListOptions --- - -func TestBuildProviderListOptions(t *testing.T) { +func TestDisplayProviderHeader_NoPanic(t *testing.T) { tests := []struct { - name string - input []string - wantLen int - wantValues []string + name string + config providerPromptConfig }{ - {name: "empty list", input: []string{}, wantLen: 0, wantValues: []string{}}, - {name: "single provider", input: []string{"zai"}, wantLen: 1, wantValues: []string{"zai"}}, - {name: "multiple providers", input: []string{"zai", "minimax", "deepseek"}, wantLen: 3, wantValues: []string{"zai", "minimax", "deepseek"}}, + { + name: "edit mode existing provider", + config: providerPromptConfig{ + ProviderName: "Test Provider", + Provider: config.Provider{ + Name: "Test Provider", + BaseURL: "https://test.com", + Model: "test-model", + }, + IsEdit: true, + Exists: true, + }, + }, + { + name: "new provider mode", + config: providerPromptConfig{ + ProviderName: "New Provider", + Provider: config.Provider{ + Name: "New Provider", + BaseURL: "https://new.com", + Model: "new-model", + }, + IsEdit: false, + Exists: false, + }, + }, + { + name: "empty provider", + config: providerPromptConfig{ + Provider: config.Provider{}, + IsEdit: true, + Exists: true, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - options := buildProviderListOptions(tt.input) - if len(options) != tt.wantLen { - t.Errorf("buildProviderListOptions() returned %d options, want %d", len(options), tt.wantLen) - } - for i, opt := range options { - if i < len(tt.wantValues) { - if opt.Value != tt.wantValues[i] { - t.Errorf("options[%d].Value = %q, want %q", i, opt.Value, tt.wantValues[i]) - } - if opt.Label != tt.wantValues[i] { - t.Errorf("options[%d].Label = %q, want %q", i, opt.Label, tt.wantValues[i]) - } - } - } + displayProviderHeader(tt.config) }) } } - -// --- Tests for displayProviderHeader --- - -func TestDisplayProviderHeader_EditExisting(t *testing.T) { - messages := []string{} - withMockedTAP(t) - tapMessageFn = func(message string, opts ...tap.MessageOptions) { - messages = append(messages, message) - } - - cfg := providerPromptConfig{ - ProviderName: "zai", - Provider: config.Provider{Name: "Z.AI"}, - IsEdit: true, Exists: true, - } - displayProviderHeader(cfg) - - if len(messages) != 1 { - t.Fatalf("Expected 1 message, got %d", len(messages)) - } - if !strings.Contains(messages[0], "Editing Z.AI") { - t.Errorf("Expected message containing 'Editing Z.AI', got %q", messages[0]) - } -} - -func TestDisplayProviderHeader_NewProvider(t *testing.T) { - called := false - withMockedTAP(t) - tapMessageFn = func(message string, opts ...tap.MessageOptions) { called = true } - - cfg := providerPromptConfig{ - ProviderName: "zai", Provider: config.Provider{Name: "Z.AI"}, - IsEdit: false, Exists: false, - } - displayProviderHeader(cfg) - if called { - t.Error("displayProviderHeader should not call Message for new provider") - } -} - -func TestDisplayProviderHeader_EditNotExisting(t *testing.T) { - called := false - withMockedTAP(t) - tapMessageFn = func(message string, opts ...tap.MessageOptions) { called = true } - - cfg := providerPromptConfig{ - ProviderName: "zai", Provider: config.Provider{Name: "Z.AI"}, - IsEdit: true, Exists: false, - } - displayProviderHeader(cfg) - if called { - t.Error("displayProviderHeader should not call Message when Exists=false") - } -} - -// --- Tests for promptForAPIKey --- - -func TestPromptForAPIKey_NewProvider(t *testing.T) { - withMockedTAP(t) - apiKey := "test-api-key-12345678901234567890" - tapPasswordFn = func(ctx context.Context, opts tap.PasswordOptions) string { return apiKey } - - cfg := providerPromptConfig{ProviderName: "zai", IsEdit: false, Exists: false} - result := promptForAPIKey(context.Background(), cfg) - if result != apiKey { - t.Errorf("promptForAPIKey() = %q, want %q", result, apiKey) - } -} - -func TestPromptForAPIKey_EditKeepExisting(t *testing.T) { - withMockedTAP(t) - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return false } - - cfg := providerPromptConfig{ - ProviderName: "zai", - Secrets: map[string]string{"ZAI_API_KEY": "existing-key-12345678901234567890"}, - IsEdit: true, Exists: true, - } - result := promptForAPIKey(context.Background(), cfg) - if result != "existing-key-12345678901234567890" { - t.Errorf("promptForAPIKey() = %q, want existing key", result) - } -} - -func TestPromptForAPIKey_EditModifyKey(t *testing.T) { - withMockedTAP(t) - newKey := "new-api-key-123456789012345678901" - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return true } - tapPasswordFn = func(ctx context.Context, opts tap.PasswordOptions) string { return newKey } - - cfg := providerPromptConfig{ - ProviderName: "zai", - Secrets: map[string]string{"ZAI_API_KEY": "existing-key-12345678901234567890"}, - IsEdit: true, Exists: true, - } - result := promptForAPIKey(context.Background(), cfg) - if result != newKey { - t.Errorf("promptForAPIKey() = %q, want %q", result, newKey) - } -} - -func TestPromptForAPIKey_EditNoExistingKey(t *testing.T) { - withMockedTAP(t) - apiKey := "fresh-api-key-12345678901234567890" - tapPasswordFn = func(ctx context.Context, opts tap.PasswordOptions) string { return apiKey } - - cfg := providerPromptConfig{ - ProviderName: "zai", Secrets: map[string]string{}, - IsEdit: true, Exists: true, - } - result := promptForAPIKey(context.Background(), cfg) - if result != apiKey { - t.Errorf("promptForAPIKey() = %q, want %q", result, apiKey) - } -} - -// --- Tests for promptForField --- - -func TestPromptForField_NewProvider(t *testing.T) { - withMockedTAP(t) - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return "custom-base-url" } - - cfg := promptFieldConfig{Label: "Base URL", DefaultValue: "https://api.default.com", IsEdit: false} - result := promptForField(context.Background(), cfg) - if result != "custom-base-url" { - t.Errorf("promptForField() = %q, want %q", result, "custom-base-url") - } -} - -func TestPromptForField_DefaultOnEmpty(t *testing.T) { - withMockedTAP(t) - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return "" } - - cfg := promptFieldConfig{Label: "Base URL", DefaultValue: "https://api.default.com", IsEdit: false} - result := promptForField(context.Background(), cfg) - if result != "https://api.default.com" { - t.Errorf("promptForField() = %q, want default %q", result, "https://api.default.com") - } -} - -func TestPromptForField_EditKeep(t *testing.T) { - withMockedTAP(t) - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return false } - - cfg := promptFieldConfig{ - Label: "Base URL", CurrentValue: "https://current.com", - DefaultValue: "https://api.default.com", IsEdit: true, Exists: true, - } - result := promptForField(context.Background(), cfg) - if result != "https://current.com" { - t.Errorf("promptForField() = %q, want current value %q", result, "https://current.com") - } -} - -func TestPromptForField_EditModify(t *testing.T) { - withMockedTAP(t) - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return true } - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return " https://modified.com " } - - cfg := promptFieldConfig{ - Label: "Base URL", CurrentValue: "https://current.com", - DefaultValue: "https://api.default.com", IsEdit: true, Exists: true, - } - result := promptForField(context.Background(), cfg) - if result != "https://modified.com" { - t.Errorf("promptForField() = %q, want %q", result, "https://modified.com") - } -} - -// --- Tests for promptForBaseURL and promptForModel --- - -func TestPromptForBaseURL(t *testing.T) { - withMockedTAP(t) - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return "https://custom.api.com/anthropic" } - - cfg := providerPromptConfig{ - ProviderName: "custom", - Definition: providers.ProviderDefinition{Name: "Custom", BaseURL: "https://default.com"}, - IsEdit: false, - } - result := promptForBaseURL(context.Background(), cfg) - if result != "https://custom.api.com/anthropic" { - t.Errorf("promptForBaseURL() = %q, want %q", result, "https://custom.api.com/anthropic") - } -} - -func TestPromptForModel(t *testing.T) { - withMockedTAP(t) - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return "custom-model-v2" } - - cfg := providerPromptConfig{ - ProviderName: "custom", - Definition: providers.ProviderDefinition{Name: "Custom", Model: "default-model"}, - IsEdit: false, - } - result := promptForModel(context.Background(), cfg) - if result != "custom-model-v2" { - t.Errorf("promptForModel() = %q, want %q", result, "custom-model-v2") - } -} - -// --- Tests for promptForProvider --- - -func TestPromptForProvider_NoProviders(t *testing.T) { - withMockedTAP(t) - tapSelectFn = func(ctx context.Context, opts tap.SelectOptions[string]) string { return "zai" } - - cfg := &config.Config{Providers: make(map[string]config.Provider)} - result := promptForProvider(context.Background(), cfg) - if result != "zai" { - t.Errorf("promptForProvider(context.Background(), ) = %q, want %q", result, "zai") - } -} - -func TestPromptForProvider_SelectNewProvider(t *testing.T) { - withMockedTAP(t) - callCount := 0 - tapSelectFn = func(ctx context.Context, opts tap.SelectOptions[string]) string { - callCount++ - if callCount == 1 { - return setupNewProvider - } - return "deepseek" - } - tapIntroFn = func(title string, opts ...tap.MessageOptions) {} - - cfg := &config.Config{Providers: map[string]config.Provider{"zai": {Name: "Z.AI"}}} - result := promptForProvider(context.Background(), cfg) - if result != "deepseek" { - t.Errorf("promptForProvider(context.Background(), ) = %q, want %q", result, "deepseek") - } -} - -func TestPromptForProvider_Cancel(t *testing.T) { - withMockedTAP(t) - tapSelectFn = func(ctx context.Context, opts tap.SelectOptions[string]) string { return "" } - - cfg := &config.Config{Providers: map[string]config.Provider{"zai": {Name: "Z.AI"}}} - result := promptForProvider(context.Background(), cfg) - if result != "" { - t.Errorf("promptForProvider(context.Background(), ) should return empty string on cancel, got %q", result) - } -} - -// --- Tests for promptForNewProvider --- - -func TestPromptForNewProvider(t *testing.T) { - withMockedTAP(t) - tapSelectFn = func(ctx context.Context, opts tap.SelectOptions[string]) string { - providerNames := make([]string, len(opts.Options)) - for i, opt := range opts.Options { - providerNames[i] = opt.Value - } - for _, name := range providers.GetProviderList() { - found := false - for _, pn := range providerNames { - if pn == name { - found = true - break - } - } - if !found { - t.Errorf("Expected provider %q in options", name) - } - } - foundCustom := false - for _, pn := range providerNames { - if pn == "custom" { - foundCustom = true - break - } - } - if !foundCustom { - t.Error("Expected 'custom' in options") - } - return "minimax" - } - - result := promptForNewProvider(context.Background()) - if result != "minimax" { - t.Errorf("promptForNewProvider() = %q, want %q", result, "minimax") - } -} - -// --- Tests for promptForFieldEdit --- - -func TestPromptForFieldEdit_ConfirmModify(t *testing.T) { - withMockedTAP(t) - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return true } - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return " https://modified.com " } - - cfg := promptFieldConfig{ - Label: "Base URL", CurrentValue: "https://current.com", - DefaultValue: "https://default.com", IsEdit: true, Exists: true, - } - result := promptForFieldEdit(context.Background(), cfg) - if result != "https://modified.com" { - t.Errorf("promptForFieldEdit() = %q, want %q", result, "https://modified.com") - } -} - -func TestPromptForFieldEdit_DeclineModify(t *testing.T) { - withMockedTAP(t) - tapConfirmFn = func(ctx context.Context, opts tap.ConfirmOptions) bool { return false } - - cfg := promptFieldConfig{ - Label: "Base URL", CurrentValue: "https://current.com", - DefaultValue: "https://default.com", IsEdit: true, Exists: true, - } - result := promptForFieldEdit(context.Background(), cfg) - if result != "https://current.com" { - t.Errorf("promptForFieldEdit() = %q, want %q", result, "https://current.com") - } -} - -func TestPromptForFieldEdit_NoCurrentNoDefault(t *testing.T) { - withMockedTAP(t) - tapTextFn = func(ctx context.Context, opts tap.TextOptions) string { return "user-entered-value" } - - cfg := promptFieldConfig{ - Label: "Base URL", CurrentValue: "", - DefaultValue: "", IsEdit: true, Exists: true, - } - result := promptForFieldEdit(context.Background(), cfg) - if result != "user-entered-value" { - t.Errorf("promptForFieldEdit() = %q, want %q", result, "user-entered-value") - } -} diff --git a/cmd/setup_provider.go b/cmd/setup_provider.go index 586f6eb..5799cd8 100644 --- a/cmd/setup_provider.go +++ b/cmd/setup_provider.go @@ -46,12 +46,12 @@ func GetProviderDefinition(providerName string) providers.ProviderDefinition { return definition } -func ResolveProviderName(ctx context.Context, providerName string) (string, error) { +func ResolveProviderName(providerName string) (string, error) { if providerName != "custom" { return providerName, nil } - customName := tapTextFn(ctx, tap.TextOptions{ + customName := tap.Text(context.Background(), tap.TextOptions{ Message: "Provider name", }) @@ -65,7 +65,7 @@ type modelValidationConfig struct { } func validateConfiguredModel(cfg modelValidationConfig) error { - if err := validate.ValidateProviderModel(cfg.ProviderName, cfg.Model); err != nil { + if err := validate.ValidateProviderModel(cfg.Model, cfg.DisplayName); err != nil { return err } if providers.IsBuiltInProvider(cfg.ProviderName) || strings.TrimSpace(cfg.Model) != "" { diff --git a/cmd/setup_test.go b/cmd/setup_test.go index c5c2b0f..fd054b8 100644 --- a/cmd/setup_test.go +++ b/cmd/setup_test.go @@ -226,13 +226,12 @@ func TestParseSecretsForIntegration(t *testing.T) { t.Fatal(err) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decrypted) - secretsMap := secretspkg.Parse(string(decrypted)) + secretsMap := secretspkg.Parse(decrypted) if len(secretsMap) != 3 { t.Errorf("ParseSecrets() returned %d entries, want 3", len(secretsMap)) @@ -270,13 +269,12 @@ func TestSecretsPreservationWhenAddingProvider(t *testing.T) { t.Fatal(err) } - secretsContent, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + secretsContent, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(secretsContent) - secretsMap := secretspkg.Parse(string(secretsContent)) + secretsMap := secretspkg.Parse(secretsContent) if len(secretsMap) != 2 { t.Errorf("ParseSecrets() returned %d entries, want 2", len(secretsMap)) } @@ -300,13 +298,12 @@ func TestSecretsPreservationWhenAddingProvider(t *testing.T) { t.Fatalf("EncryptSecrets(context.Background(), ) error = %v", err) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decrypted) - secretsMap = secretspkg.Parse(string(decrypted)) + secretsMap = secretspkg.Parse(decrypted) if len(secretsMap) != 3 { t.Errorf("After adding provider, expected 3 secrets, got %d", len(secretsMap)) } @@ -512,22 +509,21 @@ func TestCustomProviderKeyFormat(t *testing.T) { t.Fatal(err) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decrypted) expectedKey := fmt.Sprintf("%s_API_KEY=", customName) - if !strings.Contains(string(decrypted), expectedKey) { - t.Errorf("Decrypted secrets should contain %q, got: %q", expectedKey, string(decrypted)) + if !strings.Contains(decrypted, expectedKey) { + t.Errorf("Decrypted secrets should contain %q, got: %q", expectedKey, decrypted) } - if !strings.Contains(string(decrypted), "myprovider_API_KEY=sk-test-key-12345") { - t.Errorf("Decrypted secrets should contain 'myprovider_API_KEY=sk-test-key-12345', got: %q", string(decrypted)) + if !strings.Contains(decrypted, "myprovider_API_KEY=sk-test-key-12345") { + t.Errorf("Decrypted secrets should contain 'myprovider_API_KEY=sk-test-key-12345', got: %q", decrypted) } - for _, line := range strings.Split(string(decrypted), "\n") { + for _, line := range strings.Split(decrypted, "\n") { if strings.HasPrefix(line, expectedKey) { if strings.HasPrefix(line, "CUSTOM_") { t.Errorf("Custom provider key should NOT have CUSTOM_ prefix, got: %q", line) @@ -565,18 +561,17 @@ func TestCustomProviderKeyLookupInSwitch(t *testing.T) { t.Fatal(err) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes() error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer crypto.ClearMemory(decrypted) prefix := fmt.Sprintf("%s_API_KEY=", providerName) - if !strings.HasPrefix(string(decrypted), prefix) { - t.Errorf("Secrets should start with %q, got: %q", prefix, string(decrypted)) + if !strings.HasPrefix(decrypted, prefix) { + t.Errorf("Secrets should start with %q, got: %q", prefix, decrypted) } - for _, line := range strings.Split(string(decrypted), "\n") { + for _, line := range strings.Split(decrypted, "\n") { if line == "" { continue } @@ -1213,7 +1208,7 @@ func TestResolveProviderName_NonCustom(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ResolveProviderName(context.Background(), tt.providerName) + got, err := ResolveProviderName(tt.providerName) if (err != nil) != tt.wantErr { t.Errorf("ResolveProviderName() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/cmd/test_helpers_test.go b/cmd/test_helpers_test.go deleted file mode 100644 index a1c8325..0000000 --- a/cmd/test_helpers_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package cmd - -import ( - "context" - "testing" - - "github.com/dkmnx/kairo/internal/config" - "github.com/spf13/cobra" -) - -// withTempConfigDir creates a temporary config directory, sets it as the active -// config dir, and defers restoration. This replaces the 44+ repeated -// originalConfigDir/defer setConfigDir patterns across test files. -func withTempConfigDir(t *testing.T) string { - t.Helper() - originalConfigDir := getConfigDir() - tmpDir := t.TempDir() - setConfigDir(tmpDir) - t.Cleanup(func() { setConfigDir(originalConfigDir) }) - return tmpDir -} - -// newCommandWithContext creates a cobra.Command with a CLIContext attached. -func newCommandWithContext(cliCtx *CLIContext) *cobra.Command { - cmd := &cobra.Command{} - cmd.SetContext(WithCLIContext(context.Background(), cliCtx)) - return cmd -} - -// saveConfig creates and saves a config in the given temp dir for testing. -func saveConfig(t *testing.T, dir string, cfg *config.Config) { - t.Helper() - if err := config.SaveConfig(context.Background(), dir, cfg); err != nil { - t.Fatalf("saveConfig: %v", err) - } -} - -// mustCreateConfig creates a minimal valid config in the given directory. -func mustCreateConfig(t *testing.T, dir string, cfg *config.Config) { - t.Helper() - if cfg.Providers == nil { - cfg.Providers = make(map[string]config.Provider) - } - if cfg.DefaultModels == nil { - cfg.DefaultModels = make(map[string]string) - } - saveConfig(t, dir, cfg) -} - -// mustLoadConfig loads config from the given directory or fails the test. -func mustLoadConfig(t *testing.T, dir string) *config.Config { - t.Helper() - cfg, err := config.LoadConfig(context.Background(), dir) - if err != nil { - t.Fatalf("mustLoadConfig: %v", err) - } - return cfg -} diff --git a/cmd/update.go b/cmd/update.go index 9421cd0..5397cc6 100644 --- a/cmd/update.go +++ b/cmd/update.go @@ -32,10 +32,6 @@ const installScriptExtPS1 = ".ps1" var httpClient = &http.Client{ Timeout: requestTimeout, - Transport: &http.Transport{ - TLSHandshakeTimeout: 5 * time.Second, - ResponseHeaderTimeout: 5 * time.Second, - }, } var ( diff --git a/cmd/util.go b/cmd/util.go index 228b88a..6f81e22 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -17,7 +17,6 @@ func requireConfigDir(cmd *cobra.Command) string { if dir == "" { ui.PrintError("Config directory not found") } - return dir } @@ -28,10 +27,8 @@ func requireConfigDirWritable(cmd *cobra.Command) string { } if err := os.MkdirAll(dir, 0700); err != nil { ui.PrintError("Error creating config directory: " + err.Error()) - return "" } - return dir } @@ -47,14 +44,11 @@ func loadConfigOrExit(cmd *cobra.Command) *config.Config { if os.IsNotExist(err) { ui.PrintWarn("No providers configured") ui.PrintInfo("Run 'kairo setup' to get started") - return nil } handleConfigError(cmd, err) - return nil } - return cfg } @@ -80,7 +74,6 @@ func parseIntOrZero(input string) int { } result = result*10 + int(c-'0') } - return result } diff --git a/cmd/version.go b/cmd/version.go index 328852c..ee95bd5 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -7,8 +7,6 @@ import ( "github.com/spf13/cobra" ) -var versionNoUpdateCheck bool - var versionCmd = &cobra.Command{ Use: "version", Short: "Show version information", @@ -26,7 +24,7 @@ var versionCmd = &cobra.Command{ } } - if version.Version != "dev" && !versionNoUpdateCheck { + if version.Version != "dev" { checkForUpdates(cmd) } }, @@ -49,7 +47,5 @@ func checkForUpdates(cmd *cobra.Command) { } func init() { - versionCmd.Flags().BoolVar(&versionNoUpdateCheck, "no-update-check", false, - "Skip checking for updates") rootCmd.AddCommand(versionCmd) } diff --git a/cmd/version_test.go b/cmd/version_test.go index d169365..369fd7e 100644 --- a/cmd/version_test.go +++ b/cmd/version_test.go @@ -178,19 +178,3 @@ func TestCheckForUpdatesAPIError(t *testing.T) { t.Errorf("checkForUpdates() should NOT mention update on API error, got: %q", output) } } - -func TestVersionNoUpdateCheckFlag(t *testing.T) { - originalVersion := version.Version - version.Version = "v1.0.0" - defer func() { version.Version = originalVersion }() - - originalGetter := envGetter - envGetter = func(key string) (string, bool) { return "", false } - defer func() { envGetter = originalGetter }() - - rootCmd.SetArgs([]string{"version", "--no-update-check"}) - err := rootCmd.Execute() - if err != nil { - t.Fatalf("version --no-update-check error = %v", err) - } -} diff --git a/go.mod b/go.mod index 184c434..90f8d39 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/Masterminds/semver/v3 v3.4.0 github.com/spf13/cobra v1.10.2 github.com/yarlson/tap v0.13.1 - golang.org/x/sys v0.42.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -22,6 +21,7 @@ require ( github.com/mattn/go-tty v0.0.7 // indirect github.com/spf13/pflag v1.0.10 // indirect golang.org/x/crypto v0.49.0 // indirect + golang.org/x/sys v0.42.0 // indirect golang.org/x/term v0.41.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index adc5875..88dea30 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -272,7 +272,7 @@ func TestMigrateConfigFile(t *testing.T) { migrated, err := migrateConfigFile(context.Background(), tmpDir) if err != nil { - t.Fatalf("migrateConfigFile() error = %v", err) + t.Fatalf("migrateConfigFile(context.Background(), ) error = %v", err) } if migrated { t.Error("Expected no migration when old config doesn't exist") @@ -305,7 +305,7 @@ providers: migrated, err := migrateConfigFile(context.Background(), tmpDir) if err != nil { - t.Fatalf("migrateConfigFile() error = %v", err) + t.Fatalf("migrateConfigFile(context.Background(), ) error = %v", err) } if !migrated { t.Error("Expected migration to occur") @@ -356,7 +356,7 @@ providers: migrated, err := migrateConfigFile(context.Background(), tmpDir) if err != nil { - t.Fatalf("migrateConfigFile() error = %v", err) + t.Fatalf("migrateConfigFile(context.Background(), ) error = %v", err) } if migrated { t.Error("Should not migrate when new config already exists") @@ -421,7 +421,7 @@ providers: migrated, err := migrateConfigFile(context.Background(), tmpDir) if err != nil { - t.Fatalf("migrateConfigFile() error = %v", err) + t.Fatalf("migrateConfigFile(context.Background(), ) error = %v", err) } if !migrated { t.Error("Expected migration to occur") diff --git a/internal/config/loader.go b/internal/config/loader.go index 7aa03f2..8357292 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -24,108 +24,57 @@ type Provider struct { EnvVars []string `yaml:"env_vars"` } -func checkCtx(ctx context.Context) error { - return kairoerrors.CheckContext(ctx) -} - func migrateConfigFile(ctx context.Context, configDir string) (bool, error) { - if err := checkCtx(ctx); err != nil { - return false, err - } - oldConfigPath := filepath.Join(configDir, "config") newConfigPath := filepath.Join(configDir, "config.yaml") - oldInfo, err := statOldConfig(oldConfigPath) - if err != nil || oldInfo == nil { - return false, err - } - - if newExists, err := checkNewConfig(newConfigPath); err != nil || newExists { - return false, err - } - - data, err := readAndValidateConfig(oldConfigPath) - if err != nil { - return false, err - } - - if err := writeMigratedConfig(ctx, newConfigPath, data, oldInfo.Mode()); err != nil { - return false, err - } - - if err := finalizeMigration(ctx, oldConfigPath, newConfigPath); err != nil { + if err := kairoerrors.CheckContext(ctx); err != nil { return false, err } - return true, nil -} - -func statOldConfig(oldConfigPath string) (os.FileInfo, error) { oldInfo, err := os.Stat(oldConfigPath) if err != nil { if os.IsNotExist(err) { - return nil, nil + return false, nil } - return nil, kairoerrors.WrapError(kairoerrors.FileSystemError, + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, "failed to check old config file", err) } - return oldInfo, nil -} - -func checkNewConfig(newConfigPath string) (bool, error) { if _, err := os.Stat(newConfigPath); err == nil { - return true, nil + return false, nil } else if !os.IsNotExist(err) { return false, kairoerrors.WrapError(kairoerrors.FileSystemError, "failed to check new config file", err) } - return false, nil -} - -func readAndValidateConfig(oldConfigPath string) ([]byte, error) { data, err := os.ReadFile(oldConfigPath) if err != nil { - return nil, kairoerrors.WrapError(kairoerrors.FileSystemError, + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, "failed to read old config file", err) } var tempCfg Config if err := yaml.Unmarshal(data, &tempCfg); err != nil { - return nil, kairoerrors.WrapError(kairoerrors.ConfigError, + return false, kairoerrors.WrapError(kairoerrors.ConfigError, "old config file is not valid YAML, cannot migrate", err) } - return data, nil -} - -func writeMigratedConfig(ctx context.Context, newConfigPath string, data []byte, mode os.FileMode) error { - if err := checkCtx(ctx); err != nil { - return err - } - - return os.WriteFile(newConfigPath, data, mode) -} - -func finalizeMigration(ctx context.Context, oldConfigPath, newConfigPath string) error { - if err := checkCtx(ctx); err != nil { - os.Remove(newConfigPath) - - return err + if err := os.WriteFile(newConfigPath, data, oldInfo.Mode()); err != nil { + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, + "failed to write migrated config file", err) } backupPath := oldConfigPath + ".backup" if err := os.Rename(oldConfigPath, backupPath); err != nil { os.Remove(newConfigPath) - return kairoerrors.WrapError(kairoerrors.FileSystemError, + return false, kairoerrors.WrapError(kairoerrors.FileSystemError, "failed to backup old config file", err) } - return nil + return true, nil } func LoadConfig(ctx context.Context, configDir string) (*Config, error) { @@ -144,6 +93,10 @@ func LoadConfig(ctx context.Context, configDir string) (*Config, error) { WithContext("hint", "ensure you have write permissions in the config directory") } + if err := kairoerrors.CheckContext(ctx); err != nil { + return nil, err + } + data, err := os.ReadFile(configPath) if err != nil { if os.IsNotExist(err) { diff --git a/internal/crypto/age.go b/internal/crypto/age.go index e6ae36b..5ef8491 100644 --- a/internal/crypto/age.go +++ b/internal/crypto/age.go @@ -136,6 +136,19 @@ func EncryptSecrets(ctx context.Context, secretsPath, keyPath, secrets string) e return nil } +func DecryptSecrets(ctx context.Context, secretsPath, keyPath string) (string, error) { + if err := kairoerrors.CheckContext(ctx); err != nil { + return "", err + } + + var buf bytes.Buffer + if err := decryptToBuffer(ctx, secretsPath, keyPath, &buf); err != nil { + return "", err + } + + return buf.String(), nil +} + func ClearMemory(b []byte) { for i := range b { b[i] = 0 diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go index 2436c9d..a55db48 100644 --- a/internal/crypto/crypto_test.go +++ b/internal/crypto/crypto_test.go @@ -51,13 +51,12 @@ MINIMAX_API_KEY=sk-another-key t.Fatalf("EncryptSecrets(context.Background(), ) error = %v", err) } - decrypted, err := DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { - t.Fatalf("DecryptSecretsBytes(context.Background(), ) error = %v", err) + t.Fatalf("DecryptSecrets(context.Background(), ) error = %v", err) } - defer ClearMemory(decrypted) - if string(decrypted) != secrets { + if decrypted != secrets { t.Errorf("decrypted = %q, want %q", decrypted, secrets) } } @@ -72,9 +71,9 @@ func TestDecryptInvalidFile(t *testing.T) { } secretsPath := filepath.Join(tmpDir, "nonexistent.age") - _, err = DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + _, err = DecryptSecrets(context.Background(), secretsPath, keyPath) if err == nil { - t.Error("DecryptSecretsBytes(context.Background(), ) should error on nonexistent file") + t.Error("DecryptSecrets(context.Background(), ) should error on nonexistent file") } } @@ -292,9 +291,9 @@ func TestDecryptSecretsWithInvalidKeyPath(t *testing.T) { tmpDir := t.TempDir() secretsPath := filepath.Join(tmpDir, "secrets.age") - _, err := DecryptSecretsBytes(context.Background(), secretsPath, "/nonexistent/path/key") + _, err := DecryptSecrets(context.Background(), secretsPath, "/nonexistent/path/key") if err == nil { - t.Error("DecryptSecretsBytes(context.Background(), ) should error on invalid key path") + t.Error("DecryptSecrets(context.Background(), ) should error on invalid key path") } } @@ -346,9 +345,9 @@ func TestDecryptCorruptedFile(t *testing.T) { t.Fatal(err) } - _, err = DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + _, err = DecryptSecrets(context.Background(), secretsPath, keyPath) if err == nil { - t.Error("DecryptSecretsBytes(context.Background(), ) should error on corrupted file") + t.Error("DecryptSecrets(context.Background(), ) should error on corrupted file") } } @@ -379,9 +378,9 @@ func TestDecryptTruncatedFile(t *testing.T) { t.Fatal(err) } - _, err = DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + _, err = DecryptSecrets(context.Background(), secretsPath, keyPath) if err == nil { - t.Error("DecryptSecretsBytes(context.Background(), ) should error on truncated file") + t.Error("DecryptSecrets(context.Background(), ) should error on truncated file") } } @@ -404,9 +403,9 @@ func TestDecryptRandomData(t *testing.T) { t.Fatal(err) } - _, err = DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + _, err = DecryptSecrets(context.Background(), secretsPath, keyPath) if err == nil { - t.Error("DecryptSecretsBytes(context.Background(), ) should error on random data") + t.Error("DecryptSecrets(context.Background(), ) should error on random data") } } @@ -622,9 +621,9 @@ func TestDecryptSecrets_OpenError(t *testing.T) { } defer os.Chmod(secretsPath, 0644) // Clean up - _, err := DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + _, err := DecryptSecrets(context.Background(), secretsPath, keyPath) if err == nil { - t.Error("DecryptSecretsBytes(context.Background(), ) should error when secrets file is unreadable") + t.Error("DecryptSecrets(context.Background(), ) should error when secrets file is unreadable") } } diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 686e839..67a0280 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -49,7 +49,6 @@ func (e *KairoError) Error() string { } b.WriteString(")") } - return b.String() } @@ -62,7 +61,6 @@ func (e *KairoError) Is(target error) bool { if !ok { return false } - return e.Type == t.Type } @@ -86,7 +84,6 @@ func (e *KairoError) WithContext(key, value string) *KairoError { e.Context = make(map[string]string) } e.Context[key] = value - return e } diff --git a/internal/providers/registry.go b/internal/providers/registry.go index 7576df1..849b8d2 100644 --- a/internal/providers/registry.go +++ b/internal/providers/registry.go @@ -59,13 +59,11 @@ type ProviderDefinition struct { func IsBuiltInProvider(name string) bool { _, ok := BuiltInProviders[name] - return ok } func GetBuiltInProvider(name string) (ProviderDefinition, bool) { def, ok := BuiltInProviders[name] - return def, ok } @@ -80,6 +78,5 @@ func RequiresAPIKey(name string) bool { if !ok { return true } - return def.RequiresAPIKey } diff --git a/internal/ui/prompt.go b/internal/ui/prompt.go index 0f6f6fa..12dab60 100644 --- a/internal/ui/prompt.go +++ b/internal/ui/prompt.go @@ -34,53 +34,30 @@ func ClearScreen() { _ = cmd.Run() } -var enableANSI = supportsANSI() - func PrintSuccess(msg string) { - if enableANSI { - fmt.Printf("\n%s✓%s %s%s\n", Green, Reset, msg, Reset) - } else { - fmt.Printf("\n✓ %s\n", msg) - } + fmt.Printf("\n%s✓%s %s%s\n", Green, Reset, msg, Reset) } func PrintWarn(msg string) { - if enableANSI { - fmt.Printf("%s⚠%s %s%s\n", Yellow, Reset, msg, Reset) - } else { - fmt.Printf("⚠ %s\n", msg) - } + fmt.Printf("%s⚠%s %s%s\n", Yellow, Reset, msg, Reset) } func PrintError(msg string) { - if enableANSI { - fmt.Fprintf(os.Stderr, "%s✗%s %s%s\n", Red, Reset, msg, Reset) - } else { - fmt.Fprintf(os.Stderr, "✗ %s\n", msg) - } + fmt.Fprintf(os.Stderr, "%s✗%s %s%s\n", Red, Reset, msg, Reset) } func PrintInfo(msg string) { - if enableANSI { - fmt.Printf("%s%s\n", Blue, msg) - } else { - fmt.Printf(" %s\n", msg) - } + fmt.Printf("%s%s\n", Blue, msg) } func PrintWhite(msg string) { - if enableANSI { - fmt.Printf("%s%s%s\n", White, msg, Reset) - } else { - fmt.Printf("%s\n", msg) - } + fmt.Printf("%s%s%s\n", White, msg, Reset) } func isInterrupted(err error) bool { if err == nil { return false } - return errors.Is(err, os.ErrClosed) || errors.Is(err, io.EOF) || strings.Contains(err.Error(), "interrupted") } @@ -88,17 +65,12 @@ func isEmptyInput(err error) bool { if err == nil { return false } - return !errors.Is(err, io.EOF) && !isInterrupted(err) } func PrintBanner(version, modelName, providerName string) { banner := fmt.Sprintf("kairo %s · %s · %s", version, modelName, providerName) - if enableANSI { - fmt.Printf("%s%s%s\n\n", Gray, banner, Reset) - } else { - fmt.Printf("%s\n\n", banner) - } + fmt.Printf("%s%s%s\n\n", Gray, banner, Reset) } func Confirm(prompt string) (bool, error) { @@ -112,10 +84,8 @@ func Confirm(prompt string) (bool, error) { if errors.Is(err, io.EOF) || isInterrupted(err) { return false, kairoerrors.ErrUserCancelled } - return false, err } input = strings.TrimSpace(strings.ToLower(input)) - return input == "y" || input == "yes", nil } diff --git a/internal/ui/prompt_ansi.go b/internal/ui/prompt_ansi.go deleted file mode 100644 index dbb3d2c..0000000 --- a/internal/ui/prompt_ansi.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !windows - -package ui - -import "os" - -func supportsANSI() bool { - return os.Getenv("TERM") != "dumb" -} diff --git a/internal/ui/prompt_ansi_windows.go b/internal/ui/prompt_ansi_windows.go deleted file mode 100644 index 049e5f5..0000000 --- a/internal/ui/prompt_ansi_windows.go +++ /dev/null @@ -1,30 +0,0 @@ -package ui - -import ( - "os" - - "golang.org/x/sys/windows" -) - -func supportsANSI() bool { - if os.Getenv("TERM") == "dumb" { - return false - } - - handle := windows.GetStdHandle(windows.STD_OUTPUT_HANDLE) - if handle == windows.InvalidHandle { - return false - } - - var mode uint32 - if err := windows.GetConsoleMode(handle, &mode); err != nil { - return false - } - - mode |= windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING - if err := windows.SetConsoleMode(handle, mode); err != nil { - return false - } - - return true -} diff --git a/internal/validate/api_key.go b/internal/validate/api_key.go index cb04443..b6f3373 100644 --- a/internal/validate/api_key.go +++ b/internal/validate/api_key.go @@ -2,7 +2,6 @@ package validate import ( "fmt" - "log" "net" "net/url" "regexp" @@ -71,7 +70,7 @@ var ( func mustParseCIDR(s string) net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { - log.Fatalf("invalid CIDR %s: %v", s, err) + panic(fmt.Sprintf("invalid CIDR %s: %v", s, err)) } return *ipnet diff --git a/internal/wrapper/wrapper.go b/internal/wrapper/wrapper.go index baf9bd6..497d216 100644 --- a/internal/wrapper/wrapper.go +++ b/internal/wrapper/wrapper.go @@ -3,6 +3,7 @@ package wrapper import ( "fmt" "os" + "os/exec" "runtime" "strings" @@ -185,3 +186,7 @@ func generateUnixScript(envVar string, cfg ScriptConfig) string { return sb.String() } + +func ExecCommand(name string, arg ...string) *exec.Cmd { + return exec.Command(name, arg...) +} diff --git a/internal/wrapper/wrapper_test.go b/internal/wrapper/wrapper_test.go index 928cb0b..aa782cf 100644 --- a/internal/wrapper/wrapper_test.go +++ b/internal/wrapper/wrapper_test.go @@ -551,6 +551,16 @@ func TestGenerateWrapperScript_ControlCharacterEscaping(t *testing.T) { } } +func TestExecCommand(t *testing.T) { + cmd := ExecCommand("echo", "test") + if cmd == nil { + t.Fatal("ExecCommand() should return a valid command") + } + if len(cmd.Args) != 2 { + t.Errorf("Expected 2 args, got %d", len(cmd.Args)) + } +} + func TestGenerateWrapperScript_WithArgs(t *testing.T) { authDir := t.TempDir() tokenPath := filepath.Join(authDir, "token") diff --git a/tests/integration/full_workflow_test.go b/tests/integration/full_workflow_test.go index 687c8e3..f6f4d63 100644 --- a/tests/integration/full_workflow_test.go +++ b/tests/integration/full_workflow_test.go @@ -110,15 +110,14 @@ func TestFullWorkflowSetupConfigAndSwitch(t *testing.T) { t.Errorf("default provider = %q, want 'zai'", loadedCfg.DefaultProvider) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { t.Fatalf("failed to decrypt secrets: %v", err) } - defer crypto.ClearMemory(decrypted) - if !strings.Contains(string(decrypted), "ZAI_API_KEY") { + if !strings.Contains(decrypted, "ZAI_API_KEY") { t.Error("secrets should contain ZAI_API_KEY") } - if !strings.Contains(string(decrypted), "MINIMAX_API_KEY") { + if !strings.Contains(decrypted, "MINIMAX_API_KEY") { t.Error("secrets should contain MINIMAX_API_KEY") } } @@ -223,13 +222,12 @@ DEEPSEEK_API_KEY=TEST-KEY-DO-NOT-USE-list-deepseek t.Errorf("default provider = %q, want 'zai'", loadedCfg.DefaultProvider) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { t.Fatalf("failed to decrypt secrets: %v", err) } - defer crypto.ClearMemory(decrypted) - parsedSecrets := secrets.Parse(string(decrypted)) + parsedSecrets := secrets.Parse(decrypted) if _, exists := parsedSecrets["ZAI_API_KEY"]; !exists { t.Error("ZAI_API_KEY should exist") } @@ -323,15 +321,14 @@ func TestFullWorkflowCustomProvider(t *testing.T) { t.Errorf("custom provider model = %q, want 'custom-model-v1'", customProvider.Model) } - decrypted, err := crypto.DecryptSecretsBytes(context.Background(), secretsPath, keyPath) + decrypted, err := crypto.DecryptSecrets(context.Background(), secretsPath, keyPath) if err != nil { t.Fatalf("failed to decrypt secrets: %v", err) } - defer crypto.ClearMemory(decrypted) - if !strings.Contains(string(decrypted), "MYCUSTOM_API_KEY") { + if !strings.Contains(decrypted, "MYCUSTOM_API_KEY") { t.Error("secrets should contain MYCUSTOM_API_KEY") } - if !strings.Contains(string(decrypted), "TEST-KEY-DO-NOT-USE-custom-provider") { + if !strings.Contains(decrypted, "TEST-KEY-DO-NOT-USE-custom-provider") { t.Error("secrets should contain the custom API key") } }