Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ var DeviceLoginFunc = func(host string, ios *iostreams.IOStreams) (string, error
return internalauth.DeviceLogin(host, ios)
}

// SetTokenFunc is the function used to store a token in the keyring.
// Tests may override this to avoid touching the real keyring.
// SetTokenFunc is the function used to store a token in hosts.yml.
// Tests may override this to avoid touching the real config.
var SetTokenFunc = func(host, token string) error {
return internalauth.SetToken(host, token)
return config.SetHostToken(host, token)
}

// FetchTokenInfoFunc is the function used to fetch session details from the server.
Expand Down
1 change: 1 addition & 0 deletions cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func TestLoginCmd_DefaultDeviceFlow(t *testing.T) {
}

func TestLoginCmd_WithTokenFlag(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", t.TempDir())
ios, buf, _, _ := iostreams.Test()
ios.In = strings.NewReader("my-token-from-stdin\n")

Expand Down
2 changes: 1 addition & 1 deletion internal/auth/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func DeviceLogin(host string, ios *iostreams.IOStreams) (string, error) {
return "", err
}

if storeErr := SetToken(host, token); storeErr != nil {
if storeErr := config.SetHostToken(host, token); storeErr != nil {
return "", fmt.Errorf("storing token: %w", storeErr)
}

Expand Down
17 changes: 9 additions & 8 deletions internal/auth/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,29 @@ import (

// ResolveToken resolves the auth token for host using the priority chain:
// 1. KH_API_KEY environment variable
// 2. OS keyring
// 3. hosts.yml token field
// 2. hosts.yml token field
// 3. OS keyring (legacy fallback -- tokens are migrated to hosts.yml on login)
// Returns a ResolvedToken with Method set to AuthMethodNone if no token found.
func ResolveToken(host string) (ResolvedToken, error) {
if apiKey := os.Getenv("KH_API_KEY"); apiKey != "" {
return ResolvedToken{Token: apiKey, Method: AuthMethodAPIKey, Host: host}, nil
}

token, err := GetToken(host)
hosts, err := config.ReadHosts()
if err != nil {
return ResolvedToken{}, err
}
if token != "" {
return ResolvedToken{Token: token, Method: AuthMethodToken, Host: host}, nil
if entry, ok := hosts.HostEntry(host); ok && entry.Token != "" {
return ResolvedToken{Token: entry.Token, Method: AuthMethodToken, Host: host}, nil
}

hosts, err := config.ReadHosts()
// Legacy fallback: check OS keyring for tokens stored before the hosts.yml migration.
token, err := GetToken(host)
if err != nil {
return ResolvedToken{}, err
}
if entry, ok := hosts.HostEntry(host); ok && entry.Token != "" {
return ResolvedToken{Token: entry.Token, Method: AuthMethodToken, Host: host}, nil
if token != "" {
return ResolvedToken{Token: token, Method: AuthMethodToken, Host: host}, nil
}

return ResolvedToken{Method: AuthMethodNone, Host: host}, nil
Expand Down
45 changes: 25 additions & 20 deletions internal/auth/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ func writeHostsFile(t *testing.T, host, token string) {
content := "hosts:\n " + host + ":\n token: " + token + "\n"
require.NoError(t, os.WriteFile(hostsFile, []byte(content), 0o600))

origHome := os.Getenv("XDG_CONFIG_HOME")
require.NoError(t, os.Setenv("XDG_CONFIG_HOME", dir))
t.Cleanup(func() {
if origHome == "" {
os.Unsetenv("XDG_CONFIG_HOME")
} else {
os.Setenv("XDG_CONFIG_HOME", origHome)
}
})
t.Setenv("XDG_CONFIG_HOME", dir)
}

// isolateHostsFile sets XDG_CONFIG_HOME to an empty temp dir so
// ResolveToken does not read the user's real hosts.yml.
func isolateHostsFile(t *testing.T) {
t.Helper()
t.Setenv("XDG_CONFIG_HOME", t.TempDir())
}

func TestResolveToken_EnvVar(t *testing.T) {
isolateHostsFile(t)
setupEmptyKeyring(t)
t.Setenv("KH_API_KEY", "kh_test123")

Expand All @@ -71,28 +71,31 @@ func TestResolveToken_EnvVar(t *testing.T) {
require.Equal(t, testHost, rt.Host)
}

func TestResolveToken_Keyring(t *testing.T) {
setupKeyringWithToken(t, testHost, "keyring_token_abc")
func TestResolveToken_HostsYML(t *testing.T) {
setupEmptyKeyring(t)
t.Setenv("KH_API_KEY", "")
writeHostsFile(t, testHost, "hosts_yml_token")

rt, err := ResolveToken(testHost)
require.NoError(t, err)
require.Equal(t, "keyring_token_abc", rt.Token)
require.Equal(t, "hosts_yml_token", rt.Token)
require.Equal(t, AuthMethodToken, rt.Method)
}

func TestResolveToken_HostsYML(t *testing.T) {
setupEmptyKeyring(t)
func TestResolveToken_KeyringLegacyFallback(t *testing.T) {
// hosts.yml has no token for this host, so the keyring fallback is used.
isolateHostsFile(t)
setupKeyringWithToken(t, testHost, "keyring_token_abc")
t.Setenv("KH_API_KEY", "")
writeHostsFile(t, testHost, "hosts_yml_token")

rt, err := ResolveToken(testHost)
require.NoError(t, err)
require.Equal(t, "hosts_yml_token", rt.Token)
require.Equal(t, "keyring_token_abc", rt.Token)
require.Equal(t, AuthMethodToken, rt.Method)
}

func TestResolveToken_None(t *testing.T) {
isolateHostsFile(t)
setupEmptyKeyring(t)
t.Setenv("KH_API_KEY", "")

Expand All @@ -103,6 +106,7 @@ func TestResolveToken_None(t *testing.T) {
}

func TestResolveToken_EnvVarPriority(t *testing.T) {
isolateHostsFile(t)
setupKeyringWithToken(t, testHost, "keyring_token")
t.Setenv("KH_API_KEY", "env_wins")

Expand All @@ -112,13 +116,14 @@ func TestResolveToken_EnvVarPriority(t *testing.T) {
require.Equal(t, AuthMethodAPIKey, rt.Method)
}

func TestResolveToken_KeyringOverHostsYML(t *testing.T) {
setupKeyringWithToken(t, testHost, "keyring_takes_precedence")
func TestResolveToken_HostsYMLOverKeyring(t *testing.T) {
// hosts.yml now takes priority over the legacy keyring.
setupKeyringWithToken(t, testHost, "keyring_token")
t.Setenv("KH_API_KEY", "")
writeHostsFile(t, testHost, "hosts_yml_token")
writeHostsFile(t, testHost, "hosts_yml_wins")

rt, err := ResolveToken(testHost)
require.NoError(t, err)
require.Equal(t, "keyring_takes_precedence", rt.Token)
require.Equal(t, "hosts_yml_wins", rt.Token)
require.Equal(t, AuthMethodToken, rt.Method)
}
49 changes: 48 additions & 1 deletion internal/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"

khhttp "github.com/keeperhub/cli/internal/http"
Expand Down Expand Up @@ -62,8 +63,16 @@ type orgMembership struct {
}

// FetchTokenInfo queries the server for session details using the given token.
// Returns TokenInfo on success, or an error if the token is invalid.
// For session tokens, calls /api/auth/get-session.
// For API keys (kh_ prefix), validates via /api/workflows and returns basic info.
func FetchTokenInfo(host, token string) (TokenInfo, error) {
if strings.HasPrefix(token, "kh_") {
return fetchAPIKeyInfo(host, token)
}
return fetchSessionInfo(host, token)
}

func fetchSessionInfo(host, token string) (TokenInfo, error) {
client := &http.Client{Timeout: 10 * time.Second}

req, err := http.NewRequest(http.MethodGet, khhttp.BuildBaseURL(host)+"/api/auth/get-session", nil)
Expand Down Expand Up @@ -122,6 +131,44 @@ func FetchTokenInfo(host, token string) (TokenInfo, error) {
return info, nil
}

// fetchAPIKeyInfo validates an API key by calling an authenticated endpoint.
// API keys don't have session data, so we probe /api/workflows to confirm
// the key is accepted and extract org info from the response context.
func fetchAPIKeyInfo(host, token string) (TokenInfo, error) {
client := &http.Client{Timeout: 10 * time.Second}

req, err := http.NewRequest(http.MethodGet, khhttp.BuildBaseURL(host)+"/api/workflows?limit=1", nil)
if err != nil {
return TokenInfo{}, fmt.Errorf("creating validation request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)

resp, err := client.Do(req)
if err != nil {
return TokenInfo{}, fmt.Errorf("validating API key: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusUnauthorized {
return TokenInfo{}, fmt.Errorf("API key is invalid or revoked")
}
if resp.StatusCode != http.StatusOK {
return TokenInfo{}, fmt.Errorf("API key validation failed (status %d)", resp.StatusCode)
}

prefix := token
if len(prefix) > 14 {
prefix = prefix[:14]
}

return TokenInfo{
Email: prefix + "...",
Name: "API Key",
Method: AuthMethodAPIKey,
Role: "api-key",
}, nil
}

func fetchOrgDetails(client *http.Client, host, token, orgID string) (string, string) {
req, err := http.NewRequest(http.MethodGet, khhttp.BuildBaseURL(host)+"/api/organizations/"+orgID, nil)
if err != nil {
Expand Down
58 changes: 50 additions & 8 deletions internal/config/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package config
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"

"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -83,14 +85,52 @@ func (h *HostsConfig) ActiveHost(flagHost, envHost string) string {
}

// HostEntry looks up the HostConfig for the given hostname.
// Returns the entry and true if found, or an empty HostConfig and false otherwise.
// It tries the raw value first, then falls back to a bare hostname (scheme stripped)
// so that --host https://app-staging.keeperhub.com matches a hosts.yml key of
// app-staging.keeperhub.com.
func (h *HostsConfig) HostEntry(hostname string) (HostConfig, bool) {
entry, ok := h.Hosts[hostname]
return entry, ok
if entry, ok := h.Hosts[hostname]; ok {
return entry, true
}
bare := stripScheme(hostname)
if bare != hostname {
if entry, ok := h.Hosts[bare]; ok {
return entry, true
}
}
return HostConfig{}, false
}

// stripScheme removes the URL scheme (http:// or https://) and any trailing
// slash from a host string, returning just the hostname[:port].
func stripScheme(host string) string {
if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") {
if u, err := url.Parse(host); err == nil {
return u.Host
}
}
return strings.TrimRight(host, "/")
}

// resolveHostKey returns the canonical key for a host in the hosts map.
// If a bare-hostname entry already exists for a full-URL host, it returns the
// bare hostname so the token is stored alongside existing headers (e.g. CF-Access).
func resolveHostKey(hosts map[string]HostConfig, host string) string {
if _, ok := hosts[host]; ok {
return host
}
bare := stripScheme(host)
if bare != host {
if _, ok := hosts[bare]; ok {
return bare
}
}
// No existing entry -- prefer bare hostname for cleanliness.
return bare
}

// SetHostToken updates the token for a specific host in the hosts config.
// If the host entry doesn't exist, it creates one.
// If a bare-hostname entry already exists, the token is merged into it.
func SetHostToken(host, token string) error {
hosts, err := ReadHosts()
if err != nil {
Expand All @@ -99,9 +139,10 @@ func SetHostToken(host, token string) error {
if hosts.Hosts == nil {
hosts.Hosts = make(map[string]HostConfig)
}
entry := hosts.Hosts[host]
key := resolveHostKey(hosts.Hosts, host)
entry := hosts.Hosts[key]
entry.Token = token
hosts.Hosts[host] = entry
hosts.Hosts[key] = entry
return WriteHosts(hosts)
}

Expand All @@ -112,9 +153,10 @@ func ClearHostToken(host string) error {
if err != nil {
return err
}
if entry, ok := hosts.Hosts[host]; ok {
key := resolveHostKey(hosts.Hosts, host)
if entry, ok := hosts.Hosts[key]; ok {
entry.Token = ""
hosts.Hosts[host] = entry
hosts.Hosts[key] = entry
return WriteHosts(hosts)
}
return nil
Expand Down
40 changes: 40 additions & 0 deletions internal/config/hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,46 @@ func TestHostEntryLookup(t *testing.T) {
}
}

// TestHostEntrySchemeStripping verifies that HostEntry falls back to a bare
// hostname when the caller passes a full URL (e.g. --host https://staging.example.com).
func TestHostEntrySchemeStripping(t *testing.T) {
h := config.HostsConfig{
Hosts: map[string]config.HostConfig{
"app-staging.keeperhub.com": {
Headers: map[string]string{
"CF-Access-Client-Id": "abc",
"CF-Access-Client-Secret": "def",
},
},
},
}

// Full https:// URL should match bare hostname key
entry, ok := h.HostEntry("https://app-staging.keeperhub.com")
if !ok {
t.Fatal("expected HostEntry to match via scheme stripping")
}
if entry.Headers["CF-Access-Client-Id"] != "abc" {
t.Errorf("CF-Access-Client-Id: got %q, want %q", entry.Headers["CF-Access-Client-Id"], "abc")
}

// Full http:// URL should also match
h.Hosts["localhost:3000"] = config.HostConfig{Headers: map[string]string{"X-Test": "1"}}
entry, ok = h.HostEntry("http://localhost:3000")
if !ok {
t.Fatal("expected HostEntry to match http:// via scheme stripping")
}
if entry.Headers["X-Test"] != "1" {
t.Errorf("X-Test: got %q, want %q", entry.Headers["X-Test"], "1")
}

// Bare hostname should still work directly
entry, ok = h.HostEntry("app-staging.keeperhub.com")
if !ok {
t.Fatal("expected direct bare hostname match")
}
}

func TestHostsYAMLCFAccessHeaders(t *testing.T) {
rawYAML := `hosts:
app-staging.keeperhub.com:
Expand Down
Loading