diff --git a/CHANGELOG.md b/CHANGELOG.md index 59906913a..3b6f1eb5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,8 @@ ## Release (2025-XX-YY) +- `core`: + - [v0.21.0](core/CHANGELOG.md#v0210) + - **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` + - **Feature:** Support Workload Identity Federation flow - `sfs`: - [v0.2.0](services/sfs/CHANGELOG.md) - **Breaking change:** Remove region configuration in `APIClient` diff --git a/README.md b/README.md index 69d23ae86..9ca8dcace 100644 --- a/README.md +++ b/README.md @@ -234,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information. ## License -Apache 2.0 +Apache 2.0 \ No newline at end of file diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 8b1d2fb86..1e8466cac 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,7 @@ +## v0.21.0 +- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token` +- **Feature:** Support Workload Identity Federation flow + ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key diff --git a/core/VERSION b/core/VERSION index 2c80271d5..759e855fb 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.20.1 +v0.21.0 diff --git a/core/auth/auth.go b/core/auth/auth.go index 568847aea..450361c60 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -51,6 +51,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { return nil, fmt.Errorf("configuring no auth client: %w", err) } return noAuthRoundTripper, nil + } else if cfg.WorkloadIdentityFederation { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" { keyRoundTripper, err := KeyAuth(cfg) if err != nil { @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { cfg = &config.Configuration{} } - // Key flow - rt, err = KeyAuth(cfg) + // WIF flow + rt, err = WorkloadIdentityFederationAuth(cfg) if err != nil { - keyFlowErr := err - // Token flow - rt, err = TokenAuth(cfg) + // Key flow + rt, err = KeyAuth(cfg) if err != nil { - return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + keyFlowErr := err + // Token flow + rt, err = TokenAuth(cfg) + if err != nil { + return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + } } } return rt, nil @@ -221,6 +231,30 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { return client, nil } +// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper +// that can be used to make authenticated requests using an access token +func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) { + wifConfig := clients.WorkloadIdentityFederationFlowConfig{ + TokenUrl: cfg.TokenCustomUrl, + BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, + ClientID: cfg.ServiceAccountEmail, + FederatedTokenFilePath: cfg.ServiceAccountFederatedTokenPath, + TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration, + FederatedToken: cfg.ServiceAccountFederatedToken, + } + + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + wifConfig.HTTPTransport = cfg.HTTPClient.Transport + } + + client := &clients.WorkloadIdentityFederationFlow{} + if err := client.Init(&wifConfig); err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return client, nil +} + // readCredentialsFile reads the credentials file from the specified path and returns Credentials func readCredentialsFile(path string) (*Credentials, error) { if path == "" { diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index a7c776946..b861bf581 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stackitcloud/stackit-sdk-go/core/clients" "github.com/stackitcloud/stackit-sdk-go/core/config" @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) { } }() + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -147,12 +174,19 @@ func TestSetupAuth(t *testing.T) { desc string config *config.Configuration setToken bool + setWorkloadIdentity bool setKeys bool setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool isValid bool }{ + { + desc: "wif_config", + config: nil, + setWorkloadIdentity: true, + isValid: true, + }, { desc: "token_config", config: nil, @@ -241,6 +275,12 @@ func TestSetupAuth(t *testing.T) { t.Setenv("STACKIT_CREDENTIALS_PATH", "") } + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") authRoundTripper, err := SetupAuth(test.config) @@ -253,7 +293,7 @@ func TestSetupAuth(t *testing.T) { t.Fatalf("Test didn't return error on invalid test case") } - if test.isValid && authRoundTripper == nil { + if authRoundTripper == nil && test.isValid { t.Fatalf("Roundtripper returned is nil for valid test case") } }) @@ -381,6 +421,32 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Writing private key to temporary file: %s", err) } + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -409,6 +475,7 @@ func TestDefaultAuth(t *testing.T) { setKeyPaths bool setKeys bool setCredentialsFilePathKey bool + setWorkloadIdentity bool isValid bool expectedFlow string }{ @@ -418,6 +485,14 @@ func TestDefaultAuth(t *testing.T) { isValid: true, expectedFlow: "token", }, + { + desc: "wif_precedes_key_precedes_token", + setToken: true, + setKeyPaths: true, + setWorkloadIdentity: true, + isValid: true, + expectedFlow: "wif", + }, { desc: "key_precedes_token", setToken: true, @@ -475,6 +550,13 @@ func TestDefaultAuth(t *testing.T) { } else { t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "") } + + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") // Get the default authentication client and ensure that it's not nil @@ -501,6 +583,10 @@ func TestDefaultAuth(t *testing.T) { if _, ok := authClient.(*clients.KeyFlow); !ok { t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) } + case "wif": + if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok { + t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) + } } } }) diff --git a/core/clients/auth_flow.go b/core/clients/auth_flow.go new file mode 100644 index 000000000..141d75489 --- /dev/null +++ b/core/clients/auth_flow.go @@ -0,0 +1,84 @@ +package clients + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +const ( + defaultTokenExpirationLeeway = time.Second * 5 +) + +type AuthFlow interface { + RoundTrip(req *http.Request) (*http.Response, error) + GetAccessToken() (string, error) + GetBackgroundTokenRefreshContext() context.Context +} + +// TokenResponseBody is the API response +// when requesting a new token +type TokenResponseBody struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) { + if res == nil { + return nil, fmt.Errorf("received bad response from API") + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + // Fail silently, omit body from error + // We're trying to show error details, so it's unnecessary to fail because of this err + body = []byte{} + } + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: res.StatusCode, + Body: body, + } + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + token := &TokenResponseBody{} + err = json.Unmarshal(body, token) + if err != nil { + return nil, fmt.Errorf("unmarshal token response: %w", err) + } + return token, nil +} + +func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { + if token == "" { + return true, nil + } + + // We can safely use ParseUnverified because we are not authenticating the user at this point. + // We're just checking the expiration time + tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + if err != nil { + return false, fmt.Errorf("parse token: %w", err) + } + + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() + if err != nil { + return false, fmt.Errorf("get expiration timestamp: %w", err) + } + + // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring + // between retrieving the token and upstream systems validating it. + now := time.Now().Add(tokenExpirationLeeway) + return now.After(expirationTimestampNumeric.Time), nil +} diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 589774314..d18d4f0bf 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -4,11 +4,9 @@ import ( "context" "crypto/rsa" "crypto/x509" - "encoding/json" "encoding/pem" "errors" "fmt" - "io" "net/http" "net/url" "regexp" @@ -30,12 +28,10 @@ const ( ServiceAccountKeyPath = "STACKIT_SERVICE_ACCOUNT_KEY_PATH" PrivateKeyPath = "STACKIT_PRIVATE_KEY_PATH" tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive - defaultTokenType = "Bearer" - defaultScope = "" - - defaultTokenExpirationLeeway = time.Second * 5 ) +var _ AuthFlow = &KeyFlow{} + // KeyFlow handles auth with SA key type KeyFlow struct { rt http.RoundTripper @@ -65,16 +61,6 @@ type KeyFlowConfig struct { AuthHTTPClient *http.Client } -// TokenResponseBody is the API response -// when requesting a new token -type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` -} - // ServiceAccountKeyResponse is the API response // when creating a new SA key type ServiceAccountKeyResponse struct { @@ -114,6 +100,7 @@ func (c *KeyFlow) GetServiceAccountEmail() string { } // GetToken returns the token field +// Deprecated: Use GetAccessToken instead func (c *KeyFlow) GetToken() TokenResponseBody { c.tokenMutex.RLock() defer c.tokenMutex.RUnlock() @@ -160,6 +147,7 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // SetToken can be used to set an access and refresh token manually in the client. // The other fields in the token field are determined by inspecting the token or setting default values. +// Deprecated func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { // We can safely use ParseUnverified because we are not authenticating the user, // We are parsing the token just to get the expiration time claim @@ -174,11 +162,10 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { c.tokenMutex.Lock() c.token = &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, - RefreshToken: refreshToken, - TokenType: defaultTokenType, + AccessToken: accessToken, + ExpiresIn: int(exp.Time.Unix()), + Scope: "", + TokenType: "Bearer", } c.tokenMutex.Unlock() return nil @@ -198,12 +185,11 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) { return c.rt.RoundTrip(req) } -// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field +// GetAccessToken returns a short-lived access token and saves the access token in the token field func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") } - var accessToken string c.tokenMutex.RLock() @@ -219,7 +205,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if !accessTokenExpired { return accessToken, nil } - if err = c.recreateAccessToken(); err != nil { + if err = c.createAccessToken(); err != nil { var oapiErr *oapierror.GenericOpenAPIError if ok := errors.As(err, &oapiErr); ok { reg := regexp.MustCompile("Key with kid .*? was not found") @@ -237,6 +223,10 @@ func (c *KeyFlow) GetAccessToken() (string, error) { return accessToken, nil } +func (c *KeyFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + // validate the client is configured well func (c *KeyFlow) validate() error { if c.config.ServiceAccountKey == nil { @@ -269,27 +259,6 @@ func (c *KeyFlow) validate() error { // Flow auth functions -// recreateAccessToken is used to create a new access token -// when the existing one isn't valid anymore -func (c *KeyFlow) recreateAccessToken() error { - var refreshToken string - - c.tokenMutex.RLock() - if c.token != nil { - refreshToken = c.token.RefreshToken - } - c.tokenMutex.RUnlock() - - refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway) - if err != nil { - return err - } - if !refreshTokenExpired { - return c.createAccessTokenWithRefreshToken() - } - return c.createAccessToken() -} - // createAccessToken creates an access token using self signed JWT func (c *KeyFlow) createAccessToken() (err error) { grant := "urn:ietf:params:oauth:grant-type:jwt-bearer" @@ -307,27 +276,14 @@ func (c *KeyFlow) createAccessToken() (err error) { err = fmt.Errorf("close request access token response: %w", tempErr) } }() - return c.parseTokenResponse(res) -} - -// createAccessTokenWithRefreshToken creates an access token using -// an existing pre-validated refresh token -func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { - c.tokenMutex.RLock() - refreshToken := c.token.RefreshToken - c.tokenMutex.RUnlock() - - res, err := c.requestToken("refresh_token", refreshToken) + token, err := parseTokenResponse(res) if err != nil { return err } - defer func() { - tempErr := res.Body.Close() - if tempErr != nil && err == nil { - err = fmt.Errorf("close request access token with refresh token response: %w", tempErr) - } - }() - return c.parseTokenResponse(res) + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil } // generateSelfSignedJWT generates JWT token @@ -338,7 +294,7 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { "jti": uuid.New(), "aud": c.key.Credentials.Aud, "iat": jwt.NewNumericDate(time.Now()), - "exp": jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), } token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims) token.Header["kid"] = c.key.Credentials.Kid @@ -353,11 +309,8 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) { body := url.Values{} body.Set("grant_type", grant) - if grant == "refresh_token" { - body.Set("refresh_token", assertion) - } else { - body.Set("assertion", assertion) - } + body.Set("assertion", assertion) + payload := strings.NewReader(body.Encode()) req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) if err != nil { @@ -367,60 +320,3 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) return c.authClient.Do(req) } - -// parseTokenResponse parses the response from the server -func (c *KeyFlow) parseTokenResponse(res *http.Response) error { - if res == nil { - return fmt.Errorf("received bad response from API") - } - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - // Fail silently, omit body from error - // We're trying to show error details, so it's unnecessary to fail because of this err - body = []byte{} - } - return &oapierror.GenericOpenAPIError{ - StatusCode: res.StatusCode, - Body: body, - } - } - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{} - err = json.Unmarshal(body, c.token) - c.tokenMutex.Unlock() - if err != nil { - return fmt.Errorf("unmarshal token response: %w", err) - } - - return nil -} - -func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { - if token == "" { - return true, nil - } - - // We can safely use ParseUnverified because we are not authenticating the user at this point. - // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) - if err != nil { - return false, fmt.Errorf("parse token: %w", err) - } - - expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() - if err != nil { - return false, fmt.Errorf("get expiration timestamp: %w", err) - } - - // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring - // between retrieving the token and upstream systems validating it. - now := time.Now().Add(tokenExpirationLeeway) - - return now.After(expirationTimestampNumeric.Time), nil -} diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/key_flow_continuous_refresh.go index f5129aa02..702b3695c 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/key_flow_continuous_refresh.go @@ -20,9 +20,9 @@ var ( // Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. // // To terminate this routine, close the context in keyFlow.config.BackgroundTokenRefreshContext. -func continuousRefreshToken(keyflow *KeyFlow) { +func continuousRefreshToken(flow AuthFlow) { refresher := &continuousTokenRefresher{ - keyFlow: keyflow, + flow: flow, timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration, timeBetweenContextCheck: defaultTimeBetweenContextCheck, timeBetweenTries: defaultTimeBetweenTries, @@ -32,7 +32,7 @@ func continuousRefreshToken(keyflow *KeyFlow) { } type continuousTokenRefresher struct { - keyFlow *KeyFlow + flow AuthFlow // Token refresh tries start at [Access token expiration timestamp] - [This duration] timeStartBeforeTokenExpiration time.Duration timeBetweenContextCheck time.Duration @@ -46,22 +46,12 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { // Compute timestamp where we'll refresh token // Access token may be empty at this point, we have to check it var startRefreshTimestamp time.Time - var accessToken string - refresher.keyFlow.tokenMutex.RLock() - if refresher.keyFlow.token != nil { - accessToken = refresher.keyFlow.token.AccessToken - } - refresher.keyFlow.tokenMutex.RUnlock() - if accessToken == "" { - startRefreshTimestamp = time.Now() - } else { - expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() - if err != nil { - return fmt.Errorf("get access token expiration timestamp: %w", err) - } - startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) } + startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) for { err := refresher.waitUntilTimestamp(startRefreshTimestamp) @@ -69,7 +59,7 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { return err } - err = refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err = refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -92,13 +82,14 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { } func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) { - refresher.keyFlow.tokenMutex.RLock() - token := refresher.keyFlow.token.AccessToken - refresher.keyFlow.tokenMutex.RUnlock() + accessToken, err := refresher.flow.GetAccessToken() + if err != nil { + return nil, err + } // We can safely use ParseUnverified because we are not doing authentication of any kind // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + tokenParsed, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } @@ -111,7 +102,7 @@ func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() ( func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error { for time.Now().Before(timestamp) { - err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err := refresher.flow.GetBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -125,7 +116,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.recreateAccessToken() + _, err := refresher.flow.GetAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/key_flow_continuous_refresh_test.go index 7c7ee9565..cfd50e763 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/key_flow_continuous_refresh_test.go @@ -1,18 +1,13 @@ package clients import ( - "bytes" "context" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "testing" "time" "github.com/golang-jwt/jwt/v5" - "github.com/stackitcloud/stackit-sdk-go/core/oapierror" ) @@ -22,9 +17,9 @@ func TestContinuousRefreshToken(t *testing.T) { jwt.TimePrecision = time.Millisecond // Refresher settings - timeStartBeforeTokenExpiration := 500 * time.Millisecond - timeBetweenContextCheck := 10 * time.Millisecond - timeBetweenTries := 100 * time.Millisecond + timeStartBeforeTokenExpiration := 0 * time.Second + timeBetweenContextCheck := 50 * time.Millisecond + timeBetweenTries := 500 * time.Millisecond // All generated acess tokens will have this time to live accessTokensTimeToLive := 1 * time.Second @@ -34,16 +29,20 @@ func TestContinuousRefreshToken(t *testing.T) { contextClosesIn time.Duration doError error expectedNumberDoCalls int - expectedCallRange []int // Optional: for tests that can have variable call counts }{ + { + desc: "update access token never", + contextClosesIn: 900 * time.Millisecond, // Should allow no refresh + expectedNumberDoCalls: 0, + }, { desc: "update access token once", - contextClosesIn: 700 * time.Millisecond, // Should allow one refresh + contextClosesIn: 1900 * time.Millisecond, // Should allow one refresh expectedNumberDoCalls: 1, }, { desc: "update access token twice", - contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes + contextClosesIn: 2900 * time.Millisecond, // Should allow two refreshes expectedNumberDoCalls: 2, }, { @@ -62,14 +61,14 @@ func TestContinuousRefreshToken(t *testing.T) { expectedNumberDoCalls: 0, }, { - desc: "refresh token fails - non-API error", - contextClosesIn: 700 * time.Millisecond, + desc: "refresh token fails - error", + contextClosesIn: 1900 * time.Millisecond, doError: fmt.Errorf("something went wrong"), expectedNumberDoCalls: 1, }, { desc: "refresh token fails - API non-5xx error", - contextClosesIn: 700 * time.Millisecond, + contextClosesIn: 1900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusBadRequest, }, @@ -77,92 +76,35 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - API 5xx error", - contextClosesIn: 800 * time.Millisecond, + contextClosesIn: 2900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusInternalServerError, }, - expectedNumberDoCalls: 3, - expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition + expectedNumberDoCalls: 4, }, } for _, tt := range tests { + tt := tt t.Run(tt.desc, func(t *testing.T) { - accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create access token: %v", err) - } - - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) + t.Parallel() + accessToken, err := signToken(accessTokensTimeToLive) if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - - numberDoCalls := 0 - mockDo := func(_ *http.Request) (resp *http.Response, err error) { - numberDoCalls++ // count refresh attempts - if tt.doError != nil { - return nil, tt.doError - } - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("Do call: failed to create access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil + t.Fatalf("failed to sign access token: %v", err) } - ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) defer cancel() - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, - BackgroundTokenRefreshContext: nil, + authFlow := &fakeAuthFlow{ + backgroundTokenRefreshContext: ctx, + doError: tt.doError, + accessTokensTimeToLive: accessTokensTimeToLive, + accessToken: accessToken, } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessToken, refreshToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, + flow: authFlow, timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, timeBetweenContextCheck: timeBetweenContextCheck, timeBetweenTries: timeBetweenTries, @@ -172,315 +114,56 @@ func TestContinuousRefreshToken(t *testing.T) { if err == nil { t.Fatalf("routine finished with non-nil error") } - - // Check if we have a range of expected calls (for timing-sensitive tests) - if tt.expectedCallRange != nil { - if !contains(tt.expectedCallRange, numberDoCalls) { - t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls) - } - } else if numberDoCalls != tt.expectedNumberDoCalls { + numberDoCalls := authFlow.getTokenCalls() + if numberDoCalls != tt.expectedNumberDoCalls { t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) } }) } } -// Tests if -// - continuousRefreshToken() updates access token using the refresh token -// - The access token can be accessed while continuousRefreshToken() is trying to update it -func TestContinuousRefreshTokenConcurrency(t *testing.T) { - // The times here are in the order of miliseconds (so they run faster) - // For this to work, we need to increase precision of the expiration timestamps - jwt.TimePrecision = time.Millisecond - - // Test plan: - // 1) continuousRefreshToken() will trigger a token update. It will be blocked in the mockDo() routine (defined below) - // 2) After continuousRefreshToken() is blocked, a request will be made using the key flow. That request should carry the access token (shouldn't be blocked just because continuousRefreshToken() is trying to refresh the token) - // 3) After the request is successful, continuousRefreshToken() will be unblocked - // 4) After waiting a bit, a new request will be made using the key flow. That request should carry the new access token - - // Where we're at in the test plan: - // - Starts at 0 - // - Is set to 1 before continuousRefreshToken() is called - // - Is set to 2 once the continuousRefreshToken() is blocked - // - Is set to 3 once the first request goes through and is checked - // - Is set to 4 after a small wait after continuousRefreshToken() is unblocked - currentTestPhase := 0 - - // Used to signal continuousRefreshToken() has been blocked - chanBlockContinuousRefreshToken := make(chan bool) - - // Used to signal continuousRefreshToken() should be unblocked - chanUnblockContinuousRefreshToken := make(chan bool) - - // The access token at the start - accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), - }).SignedString([]byte("token-first")) - if err != nil { - t.Fatalf("failed to create first access token: %v", err) - } - - // The access token that will replace accessTokenFirst - // Has a much longer expiration timestamp - accessTokenSecond, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("token-second")) - if err != nil { - t.Fatalf("failed to create second access token: %v", err) - } - - if accessTokenFirst == accessTokenSecond { - t.Fatalf("created tokens are equal") - } - - // The refresh token used to update the access token - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), +func signToken(expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create refresh token: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() // This cancels the refresher goroutine - - // Extract host from tokenAPI constant for consistency - tokenURL, _ := url.Parse(tokenAPI) - tokenHost := tokenURL.Host - - // The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests - // The bools are used to make sure only one request goes through on each test phase - doTestPhase1RequestDone := false - doTestPhase2RequestDone := false - doTestPhase4RequestDone := false - mockDo := func(req *http.Request) (resp *http.Response, err error) { - // Handle auth requests (token refresh) - if req.URL.Host == tokenHost { - switch currentTestPhase { - default: - // After phase 1, allow additional auth requests but don't fail the test - // This handles the continuous nature of the refresh routine - if currentTestPhase > 1 { - // Return a valid response for any additional auth requests - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("additional-token")) - if err != nil { - t.Fatalf("Do call: failed to create additional access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal additional response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 1: // Call by continuousRefreshToken() - if doTestPhase1RequestDone { - t.Fatalf("Do call: multiple requests during test phase 1") - } - doTestPhase1RequestDone = true - - currentTestPhase = 2 - chanBlockContinuousRefreshToken <- true - - // Wait until continuousRefreshToken() is to be unblocked - <-chanUnblockContinuousRefreshToken - - if currentTestPhase != 3 { - t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase) - } - - // Check required fields are passed - err = req.ParseForm() - if err != nil { - t.Fatalf("Do call: failed to parse body form: %v", err) - } - reqGrantType := req.Form.Get("grant_type") - if reqGrantType != "refresh_token" { - t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType) - } - reqRefreshToken := req.Form.Get("refresh_token") - if reqRefreshToken != refreshToken { - t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set") - } - - // Return response with accessTokenSecond - responseBodyStruct := TokenResponseBody{ - AccessToken: accessTokenSecond, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - } - - // Handle regular HTTP requests - switch currentTestPhase { - default: - t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase) - return nil, nil - case 2: // Call by tokenFlow, first request - if doTestPhase2RequestDone { - t.Fatalf("Do call: multiple requests during test phase 2") - } - doTestPhase2RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "first-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst) - if authHeader != expectedAuthHeader { - t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader) - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - case 4: // Call by tokenFlow, second request - if doTestPhase4RequestDone { - t.Fatalf("Do call: multiple requests during test phase 4") - } - doTestPhase4RequestDone = true - - // Check host and access token - host := req.Host - expectedHost := "second-request-url.com" - if host != expectedHost { - t.Fatalf("Do call: second request expected to have host %q, found %q", expectedHost, host) - } - authHeader := req.Header.Get("Authorization") - if authHeader != fmt.Sprintf("Bearer %s", accessTokenSecond) { - t.Fatalf("Do call: second request didn't carry second access token") - } - - // Return empty response - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte{})), - } - return response, nil - } - } - - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests - // Don't start continuous refresh automatically - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessTokenFirst, refreshToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) - } - - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - - // Create a custom refresher with shorter timing for the test - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, - timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration - timeBetweenContextCheck: 5 * time.Millisecond, - timeBetweenTries: 40 * time.Millisecond, - } - - // TEST START - currentTestPhase = 1 - // Ignore returned error as expected in test - go func() { - _ = refresher.continuousRefreshToken() - }() +} - // Wait until continuousRefreshToken() is blocked - <-chanBlockContinuousRefreshToken +var _ AuthFlow = &fakeAuthFlow{} - if currentTestPhase != 2 { - t.Fatalf("Unexpected test phase %d after continuousRefreshToken() was blocked", currentTestPhase) - } +type fakeAuthFlow struct { + backgroundTokenRefreshContext context.Context + tokenCounter int + doError error + accessTokensTimeToLive time.Duration + accessToken string +} - // Perform first request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://first-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create first request failed: %v", err) - } - resp, err := keyFlow.RoundTrip(req) +func (f *fakeAuthFlow) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, nil +} +func (f *fakeAuthFlow) GetAccessToken() (string, error) { + expired, err := tokenExpired(f.accessToken, 0) if err != nil { - t.Fatalf("Perform first request failed: %v", err) + return "", err } - err = resp.Body.Close() - if err != nil { - t.Fatalf("First request body failed to close: %v", err) + if !expired { + return f.accessToken, nil } - - // Unblock continuousRefreshToken() - currentTestPhase = 3 - chanUnblockContinuousRefreshToken <- true - - // Wait for a bit - time.Sleep(10 * time.Millisecond) - currentTestPhase = 4 - - // Perform second request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, "http://second-request-url.com", http.NoBody) - if err != nil { - t.Fatalf("Create second request failed: %v", err) - } - resp, err = keyFlow.RoundTrip(req) - if err != nil { - t.Fatalf("Second request failed: %v", err) + f.tokenCounter++ + if f.doError != nil { + return "", f.doError } - err = resp.Body.Close() + accessToken, err := signToken(f.accessTokensTimeToLive) if err != nil { - t.Fatalf("Second request body failed to close: %v", err) + return "", f.doError } + f.accessToken = accessToken + return accessToken, nil +} +func (f *fakeAuthFlow) GetBackgroundTokenRefreshContext() context.Context { + return f.backgroundTokenRefreshContext } -func contains(arr []int, val int) bool { - for _, v := range arr { - if v == val { - return true - } - } - return false +func (f *fakeAuthFlow) getTokenCalls() int { + return f.tokenCounter } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 9803f24ee..7c094331e 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -175,11 +175,10 @@ func TestSetToken(t *testing.T) { } if err == nil { expectedKeyFlowToken := &TokenResponseBody{ - AccessToken: accessToken, - ExpiresIn: int(timestamp.Unix()), - RefreshToken: tt.refreshToken, - Scope: defaultScope, - TokenType: defaultTokenType, + AccessToken: accessToken, + ExpiresIn: int(timestamp.Unix()), + Scope: "", + TokenType: "Bearer", } if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) @@ -194,25 +193,25 @@ func TestTokenExpired(t *testing.T) { tests := []struct { desc string tokenInvalid bool - tokenExpiresAt time.Time + tokenDuration time.Duration expectedErr bool expectedIsExpired bool }{ { desc: "token valid", - tokenExpiresAt: time.Now().Add(time.Hour), + tokenDuration: time.Hour, expectedErr: false, expectedIsExpired: false, }, { desc: "token expired", - tokenExpiresAt: time.Now().Add(-time.Hour), + tokenDuration: -time.Hour, expectedErr: false, expectedIsExpired: true, }, { desc: "token almost expired", - tokenExpiresAt: time.Now().Add(tokenExpirationLeeway), + tokenDuration: tokenExpirationLeeway, expectedErr: false, expectedIsExpired: true, }, @@ -228,9 +227,7 @@ func TestTokenExpired(t *testing.T) { var err error token := "foo" if !tt.tokenInvalid { - token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt), - }).SignedString([]byte("test")) + token, err = signToken(tt.tokenDuration) if err != nil { t.Fatalf("failed to create token: %v", err) } @@ -442,10 +439,9 @@ func TestKeyFlow_Do(t *testing.T) { res.Header().Set("Content-Type", "application/json") token := &TokenResponseBody{ - AccessToken: testBearerToken, - ExpiresIn: 2147483647, - RefreshToken: testBearerToken, - TokenType: "Bearer", + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + TokenType: "Bearer", } if err := json.NewEncoder(res.Body).Encode(token); err != nil { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go new file mode 100644 index 000000000..0046ec864 --- /dev/null +++ b/core/clients/workload_identity_flow.go @@ -0,0 +1,250 @@ +package clients + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + clientIDEnv = "STACKIT_SERVICE_ACCOUNT_EMAIL" + FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" + wifTokenEndpointEnv = "STACKIT_IDP_ENDPOINT" + wifTokenExpirationEnv = "STACKIT_IDP_EXPIRATION_SECONDS" + + wifClientAssertionType = "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" + wifGrantType = "client_credentials" + defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" + defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" + defaultWifExpirationToken = "1h" +) + +var ( + _ = getEnvOrDefault(wifTokenExpirationEnv, defaultWifExpirationToken) // Not used yet +) + +func getEnvOrDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +var _ AuthFlow = &WorkloadIdentityFederationFlow{} + +// WorkloadIdentityFlow handles auth with Workload Identity Federation +type WorkloadIdentityFederationFlow struct { + rt http.RoundTripper + authClient *http.Client + config *WorkloadIdentityFederationFlowConfig + + tokenMutex sync.RWMutex + token *TokenResponseBody + + parser *jwt.Parser + + // If the current access token would expire in less than TokenExpirationLeeway, + // the client will refresh it early to prevent clock skew or other timing issues. + tokenExpirationLeeway time.Duration +} + +// KeyFlowConfig is the flow config +type WorkloadIdentityFederationFlowConfig struct { + TokenUrl string + ClientID string + FederatedToken string // Static token string. This is optional, if not set the token will be read from file. + FederatedTokenFilePath string + TokenExpiration string // Not supported yet + BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil + HTTPTransport http.RoundTripper + AuthHTTPClient *http.Client +} + +// GetConfig returns the flow configuration +func (c *WorkloadIdentityFederationFlow) GetConfig() WorkloadIdentityFederationFlowConfig { + if c.config == nil { + return WorkloadIdentityFederationFlowConfig{} + } + return *c.config +} + +// GetAccessToken implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetAccessToken() (string, error) { + if c.rt == nil { + return "", fmt.Errorf("nil http round tripper, please run Init()") + } + var accessToken string + + c.tokenMutex.RLock() + if c.token != nil { + accessToken = c.token.AccessToken + } + c.tokenMutex.RUnlock() + + accessTokenExpired, err := tokenExpired(accessToken, c.tokenExpirationLeeway) + if err != nil { + return "", fmt.Errorf("check access token is expired: %w", err) + } + if !accessTokenExpired { + return accessToken, nil + } + if err = c.createAccessToken(); err != nil { + return "", fmt.Errorf("get new access token: %w", err) + } + + c.tokenMutex.RLock() + accessToken = c.token.AccessToken + c.tokenMutex.RUnlock() + + return accessToken, nil +} + +// RoundTrip implements the http.RoundTripper interface. +// It gets a token, adds it to the request's authorization header, and performs the request. +func (c *WorkloadIdentityFederationFlow) RoundTrip(req *http.Request) (*http.Response, error) { + if c.rt == nil { + return nil, fmt.Errorf("please run Init()") + } + + accessToken, err := c.GetAccessToken() + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return c.rt.RoundTrip(req) +} + +// GetBackgroundTokenRefreshContext implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + +func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlowConfig) error { + // No concurrency at this point, so no mutex check needed + c.token = &TokenResponseBody{} + c.config = cfg + c.parser = jwt.NewParser() + + if c.config.TokenUrl == "" { + c.config.TokenUrl = getEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint) + } + + if c.config.ClientID == "" { + c.config.ClientID = getEnvOrDefault(clientIDEnv, "") + } + + if c.config.FederatedToken == "" && c.config.FederatedTokenFilePath == "" { + c.config.FederatedTokenFilePath = getEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath) + } + + c.tokenExpirationLeeway = defaultTokenExpirationLeeway + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil { + c.authClient = &http.Client{ + Transport: c.rt, + Timeout: DefaultClientTimeout, + } + } + + err := c.validate() + if err != nil { + return err + } + + if c.config.BackgroundTokenRefreshContext != nil { + go continuousRefreshToken(c) + } + return nil +} + +// validate the client is configured well +func (c *WorkloadIdentityFederationFlow) validate() error { + if c.config.ClientID == "" { + return fmt.Errorf("client ID cannot be empty") + } + if c.config.TokenUrl == "" { + return fmt.Errorf("token URL cannot be empty") + } + if c.config.FederatedToken == "" { + if _, err := c.readJWTFromFileSystem(c.config.FederatedTokenFilePath); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } + } + if c.tokenExpirationLeeway < 0 { + return fmt.Errorf("token expiration leeway cannot be negative") + } + + return nil +} + +// createAccessToken creates an access token using self signed JWT +func (c *WorkloadIdentityFederationFlow) createAccessToken() error { + clientAssertion := c.config.FederatedToken + if clientAssertion == "" { + var err error + clientAssertion, err = c.readJWTFromFileSystem(c.config.FederatedTokenFilePath) + if err != nil { + return fmt.Errorf("error reading service account assertion - %w", err) + } + } + + res, err := c.requestToken(c.config.ClientID, clientAssertion) + if err != nil { + return err + } + defer func() { + tempErr := res.Body.Close() + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token response: %w", tempErr) + } + }() + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil +} + +func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string) (*http.Response, error) { + body := url.Values{} + body.Set("grant_type", wifGrantType) + body.Set("client_assertion_type", wifClientAssertionType) + body.Set("client_assertion", assertion) + body.Set("client_id", clientID) + + payload := strings.NewReader(body.Encode()) + req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + return c.authClient.Do(req) +} + +func (c *WorkloadIdentityFederationFlow) readJWTFromFileSystem(tokenFilePath string) (string, error) { + token, err := os.ReadFile(tokenFilePath) + if err != nil { + return "", err + } + tokenStr := string(token) + _, _, err = c.parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", err + } + return tokenStr, nil +} diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go new file mode 100644 index 000000000..4a9e07161 --- /dev/null +++ b/core/clients/workload_identity_flow_test.go @@ -0,0 +1,292 @@ +package clients + +import ( + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func TestWorkloadIdentityFlowInit(t *testing.T) { + tests := []struct { + name string + clientID string + clientIDAsEnv bool + customTokenUrl string + customTokenUrlEnv bool + tokenExpiration string + validAssertion bool + tokenFilePathAsEnv bool + missingTokenFilePath bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "missing client id", + validAssertion: true, + wantErr: true, + }, + { + name: "missing assertion", + clientID: "test@stackit.cloud", + missingTokenFilePath: true, + wantErr: true, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + if tt.customTokenUrl != "" { + if tt.customTokenUrlEnv { + t.Setenv("STACKIT_IDP_ENDPOINT", tt.customTokenUrl) + } else { + flowConfig.TokenUrl = tt.customTokenUrl + } + } + + if tt.clientID != "" { + if tt.clientIDAsEnv { + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", tt.clientID) + } else { + flowConfig.ClientID = tt.clientID + } + } + if tt.tokenExpiration != "" { + flowConfig.TokenExpiration = tt.tokenExpiration + } + + if !tt.missingTokenFilePath { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + if tt.validAssertion { + token, err := signTokenWithSubject("subject", time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } + if tt.tokenFilePathAsEnv { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) + } else { + flowConfig.FederatedTokenFilePath = file.Name() + } + } + + if err := flow.Init(flowConfig); (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr) + } + if flow.config == nil { + t.Error("config is nil") + } + + if flow.config.ClientID != tt.clientID { + t.Errorf("clientID mismatch, want %s, got %s", tt.clientID, flow.config.ClientID) + } + + if tt.customTokenUrl != "" && flow.config.TokenUrl != tt.customTokenUrl { + t.Errorf("tokenUrl mismatch, want %s, got %s", tt.customTokenUrl, flow.config.TokenUrl) + } + + if tt.customTokenUrl == "" && flow.config.TokenUrl != "https://accounts.stackit.cloud/oauth/v2/token" { + t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl) + } + + if tt.missingTokenFilePath && flow.config.FederatedTokenFilePath != "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want %s, got %s", "/var/run/secrets/stackit.cloud/serviceaccount/token", flow.config.FederatedTokenFilePath) + } + + if !tt.missingTokenFilePath && flow.config.FederatedTokenFilePath == "/var/run/secrets/stackit.cloud/serviceaccount/token" { + t.Errorf("clientID mismatch, want different from %s", flow.config.FederatedTokenFilePath) + } + + if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration { + t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration) + } + }) + } +} + +func signTokenWithSubject(sub string, expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + Subject: sub, + }).SignedString([]byte("test")) +} + +func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { + validSub := "valid-sub" + serviceAccountSub := "sa-sub" + tests := []struct { + name string + clientID string + validAssertion bool + injectToken bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "injected token ok", + clientID: "test@stackit.cloud", + validAssertion: true, + injectToken: true, + wantErr: false, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + validAssertion: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertionType := r.PostForm.Get("client_assertion_type") + if assertionType != "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" { + t.Fatalf("invalid assertion type: %s", assertionType) + } + grantType := r.PostForm.Get("grant_type") + if grantType != "client_credentials" { + t.Fatalf("invalid grant type: %s", assertionType) + } + context, _, err := jwt.NewParser().ParseUnverified(r.PostForm.Get("client_assertion"), jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != validSub { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := signTokenWithSubject(serviceAccountSub, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + tokenResponse := &TokenResponseBody{ + AccessToken: token, + ExpiresIn: 60, + TokenType: "Bearer", + } + + payload, err := json.Marshal(tokenResponse) + if err != nil { + t.Fatalf("failed to create token payload: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(payload) + })) + t.Cleanup(authServer.Close) + + protectedResource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context, _, err := jwt.NewParser().ParseUnverified(strings.Fields(r.Header.Get("Authorization"))[1], jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != serviceAccountSub { + t.Fatalf("invalid token on protected resource: %v", err) + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(protectedResource.Close) + + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + flowConfig.TokenUrl = authServer.URL + + flowConfig.ClientID = tt.clientID + + subject := "wrong" + if tt.validAssertion { + subject = validSub + } + token, err := signTokenWithSubject(subject, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + if tt.injectToken { + flowConfig.FederatedToken = token + } else { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + flowConfig.FederatedTokenFilePath = file.Name() + os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + } + + if err := flow.Init(flowConfig); err != nil { + t.Errorf("KeyFlow.Init() error = %v", err) + } + if flow.config == nil { + t.Error("config is nil") + } + + client := http.Client{ + Transport: flow, + } + resp, err := client.Get(protectedResource.URL) + if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { + t.Fatalf("failed request to protected resource: %v", err) + } + }) + } +} diff --git a/core/config/config.go b/core/config/config.go index 93002c02a..dd9dd98f4 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -75,26 +75,30 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` // Deprecated: ServiceAccountEmail is not required and will be removed after 12th June 2025. - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"` + ServiceAccountFederatedToken string `json:"serviceAccountFederatedToken,omitempty"` + ServiceAccountFederatedTokenPath string `json:"serviceAccountFederatedTokenPath,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -176,8 +180,6 @@ func WithTokenEndpoint(url string) ConfigurationOption { } // WithServiceAccountEmail returns a ConfigurationOption that sets the service account email -// -// Deprecated: WithServiceAccountEmail is not required and will be removed after 12th June 2025. func WithServiceAccountEmail(serviceAccountEmail string) ConfigurationOption { return func(config *Configuration) error { config.ServiceAccountEmail = serviceAccountEmail @@ -237,6 +239,30 @@ func WithToken(token string) ConfigurationOption { } } +// WithWorkloadIdentityFederationAuth returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationAuth() ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederation = true + return nil + } +} + +// WithWorkloadIdentityFederation returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationTokenPath(path string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenPath = path + return nil + } +} + +// WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow +func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenExpiration = expiration + return nil + } +} + // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. This option has no effect, and will be removed in a later update func WithMaxRetries(_ int) ConfigurationOption { return func(_ *Configuration) error {