diff --git a/db/migrations/000008_add_nonce.down.sql b/db/migrations/000008_add_nonce.down.sql new file mode 100644 index 0000000..1ab2e12 --- /dev/null +++ b/db/migrations/000008_add_nonce.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE oauth_authorization_codes DROP COLUMN nonce; +ALTER TABLE auth_mfa_pending DROP COLUMN nonce; diff --git a/db/migrations/000008_add_nonce.up.sql b/db/migrations/000008_add_nonce.up.sql new file mode 100644 index 0000000..4cf2b6d --- /dev/null +++ b/db/migrations/000008_add_nonce.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE oauth_authorization_codes ADD COLUMN nonce text; +ALTER TABLE auth_mfa_pending ADD COLUMN nonce text; diff --git a/db/queries/auth.sql b/db/queries/auth.sql index 15ef475..0475025 100644 --- a/db/queries/auth.sql +++ b/db/queries/auth.sql @@ -74,9 +74,10 @@ INSERT INTO oauth_authorization_codes ( scope, code_challenge, code_challenge_method, - expires_at + expires_at, + nonce ) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8); +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9); -- name: GetAuthorizationCode :one SELECT * @@ -222,8 +223,8 @@ SET mfa_enabled = false, mfa_secret = NULL, mfa_verified_at = NULL, updated_at = WHERE id = $1; -- name: CreateMFAPending :exec -INSERT INTO auth_mfa_pending (id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, expires_at) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9); +INSERT INTO auth_mfa_pending (id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, expires_at, nonce) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10); -- name: GetMFAPending :one SELECT * FROM auth_mfa_pending diff --git a/pkg/db/auth.sql.go b/pkg/db/auth.sql.go index 5c40660..fc31cc6 100644 --- a/pkg/db/auth.sql.go +++ b/pkg/db/auth.sql.go @@ -59,8 +59,8 @@ func (q *Queries) CreateEmailToken(ctx context.Context, arg CreateEmailTokenPara } const createMFAPending = `-- name: CreateMFAPending :exec -INSERT INTO auth_mfa_pending (id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, expires_at) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +INSERT INTO auth_mfa_pending (id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, expires_at, nonce) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ` type CreateMFAPendingParams struct { @@ -73,6 +73,7 @@ type CreateMFAPendingParams struct { CodeChallenge sql.NullString `json:"code_challenge"` CodeChallengeMethod sql.NullString `json:"code_challenge_method"` ExpiresAt time.Time `json:"expires_at"` + Nonce sql.NullString `json:"nonce"` } func (q *Queries) CreateMFAPending(ctx context.Context, arg CreateMFAPendingParams) error { @@ -86,6 +87,7 @@ func (q *Queries) CreateMFAPending(ctx context.Context, arg CreateMFAPendingPara arg.CodeChallenge, arg.CodeChallengeMethod, arg.ExpiresAt, + arg.Nonce, ) return err } @@ -322,7 +324,7 @@ func (q *Queries) EnableMFA(ctx context.Context, arg EnableMFAParams) error { } const getAuthorizationCode = `-- name: GetAuthorizationCode :one -SELECT code, user_id, client_id, redirect_uri, scope, code_challenge, code_challenge_method, expires_at, consumed_at, created_at +SELECT code, user_id, client_id, redirect_uri, scope, code_challenge, code_challenge_method, expires_at, consumed_at, created_at, nonce FROM oauth_authorization_codes WHERE code = $1 ` @@ -341,6 +343,7 @@ func (q *Queries) GetAuthorizationCode(ctx context.Context, code string) (OauthA &i.ExpiresAt, &i.ConsumedAt, &i.CreatedAt, + &i.Nonce, ) return i, err } @@ -371,7 +374,7 @@ func (q *Queries) GetEmailToken(ctx context.Context, arg GetEmailTokenParams) (A } const getMFAPending = `-- name: GetMFAPending :one -SELECT id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, created_at, expires_at FROM auth_mfa_pending +SELECT id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, created_at, expires_at, nonce FROM auth_mfa_pending WHERE id = $1 AND expires_at > now() ` @@ -389,6 +392,7 @@ func (q *Queries) GetMFAPending(ctx context.Context, id string) (AuthMfaPending, &i.CodeChallengeMethod, &i.CreatedAt, &i.ExpiresAt, + &i.Nonce, ) return i, err } @@ -742,9 +746,10 @@ INSERT INTO oauth_authorization_codes ( scope, code_challenge, code_challenge_method, - expires_at + expires_at, + nonce ) -VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ` type InsertAuthorizationCodeParams struct { @@ -756,6 +761,7 @@ type InsertAuthorizationCodeParams struct { CodeChallenge sql.NullString `json:"code_challenge"` CodeChallengeMethod sql.NullString `json:"code_challenge_method"` ExpiresAt time.Time `json:"expires_at"` + Nonce sql.NullString `json:"nonce"` } func (q *Queries) InsertAuthorizationCode(ctx context.Context, arg InsertAuthorizationCodeParams) error { @@ -768,6 +774,7 @@ func (q *Queries) InsertAuthorizationCode(ctx context.Context, arg InsertAuthori arg.CodeChallenge, arg.CodeChallengeMethod, arg.ExpiresAt, + arg.Nonce, ) return err } diff --git a/pkg/db/models.go b/pkg/db/models.go index 977a51c..c53d880 100644 --- a/pkg/db/models.go +++ b/pkg/db/models.go @@ -32,6 +32,7 @@ type AuthMfaPending struct { CodeChallengeMethod sql.NullString `json:"code_challenge_method"` CreatedAt time.Time `json:"created_at"` ExpiresAt time.Time `json:"expires_at"` + Nonce sql.NullString `json:"nonce"` } type AuthSession struct { @@ -72,6 +73,7 @@ type OauthAuthorizationCode struct { ExpiresAt time.Time `json:"expires_at"` ConsumedAt sql.NullTime `json:"consumed_at"` CreatedAt time.Time `json:"created_at"` + Nonce sql.NullString `json:"nonce"` } type OauthClient struct { diff --git a/pkg/httpserver/credentials.go b/pkg/httpserver/credentials.go index b5020ee..2d9c079 100644 --- a/pkg/httpserver/credentials.go +++ b/pkg/httpserver/credentials.go @@ -183,7 +183,7 @@ func normalizeCodeChallenge(challenge string) string { } // generateAuthorizationCode generates a new authorization code and stores it in the database -func (s *Server) generateAuthorizationCode(ctx context.Context, userID uuid.UUID, clientID uuid.UUID, redirectURI string, scope []string, codeChallenge string, codeChallengeMethod string) (string, error) { +func (s *Server) generateAuthorizationCode(ctx context.Context, userID uuid.UUID, clientID uuid.UUID, redirectURI string, scope []string, codeChallenge string, codeChallengeMethod string, nonce string) (string, error) { code, err := generateRandomString(32) if err != nil { return "", err @@ -202,6 +202,7 @@ func (s *Server) generateAuthorizationCode(ctx context.Context, userID uuid.UUID CodeChallenge: sql.NullString{String: normalizedChallenge, Valid: codeChallenge != ""}, CodeChallengeMethod: sql.NullString{String: codeChallengeMethod, Valid: codeChallengeMethod != ""}, ExpiresAt: time.Now().Add(authorizationCodeExpiresIn), + Nonce: sql.NullString{String: nonce, Valid: nonce != ""}, }) if err != nil { return "", err @@ -241,7 +242,7 @@ func (s *Server) createSession(ctx context.Context, userID uuid.UUID) (Session, // generateTokens creates new access and refresh tokens and stores them in the database. // Returns the token pair on success. -func (s *Server) generateTokens(ctx context.Context, clientID uuid.UUID, userID uuid.UUID, scope []string) (TokenPair, error) { +func (s *Server) generateTokens(ctx context.Context, clientID uuid.UUID, userID uuid.UUID, scope []string, nonce string) (TokenPair, error) { // Fetch user information for JWT claims user, err := s.datastore.Q.GetUserByID(ctx, userID) if err != nil { @@ -298,7 +299,7 @@ func (s *Server) generateTokens(ctx context.Context, clientID uuid.UUID, userID // Generate OIDC ID token when openid scope is requested var idToken string if slices.Contains(scope, "openid") { - idClaims := jwtpkg.IDTokenClaims{} + idClaims := jwtpkg.IDTokenClaims{Nonce: nonce} if slices.Contains(scope, "email") { idClaims.Email = user.Email idClaims.EmailVerified = &user.EmailVerified diff --git a/pkg/httpserver/login.go b/pkg/httpserver/login.go index be4ab45..3e7dcb8 100644 --- a/pkg/httpserver/login.go +++ b/pkg/httpserver/login.go @@ -109,6 +109,7 @@ func (s *Server) HandleLoginPost(w http.ResponseWriter, r *http.Request) { scope := strings.Split(r.FormValue("scope"), " ") codeChallenge := r.FormValue("code_challenge") codeChallengeMethod := r.FormValue("code_challenge_method") + nonce := r.FormValue("nonce") log.Printf("[DEBUG] HandleLoginPost: username=%s, clientID=%s, redirectURI=%s", username, clientID, redirectURI) @@ -120,6 +121,7 @@ func (s *Server) HandleLoginPost(w http.ResponseWriter, r *http.Request) { Scope: scope, CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, + Nonce: nonce, } // Validate credentials - use the function that includes inactive users // so we can handle deactivated users appropriately based on the login type @@ -217,7 +219,7 @@ func (s *Server) HandleLoginPost(w http.ResponseWriter, r *http.Request) { // Generate and store authorization code log.Printf("[DEBUG] HandleLoginPost: Generating authorization code") - authorizationCode, err := s.generateAuthorizationCode(r.Context(), user.ID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod) + authorizationCode, err := s.generateAuthorizationCode(r.Context(), user.ID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod, nonce) if err != nil { log.Printf("[ERROR] HandleLoginPost: Failed to generate authorization code: %v", err) s.renderLoginError(w, http.StatusInternalServerError, "An error occurred", oauthParams) diff --git a/pkg/httpserver/mfa.go b/pkg/httpserver/mfa.go index 6552b31..1072c99 100644 --- a/pkg/httpserver/mfa.go +++ b/pkg/httpserver/mfa.go @@ -126,6 +126,7 @@ func (s *Server) HandleMFAPost(w http.ResponseWriter, r *http.Request) { scope := pending.Scope codeChallenge := pending.CodeChallenge.String codeChallengeMethod := pending.CodeChallengeMethod.String + nonce := pending.Nonce.String // Validate OAuth client client, err := s.validateOAuthClient(r.Context(), clientID, redirectURI, scope) @@ -141,7 +142,7 @@ func (s *Server) HandleMFAPost(w http.ResponseWriter, r *http.Request) { } // Generate authorization code - authorizationCode, err := s.generateAuthorizationCode(r.Context(), pending.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod) + authorizationCode, err := s.generateAuthorizationCode(r.Context(), pending.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod, nonce) if err != nil { log.Printf("[ERROR] HandleMFAPost: Failed to generate authorization code: %v", err) http.Redirect(w, r, "/oauth/login", http.StatusFound) @@ -180,6 +181,7 @@ func (s *Server) createMFAPendingSession(r *http.Request, userID uuid.UUID, oaut CodeChallenge: sql.NullString{String: oauthParams.CodeChallenge, Valid: oauthParams.CodeChallenge != ""}, CodeChallengeMethod: sql.NullString{String: oauthParams.CodeChallengeMethod, Valid: oauthParams.CodeChallengeMethod != ""}, ExpiresAt: time.Now().Add(mfaPendingExpiresIn), + Nonce: sql.NullString{String: oauthParams.Nonce, Valid: oauthParams.Nonce != ""}, } if err := s.datastore.Q.CreateMFAPending(r.Context(), params); err != nil { diff --git a/pkg/httpserver/oauth.go b/pkg/httpserver/oauth.go index bd8676a..72ab8ef 100644 --- a/pkg/httpserver/oauth.go +++ b/pkg/httpserver/oauth.go @@ -57,6 +57,7 @@ func (s *Server) HandleOauthAuthorize(w http.ResponseWriter, r *http.Request) { codeChallenge := r.URL.Query().Get("code_challenge") codeChallengeMethod := r.URL.Query().Get("code_challenge_method") state := r.URL.Query().Get("state") + nonce := r.URL.Query().Get("nonce") // Phase 1: Validate client_id and redirect_uri first. // Per RFC 6749 4.1.2.1, if these are invalid we MUST NOT redirect — show error directly. @@ -127,7 +128,7 @@ func (s *Server) HandleOauthAuthorize(w http.ResponseWriter, r *http.Request) { } // Phase 4: Generate authorization code and redirect. - authorizationCode, err := s.generateAuthorizationCode(r.Context(), session.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod) + authorizationCode, err := s.generateAuthorizationCode(r.Context(), session.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod, nonce) if err != nil { redirectError("server_error", "An error occurred") return @@ -271,8 +272,12 @@ func (s *Server) handleAuthorizationCodeGrant(w http.ResponseWriter, r *http.Req return } - // Generate tokens - s.writeTokenResponse(w, r, client.ID, authCode.UserID, authCode.Scope) + // Generate tokens, passing nonce for inclusion in the ID token + nonce := "" + if authCode.Nonce.Valid { + nonce = authCode.Nonce.String + } + s.writeTokenResponse(w, r, client.ID, authCode.UserID, authCode.Scope, nonce) } // handleRefreshTokenGrant exchanges a refresh token for new tokens @@ -323,17 +328,18 @@ func (s *Server) handleRefreshTokenGrant(w http.ResponseWriter, r *http.Request, return } - // Issue new tokens with the same user and scope + // Issue new tokens with the same user and scope (no nonce on refresh) var userID uuid.UUID if token.UserID.Valid { userID = token.UserID.UUID } - s.writeTokenResponse(w, r, client.ID, userID, token.Scope) + s.writeTokenResponse(w, r, client.ID, userID, token.Scope, "") } -// writeTokenResponse generates tokens and writes the JSON response -func (s *Server) writeTokenResponse(w http.ResponseWriter, r *http.Request, clientID uuid.UUID, userID uuid.UUID, scope []string) { - tokens, err := s.generateTokens(r.Context(), clientID, userID, scope) +// writeTokenResponse generates tokens and writes the JSON response. +// nonce is included in the ID token if non-empty (per OIDC Core 3.1.2.1). +func (s *Server) writeTokenResponse(w http.ResponseWriter, r *http.Request, clientID uuid.UUID, userID uuid.UUID, scope []string, nonce string) { + tokens, err := s.generateTokens(r.Context(), clientID, userID, scope, nonce) if err != nil { s.writeTokenError(w, "server_error", "Failed to generate tokens") return diff --git a/pkg/httpserver/oauth_integration_test.go b/pkg/httpserver/oauth_integration_test.go index 6d9c24b..2fef326 100644 --- a/pkg/httpserver/oauth_integration_test.go +++ b/pkg/httpserver/oauth_integration_test.go @@ -495,6 +495,121 @@ func (s *OAuthFlowSuite) TestTokenResponseIncludesIDToken() { s.Equal(result.User.Username, claims["preferred_username"], "preferred_username should match") } +// TestIDTokenIncludesNonce verifies that when a nonce is sent in the authorize request, +// it appears in the ID token per OIDC Core Section 3.1.2.1. +func (s *OAuthFlowSuite) TestIDTokenIncludesNonce() { + client := s.mustRegisterOAuthClient(db.CreateOAuthClientParams{ + ClientID: s.mustGenerateRandomString(8), + ClientSecret: sql.NullString{String: "", Valid: false}, + Name: s.mustGenerateRandomString(8), + RedirectUris: []string{"http://localhost:8080/callback"}, + AllowedScopes: []string{"openid", "profile", "email"}, + IsConfidential: false, + Audience: "http://localhost:8080", + }) + username := s.mustGenerateRandomString(8) + password := s.mustGenerateRandomString(16) + s.mustRegisterUser(username, password, fmt.Sprintf("%s@example.com", username)) + scv := s.mustCreateStateAndCodeVerifier() + nonce := s.mustGenerateRandomString(32) + + // Create a client with cookie jar to maintain session + jar, err := cookiejar.New(nil) + s.Require().NoError(err) + httpClient := &http.Client{ + Jar: jar, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Login to get a session + resp, err := httpClient.PostForm("http://localhost:8080/oauth/login", url.Values{ + "username": {username}, + "password": {password}, + "client_id": {client.ClientID}, + "redirect_uri": {"http://localhost:8080/callback"}, + "state": {scv.State}, + "scope": {"openid profile email"}, + "code_challenge": {scv.CodeChallenge}, + "code_challenge_method": {scv.CodeChallengeMethod}, + "nonce": {nonce}, + }) + s.Require().NoError(err) + defer resp.Body.Close() + s.Require().Equal(http.StatusFound, resp.StatusCode) + + // Extract the authorization code from the redirect + location := resp.Header.Get("Location") + redirectURL, err := url.Parse(location) + s.Require().NoError(err) + authorizationCode := redirectURL.Query().Get("code") + s.Require().NotEmpty(authorizationCode, "authorization code should be present") + + // Exchange the code for tokens + tokenForm := url.Values{ + "grant_type": {"authorization_code"}, + "code": {authorizationCode}, + "redirect_uri": {"http://localhost:8080/callback"}, + "client_id": {client.ClientID}, + "code_verifier": {scv.CodeVerifier}, + } + resp, err = httpClient.PostForm("http://localhost:8080/oauth/token", tokenForm) + s.Require().NoError(err) + defer resp.Body.Close() + s.Require().Equal(http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + s.Require().NoError(err) + + var tokenResponse TokenResponse + err = json.Unmarshal(body, &tokenResponse) + s.Require().NoError(err) + s.Require().NotEmpty(tokenResponse.IDToken, "id_token should be present") + + // Parse the ID token and verify the nonce claim + parts := strings.Split(tokenResponse.IDToken, ".") + s.Require().Len(parts, 3, "id_token should be a valid JWT") + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + s.Require().NoError(err) + + var claims map[string]interface{} + err = json.Unmarshal(payload, &claims) + s.Require().NoError(err) + + s.Equal(nonce, claims["nonce"], "nonce in ID token should match the nonce sent in authorize request") +} + +// TestIDTokenOmitsNonceWhenNotProvided verifies that when no nonce is sent, +// the ID token does not include a nonce claim. +func (s *OAuthFlowSuite) TestIDTokenOmitsNonceWhenNotProvided() { + result := s.mustCompleteOAuthFlow(db.CreateOAuthClientParams{ + ClientID: s.mustGenerateRandomString(8), + ClientSecret: sql.NullString{String: "", Valid: false}, + Name: s.mustGenerateRandomString(8), + RedirectUris: []string{"http://localhost:8080/callback"}, + AllowedScopes: []string{"openid", "profile", "email"}, + IsConfidential: false, + Audience: "http://localhost:8080", + }) + + s.Require().NotEmpty(result.TokenResponse.IDToken) + + parts := strings.Split(result.TokenResponse.IDToken, ".") + s.Require().Len(parts, 3) + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + s.Require().NoError(err) + + var claims map[string]interface{} + err = json.Unmarshal(payload, &claims) + s.Require().NoError(err) + + _, hasNonce := claims["nonce"] + s.False(hasNonce, "nonce should not be present when not sent in authorize request") +} + // TestTokenResponseIDTokenAbsentWithoutOpenID verifies that no id_token is returned // when the openid scope is not requested. func (s *OAuthFlowSuite) TestTokenResponseIDTokenAbsentWithoutOpenID() { diff --git a/pkg/httpserver/templates.go b/pkg/httpserver/templates.go index 6151780..6f2dcad 100644 --- a/pkg/httpserver/templates.go +++ b/pkg/httpserver/templates.go @@ -9,6 +9,7 @@ type LoginPageData struct { Scope []string CodeChallenge string CodeChallengeMethod string + Nonce string } // RegisterPageData holds the data needed to render the registration page template. diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index 139a67e..5425761 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -149,6 +149,7 @@ type IDTokenClaims struct { FamilyName string `json:"family_name,omitempty"` Picture string `json:"picture,omitempty"` AtHash string `json:"at_hash,omitempty"` + Nonce string `json:"nonce,omitempty"` } // GenerateIDToken creates a signed OIDC ID token per OIDC Core Section 3.1.3.3. diff --git a/templates/login.html b/templates/login.html index 90b9701..7bbe02c 100644 --- a/templates/login.html +++ b/templates/login.html @@ -66,6 +66,7 @@

Sign In

+ @@ -76,7 +77,7 @@

Sign In

- Don't have an account? Sign up + Don't have an account? Sign up