diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 4fa15084..4f679354 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "time" "github.com/alecthomas/kong" @@ -155,7 +156,10 @@ func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { } // Resolve org from the API using the new token - client, err := buildkite.NewOpts(buildkite.WithTokenAuth(tokenResp.AccessToken)) + client, err := buildkite.NewOpts( + buildkite.WithTokenAuth(tokenResp.AccessToken), + buildkite.WithBaseURL(f.Config.RESTAPIEndpoint()), + ) if err != nil { return fmt.Errorf("failed to create API client: %w", err) } @@ -174,8 +178,36 @@ func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { return err } + // Store refresh token if the server issued one + if tokenResp.RefreshToken != "" { + kr := keyring.New() + if kr.IsAvailable() { + if err := kr.SetRefreshToken(org.Slug, tokenResp.RefreshToken); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to store refresh token: %v\n", err) + } + } + } + fmt.Printf("\n✅ Successfully authenticated with organization %q\n", org.Slug) fmt.Printf(" Scopes: %s\n", tokenResp.Scope) + if tokenResp.RefreshToken != "" { + fmt.Printf(" Token expires in: %s (will refresh automatically)\n", formatDuration(tokenResp.ExpiresIn)) + } return nil } + +func formatDuration(seconds int) string { + if seconds <= 0 { + return "unknown" + } + d := time.Duration(seconds) * time.Second + if d >= time.Hour { + hours := int(d.Hours()) + if hours == 1 { + return "1 hour" + } + return fmt.Sprintf("%d hours", hours) + } + return fmt.Sprintf("%d minutes", int(d.Minutes())) +} diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go index b05b47fc..8e389c9a 100644 --- a/cmd/auth/logout.go +++ b/cmd/auth/logout.go @@ -36,6 +36,7 @@ func (c *LogoutCmd) logoutAll(f *factory.Factory) error { if err := kr.Delete(org); err != nil { fmt.Printf("Warning: could not remove token from keychain for %q: %v\n", org, err) } + _ = kr.DeleteRefreshToken(org) } } @@ -64,6 +65,7 @@ func (c *LogoutCmd) logoutOrg(f *factory.Factory) error { } else { fmt.Println("Token removed from system keychain.") } + _ = kr.DeleteRefreshToken(org) } fmt.Printf("Logged out of organization %q\n", org) diff --git a/internal/config/config.go b/internal/config/config.go index 6af0b9d2..41d011f8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -158,6 +158,25 @@ func (conf *Config) APITokenForOrg(org string) string { return "" } +// RefreshTokenForOrg gets the refresh token for a specific organization from the keyring. +func (conf *Config) RefreshTokenForOrg(org string) string { + if org == "" { + return "" + } + kr := keyring.New() + if kr.IsAvailable() { + if token, err := kr.GetRefreshToken(org); err == nil && token != "" { + return token + } + } + return "" +} + +// RefreshToken gets the refresh token for the currently selected organization. +func (conf *Config) RefreshToken() string { + return conf.RefreshTokenForOrg(conf.OrganizationSlug()) +} + // HasStoredTokenForOrg reports whether a token is stored for org in keyring // or config files, excluding environment variable overrides. func (conf *Config) HasStoredTokenForOrg(org string) bool { diff --git a/internal/http/refresh_transport.go b/internal/http/refresh_transport.go new file mode 100644 index 00000000..34f67ca8 --- /dev/null +++ b/internal/http/refresh_transport.go @@ -0,0 +1,225 @@ +package http + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + + "github.com/buildkite/cli/v3/pkg/keyring" + "github.com/buildkite/cli/v3/pkg/oauth" +) + +// TokenSource provides thread-safe access to the current access token. +// It is shared between auth-injection points (REST, GraphQL) and +// RefreshTransport so that a refreshed token is immediately visible +// to all subsequent requests. +type TokenSource struct { + mu sync.RWMutex + token string +} + +// NewTokenSource creates a TokenSource initialised with the given token. +func NewTokenSource(token string) *TokenSource { + return &TokenSource{token: token} +} + +// Token returns the current access token. +func (ts *TokenSource) Token() string { + ts.mu.RLock() + defer ts.mu.RUnlock() + return ts.token +} + +// SetToken updates the current access token. +func (ts *TokenSource) SetToken(token string) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.token = token +} + +// AuthTransport injects the Authorization header from a TokenSource +// on every outgoing request. It should wrap the base transport so that +// RefreshTransport (which sits outside it) can override the header on +// retries. +type AuthTransport struct { + Base http.RoundTripper + TokenSource *TokenSource + UserAgent string +} + +func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + token := t.TokenSource.Token() + if token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + } + if t.UserAgent != "" { + req.Header.Set("User-Agent", t.UserAgent) + } + base := t.Base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} + +// RefreshTransport wraps an http.RoundTripper to automatically refresh +// expired OAuth access tokens using a stored refresh token. +// +// On a 401 response it: +// 1. Acquires a mutex to serialise concurrent refreshes. +// 2. Checks whether the token has already been refreshed by another +// goroutine (compare-after-lock). +// 3. If not, exchanges the refresh token for new tokens. +// 4. Persists the new tokens and updates the shared TokenSource. +// 5. Retries the original request with the new token. +type RefreshTransport struct { + Base http.RoundTripper + Org string + Keyring *keyring.Keyring + TokenSource *TokenSource + + mu sync.Mutex +} + +func (t *RefreshTransport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *RefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Buffer the request body so it can be replayed on retry. + // http.NewRequest sets GetBody for standard body types, but + // custom readers (e.g. from GraphQL clients) may not. + bufferRequestBody(req) + + resp, err := t.base().RoundTrip(req) + if err != nil { + return resp, err + } + + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + + // Only attempt refresh if we have a refresh token + refreshToken, rtErr := t.Keyring.GetRefreshToken(t.Org) + if rtErr != nil || refreshToken == "" { + return resp, nil + } + + // Extract the token that was used for the failed request so we can + // detect whether another goroutine already refreshed it. + failedToken := extractBearerToken(req.Header.Get("Authorization")) + + // Attempt token refresh (serialised to prevent concurrent refreshes) + t.mu.Lock() + newToken, refreshErr := t.doRefresh(req.Context(), failedToken) + t.mu.Unlock() + + if refreshErr != nil { + fmt.Fprintf(os.Stderr, "Warning: token refresh failed: %v\n", refreshErr) + return resp, nil + } + + // Drain and close the original 401 response body + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + // Clone the request with the new token and retry + retryReq := req.Clone(req.Context()) + retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", newToken)) + + // Re-create the body for the retry + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("failed to get request body for retry: %w", err) + } + retryReq.Body = body + } + + return t.base().RoundTrip(retryReq) +} + +func (t *RefreshTransport) doRefresh(ctx context.Context, failedToken string) (string, error) { + // Compare-after-lock: if the current token differs from the one that + // failed, another goroutine already refreshed successfully. Skip the + // refresh and use the new token. + currentToken := t.TokenSource.Token() + if currentToken != "" && currentToken != failedToken { + return currentToken, nil + } + + // Re-read the refresh token under the lock — it may have been rotated + // by a concurrent refresh. + refreshToken, err := t.Keyring.GetRefreshToken(t.Org) + if err != nil || refreshToken == "" { + return "", fmt.Errorf("no refresh token available") + } + + tokenResp, err := oauth.RefreshAccessToken(ctx, "", "", refreshToken) + if err != nil { + // Only clear the stored refresh token on explicit grant errors + // (invalid/expired/revoked). Transient failures (network, 5xx) + // should not destroy the user's session. + if isTerminalRefreshError(err) { + _ = t.Keyring.DeleteRefreshToken(t.Org) + } + return "", err + } + + // Persist the new access token + if err := t.Keyring.Set(t.Org, tokenResp.AccessToken); err != nil { + return "", fmt.Errorf("failed to store refreshed access token: %w", err) + } + t.TokenSource.SetToken(tokenResp.AccessToken) + + // Rotate the refresh token if a new one was issued + if tokenResp.RefreshToken != "" { + if err := t.Keyring.SetRefreshToken(t.Org, tokenResp.RefreshToken); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to store rotated refresh token: %v\n", err) + } + } + + return tokenResp.AccessToken, nil +} + +// isTerminalRefreshError returns true for OAuth errors that indicate the +// refresh token is permanently invalid and should be cleared. +func isTerminalRefreshError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "invalid_grant") || + strings.Contains(msg, "unauthorized_client") || + strings.Contains(msg, "invalid_client") +} + +// extractBearerToken extracts the token value from a "Bearer " header. +func extractBearerToken(header string) string { + if strings.HasPrefix(header, "Bearer ") { + return header[len("Bearer "):] + } + return header +} + +// bufferRequestBody ensures the request body can be replayed for retries. +// If the body is nil or already replayable (GetBody is set), this is a no-op. +func bufferRequestBody(req *http.Request) { + if req.Body == nil || req.GetBody != nil { + return + } + bodyBytes, err := io.ReadAll(req.Body) + _ = req.Body.Close() + if err != nil { + return + } + req.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader(string(bodyBytes))), nil + } +} diff --git a/internal/http/refresh_transport_test.go b/internal/http/refresh_transport_test.go new file mode 100644 index 00000000..db1b9be0 --- /dev/null +++ b/internal/http/refresh_transport_test.go @@ -0,0 +1,356 @@ +package http + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/buildkite/cli/v3/pkg/keyring" +) + +func TestRefreshTransport_PassesThroughNon401(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + keyring.MockForTesting() + defer keyring.ResetForTesting() + + kr := keyring.New() + _ = kr.Set("test-org", "old-token") + _ = kr.SetRefreshToken("test-org", "refresh-token") + + ts := NewTokenSource("old-token") + + transport := &RefreshTransport{ + Base: http.DefaultTransport, + Org: "test-org", + Keyring: kr, + TokenSource: ts, + } + + req, _ := http.NewRequest("GET", server.URL+"/test", nil) + req.Header.Set("Authorization", "Bearer old-token") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestRefreshTransport_NoRefreshToken_PassesThrough401(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"message":"unauthorized"}`)) + })) + defer server.Close() + + keyring.MockForTesting() + defer keyring.ResetForTesting() + + kr := keyring.New() + _ = kr.Set("test-org", "some-token") + // No refresh token set + + ts := NewTokenSource("some-token") + + transport := &RefreshTransport{ + Base: http.DefaultTransport, + Org: "test-org", + Keyring: kr, + TokenSource: ts, + } + + req, _ := http.NewRequest("GET", server.URL+"/test", nil) + req.Header.Set("Authorization", "Bearer some-token") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401 pass-through, got %d", resp.StatusCode) + } +} + +func TestRefreshTransport_CompareAfterLock_SkipsRedundantRefresh(t *testing.T) { + // This test uses t.Setenv so cannot be parallel. + + keyring.MockForTesting() + defer keyring.ResetForTesting() + + kr := keyring.New() + _ = kr.Set("test-org", "already-refreshed-token") + _ = kr.SetRefreshToken("test-org", "refresh-token") + + // TokenSource already has the new token (simulating another goroutine + // having refreshed it). + ts := NewTokenSource("already-refreshed-token") + + var apiCalls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls.Add(1) + auth := r.Header.Get("Authorization") + if auth == "Bearer already-refreshed-token" { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + return + } + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + // Point BUILDKITE_HOST at a dead port so that if doRefresh is + // incorrectly called, it fails fast instead of hitting a real server. + t.Setenv("BUILDKITE_HOST", "127.0.0.1:1") + + transport := &RefreshTransport{ + Base: http.DefaultTransport, + Org: "test-org", + Keyring: kr, + TokenSource: ts, + } + + // Request with a stale token that triggers 401 + req, _ := http.NewRequest("GET", server.URL+"/test", nil) + req.Header.Set("Authorization", "Bearer stale-token") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 after compare-after-lock skip, got %d", resp.StatusCode) + } + // Should have made exactly 2 API calls: the initial 401 + the retry + if got := apiCalls.Load(); got != 2 { + t.Fatalf("expected 2 API calls (initial + retry), got %d", got) + } +} + +func TestRefreshTransport_DoesNotDeleteRefreshTokenOnTransientError(t *testing.T) { + + keyring.MockForTesting() + defer keyring.ResetForTesting() + + kr := keyring.New() + _ = kr.Set("test-org", "old-token") + _ = kr.SetRefreshToken("test-org", "my-refresh-token") + + ts := NewTokenSource("old-token") + + // API server that always returns 401 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + // Set BUILDKITE_HOST to a non-existent host to simulate a network error + // during the refresh attempt + t.Setenv("BUILDKITE_HOST", "127.0.0.1:1") // connection refused + + transport := &RefreshTransport{ + Base: http.DefaultTransport, + Org: "test-org", + Keyring: kr, + TokenSource: ts, + } + + req, _ := http.NewRequest("GET", server.URL+"/test", nil) + req.Header.Set("Authorization", "Bearer old-token") + + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401 pass-through, got %d", resp.StatusCode) + } + + // The refresh token should NOT have been deleted (transient error) + rt, rtErr := kr.GetRefreshToken("test-org") + if rtErr != nil || rt != "my-refresh-token" { + t.Fatalf("expected refresh token to be preserved after transient error, got %q err=%v", rt, rtErr) + } +} + +func TestRefreshTransport_BuffersAndRetriesPostBody(t *testing.T) { + t.Parallel() + + keyring.MockForTesting() + defer keyring.ResetForTesting() + + kr := keyring.New() + _ = kr.Set("test-org", "old-token") + _ = kr.SetRefreshToken("test-org", "refresh-token") + + ts := NewTokenSource("old-token") + + var apiCalls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := apiCalls.Add(1) + if call == 1 { + w.WriteHeader(http.StatusUnauthorized) + return + } + // Verify body was replayed on retry + body, _ := io.ReadAll(r.Body) + _ = body + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + transport := &RefreshTransport{ + Base: http.DefaultTransport, + Org: "test-org", + Keyring: kr, + TokenSource: ts, + } + + // Simulate a POST with a body that doesn't have GetBody set + body := `{"query":"{ viewer { user { name } } }"}` + req, _ := http.NewRequest("POST", server.URL+"/graphql", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer old-token") + req.Header.Set("Content-Type", "application/json") + // Explicitly clear GetBody to simulate a custom reader + req.GetBody = nil + + // doRefresh will fail (no real token server), but we can verify + // that bufferRequestBody was called by checking the request has GetBody. + // Since the refresh will fail, the 401 is returned, but the body + // buffering is the important part to verify. + resp, _ := transport.RoundTrip(req) + _ = resp + + // Verify GetBody was set by bufferRequestBody + if req.GetBody == nil { + t.Fatal("expected GetBody to be set by bufferRequestBody") + } +} + +func TestRefreshTransport_ConcurrentRequestsOnlyRefreshOnce(t *testing.T) { + // This test uses t.Setenv so cannot be parallel. + + keyring.MockForTesting() + defer keyring.ResetForTesting() + + kr := keyring.New() + _ = kr.Set("test-org", "new-token") + _ = kr.SetRefreshToken("test-org", "refresh-token") + + // TokenSource already has the refreshed token (simulating the first + // goroutine having completed the refresh). + ts := NewTokenSource("new-token") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "Bearer stale-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if auth == "Bearer new-token" { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + return + } + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + // Point BUILDKITE_HOST at a dead port so that if doRefresh is + // incorrectly called (bypassing compare-after-lock), it fails. + t.Setenv("BUILDKITE_HOST", "127.0.0.1:1") + + transport := &RefreshTransport{ + Base: http.DefaultTransport, + Org: "test-org", + Keyring: kr, + TokenSource: ts, + } + + // N goroutines hit 401 with "stale-token" concurrently. + // All should use compare-after-lock to skip refresh and retry + // with the already-refreshed "new-token". + var wg sync.WaitGroup + results := make([]int, 5) + + for i := range 5 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req, _ := http.NewRequest("GET", server.URL+"/test", nil) + req.Header.Set("Authorization", "Bearer stale-token") + resp, err := transport.RoundTrip(req) + if err != nil { + results[idx] = -1 + return + } + results[idx] = resp.StatusCode + }(i) + } + + wg.Wait() + + for i, status := range results { + if status != http.StatusOK { + t.Errorf("goroutine %d: expected 200, got %d", i, status) + } + } +} + +func TestTokenSource_ThreadSafe(t *testing.T) { + t.Parallel() + + ts := NewTokenSource("initial") + + var wg sync.WaitGroup + for range 100 { + wg.Add(2) + go func() { + defer wg.Done() + ts.SetToken("updated") + }() + go func() { + defer wg.Done() + _ = ts.Token() + }() + } + wg.Wait() +} + +func TestIsTerminalRefreshError(t *testing.T) { + t.Parallel() + + tests := []struct { + err string + terminal bool + }{ + {"token refresh error: invalid_grant - Invalid refresh token", true}, + {"token refresh error: unauthorized_client - Client not configured", true}, + {"token refresh error: invalid_client - Invalid client", true}, + {"refresh token request failed: dial tcp: connection refused", false}, + {"refresh token request failed: timeout", false}, + {"failed to parse token response: unexpected end of JSON", false}, + } + + for _, tt := range tests { + got := isTerminalRefreshError(errors.New(tt.err)) + if got != tt.terminal { + t.Errorf("isTerminalRefreshError(%q) = %v, want %v", tt.err, got, tt.terminal) + } + } +} diff --git a/pkg/cmd/factory/factory.go b/pkg/cmd/factory/factory.go index a1666198..91324049 100644 --- a/pkg/cmd/factory/factory.go +++ b/pkg/cmd/factory/factory.go @@ -7,11 +7,14 @@ import ( "net/http" "net/http/httputil" "os" + "regexp" "strings" "github.com/Khan/genqlient/graphql" "github.com/buildkite/cli/v3/cmd/version" "github.com/buildkite/cli/v3/internal/config" + bkhttp "github.com/buildkite/cli/v3/internal/http" + "github.com/buildkite/cli/v3/pkg/keyring" buildkite "github.com/buildkite/go-buildkite/v4" git "github.com/go-git/go-git/v5" ) @@ -88,7 +91,7 @@ func (d *debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if dump, err := httputil.DumpRequestOut(reqCopy, true); err == nil { - fmt.Fprintf(os.Stderr, "DEBUG request uri=%s\n%s\n", req.URL, dump) + fmt.Fprintf(os.Stderr, "DEBUG request uri=%s\n%s\n", req.URL, redactBody(string(dump))) } resp, err := d.transport.RoundTrip(req) @@ -97,12 +100,31 @@ func (d *debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if dump, err := httputil.DumpResponse(resp, true); err == nil { - fmt.Fprintf(os.Stderr, "DEBUG response uri=%s\n%s\n", req.URL, dump) + fmt.Fprintf(os.Stderr, "DEBUG response uri=%s\n%s\n", req.URL, redactBody(string(dump))) } return resp, nil } +// sensitiveBodyPatterns matches token values in form-encoded request bodies +// and JSON response bodies that should be redacted in debug output. +var sensitiveBodyPatterns = regexp.MustCompile( + `((?:refresh_token|access_token|code|code_verifier)=)[^&\s]+` + + `|("(?:access_token|refresh_token|code)":\s*")[^"]+("?)`, +) + +// redactBody replaces sensitive token values in HTTP dumps. +func redactBody(dump string) string { + return sensitiveBodyPatterns.ReplaceAllStringFunc(dump, func(match string) string { + // Form-encoded: key=value + if idx := strings.IndexByte(match, '='); idx > 0 && !strings.HasPrefix(match, `"`) { + return match[:idx+1] + "[REDACTED]" + } + // JSON: "key": "value" + return sensitiveBodyPatterns.ReplaceAllString(match, `${1}[REDACTED]${2}`) + }) +} + // redactHeaders replaces sensitive header values with [REDACTED] func redactHeaders(headers http.Header) { for _, header := range sensitiveHeaders { @@ -121,7 +143,6 @@ func redactHeaders(headers http.Header) { type gqlHTTPClient struct { client *http.Client - token string } func init() { @@ -129,8 +150,8 @@ func init() { } func (a *gqlHTTPClient) Do(req *http.Request) (*http.Response, error) { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) - req.Header.Set("User-Agent", userAgent) + // Auth and User-Agent are injected by AuthTransport in the + // shared HTTP transport chain, so we don't set them here. return a.client.Do(req) } @@ -156,21 +177,54 @@ func New(opts ...FactoryOpt) (*Factory, error) { } } - // Build client options + // Build the HTTP transport chain. + // + // The chain is (outermost first): + // RefreshTransport → AuthTransport → debugTransport → DefaultTransport + // + // AuthTransport reads the current token from a shared TokenSource on + // every request, so after a refresh all subsequent requests (REST and + // GraphQL) immediately use the new token — no stale cached values. + var transport = http.DefaultTransport + + if cfg.debug { + transport = &debugTransport{transport: transport} + } + + tokenSource := bkhttp.NewTokenSource(token) + + transport = &bkhttp.AuthTransport{ + Base: transport, + TokenSource: tokenSource, + UserAgent: userAgent, + } + + // Add refresh transport if a refresh token is available for this org + org := conf.OrganizationSlug() + if cfg.orgOverride != "" { + org = cfg.orgOverride + } + + kr := keyring.New() + if refreshToken, err := kr.GetRefreshToken(org); err == nil && refreshToken != "" { + transport = &bkhttp.RefreshTransport{ + Base: transport, + Org: org, + Keyring: kr, + TokenSource: tokenSource, + } + } + + httpClient := &http.Client{Transport: transport} + + // go-buildkite still needs WithTokenAuth to satisfy its constructor + // requirement, but our AuthTransport is the canonical source of the + // Authorization header. clientOpts := []buildkite.ClientOpt{ buildkite.WithBaseURL(conf.RESTAPIEndpoint()), buildkite.WithTokenAuth(token), buildkite.WithUserAgent(userAgent), - } - - // Use our own debug transport with redacted headers instead of go-buildkite's built-in debug - if cfg.debug { - httpClient := &http.Client{ - Transport: &debugTransport{ - transport: http.DefaultTransport, - }, - } - clientOpts = append(clientOpts, buildkite.WithHTTPClient(httpClient)) + buildkite.WithHTTPClient(httpClient), } buildkiteClient, err := buildkite.NewOpts(clientOpts...) @@ -178,7 +232,7 @@ func New(opts ...FactoryOpt) (*Factory, error) { return nil, fmt.Errorf("creating buildkite client: %w", err) } - graphqlHTTPClient := &gqlHTTPClient{client: http.DefaultClient, token: token} + graphqlHTTPClient := &gqlHTTPClient{client: httpClient} return &Factory{ Config: conf, diff --git a/pkg/keyring/keyring.go b/pkg/keyring/keyring.go index a8b9e8a2..e4eff66f 100644 --- a/pkg/keyring/keyring.go +++ b/pkg/keyring/keyring.go @@ -10,7 +10,8 @@ import ( ) const ( - serviceName = "buildkite-cli" + serviceName = "buildkite-cli" + refreshServiceName = "buildkite-cli-refresh" ) var ( @@ -55,6 +56,30 @@ func (k *Keyring) Delete(org string) error { return keyring.Delete(serviceName, org) } +// SetRefreshToken stores a refresh token for the given organization +func (k *Keyring) SetRefreshToken(org, token string) error { + if !k.useKeyring { + return nil + } + return keyring.Set(refreshServiceName, org, token) +} + +// GetRefreshToken retrieves a refresh token for the given organization +func (k *Keyring) GetRefreshToken(org string) (string, error) { + if !k.useKeyring { + return "", keyring.ErrNotFound + } + return keyring.Get(refreshServiceName, org) +} + +// DeleteRefreshToken removes a refresh token for the given organization +func (k *Keyring) DeleteRefreshToken(org string) error { + if !k.useKeyring { + return nil + } + return keyring.Delete(refreshServiceName, org) +} + // IsAvailable returns true if the system keyring is available func (k *Keyring) IsAvailable() bool { return k.useKeyring diff --git a/pkg/oauth/oauth.go b/pkg/oauth/oauth.go index 73c9c99a..a584687a 100644 --- a/pkg/oauth/oauth.go +++ b/pkg/oauth/oauth.go @@ -155,11 +155,13 @@ type CallbackResult struct { // TokenResponse holds the token exchange response type TokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - Scope string `json:"scope"` - Error string `json:"error,omitempty"` - ErrorDesc string `json:"error_description,omitempty"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + Error string `json:"error,omitempty"` + ErrorDesc string `json:"error_description,omitempty"` } // Flow manages an OAuth authentication flow @@ -366,6 +368,63 @@ func (f *Flow) ExchangeCode(ctx context.Context, code string) (*TokenResponse, e return &tokenResp, nil } +// RefreshAccessToken exchanges a refresh token for a new access token and refresh token. +func RefreshAccessToken(ctx context.Context, host, clientID, refreshToken string) (*TokenResponse, error) { + if host == "" { + if envHost := os.Getenv("BUILDKITE_HOST"); envHost != "" { + host = envHost + } else { + host = DefaultHost + } + } + if clientID == "" { + clientID = DefaultClientID + } + + tokenURL := fmt.Sprintf("https://%s/oauth/token", host) + + data := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {clientID}, + } + + req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("refresh token request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var tokenResp TokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + if tokenResp.Error != "" { + return nil, fmt.Errorf("token refresh error: %s - %s", tokenResp.Error, tokenResp.ErrorDesc) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("no access token in refresh response") + } + + return &tokenResp, nil +} + // Close cleans up the OAuth flow resources func (f *Flow) Close() error { if f.listener != nil { diff --git a/pkg/oauth/refresh_test.go b/pkg/oauth/refresh_test.go new file mode 100644 index 00000000..6b589d90 --- /dev/null +++ b/pkg/oauth/refresh_test.go @@ -0,0 +1,97 @@ +package oauth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRefreshAccessToken_Success(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + if r.URL.Path != "/oauth/token" { + t.Errorf("expected /oauth/token, got %s", r.URL.Path) + } + + if err := r.ParseForm(); err != nil { + t.Fatalf("failed to parse form: %v", err) + } + + if got := r.FormValue("grant_type"); got != "refresh_token" { + t.Errorf("expected grant_type=refresh_token, got %s", got) + } + if got := r.FormValue("refresh_token"); got != "bkur_old_refresh_token" { + t.Errorf("expected refresh_token=bkur_old_refresh_token, got %s", got) + } + if got := r.FormValue("client_id"); got != "test-client" { + t.Errorf("expected client_id=test-client, got %s", got) + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "new_access_token", + "token_type": "Bearer", + "scope": "read_user read_organizations", + "refresh_token": "bkur_new_refresh_token", + "expires_in": 3600 + }`)) + })) + defer server.Close() + + // Override the default HTTP client to trust the test server's TLS cert + origTransport := http.DefaultTransport + http.DefaultTransport = server.Client().Transport + defer func() { http.DefaultTransport = origTransport }() + + // Extract host from the test server URL (strip https://) + host := server.URL[len("https://"):] + + resp, err := RefreshAccessToken(context.Background(), host, "test-client", "bkur_old_refresh_token") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resp.AccessToken != "new_access_token" { + t.Errorf("expected access_token=new_access_token, got %s", resp.AccessToken) + } + if resp.RefreshToken != "bkur_new_refresh_token" { + t.Errorf("expected refresh_token=bkur_new_refresh_token, got %s", resp.RefreshToken) + } + if resp.ExpiresIn != 3600 { + t.Errorf("expected expires_in=3600, got %d", resp.ExpiresIn) + } + if resp.Scope != "read_user read_organizations" { + t.Errorf("expected scope=read_user read_organizations, got %s", resp.Scope) + } +} + +func TestRefreshAccessToken_ErrorResponse(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{ + "error": "invalid_grant", + "error_description": "Invalid refresh token" + }`)) + })) + defer server.Close() + + origTransport := http.DefaultTransport + http.DefaultTransport = server.Client().Transport + defer func() { http.DefaultTransport = origTransport }() + + host := server.URL[len("https://"):] + + _, err := RefreshAccessToken(context.Background(), host, "test-client", "bad-token") + if err == nil { + t.Fatal("expected error, got nil") + } + + expected := "token refresh error: invalid_grant - Invalid refresh token" + if err.Error() != expected { + t.Errorf("expected error %q, got %q", expected, err.Error()) + } +}