From f8a99c43ada9d236670b459c63b74a32db6ad2a0 Mon Sep 17 00:00:00 2001 From: kanywst Date: Thu, 19 Mar 2026 02:38:25 +0900 Subject: [PATCH 1/9] feat(oauth2): implement ID-JAG token issuance via Token Exchange (DEP #4600) Add support for Identity Assertion JWT Authorization Grant (draft-ietf-oauth-identity-assertion-authz-grant-02) as a new requested_token_type in the RFC 8693 Token Exchange flow. - ID-JAG JWT with typ=oauth-id-jag+jwt, required claims (iss, sub, aud, client_id, jti, exp, iat) and optional claims (resource, scope) - Per-client policy with default-deny: allowedAudiences, allowedScopes - Scope filtering per Section 4.3.2 (granted scopes may differ from requested; response includes scope when modified) - subject_token audience validation against client_id (Section 4.3) - Public client rejection (Section 8.1) - OIDC Discovery: identity_chaining_requested_token_types_supported - SignWithType on signer interface for custom JWT typ header - Prometheus counters: dex_id_jag_requests_total, dex_id_jag_policy_rejections_total, dex_id_jag_scope_modifications_total - Structured logging on issuance and rejection with decision context Signed-off-by: Takuma Niwa Signed-off-by: kanywst --- cmd/dex/config.go | 19 +- cmd/dex/config_test.go | 108 ++++++- cmd/dex/serve.go | 36 ++- config.yaml.dist | 24 ++ server/handlers.go | 240 +++++++++++++++- server/handlers_test.go | 609 ++++++++++++++++++++++++++++++++++++++++ server/oauth2.go | 58 ++++ server/policy.go | 98 +++++++ server/policy_test.go | 127 +++++++++ server/server.go | 54 ++++ server/signer/local.go | 18 ++ server/signer/mock.go | 4 + server/signer/signer.go | 2 + server/signer/utils.go | 17 ++ server/signer/vault.go | 82 ++++++ 15 files changed, 1470 insertions(+), 26 deletions(-) create mode 100644 server/policy.go create mode 100644 server/policy_test.go diff --git a/cmd/dex/config.go b/cmd/dex/config.go index e80c1c5783..e30ad6aff8 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -56,7 +56,7 @@ type Config struct { // StaticClients cause the server to use this list of clients rather than // querying the storage. Write operations, like creating a client, will fail. - StaticClients []storage.Client `json:"staticClients"` + StaticClients []staticClient `json:"staticClients"` // If enabled, the server will maintain a list of passwords which can be used // to identify a user. @@ -229,6 +229,18 @@ func (p *password) UnmarshalJSON(b []byte) error { return nil } +// staticClient wraps storage.Client with optional per-client ID-JAG policy. +type staticClient struct { + storage.Client + IDJAGPolicies *IDJAGClientPolicy `json:"idJAGPolicies,omitempty"` +} + +// IDJAGClientPolicy configures allowed audiences and scopes for ID-JAG exchange. +type IDJAGClientPolicy struct { + AllowedAudiences []string `json:"allowedAudiences"` + AllowedScopes []string `json:"allowedScopes"` +} + // OAuth2 describes enabled OAuth2 extensions. type OAuth2 struct { // list of allowed grant types, @@ -245,6 +257,8 @@ type OAuth2 struct { PasswordConnector string `json:"passwordConnector"` // PKCE configuration PKCE PKCE `json:"pkce"` + // TokenExchange configures Token Exchange support. + TokenExchange server.TokenExchangeConfig `json:"tokenExchange"` } // PKCE holds the PKCE (Proof Key for Code Exchange) configuration. @@ -641,6 +655,9 @@ type Expiry struct { // IdTokens defines the duration of time for which the IdTokens will be valid. IDTokens string `json:"idTokens"` + // IDJAGTokens defines the duration of time for which ID-JAG tokens will be valid. + IDJAGTokens string `json:"idJAGTokens"` + // AuthRequests defines the duration of time for which the AuthRequests will be valid. AuthRequests string `json:"authRequests"` diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 2a08ab1eb5..5678ffd123 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -184,15 +184,15 @@ additionalFeatures: [ "foo": "bar", }, }, - StaticClients: []storage.Client{ - { + StaticClients: []staticClient{ + {Client: storage.Client{ ID: "example-app", Secret: "ZXhhbXBsZS1hcHAtc2VjcmV0", Name: "Example App", RedirectURIs: []string{ "http://127.0.0.1:5555/callback", }, - }, + }}, }, OAuth2: OAuth2{ AlwaysShowLoginScreen: true, @@ -413,15 +413,15 @@ logger: "foo": "bar", }, }, - StaticClients: []storage.Client{ - { + StaticClients: []staticClient{ + {Client: storage.Client{ ID: "example-app", Secret: "ZXhhbXBsZS1hcHAtc2VjcmV0", Name: "Example App", RedirectURIs: []string{ "http://127.0.0.1:5555/callback", }, - }, + }}, }, OAuth2: OAuth2{ AlwaysShowLoginScreen: true, @@ -675,3 +675,99 @@ enablePasswordDB: true }) } } + +func TestUnmarshalConfigWithIDJAGPolicies(t *testing.T) { + rawConfig := []byte(` +issuer: http://127.0.0.1:5556/dex +storage: + type: memory +web: + http: 0.0.0.0:5556 + +oauth2: + grantTypes: + - authorization_code + - "urn:ietf:params:oauth:grant-type:token-exchange" + tokenExchange: + tokenTypes: + - "urn:ietf:params:oauth:token-type:id_token" + - "urn:ietf:params:oauth:token-type:id-jag" + +expiry: + idJAGTokens: "10m" + +staticClients: + - id: wiki-app + secret: wiki-secret + name: "Wiki Application" + redirectURIs: + - "https://wiki.example/callback" + idJAGPolicies: + allowedAudiences: + - "https://chat.example/" + - "https://calendar.example/" + allowedScopes: + - "chat.read" + - "calendar.read" + - id: plain-app + secret: plain-secret + name: "Plain Application" + redirectURIs: + - "https://plain.example/callback" + +enablePasswordDB: true +`) + + var c Config + data, err := yaml.YAMLToJSON(rawConfig) + if err != nil { + t.Fatalf("failed to convert yaml to json: %v", err) + } + if err := json.Unmarshal(data, &c); err != nil { + t.Fatalf("failed to unmarshal config: %v", err) + } + + // Verify tokenExchange config. + if len(c.OAuth2.TokenExchange.TokenTypes) != 2 { + t.Fatalf("expected 2 token types, got %d", len(c.OAuth2.TokenExchange.TokenTypes)) + } + if !c.OAuth2.TokenExchange.IDJAGEnabled() { + t.Fatal("expected ID-JAG to be enabled") + } + + // Verify expiry. + if c.Expiry.IDJAGTokens != "10m" { + t.Errorf("expected IDJAGTokens=10m, got %q", c.Expiry.IDJAGTokens) + } + + // Verify static clients with idJAGPolicies. + if len(c.StaticClients) != 2 { + t.Fatalf("expected 2 static clients, got %d", len(c.StaticClients)) + } + + wikiClient := c.StaticClients[0] + if wikiClient.Client.ID != "wiki-app" { + t.Errorf("expected wiki-app, got %q", wikiClient.Client.ID) + } + if wikiClient.IDJAGPolicies == nil { + t.Fatal("expected idJAGPolicies for wiki-app, got nil") + } + if len(wikiClient.IDJAGPolicies.AllowedAudiences) != 2 { + t.Fatalf("expected 2 allowed audiences, got %d", len(wikiClient.IDJAGPolicies.AllowedAudiences)) + } + if wikiClient.IDJAGPolicies.AllowedAudiences[0] != "https://chat.example/" { + t.Errorf("expected first audience https://chat.example/, got %q", wikiClient.IDJAGPolicies.AllowedAudiences[0]) + } + if len(wikiClient.IDJAGPolicies.AllowedScopes) != 2 { + t.Fatalf("expected 2 allowed scopes, got %d", len(wikiClient.IDJAGPolicies.AllowedScopes)) + } + if wikiClient.IDJAGPolicies.AllowedScopes[0] != "chat.read" { + t.Errorf("expected first scope chat.read, got %q", wikiClient.IDJAGPolicies.AllowedScopes[0]) + } + + // Client without idJAGPolicies. + plainClient := c.StaticClients[1] + if plainClient.IDJAGPolicies != nil { + t.Errorf("expected nil idJAGPolicies for plain-app, got %+v", plainClient.IDJAGPolicies) + } +} diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 0c26006d39..7602f71665 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -214,7 +214,9 @@ func runServe(options serveOptions) error { logger.Info("config storage", "storage_type", c.Storage.Type) if len(c.StaticClients) > 0 { - for i, client := range c.StaticClients { + storageClients := make([]storage.Client, len(c.StaticClients)) + for i, sc := range c.StaticClients { + client := sc.Client if client.Name == "" { return fmt.Errorf("invalid config: Name field is required for a client") } @@ -225,7 +227,7 @@ func runServe(options serveOptions) error { if client.ID != "" { return fmt.Errorf("invalid config: ID and IDEnv fields are exclusive for client %q", client.ID) } - c.StaticClients[i].ID = os.Getenv(client.IDEnv) + client.ID = os.Getenv(client.IDEnv) } if client.Secret == "" && client.SecretEnv == "" && !client.Public { return fmt.Errorf("invalid config: Secret or SecretEnv field is required for client %q", client.ID) @@ -234,11 +236,12 @@ func runServe(options serveOptions) error { if client.Secret != "" { return fmt.Errorf("invalid config: Secret and SecretEnv fields are exclusive for client %q", client.ID) } - c.StaticClients[i].Secret = os.Getenv(client.SecretEnv) + client.Secret = os.Getenv(client.SecretEnv) } logger.Info("config static client", "client_name", client.Name) + storageClients[i] = client } - s = storage.WithStaticClients(s, c.StaticClients) + s = storage.WithStaticClients(s, storageClients) } if len(c.StaticPasswords) > 0 { passwords := make([]storage.Password, len(c.StaticPasswords)) @@ -387,6 +390,7 @@ func runServe(options serveOptions) error { IDTokensValidFor: idTokensValidFor, MFAProviders: buildMFAProviders(c.MFA.Authenticators, logger), DefaultMFAChain: c.MFA.DefaultMFAChain, + TokenExchange: c.OAuth2.TokenExchange, } if c.Expiry.AuthRequests != "" { @@ -405,6 +409,30 @@ func runServe(options serveOptions) error { logger.Info("config device requests", "valid_for", deviceRequests) serverConfig.DeviceRequestsValidFor = deviceRequests } + if c.Expiry.IDJAGTokens != "" { + idJAGTokens, err := time.ParseDuration(c.Expiry.IDJAGTokens) + if err != nil { + return fmt.Errorf("invalid config value %q for ID-JAG token expiry: %v", c.Expiry.IDJAGTokens, err) + } + logger.Info("config ID-JAG tokens", "valid_for", idJAGTokens) + serverConfig.IDJAGTokensValidFor = idJAGTokens + } + + // Build per-client ID-JAG policies from static client config. + for _, sc := range c.StaticClients { + if sc.IDJAGPolicies != nil { + clientID := sc.Client.ID + if clientID == "" && sc.Client.IDEnv != "" { + clientID = os.Getenv(sc.Client.IDEnv) + } + serverConfig.IDJAGPolicies = append(serverConfig.IDJAGPolicies, server.TokenExchangePolicy{ + ClientID: clientID, + AllowedAudiences: sc.IDJAGPolicies.AllowedAudiences, + AllowedScopes: sc.IDJAGPolicies.AllowedScopes, + }) + } + } + refreshTokenPolicy, err := server.NewRefreshTokenPolicy( logger, c.Expiry.RefreshTokens.DisableRotation, diff --git a/config.yaml.dist b/config.yaml.dist index 807448de11..a6157ef29c 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -89,6 +89,7 @@ web: # deviceRequests: "5m" # signingKeys: "6h" # deprecated, use signer.config.keysRotationPeriod # idTokens: "24h" +# idJAGTokens: "5m" # default: 5m; independent of idTokens # refreshTokens: # disableRotation: false # reuseInterval: "3s" @@ -124,6 +125,14 @@ web: # enforce: false # # Supported code challenge methods. Defaults to ["S256", "plain"]. # codeChallengeMethodsSupported: ["S256", "plain"] +# +# # Token Exchange configuration +# tokenExchange: +# # List of token types enabled for exchange. Adding id-jag enables ID-JAG support. +# # Omitting it (default) disables ID-JAG without affecting other token exchange flows. +# tokenTypes: +# - urn:ietf:params:oauth:token-type:id_token +# - urn:ietf:params:oauth:token-type:id-jag # Static clients registered in Dex by default. # @@ -159,6 +168,21 @@ web: # allowedConnectors: # - github # - google +# +# # Example of a client with ID-JAG token exchange policy +# - id: wiki-app +# secret: wiki-secret +# redirectURIs: +# - 'https://wiki.example/callback' +# name: 'Wiki Application' +# # Per-client ID-JAG policy. Clients without this section cannot obtain ID-JAG tokens. +# idJAGPolicies: +# allowedAudiences: +# - "https://chat.example/" +# - "https://calendar.example/" +# allowedScopes: +# - "chat.read" +# - "calendar.read" # Connectors are used to authenticate users against upstream identity providers. # diff --git a/server/handlers.go b/server/handlers.go index 7bc04c48b0..3802142a8f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -74,21 +74,23 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { } type discovery struct { - Issuer string `json:"issuer"` - Auth string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` - Keys string `json:"jwks_uri"` - UserInfo string `json:"userinfo_endpoint"` - DeviceEndpoint string `json:"device_authorization_endpoint"` - Introspect string `json:"introspection_endpoint"` - GrantTypes []string `json:"grant_types_supported"` - ResponseTypes []string `json:"response_types_supported"` - Subjects []string `json:"subject_types_supported"` - IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` - CodeChallengeAlgs []string `json:"code_challenge_methods_supported"` - Scopes []string `json:"scopes_supported"` - AuthMethods []string `json:"token_endpoint_auth_methods_supported"` - Claims []string `json:"claims_supported"` + Issuer string `json:"issuer"` + Auth string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` + Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` + DeviceEndpoint string `json:"device_authorization_endpoint"` + Introspect string `json:"introspection_endpoint"` + GrantTypes []string `json:"grant_types_supported"` + ResponseTypes []string `json:"response_types_supported"` + Subjects []string `json:"subject_types_supported"` + IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` + CodeChallengeAlgs []string `json:"code_challenge_methods_supported"` + Scopes []string `json:"scopes_supported"` + AuthMethods []string `json:"token_endpoint_auth_methods_supported"` + Claims []string `json:"claims_supported"` + IDJAGSigningAlgs []string `json:"id_jag_signing_alg_values_supported,omitempty"` + IdentityChainingTokenTypes []string `json:"identity_chaining_requested_token_types_supported,omitempty"` } func (s *Server) discoveryHandler(ctx context.Context) (http.HandlerFunc, error) { @@ -134,6 +136,11 @@ func (s *Server) constructDiscovery(ctx context.Context) discovery { d.IDTokenAlgs = []string{string(signingAlg)} } + if s.enableIDJAG { + d.IDJAGSigningAlgs = d.IDTokenAlgs + d.IdentityChainingTokenTypes = []string{tokenTypeIDJAG} + } + for responseType := range s.supportedResponseTypes { d.ResponseTypes = append(d.ResponseTypes, responseType) } @@ -873,6 +880,13 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, userIdentity = &ui } + // Skip approval if user already consented to the requested scopes for this client. + if !authReq.ForceApprovalPrompt && userIdentity != nil { + if scopesCoveredByConsent(userIdentity.Consents[authReq.ClientID], authReq.Scopes) { + return "", true, nil + } + } + // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original // flow would be unable to poll for the result at the /approval endpoint h := hmac.New(sha256.New, authReq.HMACKey) @@ -1798,6 +1812,15 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli return } + if requestedTokenType == tokenTypeIDJAG { + if !s.enableIDJAG { + s.tokenErrHelper(w, errRequestNotSupported, "ID-JAG token exchange is not enabled on this server.", http.StatusBadRequest) + return + } + s.handleIDJAGExchange(w, r, client, subjectToken, subjectTokenType, connID, scopes) + return + } + conn, err := s.getConnector(ctx, connID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err) @@ -1858,6 +1881,193 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli json.NewEncoder(w).Encode(resp) } +// handleIDJAGExchange handles a Token Exchange request with requested_token_type=ID-JAG. +// See: https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/ +func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, client storage.Client, subjectToken, subjectTokenType string, connectorID string, scopes []string) { + ctx := r.Context() + q := r.Form + + audience := q.Get("audience") + resource := q.Get("resource") + requestedScope := strings.Join(scopes, " ") + + // Reject public clients (Section 8.1). + if client.Public { + s.idJAGReject(ctx, w, "rejected", errUnauthorizedClient, "Public clients cannot use ID-JAG token exchange.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "public_client") + return + } + + // connector_id is required for identifying the upstream connector. + if connectorID == "" { + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Missing required parameter connector_id for ID-JAG token exchange.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "missing_connector_id") + return + } + + if _, err := s.getConnector(ctx, connectorID); err != nil { + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "connector_not_found") + return + } + + // audience is required. + if audience == "" { + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Missing required parameter audience for ID-JAG token exchange.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "missing_audience") + return + } + + // subject_token_type must be id_token. + if subjectTokenType != tokenTypeID { + s.idJAGReject(ctx, w, "rejected", errRequestNotSupported, "ID-JAG token exchange requires subject_token_type=urn:ietf:params:oauth:token-type:id_token.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "invalid_subject_token_type") + return + } + + // Extract sub and aud from the subject_token. + sub, tokenAud, err := extractJWTSubAndAud(subjectToken) + if err != nil { + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Invalid subject_token: could not parse JWT claims.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "invalid_subject_token") + return + } + if sub == "" { + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "subject_token missing required sub claim.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "missing_sub") + return + } + + // Validate that the subject_token audience matches the requesting client (Section 4.3). + if !audContains(tokenAud, client.ID) { + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "subject_token audience does not match client_id.", http.StatusBadRequest, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "audience_mismatch") + return + } + + policyResult, policyErr := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes) + if policyResult.Denied { + if s.idJAGPolicyRejectionsTotal != nil { + s.idJAGPolicyRejectionsTotal.WithLabelValues(string(policyResult.DenialReason)).Inc() + } + s.idJAGReject(ctx, w, "denied", errAccessDenied, "", http.StatusForbidden, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", string(policyResult.DenialReason)) + return + } + if policyErr != nil { + s.idJAGReject(ctx, w, "rejected", errServerError, "", http.StatusInternalServerError, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "policy_error") + return + } + + grantedScopes := policyResult.GrantedScopes + grantedScope := strings.Join(grantedScopes, " ") + + scopeModified := requestedScope != grantedScope + if scopeModified && s.idJAGScopeModificationsTotal != nil { + s.idJAGScopeModificationsTotal.Inc() + } + + idJAGToken, jti, expiry, err := s.newIDJAG(ctx, client.ID, sub, audience, resource, grantedScopes) + if err != nil { + s.logger.ErrorContext(ctx, "failed to create ID-JAG token", "err", err) + s.idJAGReject(ctx, w, "rejected", errServerError, "", http.StatusInternalServerError, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "token_creation_failed") + return + } + + s.logger.InfoContext(ctx, "ID-JAG token issued", + "client_id", client.ID, + "connector_id", connectorID, + "audience", audience, + "resource", resource, + "requested_scope", requestedScope, + "granted_scope", grantedScope, + "sub", sub, + "jti", jti, + "decision", "approved", + ) + + if s.idJAGRequestsTotal != nil { + s.idJAGRequestsTotal.WithLabelValues("issued").Inc() + } + + // RFC 8693 §2.2.1: token_type is "N_A" for non-access tokens. + resp := accessTokenResponse{ + AccessToken: idJAGToken, + IssuedTokenType: tokenTypeIDJAG, + TokenType: "N_A", + ExpiresIn: int(time.Until(expiry).Seconds()), + } + + // Per Section 4.3.2: include scope in response if granted scope differs from requested. + if scopeModified { + resp.Scope = grantedScope + } + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// idJAGReject is a helper for ID-JAG rejection responses with structured logging and metrics. +func (s *Server) idJAGReject(ctx context.Context, w http.ResponseWriter, result string, errType, errDesc string, status int, logKVs ...any) { + logKVs = append(logKVs, "decision", "denied") + s.logger.InfoContext(ctx, "ID-JAG token exchange rejected", logKVs...) + + if s.idJAGRequestsTotal != nil { + s.idJAGRequestsTotal.WithLabelValues(result).Inc() + } + + s.tokenErrHelper(w, errType, errDesc, status) +} + +// extractJWTSubAndAud extracts the "sub" and "aud" claims from a JWT without +// verifying the signature. The aud claim may be a string or []string. +func extractJWTSubAndAud(token string) (sub string, aud []string, err error) { + parts := strings.SplitN(token, ".", 3) + if len(parts) != 3 { + return "", nil, fmt.Errorf("malformed JWT: expected 3 parts, got %d", len(parts)) + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", nil, fmt.Errorf("failed to decode JWT payload: %v", err) + } + + var claims struct { + Sub string `json:"sub"` + Aud json.RawMessage `json:"aud"` + } + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return "", nil, fmt.Errorf("failed to unmarshal JWT payload: %v", err) + } + + if len(claims.Aud) > 0 { + var single string + if err := json.Unmarshal(claims.Aud, &single); err == nil { + aud = []string{single} + } else { + var multi []string + if err := json.Unmarshal(claims.Aud, &multi); err == nil { + aud = multi + } + } + } + + return claims.Sub, aud, nil +} + +// audContains reports whether target is in aud. +func audContains(aud []string, target string) bool { + for _, a := range aud { + if a == target { + return true + } + } + return false +} + func (s *Server) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { ctx := r.Context() diff --git a/server/handlers_test.go b/server/handlers_test.go index 64f4035629..e71a966eb5 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -137,6 +137,49 @@ func TestHandleDiscoveryWithES256LocalSigner(t *testing.T) { require.Equal(t, []string{string(jose.ES256)}, res.IDTokenAlgs) } +// TestHandleDiscovery_IDJAG verifies OIDC discovery includes ID-JAG metadata when enabled. +func TestHandleDiscovery_IDJAG(t *testing.T) { + httpServer, server := newTestServer(t, func(c *Config) { + c.TokenExchange = TokenExchangeConfig{ + TokenTypes: []string{tokenTypeIDJAG}, + } + }) + defer httpServer.Close() + + rr := httptest.NewRecorder() + server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)) + require.Equal(t, http.StatusOK, rr.Code) + + var res discovery + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + + // Section 7: identity_chaining_requested_token_types_supported + require.Equal(t, []string{tokenTypeIDJAG}, res.IdentityChainingTokenTypes, + "discovery must include identity_chaining_requested_token_types_supported when ID-JAG is enabled") + + // id_jag_signing_alg_values_supported must match ID token signing algs. + require.Equal(t, res.IDTokenAlgs, res.IDJAGSigningAlgs, + "discovery must include id_jag_signing_alg_values_supported matching ID token algs") +} + +// TestHandleDiscovery_IDJAGDisabled verifies OIDC discovery omits ID-JAG metadata when disabled. +func TestHandleDiscovery_IDJAGDisabled(t *testing.T) { + httpServer, server := newTestServer(t, nil) // ID-JAG not enabled + defer httpServer.Close() + + rr := httptest.NewRecorder() + server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)) + require.Equal(t, http.StatusOK, rr.Code) + + var res discovery + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + + require.Empty(t, res.IdentityChainingTokenTypes, + "discovery must NOT include identity_chaining_requested_token_types_supported when ID-JAG is disabled") + require.Empty(t, res.IDJAGSigningAlgs, + "discovery must NOT include id_jag_signing_alg_values_supported when ID-JAG is disabled") +} + func TestHandleHealthFailure(t *testing.T) { httpServer, server := newTestServer(t, func(c *Config) { c.HealthChecker = gosundheit.New() @@ -1761,6 +1804,572 @@ func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scop return m.refreshIdentity, nil } +// makeTestJWT builds a minimal JWT with the given sub for testing. +// The audience defaults to "client_1". +func makeTestJWT(sub string) string { + return makeTestJWTWithClaims(sub, "client_1") +} + +// makeTestJWTWithClaims builds a JWT with configurable sub and aud. +func makeTestJWTWithClaims(sub, aud string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + claimsJSON := fmt.Sprintf(`{"sub":"%s","iss":"https://issuer.example","aud":"%s","exp":9999999999}`, sub, aud) + payload := base64.RawURLEncoding.EncodeToString([]byte(claimsJSON)) + return header + "." + payload + ".fakesig" +} + +// decodeJWTPayload decodes the payload section of a compact JWT (without signature verification). +func decodeJWTPayload(t *testing.T, token string) map[string]interface{} { + t.Helper() + parts := strings.Split(token, ".") + require.Equal(t, 3, len(parts), "expected compact JWT with 3 parts") + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + var claims map[string]interface{} + require.NoError(t, json.Unmarshal(payloadBytes, &claims)) + return claims +} + +// decodeJWTHeader decodes the header section of a compact JWT. +func decodeJWTHeader(t *testing.T, token string) map[string]interface{} { + t.Helper() + parts := strings.Split(token, ".") + require.Equal(t, 3, len(parts), "expected compact JWT with 3 parts") + headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) + require.NoError(t, err) + var header map[string]interface{} + require.NoError(t, json.Unmarshal(headerBytes, &header)) + return header +} + +// TestExtractJWTSubAndAud tests extractJWTSubAndAud. +func TestExtractJWTSubAndAud(t *testing.T) { + tests := []struct { + name string + token string + wantSub string + wantAud []string + wantErr bool + }{ + { + name: "valid JWT returns sub and aud", + token: makeTestJWT("user-abc-123"), + wantSub: "user-abc-123", + wantAud: []string{"client_1"}, + }, + { + name: "not a JWT (no dots)", + token: "notajwt", + wantErr: true, + }, + { + name: "invalid base64 payload", + token: "aGVhZGVy.!!!.c2ln", + wantErr: true, + }, + { + name: "valid JWT without sub returns empty string", + token: func() string { + h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + p := base64.RawURLEncoding.EncodeToString([]byte(`{"iss":"https://issuer.example"}`)) + return h + "." + p + ".sig" + }(), + wantSub: "", + wantAud: nil, + wantErr: false, + }, + { + name: "aud as array", + token: func() string { + h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + p := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"u1","aud":["a","b"]}`)) + return h + "." + p + ".sig" + }(), + wantSub: "u1", + wantAud: []string{"a", "b"}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sub, aud, err := extractJWTSubAndAud(tc.token) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.wantSub, sub) + require.Equal(t, tc.wantAud, aud) + }) + } +} + +// TestHandleIDJAGExchange_JWTClaims verifies the issued ID-JAG JWT contains all +// required claims per the spec (iss, sub, aud, client_id, jti, exp, iat) and +// uses the correct typ header (oauth-id-jag+jwt). +func TestHandleIDJAGExchange_JWTClaims(t *testing.T) { + subjectToken := makeTestJWT("user-123") + + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{ + TokenTypes: []string{tokenTypeIDJAG}, + } + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://resource-as.example.com"}}, + } + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://resource-as.example.com") + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var res accessTokenResponse + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + + // Response-level checks. + require.Equal(t, "N_A", res.TokenType) + require.Equal(t, tokenTypeIDJAG, res.IssuedTokenType) + require.NotEmpty(t, res.AccessToken) + + // Verify JWT header. + header := decodeJWTHeader(t, res.AccessToken) + require.Equal(t, "oauth-id-jag+jwt", header["typ"], "JWT typ header must be oauth-id-jag+jwt") + require.Equal(t, "RS256", header["alg"]) + + // Verify JWT payload claims. + claims := decodeJWTPayload(t, res.AccessToken) + require.Equal(t, httpServer.URL, claims["iss"], "iss must match server issuer") + require.Equal(t, "user-123", claims["sub"], "sub must be preserved from subject_token") + require.Equal(t, "https://resource-as.example.com", claims["aud"], "aud must be the requested audience") + require.Equal(t, "client_1", claims["client_id"], "client_id must be the requesting client") + require.NotEmpty(t, claims["jti"], "jti must be present") + require.NotZero(t, claims["exp"], "exp must be set") + require.NotZero(t, claims["iat"], "iat must be set") + + // Verify expires_in is approximately 5 minutes (default). + require.InDelta(t, 300, res.ExpiresIn, 5, "expires_in should be ~300s (5m default)") +} + +// TestHandleIDJAGExchange_ResourceAndScope verifies that the resource parameter +// and scopes are correctly passed through to the JWT claims, and that scope +// reduction by policy produces the scope field in the response. +func TestHandleIDJAGExchange_ResourceAndScope(t *testing.T) { + subjectToken := makeTestJWT("user-456") + + t.Run("resource parameter appears in JWT", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://chat.example/"}}, + } + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://chat.example/") + vals.Set("resource", "https://chat.example/api/v1") + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var res accessTokenResponse + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + claims := decodeJWTPayload(t, res.AccessToken) + require.Equal(t, "https://chat.example/api/v1", claims["resource"], "resource claim must match request") + }) + + t.Run("scope in JWT and response when all scopes allowed", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://chat.example/"}, AllowedScopes: []string{"chat.read", "chat.write"}}, + } + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://chat.example/") + vals.Set("scope", "chat.read chat.write") + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var res accessTokenResponse + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + + claims := decodeJWTPayload(t, res.AccessToken) + require.Equal(t, "chat.read chat.write", claims["scope"], "scope claim must contain granted scopes") + // When all requested scopes are granted, scope should NOT appear in response. + require.Empty(t, res.Scope, "scope in response should be empty when identical to requested") + }) + + t.Run("policy reduces scopes: scope in response and JWT reflect granted only", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + // Policy only allows chat.read, not chat.write. + {ClientID: "client_1", AllowedAudiences: []string{"https://chat.example/"}, AllowedScopes: []string{"chat.read"}}, + } + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://chat.example/") + vals.Set("scope", "chat.read chat.write") // request both; only chat.read should be granted + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var res accessTokenResponse + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + + // Response must include scope field when granted != requested (Section 4.3.2). + require.Equal(t, "chat.read", res.Scope, "response scope must contain only granted scopes") + + claims := decodeJWTPayload(t, res.AccessToken) + require.Equal(t, "chat.read", claims["scope"], "JWT scope claim must contain only granted scopes") + }) +} + +// TestHandleIDJAGExchange_SecurityBoundaries verifies security-critical rejection paths. +func TestHandleIDJAGExchange_SecurityBoundaries(t *testing.T) { + t.Run("public client rejected (Section 8.1)", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "public_client", + Public: true, + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "public_client", AllowedAudiences: []string{"https://resource.example.com"}}, + } + }) + defer httpServer.Close() + + subjectToken := makeTestJWTWithClaims("user-1", "public_client") + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://resource.example.com") + vals.Set("client_id", "public_client") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusBadRequest, rr.Code) + require.Contains(t, rr.Body.String(), "unauthorized_client") + }) + + t.Run("subject_token audience mismatch with client_id", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://resource.example.com"}}, + } + }) + defer httpServer.Close() + + // Subject token has aud="other_client", but we authenticate as client_1. + subjectToken := makeTestJWTWithClaims("user-1", "other_client") + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://resource.example.com") + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusBadRequest, rr.Code, "body: %s", rr.Body.String()) + require.Contains(t, rr.Body.String(), "invalid_request") + }) + + t.Run("default-deny: no policy configured returns 403", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + // No IDJAGPolicies — should be denied. + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", makeTestJWT("user-1")) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://resource.example.com") + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusForbidden, rr.Code) + }) + + t.Run("policy denies audience", func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://other.example.com"}}, + } + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", makeTestJWT("user-1")) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://resource-as.example.com") // not in allowed list + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + require.Equal(t, http.StatusForbidden, rr.Code) + }) +} + +// TestHandleIDJAGExchange_ValidationErrors verifies parameter validation. +func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { + subjectToken := makeTestJWT("user-123") + + tests := []struct { + name string + audience string + connectorID string + subjectTokenType string + subjectToken string + enableIDJAG bool + wantCode int + wantErrContains string + }{ + { + name: "missing audience returns 400", + audience: "", + connectorID: "mock", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + enableIDJAG: true, + wantCode: http.StatusBadRequest, + }, + { + name: "wrong subject_token_type returns 400", + audience: "https://resource.example.com", + connectorID: "mock", + subjectTokenType: tokenTypeAccess, + subjectToken: subjectToken, + enableIDJAG: true, + wantCode: http.StatusBadRequest, + }, + { + name: "missing connector_id returns 400", + audience: "https://resource.example.com", + connectorID: "", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + enableIDJAG: true, + wantCode: http.StatusBadRequest, + wantErrContains: "connector_id", + }, + { + name: "nonexistent connector_id returns 400", + audience: "https://resource.example.com", + connectorID: "nonexistent", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + enableIDJAG: true, + wantCode: http.StatusBadRequest, + }, + { + name: "ID-JAG disabled returns 400", + audience: "https://resource.example.com", + connectorID: "mock", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + enableIDJAG: false, + wantCode: http.StatusBadRequest, + wantErrContains: "not enabled", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + if tc.enableIDJAG { + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://resource.example.com"}}, + } + } + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tc.subjectTokenType) + vals.Set("subject_token", tc.subjectToken) + if tc.connectorID != "" { + vals.Set("connector_id", tc.connectorID) + } + if tc.audience != "" { + vals.Set("audience", tc.audience) + } + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + + require.Equal(t, tc.wantCode, rr.Code, "body: %s", rr.Body.String()) + if tc.wantErrContains != "" { + require.Contains(t, rr.Body.String(), tc.wantErrContains) + } + }) + } +} + +// TestHandleIDJAGExchange_CustomExpiry verifies that IDJAGTokensValidFor is honored. +func TestHandleIDJAGExchange_CustomExpiry(t *testing.T) { + subjectToken := makeTestJWT("user-789") + + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{TokenTypes: []string{tokenTypeIDJAG}} + c.IDJAGPolicies = []TokenExchangePolicy{ + {ClientID: "client_1", AllowedAudiences: []string{"https://resource.example.com"}}, + } + c.IDJAGTokensValidFor = 10 * time.Minute // custom: 10 minutes instead of default 5 + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tokenTypeID) + vals.Set("subject_token", subjectToken) + vals.Set("connector_id", "mock") + vals.Set("audience", "https://resource.example.com") + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String()) + + var res accessTokenResponse + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + require.InDelta(t, 600, res.ExpiresIn, 5, "expires_in should be ~600s (10m custom)") +} + func TestFilterConnectors(t *testing.T) { connectors := []storage.Connector{ {ID: "github", Type: "github", Name: "GitHub"}, diff --git a/server/oauth2.go b/server/oauth2.go index 40f8063b3f..fcc394033d 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -182,6 +182,8 @@ const ( tokenTypeSAML1 = "urn:ietf:params:oauth:token-type:saml1" tokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2" tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" + // https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/ + tokenTypeIDJAG = "urn:ietf:params:oauth:token-type:id-jag" ) const ( @@ -300,6 +302,62 @@ func (s *Server) newAccessToken(ctx context.Context, clientID string, claims sto return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID, authTime) } +// idJAGTyp is the JWT "typ" header value for ID-JAG tokens. +const idJAGTyp = "oauth-id-jag+jwt" + +// idJAGClaims is the JWT payload for an ID-JAG token. +type idJAGClaims struct { + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience string `json:"aud"` + ClientID string `json:"client_id"` + JTI string `json:"jti"` + Expiry int64 `json:"exp"` + IssuedAt int64 `json:"iat"` + + Resource string `json:"resource,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// newIDJAG creates an ID-JAG token with the given subject and audience. +func (s *Server) newIDJAG( + ctx context.Context, + clientID string, + subject string, + audience string, + resource string, + scopes []string, +) (token string, jti string, expiry time.Time, err error) { + issuedAt := s.now() + expiry = issuedAt.Add(s.idJAGTokensValidFor) + + jti = storage.NewID() + claims := idJAGClaims{ + Issuer: s.issuerURL.String(), + Subject: subject, + Audience: audience, + ClientID: clientID, + JTI: jti, + Expiry: expiry.Unix(), + IssuedAt: issuedAt.Unix(), + Resource: resource, + } + + if len(scopes) > 0 { + claims.Scope = strings.Join(scopes, " ") + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", "", expiry, fmt.Errorf("could not serialize ID-JAG claims: %v", err) + } + + if token, err = s.signer.SignWithType(ctx, payload, idJAGTyp); err != nil { + return "", "", expiry, fmt.Errorf("failed to sign ID-JAG payload: %v", err) + } + return token, jti, expiry, nil +} + func getClientID(aud audience, azp string) (string, error) { switch len(aud) { case 0: diff --git a/server/policy.go b/server/policy.go new file mode 100644 index 0000000000..577023e166 --- /dev/null +++ b/server/policy.go @@ -0,0 +1,98 @@ +package server + +import ( + "fmt" +) + +// PolicyDenialReason categorizes why an ID-JAG policy check failed. +type PolicyDenialReason string + +const ( + PolicyDenialClientHasNoPolicy PolicyDenialReason = "client_has_no_policy" + PolicyDenialAudienceNotAllowed PolicyDenialReason = "audience_not_allowed" +) + +// PolicyResult holds the outcome of an ID-JAG policy evaluation. +type PolicyResult struct { + Denied bool + DenialReason PolicyDenialReason + // GrantedScopes is the set of scopes that passed policy evaluation. + // May be smaller than the requested scopes if policy restricts them. + GrantedScopes []string +} + +// TokenExchangePolicy defines per-client access control for ID-JAG token exchange. +type TokenExchangePolicy struct { + // ClientID is the client this policy applies to. Use "*" for a default policy. + ClientID string `json:"clientID"` + AllowedAudiences []string `json:"allowedAudiences"` + AllowedScopes []string `json:"allowedScopes"` +} + +// evaluateIDJAGPolicy checks whether the client is permitted to obtain an ID-JAG +// for the given audience and scopes. Clients without a matching policy are denied +// by default (default-deny). +func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience string, scopes []string) (PolicyResult, error) { + // Find the most-specific policy for this client: exact match first, then wildcard. + var matched *TokenExchangePolicy + for i := range policies { + p := &policies[i] + if p.ClientID == clientID { + matched = p + break + } + if p.ClientID == "*" && matched == nil { + matched = p + } + } + + if matched == nil { + return PolicyResult{ + Denied: true, + DenialReason: PolicyDenialClientHasNoPolicy, + }, fmt.Errorf("no ID-JAG policy found for client %q: access_denied", clientID) + } + + // Check audience. + if !audienceAllowed(matched.AllowedAudiences, audience) { + return PolicyResult{ + Denied: true, + DenialReason: PolicyDenialAudienceNotAllowed, + }, fmt.Errorf("audience %q is not allowed for client %q: access_denied", audience, clientID) + } + + // Filter scopes: if the policy restricts scopes, only grant those that are allowed. + grantedScopes := scopes + if len(matched.AllowedScopes) > 0 && len(scopes) > 0 { + var filtered []string + for _, scope := range scopes { + if scopeAllowed(matched.AllowedScopes, scope) { + filtered = append(filtered, scope) + } + } + grantedScopes = filtered + } + + return PolicyResult{ + Denied: false, + GrantedScopes: grantedScopes, + }, nil +} + +func audienceAllowed(allowedAudiences []string, audience string) bool { + for _, a := range allowedAudiences { + if a == audience { + return true + } + } + return false +} + +func scopeAllowed(allowedScopes []string, scope string) bool { + for _, s := range allowedScopes { + if s == scope { + return true + } + } + return false +} diff --git a/server/policy_test.go b/server/policy_test.go new file mode 100644 index 0000000000..39e35ba979 --- /dev/null +++ b/server/policy_test.go @@ -0,0 +1,127 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEvaluateIDJAGPolicy(t *testing.T) { + tests := []struct { + name string + policies []TokenExchangePolicy + clientID string + audience string + scopes []string + wantErr bool + wantDenialReason PolicyDenialReason + wantGrantedScopes []string + }{ + { + name: "no policies: default-deny", + policies: nil, + clientID: "any-client", + audience: "https://resource.example.com", + wantErr: true, + wantDenialReason: PolicyDenialClientHasNoPolicy, + }, + { + name: "exact match allowed", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "audience not allowed", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://other.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + wantErr: true, + wantDenialReason: PolicyDenialAudienceNotAllowed, + }, + { + name: "client not found: denied", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "unknown-client", + audience: "https://resource.example.com", + wantErr: true, + wantDenialReason: PolicyDenialClientHasNoPolicy, + }, + { + name: "wildcard client matches", + policies: []TokenExchangePolicy{ + {ClientID: "*", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "any-client", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "exact match takes priority over wildcard", + policies: []TokenExchangePolicy{ + {ClientID: "*", AllowedAudiences: []string{"https://other.example.com"}}, + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "scope filtered by policy", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}, AllowedScopes: []string{"read"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + scopes: []string{"read", "admin"}, + wantErr: false, + wantGrantedScopes: []string{"read"}, + }, + { + name: "allowed scope passes", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}, AllowedScopes: []string{"read", "write"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + scopes: []string{"read"}, + wantErr: false, + wantGrantedScopes: []string{"read"}, + }, + { + name: "no scope restriction: all scopes granted", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + scopes: []string{"anything"}, + wantErr: false, + wantGrantedScopes: []string{"anything"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := evaluateIDJAGPolicy(tc.policies, tc.clientID, tc.audience, tc.scopes) + if tc.wantErr { + require.Error(t, err) + require.True(t, result.Denied) + require.Equal(t, tc.wantDenialReason, result.DenialReason) + } else { + require.NoError(t, err) + require.False(t, result.Denied) + if tc.wantGrantedScopes != nil { + require.Equal(t, tc.wantGrantedScopes, result.GrantedScopes) + } + } + }) + } +} diff --git a/server/server.go b/server/server.go index 4d55d5eeff..e57bbae1fd 100644 --- a/server/server.go +++ b/server/server.go @@ -107,6 +107,7 @@ type Config struct { AlwaysShowLoginScreen bool IDTokensValidFor time.Duration // Defaults to 24 hours + IDJAGTokensValidFor time.Duration // Defaults to 5 minutes AuthRequestsValidFor time.Duration // Defaults to 24 hours DeviceRequestsValidFor time.Duration // Defaults to 5 minutes @@ -139,6 +140,11 @@ type Config struct { // This allows the server to operate with a subset of connectors if some are misconfigured. ContinueOnConnectorFailure bool + // TokenExchange configures Token Exchange support. + TokenExchange TokenExchangeConfig + + IDJAGPolicies []TokenExchangePolicy + // SessionConfig holds session settings. Nil when sessions are disabled. SessionConfig *SessionConfig @@ -149,6 +155,21 @@ type Config struct { DefaultMFAChain []string } +// TokenExchangeConfig holds configuration for Token Exchange support. +type TokenExchangeConfig struct { + TokenTypes []string `json:"tokenTypes"` +} + +// IDJAGEnabled reports whether the ID-JAG token type is enabled. +func (c TokenExchangeConfig) IDJAGEnabled() bool { + for _, t := range c.TokenTypes { + if t == "urn:ietf:params:oauth:token-type:id-jag" { + return true + } + } + return false +} + // SessionConfig holds resolved session configuration. type SessionConfig struct { CookieName string @@ -246,6 +267,15 @@ type Server struct { signer signer.Signer + enableIDJAG bool + idJAGTokensValidFor time.Duration + tokenExchangePolicies []TokenExchangePolicy + + // ID-JAG Prometheus metrics (nil when PrometheusRegistry is not set). + idJAGRequestsTotal *prometheus.CounterVec + idJAGPolicyRejectionsTotal *prometheus.CounterVec + idJAGScopeModificationsTotal prometheus.Counter + sessionConfig *SessionConfig mfaProviders map[string]MFAProvider @@ -355,6 +385,8 @@ func newServer(ctx context.Context, c Config) (*Server, error) { now = time.Now } + idJAGTokensValidFor := value(c.IDJAGTokensValidFor, 5*time.Minute) + s := &Server{ issuerURL: *issuerURL, connectors: make(map[string]Connector), @@ -373,6 +405,9 @@ func newServer(ctx context.Context, c Config) (*Server, error) { passwordConnector: c.PasswordConnector, logger: c.Logger, signer: c.Signer, + enableIDJAG: c.TokenExchange.IDJAGEnabled(), + idJAGTokensValidFor: idJAGTokensValidFor, + tokenExchangePolicies: c.IDJAGPolicies, sessionConfig: c.SessionConfig, mfaProviders: c.MFAProviders, defaultMFAChain: c.DefaultMFAChain, @@ -433,6 +468,25 @@ func newServer(ctx context.Context, c Config) (*Server, error) { c.PrometheusRegistry.MustRegister(requestCounter, durationHist, sizeHist) + // ID-JAG metrics. + s.idJAGRequestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "dex_id_jag_requests_total", + Help: "Total number of ID-JAG token exchange requests.", + }, []string{"result"}) + s.idJAGPolicyRejectionsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "dex_id_jag_policy_rejections_total", + Help: "Total number of ID-JAG policy rejections by reason.", + }, []string{"reason"}) + s.idJAGScopeModificationsTotal = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "dex_id_jag_scope_modifications_total", + Help: "Total number of ID-JAG requests where policy reduced the requested scopes.", + }) + c.PrometheusRegistry.MustRegister( + s.idJAGRequestsTotal, + s.idJAGPolicyRejectionsTotal, + s.idJAGScopeModificationsTotal, + ) + instrumentHandler = func(handlerName string, handler http.Handler) http.HandlerFunc { return promhttp.InstrumentHandlerDuration(durationHist.MustCurryWith(prometheus.Labels{"handler": handlerName}), promhttp.InstrumentHandlerCounter(requestCounter.MustCurryWith(prometheus.Labels{"handler": handlerName}), diff --git a/server/signer/local.go b/server/signer/local.go index 30fd37d31e..afddb80b37 100644 --- a/server/signer/local.go +++ b/server/signer/local.go @@ -115,6 +115,24 @@ func (l *localSigner) Sign(ctx context.Context, payload []byte) (string, error) return signPayload(signingKey, signingAlg, payload) } +func (l *localSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) { + keys, err := l.storage.GetKeys(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys: %v", err) + } + + signingKey := keys.SigningKey + if signingKey == nil { + return "", fmt.Errorf("no key to sign payload with") + } + signingAlg, err := signatureAlgorithm(signingKey) + if err != nil { + return "", err + } + + return signPayloadWithType(signingKey, signingAlg, payload, tokenType) +} + func (l *localSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { keys, err := l.storage.GetKeys(ctx) if err != nil { diff --git a/server/signer/mock.go b/server/signer/mock.go index 832a9be87c..cfdf00471b 100644 --- a/server/signer/mock.go +++ b/server/signer/mock.go @@ -59,6 +59,10 @@ func (m *mockSigner) Sign(_ context.Context, payload []byte) (string, error) { return signPayload(m.key, jose.RS256, payload) } +func (m *mockSigner) SignWithType(_ context.Context, payload []byte, tokenType string) (string, error) { + return signPayloadWithType(m.key, jose.RS256, payload, tokenType) +} + func (m *mockSigner) ValidationKeys(_ context.Context) ([]*jose.JSONWebKey, error) { return []*jose.JSONWebKey{m.pubKey}, nil } diff --git a/server/signer/signer.go b/server/signer/signer.go index 1e15bbd196..801aab9fca 100644 --- a/server/signer/signer.go +++ b/server/signer/signer.go @@ -10,6 +10,8 @@ import ( type Signer interface { // Sign signs the provided payload. Sign(ctx context.Context, payload []byte) (string, error) + // SignWithType signs the provided payload with a custom JWT "typ" header. + SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) // ValidationKeys returns the current public keys used for signature validation. ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) // Algorithm returns the signing algorithm used by this signer. diff --git a/server/signer/utils.go b/server/signer/utils.go index 92926d5b57..111c865e15 100644 --- a/server/signer/utils.go +++ b/server/signer/utils.go @@ -72,3 +72,20 @@ func signPayload(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []by } return signature.CompactSerialize() } + +func signPayloadWithType(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []byte, tokenType string) (jws string, err error) { + signingKey := jose.SigningKey{Key: key, Algorithm: alg} + + opts := &jose.SignerOptions{} + opts.WithType(jose.ContentType(tokenType)) + + signer, err := jose.NewSigner(signingKey, opts) + if err != nil { + return "", fmt.Errorf("new signer: %v", err) + } + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + return signature.CompactSerialize() +} diff --git a/server/signer/vault.go b/server/signer/vault.go index ba175f2775..a073e8639c 100644 --- a/server/signer/vault.go +++ b/server/signer/vault.go @@ -179,6 +179,88 @@ func (v *vaultSigner) Sign(ctx context.Context, payload []byte) (string, error) return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil } +func (v *vaultSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) { + // 1. Fetch keys to determine the key to use (latest version) and its ID. + keysMap, latestVersion, err := v.getTransitKeysMap(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys for signing context: %v", err) + } + + // Determine the key version and ID to use + signingJWK, ok := keysMap[latestVersion] + if !ok { + return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion) + } + + // 2. Construct JWS Header with custom typ and Payload first (Signing Input) + header := map[string]interface{}{ + "alg": signingJWK.Algorithm, + "kid": signingJWK.KeyID, + "typ": tokenType, + } + + headerBytes, err := json.Marshal(header) + if err != nil { + return "", fmt.Errorf("failed to marshal header: %v", err) + } + + headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes) + payloadB64 := base64.RawURLEncoding.EncodeToString(payload) + + // The input to the signature is "header.payload" + signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64) + + // 3. Sign the signingInput using Vault + var vaultInput string + data := map[string]interface{}{} + + // Determine Vault params based on JWS algorithm + params, err := getVaultParams(signingJWK.Algorithm) + if err != nil { + return "", err + } + + // Apply params to data map + for k, v := range params.extraParams { + data[k] = v + } + + // Hash input if needed + if params.hasher != nil { + params.hasher.Write([]byte(signingInput)) + hash := params.hasher.Sum(nil) + vaultInput = base64.StdEncoding.EncodeToString(hash) + } else { + // No pre-hashing (EdDSA) + vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput)) + } + data["input"] = vaultInput + + signPath := fmt.Sprintf("transit/sign/%s", v.keyName) + signSecret, err := v.client.Logical().WriteWithContext(ctx, signPath, data) + if err != nil { + return "", fmt.Errorf("vault sign: %v", err) + } + + signatureString, ok := signSecret.Data["signature"].(string) + if !ok { + return "", fmt.Errorf("vault response missing signature") + } + + // Parse vault signature: "vault:v1:base64sig" + var signatureB64 []byte + if len(signatureString) > 8 && signatureString[:6] == "vault:" { + parts := splitVaultSignature(signatureString) + if len(parts) == 3 { + signatureB64 = []byte(parts[2]) + } + } else { + return "", fmt.Errorf("unexpected signature format: %s", signatureString) + } + + return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil +} + func (v *vaultSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { keysMap, _, err := v.getTransitKeysMap(ctx) if err != nil { From 76191d7b9780b6d357a63bf5fb08f4419f65b763 Mon Sep 17 00:00:00 2001 From: kanywst Date: Tue, 24 Mar 2026 19:51:01 +0900 Subject: [PATCH 2/9] fix(oauth2): verify subject_token signature and expiry in ID-JAG exchange Signed-off-by: kanywst --- server/handlers.go | 80 +++--------------------- server/handlers_test.go | 131 ++++++++++++---------------------------- 2 files changed, 49 insertions(+), 162 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index 3802142a8f..496cb97c2d 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -880,13 +880,6 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, userIdentity = &ui } - // Skip approval if user already consented to the requested scopes for this client. - if !authReq.ForceApprovalPrompt && userIdentity != nil { - if scopesCoveredByConsent(userIdentity.Consents[authReq.ClientID], authReq.Scopes) { - return "", true, nil - } - } - // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original // flow would be unable to poll for the result at the /approval endpoint h := hmac.New(sha256.New, authReq.HMACKey) @@ -1925,27 +1918,22 @@ func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, cli return } - // Extract sub and aud from the subject_token. - sub, tokenAud, err := extractJWTSubAndAud(subjectToken) + // Verify the subject_token signature and expiry against this server's signing keys. + verifier := oidc.NewVerifier(s.issuerURL.String(), &signerKeySet{s.signer}, &oidc.Config{ClientID: client.ID}) + idToken, err := verifier.Verify(ctx, subjectToken) if err != nil { - s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Invalid subject_token: could not parse JWT claims.", http.StatusBadRequest, + s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Invalid subject_token.", http.StatusBadRequest, "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "invalid_subject_token") return } - if sub == "" { - s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "subject_token missing required sub claim.", http.StatusBadRequest, - "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "missing_sub") - return - } + sub := idToken.Subject - // Validate that the subject_token audience matches the requesting client (Section 4.3). - if !audContains(tokenAud, client.ID) { - s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "subject_token audience does not match client_id.", http.StatusBadRequest, - "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "audience_mismatch") + policyResult, err := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes) + if err != nil { + s.idJAGReject(ctx, w, "rejected", errServerError, "", http.StatusInternalServerError, + "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "policy_error") return } - - policyResult, policyErr := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes) if policyResult.Denied { if s.idJAGPolicyRejectionsTotal != nil { s.idJAGPolicyRejectionsTotal.WithLabelValues(string(policyResult.DenialReason)).Inc() @@ -1954,11 +1942,6 @@ func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, cli "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", string(policyResult.DenialReason)) return } - if policyErr != nil { - s.idJAGReject(ctx, w, "rejected", errServerError, "", http.StatusInternalServerError, - "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "policy_error") - return - } grantedScopes := policyResult.GrantedScopes grantedScope := strings.Join(grantedScopes, " ") @@ -2023,51 +2006,6 @@ func (s *Server) idJAGReject(ctx context.Context, w http.ResponseWriter, result s.tokenErrHelper(w, errType, errDesc, status) } -// extractJWTSubAndAud extracts the "sub" and "aud" claims from a JWT without -// verifying the signature. The aud claim may be a string or []string. -func extractJWTSubAndAud(token string) (sub string, aud []string, err error) { - parts := strings.SplitN(token, ".", 3) - if len(parts) != 3 { - return "", nil, fmt.Errorf("malformed JWT: expected 3 parts, got %d", len(parts)) - } - payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return "", nil, fmt.Errorf("failed to decode JWT payload: %v", err) - } - - var claims struct { - Sub string `json:"sub"` - Aud json.RawMessage `json:"aud"` - } - if err := json.Unmarshal(payloadBytes, &claims); err != nil { - return "", nil, fmt.Errorf("failed to unmarshal JWT payload: %v", err) - } - - if len(claims.Aud) > 0 { - var single string - if err := json.Unmarshal(claims.Aud, &single); err == nil { - aud = []string{single} - } else { - var multi []string - if err := json.Unmarshal(claims.Aud, &multi); err == nil { - aud = multi - } - } - } - - return claims.Sub, aud, nil -} - -// audContains reports whether target is in aud. -func audContains(aud []string, target string) bool { - for _, a := range aud { - if a == target { - return true - } - } - return false -} - func (s *Server) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { ctx := r.Context() diff --git a/server/handlers_test.go b/server/handlers_test.go index e71a966eb5..3b5b975d88 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -1804,18 +1804,34 @@ func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scop return m.refreshIdentity, nil } -// makeTestJWT builds a minimal JWT with the given sub for testing. -// The audience defaults to "client_1". -func makeTestJWT(sub string) string { - return makeTestJWTWithClaims(sub, "client_1") -} +// makeTestJWT builds a properly signed ID token JWT for testing. +// The token is signed with testKey and has aud=clientID, iss=issuerURL. +func makeTestJWT(t *testing.T, issuerURL, sub, clientID string) string { + t.Helper() + claims := struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Aud string `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + }{ + Iss: issuerURL, + Sub: sub, + Aud: clientID, + Exp: time.Now().Add(time.Hour).Unix(), + Iat: time.Now().Unix(), + } + payload, err := json.Marshal(claims) + require.NoError(t, err) -// makeTestJWTWithClaims builds a JWT with configurable sub and aud. -func makeTestJWTWithClaims(sub, aud string) string { - header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) - claimsJSON := fmt.Sprintf(`{"sub":"%s","iss":"https://issuer.example","aud":"%s","exp":9999999999}`, sub, aud) - payload := base64.RawURLEncoding.EncodeToString([]byte(claimsJSON)) - return header + "." + payload + ".fakesig" + key := &jose.JSONWebKey{Key: testKey, Algorithm: "RS256"} + s, err := jose.NewSigner(jose.SigningKey{Key: key, Algorithm: jose.RS256}, &jose.SignerOptions{}) + require.NoError(t, err) + jws, err := s.Sign(payload) + require.NoError(t, err) + token, err := jws.CompactSerialize() + require.NoError(t, err) + return token } // decodeJWTPayload decodes the payload section of a compact JWT (without signature verification). @@ -1842,73 +1858,10 @@ func decodeJWTHeader(t *testing.T, token string) map[string]interface{} { return header } -// TestExtractJWTSubAndAud tests extractJWTSubAndAud. -func TestExtractJWTSubAndAud(t *testing.T) { - tests := []struct { - name string - token string - wantSub string - wantAud []string - wantErr bool - }{ - { - name: "valid JWT returns sub and aud", - token: makeTestJWT("user-abc-123"), - wantSub: "user-abc-123", - wantAud: []string{"client_1"}, - }, - { - name: "not a JWT (no dots)", - token: "notajwt", - wantErr: true, - }, - { - name: "invalid base64 payload", - token: "aGVhZGVy.!!!.c2ln", - wantErr: true, - }, - { - name: "valid JWT without sub returns empty string", - token: func() string { - h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) - p := base64.RawURLEncoding.EncodeToString([]byte(`{"iss":"https://issuer.example"}`)) - return h + "." + p + ".sig" - }(), - wantSub: "", - wantAud: nil, - wantErr: false, - }, - { - name: "aud as array", - token: func() string { - h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) - p := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"u1","aud":["a","b"]}`)) - return h + "." + p + ".sig" - }(), - wantSub: "u1", - wantAud: []string{"a", "b"}, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - sub, aud, err := extractJWTSubAndAud(tc.token) - if tc.wantErr { - require.Error(t, err) - return - } - require.NoError(t, err) - require.Equal(t, tc.wantSub, sub) - require.Equal(t, tc.wantAud, aud) - }) - } -} - // TestHandleIDJAGExchange_JWTClaims verifies the issued ID-JAG JWT contains all // required claims per the spec (iss, sub, aud, client_id, jti, exp, iat) and // uses the correct typ header (oauth-id-jag+jwt). func TestHandleIDJAGExchange_JWTClaims(t *testing.T) { - subjectToken := makeTestJWT("user-123") - ctx := t.Context() httpServer, s := newTestServer(t, func(c *Config) { require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ @@ -1924,6 +1877,8 @@ func TestHandleIDJAGExchange_JWTClaims(t *testing.T) { }) defer httpServer.Close() + subjectToken := makeTestJWT(t, httpServer.URL, "user-123", "client_1") + vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) @@ -1972,8 +1927,6 @@ func TestHandleIDJAGExchange_JWTClaims(t *testing.T) { // and scopes are correctly passed through to the JWT claims, and that scope // reduction by policy produces the scope field in the response. func TestHandleIDJAGExchange_ResourceAndScope(t *testing.T) { - subjectToken := makeTestJWT("user-456") - t.Run("resource parameter appears in JWT", func(t *testing.T) { ctx := t.Context() httpServer, s := newTestServer(t, func(c *Config) { @@ -1988,6 +1941,7 @@ func TestHandleIDJAGExchange_ResourceAndScope(t *testing.T) { }) defer httpServer.Close() + subjectToken := makeTestJWT(t, httpServer.URL, "user-456", "client_1") vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) @@ -2025,6 +1979,7 @@ func TestHandleIDJAGExchange_ResourceAndScope(t *testing.T) { }) defer httpServer.Close() + subjectToken := makeTestJWT(t, httpServer.URL, "user-456", "client_1") vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) @@ -2066,6 +2021,7 @@ func TestHandleIDJAGExchange_ResourceAndScope(t *testing.T) { }) defer httpServer.Close() + subjectToken := makeTestJWT(t, httpServer.URL, "user-456", "client_1") vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) @@ -2110,7 +2066,7 @@ func TestHandleIDJAGExchange_SecurityBoundaries(t *testing.T) { }) defer httpServer.Close() - subjectToken := makeTestJWTWithClaims("user-1", "public_client") + subjectToken := makeTestJWT(t, httpServer.URL, "user-1", "public_client") vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) @@ -2143,7 +2099,7 @@ func TestHandleIDJAGExchange_SecurityBoundaries(t *testing.T) { defer httpServer.Close() // Subject token has aud="other_client", but we authenticate as client_1. - subjectToken := makeTestJWTWithClaims("user-1", "other_client") + subjectToken := makeTestJWT(t, httpServer.URL, "user-1", "other_client") vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) @@ -2178,7 +2134,7 @@ func TestHandleIDJAGExchange_SecurityBoundaries(t *testing.T) { vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) vals.Set("subject_token_type", tokenTypeID) - vals.Set("subject_token", makeTestJWT("user-1")) + vals.Set("subject_token", makeTestJWT(t, httpServer.URL, "user-1", "client_1")) vals.Set("connector_id", "mock") vals.Set("audience", "https://resource.example.com") vals.Set("client_id", "client_1") @@ -2209,7 +2165,7 @@ func TestHandleIDJAGExchange_SecurityBoundaries(t *testing.T) { vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) vals.Set("subject_token_type", tokenTypeID) - vals.Set("subject_token", makeTestJWT("user-1")) + vals.Set("subject_token", makeTestJWT(t, httpServer.URL, "user-1", "client_1")) vals.Set("connector_id", "mock") vals.Set("audience", "https://resource-as.example.com") // not in allowed list vals.Set("client_id", "client_1") @@ -2224,15 +2180,13 @@ func TestHandleIDJAGExchange_SecurityBoundaries(t *testing.T) { } // TestHandleIDJAGExchange_ValidationErrors verifies parameter validation. +// All these cases are rejected before subject_token verification is reached. func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { - subjectToken := makeTestJWT("user-123") - tests := []struct { name string audience string connectorID string subjectTokenType string - subjectToken string enableIDJAG bool wantCode int wantErrContains string @@ -2242,7 +2196,6 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { audience: "", connectorID: "mock", subjectTokenType: tokenTypeID, - subjectToken: subjectToken, enableIDJAG: true, wantCode: http.StatusBadRequest, }, @@ -2251,7 +2204,6 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { audience: "https://resource.example.com", connectorID: "mock", subjectTokenType: tokenTypeAccess, - subjectToken: subjectToken, enableIDJAG: true, wantCode: http.StatusBadRequest, }, @@ -2260,7 +2212,6 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { audience: "https://resource.example.com", connectorID: "", subjectTokenType: tokenTypeID, - subjectToken: subjectToken, enableIDJAG: true, wantCode: http.StatusBadRequest, wantErrContains: "connector_id", @@ -2270,7 +2221,6 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { audience: "https://resource.example.com", connectorID: "nonexistent", subjectTokenType: tokenTypeID, - subjectToken: subjectToken, enableIDJAG: true, wantCode: http.StatusBadRequest, }, @@ -2279,7 +2229,6 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { audience: "https://resource.example.com", connectorID: "mock", subjectTokenType: tokenTypeID, - subjectToken: subjectToken, enableIDJAG: false, wantCode: http.StatusBadRequest, wantErrContains: "not enabled", @@ -2307,7 +2256,7 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) vals.Set("subject_token_type", tc.subjectTokenType) - vals.Set("subject_token", tc.subjectToken) + vals.Set("subject_token", "placeholder") if tc.connectorID != "" { vals.Set("connector_id", tc.connectorID) } @@ -2332,8 +2281,6 @@ func TestHandleIDJAGExchange_ValidationErrors(t *testing.T) { // TestHandleIDJAGExchange_CustomExpiry verifies that IDJAGTokensValidFor is honored. func TestHandleIDJAGExchange_CustomExpiry(t *testing.T) { - subjectToken := makeTestJWT("user-789") - ctx := t.Context() httpServer, s := newTestServer(t, func(c *Config) { require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ @@ -2348,6 +2295,8 @@ func TestHandleIDJAGExchange_CustomExpiry(t *testing.T) { }) defer httpServer.Close() + subjectToken := makeTestJWT(t, httpServer.URL, "user-789", "client_1") + vals := url.Values{} vals.Set("grant_type", grantTypeTokenExchange) vals.Set("requested_token_type", tokenTypeIDJAG) From c3a6afe66f655e21bc0768f92b83c681598683da Mon Sep 17 00:00:00 2001 From: kanywst Date: Tue, 24 Mar 2026 19:51:19 +0900 Subject: [PATCH 3/9] refactor(oauth2): separate policy denial from error in evaluateIDJAGPolicy Signed-off-by: kanywst --- server/policy.go | 8 ++------ server/policy_test.go | 19 ++++++------------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/server/policy.go b/server/policy.go index 577023e166..460ed84a2d 100644 --- a/server/policy.go +++ b/server/policy.go @@ -1,9 +1,5 @@ package server -import ( - "fmt" -) - // PolicyDenialReason categorizes why an ID-JAG policy check failed. type PolicyDenialReason string @@ -50,7 +46,7 @@ func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience stri return PolicyResult{ Denied: true, DenialReason: PolicyDenialClientHasNoPolicy, - }, fmt.Errorf("no ID-JAG policy found for client %q: access_denied", clientID) + }, nil } // Check audience. @@ -58,7 +54,7 @@ func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience stri return PolicyResult{ Denied: true, DenialReason: PolicyDenialAudienceNotAllowed, - }, fmt.Errorf("audience %q is not allowed for client %q: access_denied", audience, clientID) + }, nil } // Filter scopes: if the policy restricts scopes, only grant those that are allowed. diff --git a/server/policy_test.go b/server/policy_test.go index 39e35ba979..5ecc12279d 100644 --- a/server/policy_test.go +++ b/server/policy_test.go @@ -13,7 +13,7 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { clientID string audience string scopes []string - wantErr bool + wantDenied bool wantDenialReason PolicyDenialReason wantGrantedScopes []string }{ @@ -22,7 +22,7 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { policies: nil, clientID: "any-client", audience: "https://resource.example.com", - wantErr: true, + wantDenied: true, wantDenialReason: PolicyDenialClientHasNoPolicy, }, { @@ -32,7 +32,6 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { }, clientID: "client-a", audience: "https://resource.example.com", - wantErr: false, }, { name: "audience not allowed", @@ -41,7 +40,7 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { }, clientID: "client-a", audience: "https://resource.example.com", - wantErr: true, + wantDenied: true, wantDenialReason: PolicyDenialAudienceNotAllowed, }, { @@ -51,7 +50,7 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { }, clientID: "unknown-client", audience: "https://resource.example.com", - wantErr: true, + wantDenied: true, wantDenialReason: PolicyDenialClientHasNoPolicy, }, { @@ -61,7 +60,6 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { }, clientID: "any-client", audience: "https://resource.example.com", - wantErr: false, }, { name: "exact match takes priority over wildcard", @@ -71,7 +69,6 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { }, clientID: "client-a", audience: "https://resource.example.com", - wantErr: false, }, { name: "scope filtered by policy", @@ -81,7 +78,6 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { clientID: "client-a", audience: "https://resource.example.com", scopes: []string{"read", "admin"}, - wantErr: false, wantGrantedScopes: []string{"read"}, }, { @@ -92,7 +88,6 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { clientID: "client-a", audience: "https://resource.example.com", scopes: []string{"read"}, - wantErr: false, wantGrantedScopes: []string{"read"}, }, { @@ -103,7 +98,6 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { clientID: "client-a", audience: "https://resource.example.com", scopes: []string{"anything"}, - wantErr: false, wantGrantedScopes: []string{"anything"}, }, } @@ -111,12 +105,11 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { result, err := evaluateIDJAGPolicy(tc.policies, tc.clientID, tc.audience, tc.scopes) - if tc.wantErr { - require.Error(t, err) + require.NoError(t, err) + if tc.wantDenied { require.True(t, result.Denied) require.Equal(t, tc.wantDenialReason, result.DenialReason) } else { - require.NoError(t, err) require.False(t, result.Denied) if tc.wantGrantedScopes != nil { require.Equal(t, tc.wantGrantedScopes, result.GrantedScopes) From cc9efef36f68f1a48ddf43b38b9c128731f6dcf2 Mon Sep 17 00:00:00 2001 From: kanywst Date: Tue, 24 Mar 2026 19:57:49 +0900 Subject: [PATCH 4/9] refactor(signer): deduplicate Sign and SignWithType into common method Signed-off-by: kanywst --- server/signer/local.go | 25 ++++------ server/signer/mock.go | 11 ++++- server/signer/vault.go | 101 ++++------------------------------------- 3 files changed, 26 insertions(+), 111 deletions(-) diff --git a/server/signer/local.go b/server/signer/local.go index afddb80b37..fb20801905 100644 --- a/server/signer/local.go +++ b/server/signer/local.go @@ -98,24 +98,14 @@ func (l *localSigner) logRotateError(err error) { } func (l *localSigner) Sign(ctx context.Context, payload []byte) (string, error) { - keys, err := l.storage.GetKeys(ctx) - if err != nil { - return "", fmt.Errorf("failed to get keys: %v", err) - } - - signingKey := keys.SigningKey - if signingKey == nil { - return "", fmt.Errorf("no key to sign payload with") - } - signingAlg, err := signatureAlgorithm(signingKey) - if err != nil { - return "", err - } - - return signPayload(signingKey, signingAlg, payload) + return l.sign(ctx, payload, "") } func (l *localSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) { + return l.sign(ctx, payload, tokenType) +} + +func (l *localSigner) sign(ctx context.Context, payload []byte, tokenType string) (string, error) { keys, err := l.storage.GetKeys(ctx) if err != nil { return "", fmt.Errorf("failed to get keys: %v", err) @@ -130,7 +120,10 @@ func (l *localSigner) SignWithType(ctx context.Context, payload []byte, tokenTyp return "", err } - return signPayloadWithType(signingKey, signingAlg, payload, tokenType) + if tokenType != "" { + return signPayloadWithType(signingKey, signingAlg, payload, tokenType) + } + return signPayload(signingKey, signingAlg, payload) } func (l *localSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { diff --git a/server/signer/mock.go b/server/signer/mock.go index cfdf00471b..3031207e77 100644 --- a/server/signer/mock.go +++ b/server/signer/mock.go @@ -56,11 +56,18 @@ type mockSigner struct { } func (m *mockSigner) Sign(_ context.Context, payload []byte) (string, error) { - return signPayload(m.key, jose.RS256, payload) + return m.sign(payload, "") } func (m *mockSigner) SignWithType(_ context.Context, payload []byte, tokenType string) (string, error) { - return signPayloadWithType(m.key, jose.RS256, payload, tokenType) + return m.sign(payload, tokenType) +} + +func (m *mockSigner) sign(payload []byte, tokenType string) (string, error) { + if tokenType != "" { + return signPayloadWithType(m.key, jose.RS256, payload, tokenType) + } + return signPayload(m.key, jose.RS256, payload) } func (m *mockSigner) ValidationKeys(_ context.Context) ([]*jose.JSONWebKey, error) { diff --git a/server/signer/vault.go b/server/signer/vault.go index a073e8639c..29eff48b02 100644 --- a/server/signer/vault.go +++ b/server/signer/vault.go @@ -95,108 +95,30 @@ func (v *vaultSigner) Start(_ context.Context) { } func (v *vaultSigner) Sign(ctx context.Context, payload []byte) (string, error) { - // 1. Fetch keys to determine the key to use (latest version) and its ID. - keysMap, latestVersion, err := v.getTransitKeysMap(ctx) - if err != nil { - return "", fmt.Errorf("failed to get keys for signing context: %v", err) - } - - // Determine the key version and ID to use - // We use the latest version by default - signingJWK, ok := keysMap[latestVersion] - if !ok { - return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion) - } - - // 2. Construct JWS Header and Payload first (Signing Input) - header := map[string]interface{}{ - "alg": signingJWK.Algorithm, - "kid": signingJWK.KeyID, - } - - headerBytes, err := json.Marshal(header) - if err != nil { - return "", fmt.Errorf("failed to marshal header: %v", err) - } - - headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes) - payloadB64 := base64.RawURLEncoding.EncodeToString(payload) - - // The input to the signature is "header.payload" - signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64) - - // 3. Sign the signingInput using Vault - var vaultInput string - data := map[string]interface{}{} - - // Determine Vault params based on JWS algorithm - params, err := getVaultParams(signingJWK.Algorithm) - if err != nil { - return "", err - } - - // Apply params to data map - for k, v := range params.extraParams { - data[k] = v - } - - // Hash input if needed - if params.hasher != nil { - params.hasher.Write([]byte(signingInput)) - hash := params.hasher.Sum(nil) - vaultInput = base64.StdEncoding.EncodeToString(hash) - } else { - // No pre-hashing (EdDSA) - vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput)) - } - data["input"] = vaultInput - - signPath := fmt.Sprintf("transit/sign/%s", v.keyName) - signSecret, err := v.client.Logical().WriteWithContext(ctx, signPath, data) - if err != nil { - return "", fmt.Errorf("vault sign: %v", err) - } - - signatureString, ok := signSecret.Data["signature"].(string) - if !ok { - return "", fmt.Errorf("vault response missing signature") - } - - // Parse vault signature: "vault:v1:base64sig" - var signatureB64 []byte - if len(signatureString) > 8 && signatureString[:6] == "vault:" { - parts := splitVaultSignature(signatureString) - if len(parts) == 3 { - // part 1 is "vault", part 2 is "v1", part 3 is signature - // The signature is already base64 encoded, decoding it is not needed and - // will make the code failing. - signatureB64 = []byte(parts[2]) - } - } else { - return "", fmt.Errorf("unexpected signature format: %s", signatureString) - } - - return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil + return v.sign(ctx, payload, "") } func (v *vaultSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) { - // 1. Fetch keys to determine the key to use (latest version) and its ID. + return v.sign(ctx, payload, tokenType) +} + +func (v *vaultSigner) sign(ctx context.Context, payload []byte, tokenType string) (string, error) { keysMap, latestVersion, err := v.getTransitKeysMap(ctx) if err != nil { return "", fmt.Errorf("failed to get keys for signing context: %v", err) } - // Determine the key version and ID to use signingJWK, ok := keysMap[latestVersion] if !ok { return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion) } - // 2. Construct JWS Header with custom typ and Payload first (Signing Input) header := map[string]interface{}{ "alg": signingJWK.Algorithm, "kid": signingJWK.KeyID, - "typ": tokenType, + } + if tokenType != "" { + header["typ"] = tokenType } headerBytes, err := json.Marshal(header) @@ -207,31 +129,25 @@ func (v *vaultSigner) SignWithType(ctx context.Context, payload []byte, tokenTyp headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes) payloadB64 := base64.RawURLEncoding.EncodeToString(payload) - // The input to the signature is "header.payload" signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64) - // 3. Sign the signingInput using Vault var vaultInput string data := map[string]interface{}{} - // Determine Vault params based on JWS algorithm params, err := getVaultParams(signingJWK.Algorithm) if err != nil { return "", err } - // Apply params to data map for k, v := range params.extraParams { data[k] = v } - // Hash input if needed if params.hasher != nil { params.hasher.Write([]byte(signingInput)) hash := params.hasher.Sum(nil) vaultInput = base64.StdEncoding.EncodeToString(hash) } else { - // No pre-hashing (EdDSA) vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput)) } data["input"] = vaultInput @@ -247,7 +163,6 @@ func (v *vaultSigner) SignWithType(ctx context.Context, payload []byte, tokenTyp return "", fmt.Errorf("vault response missing signature") } - // Parse vault signature: "vault:v1:base64sig" var signatureB64 []byte if len(signatureString) > 8 && signatureString[:6] == "vault:" { parts := splitVaultSignature(signatureString) From 1402367667d35fd17914d24ddbfe44ea31ba8221 Mon Sep 17 00:00:00 2001 From: kanywst Date: Tue, 24 Mar 2026 19:58:12 +0900 Subject: [PATCH 5/9] fix(oauth2): register ID-JAG metrics only when enabled Signed-off-by: kanywst --- server/server.go | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/server/server.go b/server/server.go index e57bbae1fd..c42d402361 100644 --- a/server/server.go +++ b/server/server.go @@ -468,24 +468,25 @@ func newServer(ctx context.Context, c Config) (*Server, error) { c.PrometheusRegistry.MustRegister(requestCounter, durationHist, sizeHist) - // ID-JAG metrics. - s.idJAGRequestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "dex_id_jag_requests_total", - Help: "Total number of ID-JAG token exchange requests.", - }, []string{"result"}) - s.idJAGPolicyRejectionsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "dex_id_jag_policy_rejections_total", - Help: "Total number of ID-JAG policy rejections by reason.", - }, []string{"reason"}) - s.idJAGScopeModificationsTotal = prometheus.NewCounter(prometheus.CounterOpts{ - Name: "dex_id_jag_scope_modifications_total", - Help: "Total number of ID-JAG requests where policy reduced the requested scopes.", - }) - c.PrometheusRegistry.MustRegister( - s.idJAGRequestsTotal, - s.idJAGPolicyRejectionsTotal, - s.idJAGScopeModificationsTotal, - ) + if s.enableIDJAG { + s.idJAGRequestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "dex_id_jag_requests_total", + Help: "Total number of ID-JAG token exchange requests.", + }, []string{"result"}) + s.idJAGPolicyRejectionsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "dex_id_jag_policy_rejections_total", + Help: "Total number of ID-JAG policy rejections by reason.", + }, []string{"reason"}) + s.idJAGScopeModificationsTotal = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "dex_id_jag_scope_modifications_total", + Help: "Total number of ID-JAG requests where policy reduced the requested scopes.", + }) + c.PrometheusRegistry.MustRegister( + s.idJAGRequestsTotal, + s.idJAGPolicyRejectionsTotal, + s.idJAGScopeModificationsTotal, + ) + } instrumentHandler = func(handlerName string, handler http.Handler) http.HandlerFunc { return promhttp.InstrumentHandlerDuration(durationHist.MustCurryWith(prometheus.Labels{"handler": handlerName}), From c83af6dc5f51d4aecd4f7bcde3ae375cb1588663 Mon Sep 17 00:00:00 2001 From: kanywst Date: Tue, 24 Mar 2026 19:59:14 +0900 Subject: [PATCH 6/9] docs(oauth2): note single audience per ID-JAG draft spec Signed-off-by: kanywst --- server/oauth2.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/oauth2.go b/server/oauth2.go index fcc394033d..807ffb032a 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -306,6 +306,7 @@ func (s *Server) newAccessToken(ctx context.Context, clientID string, claims sto const idJAGTyp = "oauth-id-jag+jwt" // idJAGClaims is the JWT payload for an ID-JAG token. +// Audience is a single string per draft-ietf-oauth-identity-assertion-authz-grant-02. type idJAGClaims struct { Issuer string `json:"iss"` Subject string `json:"sub"` From d1f8ab7f752a0547f67c6b264f907a9255dbdc65 Mon Sep 17 00:00:00 2001 From: kanywst Date: Mon, 30 Mar 2026 19:40:54 +0900 Subject: [PATCH 7/9] refactor(oauth2): drop unused error return from evaluateIDJAGPolicy No code path returns a non-nil error after the earlier refactor that moved policy denials into PolicyResult. Signed-off-by: kanywst --- server/handlers.go | 7 +------ server/policy.go | 8 ++++---- server/policy_test.go | 3 +-- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index 496cb97c2d..d9f8685448 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1928,12 +1928,7 @@ func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, cli } sub := idToken.Subject - policyResult, err := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes) - if err != nil { - s.idJAGReject(ctx, w, "rejected", errServerError, "", http.StatusInternalServerError, - "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "sub", sub, "reason", "policy_error") - return - } + policyResult := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes) if policyResult.Denied { if s.idJAGPolicyRejectionsTotal != nil { s.idJAGPolicyRejectionsTotal.WithLabelValues(string(policyResult.DenialReason)).Inc() diff --git a/server/policy.go b/server/policy.go index 460ed84a2d..cc87c052cb 100644 --- a/server/policy.go +++ b/server/policy.go @@ -28,7 +28,7 @@ type TokenExchangePolicy struct { // evaluateIDJAGPolicy checks whether the client is permitted to obtain an ID-JAG // for the given audience and scopes. Clients without a matching policy are denied // by default (default-deny). -func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience string, scopes []string) (PolicyResult, error) { +func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience string, scopes []string) PolicyResult { // Find the most-specific policy for this client: exact match first, then wildcard. var matched *TokenExchangePolicy for i := range policies { @@ -46,7 +46,7 @@ func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience stri return PolicyResult{ Denied: true, DenialReason: PolicyDenialClientHasNoPolicy, - }, nil + } } // Check audience. @@ -54,7 +54,7 @@ func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience stri return PolicyResult{ Denied: true, DenialReason: PolicyDenialAudienceNotAllowed, - }, nil + } } // Filter scopes: if the policy restricts scopes, only grant those that are allowed. @@ -72,7 +72,7 @@ func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience stri return PolicyResult{ Denied: false, GrantedScopes: grantedScopes, - }, nil + } } func audienceAllowed(allowedAudiences []string, audience string) bool { diff --git a/server/policy_test.go b/server/policy_test.go index 5ecc12279d..f92ea1cc38 100644 --- a/server/policy_test.go +++ b/server/policy_test.go @@ -104,8 +104,7 @@ func TestEvaluateIDJAGPolicy(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := evaluateIDJAGPolicy(tc.policies, tc.clientID, tc.audience, tc.scopes) - require.NoError(t, err) + result := evaluateIDJAGPolicy(tc.policies, tc.clientID, tc.audience, tc.scopes) if tc.wantDenied { require.True(t, result.Denied) require.Equal(t, tc.wantDenialReason, result.DenialReason) From 98ed3522ecaf72c91701bbb0d37773fb443c2f81 Mon Sep 17 00:00:00 2001 From: kanywst Date: Mon, 30 Mar 2026 19:41:09 +0900 Subject: [PATCH 8/9] fix(oauth2): use tokenTypeIDJAG const in IDJAGEnabled Signed-off-by: kanywst --- server/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index c42d402361..c69ca1fbb3 100644 --- a/server/server.go +++ b/server/server.go @@ -163,7 +163,7 @@ type TokenExchangeConfig struct { // IDJAGEnabled reports whether the ID-JAG token type is enabled. func (c TokenExchangeConfig) IDJAGEnabled() bool { for _, t := range c.TokenTypes { - if t == "urn:ietf:params:oauth:token-type:id-jag" { + if t == tokenTypeIDJAG { return true } } From f397972fa1b9da604e53f77372c8882f7da72e08 Mon Sep 17 00:00:00 2001 From: kanywst Date: Mon, 30 Mar 2026 19:41:17 +0900 Subject: [PATCH 9/9] docs(oauth2): clarify why getConnector result is unused in token exchange Signed-off-by: kanywst --- server/handlers.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/handlers.go b/server/handlers.go index d9f8685448..1c2febefde 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1898,6 +1898,7 @@ func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, cli return } + // Only checking existence; the connector value is not needed for token exchange. if _, err := s.getConnector(ctx, connectorID); err != nil { s.idJAGReject(ctx, w, "rejected", errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest, "client_id", client.ID, "connector_id", connectorID, "audience", audience, "resource", resource, "requested_scope", requestedScope, "reason", "connector_not_found")