Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"os"
"time"

"github.com/alecthomas/kong"
Expand Down Expand Up @@ -155,7 +156,10 @@ func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error {
}

// Resolve org from the API using the new token
client, err := buildkite.NewOpts(buildkite.WithTokenAuth(tokenResp.AccessToken))
client, err := buildkite.NewOpts(
buildkite.WithTokenAuth(tokenResp.AccessToken),
buildkite.WithBaseURL(f.Config.RESTAPIEndpoint()),
)
if err != nil {
return fmt.Errorf("failed to create API client: %w", err)
}
Expand All @@ -174,8 +178,36 @@ func (c *LoginCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error {
return err
}

// Store refresh token if the server issued one
if tokenResp.RefreshToken != "" {
kr := keyring.New()
if kr.IsAvailable() {
if err := kr.SetRefreshToken(org.Slug, tokenResp.RefreshToken); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to store refresh token: %v\n", err)
}
}
}

fmt.Printf("\n✅ Successfully authenticated with organization %q\n", org.Slug)
fmt.Printf(" Scopes: %s\n", tokenResp.Scope)
if tokenResp.RefreshToken != "" {
fmt.Printf(" Token expires in: %s (will refresh automatically)\n", formatDuration(tokenResp.ExpiresIn))
}

return nil
}

func formatDuration(seconds int) string {
if seconds <= 0 {
return "unknown"
}
d := time.Duration(seconds) * time.Second
if d >= time.Hour {
hours := int(d.Hours())
if hours == 1 {
return "1 hour"
}
return fmt.Sprintf("%d hours", hours)
}
return fmt.Sprintf("%d minutes", int(d.Minutes()))
}
2 changes: 2 additions & 0 deletions cmd/auth/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (c *LogoutCmd) logoutAll(f *factory.Factory) error {
if err := kr.Delete(org); err != nil {
fmt.Printf("Warning: could not remove token from keychain for %q: %v\n", org, err)
}
_ = kr.DeleteRefreshToken(org)
}
}

Expand Down Expand Up @@ -64,6 +65,7 @@ func (c *LogoutCmd) logoutOrg(f *factory.Factory) error {
} else {
fmt.Println("Token removed from system keychain.")
}
_ = kr.DeleteRefreshToken(org)
}

fmt.Printf("Logged out of organization %q\n", org)
Expand Down
19 changes: 19 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,25 @@ func (conf *Config) APITokenForOrg(org string) string {
return ""
}

// RefreshTokenForOrg gets the refresh token for a specific organization from the keyring.
func (conf *Config) RefreshTokenForOrg(org string) string {
if org == "" {
return ""
}
kr := keyring.New()
if kr.IsAvailable() {
if token, err := kr.GetRefreshToken(org); err == nil && token != "" {
return token
}
}
return ""
}

// RefreshToken gets the refresh token for the currently selected organization.
func (conf *Config) RefreshToken() string {
return conf.RefreshTokenForOrg(conf.OrganizationSlug())
}

// HasStoredTokenForOrg reports whether a token is stored for org in keyring
// or config files, excluding environment variable overrides.
func (conf *Config) HasStoredTokenForOrg(org string) bool {
Expand Down
225 changes: 225 additions & 0 deletions internal/http/refresh_transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package http

import (
"context"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"

"github.com/buildkite/cli/v3/pkg/keyring"
"github.com/buildkite/cli/v3/pkg/oauth"
)

// TokenSource provides thread-safe access to the current access token.
// It is shared between auth-injection points (REST, GraphQL) and
// RefreshTransport so that a refreshed token is immediately visible
// to all subsequent requests.
type TokenSource struct {
mu sync.RWMutex
token string
}

// NewTokenSource creates a TokenSource initialised with the given token.
func NewTokenSource(token string) *TokenSource {
return &TokenSource{token: token}
}

// Token returns the current access token.
func (ts *TokenSource) Token() string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.token
}

// SetToken updates the current access token.
func (ts *TokenSource) SetToken(token string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.token = token
}

// AuthTransport injects the Authorization header from a TokenSource
// on every outgoing request. It should wrap the base transport so that
// RefreshTransport (which sits outside it) can override the header on
// retries.
type AuthTransport struct {
Base http.RoundTripper
TokenSource *TokenSource
UserAgent string
}

func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token := t.TokenSource.Token()
if token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}
if t.UserAgent != "" {
req.Header.Set("User-Agent", t.UserAgent)
}
base := t.Base
if base == nil {
base = http.DefaultTransport
}
return base.RoundTrip(req)
}

// RefreshTransport wraps an http.RoundTripper to automatically refresh
// expired OAuth access tokens using a stored refresh token.
//
// On a 401 response it:
// 1. Acquires a mutex to serialise concurrent refreshes.
// 2. Checks whether the token has already been refreshed by another
// goroutine (compare-after-lock).
// 3. If not, exchanges the refresh token for new tokens.
// 4. Persists the new tokens and updates the shared TokenSource.
// 5. Retries the original request with the new token.
type RefreshTransport struct {
Base http.RoundTripper
Org string
Keyring *keyring.Keyring
TokenSource *TokenSource

mu sync.Mutex
}

func (t *RefreshTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}

func (t *RefreshTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Buffer the request body so it can be replayed on retry.
// http.NewRequest sets GetBody for standard body types, but
// custom readers (e.g. from GraphQL clients) may not.
bufferRequestBody(req)

resp, err := t.base().RoundTrip(req)
if err != nil {
return resp, err
}

if resp.StatusCode != http.StatusUnauthorized {
return resp, nil
}

// Only attempt refresh if we have a refresh token
refreshToken, rtErr := t.Keyring.GetRefreshToken(t.Org)
if rtErr != nil || refreshToken == "" {
return resp, nil
}

// Extract the token that was used for the failed request so we can
// detect whether another goroutine already refreshed it.
failedToken := extractBearerToken(req.Header.Get("Authorization"))

// Attempt token refresh (serialised to prevent concurrent refreshes)
t.mu.Lock()
newToken, refreshErr := t.doRefresh(req.Context(), failedToken)
t.mu.Unlock()

if refreshErr != nil {
fmt.Fprintf(os.Stderr, "Warning: token refresh failed: %v\n", refreshErr)
return resp, nil
}

// Drain and close the original 401 response body
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()

// Clone the request with the new token and retry
retryReq := req.Clone(req.Context())
retryReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", newToken))

// Re-create the body for the retry
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return nil, fmt.Errorf("failed to get request body for retry: %w", err)
}
retryReq.Body = body
}

return t.base().RoundTrip(retryReq)
}

func (t *RefreshTransport) doRefresh(ctx context.Context, failedToken string) (string, error) {
// Compare-after-lock: if the current token differs from the one that
// failed, another goroutine already refreshed successfully. Skip the
// refresh and use the new token.
currentToken := t.TokenSource.Token()
if currentToken != "" && currentToken != failedToken {
return currentToken, nil
}

// Re-read the refresh token under the lock — it may have been rotated
// by a concurrent refresh.
refreshToken, err := t.Keyring.GetRefreshToken(t.Org)
if err != nil || refreshToken == "" {
return "", fmt.Errorf("no refresh token available")
}

tokenResp, err := oauth.RefreshAccessToken(ctx, "", "", refreshToken)
if err != nil {
// Only clear the stored refresh token on explicit grant errors
// (invalid/expired/revoked). Transient failures (network, 5xx)
// should not destroy the user's session.
if isTerminalRefreshError(err) {
_ = t.Keyring.DeleteRefreshToken(t.Org)
}
return "", err
}

// Persist the new access token
if err := t.Keyring.Set(t.Org, tokenResp.AccessToken); err != nil {
return "", fmt.Errorf("failed to store refreshed access token: %w", err)
}
t.TokenSource.SetToken(tokenResp.AccessToken)

// Rotate the refresh token if a new one was issued
if tokenResp.RefreshToken != "" {
if err := t.Keyring.SetRefreshToken(t.Org, tokenResp.RefreshToken); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to store rotated refresh token: %v\n", err)
}
}

return tokenResp.AccessToken, nil
}

// isTerminalRefreshError returns true for OAuth errors that indicate the
// refresh token is permanently invalid and should be cleared.
func isTerminalRefreshError(err error) bool {
msg := err.Error()
return strings.Contains(msg, "invalid_grant") ||
strings.Contains(msg, "unauthorized_client") ||
strings.Contains(msg, "invalid_client")
}

// extractBearerToken extracts the token value from a "Bearer <token>" header.
func extractBearerToken(header string) string {
if strings.HasPrefix(header, "Bearer ") {
return header[len("Bearer "):]
}
return header
}

// bufferRequestBody ensures the request body can be replayed for retries.
// If the body is nil or already replayable (GetBody is set), this is a no-op.
func bufferRequestBody(req *http.Request) {
if req.Body == nil || req.GetBody != nil {
return
}
bodyBytes, err := io.ReadAll(req.Body)
_ = req.Body.Close()
if err != nil {
return
}
req.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(string(bodyBytes))), nil
}
}
Loading