Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions db/migrations/000008_add_nonce.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE oauth_authorization_codes DROP COLUMN nonce;
ALTER TABLE auth_mfa_pending DROP COLUMN nonce;
2 changes: 2 additions & 0 deletions db/migrations/000008_add_nonce.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE oauth_authorization_codes ADD COLUMN nonce text;
ALTER TABLE auth_mfa_pending ADD COLUMN nonce text;
9 changes: 5 additions & 4 deletions db/queries/auth.sql
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ INSERT INTO oauth_authorization_codes (
scope,
code_challenge,
code_challenge_method,
expires_at
expires_at,
nonce
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8);
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9);

-- name: GetAuthorizationCode :one
SELECT *
Expand Down Expand Up @@ -222,8 +223,8 @@ SET mfa_enabled = false, mfa_secret = NULL, mfa_verified_at = NULL, updated_at =
WHERE id = $1;

-- name: CreateMFAPending :exec
INSERT INTO auth_mfa_pending (id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9);
INSERT INTO auth_mfa_pending (id, user_id, client_id, redirect_uri, state, scope, code_challenge, code_challenge_method, expires_at, nonce)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10);

-- name: GetMFAPending :one
SELECT * FROM auth_mfa_pending
Expand Down
19 changes: 13 additions & 6 deletions pkg/db/auth.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions pkg/httpserver/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func normalizeCodeChallenge(challenge string) string {
}

// generateAuthorizationCode generates a new authorization code and stores it in the database
func (s *Server) generateAuthorizationCode(ctx context.Context, userID uuid.UUID, clientID uuid.UUID, redirectURI string, scope []string, codeChallenge string, codeChallengeMethod string) (string, error) {
func (s *Server) generateAuthorizationCode(ctx context.Context, userID uuid.UUID, clientID uuid.UUID, redirectURI string, scope []string, codeChallenge string, codeChallengeMethod string, nonce string) (string, error) {
code, err := generateRandomString(32)
if err != nil {
return "", err
Expand All @@ -202,6 +202,7 @@ func (s *Server) generateAuthorizationCode(ctx context.Context, userID uuid.UUID
CodeChallenge: sql.NullString{String: normalizedChallenge, Valid: codeChallenge != ""},
CodeChallengeMethod: sql.NullString{String: codeChallengeMethod, Valid: codeChallengeMethod != ""},
ExpiresAt: time.Now().Add(authorizationCodeExpiresIn),
Nonce: sql.NullString{String: nonce, Valid: nonce != ""},
})
if err != nil {
return "", err
Expand Down Expand Up @@ -241,7 +242,7 @@ func (s *Server) createSession(ctx context.Context, userID uuid.UUID) (Session,

// generateTokens creates new access and refresh tokens and stores them in the database.
// Returns the token pair on success.
func (s *Server) generateTokens(ctx context.Context, clientID uuid.UUID, userID uuid.UUID, scope []string) (TokenPair, error) {
func (s *Server) generateTokens(ctx context.Context, clientID uuid.UUID, userID uuid.UUID, scope []string, nonce string) (TokenPair, error) {
// Fetch user information for JWT claims
user, err := s.datastore.Q.GetUserByID(ctx, userID)
if err != nil {
Expand Down Expand Up @@ -298,7 +299,7 @@ func (s *Server) generateTokens(ctx context.Context, clientID uuid.UUID, userID
// Generate OIDC ID token when openid scope is requested
var idToken string
if slices.Contains(scope, "openid") {
idClaims := jwtpkg.IDTokenClaims{}
idClaims := jwtpkg.IDTokenClaims{Nonce: nonce}
if slices.Contains(scope, "email") {
idClaims.Email = user.Email
idClaims.EmailVerified = &user.EmailVerified
Expand Down
4 changes: 3 additions & 1 deletion pkg/httpserver/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func (s *Server) HandleLoginPost(w http.ResponseWriter, r *http.Request) {
scope := strings.Split(r.FormValue("scope"), " ")
codeChallenge := r.FormValue("code_challenge")
codeChallengeMethod := r.FormValue("code_challenge_method")
nonce := r.FormValue("nonce")

log.Printf("[DEBUG] HandleLoginPost: username=%s, clientID=%s, redirectURI=%s", username, clientID, redirectURI)

Expand All @@ -120,6 +121,7 @@ func (s *Server) HandleLoginPost(w http.ResponseWriter, r *http.Request) {
Scope: scope,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
Nonce: nonce,
}
// Validate credentials - use the function that includes inactive users
// so we can handle deactivated users appropriately based on the login type
Expand Down Expand Up @@ -217,7 +219,7 @@ func (s *Server) HandleLoginPost(w http.ResponseWriter, r *http.Request) {

// Generate and store authorization code
log.Printf("[DEBUG] HandleLoginPost: Generating authorization code")
authorizationCode, err := s.generateAuthorizationCode(r.Context(), user.ID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod)
authorizationCode, err := s.generateAuthorizationCode(r.Context(), user.ID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod, nonce)
if err != nil {
log.Printf("[ERROR] HandleLoginPost: Failed to generate authorization code: %v", err)
s.renderLoginError(w, http.StatusInternalServerError, "An error occurred", oauthParams)
Expand Down
4 changes: 3 additions & 1 deletion pkg/httpserver/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func (s *Server) HandleMFAPost(w http.ResponseWriter, r *http.Request) {
scope := pending.Scope
codeChallenge := pending.CodeChallenge.String
codeChallengeMethod := pending.CodeChallengeMethod.String
nonce := pending.Nonce.String

// Validate OAuth client
client, err := s.validateOAuthClient(r.Context(), clientID, redirectURI, scope)
Expand All @@ -141,7 +142,7 @@ func (s *Server) HandleMFAPost(w http.ResponseWriter, r *http.Request) {
}

// Generate authorization code
authorizationCode, err := s.generateAuthorizationCode(r.Context(), pending.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod)
authorizationCode, err := s.generateAuthorizationCode(r.Context(), pending.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod, nonce)
if err != nil {
log.Printf("[ERROR] HandleMFAPost: Failed to generate authorization code: %v", err)
http.Redirect(w, r, "/oauth/login", http.StatusFound)
Expand Down Expand Up @@ -180,6 +181,7 @@ func (s *Server) createMFAPendingSession(r *http.Request, userID uuid.UUID, oaut
CodeChallenge: sql.NullString{String: oauthParams.CodeChallenge, Valid: oauthParams.CodeChallenge != ""},
CodeChallengeMethod: sql.NullString{String: oauthParams.CodeChallengeMethod, Valid: oauthParams.CodeChallengeMethod != ""},
ExpiresAt: time.Now().Add(mfaPendingExpiresIn),
Nonce: sql.NullString{String: oauthParams.Nonce, Valid: oauthParams.Nonce != ""},
}

if err := s.datastore.Q.CreateMFAPending(r.Context(), params); err != nil {
Expand Down
22 changes: 14 additions & 8 deletions pkg/httpserver/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (s *Server) HandleOauthAuthorize(w http.ResponseWriter, r *http.Request) {
codeChallenge := r.URL.Query().Get("code_challenge")
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
state := r.URL.Query().Get("state")
nonce := r.URL.Query().Get("nonce")

// Phase 1: Validate client_id and redirect_uri first.
// Per RFC 6749 4.1.2.1, if these are invalid we MUST NOT redirect — show error directly.
Expand Down Expand Up @@ -127,7 +128,7 @@ func (s *Server) HandleOauthAuthorize(w http.ResponseWriter, r *http.Request) {
}

// Phase 4: Generate authorization code and redirect.
authorizationCode, err := s.generateAuthorizationCode(r.Context(), session.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod)
authorizationCode, err := s.generateAuthorizationCode(r.Context(), session.UserID, client.ID, redirectURI, scope, codeChallenge, codeChallengeMethod, nonce)
if err != nil {
redirectError("server_error", "An error occurred")
return
Expand Down Expand Up @@ -271,8 +272,12 @@ func (s *Server) handleAuthorizationCodeGrant(w http.ResponseWriter, r *http.Req
return
}

// Generate tokens
s.writeTokenResponse(w, r, client.ID, authCode.UserID, authCode.Scope)
// Generate tokens, passing nonce for inclusion in the ID token
nonce := ""
if authCode.Nonce.Valid {
nonce = authCode.Nonce.String
}
s.writeTokenResponse(w, r, client.ID, authCode.UserID, authCode.Scope, nonce)
}

// handleRefreshTokenGrant exchanges a refresh token for new tokens
Expand Down Expand Up @@ -323,17 +328,18 @@ func (s *Server) handleRefreshTokenGrant(w http.ResponseWriter, r *http.Request,
return
}

// Issue new tokens with the same user and scope
// Issue new tokens with the same user and scope (no nonce on refresh)
var userID uuid.UUID
if token.UserID.Valid {
userID = token.UserID.UUID
}
s.writeTokenResponse(w, r, client.ID, userID, token.Scope)
s.writeTokenResponse(w, r, client.ID, userID, token.Scope, "")
}

// writeTokenResponse generates tokens and writes the JSON response
func (s *Server) writeTokenResponse(w http.ResponseWriter, r *http.Request, clientID uuid.UUID, userID uuid.UUID, scope []string) {
tokens, err := s.generateTokens(r.Context(), clientID, userID, scope)
// writeTokenResponse generates tokens and writes the JSON response.
// nonce is included in the ID token if non-empty (per OIDC Core 3.1.2.1).
func (s *Server) writeTokenResponse(w http.ResponseWriter, r *http.Request, clientID uuid.UUID, userID uuid.UUID, scope []string, nonce string) {
tokens, err := s.generateTokens(r.Context(), clientID, userID, scope, nonce)
if err != nil {
s.writeTokenError(w, "server_error", "Failed to generate tokens")
return
Expand Down
115 changes: 115 additions & 0 deletions pkg/httpserver/oauth_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,121 @@ func (s *OAuthFlowSuite) TestTokenResponseIncludesIDToken() {
s.Equal(result.User.Username, claims["preferred_username"], "preferred_username should match")
}

// TestIDTokenIncludesNonce verifies that when a nonce is sent in the authorize request,
// it appears in the ID token per OIDC Core Section 3.1.2.1.
func (s *OAuthFlowSuite) TestIDTokenIncludesNonce() {
client := s.mustRegisterOAuthClient(db.CreateOAuthClientParams{
ClientID: s.mustGenerateRandomString(8),
ClientSecret: sql.NullString{String: "", Valid: false},
Name: s.mustGenerateRandomString(8),
RedirectUris: []string{"http://localhost:8080/callback"},
AllowedScopes: []string{"openid", "profile", "email"},
IsConfidential: false,
Audience: "http://localhost:8080",
})
username := s.mustGenerateRandomString(8)
password := s.mustGenerateRandomString(16)
s.mustRegisterUser(username, password, fmt.Sprintf("%s@example.com", username))
scv := s.mustCreateStateAndCodeVerifier()
nonce := s.mustGenerateRandomString(32)

// Create a client with cookie jar to maintain session
jar, err := cookiejar.New(nil)
s.Require().NoError(err)
httpClient := &http.Client{
Jar: jar,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}

// Login to get a session
resp, err := httpClient.PostForm("http://localhost:8080/oauth/login", url.Values{
"username": {username},
"password": {password},
"client_id": {client.ClientID},
"redirect_uri": {"http://localhost:8080/callback"},
"state": {scv.State},
"scope": {"openid profile email"},
"code_challenge": {scv.CodeChallenge},
"code_challenge_method": {scv.CodeChallengeMethod},
"nonce": {nonce},
})
s.Require().NoError(err)
defer resp.Body.Close()
s.Require().Equal(http.StatusFound, resp.StatusCode)

// Extract the authorization code from the redirect
location := resp.Header.Get("Location")
redirectURL, err := url.Parse(location)
s.Require().NoError(err)
authorizationCode := redirectURL.Query().Get("code")
s.Require().NotEmpty(authorizationCode, "authorization code should be present")

// Exchange the code for tokens
tokenForm := url.Values{
"grant_type": {"authorization_code"},
"code": {authorizationCode},
"redirect_uri": {"http://localhost:8080/callback"},
"client_id": {client.ClientID},
"code_verifier": {scv.CodeVerifier},
}
resp, err = httpClient.PostForm("http://localhost:8080/oauth/token", tokenForm)
s.Require().NoError(err)
defer resp.Body.Close()
s.Require().Equal(http.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
s.Require().NoError(err)

var tokenResponse TokenResponse
err = json.Unmarshal(body, &tokenResponse)
s.Require().NoError(err)
s.Require().NotEmpty(tokenResponse.IDToken, "id_token should be present")

// Parse the ID token and verify the nonce claim
parts := strings.Split(tokenResponse.IDToken, ".")
s.Require().Len(parts, 3, "id_token should be a valid JWT")

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
s.Require().NoError(err)

var claims map[string]interface{}
err = json.Unmarshal(payload, &claims)
s.Require().NoError(err)

s.Equal(nonce, claims["nonce"], "nonce in ID token should match the nonce sent in authorize request")
}

// TestIDTokenOmitsNonceWhenNotProvided verifies that when no nonce is sent,
// the ID token does not include a nonce claim.
func (s *OAuthFlowSuite) TestIDTokenOmitsNonceWhenNotProvided() {
result := s.mustCompleteOAuthFlow(db.CreateOAuthClientParams{
ClientID: s.mustGenerateRandomString(8),
ClientSecret: sql.NullString{String: "", Valid: false},
Name: s.mustGenerateRandomString(8),
RedirectUris: []string{"http://localhost:8080/callback"},
AllowedScopes: []string{"openid", "profile", "email"},
IsConfidential: false,
Audience: "http://localhost:8080",
})

s.Require().NotEmpty(result.TokenResponse.IDToken)

parts := strings.Split(result.TokenResponse.IDToken, ".")
s.Require().Len(parts, 3)

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
s.Require().NoError(err)

var claims map[string]interface{}
err = json.Unmarshal(payload, &claims)
s.Require().NoError(err)

_, hasNonce := claims["nonce"]
s.False(hasNonce, "nonce should not be present when not sent in authorize request")
}

// TestTokenResponseIDTokenAbsentWithoutOpenID verifies that no id_token is returned
// when the openid scope is not requested.
func (s *OAuthFlowSuite) TestTokenResponseIDTokenAbsentWithoutOpenID() {
Expand Down
Loading