Skip to content
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
## Release (2025-XX-YY)
- `core`:
- [v0.21.0](core/CHANGELOG.md#v0210)
- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token`
- **Feature:** Support Workload Identity Federation flow
- `sfs`:
- [v0.2.0](services/sfs/CHANGELOG.md)
- **Breaking change:** Remove region configuration in `APIClient`
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information.

## License

Apache 2.0
Apache 2.0
4 changes: 4 additions & 0 deletions core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v0.21.0
- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token`
- **Feature:** Support Workload Identity Federation flow

## v0.20.1
- **Improvement:** Improve error message when passing a PEM encoded file to as service account key

Expand Down
2 changes: 1 addition & 1 deletion core/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.20.1
v0.21.0
45 changes: 39 additions & 6 deletions core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
return nil, fmt.Errorf("configuring no auth client: %w", err)
}
return noAuthRoundTripper, nil
} else if cfg.WorkloadIdentityFederation {
wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg)
if err != nil {
return nil, fmt.Errorf("configuring no auth client: %w", err)
}
return wifRoundTripper, nil
} else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" {
keyRoundTripper, err := KeyAuth(cfg)
if err != nil {
Expand Down Expand Up @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) {
cfg = &config.Configuration{}
}

// Key flow
rt, err = KeyAuth(cfg)
// WIF flow
rt, err = WorkloadIdentityFederationAuth(cfg)
if err != nil {
keyFlowErr := err
// Token flow
rt, err = TokenAuth(cfg)
// Key flow
rt, err = KeyAuth(cfg)
if err != nil {
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
keyFlowErr := err
// Token flow
rt, err = TokenAuth(cfg)
if err != nil {
return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err)
}
}
}
return rt, nil
Expand Down Expand Up @@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
return client, nil
}

// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper
// that can be used to make authenticated requests using an access token
func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) {
wifConfig := clients.WorkloadIdentityFederationFlowConfig{
TokenUrl: cfg.TokenCustomUrl,
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
ClientID: cfg.ServiceAccountEmail,
FederatedTokenFilePath: cfg.WorkloadIdentityFederationFederatedTokenPath,
TokenExpiration: cfg.WorkloadIdentityFederationTokenExpiration,
}

if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil {
wifConfig.HTTPTransport = cfg.HTTPClient.Transport
}

client := &clients.WorkloadIdentityFederationFlow{}
if err := client.Init(&wifConfig); err != nil {
return nil, fmt.Errorf("error initializing client: %w", err)
}

return client, nil
}

// readCredentialsFile reads the credentials file from the specified path and returns Credentials
func readCredentialsFile(path string) (*Credentials, error) {
if path == "" {
Expand Down
88 changes: 87 additions & 1 deletion core/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stackitcloud/stackit-sdk-go/core/clients"
"github.com/stackitcloud/stackit-sdk-go/core/config"
Expand Down Expand Up @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) {
}
}()

// create a wif assertion file
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = wifAssertionFile.Close()
err := os.Remove(wifAssertionFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}
}()

token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
Subject: "sub",
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}

_, errs = wifAssertionFile.WriteString(string(token))
if errs != nil {
t.Fatalf("Writing wif assertion to temporary file: %s", err)
}

// create a credentials file with saKey and private key
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
Expand All @@ -147,12 +174,19 @@ func TestSetupAuth(t *testing.T) {
desc string
config *config.Configuration
setToken bool
setWorkloadIdentity bool
setKeys bool
setKeyPaths bool
setCredentialsFilePathToken bool
setCredentialsFilePathKey bool
isValid bool
}{
{
desc: "wif_config",
config: nil,
setWorkloadIdentity: true,
isValid: true,
},
{
desc: "token_config",
config: nil,
Expand Down Expand Up @@ -241,6 +275,12 @@ func TestSetupAuth(t *testing.T) {
t.Setenv("STACKIT_CREDENTIALS_PATH", "")
}

if test.setWorkloadIdentity {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
} else {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
}

t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")

authRoundTripper, err := SetupAuth(test.config)
Expand All @@ -253,7 +293,7 @@ func TestSetupAuth(t *testing.T) {
t.Fatalf("Test didn't return error on invalid test case")
}

if test.isValid && authRoundTripper == nil {
if authRoundTripper == nil && test.isValid {
t.Fatalf("Roundtripper returned is nil for valid test case")
}
})
Expand Down Expand Up @@ -381,6 +421,32 @@ func TestDefaultAuth(t *testing.T) {
t.Fatalf("Writing private key to temporary file: %s", err)
}

// create a wif assertion file
wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
t.Fatalf("Creating temporary file: %s", err)
}
defer func() {
_ = wifAssertionFile.Close()
err := os.Remove(wifAssertionFile.Name())
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}
}()

token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
Subject: "sub",
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("Removing temporary file: %s", err)
}

_, errs = wifAssertionFile.WriteString(string(token))
if errs != nil {
t.Fatalf("Writing wif assertion to temporary file: %s", err)
}

// create a credentials file with saKey and private key
credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt")
if errs != nil {
Expand Down Expand Up @@ -409,6 +475,7 @@ func TestDefaultAuth(t *testing.T) {
setKeyPaths bool
setKeys bool
setCredentialsFilePathKey bool
setWorkloadIdentity bool
isValid bool
expectedFlow string
}{
Expand All @@ -418,6 +485,14 @@ func TestDefaultAuth(t *testing.T) {
isValid: true,
expectedFlow: "token",
},
{
desc: "wif_precedes_key_precedes_token",
setToken: true,
setKeyPaths: true,
setWorkloadIdentity: true,
isValid: true,
expectedFlow: "wif",
},
{
desc: "key_precedes_token",
setToken: true,
Expand Down Expand Up @@ -475,6 +550,13 @@ func TestDefaultAuth(t *testing.T) {
} else {
t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "")
}

if test.setWorkloadIdentity {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name())
} else {
t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "")
}

t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email")

// Get the default authentication client and ensure that it's not nil
Expand All @@ -501,6 +583,10 @@ func TestDefaultAuth(t *testing.T) {
if _, ok := authClient.(*clients.KeyFlow); !ok {
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
}
case "wif":
if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok {
t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient))
}
}
}
})
Expand Down
84 changes: 84 additions & 0 deletions core/clients/auth_flow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package clients

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/stackitcloud/stackit-sdk-go/core/oapierror"
)

const (
defaultTokenExpirationLeeway = time.Second * 5
)

type AuthFlow interface {
RoundTrip(req *http.Request) (*http.Response, error)
GetAccessToken() (string, error)
GetBackgroundTokenRefreshContext() context.Context
}

// TokenResponseBody is the API response
// when requesting a new token
type TokenResponseBody struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}

func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) {
if res == nil {
return nil, fmt.Errorf("received bad response from API")
}
if res.StatusCode != http.StatusOK {
body, err := io.ReadAll(res.Body)
if err != nil {
// Fail silently, omit body from error
// We're trying to show error details, so it's unnecessary to fail because of this err
body = []byte{}
}
return nil, &oapierror.GenericOpenAPIError{
StatusCode: res.StatusCode,
Body: body,
}
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}

token := &TokenResponseBody{}
err = json.Unmarshal(body, token)
if err != nil {
return nil, fmt.Errorf("unmarshal token response: %w", err)
}
return token, nil
}

func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) {
if token == "" {
return true, nil
}

// We can safely use ParseUnverified because we are not authenticating the user at this point.
// We're just checking the expiration time
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
if err != nil {
return false, fmt.Errorf("parse token: %w", err)
}

expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
if err != nil {
return false, fmt.Errorf("get expiration timestamp: %w", err)
}

// Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring
// between retrieving the token and upstream systems validating it.
now := time.Now().Add(tokenExpirationLeeway)
return now.After(expirationTimestampNumeric.Time), nil
}
Loading