diff --git a/pkg/httpserver/oauth_integration_test.go b/pkg/httpserver/oauth_integration_test.go index ee07215..1a6b86b 100644 --- a/pkg/httpserver/oauth_integration_test.go +++ b/pkg/httpserver/oauth_integration_test.go @@ -1517,6 +1517,91 @@ func (s *OAuthFlowSuite) TestOIDCDiscoveryCORSHeaders() { s.Equal("*", resp.Header.Get("Access-Control-Allow-Origin")) } +// Logout Tests + +func (s *OAuthFlowSuite) TestLogoutRedirectsToLoginByDefault() { + resp, err := s.httpClient.Post("http://localhost:8080/oauth/logout", "", nil) + s.Require().NoError(err) + defer resp.Body.Close() + + s.Equal(http.StatusFound, resp.StatusCode) + s.Equal("/oauth/login", resp.Header.Get("Location")) +} + +func (s *OAuthFlowSuite) TestLogoutWithValidPostLogoutRedirectURI() { + // Register a client with a known redirect URI + client := s.mustRegisterOAuthClient(db.CreateOAuthClientParams{ + ClientID: "logout-test-client", + Name: "Logout Test Client", + RedirectUris: []string{"http://example.com/callback"}, + AllowedScopes: []string{"openid"}, + IsConfidential: false, + Audience: "test-audience", + }) + + // Logout with a valid post_logout_redirect_uri and client_id + resp, err := s.httpClient.Post( + "http://localhost:8080/oauth/logout?post_logout_redirect_uri=http://example.com/callback&client_id="+client.ClientID, + "", nil, + ) + s.Require().NoError(err) + defer resp.Body.Close() + + s.Equal(http.StatusFound, resp.StatusCode) + s.Equal("http://example.com/callback", resp.Header.Get("Location")) +} + +func (s *OAuthFlowSuite) TestLogoutRejectsUnregisteredPostLogoutRedirectURI() { + // Register a client with a known redirect URI + s.mustRegisterOAuthClient(db.CreateOAuthClientParams{ + ClientID: "logout-reject-client", + Name: "Logout Reject Client", + RedirectUris: []string{"http://example.com/callback"}, + AllowedScopes: []string{"openid"}, + IsConfidential: false, + Audience: "test-audience", + }) + + // Logout with a redirect URI that is NOT in the client's registered URIs + resp, err := s.httpClient.Post( + "http://localhost:8080/oauth/logout?post_logout_redirect_uri=https://evil.com/phishing&client_id=logout-reject-client", + "", nil, + ) + s.Require().NoError(err) + defer resp.Body.Close() + + // Should ignore the invalid URI and redirect to login + s.Equal(http.StatusFound, resp.StatusCode) + s.Equal("/oauth/login", resp.Header.Get("Location")) +} + +func (s *OAuthFlowSuite) TestLogoutRejectsPostLogoutRedirectURIWithUnknownClientID() { + // Providing a client_id that doesn't exist should fall back to login + resp, err := s.httpClient.Post( + "http://localhost:8080/oauth/logout?post_logout_redirect_uri=http://example.com/callback&client_id=nonexistent-client", + "", nil, + ) + s.Require().NoError(err) + defer resp.Body.Close() + + s.Equal(http.StatusFound, resp.StatusCode) + s.Equal("/oauth/login", resp.Header.Get("Location")) +} + +func (s *OAuthFlowSuite) TestLogoutRejectsPostLogoutRedirectURIWithoutClientID() { + // Providing post_logout_redirect_uri without client_id should fall back to login + resp, err := s.httpClient.Post( + "http://localhost:8080/oauth/logout?post_logout_redirect_uri=http://example.com/callback", + "", nil, + ) + s.Require().NoError(err) + defer resp.Body.Close() + + // Without client_id, can't validate the URI, so redirect to login + s.Equal(http.StatusFound, resp.StatusCode) + s.Equal("/oauth/login", resp.Header.Get("Location")) +} + // Token Introspection Tests func (s *OAuthFlowSuite) TestTokenIntrospectionAccessToken() { diff --git a/pkg/httpserver/routes.go b/pkg/httpserver/routes.go index 6b0949f..78c90aa 100644 --- a/pkg/httpserver/routes.go +++ b/pkg/httpserver/routes.go @@ -2,6 +2,7 @@ package httpserver import ( "net/http" + "slices" _ "github.com/eswan18/identity/docs" "github.com/go-chi/chi/v5" @@ -199,13 +200,22 @@ func (s *Server) HandleLogout(w http.ResponseWriter, r *http.Request) { SameSite: http.SameSiteLaxMode, }) - // Redirect to post_logout_redirect_uri if provided, otherwise to login page - redirectURI := r.URL.Query().Get("post_logout_redirect_uri") - if redirectURI == "" { - redirectURI = r.FormValue("post_logout_redirect_uri") + // Redirect to post_logout_redirect_uri if provided and valid, otherwise to login page. + // Per OIDC RP-Initiated Logout, the URI must be validated against the client's + // registered redirect URIs. A client_id parameter is required to identify the client. + redirectURI := "/oauth/login" + postLogoutURI := r.URL.Query().Get("post_logout_redirect_uri") + if postLogoutURI == "" { + postLogoutURI = r.FormValue("post_logout_redirect_uri") } - if redirectURI == "" { - redirectURI = "/oauth/login" + if postLogoutURI != "" { + clientID := r.URL.Query().Get("client_id") + if clientID != "" { + client, err := s.datastore.Q.GetOAuthClientByClientID(r.Context(), clientID) + if err == nil && slices.Contains(client.RedirectUris, postLogoutURI) { + redirectURI = postLogoutURI + } + } } http.Redirect(w, r, redirectURI, http.StatusFound)