diff --git a/cmd/api/api.go b/cmd/api/api.go index f56b4f5e..b872044c 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -116,7 +116,7 @@ func (c *ApiCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { // Create an HTTP client with appropriate configuration client := httpClient.NewClient( - f.Config.APIToken(), + f.Token, httpClient.WithBaseURL(f.RestAPIClient.BaseURL.String()), httpClient.WithMaxRetries(3), httpClient.WithMaxRetryDelay(60*time.Second), diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 40e7f5ea..2c2c67d8 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -83,8 +83,37 @@ func LoginWithToken(f *factory.Factory, org, token string) error { return nil } +// LoginWithSession stores an OAuth session for an organization in the system keychain. +func LoginWithSession(f *factory.Factory, org string, session *oauth.Session) error { + if org == "" { + return errors.New("organization cannot be empty") + } + if session == nil || session.AccessToken == "" { + return errors.New("oauth session must include an access token") + } + + kr := keyring.New() + if !kr.IsAvailable() { + return errors.New("system keychain is not available; cannot store token") + } + if err := kr.SetSession(org, session); err != nil { + return fmt.Errorf("failed to store token in keychain: %w", err) + } + fmt.Println("Token stored securely in system keychain.") + + if err := f.Config.EnsureOrganization(org); err != nil { + return fmt.Errorf("failed to register organization in config: %w", err) + } + + if err := f.Config.SelectOrganization(org, f.GitRepository != nil); err != nil { + return fmt.Errorf("failed to select organization: %w", err) + } + + return nil +} + func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } @@ -102,16 +131,15 @@ func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { return errors.New("--org requires --token. Use `bk auth login` for OAuth or `bk auth login --org --token ` for token login") } - // Resolve scope groups (e.g., "read_only" → individual read_* scopes). - // When --scopes is empty, no scope parameter is sent and the token - // inherits the user's full Buildkite permissions. + // Resolve scope groups (e.g. "read_only" to individual read_* scopes). + // When --scopes is empty, NewFlow defaults to requesting the full known + // scope set and Buildkite grants the subset the user can actually use. resolvedScopes := oauth.ResolveScopes(c.Scopes) // Create OAuth flow cfg := &oauth.Config{ // Host default handled via NewFlow, omitted to allow usage of BUILDKITE_HOST - ClientID: oauth.DefaultClientID, - Scopes: resolvedScopes, + Scopes: resolvedScopes, } flow, err := oauth.NewFlow(cfg) @@ -150,28 +178,89 @@ func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { return fmt.Errorf("token exchange failed: %w", err) } - // Resolve org from the API using the new token - client, err := buildkite.NewOpts(buildkite.WithTokenAuth(tokenResp.AccessToken)) + orgs, err := resolveOrganizationsFromToken(ctx, f.Config.RESTAPIEndpoint(), tokenResp.AccessToken) if err != nil { - return fmt.Errorf("failed to create API client: %w", err) + return err + } + + session := tokenResp.Session(cfg.Host, cfg.ClientID, time.Now()) + if err := storeSessionForOrganizations(f, orgs, session); err != nil { + return err } - orgs, _, err := client.Organizations.List(ctx, nil) + fmt.Printf("\nāœ… Successfully authenticated with organization %q\n", orgs[0].Slug) + fmt.Printf(" Scopes: %s\n", tokenResp.Scope) + + return nil +} + +func resolveOrganizationsFromToken(ctx context.Context, baseURL, token string) ([]buildkite.Organization, error) { + client, err := buildkite.NewOpts( + buildkite.WithBaseURL(baseURL), + buildkite.WithTokenAuth(token), + ) if err != nil { - return fmt.Errorf("failed to list organizations: %w", err) + return nil, fmt.Errorf("failed to create API client: %w", err) } - if len(orgs) == 0 { - return fmt.Errorf("no organizations found for this token") + + var allOrgs []buildkite.Organization + page := 1 + for { + orgs, resp, err := client.Organizations.List(ctx, &buildkite.OrganizationListOptions{ + ListOptions: buildkite.ListOptions{Page: page}, + }) + if err != nil { + return nil, fmt.Errorf("failed to list organizations: %w", err) + } + allOrgs = append(allOrgs, orgs...) + if resp == nil || resp.NextPage == 0 { + break + } + page = resp.NextPage } - org := orgs[0] + if len(allOrgs) == 0 { + return nil, fmt.Errorf("no organizations found for this token") + } - if err := LoginWithToken(f, org.Slug, tokenResp.AccessToken); err != nil { + return allOrgs, nil +} + +func resolveOrganizationFromToken(ctx context.Context, baseURL, token string) (*buildkite.Organization, error) { + orgs, err := resolveOrganizationsFromToken(ctx, baseURL, token) + if err != nil { + return nil, err + } + + return &orgs[0], nil +} + +func storeSessionForOrganizations(f *factory.Factory, orgs []buildkite.Organization, session *oauth.Session) error { + if len(orgs) == 0 { + return errors.New("no organizations found for this token") + } + if err := LoginWithSession(f, orgs[0].Slug, session); err != nil { return err } - fmt.Printf("\nāœ… Successfully authenticated with organization %q\n", org.Slug) - fmt.Printf(" Scopes: %s\n", tokenResp.Scope) + kr := keyring.New() + seen := map[string]struct{}{orgs[0].Slug: {}} + for _, org := range orgs[1:] { + if org.Slug == "" { + continue + } + if _, exists := seen[org.Slug]; exists { + continue + } + seen[org.Slug] = struct{}{} + + if err := kr.SetSession(org.Slug, session); err != nil { + return fmt.Errorf("failed to store token in keychain: %w", err) + } + if err := f.Config.EnsureOrganization(org.Slug); err != nil { + return fmt.Errorf("failed to register organization in config: %w", err) + } + } return nil } diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go new file mode 100644 index 00000000..bb9a5cb4 --- /dev/null +++ b/cmd/auth/login_test.go @@ -0,0 +1,112 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/buildkite/cli/v3/internal/config" + "github.com/buildkite/cli/v3/pkg/cmd/factory" + "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" + buildkite "github.com/buildkite/go-buildkite/v4" + "github.com/spf13/afero" +) + +func TestResolveOrganizationFromTokenUsesConfiguredBaseURL(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer bkua_test_token" { + t.Fatalf("Authorization = %q, want Bearer bkua_test_token", got) + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode([]map[string]any{{"slug": "test-org"}}); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + org, err := resolveOrganizationFromToken(context.Background(), server.URL, "bkua_test_token") + if err != nil { + t.Fatalf("resolveOrganizationFromToken returned error: %v", err) + } + if org == nil { + t.Fatal("resolveOrganizationFromToken returned nil organization") + } + if org.Slug != "test-org" { + t.Fatalf("Slug = %q, want test-org", org.Slug) + } +} + +func TestStoreSessionForOrganizationsStoresAllAccessibleOrgs(t *testing.T) { + keyring.MockForTesting() + + f := &factory.Factory{ + Config: config.New(afero.NewMemMapFs(), nil), + } + session := &oauth.Session{ + Version: oauth.SessionVersion, + AccessToken: "bkua_access", + TokenType: "Bearer", + } + + orgs := []buildkite.Organization{ + {Slug: "test-org"}, + {Slug: "other-org"}, + {Slug: "other-org"}, + } + + if err := storeSessionForOrganizations(f, orgs, session); err != nil { + t.Fatalf("storeSessionForOrganizations returned error: %v", err) + } + + kr := keyring.New() + for _, slug := range []string{"test-org", "other-org"} { + storedSession, err := kr.GetSession(slug) + if err != nil { + t.Fatalf("GetSession(%q) returned error: %v", slug, err) + } + if storedSession.AccessToken != "bkua_access" { + t.Fatalf("stored access token for %q = %q, want bkua_access", slug, storedSession.AccessToken) + } + } + + if got := f.Config.OrganizationSlug(); got != "test-org" { + t.Fatalf("OrganizationSlug() = %q, want test-org", got) + } +} + +func TestResolveOrganizationsFromTokenPaginates(t *testing.T) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + w.Header().Set("Content-Type", "application/json") + switch page { + case "", "1": + w.Header().Set("Link", `<`+server.URL+`/v2/organizations?page=2>; rel="next"`) + if err := json.NewEncoder(w).Encode([]map[string]any{{"slug": "org-one"}}); err != nil { + t.Fatalf("Encode page 1 returned error: %v", err) + } + case "2": + if err := json.NewEncoder(w).Encode([]map[string]any{{"slug": "org-two"}}); err != nil { + t.Fatalf("Encode page 2 returned error: %v", err) + } + default: + t.Fatalf("unexpected page query %q", page) + } + })) + defer server.Close() + + orgs, err := resolveOrganizationsFromToken(context.Background(), server.URL, "bkua_test_token") + if err != nil { + t.Fatalf("resolveOrganizationsFromToken returned error: %v", err) + } + if len(orgs) != 2 { + t.Fatalf("len(orgs) = %d, want 2", len(orgs)) + } + if orgs[0].Slug != "org-one" || orgs[1].Slug != "org-two" { + t.Fatalf("org slugs = [%q %q], want [org-one org-two]", orgs[0].Slug, orgs[1].Slug) + } +} diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go index b05b47fc..9805bfce 100644 --- a/cmd/auth/logout.go +++ b/cmd/auth/logout.go @@ -7,6 +7,7 @@ import ( "github.com/buildkite/cli/v3/internal/cli" "github.com/buildkite/cli/v3/pkg/cmd/factory" "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" ) type LogoutCmd struct { @@ -15,7 +16,7 @@ type LogoutCmd struct { } func (c *LogoutCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } @@ -59,9 +60,12 @@ func (c *LogoutCmd) logoutOrg(f *factory.Factory) error { kr := keyring.New() if kr.IsAvailable() { + var currentSession *oauth.Session + currentSession, _ = kr.GetSession(org) if err := kr.Delete(org); err != nil { fmt.Printf("Warning: could not remove token from keychain: %v\n", err) } else { + c.deleteSiblingOAuthSessions(f, kr, org, currentSession) fmt.Println("Token removed from system keychain.") } } @@ -69,3 +73,28 @@ func (c *LogoutCmd) logoutOrg(f *factory.Factory) error { fmt.Printf("Logged out of organization %q\n", org) return nil } + +func (c *LogoutCmd) deleteSiblingOAuthSessions(f *factory.Factory, kr *keyring.Keyring, org string, session *oauth.Session) { + if session == nil || session.RefreshToken == "" { + return + } + + for _, sibling := range f.Config.ConfiguredOrganizations() { + if sibling == "" || sibling == org { + continue + } + + siblingSession, err := kr.GetSession(sibling) + if err != nil || siblingSession == nil { + continue + } + if siblingSession.Host != session.Host || siblingSession.ClientID != session.ClientID { + continue + } + if siblingSession.RefreshToken != session.RefreshToken || siblingSession.AccessToken != session.AccessToken { + continue + } + + _ = kr.Delete(sibling) + } +} diff --git a/cmd/auth/logout_test.go b/cmd/auth/logout_test.go new file mode 100644 index 00000000..ef7a994a --- /dev/null +++ b/cmd/auth/logout_test.go @@ -0,0 +1,56 @@ +package auth + +import ( + "testing" + + "github.com/buildkite/cli/v3/internal/config" + "github.com/buildkite/cli/v3/pkg/cmd/factory" + "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" + "github.com/spf13/afero" +) + +func TestLogoutOrgDeletesSiblingOAuthAliases(t *testing.T) { + keyring.MockForTesting() + + conf := config.New(afero.NewMemMapFs(), nil) + if err := conf.EnsureOrganization("org-a"); err != nil { + t.Fatalf("EnsureOrganization org-a returned error: %v", err) + } + if err := conf.EnsureOrganization("org-b"); err != nil { + t.Fatalf("EnsureOrganization org-b returned error: %v", err) + } + if err := conf.SelectOrganization("org-a", false); err != nil { + t.Fatalf("SelectOrganization returned error: %v", err) + } + + session := &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + ClientID: "buildkite-cli", + AccessToken: "bkua_access", + RefreshToken: "bkrt_refresh", + TokenType: "Bearer", + } + + kr := keyring.New() + if err := kr.SetSession("org-a", session); err != nil { + t.Fatalf("SetSession org-a returned error: %v", err) + } + if err := kr.SetSession("org-b", session); err != nil { + t.Fatalf("SetSession org-b returned error: %v", err) + } + + cmd := &LogoutCmd{Org: "org-a"} + f := &factory.Factory{Config: conf} + if err := cmd.logoutOrg(f); err != nil { + t.Fatalf("logoutOrg returned error: %v", err) + } + + if _, err := kr.GetSession("org-a"); err == nil { + t.Fatal("expected org-a session to be deleted") + } + if _, err := kr.GetSession("org-b"); err == nil { + t.Fatal("expected org-b sibling session to be deleted") + } +} diff --git a/cmd/auth/switch.go b/cmd/auth/switch.go index cf2dc015..2ee029f1 100644 --- a/cmd/auth/switch.go +++ b/cmd/auth/switch.go @@ -26,7 +26,7 @@ Examples: } func (c *SwitchCmd) Run(globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/config/get.go b/cmd/config/get.go index 32b0f2e5..7da9dd69 100644 --- a/cmd/config/get.go +++ b/cmd/config/get.go @@ -36,7 +36,7 @@ func (c *GetCmd) Run() error { return err } - f, err := factory.New() + f, err := factory.New(factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/config/list.go b/cmd/config/list.go index 283dbd20..c8d9c7b4 100644 --- a/cmd/config/list.go +++ b/cmd/config/list.go @@ -24,7 +24,7 @@ Examples: } func (c *ListCmd) Run() error { - f, err := factory.New() + f, err := factory.New(factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/config/set.go b/cmd/config/set.go index 87341643..047d3903 100644 --- a/cmd/config/set.go +++ b/cmd/config/set.go @@ -57,7 +57,7 @@ func (c *SetCmd) Run() error { return fmt.Errorf("%s can only be set in user config (not --local)", key) } - f, err := factory.New() + f, err := factory.New(factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/config/unset.go b/cmd/config/unset.go index 4fb4881c..f5b6c02d 100644 --- a/cmd/config/unset.go +++ b/cmd/config/unset.go @@ -36,7 +36,7 @@ func (c *UnsetCmd) Run() error { return fmt.Errorf("%s can only be unset from user config (not --local)", key) } - f, err := factory.New() + f, err := factory.New(factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/configure/configure.go b/cmd/configure/configure.go index 74eca56c..507f3c74 100644 --- a/cmd/configure/configure.go +++ b/cmd/configure/configure.go @@ -50,7 +50,7 @@ Examples: } func (c *ConfigureCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } @@ -68,7 +68,7 @@ func (c *ConfigureCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error if targetOrg == "" { targetOrg = f.Config.OrganizationSlug() } - if !c.Force && targetOrg != "" && f.Config.APITokenForOrg(targetOrg) != "" { + if !c.Force && targetOrg != "" && hasStoredCredentialsForOrg(f, targetOrg) { return fmt.Errorf("API token already configured for organization %q. Use --force to overwrite", targetOrg) } } @@ -99,8 +99,7 @@ func ConfigureRun(f *factory.Factory, org string) error { } // Check if token already exists for this organization. // Use resolved token lookup so keychain-backed entries are detected. - existingToken := getTokenForOrg(f, org) - if existingToken != "" { + if hasStoredCredentialsForOrg(f, org) { fmt.Printf("Using existing API token for organization: %s\n", org) return f.Config.SelectOrganization(org, f.GitRepository != nil) } @@ -123,6 +122,10 @@ func getTokenForOrg(f *factory.Factory, org string) string { return f.Config.APITokenForOrg(org) } +func hasStoredCredentialsForOrg(f *factory.Factory, org string) bool { + return f.Config.HasStoredTokenForOrg(org) +} + // promptForInput handles terminal input with optional password masking func promptForInput(prompt string, isPassword bool) (string, error) { fmt.Print(prompt) diff --git a/cmd/organization/list.go b/cmd/organization/list.go index 47f279d2..289ba55b 100644 --- a/cmd/organization/list.go +++ b/cmd/organization/list.go @@ -34,7 +34,7 @@ Examples: } func (c *ListCmd) Run(globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/pipeline/convert.go b/cmd/pipeline/convert.go index 537a896b..07502e65 100644 --- a/cmd/pipeline/convert.go +++ b/cmd/pipeline/convert.go @@ -87,7 +87,7 @@ Examples: } func (c *ConvertCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/pipeline/copy.go b/cmd/pipeline/copy.go index 7ae4442b..2ce8a8b9 100644 --- a/cmd/pipeline/copy.go +++ b/cmd/pipeline/copy.go @@ -306,7 +306,7 @@ func (c *CopyCmd) runCopy(kongCtx *kong.Context, f *factory.Factory, source *bui // getClientForOrg creates a Buildkite client authenticated for the specified organization func (c *CopyCmd) getClientForOrg(f *factory.Factory, org string) (*buildkite.Client, error) { - token := f.Config.APITokenForOrg(org) + token := f.Config.RefreshedAPITokenForOrg(org) if token == "" { return nil, fmt.Errorf("no API token configured for organization %q. Run 'bk configure' to add it", org) } diff --git a/cmd/pipeline/validate.go b/cmd/pipeline/validate.go index 1f986934..5a526a2c 100644 --- a/cmd/pipeline/validate.go +++ b/cmd/pipeline/validate.go @@ -69,7 +69,7 @@ Examples: } func (c *ValidateCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } diff --git a/cmd/use/use.go b/cmd/use/use.go index 29670fa0..795d462a 100644 --- a/cmd/use/use.go +++ b/cmd/use/use.go @@ -26,7 +26,7 @@ Examples: } func (c *UseCmd) Run(globals cli.GlobalFlags) error { - f, err := factory.New(factory.WithDebug(globals.EnableDebug())) + f, err := factory.New(factory.WithDebug(globals.EnableDebug()), factory.WithoutAPIClients()) if err != nil { return err } diff --git a/internal/config/config.go b/internal/config/config.go index 6af0b9d2..8d3b83f6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,10 +6,12 @@ package config import ( + "context" "errors" "fmt" "io" "maps" + "net/url" "os" "path/filepath" "runtime" @@ -17,9 +19,11 @@ import ( "strconv" "strings" "sync" + "time" "github.com/buildkite/cli/v3/internal/pipeline" "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" buildkite "github.com/buildkite/go-buildkite/v4" git "github.com/go-git/go-git/v5" "github.com/goccy/go-yaml" @@ -122,14 +126,32 @@ func (conf *Config) SelectOrganization(org string, inGitRepo bool) error { } // APIToken gets the API token configured for the currently selected organization. -// Precedence: environment variable > keyring > config file (legacy, read-only with warning) +// Precedence: environment variable > keyring > config file (legacy, read-only with warning). +// This is a side-effect-free lookup and does not refresh OAuth sessions. func (conf *Config) APIToken() string { - return conf.APITokenForOrg(conf.OrganizationSlug()) + return conf.apiTokenForOrg(conf.OrganizationSlug(), false) } // APITokenForOrg gets the API token for a specific organization. -// Precedence: environment variable > keyring > config file (legacy, read-only with warning) +// Precedence: environment variable > keyring > config file (legacy, read-only with warning). +// This is a side-effect-free lookup and does not refresh OAuth sessions. func (conf *Config) APITokenForOrg(org string) string { + return conf.apiTokenForOrg(org, false) +} + +// RefreshedAPIToken gets the API token for the currently selected organization, +// refreshing an OAuth session first when needed. +func (conf *Config) RefreshedAPIToken() string { + return conf.apiTokenForOrg(conf.OrganizationSlug(), true) +} + +// RefreshedAPITokenForOrg gets the API token for a specific organization, +// refreshing an OAuth session first when needed. +func (conf *Config) RefreshedAPITokenForOrg(org string) string { + return conf.apiTokenForOrg(org, true) +} + +func (conf *Config) apiTokenForOrg(org string, refresh bool) string { if token := os.Getenv("BUILDKITE_API_TOKEN"); token != "" { envTokenWarningOnce.Do(func() { fmt.Fprintln(os.Stderr, "Warning: using BUILDKITE_API_TOKEN environment variable for authentication.") @@ -139,8 +161,29 @@ func (conf *Config) APITokenForOrg(org string) string { kr := keyring.New() if kr.IsAvailable() { - if token, err := kr.Get(org); err == nil && token != "" { - return token + if session, err := kr.GetSession(org); err == nil && session != nil && session.AccessToken != "" { + if refresh { + now := time.Now() + if !session.CanRefresh() { + if !session.ExpiresAt.IsZero() && !now.Before(session.ExpiresAt) { + return "" + } + return session.AccessToken + } + refreshedSession, refreshErr := conf.refreshOAuthSession(org, kr, session, now) + if refreshedSession != nil && refreshedSession.AccessToken != "" { + if refreshErr != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to refresh OAuth token for %q: %v\n", org, refreshErr) + } + return refreshedSession.AccessToken + } + } else { + now := time.Now() + if !session.ExpiresAt.IsZero() && !now.Before(session.ExpiresAt) { + return "" + } + return session.AccessToken + } } } @@ -158,6 +201,14 @@ func (conf *Config) APITokenForOrg(org string) string { return "" } +func (conf *Config) ShouldFallbackToSelectedOrg(org string) bool { + if org == "" || org == conf.OrganizationSlug() { + return false + } + + return !conf.HasStoredTokenForOrg(org) +} + // HasStoredTokenForOrg reports whether a token is stored for org in keyring // or config files, excluding environment variable overrides. func (conf *Config) HasStoredTokenForOrg(org string) bool { @@ -215,12 +266,33 @@ func (conf *Config) GetGraphQLEndpoint() string { func (conf *Config) RESTAPIEndpoint() string { value := os.Getenv("BUILDKITE_REST_API_ENDPOINT") if value != "" { - return value + return normaliseRESTAPIEndpoint(value) } return buildkite.DefaultBaseURL } +func normaliseRESTAPIEndpoint(value string) string { + parsed, err := url.Parse(value) + if err != nil { + return value + } + + trimmedPath := strings.TrimRight(parsed.Path, "/") + if !strings.HasSuffix(trimmedPath, "/v2") { + return value + } + + parsed.Path = strings.TrimSuffix(trimmedPath, "/v2") + if parsed.Path == "" { + parsed.Path = "/" + } else { + parsed.Path += "/" + } + + return parsed.String() +} + func (conf *Config) PagerDisabled() bool { if v, ok := lookupBoolEnv("BUILDKITE_NO_PAGER"); ok { return v @@ -544,3 +616,49 @@ func (conf *Config) writeUser() error { func (conf *Config) writeLocal() error { return writeFileConfig(conf.fs, conf.localPath, conf.local) } + +func (conf *Config) refreshOAuthSession(org string, kr *keyring.Keyring, session *oauth.Session, now time.Time) (*oauth.Session, error) { + if session == nil || session.AccessToken == "" { + return nil, nil + } + if !session.NeedsRefresh(now) { + return session, nil + } + + refreshedToken, err := oauth.RefreshAccessToken(context.Background(), &oauth.Config{Host: session.Host, ClientID: session.ClientID}, session.RefreshToken, session.Scope) + if err != nil { + if !session.ExpiresAt.IsZero() && !now.Before(session.ExpiresAt) { + return nil, err + } + return session, err + } + + refreshedSession := session.Update(refreshedToken, now) + if err := kr.SetSession(org, refreshedSession); err != nil { + return refreshedSession, err + } + conf.propagateOAuthSessionUpdate(kr, org, session, refreshedSession) + + return refreshedSession, nil +} + +func (conf *Config) propagateOAuthSessionUpdate(kr *keyring.Keyring, sourceOrg string, previous, updated *oauth.Session) { + for _, org := range conf.ConfiguredOrganizations() { + if org == "" || org == sourceOrg { + continue + } + + sibling, err := kr.GetSession(org) + if err != nil || sibling == nil { + continue + } + if sibling.Host != previous.Host || sibling.ClientID != previous.ClientID { + continue + } + if sibling.RefreshToken != previous.RefreshToken || sibling.AccessToken != previous.AccessToken { + continue + } + + _ = kr.SetSession(org, updated) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 43a75fd2..6db64060 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,10 +1,14 @@ package config import ( + "context" + "net/http" + "net/http/httptest" "os" "path/filepath" "testing" + buildkite "github.com/buildkite/go-buildkite/v4" "github.com/spf13/afero" ) @@ -110,6 +114,33 @@ func TestConfig(t *testing.T) { } }) + t.Run("RESTAPIEndpoint strips version suffix from env override", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v2/organizations" { + t.Fatalf("request path = %q, want /v2/organizations", r.URL.Path) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[]`)) + })) + defer server.Close() + + setEnv(t, "BUILDKITE_REST_API_ENDPOINT", server.URL+"/v2") + + conf := New(afero.NewMemMapFs(), nil) + client, err := buildkite.NewOpts( + buildkite.WithBaseURL(conf.RESTAPIEndpoint()), + buildkite.WithTokenAuth("bkua_test_token"), + ) + if err != nil { + t.Fatalf("NewOpts returned error: %v", err) + } + + if _, _, err := client.Organizations.List(context.Background(), nil); err != nil { + t.Fatalf("Organizations.List returned error: %v", err) + } + }) + t.Run("loadFileConfig returns error on invalid yaml", func(t *testing.T) { fs := afero.NewMemMapFs() path := filepath.Join(t.TempDir(), "bk.yaml") diff --git a/internal/config/oauth_refresh_test.go b/internal/config/oauth_refresh_test.go new file mode 100644 index 00000000..19e6fe86 --- /dev/null +++ b/internal/config/oauth_refresh_test.go @@ -0,0 +1,333 @@ +package config + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" + "github.com/spf13/afero" +) + +func TestRefreshedAPITokenForOrgRefreshesStoredOAuthSession(t *testing.T) { + keyring.MockForTesting() + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + var requests int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + + if r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected token path %q", r.URL.Path) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm returned error: %v", err) + } + if got := r.Form.Get("grant_type"); got != "refresh_token" { + t.Fatalf("grant_type = %q, want refresh_token", got) + } + if got := r.Form.Get("refresh_token"); got != "bkrt_old_refresh" { + t.Fatalf("refresh_token = %q, want bkrt_old_refresh", got) + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "bkua_refreshed_access", + RefreshToken: "bkrt_rotated_refresh", + TokenType: "Bearer", + Scope: "read_user read_organizations", + ExpiresIn: 3600, + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_old_refresh", + TokenType: "Bearer", + Scope: "read_user read_organizations", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + token := conf.RefreshedAPITokenForOrg("test-org") + if token != "bkua_refreshed_access" { + t.Fatalf("RefreshedAPITokenForOrg() = %q, want bkua_refreshed_access", token) + } + if requests != 1 { + t.Fatalf("refresh requests = %d, want 1", requests) + } + + session, err := kr.GetSession("test-org") + if err != nil { + t.Fatalf("GetSession returned error: %v", err) + } + if session.AccessToken != "bkua_refreshed_access" { + t.Fatalf("stored AccessToken = %q, want bkua_refreshed_access", session.AccessToken) + } + if session.RefreshToken != "bkrt_rotated_refresh" { + t.Fatalf("stored RefreshToken = %q, want bkrt_rotated_refresh", session.RefreshToken) + } + if !session.ExpiresAt.After(time.Now()) { + t.Fatalf("stored ExpiresAt = %s, want a future timestamp", session.ExpiresAt) + } +} + +func TestAPITokenForOrgDoesNotRefreshStoredOAuthSession(t *testing.T) { + keyring.MockForTesting() + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + Error: "invalid_grant", + ErrorDesc: "should not be called", + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + ClientID: "stored-client", + AccessToken: "bkua_current_access", + RefreshToken: "bkrt_old_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + if token := conf.APITokenForOrg("test-org"); token != "bkua_current_access" { + t.Fatalf("APITokenForOrg() = %q, want bkua_current_access", token) + } + if requests != 0 { + t.Fatalf("refresh requests = %d, want 0", requests) + } +} + +func TestAPITokenForOrgDoesNotReturnExpiredNonRefreshableToken(t *testing.T) { + keyring.MockForTesting() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + AccessToken: "bkua_expired_access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + if token := conf.APITokenForOrg("test-org"); token != "" { + t.Fatalf("APITokenForOrg() = %q, want empty token", token) + } +} + +func TestAPITokenForOrgDoesNotReturnExpiredRefreshableToken(t *testing.T) { + keyring.MockForTesting() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + ClientID: "buildkite-cli", + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_refresh", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + if token := conf.APITokenForOrg("test-org"); token != "" { + t.Fatalf("APITokenForOrg() = %q, want empty token", token) + } +} + +func TestRefreshedAPITokenForOrgRefreshUsesStoredClientID(t *testing.T) { + keyring.MockForTesting() + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm returned error: %v", err) + } + if got := r.Form.Get("client_id"); got != "stored-client" { + t.Fatalf("client_id = %q, want stored-client", got) + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "bkua_refreshed_access", + RefreshToken: "bkrt_rotated_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresIn: 3600, + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + ClientID: "stored-client", + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_old_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + if token := conf.RefreshedAPITokenForOrg("test-org"); token != "bkua_refreshed_access" { + t.Fatalf("RefreshedAPITokenForOrg() = %q, want bkua_refreshed_access", token) + } +} + +func TestRefreshedAPITokenForOrgDoesNotReturnExpiredTokenWhenRefreshFails(t *testing.T) { + keyring.MockForTesting() + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + Error: "invalid_grant", + ErrorDesc: "refresh token is invalid", + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + ClientID: "stored-client", + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_old_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + if token := conf.RefreshedAPITokenForOrg("test-org"); token != "" { + t.Fatalf("RefreshedAPITokenForOrg() = %q, want empty token", token) + } +} + +func TestRefreshedAPITokenForOrgPropagatesRotatedSessionToSiblingOrganizations(t *testing.T) { + keyring.MockForTesting() + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "bkua_refreshed_access", + RefreshToken: "bkrt_rotated_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresIn: 3600, + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + conf := New(afero.NewMemMapFs(), nil) + if err := conf.EnsureOrganization("org-a"); err != nil { + t.Fatalf("EnsureOrganization org-a returned error: %v", err) + } + if err := conf.EnsureOrganization("org-b"); err != nil { + t.Fatalf("EnsureOrganization org-b returned error: %v", err) + } + + kr := keyring.New() + original := &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + ClientID: "stored-client", + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_old_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresAt: time.Now().Add(-time.Minute), + } + if err := kr.SetSession("org-a", original); err != nil { + t.Fatalf("SetSession org-a returned error: %v", err) + } + if err := kr.SetSession("org-b", original); err != nil { + t.Fatalf("SetSession org-b returned error: %v", err) + } + + if token := conf.RefreshedAPITokenForOrg("org-a"); token != "bkua_refreshed_access" { + t.Fatalf("RefreshedAPITokenForOrg() = %q, want bkua_refreshed_access", token) + } + + for _, org := range []string{"org-a", "org-b"} { + session, err := kr.GetSession(org) + if err != nil { + t.Fatalf("GetSession(%q) returned error: %v", org, err) + } + if session.AccessToken != "bkua_refreshed_access" { + t.Fatalf("stored AccessToken for %q = %q, want bkua_refreshed_access", org, session.AccessToken) + } + if session.RefreshToken != "bkrt_rotated_refresh" { + t.Fatalf("stored RefreshToken for %q = %q, want bkrt_rotated_refresh", org, session.RefreshToken) + } + } +} + +func TestRefreshedAPITokenForOrgDoesNotReturnExpiredNonRefreshableToken(t *testing.T) { + keyring.MockForTesting() + + conf := New(afero.NewMemMapFs(), nil) + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + AccessToken: "bkua_expired_access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + if token := conf.RefreshedAPITokenForOrg("test-org"); token != "" { + t.Fatalf("RefreshedAPITokenForOrg() = %q, want empty token", token) + } +} diff --git a/pkg/cmd/factory/factory.go b/pkg/cmd/factory/factory.go index a1666198..ac867ec2 100644 --- a/pkg/cmd/factory/factory.go +++ b/pkg/cmd/factory/factory.go @@ -21,6 +21,7 @@ var userAgent string type Factory struct { Config *config.Config GitRepository *git.Repository + Token string GraphQLClient graphql.Client RestAPIClient *buildkite.Client Version string @@ -35,8 +36,9 @@ type Factory struct { type FactoryOpt func(*factoryConfig) type factoryConfig struct { - debug bool - orgOverride string + debug bool + orgOverride string + withoutAPIClient bool } // WithDebug enables debug output for REST API calls @@ -55,6 +57,14 @@ func WithOrgOverride(org string) FactoryOpt { } } +// WithoutAPIClients skips token lookup and API client construction. +// Use this for commands that only need local config or repository context. +func WithoutAPIClients() FactoryOpt { + return func(c *factoryConfig) { + c.withoutAPIClient = true + } +} + // debugTransport wraps an http.RoundTripper and logs requests/responses with sensitive headers redacted type debugTransport struct { transport http.RoundTripper @@ -148,12 +158,26 @@ func New(opts ...FactoryOpt) (*Factory, error) { } conf := config.New(nil, repo) + if cfg.withoutAPIClient { + return &Factory{ + Config: conf, + GitRepository: repo, + Version: version.Version, + NoPager: conf.PagerDisabled(), + Quiet: conf.Quiet(), + NoInput: conf.NoInput(), + Debug: cfg.debug, + }, nil + } - token := conf.APIToken() + token := "" if cfg.orgOverride != "" { - if t := conf.APITokenForOrg(cfg.orgOverride); t != "" { - token = t + token = conf.RefreshedAPITokenForOrg(cfg.orgOverride) + if token == "" && conf.ShouldFallbackToSelectedOrg(cfg.orgOverride) { + token = conf.RefreshedAPIToken() } + } else { + token = conf.RefreshedAPIToken() } // Build client options @@ -183,6 +207,7 @@ func New(opts ...FactoryOpt) (*Factory, error) { return &Factory{ Config: conf, GitRepository: repo, + Token: token, GraphQLClient: graphql.NewClient(conf.GetGraphQLEndpoint(), graphqlHTTPClient), RestAPIClient: buildkiteClient, Version: version.Version, diff --git a/pkg/cmd/factory/factory_test.go b/pkg/cmd/factory/factory_test.go index 3c9b9102..173cf91f 100644 --- a/pkg/cmd/factory/factory_test.go +++ b/pkg/cmd/factory/factory_test.go @@ -1,11 +1,16 @@ package factory import ( + "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" + "time" + + "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" ) func TestRedactHeaders(t *testing.T) { @@ -152,3 +157,163 @@ func TestDebugTransportHandlesNilBody(t *testing.T) { t.Errorf("expected status 200, got %d", resp.StatusCode) } } + +func TestNewWithoutAPIClientsDoesNotRefreshTokens(t *testing.T) { + keyring.MockForTesting() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("BUILDKITE_ORGANIZATION_SLUG", "test-org") + t.Setenv("BUILDKITE_API_TOKEN", "") + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + Error: "invalid_grant", + ErrorDesc: "should not be called", + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + kr := keyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + ClientID: "stored-client", + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_old_refresh", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + f, err := New(WithoutAPIClients()) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + if f.Config == nil { + t.Fatal("expected config to be initialised") + } + if f.RestAPIClient != nil { + t.Fatal("expected RestAPIClient to be nil when API clients are disabled") + } + if requests != 0 { + t.Fatalf("refresh requests = %d, want 0", requests) + } +} + +func TestNewWithOrgOverrideRefreshesOnlyOverrideOrg(t *testing.T) { + keyring.MockForTesting() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("BUILDKITE_ORGANIZATION_SLUG", "current-org") + t.Setenv("BUILDKITE_API_TOKEN", "") + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + currentRequests := 0 + currentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + currentRequests++ + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + Error: "invalid_grant", + ErrorDesc: "should not be called", + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer currentServer.Close() + + overrideRequests := 0 + overrideServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + overrideRequests++ + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "bkua_override_access", + RefreshToken: "bkrt_override_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresIn: 3600, + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer overrideServer.Close() + + kr := keyring.New() + if err := kr.SetSession("current-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: currentServer.URL, + ClientID: "current-client", + AccessToken: "bkua_current_access", + RefreshToken: "bkrt_current_refresh", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession current-org returned error: %v", err) + } + if err := kr.SetSession("override-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: overrideServer.URL, + ClientID: "override-client", + AccessToken: "bkua_old_override_access", + RefreshToken: "bkrt_old_override_refresh", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession override-org returned error: %v", err) + } + + f, err := New(WithOrgOverride("override-org")) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + if f.Token != "bkua_override_access" { + t.Fatalf("Factory token = %q, want bkua_override_access", f.Token) + } + if currentRequests != 0 { + t.Fatalf("current-org refresh requests = %d, want 0", currentRequests) + } + if overrideRequests != 1 { + t.Fatalf("override-org refresh requests = %d, want 1", overrideRequests) + } +} + +func TestNewWithOrgOverrideDoesNotFallbackWhenOverrideOrgCredentialsAreExpired(t *testing.T) { + keyring.MockForTesting() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("BUILDKITE_ORGANIZATION_SLUG", "current-org") + t.Setenv("BUILDKITE_API_TOKEN", "") + + kr := keyring.New() + if err := kr.SetSession("current-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + AccessToken: "bkua_current_access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SetSession current-org returned error: %v", err) + } + if err := kr.SetSession("override-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + AccessToken: "bkua_expired_override_access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession override-org returned error: %v", err) + } + + f, err := New(WithOrgOverride("override-org")) + if err != nil { + t.Fatalf("New returned error: %v", err) + } + if f.Token != "" { + t.Fatalf("Factory token = %q, want empty token for expired override credentials", f.Token) + } +} diff --git a/pkg/cmd/validation/config.go b/pkg/cmd/validation/config.go index 36154de8..65a76791 100644 --- a/pkg/cmd/validation/config.go +++ b/pkg/cmd/validation/config.go @@ -26,31 +26,31 @@ func ValidateConfigurationForOrg(conf *config.Config, commandPath, org string) e } func validateConfiguration(conf *config.Config, commandPath, orgOverride string) error { + // Skip token checks for configure commands before consulting auth state. + if strings.HasPrefix(commandPath, "configure") { + return nil + } + + // Skip token checks for commands that don't need it. + for _, exemptCmd := range CommandsNotRequiringToken { + if strings.HasSuffix(commandPath, exemptCmd) { + return nil + } + } + org := conf.OrganizationSlug() token := conf.APIToken() if orgOverride != "" { org = orgOverride - if t := conf.APITokenForOrg(org); t != "" { - token = t + token = conf.APITokenForOrg(org) + if token == "" && conf.ShouldFallbackToSelectedOrg(org) { + token = conf.APIToken() } } missingToken := token == "" missingOrg := org == "" - // Skip token check for all configure commands - if strings.HasPrefix(commandPath, "configure") { - return nil - } - - // Skip token check for commands that don't need it - for _, exemptCmd := range CommandsNotRequiringToken { - // Check if the command path ends with the exempt command pattern - if strings.HasSuffix(commandPath, exemptCmd) { - return nil // Skip validation for exempt commands - } - } - switch { case missingToken && missingOrg: return errors.New("you are not authenticated. Run bk auth login to authenticate, or run bk use to select a configured organization") diff --git a/pkg/cmd/validation/config_test.go b/pkg/cmd/validation/config_test.go index 59969754..82781d3b 100644 --- a/pkg/cmd/validation/config_test.go +++ b/pkg/cmd/validation/config_test.go @@ -1,10 +1,15 @@ package validation import ( + "encoding/json" + "net/http" + "net/http/httptest" "testing" + "time" "github.com/buildkite/cli/v3/internal/config" bkKeyring "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" ) func TestValidateConfiguration_ExemptCommands(t *testing.T) { @@ -25,6 +30,49 @@ func TestValidateConfiguration_ExemptCommands(t *testing.T) { } } +func TestValidateConfiguration_ExemptCommandsDoNotRefreshTokens(t *testing.T) { + bkKeyring.MockForTesting() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("BUILDKITE_API_TOKEN", "") + t.Setenv("BUILDKITE_ORGANIZATION_SLUG", "test-org") + t.Setenv(oauth.EnvClientID, "env-client") + t.Setenv(oauth.LegacyEnvClientID, "") + + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oauth.TokenResponse{ + Error: "invalid_grant", + ErrorDesc: "should not be called", + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + kr := bkKeyring.New() + if err := kr.SetSession("test-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: server.URL, + ClientID: "stored-client", + AccessToken: "bkua_expired_access", + RefreshToken: "bkrt_old_refresh", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession returned error: %v", err) + } + + conf := config.New(nil, nil) + if err := ValidateConfiguration(conf, "configure default"); err != nil { + t.Fatalf("expected no error for configure default, got %v", err) + } + if requests != 0 { + t.Fatalf("refresh requests = %d, want 0", requests) + } +} + func TestValidateConfiguration_MissingValues(t *testing.T) { t.Run("missing token and org", func(t *testing.T) { t.Setenv("BUILDKITE_API_TOKEN", "") @@ -54,6 +102,54 @@ func TestValidateConfiguration_MissingValues(t *testing.T) { }) } +func TestValidateConfigurationForOrgRequiresCredentialsForOverrideOrg(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("BUILDKITE_API_TOKEN", "") + t.Setenv("BUILDKITE_ORGANIZATION_SLUG", "current-org") + + conf := newTestConfig(t) + if err := conf.EnsureOrganization("current-org"); err != nil { + t.Fatalf("EnsureOrganization returned error: %v", err) + } + + if err := ValidateConfigurationForOrg(conf, "pipeline list", "override-org"); err == nil { + t.Fatal("expected missing credentials error for override org") + } +} + +func TestValidateConfigurationForOrgDoesNotFallbackWhenOverrideOrgCredentialsAreExpired(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("BUILDKITE_API_TOKEN", "") + t.Setenv("BUILDKITE_ORGANIZATION_SLUG", "current-org") + + conf := newTestConfig(t) + kr := bkKeyring.New() + if err := kr.SetSession("current-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + AccessToken: "bkua_current_access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), + }); err != nil { + t.Fatalf("SetSession current-org returned error: %v", err) + } + if err := kr.SetSession("override-org", &oauth.Session{ + Version: oauth.SessionVersion, + Host: "buildkite.localhost", + AccessToken: "bkua_expired_override_access", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Minute), + }); err != nil { + t.Fatalf("SetSession override-org returned error: %v", err) + } + + if err := ValidateConfigurationForOrg(conf, "pipeline list", "override-org"); err == nil { + t.Fatal("expected missing credentials error for override org with expired stored credentials") + } +} + func newTestConfig(t *testing.T) *config.Config { t.Helper() t.Setenv("HOME", t.TempDir()) diff --git a/pkg/keyring/keyring.go b/pkg/keyring/keyring.go index 8ea8db03..0293e239 100644 --- a/pkg/keyring/keyring.go +++ b/pkg/keyring/keyring.go @@ -1,11 +1,12 @@ // Package keyring provides secure credential storage using the OS keychain. -// It falls back to file-based storage when the keychain is unavailable (e.g., in CI environments). package keyring import ( + "encoding/json" "os" "sync" + "github.com/buildkite/cli/v3/pkg/oauth" "github.com/zalando/go-keyring" ) @@ -18,7 +19,7 @@ var ( keyringAvailable bool ) -// Keyring provides secure credential storage with fallback support +// Keyring provides secure credential storage. type Keyring struct { useKeyring bool } @@ -34,17 +35,59 @@ func New() *Keyring { // Set stores a token for the given organization func (k *Keyring) Set(org, token string) error { if !k.useKeyring { - return nil // Fallback handled by config file + return nil } return keyring.Set(serviceName, org, token) } // Get retrieves a token for the given organization func (k *Keyring) Get(org string) (string, error) { + session, err := k.GetSession(org) + if err != nil { + return "", err + } + return session.AccessToken, nil +} + +// SetSession stores an OAuth session for the given organization. +func (k *Keyring) SetSession(org string, session *oauth.Session) error { + if !k.useKeyring { + return nil + } + + encoded, err := json.Marshal(session) + if err != nil { + return err + } + + return keyring.Set(serviceName, org, string(encoded)) +} + +// GetSession retrieves an OAuth session for the given organization. +// Legacy plaintext tokens are returned as access-token-only sessions. +func (k *Keyring) GetSession(org string) (*oauth.Session, error) { if !k.useKeyring { - return "", keyring.ErrNotFound + return nil, keyring.ErrNotFound + } + + stored, err := keyring.Get(serviceName, org) + if err != nil { + return nil, err } - return keyring.Get(serviceName, org) + + var session oauth.Session + if err := json.Unmarshal([]byte(stored), &session); err == nil && session.AccessToken != "" { + if session.Version == 0 { + session.Version = oauth.SessionVersion + } + return &session, nil + } + + return &oauth.Session{ + Version: oauth.SessionVersion, + AccessToken: stored, + TokenType: "Bearer", + }, nil } // Delete removes a token for the given organization diff --git a/pkg/oauth/oauth.go b/pkg/oauth/oauth.go index 73c9c99a..b57688fe 100644 --- a/pkg/oauth/oauth.go +++ b/pkg/oauth/oauth.go @@ -7,6 +7,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "net" @@ -19,6 +20,10 @@ import ( const ( DefaultHost = "buildkite.com" + EnvClientID = "BUILDKITE_CLI_CLIENT_ID" + // LegacyEnvClientID preserves the older runtime override while builds migrate. + LegacyEnvClientID = "BUILDKITE_OAUTH_CLIENT_ID" + SessionVersion = 1 ) // AllScopes is the complete set of Buildkite API token scopes. When no --scopes @@ -27,6 +32,8 @@ const ( // // Reference: https://buildkite.com/docs/apis/managing-api-tokens var AllScopes = []string{ + "graphql", + // CI/CD "read_agents", "read_artifacts", @@ -34,6 +41,7 @@ var AllScopes = []string{ "read_builds", "read_clusters", "read_job_env", + "read_notification_services", "read_pipeline_templates", "read_pipelines", "read_rules", @@ -42,6 +50,7 @@ var AllScopes = []string{ "write_build_logs", "write_builds", "write_clusters", + "write_notification_services", "write_pipeline_templates", "write_pipelines", "write_rules", @@ -155,11 +164,82 @@ type CallbackResult struct { // TokenResponse holds the token exchange response type TokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - Scope string `json:"scope"` - Error string `json:"error,omitempty"` - ErrorDesc string `json:"error_description,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + ExpiresIn int `json:"expires_in,omitempty"` + Error string `json:"error,omitempty"` + ErrorDesc string `json:"error_description,omitempty"` +} + +// Session holds a refreshable OAuth session persisted in the keychain. +type Session struct { + Version int `json:"version"` + Host string `json:"host,omitempty"` + ClientID string `json:"client_id,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + Scope string `json:"scope,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` +} + +// Session converts a token response into a persisted session. +func (r *TokenResponse) Session(host, clientID string, now time.Time) *Session { + session := &Session{ + Version: SessionVersion, + Host: host, + ClientID: clientID, + AccessToken: r.AccessToken, + RefreshToken: r.RefreshToken, + TokenType: r.TokenType, + Scope: r.Scope, + } + if r.ExpiresIn > 0 { + session.ExpiresAt = now.UTC().Add(time.Duration(r.ExpiresIn) * time.Second) + } + return session +} + +// Update returns a new session with the latest token response applied. +func (s *Session) Update(tokenResp *TokenResponse, now time.Time) *Session { + refreshToken := s.RefreshToken + if tokenResp.RefreshToken != "" { + refreshToken = tokenResp.RefreshToken + } + + updated := &Session{ + Version: SessionVersion, + Host: s.Host, + ClientID: s.ClientID, + AccessToken: tokenResp.AccessToken, + RefreshToken: refreshToken, + TokenType: firstNonEmpty(tokenResp.TokenType, s.TokenType), + Scope: firstNonEmpty(tokenResp.Scope, s.Scope), + } + if tokenResp.ExpiresIn > 0 { + updated.ExpiresAt = now.UTC().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + return updated +} + +// CanRefresh reports whether the session has enough information to perform a refresh grant. +func (s *Session) CanRefresh() bool { + return s != nil && s.RefreshToken != "" +} + +// NeedsRefresh reports whether the access token should be refreshed before use. +func (s *Session) NeedsRefresh(now time.Time) bool { + if !s.CanRefresh() { + return false + } + if s.ExpiresAt.IsZero() { + return true + } + + // Refresh slightly early so commands don't race against expiry mid-request. + return !now.Before(s.ExpiresAt.Add(-time.Minute)) } // Flow manages an OAuth authentication flow @@ -172,17 +252,13 @@ type Flow struct { // NewFlow creates a new OAuth flow func NewFlow(cfg *Config) (*Flow, error) { - if cfg.Host == "" { - // Allow override via environment variable for local development - if envHost := os.Getenv("BUILDKITE_HOST"); envHost != "" { - cfg.Host = envHost - } else { - cfg.Host = DefaultHost - } - } - if cfg.ClientID == "" { - cfg.ClientID = DefaultClientID + cfg.Host = resolveHost(cfg.Host) + + clientID, err := resolveClientID(cfg.ClientID) + if err != nil { + return nil, err } + cfg.ClientID = clientID if cfg.Scopes == "" { cfg.Scopes = strings.Join(AllScopes, " ") } @@ -236,7 +312,7 @@ func (f *Flow) AuthorizationURL() string { params.Set("organization_uuid", f.config.OrgUUID) } - return fmt.Sprintf("https://%s/oauth/authorize?%s", f.config.Host, params.Encode()) + return fmt.Sprintf("%s/oauth/authorize?%s", baseURL(f.config.Host), params.Encode()) } // WaitForCallback waits for the OAuth callback and returns the authorization code @@ -320,8 +396,6 @@ func (f *Flow) WaitForCallback(ctx context.Context) (*CallbackResult, error) { // ExchangeCode exchanges the authorization code for an access token func (f *Flow) ExchangeCode(ctx context.Context, code string) (*TokenResponse, error) { - tokenURL := fmt.Sprintf("https://%s/oauth/token", f.config.Host) - data := url.Values{ "grant_type": {"authorization_code"}, "code": {code}, @@ -330,6 +404,36 @@ func (f *Flow) ExchangeCode(ctx context.Context, code string) (*TokenResponse, e "code_verifier": {f.codeVerifier}, } + return exchangeToken(ctx, f.config, data) +} + +// RefreshAccessToken exchanges a refresh token for a new access token. +func RefreshAccessToken(ctx context.Context, cfg *Config, refreshToken, scope string) (*TokenResponse, error) { + resolvedHost := resolveHost(cfg.Host) + clientID, err := resolveClientID(cfg.ClientID) + if err != nil { + return nil, err + } + + data := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {clientID}, + } + if scope != "" { + data.Set("scope", scope) + } + + requestCfg := *cfg + requestCfg.Host = resolvedHost + requestCfg.ClientID = clientID + + return exchangeToken(ctx, &requestCfg, data) +} + +func exchangeToken(ctx context.Context, cfg *Config, data url.Values) (*TokenResponse, error) { + tokenURL := fmt.Sprintf("%s/oauth/token", baseURL(cfg.Host)) + req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, err @@ -366,6 +470,53 @@ func (f *Flow) ExchangeCode(ctx context.Context, code string) (*TokenResponse, e return &tokenResp, nil } +func resolveHost(host string) string { + if host != "" { + return host + } + + if envHost := os.Getenv("BUILDKITE_HOST"); envHost != "" { + return envHost + } + + return DefaultHost +} + +func resolveClientID(explicit string) (string, error) { + if explicit != "" { + return explicit, nil + } + if envClientID := os.Getenv(EnvClientID); envClientID != "" { + return envClientID, nil + } + if legacyEnvClientID := os.Getenv(LegacyEnvClientID); legacyEnvClientID != "" { + return legacyEnvClientID, nil + } + if DefaultClientID != "" { + return DefaultClientID, nil + } + + return "", errors.New("oauth client ID is not configured") +} + +func baseURL(host string) string { + if strings.HasPrefix(host, "http://") || strings.HasPrefix(host, "https://") { + return strings.TrimRight(host, "/") + } + + return "https://" + host +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + + return "" +} + // Close cleans up the OAuth flow resources func (f *Flow) Close() error { if f.listener != nil { diff --git a/pkg/oauth/oauth_test.go b/pkg/oauth/oauth_test.go index bd2b5a54..ef0574cd 100644 --- a/pkg/oauth/oauth_test.go +++ b/pkg/oauth/oauth_test.go @@ -3,6 +3,7 @@ package oauth import ( "strings" "testing" + "time" ) func TestResolveScopes(t *testing.T) { @@ -76,7 +77,6 @@ func TestNewFlow_DefaultsToAllScopes(t *testing.T) { t.Fatal("expected scope parameter in URL") } - // Verify all scopes are present for _, s := range AllScopes { if !strings.Contains(authURL, s) { t.Errorf("expected scope %q in URL, got: %s", s, authURL) @@ -101,8 +101,178 @@ func TestAuthorizationURL_UsesProvidedScopes(t *testing.T) { if !strings.Contains(authURL, "scope=") { t.Fatal("expected scope parameter in URL") } - // Should use the provided scopes, not all scopes if strings.Contains(authURL, "write_builds") { t.Errorf("expected only provided scopes, but found write_builds in URL: %s", authURL) } } + +func TestNewFlowUsesExplicitClientID(t *testing.T) { + t.Setenv(EnvClientID, "env-client") + t.Setenv(LegacyEnvClientID, "legacy-client") + + flow, err := NewFlow(&Config{ + ClientID: "explicit-client", + CallbackURL: "http://127.0.0.1/callback", + }) + if err != nil { + t.Fatalf("NewFlow returned error: %v", err) + } + + if flow.config.ClientID != "explicit-client" { + t.Fatalf("expected explicit client ID, got %q", flow.config.ClientID) + } +} + +func TestNewFlowUsesRuntimeClientIDOverride(t *testing.T) { + originalDefault := DefaultClientID + DefaultClientID = "" + t.Cleanup(func() { + DefaultClientID = originalDefault + }) + + t.Setenv(EnvClientID, "env-client") + + flow, err := NewFlow(&Config{ + CallbackURL: "http://127.0.0.1/callback", + }) + if err != nil { + t.Fatalf("NewFlow returned error: %v", err) + } + + if flow.config.ClientID != "env-client" { + t.Fatalf("expected runtime client ID override, got %q", flow.config.ClientID) + } +} + +func TestNewFlowUsesLegacyRuntimeClientIDFallback(t *testing.T) { + originalDefault := DefaultClientID + DefaultClientID = "" + t.Cleanup(func() { + DefaultClientID = originalDefault + }) + + t.Setenv(EnvClientID, "") + t.Setenv(LegacyEnvClientID, "legacy-client") + + flow, err := NewFlow(&Config{ + CallbackURL: "http://127.0.0.1/callback", + }) + if err != nil { + t.Fatalf("NewFlow returned error: %v", err) + } + + if flow.config.ClientID != "legacy-client" { + t.Fatalf("expected legacy runtime client ID fallback, got %q", flow.config.ClientID) + } +} + +func TestNewFlowUsesLinkerInjectedDefaultClientID(t *testing.T) { + originalDefault := DefaultClientID + DefaultClientID = "linked-client" + t.Cleanup(func() { + DefaultClientID = originalDefault + }) + + t.Setenv(EnvClientID, "") + t.Setenv(LegacyEnvClientID, "") + + flow, err := NewFlow(&Config{ + CallbackURL: "http://127.0.0.1/callback", + }) + if err != nil { + t.Fatalf("NewFlow returned error: %v", err) + } + + if flow.config.ClientID != "linked-client" { + t.Fatalf("expected linker-injected client ID, got %q", flow.config.ClientID) + } +} + +func TestNewFlowErrorsWhenNoClientIDIsConfigured(t *testing.T) { + originalDefault := DefaultClientID + DefaultClientID = "" + t.Cleanup(func() { + DefaultClientID = originalDefault + }) + + t.Setenv(EnvClientID, "") + t.Setenv(LegacyEnvClientID, "") + + if _, err := NewFlow(&Config{ + CallbackURL: "http://127.0.0.1/callback", + }); err == nil { + t.Fatal("expected error when client ID is not configured") + } +} + +func TestTokenResponseSessionPersistsClientID(t *testing.T) { + now := time.Date(2026, time.March, 28, 12, 0, 0, 0, time.UTC) + + session := (&TokenResponse{ + AccessToken: "bkua_access", + RefreshToken: "bkrt_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresIn: 3600, + }).Session("buildkite.localhost", "buildkite-cli", now) + + if session.ClientID != "buildkite-cli" { + t.Fatalf("ClientID = %q, want buildkite-cli", session.ClientID) + } + if got, want := session.ExpiresAt, now.Add(time.Hour); !got.Equal(want) { + t.Fatalf("ExpiresAt = %s, want %s", got, want) + } +} + +func TestSessionUpdateKeepsSessionRefreshableWhenExpiresInIsOmitted(t *testing.T) { + now := time.Date(2026, time.March, 28, 12, 0, 0, 0, time.UTC) + originalExpiry := now.Add(-2 * time.Hour) + + updated := (&Session{ + Version: SessionVersion, + Host: "buildkite.localhost", + ClientID: "buildkite-cli", + AccessToken: "bkua_old_access", + RefreshToken: "bkrt_old_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresAt: originalExpiry, + }).Update(&TokenResponse{ + AccessToken: "bkua_new_access", + }, now) + + if updated.ClientID != "buildkite-cli" { + t.Fatalf("ClientID = %q, want buildkite-cli", updated.ClientID) + } + if updated.RefreshToken != "bkrt_old_refresh" { + t.Fatalf("RefreshToken = %q, want bkrt_old_refresh", updated.RefreshToken) + } + if updated.Scope != "read_user" { + t.Fatalf("Scope = %q, want read_user", updated.Scope) + } + if updated.TokenType != "Bearer" { + t.Fatalf("TokenType = %q, want Bearer", updated.TokenType) + } + if !updated.ExpiresAt.IsZero() { + t.Fatalf("ExpiresAt = %s, want zero value when expires_in is omitted", updated.ExpiresAt) + } + if !updated.CanRefresh() { + t.Fatal("expected updated session to remain refreshable") + } + if !updated.NeedsRefresh(now) { + t.Fatal("expected updated session to refresh again when expiry is unknown") + } +} + +func TestSessionWithoutExpiryStillRefreshes(t *testing.T) { + session := &Session{ + RefreshToken: "bkrt_refresh", + } + + if !session.CanRefresh() { + t.Fatal("expected session with refresh token to be refreshable") + } + if !session.NeedsRefresh(time.Now()) { + t.Fatal("expected session without expiry to refresh before use") + } +} diff --git a/pkg/oauth/refresh_test.go b/pkg/oauth/refresh_test.go new file mode 100644 index 00000000..f9b047f9 --- /dev/null +++ b/pkg/oauth/refresh_test.go @@ -0,0 +1,69 @@ +package oauth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRefreshAccessToken(t *testing.T) { + originalDefault := DefaultClientID + DefaultClientID = "" + t.Cleanup(func() { + DefaultClientID = originalDefault + }) + + t.Setenv(EnvClientID, "env-client") + t.Setenv(LegacyEnvClientID, "") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + t.Fatalf("unexpected token path %q", r.URL.Path) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm returned error: %v", err) + } + + if got := r.Form.Get("grant_type"); got != "refresh_token" { + t.Fatalf("grant_type = %q, want refresh_token", got) + } + if got := r.Form.Get("refresh_token"); got != "bkrt_old_refresh" { + t.Fatalf("refresh_token = %q, want bkrt_old_refresh", got) + } + if got := r.Form.Get("client_id"); got != "env-client" { + t.Fatalf("client_id = %q, want env-client", got) + } + if got := r.Form.Get("scope"); got != "read_user" { + t.Fatalf("scope = %q, want read_user", got) + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "bkua_new_access", + RefreshToken: "bkrt_new_refresh", + TokenType: "Bearer", + Scope: "read_user", + ExpiresIn: 3600, + }); err != nil { + t.Fatalf("Encode returned error: %v", err) + } + })) + defer server.Close() + + tokenResp, err := RefreshAccessToken(context.Background(), &Config{Host: server.URL}, "bkrt_old_refresh", "read_user") + if err != nil { + t.Fatalf("RefreshAccessToken returned error: %v", err) + } + + if tokenResp.AccessToken != "bkua_new_access" { + t.Fatalf("AccessToken = %q, want bkua_new_access", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "bkrt_new_refresh" { + t.Fatalf("RefreshToken = %q, want bkrt_new_refresh", tokenResp.RefreshToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Fatalf("ExpiresIn = %d, want 3600", tokenResp.ExpiresIn) + } +}