diff --git a/cmd/auth/login.go b/cmd/auth/login.go index a7a9420..15ebe31 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -18,10 +18,10 @@ var DeviceLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error return internalauth.DeviceLogin(host, ios) } -// SetTokenFunc is the function used to store a token in the keyring. -// Tests may override this to avoid touching the real keyring. +// SetTokenFunc is the function used to store a token in hosts.yml. +// Tests may override this to avoid touching the real config. var SetTokenFunc = func(host, token string) error { - return internalauth.SetToken(host, token) + return config.SetHostToken(host, token) } // FetchTokenInfoFunc is the function used to fetch session details from the server. diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 97d59ca..5eb8eab 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -43,6 +43,7 @@ func TestLoginCmd_DefaultDeviceFlow(t *testing.T) { } func TestLoginCmd_WithTokenFlag(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) ios, buf, _, _ := iostreams.Test() ios.In = strings.NewReader("my-token-from-stdin\n") diff --git a/internal/auth/device.go b/internal/auth/device.go index 5d04fca..ec9d65f 100644 --- a/internal/auth/device.go +++ b/internal/auth/device.go @@ -63,7 +63,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) { return "", err } - if storeErr := SetToken(host, token); storeErr != nil { + if storeErr := config.SetHostToken(host, token); storeErr != nil { return "", fmt.Errorf("storing token: %w", storeErr) } diff --git a/internal/auth/resolver.go b/internal/auth/resolver.go index b1b038f..6f865b6 100644 --- a/internal/auth/resolver.go +++ b/internal/auth/resolver.go @@ -8,28 +8,29 @@ import ( // ResolveToken resolves the auth token for host using the priority chain: // 1. KH_API_KEY environment variable -// 2. OS keyring -// 3. hosts.yml token field +// 2. hosts.yml token field +// 3. OS keyring (legacy fallback -- tokens are migrated to hosts.yml on login) // Returns a ResolvedToken with Method set to AuthMethodNone if no token found. func ResolveToken(host string) (ResolvedToken, error) { if apiKey := os.Getenv("KH_API_KEY"); apiKey != "" { return ResolvedToken{Token: apiKey, Method: AuthMethodAPIKey, Host: host}, nil } - token, err := GetToken(host) + hosts, err := config.ReadHosts() if err != nil { return ResolvedToken{}, err } - if token != "" { - return ResolvedToken{Token: token, Method: AuthMethodToken, Host: host}, nil + if entry, ok := hosts.HostEntry(host); ok && entry.Token != "" { + return ResolvedToken{Token: entry.Token, Method: AuthMethodToken, Host: host}, nil } - hosts, err := config.ReadHosts() + // Legacy fallback: check OS keyring for tokens stored before the hosts.yml migration. + token, err := GetToken(host) if err != nil { return ResolvedToken{}, err } - if entry, ok := hosts.HostEntry(host); ok && entry.Token != "" { - return ResolvedToken{Token: entry.Token, Method: AuthMethodToken, Host: host}, nil + if token != "" { + return ResolvedToken{Token: token, Method: AuthMethodToken, Host: host}, nil } return ResolvedToken{Method: AuthMethodNone, Host: host}, nil diff --git a/internal/auth/resolver_test.go b/internal/auth/resolver_test.go index 8e9e17d..f9de2f3 100644 --- a/internal/auth/resolver_test.go +++ b/internal/auth/resolver_test.go @@ -49,18 +49,18 @@ func writeHostsFile(t *testing.T, host, token string) { content := "hosts:\n " + host + ":\n token: " + token + "\n" require.NoError(t, os.WriteFile(hostsFile, []byte(content), 0o600)) - origHome := os.Getenv("XDG_CONFIG_HOME") - require.NoError(t, os.Setenv("XDG_CONFIG_HOME", dir)) - t.Cleanup(func() { - if origHome == "" { - os.Unsetenv("XDG_CONFIG_HOME") - } else { - os.Setenv("XDG_CONFIG_HOME", origHome) - } - }) + t.Setenv("XDG_CONFIG_HOME", dir) +} + +// isolateHostsFile sets XDG_CONFIG_HOME to an empty temp dir so +// ResolveToken does not read the user's real hosts.yml. +func isolateHostsFile(t *testing.T) { + t.Helper() + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) } func TestResolveToken_EnvVar(t *testing.T) { + isolateHostsFile(t) setupEmptyKeyring(t) t.Setenv("KH_API_KEY", "kh_test123") @@ -71,28 +71,31 @@ func TestResolveToken_EnvVar(t *testing.T) { require.Equal(t, testHost, rt.Host) } -func TestResolveToken_Keyring(t *testing.T) { - setupKeyringWithToken(t, testHost, "keyring_token_abc") +func TestResolveToken_HostsYML(t *testing.T) { + setupEmptyKeyring(t) t.Setenv("KH_API_KEY", "") + writeHostsFile(t, testHost, "hosts_yml_token") rt, err := ResolveToken(testHost) require.NoError(t, err) - require.Equal(t, "keyring_token_abc", rt.Token) + require.Equal(t, "hosts_yml_token", rt.Token) require.Equal(t, AuthMethodToken, rt.Method) } -func TestResolveToken_HostsYML(t *testing.T) { - setupEmptyKeyring(t) +func TestResolveToken_KeyringLegacyFallback(t *testing.T) { + // hosts.yml has no token for this host, so the keyring fallback is used. + isolateHostsFile(t) + setupKeyringWithToken(t, testHost, "keyring_token_abc") t.Setenv("KH_API_KEY", "") - writeHostsFile(t, testHost, "hosts_yml_token") rt, err := ResolveToken(testHost) require.NoError(t, err) - require.Equal(t, "hosts_yml_token", rt.Token) + require.Equal(t, "keyring_token_abc", rt.Token) require.Equal(t, AuthMethodToken, rt.Method) } func TestResolveToken_None(t *testing.T) { + isolateHostsFile(t) setupEmptyKeyring(t) t.Setenv("KH_API_KEY", "") @@ -103,6 +106,7 @@ func TestResolveToken_None(t *testing.T) { } func TestResolveToken_EnvVarPriority(t *testing.T) { + isolateHostsFile(t) setupKeyringWithToken(t, testHost, "keyring_token") t.Setenv("KH_API_KEY", "env_wins") @@ -112,13 +116,14 @@ func TestResolveToken_EnvVarPriority(t *testing.T) { require.Equal(t, AuthMethodAPIKey, rt.Method) } -func TestResolveToken_KeyringOverHostsYML(t *testing.T) { - setupKeyringWithToken(t, testHost, "keyring_takes_precedence") +func TestResolveToken_HostsYMLOverKeyring(t *testing.T) { + // hosts.yml now takes priority over the legacy keyring. + setupKeyringWithToken(t, testHost, "keyring_token") t.Setenv("KH_API_KEY", "") - writeHostsFile(t, testHost, "hosts_yml_token") + writeHostsFile(t, testHost, "hosts_yml_wins") rt, err := ResolveToken(testHost) require.NoError(t, err) - require.Equal(t, "keyring_takes_precedence", rt.Token) + require.Equal(t, "hosts_yml_wins", rt.Token) require.Equal(t, AuthMethodToken, rt.Method) } diff --git a/internal/auth/token.go b/internal/auth/token.go index 7a1bf84..e5fe684 100644 --- a/internal/auth/token.go +++ b/internal/auth/token.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "time" khhttp "github.com/keeperhub/cli/internal/http" @@ -62,8 +63,16 @@ type orgMembership struct { } // FetchTokenInfo queries the server for session details using the given token. -// Returns TokenInfo on success, or an error if the token is invalid. +// For session tokens, calls /api/auth/get-session. +// For API keys (kh_ prefix), validates via /api/workflows and returns basic info. func FetchTokenInfo(host, token string) (TokenInfo, error) { + if strings.HasPrefix(token, "kh_") { + return fetchAPIKeyInfo(host, token) + } + return fetchSessionInfo(host, token) +} + +func fetchSessionInfo(host, token string) (TokenInfo, error) { client := &http.Client{Timeout: 10 * time.Second} req, err := http.NewRequest(http.MethodGet, khhttp.BuildBaseURL(host)+"/api/auth/get-session", nil) @@ -122,6 +131,44 @@ func FetchTokenInfo(host, token string) (TokenInfo, error) { return info, nil } +// fetchAPIKeyInfo validates an API key by calling an authenticated endpoint. +// API keys don't have session data, so we probe /api/workflows to confirm +// the key is accepted and extract org info from the response context. +func fetchAPIKeyInfo(host, token string) (TokenInfo, error) { + client := &http.Client{Timeout: 10 * time.Second} + + req, err := http.NewRequest(http.MethodGet, khhttp.BuildBaseURL(host)+"/api/workflows?limit=1", nil) + if err != nil { + return TokenInfo{}, fmt.Errorf("creating validation request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + return TokenInfo{}, fmt.Errorf("validating API key: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusUnauthorized { + return TokenInfo{}, fmt.Errorf("API key is invalid or revoked") + } + if resp.StatusCode != http.StatusOK { + return TokenInfo{}, fmt.Errorf("API key validation failed (status %d)", resp.StatusCode) + } + + prefix := token + if len(prefix) > 14 { + prefix = prefix[:14] + } + + return TokenInfo{ + Email: prefix + "...", + Name: "API Key", + Method: AuthMethodAPIKey, + Role: "api-key", + }, nil +} + func fetchOrgDetails(client *http.Client, host, token, orgID string) (string, string) { req, err := http.NewRequest(http.MethodGet, khhttp.BuildBaseURL(host)+"/api/organizations/"+orgID, nil) if err != nil { diff --git a/internal/config/hosts.go b/internal/config/hosts.go index 24c4b71..43ec416 100644 --- a/internal/config/hosts.go +++ b/internal/config/hosts.go @@ -3,8 +3,10 @@ package config import ( "errors" "fmt" + "net/url" "os" "path/filepath" + "strings" "gopkg.in/yaml.v3" ) @@ -83,14 +85,52 @@ func (h *HostsConfig) ActiveHost(flagHost, envHost string) string { } // HostEntry looks up the HostConfig for the given hostname. -// Returns the entry and true if found, or an empty HostConfig and false otherwise. +// It tries the raw value first, then falls back to a bare hostname (scheme stripped) +// so that --host https://app-staging.keeperhub.com matches a hosts.yml key of +// app-staging.keeperhub.com. func (h *HostsConfig) HostEntry(hostname string) (HostConfig, bool) { - entry, ok := h.Hosts[hostname] - return entry, ok + if entry, ok := h.Hosts[hostname]; ok { + return entry, true + } + bare := stripScheme(hostname) + if bare != hostname { + if entry, ok := h.Hosts[bare]; ok { + return entry, true + } + } + return HostConfig{}, false +} + +// stripScheme removes the URL scheme (http:// or https://) and any trailing +// slash from a host string, returning just the hostname[:port]. +func stripScheme(host string) string { + if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") { + if u, err := url.Parse(host); err == nil { + return u.Host + } + } + return strings.TrimRight(host, "/") +} + +// resolveHostKey returns the canonical key for a host in the hosts map. +// If a bare-hostname entry already exists for a full-URL host, it returns the +// bare hostname so the token is stored alongside existing headers (e.g. CF-Access). +func resolveHostKey(hosts map[string]HostConfig, host string) string { + if _, ok := hosts[host]; ok { + return host + } + bare := stripScheme(host) + if bare != host { + if _, ok := hosts[bare]; ok { + return bare + } + } + // No existing entry -- prefer bare hostname for cleanliness. + return bare } // SetHostToken updates the token for a specific host in the hosts config. -// If the host entry doesn't exist, it creates one. +// If a bare-hostname entry already exists, the token is merged into it. func SetHostToken(host, token string) error { hosts, err := ReadHosts() if err != nil { @@ -99,9 +139,10 @@ func SetHostToken(host, token string) error { if hosts.Hosts == nil { hosts.Hosts = make(map[string]HostConfig) } - entry := hosts.Hosts[host] + key := resolveHostKey(hosts.Hosts, host) + entry := hosts.Hosts[key] entry.Token = token - hosts.Hosts[host] = entry + hosts.Hosts[key] = entry return WriteHosts(hosts) } @@ -112,9 +153,10 @@ func ClearHostToken(host string) error { if err != nil { return err } - if entry, ok := hosts.Hosts[host]; ok { + key := resolveHostKey(hosts.Hosts, host) + if entry, ok := hosts.Hosts[key]; ok { entry.Token = "" - hosts.Hosts[host] = entry + hosts.Hosts[key] = entry return WriteHosts(hosts) } return nil diff --git a/internal/config/hosts_test.go b/internal/config/hosts_test.go index 8b79a64..c7d85d6 100644 --- a/internal/config/hosts_test.go +++ b/internal/config/hosts_test.go @@ -120,6 +120,46 @@ func TestHostEntryLookup(t *testing.T) { } } +// TestHostEntrySchemeStripping verifies that HostEntry falls back to a bare +// hostname when the caller passes a full URL (e.g. --host https://staging.example.com). +func TestHostEntrySchemeStripping(t *testing.T) { + h := config.HostsConfig{ + Hosts: map[string]config.HostConfig{ + "app-staging.keeperhub.com": { + Headers: map[string]string{ + "CF-Access-Client-Id": "abc", + "CF-Access-Client-Secret": "def", + }, + }, + }, + } + + // Full https:// URL should match bare hostname key + entry, ok := h.HostEntry("https://app-staging.keeperhub.com") + if !ok { + t.Fatal("expected HostEntry to match via scheme stripping") + } + if entry.Headers["CF-Access-Client-Id"] != "abc" { + t.Errorf("CF-Access-Client-Id: got %q, want %q", entry.Headers["CF-Access-Client-Id"], "abc") + } + + // Full http:// URL should also match + h.Hosts["localhost:3000"] = config.HostConfig{Headers: map[string]string{"X-Test": "1"}} + entry, ok = h.HostEntry("http://localhost:3000") + if !ok { + t.Fatal("expected HostEntry to match http:// via scheme stripping") + } + if entry.Headers["X-Test"] != "1" { + t.Errorf("X-Test: got %q, want %q", entry.Headers["X-Test"], "1") + } + + // Bare hostname should still work directly + entry, ok = h.HostEntry("app-staging.keeperhub.com") + if !ok { + t.Fatal("expected direct bare hostname match") + } +} + func TestHostsYAMLCFAccessHeaders(t *testing.T) { rawYAML := `hosts: app-staging.keeperhub.com: