From 62257b4814b3319856d361a5aa3f2903cb084bd3 Mon Sep 17 00:00:00 2001 From: Ivan Zvyagintsev Date: Mon, 13 Apr 2026 14:18:09 +0300 Subject: [PATCH 1/4] add saml slo Signed-off-by: Ivan Zvyagintsev --- connector/saml/saml.go | 218 +++++++++++++++- connector/saml/saml_test.go | 491 ++++++++++++++++++++++++++++++++++++ connector/saml/types.go | 38 ++- 3 files changed, 735 insertions(+), 12 deletions(-) diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 8ef434b62a..400c81f908 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -3,14 +3,19 @@ package saml import ( "bytes" + "compress/flate" "context" + "crypto/rand" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "encoding/xml" "fmt" + "io" "log/slog" + "net/http" + "net/url" "os" "strings" "sync" @@ -84,6 +89,15 @@ type Config struct { InsecureSkipSignatureValidation bool `json:"insecureSkipSignatureValidation"` + // SLOURL is the IdP's Single Logout Service URL (HTTP-Redirect binding). + // If empty, SLO is not available for this connector. + SLOURL string `json:"sloURL"` + + // InsecureSkipSLOSignatureValidation skips signature validation on SLO responses. + // This is insecure and should only be used for testing or when the IdP + // does not sign LogoutResponses. + InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"` + // Assertion attribute names to lookup various claims with. UsernameAttr string `json:"usernameAttr"` EmailAttr string `json:"emailAttr"` @@ -164,6 +178,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { logger: logger, nameIDPolicyFormat: c.NameIDPolicyFormat, + + sloURL: c.SLOURL, + insecureSkipSLOSignatureValidation: c.InsecureSkipSLOSignatureValidation, } if p.nameIDPolicyFormat == "" { @@ -189,7 +206,8 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { } } - if !c.InsecureSkipSignatureValidation { + needsSLOSigValidation := c.SLOURL != "" && !c.InsecureSkipSLOSignatureValidation + if !c.InsecureSkipSignatureValidation || needsSLOSigValidation { if (c.CA == "") == (c.CAData == nil) { return nil, errors.New("must provide either 'ca' or 'caData'") } @@ -233,8 +251,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { } var ( - _ connector.SAMLConnector = (*provider)(nil) - _ connector.RefreshConnector = (*provider)(nil) + _ connector.SAMLConnector = (*provider)(nil) + _ connector.RefreshConnector = (*provider)(nil) + _ connector.LogoutCallbackConnector = (*provider)(nil) ) type provider struct { @@ -259,12 +278,15 @@ type provider struct { nameIDPolicyFormat string + sloURL string + insecureSkipSLOSignatureValidation bool + logger *slog.Logger } -// cachedIdentity stores the identity from SAML assertion for refresh token support. -// Since SAML has no native refresh mechanism, we cache the identity obtained during -// the initial authentication and return it on subsequent refresh requests. +// cachedIdentity stores the identity from SAML assertion for refresh token support +// and SLO (Single Logout). The NameID/NameIDFormat/SessionIndex fields are used +// to build a SAML LogoutRequest when the user logs out. type cachedIdentity struct { UserID string `json:"userId"` Username string `json:"username"` @@ -272,10 +294,15 @@ type cachedIdentity struct { Email string `json:"email"` EmailVerified bool `json:"emailVerified"` Groups []string `json:"groups,omitempty"` + NameID string `json:"nameId,omitempty"` + NameIDFormat string `json:"nameIdFormat,omitempty"` + SessionIndex string `json:"sessionIndex,omitempty"` } -// marshalCachedIdentity serializes the identity into ConnectorData for refresh token support. -func marshalCachedIdentity(ident connector.Identity) (connector.Identity, error) { +// marshalCachedIdentity serializes the identity along with SAML-specific SLO +// fields into ConnectorData. The nameIDFormat and sessionIdx parameters come +// from the parsed SAML assertion and are needed to construct a LogoutRequest. +func marshalCachedIdentity(ident connector.Identity, nameIDFormat, sessionIdx string) (connector.Identity, error) { ci := cachedIdentity{ UserID: ident.UserID, Username: ident.Username, @@ -283,6 +310,9 @@ func marshalCachedIdentity(ident connector.Identity) (connector.Identity, error) Email: ident.Email, EmailVerified: ident.EmailVerified, Groups: ident.Groups, + NameID: ident.UserID, + NameIDFormat: nameIDFormat, + SessionIndex: sessionIdx, } connectorData, err := json.Marshal(ci) if err != nil { @@ -407,15 +437,22 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str } } + var nameIDFormat string switch { case subject.NameID != nil: if ident.UserID = subject.NameID.Value; ident.UserID == "" { return ident, fmt.Errorf("element NameID does not contain a value") } + nameIDFormat = subject.NameID.Format default: return ident, fmt.Errorf("subject does not contain an NameID element") } + var sessionIdx string + if len(assertion.AuthnStatements) > 0 { + sessionIdx = assertion.AuthnStatements[0].SessionIndex + } + // After verifying the assertion, map data in the attribute statements to // various user info. attributes := assertion.AttributeStatement @@ -442,7 +479,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str if len(p.allowedGroups) == 0 && (!s.Groups || p.groupsAttr == "") { // Groups not requested or not configured. We're done. - return marshalCachedIdentity(ident) + return marshalCachedIdentity(ident, nameIDFormat, sessionIdx) } if len(p.allowedGroups) > 0 && (!s.Groups || p.groupsAttr == "") { @@ -468,7 +505,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str if len(p.allowedGroups) == 0 { // No allowed groups set, just return the ident - return marshalCachedIdentity(ident) + return marshalCachedIdentity(ident, nameIDFormat, sessionIdx) } // Look for membership in one of the allowed groups @@ -484,7 +521,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str } // Otherwise, we're good - return marshalCachedIdentity(ident) + return marshalCachedIdentity(ident, nameIDFormat, sessionIdx) } // Refresh implements connector.RefreshConnector. @@ -711,3 +748,162 @@ func before(now, notBefore time.Time) bool { func after(now, notOnOrAfter time.Time) bool { return now.After(notOnOrAfter.Add(allowedClockDrift)) } + +// newRequestID generates a random ID suitable for SAML request IDs. +func newRequestID() string { + buf := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + panic("crypto/rand failed: " + err.Error()) + } + return fmt.Sprintf("_%x", buf) +} + +// LogoutURL builds a SAML LogoutRequest and returns the IdP's SLO endpoint URL +// with the request encoded using HTTP-Redirect binding (deflate + base64). +// +// See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf +// "3.4 HTTP Redirect Binding" +func (p *provider) LogoutURL(_ context.Context, connectorData []byte, postLogoutRedirectURI string) (string, error) { + if p.sloURL == "" { + return "", nil + } + + var ci cachedIdentity + if len(connectorData) > 0 { + if err := json.Unmarshal(connectorData, &ci); err != nil { + return "", fmt.Errorf("saml: failed to unmarshal connector data for logout: %v", err) + } + } + + if ci.NameID == "" { + return "", nil + } + + req := &logoutRequest{ + ID: newRequestID(), + IssueInstant: xmlTime(p.now()), + Destination: p.sloURL, + NameID: nameID{ + Format: ci.NameIDFormat, + Value: ci.NameID, + }, + } + if p.entityIssuer != "" { + req.Issuer = &issuer{Issuer: p.entityIssuer} + } + if ci.SessionIndex != "" { + req.SessionIndex = []sessionIndex{{Value: ci.SessionIndex}} + } + + data, err := xml.Marshal(req) + if err != nil { + return "", fmt.Errorf("saml: failed to marshal LogoutRequest: %v", err) + } + + // HTTP-Redirect binding: deflate then base64-encode. + var buf bytes.Buffer + fw, err := flate.NewWriter(&buf, flate.DefaultCompression) + if err != nil { + return "", fmt.Errorf("saml: failed to create deflate writer: %v", err) + } + if _, err := fw.Write(data); err != nil { + return "", fmt.Errorf("saml: failed to deflate LogoutRequest: %v", err) + } + if err := fw.Close(); err != nil { + return "", fmt.Errorf("saml: failed to close deflate writer: %v", err) + } + + encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) + + u, err := url.Parse(p.sloURL) + if err != nil { + return "", fmt.Errorf("saml: failed to parse SLO URL: %v", err) + } + q := u.Query() + q.Set("SAMLRequest", encoded) + if postLogoutRedirectURI != "" { + q.Set("RelayState", postLogoutRedirectURI) + } + u.RawQuery = q.Encode() + + return u.String(), nil +} + +// HandleLogoutCallback validates the IdP's LogoutResponse received after +// an SP-initiated logout redirect. The response arrives as a SAMLResponse +// parameter via either GET query (HTTP-Redirect binding: deflated + base64) +// or POST form (HTTP-POST binding: base64 only). +func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) error { + var samlResponse string + if r.Method == http.MethodGet { + samlResponse = r.URL.Query().Get("SAMLResponse") + } else { + if err := r.ParseForm(); err != nil { + return fmt.Errorf("saml slo: failed to parse form: %v", err) + } + samlResponse = r.FormValue("SAMLResponse") + } + + if samlResponse == "" { + return nil + } + + compressed, err := base64.StdEncoding.DecodeString(samlResponse) + if err != nil { + return fmt.Errorf("saml slo: failed to decode SAMLResponse: %v", err) + } + + // HTTP-Redirect binding uses DEFLATE compression; HTTP-POST does not. + // Try to inflate; if it fails, treat the data as uncompressed XML. + rawResp, err := io.ReadAll(flate.NewReader(bytes.NewReader(compressed))) + if err != nil { + rawResp = compressed + } + + byteReader := bytes.NewReader(rawResp) + if xrvErr := xrv.Validate(byteReader); xrvErr != nil { + return fmt.Errorf("saml slo: %w", xrvErr) + } + + if p.validator != nil && !p.insecureSkipSLOSignatureValidation { + if _, err := p.validateSignature(rawResp); err != nil { + return fmt.Errorf("saml slo: %v", err) + } + } + + var resp logoutResponse + if err := xml.Unmarshal(rawResp, &resp); err != nil { + return fmt.Errorf("saml slo: failed to unmarshal LogoutResponse: %v", err) + } + + if resp.Status != nil { + if err := p.validateStatus(resp.Status); err != nil { + return fmt.Errorf("saml slo: %v", err) + } + } + + return nil +} + +// validateSignature validates the XML digital signature of the given raw XML. +func (p *provider) validateSignature(rawXML []byte) ([]byte, error) { + if p.validator == nil { + return nil, fmt.Errorf("signature validation unavailable (no validator configured)") + } + + doc := etree.NewDocument() + if err := doc.ReadFromBytes(rawXML); err != nil { + return nil, fmt.Errorf("failed to parse XML: %v", err) + } + + root := doc.Root() + if root == nil { + return nil, fmt.Errorf("empty XML document") + } + + if _, err := p.validator.Validate(root); err != nil { + return nil, fmt.Errorf("signature validation failed: %v", err) + } + + return rawXML, nil +} diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index 3eba5cf878..0a176b290d 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -1,18 +1,28 @@ package saml import ( + "bytes" + "compress/flate" "context" + "crypto/tls" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" + "encoding/xml" "errors" + "io" "log/slog" + "net/http" + "net/http/httptest" + "net/url" "os" "sort" + "strings" "testing" "time" + "github.com/beevik/etree" "github.com/kylelemons/godebug/pretty" dsig "github.com/russellhaering/goxmldsig" @@ -916,3 +926,484 @@ func TestSAMLRefresh(t *testing.T) { } }) } + +func TestHandlePOSTPopulatesSLOFields(t *testing.T) { + c := Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + GroupsAttr: "groups", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + } + + conn, err := c.openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z") + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return now } + + resp, err := os.ReadFile("testdata/good-resp.xml") + if err != nil { + t.Fatal(err) + } + samlResp := base64.StdEncoding.EncodeToString(resp) + + scopes := connector.Scopes{OfflineAccess: true, Groups: true} + ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m") + if err != nil { + t.Fatalf("HandlePOST failed: %v", err) + } + + if len(ident.ConnectorData) == 0 { + t.Fatal("expected ConnectorData to be set") + } + + var ci cachedIdentity + if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil { + t.Fatalf("failed to unmarshal ConnectorData: %v", err) + } + + if ci.NameID == "" { + t.Error("expected NameID to be populated in ConnectorData") + } + if ci.NameID != ident.UserID { + t.Errorf("NameID should match UserID: got %q, want %q", ci.NameID, ident.UserID) + } + if ci.NameIDFormat != "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" { + t.Errorf("unexpected NameIDFormat: %q", ci.NameIDFormat) + } + if ci.SessionIndex != "6zmm5mguyebwvajyf2sdwwcw6m" { + t.Errorf("unexpected SessionIndex: got %q, want %q", ci.SessionIndex, "6zmm5mguyebwvajyf2sdwwcw6m") + } +} + +// decodeSAMLRequest decodes a SAMLRequest query parameter value +// (base64 → inflate → XML) into a logoutRequest struct. +func decodeSAMLRequest(t *testing.T, encoded string) logoutRequest { + t.Helper() + compressed, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatalf("failed to base64 decode SAMLRequest: %v", err) + } + inflated, err := io.ReadAll(flate.NewReader(bytes.NewReader(compressed))) + if err != nil { + t.Fatalf("failed to inflate SAMLRequest: %v", err) + } + var req logoutRequest + if err := xml.Unmarshal(inflated, &req); err != nil { + t.Fatalf("failed to unmarshal LogoutRequest: %v", err) + } + return req +} + +// successLogoutResponseXML returns a minimal SAML LogoutResponse with Success status. +const successLogoutResponseXML = ` + https://idp.example.com + + + +` + +const failedLogoutResponseXML = ` + https://idp.example.com + + + Logout failed + +` + +func TestLogoutURL(t *testing.T) { + connNoSLO, err := (&Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + connSLO, err := (&Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + SLOURL: "http://idp.example.com/slo", + EntityIssuer: "http://127.0.0.1:5556/dex", + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("SLONotConfigured", func(t *testing.T) { + connData, _ := json.Marshal(cachedIdentity{ + NameID: "user@example.com", + NameIDFormat: nameIDFormatEmailAddress, + }) + + u, err := connNoSLO.LogoutURL(context.Background(), connData, "https://app.example.com/done") + if err != nil { + t.Fatalf("LogoutURL error: %v", err) + } + if u != "" { + t.Errorf("expected empty URL when SLO not configured, got %q", u) + } + }) + + t.Run("EmptyNameID", func(t *testing.T) { + connData, _ := json.Marshal(cachedIdentity{}) + + u, err := connSLO.LogoutURL(context.Background(), connData, "https://app.example.com/done") + if err != nil { + t.Fatalf("LogoutURL error: %v", err) + } + if u != "" { + t.Errorf("expected empty URL when NameID is empty, got %q", u) + } + }) + + t.Run("NilConnectorData", func(t *testing.T) { + u, err := connSLO.LogoutURL(context.Background(), nil, "https://app.example.com/done") + if err != nil { + t.Fatalf("LogoutURL error: %v", err) + } + if u != "" { + t.Errorf("expected empty URL with nil connector data, got %q", u) + } + }) + + t.Run("ValidLogoutRequest", func(t *testing.T) { + connData, _ := json.Marshal(cachedIdentity{ + NameID: "user@example.com", + NameIDFormat: nameIDFormatEmailAddress, + SessionIndex: "session-abc-123", + }) + + u, err := connSLO.LogoutURL(context.Background(), connData, "https://app.example.com/done") + if err != nil { + t.Fatalf("LogoutURL error: %v", err) + } + if u == "" { + t.Fatal("expected non-empty URL") + } + + parsed, err := url.Parse(u) + if err != nil { + t.Fatalf("failed to parse returned URL: %v", err) + } + + if parsed.Host != "idp.example.com" { + t.Errorf("unexpected host: %q", parsed.Host) + } + if parsed.Path != "/slo" { + t.Errorf("unexpected path: %q", parsed.Path) + } + if parsed.Query().Get("RelayState") != "https://app.example.com/done" { + t.Errorf("unexpected RelayState: %q", parsed.Query().Get("RelayState")) + } + + req := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) + + if req.NameID.Value != "user@example.com" { + t.Errorf("NameID mismatch: got %q", req.NameID.Value) + } + if req.NameID.Format != nameIDFormatEmailAddress { + t.Errorf("NameID Format mismatch: got %q", req.NameID.Format) + } + if req.Destination != "http://idp.example.com/slo" { + t.Errorf("Destination mismatch: got %q", req.Destination) + } + if req.Issuer == nil || req.Issuer.Issuer != "http://127.0.0.1:5556/dex" { + t.Errorf("Issuer mismatch: %+v", req.Issuer) + } + if len(req.SessionIndex) != 1 || req.SessionIndex[0].Value != "session-abc-123" { + t.Errorf("SessionIndex mismatch: %+v", req.SessionIndex) + } + if req.ID == "" { + t.Error("expected non-empty request ID") + } + }) + + t.Run("NoSessionIndex", func(t *testing.T) { + connData, _ := json.Marshal(cachedIdentity{ + NameID: "user@example.com", + NameIDFormat: nameIDFormatEmailAddress, + }) + + u, err := connSLO.LogoutURL(context.Background(), connData, "") + if err != nil { + t.Fatalf("LogoutURL error: %v", err) + } + + parsed, err := url.Parse(u) + if err != nil { + t.Fatalf("failed to parse URL: %v", err) + } + + if parsed.Query().Get("RelayState") != "" { + t.Error("expected no RelayState when postLogoutRedirectURI is empty") + } + + req := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) + if len(req.SessionIndex) != 0 { + t.Errorf("expected no SessionIndex, got %+v", req.SessionIndex) + } + }) +} + +func TestHandleLogoutCallback(t *testing.T) { + conn, err := (&Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSLOSignatureValidation: true, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("EmptySAMLResponse", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/logout/callback", nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + t.Errorf("expected nil error for empty SAMLResponse, got: %v", err) + } + }) + + t.Run("ValidLogoutResponse", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + t.Errorf("expected no error for valid response, got: %v", err) + } + }) + + t.Run("FailedStatus", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte(failedLogoutResponseXML)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error for failed status") + } + }) + + t.Run("InvalidBase64", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse=not-valid-base64!!!", nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error for invalid base64") + } + }) + + t.Run("InvalidXML", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("not xml at all")) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error for invalid XML") + } + }) + + t.Run("POSTBinding", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) + form := url.Values{"SAMLResponse": {encoded}} + req := httptest.NewRequest(http.MethodPost, "/logout/callback", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + t.Errorf("expected no error for POST binding, got: %v", err) + } + }) + + t.Run("DeflatedResponse", func(t *testing.T) { + // HTTP-Redirect binding: response is deflated + base64 encoded + var buf bytes.Buffer + fw, err := flate.NewWriter(&buf, flate.DefaultCompression) + if err != nil { + t.Fatal(err) + } + if _, err := fw.Write([]byte(successLogoutResponseXML)); err != nil { + t.Fatal(err) + } + if err := fw.Close(); err != nil { + t.Fatal(err) + } + + encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + t.Errorf("expected no error for deflated response, got: %v", err) + } + }) +} + +// signXMLDocument signs an etree document using the test CA key/cert pair +// and returns the resulting XML bytes. +func signXMLDocument(t *testing.T, doc *etree.Document) []byte { + t.Helper() + tlsCert, err := tls.LoadX509KeyPair("testdata/ca.crt", "testdata/ca.key") + if err != nil { + t.Fatalf("failed to load test key pair: %v", err) + } + keyStore := dsig.TLSCertKeyStore(tlsCert) + sigCtx := dsig.NewDefaultSigningContext(keyStore) + + signed, err := sigCtx.SignEnveloped(doc.Root()) + if err != nil { + t.Fatalf("failed to sign XML: %v", err) + } + + signedDoc := etree.NewDocument() + signedDoc.SetRoot(signed) + out, err := signedDoc.WriteToBytes() + if err != nil { + t.Fatalf("failed to serialize signed XML: %v", err) + } + return out +} + +func TestHandleLogoutCallbackSignatureValidation(t *testing.T) { + conn, err := (&Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSLOSignatureValidation: false, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("ValidSignature", func(t *testing.T) { + doc := etree.NewDocument() + if err := doc.ReadFromString(successLogoutResponseXML); err != nil { + t.Fatal(err) + } + signedXML := signXMLDocument(t, doc) + + encoded := base64.StdEncoding.EncodeToString(signedXML) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + t.Errorf("expected no error for validly signed response, got: %v", err) + } + }) + + t.Run("InvalidSignature", func(t *testing.T) { + // Use unsigned XML — should fail signature validation + encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error for unsigned response when signature validation is enabled") + } + }) + + t.Run("WrongCA", func(t *testing.T) { + connBadCA, err := (&Config{ + CA: "testdata/bad-ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSLOSignatureValidation: false, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + doc := etree.NewDocument() + if err := doc.ReadFromString(successLogoutResponseXML); err != nil { + t.Fatal(err) + } + signedXML := signXMLDocument(t, doc) + + encoded := base64.StdEncoding.EncodeToString(signedXML) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := connBadCA.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error when response is signed with different CA") + } + }) +} + +func TestSLOEndToEnd(t *testing.T) { + c := Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + GroupsAttr: "groups", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + SLOURL: "http://idp.example.com/slo", + InsecureSkipSLOSignatureValidation: true, + } + + conn, err := c.openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + // Step 1: HandlePOST — simulate login, extract ConnectorData + now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z") + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return now } + + resp, err := os.ReadFile("testdata/good-resp.xml") + if err != nil { + t.Fatal(err) + } + samlResp := base64.StdEncoding.EncodeToString(resp) + + scopes := connector.Scopes{OfflineAccess: true, Groups: true} + ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m") + if err != nil { + t.Fatalf("HandlePOST failed: %v", err) + } + + if len(ident.ConnectorData) == 0 { + t.Fatal("expected ConnectorData after HandlePOST") + } + + // Step 2: LogoutURL — build logout redirect URL from ConnectorData + logoutURL, err := conn.LogoutURL(context.Background(), ident.ConnectorData, "https://app.example.com/done") + if err != nil { + t.Fatalf("LogoutURL failed: %v", err) + } + if logoutURL == "" { + t.Fatal("expected non-empty logout URL") + } + + parsed, err := url.Parse(logoutURL) + if err != nil { + t.Fatalf("failed to parse logout URL: %v", err) + } + + // Verify the LogoutRequest contains the same NameID from HandlePOST + req := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) + if req.NameID.Value != ident.UserID { + t.Errorf("LogoutRequest NameID should match HandlePOST UserID: got %q, want %q", req.NameID.Value, ident.UserID) + } + if req.NameID.Format != "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" { + t.Errorf("LogoutRequest NameID format mismatch: got %q", req.NameID.Format) + } + if len(req.SessionIndex) != 1 || req.SessionIndex[0].Value != "6zmm5mguyebwvajyf2sdwwcw6m" { + t.Errorf("LogoutRequest SessionIndex mismatch: %+v", req.SessionIndex) + } + if req.Issuer != nil { + t.Errorf("expected no Issuer when EntityIssuer is not configured, got: %+v", req.Issuer) + } + + // Step 3: HandleLogoutCallback — simulate IdP response + encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) + callbackReq := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), callbackReq); err != nil { + t.Fatalf("HandleLogoutCallback failed: %v", err) + } +} diff --git a/connector/saml/types.go b/connector/saml/types.go index c8d7e7f3b3..2717002fb7 100644 --- a/connector/saml/types.go +++ b/connector/saml/types.go @@ -80,7 +80,7 @@ type subject struct { type nameID struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` - Format string `xml:"Format,omitempty"` + Format string `xml:"Format,attr,omitempty"` Value string `xml:",chardata"` } @@ -191,9 +191,15 @@ type assertion struct { Conditions *conditions `xml:"Conditions"` + AuthnStatements []authnStatement `xml:"AuthnStatement,omitempty"` AttributeStatement *attributeStatement `xml:"AttributeStatement,omitempty"` } +type authnStatement struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"` + SessionIndex string `xml:"SessionIndex,attr,omitempty"` +} + type attributeStatement struct { XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"` @@ -275,3 +281,33 @@ func (a attribute) String() string { // "groups" = ["engineering", "docs"] return fmt.Sprintf("%q = %q", a.Name, values) } + +type logoutRequest struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutRequest"` + + ID string `xml:"ID,attr"` + Version samlVersion `xml:"Version,attr"` + IssueInstant xmlTime `xml:"IssueInstant,attr"` + Destination string `xml:"Destination,attr,omitempty"` + + Issuer *issuer `xml:"Issuer,omitempty"` + NameID nameID `xml:"NameID"` + SessionIndex []sessionIndex `xml:"SessionIndex,omitempty"` +} + +type sessionIndex struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol SessionIndex"` + Value string `xml:",chardata"` +} + +type logoutResponse struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutResponse"` + + ID string `xml:"ID,attr"` + InResponseTo string `xml:"InResponseTo,attr,omitempty"` + Version samlVersion `xml:"Version,attr"` + Destination string `xml:"Destination,attr,omitempty"` + + Issuer *issuer `xml:"Issuer,omitempty"` + Status *status `xml:"Status"` +} From 50fa0cd90f86d4b248dee36423ed6ceb476cde6c Mon Sep 17 00:00:00 2001 From: Ivan Zvyagintsev Date: Mon, 13 Apr 2026 14:33:33 +0300 Subject: [PATCH 2/4] add saml slo Signed-off-by: Ivan Zvyagintsev --- connector/saml/saml.go | 115 ++++++++++++++++++++++++++++++++- connector/saml/saml_test.go | 124 +++++++++++++++++++++++++++++++++--- connector/saml/types.go | 1 + 3 files changed, 227 insertions(+), 13 deletions(-) diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 400c81f908..778c182e13 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -5,7 +5,9 @@ import ( "bytes" "compress/flate" "context" + "crypto" "crypto/rand" + "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/json" @@ -246,6 +248,7 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { return nil, errors.New("no certificates found in ca data") } p.validator = dsig.NewDefaultValidationContext(certStore{certs}) + p.certs = certs } return p, nil } @@ -265,6 +268,9 @@ type provider struct { // If nil, don't do signature validation. validator *dsig.ValidationContext + // Stored separately for HTTP-Redirect binding signature verification, + // which uses raw RSA/ECDSA over query string rather than XML digital signatures. + certs []*x509.Certificate // Attribute mappings usernameAttr string @@ -865,9 +871,17 @@ func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) erro return fmt.Errorf("saml slo: %w", xrvErr) } - if p.validator != nil && !p.insecureSkipSLOSignatureValidation { - if _, err := p.validateSignature(rawResp); err != nil { - return fmt.Errorf("saml slo: %v", err) + if !p.insecureSkipSLOSignatureValidation { + if r.Method == http.MethodGet && len(p.certs) > 0 { + // HTTP-Redirect binding: signature is in query parameters. + if err := p.validateRedirectSignature(r); err != nil { + return fmt.Errorf("saml slo: %v", err) + } + } else if r.Method != http.MethodGet && p.validator != nil { + // HTTP-POST binding: signature is embedded in XML. + if _, err := p.validateSignature(rawResp); err != nil { + return fmt.Errorf("saml slo: %v", err) + } } } @@ -885,6 +899,101 @@ func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) erro return nil } +// redirectSigAlgToHash maps XML Signature algorithm URIs used in SAML HTTP-Redirect +// binding to Go crypto.Hash values. Only RSA algorithms are supported. +// See: https://www.w3.org/TR/xmldsig-core1/#sec-AlgID +var redirectSigAlgToHash = map[string]crypto.Hash{ + "http://www.w3.org/2000/09/xmldsig#rsa-sha1": crypto.SHA1, + "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256": crypto.SHA256, + "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384": crypto.SHA384, + "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512": crypto.SHA512, +} + +// rawQueryParam extracts the raw (still URL-encoded) value of a query parameter +// from a raw query string. This is needed for SAML HTTP-Redirect binding signature +// validation, which signs over the URL-encoded parameter values. +func rawQueryParam(rawQuery, key string) (string, bool) { + prefix := key + "=" + for rawQuery != "" { + var pair string + if i := strings.IndexByte(rawQuery, '&'); i >= 0 { + pair, rawQuery = rawQuery[:i], rawQuery[i+1:] + } else { + pair, rawQuery = rawQuery, "" + } + if strings.HasPrefix(pair, prefix) { + return pair[len(prefix):], true + } + } + return "", false +} + +// validateRedirectSignature verifies the query-string signature used in SAML +// HTTP-Redirect binding. Unlike HTTP-POST where the signature is embedded in +// the XML (), HTTP-Redirect carries it as Signature and SigAlg +// query parameters. The signed content is reconstructed per SAML 2.0 Bindings +// Section 3.4.4.1: SAMLResponse=value&RelayState=value&SigAlg=value (using +// the original URL-encoded values). +func (p *provider) validateRedirectSignature(r *http.Request) error { + rawQuery := r.URL.RawQuery + + sigEncoded, ok := rawQueryParam(rawQuery, "Signature") + if !ok || sigEncoded == "" { + return fmt.Errorf("missing Signature query parameter") + } + + sigAlgEncoded, ok := rawQueryParam(rawQuery, "SigAlg") + if !ok || sigAlgEncoded == "" { + return fmt.Errorf("missing SigAlg query parameter") + } + + sigAlg, err := url.QueryUnescape(sigAlgEncoded) + if err != nil { + return fmt.Errorf("failed to decode SigAlg: %v", err) + } + + hashAlg, ok := redirectSigAlgToHash[sigAlg] + if !ok { + return fmt.Errorf("unsupported signature algorithm: %s", sigAlg) + } + + // Reconstruct the signed content in the spec-mandated order. + var parts []string + if v, ok := rawQueryParam(rawQuery, "SAMLResponse"); ok { + parts = append(parts, "SAMLResponse="+v) + } + if v, ok := rawQueryParam(rawQuery, "RelayState"); ok { + parts = append(parts, "RelayState="+v) + } + parts = append(parts, "SigAlg="+sigAlgEncoded) + signedContent := strings.Join(parts, "&") + + sigB64, err := url.QueryUnescape(sigEncoded) + if err != nil { + return fmt.Errorf("failed to URL-decode Signature: %v", err) + } + sig, err := base64.StdEncoding.DecodeString(sigB64) + if err != nil { + return fmt.Errorf("failed to base64-decode Signature: %v", err) + } + + h := hashAlg.New() + h.Write([]byte(signedContent)) + hashed := h.Sum(nil) + + for _, cert := range p.certs { + rsaPub, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + continue + } + if rsa.VerifyPKCS1v15(rsaPub, hashAlg, hashed, sig) == nil { + return nil + } + } + + return fmt.Errorf("redirect binding signature validation failed") +} + // validateSignature validates the XML digital signature of the given raw XML. func (p *provider) validateSignature(rawXML []byte) ([]byte, error) { if p.validator == nil { diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index 0a176b290d..4a61ca74d3 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -4,6 +4,9 @@ import ( "bytes" "compress/flate" "context" + "crypto" + "crypto/rand" + "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" @@ -1268,7 +1271,14 @@ func signXMLDocument(t *testing.T, doc *etree.Document) []byte { return out } -func TestHandleLogoutCallbackSignatureValidation(t *testing.T) { +func postSAMLResponse(encoded string) *http.Request { + form := url.Values{"SAMLResponse": {encoded}} + req := httptest.NewRequest(http.MethodPost, "/logout/callback", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req +} + +func TestHandleLogoutCallbackPOSTSignatureValidation(t *testing.T) { conn, err := (&Config{ CA: "testdata/ca.crt", UsernameAttr: "Name", @@ -1287,19 +1297,16 @@ func TestHandleLogoutCallbackSignatureValidation(t *testing.T) { t.Fatal(err) } signedXML := signXMLDocument(t, doc) - encoded := base64.StdEncoding.EncodeToString(signedXML) - req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + + if err := conn.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err != nil { t.Errorf("expected no error for validly signed response, got: %v", err) } }) t.Run("InvalidSignature", func(t *testing.T) { - // Use unsigned XML — should fail signature validation encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) - req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err == nil { t.Error("expected error for unsigned response when signature validation is enabled") } }) @@ -1322,15 +1329,112 @@ func TestHandleLogoutCallbackSignatureValidation(t *testing.T) { t.Fatal(err) } signedXML := signXMLDocument(t, doc) - encoded := base64.StdEncoding.EncodeToString(signedXML) - req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := connBadCA.HandleLogoutCallback(context.Background(), req); err == nil { + + if err := connBadCA.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err == nil { t.Error("expected error when response is signed with different CA") } }) } +// signRedirectBinding builds a complete URL for a GET LogoutResponse with +// SAML HTTP-Redirect binding signature. The XML is deflated, base64-encoded, +// and a query-string RSA-SHA256 signature is appended. +func signRedirectBinding(t *testing.T, xmlPayload string, keyFile, certFile string) string { + t.Helper() + + var buf bytes.Buffer + fw, err := flate.NewWriter(&buf, flate.DefaultCompression) + if err != nil { + t.Fatalf("deflate writer: %v", err) + } + if _, err := fw.Write([]byte(xmlPayload)); err != nil { + t.Fatalf("deflate write: %v", err) + } + if err := fw.Close(); err != nil { + t.Fatalf("deflate close: %v", err) + } + samlResp := base64.StdEncoding.EncodeToString(buf.Bytes()) + + sigAlg := "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" + signedContent := "SAMLResponse=" + url.QueryEscape(samlResp) + + "&SigAlg=" + url.QueryEscape(sigAlg) + + tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + t.Fatalf("load key pair: %v", err) + } + rsaKey, ok := tlsCert.PrivateKey.(*rsa.PrivateKey) + if !ok { + t.Fatal("test key is not RSA") + } + + h := crypto.SHA256.New() + h.Write([]byte(signedContent)) + sig, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, h.Sum(nil)) + if err != nil { + t.Fatalf("sign: %v", err) + } + sigB64 := base64.StdEncoding.EncodeToString(sig) + + return "/logout/callback?" + signedContent + + "&Signature=" + url.QueryEscape(sigB64) +} + +func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) { + conn, err := (&Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSLOSignatureValidation: false, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("ValidSignature", func(t *testing.T) { + u := signRedirectBinding(t, successLogoutResponseXML, "testdata/ca.key", "testdata/ca.crt") + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("MissingSignature", func(t *testing.T) { + var buf bytes.Buffer + fw, _ := flate.NewWriter(&buf, flate.DefaultCompression) + fw.Write([]byte(successLogoutResponseXML)) + fw.Close() + encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) + + req := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error for missing Signature parameter") + } + }) + + t.Run("WrongCA", func(t *testing.T) { + u := signRedirectBinding(t, successLogoutResponseXML, "testdata/bad-ca.key", "testdata/bad-ca.crt") + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error when signed with wrong CA") + } + }) + + t.Run("TamperedPayload", func(t *testing.T) { + u := signRedirectBinding(t, successLogoutResponseXML, "testdata/ca.key", "testdata/ca.crt") + // Replace part of the SAMLResponse value to simulate tampering. + u = strings.Replace(u, "SAMLResponse=", "SAMLResponse=AAAA", 1) + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + t.Error("expected error for tampered payload") + } + }) +} + func TestSLOEndToEnd(t *testing.T) { c := Config{ CA: "testdata/ca.crt", diff --git a/connector/saml/types.go b/connector/saml/types.go index 2717002fb7..3997e60f40 100644 --- a/connector/saml/types.go +++ b/connector/saml/types.go @@ -306,6 +306,7 @@ type logoutResponse struct { ID string `xml:"ID,attr"` InResponseTo string `xml:"InResponseTo,attr,omitempty"` Version samlVersion `xml:"Version,attr"` + IssueInstant xmlTime `xml:"IssueInstant,attr,omitempty"` Destination string `xml:"Destination,attr,omitempty"` Issuer *issuer `xml:"Issuer,omitempty"` From 9c88d96f7910c076be827e6b95b52767d4f4aed6 Mon Sep 17 00:00:00 2001 From: Ivan Zvyagintsev Date: Tue, 5 May 2026 14:43:36 +0300 Subject: [PATCH 3/4] apply suggestions from review Signed-off-by: Ivan Zvyagintsev --- connector/connector.go | 23 + connector/saml/saml.go | 327 ++++++++++--- connector/saml/saml_test.go | 549 +++++++++++++++++++--- server/logout.go | 63 ++- storage/conformance/conformance.go | 86 ++++ storage/ent/client/authsession.go | 34 +- storage/ent/client/types.go | 7 + storage/ent/db/authsession.go | 17 +- storage/ent/db/authsession/authsession.go | 3 + storage/ent/db/authsession/where.go | 55 +++ storage/ent/db/authsession_create.go | 10 + storage/ent/db/authsession_update.go | 36 ++ storage/ent/db/migrate/schema.go | 1 + storage/ent/db/mutation.go | 80 +++- storage/ent/schema/authsession.go | 3 + storage/etcd/types.go | 3 + storage/storage.go | 7 + 17 files changed, 1136 insertions(+), 168 deletions(-) diff --git a/connector/connector.go b/connector/connector.go index 9ab19cb96d..017859d21f 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -136,3 +136,26 @@ type LogoutCallbackConnector interface { // return nil. HandleLogoutCallback(ctx context.Context, r *http.Request) error } + +// StatefulLogoutCallbackConnector is an optional capability for connectors +// whose logout flow needs server-side correlation state to be carried from +// the outgoing logout request to the inbound logout response. The server +// persists the opaque state alongside the logout session and hands it back +// on the callback, allowing the connector to enforce one-shot, replay-proof +// checks (e.g. SAML's InResponseTo). +// +// Connectors that don't need correlation state should implement the simpler +// LogoutCallbackConnector instead. The server prefers this interface over +// LogoutCallbackConnector when both are implemented. +type StatefulLogoutCallbackConnector interface { + // LogoutURLWithState returns the upstream provider's logout URL plus an + // opaque connector-specific state to be persisted by the server and + // passed back to HandleLogoutCallbackWithState. Returning empty url means + // upstream logout is not available; in that case state must be nil. + LogoutURLWithState(ctx context.Context, connectorData []byte, postLogoutRedirectURI string) (logoutURL string, state []byte, err error) + + // HandleLogoutCallbackWithState validates the upstream provider's logout + // response received in the callback request. state is the value returned + // by the matching LogoutURLWithState call. + HandleLogoutCallbackWithState(ctx context.Context, r *http.Request, state []byte) error +} diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 778c182e13..4ebc7e66d4 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -55,6 +55,9 @@ const ( // allowed clock drift for timestamp validation allowedClockDrift = time.Duration(30) * time.Second + + // Default RSA algorithm for SAML HTTP-Redirect query-string signatures (SP logout). + defaultRedirectSigAlg = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" ) var ( @@ -95,10 +98,15 @@ type Config struct { // If empty, SLO is not available for this connector. SLOURL string `json:"sloURL"` - // InsecureSkipSLOSignatureValidation skips signature validation on SLO responses. - // This is insecure and should only be used for testing or when the IdP - // does not sign LogoutResponses. - InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"` + // SLOSigningKey and SLOSigningKeyData are a PEM-encoded RSA private key used to + // sign SP-initiated SAML LogoutRequests on the HTTP-Redirect binding (SigAlg + + // Signature query parameters per SAML 2.0 Bindings §3.4.4). Optional: when both + // are empty, LogoutRequest is sent unsigned. + // + // This is not the same as ca/caData: those are public certificates for verifying + // the IdP; signing requires the SP's own key pair registered with the IdP. + SLOSigningKey string `json:"sloSigningKey"` + SLOSigningKeyData []byte `json:"sloSigningKeyData"` // Assertion attribute names to lookup various claims with. UsernameAttr string `json:"usernameAttr"` @@ -181,8 +189,46 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { nameIDPolicyFormat: c.NameIDPolicyFormat, - sloURL: c.SLOURL, - insecureSkipSLOSignatureValidation: c.InsecureSkipSLOSignatureValidation, + sloURL: c.SLOURL, + } + + if c.SLOSigningKey != "" && len(c.SLOSigningKeyData) > 0 { + return nil, errors.New("saml: specify at most one of sloSigningKey and sloSigningKeyData") + } + if c.SLOURL != "" && c.EntityIssuer == "" { + // Single Logout Profile (§4.4.4.1) requires on every LogoutRequest; + // without entityIssuer we have no SP entityID to populate it with, and + // most production IdPs (Keycloak, ADFS, Okta, ...) reject Issuer-less + // LogoutRequests outright. + return nil, errors.New("saml: entityIssuer is required when sloURL is set") + } + if c.SLOSigningKey != "" || len(c.SLOSigningKeyData) > 0 { + if c.SLOURL == "" { + return nil, errors.New("saml: sloSigningKey or sloSigningKeyData requires sloURL") + } + var keyPEM []byte + if c.SLOSigningKey != "" { + data, err := os.ReadFile(c.SLOSigningKey) + if err != nil { + return nil, fmt.Errorf("saml: read sloSigningKey: %v", err) + } + keyPEM = data + } else { + keyPEM = c.SLOSigningKeyData + } + sloKey, err := parseRSAPrivateKeyPEM(keyPEM) + if err != nil { + return nil, fmt.Errorf("saml: parse sloSigningKey: %v", err) + } + p.sloSignKey = sloKey + } else if c.SLOURL != "" { + // Per SAML 2.0 Profiles §4.4.3.4 / §4.4.4.1, LogoutRequests sent over + // HTTP-Redirect or HTTP-POST MUST be signed. We allow the unsigned + // configuration (some test setups need it), but warn loudly so the + // operator notices before the IdP rejects every logout. + logger.Warn("saml: sloURL configured without sloSigningKey/sloSigningKeyData; " + + "LogoutRequest will be sent unsigned, which violates SAML 2.0 Profiles §4.4.3.4 " + + "and is rejected by most production IdPs") } if p.nameIDPolicyFormat == "" { @@ -208,8 +254,7 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { } } - needsSLOSigValidation := c.SLOURL != "" && !c.InsecureSkipSLOSignatureValidation - if !c.InsecureSkipSignatureValidation || needsSLOSigValidation { + if !c.InsecureSkipSignatureValidation { if (c.CA == "") == (c.CAData == nil) { return nil, errors.New("must provide either 'ca' or 'caData'") } @@ -253,10 +298,40 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { return p, nil } +// parseRSAPrivateKeyPEM loads the first RSA private key from PEM data (PKCS#1 or PKCS#8). +func parseRSAPrivateKeyPEM(pemData []byte) (*rsa.PrivateKey, error) { + for len(pemData) > 0 { + block, rest := pem.Decode(pemData) + if block == nil { + break + } + switch block.Type { + case "RSA PRIVATE KEY": + k, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return k, nil + case "PRIVATE KEY": + k, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + rsaK, ok := k.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is %T, want RSA", k) + } + return rsaK, nil + } + pemData = rest + } + return nil, errors.New("no RSA private key found in PEM data") +} + var ( - _ connector.SAMLConnector = (*provider)(nil) - _ connector.RefreshConnector = (*provider)(nil) - _ connector.LogoutCallbackConnector = (*provider)(nil) + _ connector.SAMLConnector = (*provider)(nil) + _ connector.RefreshConnector = (*provider)(nil) + _ connector.StatefulLogoutCallbackConnector = (*provider)(nil) ) type provider struct { @@ -284,8 +359,10 @@ type provider struct { nameIDPolicyFormat string - sloURL string - insecureSkipSLOSignatureValidation bool + sloURL string + + // If non-nil, SP-initiated LogoutRequests use HTTP-Redirect query signing. + sloSignKey *rsa.PrivateKey logger *slog.Logger } @@ -756,90 +833,178 @@ func after(now, notOnOrAfter time.Time) bool { } // newRequestID generates a random ID suitable for SAML request IDs. -func newRequestID() string { +func newRequestID() (string, error) { buf := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, buf); err != nil { - panic("crypto/rand failed: " + err.Error()) + return "", fmt.Errorf("crypto/rand failed: %v", err) + } + return fmt.Sprintf("_%x", buf), nil +} + +// sloCallbackURLFromRequest reconstructs the absolute URL of this HTTP request (SLO callback). +// Used to validate LogoutResponse Destination per SAML 2.0 Bindings §3.4.5.2. When Host is +// missing (e.g. some unit tests), returns empty and Destination checking is skipped. +func sloCallbackURLFromRequest(r *http.Request) string { + host := r.Host + if h := r.Header.Get("X-Forwarded-Host"); h != "" { + if i := strings.IndexByte(h, ','); i >= 0 { + h = strings.TrimSpace(h[:i]) + } else { + h = strings.TrimSpace(h) + } + host = h + } + if host == "" { + return "" + } + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if p := r.Header.Get("X-Forwarded-Proto"); p != "" { + if i := strings.IndexByte(p, ','); i >= 0 { + p = strings.TrimSpace(p[:i]) + } else { + p = strings.TrimSpace(p) + } + scheme = p } - return fmt.Sprintf("_%x", buf) + return scheme + "://" + host + r.URL.Path } -// LogoutURL builds a SAML LogoutRequest and returns the IdP's SLO endpoint URL -// with the request encoded using HTTP-Redirect binding (deflate + base64). +// sloURLsMatch compares two URLs ignoring trailing slashes, with case-insensitive +// scheme/host comparison per RFC 3986 §3.1/§3.2.2. Default-port normalization +// (e.g. https://x → https://x:443) is intentionally NOT performed: SAML IdPs +// typically echo the exact Destination they received, and pretending two URLs +// are equal when an operator wrote them differently in their config tends to +// hide misconfigurations rather than fix them. +func sloURLsMatch(a, b string) bool { + pa, errA := url.Parse(strings.TrimSpace(a)) + pb, errB := url.Parse(strings.TrimSpace(b)) + if errA != nil || errB != nil { + return strings.TrimSuffix(strings.TrimSpace(a), "/") == strings.TrimSuffix(strings.TrimSpace(b), "/") + } + if !strings.EqualFold(pa.Scheme, pb.Scheme) { + return false + } + if !strings.EqualFold(pa.Host, pb.Host) { + return false + } + return strings.TrimSuffix(pa.Path, "/") == strings.TrimSuffix(pb.Path, "/") && + pa.RawQuery == pb.RawQuery +} + +// LogoutURLWithState builds a SAML LogoutRequest and returns the IdP's SLO +// endpoint URL with the request encoded using HTTP-Redirect binding +// (deflate + base64). The second return value is the outgoing LogoutRequest +// ID, which the server persists in storage.LogoutState.ConnectorState and +// hands back to HandleLogoutCallbackWithState so InResponseTo can be matched +// against a server-side, one-shot value (defeats replay of captured +// LogoutResponses). +// +// postLogoutRedirectURI is Dex's own /logout/callback URL; SAML doesn't carry +// it in the request (the IdP knows where to send the LogoutResponse via its +// configured SP metadata), so it is intentionally ignored here. // -// See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf -// "3.4 HTTP Redirect Binding" -func (p *provider) LogoutURL(_ context.Context, connectorData []byte, postLogoutRedirectURI string) (string, error) { +// See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf §3.4 +// and https://docs.oasis-open.org/security/saml/v2.0/saml-profiles-2.0-os.pdf §4.4. +func (p *provider) LogoutURLWithState(_ context.Context, connectorData []byte, _ string) (string, []byte, error) { if p.sloURL == "" { - return "", nil + return "", nil, nil } var ci cachedIdentity if len(connectorData) > 0 { if err := json.Unmarshal(connectorData, &ci); err != nil { - return "", fmt.Errorf("saml: failed to unmarshal connector data for logout: %v", err) + return "", nil, fmt.Errorf("saml: failed to unmarshal connector data for logout: %v", err) } } if ci.NameID == "" { - return "", nil + return "", nil, nil + } + + reqID, err := newRequestID() + if err != nil { + return "", nil, fmt.Errorf("saml: %v", err) } req := &logoutRequest{ - ID: newRequestID(), + ID: reqID, IssueInstant: xmlTime(p.now()), Destination: p.sloURL, + Issuer: &issuer{Issuer: p.entityIssuer}, // §4.4.4.1: Issuer is REQUIRED NameID: nameID{ Format: ci.NameIDFormat, Value: ci.NameID, }, } - if p.entityIssuer != "" { - req.Issuer = &issuer{Issuer: p.entityIssuer} - } if ci.SessionIndex != "" { req.SessionIndex = []sessionIndex{{Value: ci.SessionIndex}} } data, err := xml.Marshal(req) if err != nil { - return "", fmt.Errorf("saml: failed to marshal LogoutRequest: %v", err) + return "", nil, fmt.Errorf("saml: failed to marshal LogoutRequest: %v", err) } // HTTP-Redirect binding: deflate then base64-encode. var buf bytes.Buffer fw, err := flate.NewWriter(&buf, flate.DefaultCompression) if err != nil { - return "", fmt.Errorf("saml: failed to create deflate writer: %v", err) + return "", nil, fmt.Errorf("saml: failed to create deflate writer: %v", err) } if _, err := fw.Write(data); err != nil { - return "", fmt.Errorf("saml: failed to deflate LogoutRequest: %v", err) + return "", nil, fmt.Errorf("saml: failed to deflate LogoutRequest: %v", err) } if err := fw.Close(); err != nil { - return "", fmt.Errorf("saml: failed to close deflate writer: %v", err) + return "", nil, fmt.Errorf("saml: failed to close deflate writer: %v", err) } encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) u, err := url.Parse(p.sloURL) if err != nil { - return "", fmt.Errorf("saml: failed to parse SLO URL: %v", err) - } - q := u.Query() - q.Set("SAMLRequest", encoded) - if postLogoutRedirectURI != "" { - q.Set("RelayState", postLogoutRedirectURI) + return "", nil, fmt.Errorf("saml: failed to parse SLO URL: %v", err) + } + + // We do not emit RelayState. SAML 2.0 Bindings §3.4.3 limits it to 80 + // bytes and Dex correlates the SLO flow through the session cookie + + // server-side LogoutState; InResponseTo is matched against the request + // ID we hand back to the server below. + samlReqEscaped := url.QueryEscape(encoded) + baseQuery := "SAMLRequest=" + samlReqEscaped + + if p.sloSignKey != nil { + sigAlgEscaped := url.QueryEscape(defaultRedirectSigAlg) + signedContent := baseQuery + "&SigAlg=" + sigAlgEscaped + h := crypto.SHA256.New() + h.Write([]byte(signedContent)) + sig, err := rsa.SignPKCS1v15(rand.Reader, p.sloSignKey, crypto.SHA256, h.Sum(nil)) + if err != nil { + return "", nil, fmt.Errorf("saml: sign LogoutRequest redirect binding: %v", err) + } + sigB64 := base64.StdEncoding.EncodeToString(sig) + u.RawQuery = signedContent + "&Signature=" + url.QueryEscape(sigB64) + } else { + u.RawQuery = baseQuery } - u.RawQuery = q.Encode() - return u.String(), nil + return u.String(), []byte(reqID), nil } -// HandleLogoutCallback validates the IdP's LogoutResponse received after -// an SP-initiated logout redirect. The response arrives as a SAMLResponse -// parameter via either GET query (HTTP-Redirect binding: deflated + base64) -// or POST form (HTTP-POST binding: base64 only). -func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) error { +// HandleLogoutCallbackWithState validates the IdP's LogoutResponse received +// after an SP-initiated logout redirect. The response arrives as a +// SAMLResponse parameter via either GET query (HTTP-Redirect binding: +// deflated + base64) or POST form (HTTP-POST binding: base64 only). +// +// state is the value previously returned by LogoutURLWithState — for SAML, +// the bytes of the outgoing LogoutRequest ID. The server retrieves it from +// storage.LogoutState.ConnectorState; using a server-side, single-use value +// (instead of an IdP-echoed RelayState) makes InResponseTo replay-resistant +// even when the LogoutResponse is HTTP-POST and RelayState isn't covered by +// the signature. +func (p *provider) HandleLogoutCallbackWithState(_ context.Context, r *http.Request, state []byte) error { var samlResponse string if r.Method == http.MethodGet { samlResponse = r.URL.Query().Get("SAMLResponse") @@ -871,17 +1036,13 @@ func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) erro return fmt.Errorf("saml slo: %w", xrvErr) } - if !p.insecureSkipSLOSignatureValidation { - if r.Method == http.MethodGet && len(p.certs) > 0 { - // HTTP-Redirect binding: signature is in query parameters. - if err := p.validateRedirectSignature(r); err != nil { - return fmt.Errorf("saml slo: %v", err) - } - } else if r.Method != http.MethodGet && p.validator != nil { - // HTTP-POST binding: signature is embedded in XML. - if _, err := p.validateSignature(rawResp); err != nil { - return fmt.Errorf("saml slo: %v", err) - } + if r.Method == http.MethodGet && len(p.certs) > 0 { + if err := p.validateRedirectSignature(r, "SAMLResponse"); err != nil { + return fmt.Errorf("saml slo: %v", err) + } + } else if r.Method != http.MethodGet && p.validator != nil { + if _, err := p.validateSignature(rawResp); err != nil { + return fmt.Errorf("saml slo: %v", err) } } @@ -896,6 +1057,50 @@ func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) erro } } + // §4.4.4.2: MUST be present in LogoutResponse. When ssoIssuer is + // configured, treat a missing Issuer as a rejection too. + if p.ssoIssuer != "" { + if resp.Issuer == nil { + return fmt.Errorf("saml slo: LogoutResponse is missing required Issuer element (expected %q)", p.ssoIssuer) + } + if resp.Issuer.Issuer != p.ssoIssuer { + return fmt.Errorf("saml slo: expected Issuer value %q, got %q", p.ssoIssuer, resp.Issuer.Issuer) + } + } + + issueInstant := time.Time(resp.IssueInstant) + if !issueInstant.IsZero() { + now := p.now() + if before(now, issueInstant) { + return fmt.Errorf("saml slo: LogoutResponse IssueInstant %s is in the future (now: %s)", issueInstant, now) + } + const maxAge = 5 * time.Minute + if now.After(issueInstant.Add(maxAge + allowedClockDrift)) { + return fmt.Errorf("saml slo: LogoutResponse IssueInstant %s is too old (now: %s)", issueInstant, now) + } + } + + if resp.Destination != "" { + if recv := sloCallbackURLFromRequest(r); recv != "" { + if !sloURLsMatch(resp.Destination, recv) { + return fmt.Errorf("saml slo: expected Destination %q, callback URL was %q", resp.Destination, recv) + } + } + } + + // Match InResponseTo against the one-shot request ID the server kept in + // LogoutState.ConnectorState. An empty state means the server didn't have + // one (e.g. legacy session predating this change) — skip the check rather + // than break upgrade flows, but log so it's noticed. + if len(state) > 0 { + expectedReqID := string(state) + if resp.InResponseTo != expectedReqID { + return fmt.Errorf("saml slo: InResponseTo mismatch: expected %q, got %q", expectedReqID, resp.InResponseTo) + } + } else if p.logger != nil { + p.logger.Warn("saml slo: no server-side request ID for InResponseTo check; replay protection disabled for this callback") + } + return nil } @@ -932,9 +1137,9 @@ func rawQueryParam(rawQuery, key string) (string, bool) { // HTTP-Redirect binding. Unlike HTTP-POST where the signature is embedded in // the XML (), HTTP-Redirect carries it as Signature and SigAlg // query parameters. The signed content is reconstructed per SAML 2.0 Bindings -// Section 3.4.4.1: SAMLResponse=value&RelayState=value&SigAlg=value (using -// the original URL-encoded values). -func (p *provider) validateRedirectSignature(r *http.Request) error { +// Section 3.4.4.1: SAMLRequest or SAMLResponse, then optional RelayState, then +// SigAlg (using the original URL-encoded values). +func (p *provider) validateRedirectSignature(r *http.Request, samlMsgParam string) error { rawQuery := r.URL.RawQuery sigEncoded, ok := rawQueryParam(rawQuery, "Signature") @@ -959,8 +1164,8 @@ func (p *provider) validateRedirectSignature(r *http.Request) error { // Reconstruct the signed content in the spec-mandated order. var parts []string - if v, ok := rawQueryParam(rawQuery, "SAMLResponse"); ok { - parts = append(parts, "SAMLResponse="+v) + if v, ok := rawQueryParam(rawQuery, samlMsgParam); ok { + parts = append(parts, samlMsgParam+"="+v) } if v, ok := rawQueryParam(rawQuery, "RelayState"); ok { parts = append(parts, "RelayState="+v) diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index 4a61ca74d3..e6055f9487 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -14,6 +14,7 @@ import ( "encoding/pem" "encoding/xml" "errors" + "fmt" "io" "log/slog" "net/http" @@ -1052,35 +1053,44 @@ func TestLogoutURL(t *testing.T) { NameIDFormat: nameIDFormatEmailAddress, }) - u, err := connNoSLO.LogoutURL(context.Background(), connData, "https://app.example.com/done") + u, state, err := connNoSLO.LogoutURLWithState(context.Background(), connData, "https://app.example.com/done") if err != nil { t.Fatalf("LogoutURL error: %v", err) } if u != "" { t.Errorf("expected empty URL when SLO not configured, got %q", u) } + if state != nil { + t.Errorf("expected nil state when SLO not configured, got %q", state) + } }) t.Run("EmptyNameID", func(t *testing.T) { connData, _ := json.Marshal(cachedIdentity{}) - u, err := connSLO.LogoutURL(context.Background(), connData, "https://app.example.com/done") + u, state, err := connSLO.LogoutURLWithState(context.Background(), connData, "https://app.example.com/done") if err != nil { t.Fatalf("LogoutURL error: %v", err) } if u != "" { t.Errorf("expected empty URL when NameID is empty, got %q", u) } + if state != nil { + t.Errorf("expected nil state when NameID is empty, got %q", state) + } }) t.Run("NilConnectorData", func(t *testing.T) { - u, err := connSLO.LogoutURL(context.Background(), nil, "https://app.example.com/done") + u, state, err := connSLO.LogoutURLWithState(context.Background(), nil, "https://app.example.com/done") if err != nil { t.Fatalf("LogoutURL error: %v", err) } if u != "" { t.Errorf("expected empty URL with nil connector data, got %q", u) } + if state != nil { + t.Errorf("expected nil state with nil connector data, got %q", state) + } }) t.Run("ValidLogoutRequest", func(t *testing.T) { @@ -1090,7 +1100,7 @@ func TestLogoutURL(t *testing.T) { SessionIndex: "session-abc-123", }) - u, err := connSLO.LogoutURL(context.Background(), connData, "https://app.example.com/done") + u, state, err := connSLO.LogoutURLWithState(context.Background(), connData, "https://app.example.com/done") if err != nil { t.Fatalf("LogoutURL error: %v", err) } @@ -1109,12 +1119,19 @@ func TestLogoutURL(t *testing.T) { if parsed.Path != "/slo" { t.Errorf("unexpected path: %q", parsed.Path) } - if parsed.Query().Get("RelayState") != "https://app.example.com/done" { - t.Errorf("unexpected RelayState: %q", parsed.Query().Get("RelayState")) - } - req := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) + // State is the bytes of the outgoing LogoutRequest ID; the server + // persists it in storage.LogoutState.ConnectorState so InResponseTo + // can be matched in HandleLogoutCallback. + if string(state) != req.ID { + t.Errorf("expected state to be request ID %q, got %q", req.ID, state) + } + // We intentionally do not emit RelayState (correlation is server-side). + if rs := parsed.Query().Get("RelayState"); rs != "" { + t.Errorf("expected no RelayState, got %q", rs) + } + if req.NameID.Value != "user@example.com" { t.Errorf("NameID mismatch: got %q", req.NameID.Value) } @@ -1141,7 +1158,7 @@ func TestLogoutURL(t *testing.T) { NameIDFormat: nameIDFormatEmailAddress, }) - u, err := connSLO.LogoutURL(context.Background(), connData, "") + u, state, err := connSLO.LogoutURLWithState(context.Background(), connData, "") if err != nil { t.Fatalf("LogoutURL error: %v", err) } @@ -1151,33 +1168,128 @@ func TestLogoutURL(t *testing.T) { t.Fatalf("failed to parse URL: %v", err) } - if parsed.Query().Get("RelayState") != "" { - t.Error("expected no RelayState when postLogoutRedirectURI is empty") - } - req := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) + + if string(state) != req.ID { + t.Errorf("expected state to be request ID %q, got %q", req.ID, state) + } if len(req.SessionIndex) != 0 { t.Errorf("expected no SessionIndex, got %+v", req.SessionIndex) } }) } +func TestSLOSigningKeyConfigErrors(t *testing.T) { + logger := slog.New(slog.DiscardHandler) + base := Config{ + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSignatureValidation: true, + } + t.Run("BothKeyAndData", func(t *testing.T) { + c := base + c.SLOSigningKey = "testdata/ca.key" + c.SLOSigningKeyData = []byte("dummy") + _, err := c.Open("saml", logger) + if err == nil { + t.Fatal("expected error when both sloSigningKey and sloSigningKeyData are set") + } + }) + t.Run("KeyWithoutSLOURL", func(t *testing.T) { + c := base + c.SLOSigningKey = "testdata/ca.key" + _, err := c.Open("saml", logger) + if err == nil { + t.Fatal("expected error when sloSigningKey is set without sloURL") + } + }) + t.Run("SLOURLWithoutEntityIssuer", func(t *testing.T) { + // Profile §4.4.4.1: MUST be present in LogoutRequest, so + // entityIssuer must be configured whenever SLO is enabled. + c := base + c.SLOURL = "http://idp.example.com/slo" + _, err := c.Open("saml", logger) + if err == nil { + t.Fatal("expected error when sloURL is set without entityIssuer") + } + }) +} + +func TestLogoutURLRedirectSigning(t *testing.T) { + keyPEM, err := os.ReadFile("testdata/ca.key") + if err != nil { + t.Fatal(err) + } + conn, err := (&Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + SLOURL: "http://idp.example.com/slo", + SLOSigningKeyData: keyPEM, + EntityIssuer: "http://127.0.0.1:5556/dex", + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + connData, _ := json.Marshal(cachedIdentity{ + NameID: "user@example.com", + NameIDFormat: nameIDFormatEmailAddress, + }) + logoutURL, state, err := conn.LogoutURLWithState(context.Background(), connData, "https://app.example.com/done") + if err != nil { + t.Fatal(err) + } + if len(state) == 0 { + t.Fatal("expected non-empty state (LogoutRequest ID)") + } + parsed, err := url.Parse(logoutURL) + if err != nil { + t.Fatal(err) + } + if parsed.Query().Get("Signature") == "" { + t.Fatal("expected Signature query parameter") + } + if parsed.Query().Get("SigAlg") == "" { + t.Fatal("expected SigAlg query parameter") + } + req := httptest.NewRequest(http.MethodGet, logoutURL, nil) + if err := conn.validateRedirectSignature(req, "SAMLRequest"); err != nil { + t.Fatalf("signed LogoutURL failed verification: %v", err) + } + t.Run("TamperedQueryRejected", func(t *testing.T) { + bad := strings.Replace(logoutURL, "SAMLRequest=", "SAMLRequest=AAAA", 1) + badReq := httptest.NewRequest(http.MethodGet, bad, nil) + if err := conn.validateRedirectSignature(badReq, "SAMLRequest"); err == nil { + t.Error("expected verification error for tampered SAMLRequest") + } + }) +} + func TestHandleLogoutCallback(t *testing.T) { conn, err := (&Config{ - CA: "testdata/ca.crt", - UsernameAttr: "Name", - EmailAttr: "email", - RedirectURI: "http://127.0.0.1:5556/dex/callback", - SSOURL: "http://foo.bar/", - InsecureSkipSLOSignatureValidation: true, + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSignatureValidation: true, }).openConnector(slog.New(slog.DiscardHandler)) if err != nil { t.Fatal(err) } + // Match the IssueInstant in successLogoutResponseXML / failedLogoutResponseXML. + respTime, _ := time.Parse(timeFormat, "2024-01-01T00:00:00Z") + conn.now = func() time.Time { return respTime } + + // successLogoutResponseXML's InResponseTo is "_req456"; pass matching state. + successState := []byte("_req456") t.Run("EmptySAMLResponse", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/logout/callback", nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { t.Errorf("expected nil error for empty SAMLResponse, got: %v", err) } }) @@ -1185,7 +1297,7 @@ func TestHandleLogoutCallback(t *testing.T) { t.Run("ValidLogoutResponse", func(t *testing.T) { encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, successState); err != nil { t.Errorf("expected no error for valid response, got: %v", err) } }) @@ -1193,14 +1305,14 @@ func TestHandleLogoutCallback(t *testing.T) { t.Run("FailedStatus", func(t *testing.T) { encoded := base64.StdEncoding.EncodeToString([]byte(failedLogoutResponseXML)) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for failed status") } }) t.Run("InvalidBase64", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse=not-valid-base64!!!", nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for invalid base64") } }) @@ -1208,7 +1320,7 @@ func TestHandleLogoutCallback(t *testing.T) { t.Run("InvalidXML", func(t *testing.T) { encoded := base64.StdEncoding.EncodeToString([]byte("not xml at all")) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for invalid XML") } }) @@ -1219,13 +1331,12 @@ func TestHandleLogoutCallback(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/logout/callback", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, successState); err != nil { t.Errorf("expected no error for POST binding, got: %v", err) } }) t.Run("DeflatedResponse", func(t *testing.T) { - // HTTP-Redirect binding: response is deflated + base64 encoded var buf bytes.Buffer fw, err := flate.NewWriter(&buf, flate.DefaultCompression) if err != nil { @@ -1240,12 +1351,257 @@ func TestHandleLogoutCallback(t *testing.T) { encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, successState); err != nil { t.Errorf("expected no error for deflated response, got: %v", err) } }) } +func TestHandleLogoutCallbackIssuerValidation(t *testing.T) { + conn, err := (&Config{ + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + SSOIssuer: "https://correct-idp.example.com", + InsecureSkipSignatureValidation: true, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("MatchingIssuer", func(t *testing.T) { + xml := fmt.Sprintf(` + https://correct-idp.example.com + + `, time.Now().UTC().Format(timeFormat)) + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("MismatchedIssuer", func(t *testing.T) { + xml := fmt.Sprintf(` + https://evil-idp.example.com + + `, time.Now().UTC().Format(timeFormat)) + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { + t.Error("expected error for mismatched issuer") + } + }) + + t.Run("MissingIssuer", func(t *testing.T) { + // Profile §4.4.4.2 — missing Issuer must be rejected when ssoIssuer is configured. + xml := fmt.Sprintf(` + + `, time.Now().UTC().Format(timeFormat)) + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { + t.Error("expected error for missing Issuer when ssoIssuer is configured") + } + }) +} + +func TestHandleLogoutCallbackDestination(t *testing.T) { + conn, err := (&Config{ + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSignatureValidation: true, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + inst := time.Now().UTC().Format(timeFormat) + makeResp := func(dest string) string { + destAttr := "" + if dest != "" { + destAttr = fmt.Sprintf(` Destination="%s"`, dest) + } + return fmt.Sprintf(` + + `, inst, destAttr) + } + + t.Run("MatchingAbsoluteURL", func(t *testing.T) { + dest := "https://dex.example.com/logout/callback" + enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + + t.Run("TrailingSlashEquivalence", func(t *testing.T) { + dest := "https://dex.example.com/logout/callback/" + enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + + t.Run("CaseInsensitiveSchemeHost", func(t *testing.T) { + // RFC 3986 mandates case-insensitive comparison for scheme and host. + dest := "HTTPS://Dex.Example.COM/logout/callback" + enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected nil for case-only scheme/host difference, got %v", err) + } + }) + + t.Run("MismatchedDestination", func(t *testing.T) { + enc := base64.StdEncoding.EncodeToString([]byte(makeResp("https://evil.example.com/callback"))) + u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) + req := httptest.NewRequest(http.MethodGet, u, nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { + t.Error("expected error for wrong Destination") + } + }) + + t.Run("XForwardedProtoAndHost", func(t *testing.T) { + dest := "https://public.example.com/dex/logout/callback" + enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + u := "http://10.0.0.5/dex/logout/callback?SAMLResponse=" + url.QueryEscape(enc) + req := httptest.NewRequest(http.MethodGet, u, nil) + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "public.example.com") + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) +} + +func TestHandleLogoutCallbackIssueInstantFreshness(t *testing.T) { + conn, err := (&Config{ + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSignatureValidation: true, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("FreshResponse", func(t *testing.T) { + conn.now = func() time.Time { return time.Now() } + xml := fmt.Sprintf(` + + `, time.Now().UTC().Format(timeFormat)) + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected no error for fresh response, got: %v", err) + } + }) + + t.Run("StaleResponse", func(t *testing.T) { + conn.now = func() time.Time { return time.Now() } + stale := time.Now().Add(-10 * time.Minute).UTC().Format(timeFormat) + xml := fmt.Sprintf(` + + `, stale) + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { + t.Error("expected error for stale IssueInstant") + } + }) + + t.Run("FutureResponse", func(t *testing.T) { + conn.now = func() time.Time { return time.Now() } + future := time.Now().Add(5 * time.Minute).UTC().Format(timeFormat) + xml := fmt.Sprintf(` + + `, future) + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { + t.Error("expected error for future IssueInstant") + } + }) +} + +func TestHandleLogoutCallbackInResponseTo(t *testing.T) { + conn, err := (&Config{ + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSignatureValidation: true, + }).openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return time.Now() } + + makeResponse := func(inResponseTo string) string { + return fmt.Sprintf(` + + `, time.Now().UTC().Format(timeFormat), inResponseTo) + } + + // state is what server-side LogoutState.ConnectorState carries: the + // outgoing LogoutRequest ID that LogoutURL produced for this user. + t.Run("MatchingID", func(t *testing.T) { + xml := makeResponse("_abc123") + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, []byte("_abc123")); err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + + t.Run("MismatchedID", func(t *testing.T) { + xml := makeResponse("_wrong") + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, []byte("_abc123")); err == nil { + t.Error("expected error for InResponseTo mismatch") + } + }) + + t.Run("NilStateSkipsCheck", func(t *testing.T) { + // Legacy sessions persisted before this change won't have ConnectorState + // — we must not break the upgrade path. + xml := makeResponse("_anything") + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { + t.Errorf("expected no error when state is nil, got: %v", err) + } + }) + + t.Run("RelayStateIgnored", func(t *testing.T) { + // Replay defense: a captured LogoutResponse with an attacker-supplied + // RelayState must NOT be accepted just because RelayState matches — + // only server-side state counts. + xml := makeResponse("_attacker_chosen") + encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + req := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(encoded)+ + "&RelayState="+url.QueryEscape("_attacker_chosen"), + nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, []byte("_real_request_id")); err == nil { + t.Error("expected error: RelayState must not satisfy InResponseTo check") + } + }) +} + // signXMLDocument signs an etree document using the test CA key/cert pair // and returns the resulting XML bytes. func signXMLDocument(t *testing.T, doc *etree.Document) []byte { @@ -1280,16 +1636,17 @@ func postSAMLResponse(encoded string) *http.Request { func TestHandleLogoutCallbackPOSTSignatureValidation(t *testing.T) { conn, err := (&Config{ - CA: "testdata/ca.crt", - UsernameAttr: "Name", - EmailAttr: "email", - RedirectURI: "http://127.0.0.1:5556/dex/callback", - SSOURL: "http://foo.bar/", - InsecureSkipSLOSignatureValidation: false, + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", }).openConnector(slog.New(slog.DiscardHandler)) if err != nil { t.Fatal(err) } + respTime, _ := time.Parse(timeFormat, "2024-01-01T00:00:00Z") + conn.now = func() time.Time { return respTime } t.Run("ValidSignature", func(t *testing.T) { doc := etree.NewDocument() @@ -1299,26 +1656,25 @@ func TestHandleLogoutCallbackPOSTSignatureValidation(t *testing.T) { signedXML := signXMLDocument(t, doc) encoded := base64.StdEncoding.EncodeToString(signedXML) - if err := conn.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err != nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), postSAMLResponse(encoded), nil); err != nil { t.Errorf("expected no error for validly signed response, got: %v", err) } }) t.Run("InvalidSignature", func(t *testing.T) { encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) - if err := conn.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), postSAMLResponse(encoded), nil); err == nil { t.Error("expected error for unsigned response when signature validation is enabled") } }) t.Run("WrongCA", func(t *testing.T) { connBadCA, err := (&Config{ - CA: "testdata/bad-ca.crt", - UsernameAttr: "Name", - EmailAttr: "email", - RedirectURI: "http://127.0.0.1:5556/dex/callback", - SSOURL: "http://foo.bar/", - InsecureSkipSLOSignatureValidation: false, + CA: "testdata/bad-ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", }).openConnector(slog.New(slog.DiscardHandler)) if err != nil { t.Fatal(err) @@ -1331,7 +1687,7 @@ func TestHandleLogoutCallbackPOSTSignatureValidation(t *testing.T) { signedXML := signXMLDocument(t, doc) encoded := base64.StdEncoding.EncodeToString(signedXML) - if err := connBadCA.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err == nil { + if err := connBadCA.HandleLogoutCallbackWithState(context.Background(), postSAMLResponse(encoded), nil); err == nil { t.Error("expected error when response is signed with different CA") } }) @@ -1383,21 +1739,22 @@ func signRedirectBinding(t *testing.T, xmlPayload string, keyFile, certFile stri func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) { conn, err := (&Config{ - CA: "testdata/ca.crt", - UsernameAttr: "Name", - EmailAttr: "email", - RedirectURI: "http://127.0.0.1:5556/dex/callback", - SSOURL: "http://foo.bar/", - InsecureSkipSLOSignatureValidation: false, + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", }).openConnector(slog.New(slog.DiscardHandler)) if err != nil { t.Fatal(err) } + respTime, _ := time.Parse(timeFormat, "2024-01-01T00:00:00Z") + conn.now = func() time.Time { return respTime } t.Run("ValidSignature", func(t *testing.T) { u := signRedirectBinding(t, successLogoutResponseXML, "testdata/ca.key", "testdata/ca.crt") req := httptest.NewRequest(http.MethodGet, u, nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err != nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { t.Errorf("expected no error, got: %v", err) } }) @@ -1411,7 +1768,7 @@ func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for missing Signature parameter") } }) @@ -1419,7 +1776,7 @@ func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) { t.Run("WrongCA", func(t *testing.T) { u := signRedirectBinding(t, successLogoutResponseXML, "testdata/bad-ca.key", "testdata/bad-ca.crt") req := httptest.NewRequest(http.MethodGet, u, nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error when signed with wrong CA") } }) @@ -1429,7 +1786,7 @@ func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) { // Replace part of the SAMLResponse value to simulate tampering. u = strings.Replace(u, "SAMLResponse=", "SAMLResponse=AAAA", 1) req := httptest.NewRequest(http.MethodGet, u, nil) - if err := conn.HandleLogoutCallback(context.Background(), req); err == nil { + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for tampered payload") } }) @@ -1437,14 +1794,17 @@ func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) { func TestSLOEndToEnd(t *testing.T) { c := Config{ - CA: "testdata/ca.crt", - UsernameAttr: "Name", - EmailAttr: "email", - GroupsAttr: "groups", - RedirectURI: "http://127.0.0.1:5556/dex/callback", - SSOURL: "http://foo.bar/", - SLOURL: "http://idp.example.com/slo", - InsecureSkipSLOSignatureValidation: true, + UsernameAttr: "Name", + EmailAttr: "email", + GroupsAttr: "groups", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + SLOURL: "http://idp.example.com/slo", + // EntityIssuer is required when SLOURL is set (Profile §4.4.4.1). + // It also drives audience validation, so it must match good-resp.xml's + // http://127.0.0.1:5556/dex/callback. + EntityIssuer: "http://127.0.0.1:5556/dex/callback", + InsecureSkipSignatureValidation: true, } conn, err := c.openConnector(slog.New(slog.DiscardHandler)) @@ -1475,39 +1835,76 @@ func TestSLOEndToEnd(t *testing.T) { t.Fatal("expected ConnectorData after HandlePOST") } - // Step 2: LogoutURL — build logout redirect URL from ConnectorData - logoutURL, err := conn.LogoutURL(context.Background(), ident.ConnectorData, "https://app.example.com/done") + // Step 2: LogoutURL — build logout redirect URL + connector state from ConnectorData + conn.now = func() time.Time { return time.Now() } + logoutURL, connectorState, err := conn.LogoutURLWithState(context.Background(), ident.ConnectorData, "https://dex.example.com/logout/callback") if err != nil { t.Fatalf("LogoutURL failed: %v", err) } if logoutURL == "" { t.Fatal("expected non-empty logout URL") } + if len(connectorState) == 0 { + t.Fatal("expected non-empty connector state (LogoutRequest ID)") + } parsed, err := url.Parse(logoutURL) if err != nil { t.Fatalf("failed to parse logout URL: %v", err) } - // Verify the LogoutRequest contains the same NameID from HandlePOST - req := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) - if req.NameID.Value != ident.UserID { - t.Errorf("LogoutRequest NameID should match HandlePOST UserID: got %q, want %q", req.NameID.Value, ident.UserID) + logReq := decodeSAMLRequest(t, parsed.Query().Get("SAMLRequest")) + if logReq.NameID.Value != ident.UserID { + t.Errorf("LogoutRequest NameID should match HandlePOST UserID: got %q, want %q", logReq.NameID.Value, ident.UserID) } - if req.NameID.Format != "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" { - t.Errorf("LogoutRequest NameID format mismatch: got %q", req.NameID.Format) + if logReq.NameID.Format != "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" { + t.Errorf("LogoutRequest NameID format mismatch: got %q", logReq.NameID.Format) } - if len(req.SessionIndex) != 1 || req.SessionIndex[0].Value != "6zmm5mguyebwvajyf2sdwwcw6m" { - t.Errorf("LogoutRequest SessionIndex mismatch: %+v", req.SessionIndex) + if len(logReq.SessionIndex) != 1 || logReq.SessionIndex[0].Value != "6zmm5mguyebwvajyf2sdwwcw6m" { + t.Errorf("LogoutRequest SessionIndex mismatch: %+v", logReq.SessionIndex) } - if req.Issuer != nil { - t.Errorf("expected no Issuer when EntityIssuer is not configured, got: %+v", req.Issuer) + // §4.4.4.1: MUST be present. + if logReq.Issuer == nil || logReq.Issuer.Issuer != "http://127.0.0.1:5556/dex/callback" { + t.Errorf("LogoutRequest Issuer mismatch: %+v", logReq.Issuer) + } + + // Connector state must be the outgoing request ID; the server stores it + // in storage.LogoutState.ConnectorState and hands it back below. + if string(connectorState) != logReq.ID { + t.Errorf("connector state should equal request ID %q, got %q", logReq.ID, connectorState) } - // Step 3: HandleLogoutCallback — simulate IdP response - encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) - callbackReq := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) - if err := conn.HandleLogoutCallback(context.Background(), callbackReq); err != nil { + // Step 3: HandleLogoutCallback — simulate IdP response with matching InResponseTo. + // Note: no RelayState is sent or expected; correlation is purely server-side. + logoutResponseXML := fmt.Sprintf(` + https://idp.example.com + + + +`, time.Now().UTC().Format(timeFormat), logReq.ID) + + encoded := base64.StdEncoding.EncodeToString([]byte(logoutResponseXML)) + callbackReq := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), callbackReq, connectorState); err != nil { t.Fatalf("HandleLogoutCallback failed: %v", err) } + + // Step 4: InResponseTo mismatch must be detected even when an attacker + // pre-fills RelayState — only server-side state counts. + badResponseXML := fmt.Sprintf(` + https://idp.example.com + + + +`, time.Now().UTC().Format(timeFormat)) + + badEncoded := base64.StdEncoding.EncodeToString([]byte(badResponseXML)) + badCallbackReq := httptest.NewRequest(http.MethodGet, + "/logout/callback?SAMLResponse="+url.QueryEscape(badEncoded)+ + "&RelayState="+url.QueryEscape("_wrong_id"), + nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), badCallbackReq, connectorState); err == nil { + t.Error("expected error for InResponseTo mismatch") + } } diff --git a/server/logout.go b/server/logout.go index d48f7a59fe..2426b2afdd 100644 --- a/server/logout.go +++ b/server/logout.go @@ -22,7 +22,7 @@ import ( // 2. Extract user identity (subject) and client (audience/azp) from the token // 3. Validate post_logout_redirect_uri against the client's registered URIs // 4. Revoke refresh tokens for the user/connector pair -// 5. If the auth session exists and upstream connector implements LogoutCallbackConnector: +// 5. If the auth session exists and upstream connector implements (Stateful)LogoutCallbackConnector: // a. Store LogoutState + HMAC key in the session (not deleted yet) // b. Redirect to upstream logout with signed state // c. On callback: verify HMAC, read LogoutState from session, delete session, render page @@ -192,10 +192,19 @@ func (s *Server) handleLogoutCallback(w http.ResponseWriter, r *http.Request) { ls := session.LogoutState // Let the connector validate the upstream logout response if it supports it. + // Prefer StatefulLogoutCallbackConnector (replays ls.ConnectorState — e.g. + // SAML's outgoing LogoutRequest ID for InResponseTo correlation) and fall + // back to the simpler LogoutCallbackConnector for stateless connectors. if ls.ConnectorID != "" { conn, err := s.getConnector(ctx, ls.ConnectorID) if err == nil { - if logoutConn, ok := conn.Connector.(connector.LogoutCallbackConnector); ok { + switch logoutConn := conn.Connector.(type) { + case connector.StatefulLogoutCallbackConnector: + if err := logoutConn.HandleLogoutCallbackWithState(ctx, r, ls.ConnectorState); err != nil { + s.logger.ErrorContext(ctx, "logout: upstream logout response validation failed", + "connector_id", ls.ConnectorID, "err", err) + } + case connector.LogoutCallbackConnector: if err := logoutConn.HandleLogoutCallback(ctx, r); err != nil { s.logger.ErrorContext(ctx, "logout: upstream logout response validation failed", "connector_id", ls.ConnectorID, "err", err) @@ -245,8 +254,13 @@ func (s *Server) tryUpstreamLogout(ctx context.Context, userID, connectorID stri return "", false } - logoutConn, ok := conn.Connector.(connector.LogoutCallbackConnector) - if !ok { + // Connectors may implement either the basic LogoutCallbackConnector + // or its stateful variant. We probe for the richer interface first so + // connectors that need server-side correlation state (e.g. SAML's + // LogoutRequest ID for InResponseTo) can hand it back to us. + statefulConn, hasState := conn.Connector.(connector.StatefulLogoutCallbackConnector) + basicConn, hasBasic := conn.Connector.(connector.LogoutCallbackConnector) + if !hasState && !hasBasic { return "", false } @@ -258,22 +272,19 @@ func (s *Server) tryUpstreamLogout(ctx context.Context, userID, connectorID stri return "", false } - // Store logout parameters in the session. - if err := s.storage.UpdateAuthSession(ctx, userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) { - old.LogoutState = &storage.LogoutState{ - PostLogoutRedirectURI: postLogoutRedirectURI, - State: state, - ClientID: clientID, - ConnectorID: connectorID, - } - return old, nil - }); err != nil { - s.logger.ErrorContext(ctx, "logout: failed to save logout state", "err", err) - return "", false - } - + // Build the upstream URL first so any connector-produced state gets + // persisted alongside the rest of LogoutState in a single write. If the + // connector errors we don't want to leave stale logout state behind. callbackURI := s.absURL("/logout/callback") - upstreamURL, err := logoutConn.LogoutURL(ctx, connectorData, callbackURI) + var ( + upstreamURL string + connectorState []byte + ) + if hasState { + upstreamURL, connectorState, err = statefulConn.LogoutURLWithState(ctx, connectorData, callbackURI) + } else { + upstreamURL, err = basicConn.LogoutURL(ctx, connectorData, callbackURI) + } if err != nil { s.logger.ErrorContext(ctx, "logout: upstream connector error", "err", err) return "", false @@ -288,6 +299,20 @@ func (s *Server) tryUpstreamLogout(ctx context.Context, userID, connectorID stri return "", false } + if err := s.storage.UpdateAuthSession(ctx, userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) { + old.LogoutState = &storage.LogoutState{ + PostLogoutRedirectURI: postLogoutRedirectURI, + State: state, + ClientID: clientID, + ConnectorID: connectorID, + ConnectorState: connectorState, + } + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "logout: failed to save logout state", "err", err) + return "", false + } + return u.String(), true } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 33ec950cfd..28be39ef0d 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -59,6 +59,7 @@ func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { {"DeviceTokenCRUD", testDeviceTokenCRUD}, {"UserIdentityCRUD", testUserIdentityCRUD}, {"AuthSessionCRUD", testAuthSessionCRUD}, + {"AuthSessionLogoutState", testAuthSessionLogoutState}, }) } @@ -1462,3 +1463,88 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { _, err = s.GetAuthSession(ctx, session.UserID, session.ConnectorID) mustBeErrNotFound(t, "auth session", err) } + +// testAuthSessionLogoutState verifies that storage backends round-trip the +// LogoutState (including the opaque connector-specific ConnectorState used by +// SAML for InResponseTo correlation) and can clear it back to nil. +// +// Backends serialize LogoutState differently (etcd/kubernetes embed a struct, +// SQL stores a JSON blob, ent uses a nillable bytes column); without this +// test, missing a field in any of those mirrors would silently break SAML SLO +// at runtime with "No logout in progress." after the upstream redirect. +func testAuthSessionLogoutState(t *testing.T, s storage.Storage) { + ctx := t.Context() + + now := time.Now().UTC().Round(time.Millisecond) + + session := storage.AuthSession{ + UserID: "user-logout", + ConnectorID: "conn-logout", + Nonce: storage.NewID(), + ClientStates: map[string]*storage.ClientAuthState{}, + CreatedAt: now, + LastActivity: now, + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(1 * time.Hour), + } + + if err := s.CreateAuthSession(ctx, session); err != nil { + t.Fatalf("create auth session: %v", err) + } + t.Cleanup(func() { + _ = s.DeleteAuthSession(ctx, session.UserID, session.ConnectorID) + }) + + // Initially LogoutState must be nil. + got, err := s.GetAuthSession(ctx, session.UserID, session.ConnectorID) + if err != nil { + t.Fatalf("get auth session: %v", err) + } + if got.LogoutState != nil { + t.Fatalf("expected nil LogoutState on fresh session, got %+v", got.LogoutState) + } + + // Set LogoutState with a non-trivial ConnectorState (mimics SAML's + // outgoing LogoutRequest ID used for InResponseTo correlation). + want := &storage.LogoutState{ + PostLogoutRedirectURI: "https://app.example.com/done", + State: "client-state-xyz", + ClientID: "client1", + ConnectorID: session.ConnectorID, + ConnectorState: []byte("_saml_request_id_12345"), + } + if err := s.UpdateAuthSession(ctx, session.UserID, session.ConnectorID, func(old storage.AuthSession) (storage.AuthSession, error) { + old.LogoutState = want + return old, nil + }); err != nil { + t.Fatalf("update auth session with LogoutState: %v", err) + } + + got, err = s.GetAuthSession(ctx, session.UserID, session.ConnectorID) + if err != nil { + t.Fatalf("get auth session after LogoutState write: %v", err) + } + if got.LogoutState == nil { + t.Fatalf("expected LogoutState to round-trip, got nil") + } + if diff := pretty.Compare(want, got.LogoutState); diff != "" { + t.Errorf("LogoutState did not round-trip: %s", diff) + } + + // Clear LogoutState back to nil; the storage must persist nil, not an + // empty struct (server uses nil to mean "no logout in progress"). + if err := s.UpdateAuthSession(ctx, session.UserID, session.ConnectorID, func(old storage.AuthSession) (storage.AuthSession, error) { + old.LogoutState = nil + return old, nil + }); err != nil { + t.Fatalf("update auth session clearing LogoutState: %v", err) + } + + got, err = s.GetAuthSession(ctx, session.UserID, session.ConnectorID) + if err != nil { + t.Fatalf("get auth session after LogoutState clear: %v", err) + } + if got.LogoutState != nil { + t.Errorf("expected nil LogoutState after clear, got %+v", got.LogoutState) + } +} diff --git a/storage/ent/client/authsession.go b/storage/ent/client/authsession.go index b4cdfe8147..0299828872 100644 --- a/storage/ent/client/authsession.go +++ b/storage/ent/client/authsession.go @@ -19,7 +19,7 @@ func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSe } id := compositeKeyID(session.UserID, session.ConnectorID, d.hasher) - _, err = d.client.AuthSession.Create(). + create := d.client.AuthSession.Create(). SetID(id). SetUserID(session.UserID). SetConnectorID(session.ConnectorID). @@ -30,9 +30,17 @@ func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSe SetIPAddress(session.IPAddress). SetUserAgent(session.UserAgent). SetAbsoluteExpiry(session.AbsoluteExpiry.UTC()). - SetIdleExpiry(session.IdleExpiry.UTC()). - Save(ctx) - if err != nil { + SetIdleExpiry(session.IdleExpiry.UTC()) + + if session.LogoutState != nil { + encodedLogoutState, err := json.Marshal(session.LogoutState) + if err != nil { + return fmt.Errorf("encode logout state auth session: %w", err) + } + create.SetLogoutState(encodedLogoutState) + } + + if _, err := create.Save(ctx); err != nil { return convertDBError("create auth session: %w", err) } return nil @@ -99,15 +107,25 @@ func (d *Database) UpdateAuthSession(ctx context.Context, userID, connectorID st return rollback(tx, "encode client states auth session: %w", err) } - _, err = tx.AuthSession.UpdateOneID(id). + update := tx.AuthSession.UpdateOneID(id). SetClientStates(encodedStates). SetLastActivity(newSession.LastActivity). SetIPAddress(newSession.IPAddress). SetUserAgent(newSession.UserAgent). SetAbsoluteExpiry(newSession.AbsoluteExpiry.UTC()). - SetIdleExpiry(newSession.IdleExpiry.UTC()). - Save(ctx) - if err != nil { + SetIdleExpiry(newSession.IdleExpiry.UTC()) + + if newSession.LogoutState != nil { + encodedLogoutState, err := json.Marshal(newSession.LogoutState) + if err != nil { + return rollback(tx, "encode logout state auth session: %w", err) + } + update.SetLogoutState(encodedLogoutState) + } else { + update.ClearLogoutState() + } + + if _, err = update.Save(ctx); err != nil { return rollback(tx, "update auth session updating: %w", err) } diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 4a6e2bc740..ecd86c6b26 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -254,6 +254,13 @@ func toStorageAuthSession(s *db.AuthSession) storage.AuthSession { } else { result.ClientStates = make(map[string]*storage.ClientAuthState) } + + if s.LogoutState != nil && len(*s.LogoutState) > 0 && string(*s.LogoutState) != "null" { + result.LogoutState = new(storage.LogoutState) + if err := json.Unmarshal(*s.LogoutState, result.LogoutState); err != nil { + panic(err) + } + } return result } diff --git a/storage/ent/db/authsession.go b/storage/ent/db/authsession.go index 6ced0680fc..e706ff52da 100644 --- a/storage/ent/db/authsession.go +++ b/storage/ent/db/authsession.go @@ -36,7 +36,9 @@ type AuthSession struct { // AbsoluteExpiry holds the value of the "absolute_expiry" field. AbsoluteExpiry time.Time `json:"absolute_expiry,omitempty"` // IdleExpiry holds the value of the "idle_expiry" field. - IdleExpiry time.Time `json:"idle_expiry,omitempty"` + IdleExpiry time.Time `json:"idle_expiry,omitempty"` + // LogoutState holds the value of the "logout_state" field. + LogoutState *[]byte `json:"logout_state,omitempty"` selectValues sql.SelectValues } @@ -45,7 +47,7 @@ func (*AuthSession) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case authsession.FieldClientStates: + case authsession.FieldClientStates, authsession.FieldLogoutState: values[i] = new([]byte) case authsession.FieldID, authsession.FieldUserID, authsession.FieldConnectorID, authsession.FieldNonce, authsession.FieldIPAddress, authsession.FieldUserAgent: values[i] = new(sql.NullString) @@ -132,6 +134,12 @@ func (_m *AuthSession) assignValues(columns []string, values []any) error { } else if value.Valid { _m.IdleExpiry = value.Time } + case authsession.FieldLogoutState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field logout_state", values[i]) + } else if value != nil { + _m.LogoutState = value + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -197,6 +205,11 @@ func (_m *AuthSession) String() string { builder.WriteString(", ") builder.WriteString("idle_expiry=") builder.WriteString(_m.IdleExpiry.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.LogoutState; v != nil { + builder.WriteString("logout_state=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authsession/authsession.go b/storage/ent/db/authsession/authsession.go index fc1cd5ed3b..e1c7052a54 100644 --- a/storage/ent/db/authsession/authsession.go +++ b/storage/ent/db/authsession/authsession.go @@ -31,6 +31,8 @@ const ( FieldAbsoluteExpiry = "absolute_expiry" // FieldIdleExpiry holds the string denoting the idle_expiry field in the database. FieldIdleExpiry = "idle_expiry" + // FieldLogoutState holds the string denoting the logout_state field in the database. + FieldLogoutState = "logout_state" // Table holds the table name of the authsession in the database. Table = "auth_sessions" ) @@ -48,6 +50,7 @@ var Columns = []string{ FieldUserAgent, FieldAbsoluteExpiry, FieldIdleExpiry, + FieldLogoutState, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/storage/ent/db/authsession/where.go b/storage/ent/db/authsession/where.go index 193f1133e5..a1fa279852 100644 --- a/storage/ent/db/authsession/where.go +++ b/storage/ent/db/authsession/where.go @@ -114,6 +114,11 @@ func IdleExpiry(v time.Time) predicate.AuthSession { return predicate.AuthSession(sql.FieldEQ(FieldIdleExpiry, v)) } +// LogoutState applies equality check predicate on the "logout_state" field. It's identical to LogoutStateEQ. +func LogoutState(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldLogoutState, v)) +} + // UserIDEQ applies the EQ predicate on the "user_id" field. func UserIDEQ(v string) predicate.AuthSession { return predicate.AuthSession(sql.FieldEQ(FieldUserID, v)) @@ -639,6 +644,56 @@ func IdleExpiryLTE(v time.Time) predicate.AuthSession { return predicate.AuthSession(sql.FieldLTE(FieldIdleExpiry, v)) } +// LogoutStateEQ applies the EQ predicate on the "logout_state" field. +func LogoutStateEQ(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldLogoutState, v)) +} + +// LogoutStateNEQ applies the NEQ predicate on the "logout_state" field. +func LogoutStateNEQ(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldLogoutState, v)) +} + +// LogoutStateIn applies the In predicate on the "logout_state" field. +func LogoutStateIn(vs ...[]byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldLogoutState, vs...)) +} + +// LogoutStateNotIn applies the NotIn predicate on the "logout_state" field. +func LogoutStateNotIn(vs ...[]byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldLogoutState, vs...)) +} + +// LogoutStateGT applies the GT predicate on the "logout_state" field. +func LogoutStateGT(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldLogoutState, v)) +} + +// LogoutStateGTE applies the GTE predicate on the "logout_state" field. +func LogoutStateGTE(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldLogoutState, v)) +} + +// LogoutStateLT applies the LT predicate on the "logout_state" field. +func LogoutStateLT(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldLogoutState, v)) +} + +// LogoutStateLTE applies the LTE predicate on the "logout_state" field. +func LogoutStateLTE(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldLogoutState, v)) +} + +// LogoutStateIsNil applies the IsNil predicate on the "logout_state" field. +func LogoutStateIsNil() predicate.AuthSession { + return predicate.AuthSession(sql.FieldIsNull(FieldLogoutState)) +} + +// LogoutStateNotNil applies the NotNil predicate on the "logout_state" field. +func LogoutStateNotNil() predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotNull(FieldLogoutState)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthSession) predicate.AuthSession { return predicate.AuthSession(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authsession_create.go b/storage/ent/db/authsession_create.go index 0dc99e7615..188077f42b 100644 --- a/storage/ent/db/authsession_create.go +++ b/storage/ent/db/authsession_create.go @@ -96,6 +96,12 @@ func (_c *AuthSessionCreate) SetIdleExpiry(v time.Time) *AuthSessionCreate { return _c } +// SetLogoutState sets the "logout_state" field. +func (_c *AuthSessionCreate) SetLogoutState(v []byte) *AuthSessionCreate { + _c.mutation.SetLogoutState(v) + return _c +} + // SetID sets the "id" field. func (_c *AuthSessionCreate) SetID(v string) *AuthSessionCreate { _c.mutation.SetID(v) @@ -274,6 +280,10 @@ func (_c *AuthSessionCreate) createSpec() (*AuthSession, *sqlgraph.CreateSpec) { _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) _node.IdleExpiry = value } + if value, ok := _c.mutation.LogoutState(); ok { + _spec.SetField(authsession.FieldLogoutState, field.TypeBytes, value) + _node.LogoutState = &value + } return _node, _spec } diff --git a/storage/ent/db/authsession_update.go b/storage/ent/db/authsession_update.go index d80e682b91..cb1544ee7c 100644 --- a/storage/ent/db/authsession_update.go +++ b/storage/ent/db/authsession_update.go @@ -160,6 +160,18 @@ func (_u *AuthSessionUpdate) SetNillableIdleExpiry(v *time.Time) *AuthSessionUpd return _u } +// SetLogoutState sets the "logout_state" field. +func (_u *AuthSessionUpdate) SetLogoutState(v []byte) *AuthSessionUpdate { + _u.mutation.SetLogoutState(v) + return _u +} + +// ClearLogoutState clears the value of the "logout_state" field. +func (_u *AuthSessionUpdate) ClearLogoutState() *AuthSessionUpdate { + _u.mutation.ClearLogoutState() + return _u +} + // Mutation returns the AuthSessionMutation object of the builder. func (_u *AuthSessionUpdate) Mutation() *AuthSessionMutation { return _u.mutation @@ -254,6 +266,12 @@ func (_u *AuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) if value, ok := _u.mutation.IdleExpiry(); ok { _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) } + if value, ok := _u.mutation.LogoutState(); ok { + _spec.SetField(authsession.FieldLogoutState, field.TypeBytes, value) + } + if _u.mutation.LogoutStateCleared() { + _spec.ClearField(authsession.FieldLogoutState, field.TypeBytes) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authsession.Label} @@ -406,6 +424,18 @@ func (_u *AuthSessionUpdateOne) SetNillableIdleExpiry(v *time.Time) *AuthSession return _u } +// SetLogoutState sets the "logout_state" field. +func (_u *AuthSessionUpdateOne) SetLogoutState(v []byte) *AuthSessionUpdateOne { + _u.mutation.SetLogoutState(v) + return _u +} + +// ClearLogoutState clears the value of the "logout_state" field. +func (_u *AuthSessionUpdateOne) ClearLogoutState() *AuthSessionUpdateOne { + _u.mutation.ClearLogoutState() + return _u +} + // Mutation returns the AuthSessionMutation object of the builder. func (_u *AuthSessionUpdateOne) Mutation() *AuthSessionMutation { return _u.mutation @@ -530,6 +560,12 @@ func (_u *AuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *AuthSession if value, ok := _u.mutation.IdleExpiry(); ok { _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) } + if value, ok := _u.mutation.LogoutState(); ok { + _spec.SetField(authsession.FieldLogoutState, field.TypeBytes, value) + } + if _u.mutation.LogoutStateCleared() { + _spec.ClearField(authsession.FieldLogoutState, field.TypeBytes) + } _node = &AuthSession{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index a6050cb333..5a5f881a6e 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -82,6 +82,7 @@ var ( {Name: "user_agent", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "absolute_expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "idle_expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "logout_state", Type: field.TypeBytes, Nullable: true}, } // AuthSessionsTable holds the schema information for the "auth_sessions" table. AuthSessionsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index a21c65765c..5137aa148c 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -3154,6 +3154,7 @@ type AuthSessionMutation struct { user_agent *string absolute_expiry *time.Time idle_expiry *time.Time + logout_state *[]byte clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthSession, error) @@ -3624,6 +3625,55 @@ func (m *AuthSessionMutation) ResetIdleExpiry() { m.idle_expiry = nil } +// SetLogoutState sets the "logout_state" field. +func (m *AuthSessionMutation) SetLogoutState(b []byte) { + m.logout_state = &b +} + +// LogoutState returns the value of the "logout_state" field in the mutation. +func (m *AuthSessionMutation) LogoutState() (r []byte, exists bool) { + v := m.logout_state + if v == nil { + return + } + return *v, true +} + +// OldLogoutState returns the old "logout_state" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldLogoutState(ctx context.Context) (v *[]byte, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLogoutState is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLogoutState requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLogoutState: %w", err) + } + return oldValue.LogoutState, nil +} + +// ClearLogoutState clears the value of the "logout_state" field. +func (m *AuthSessionMutation) ClearLogoutState() { + m.logout_state = nil + m.clearedFields[authsession.FieldLogoutState] = struct{}{} +} + +// LogoutStateCleared returns if the "logout_state" field was cleared in this mutation. +func (m *AuthSessionMutation) LogoutStateCleared() bool { + _, ok := m.clearedFields[authsession.FieldLogoutState] + return ok +} + +// ResetLogoutState resets all changes to the "logout_state" field. +func (m *AuthSessionMutation) ResetLogoutState() { + m.logout_state = nil + delete(m.clearedFields, authsession.FieldLogoutState) +} + // Where appends a list predicates to the AuthSessionMutation builder. func (m *AuthSessionMutation) Where(ps ...predicate.AuthSession) { m.predicates = append(m.predicates, ps...) @@ -3658,7 +3708,7 @@ func (m *AuthSessionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthSessionMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 11) if m.user_id != nil { fields = append(fields, authsession.FieldUserID) } @@ -3689,6 +3739,9 @@ func (m *AuthSessionMutation) Fields() []string { if m.idle_expiry != nil { fields = append(fields, authsession.FieldIdleExpiry) } + if m.logout_state != nil { + fields = append(fields, authsession.FieldLogoutState) + } return fields } @@ -3717,6 +3770,8 @@ func (m *AuthSessionMutation) Field(name string) (ent.Value, bool) { return m.AbsoluteExpiry() case authsession.FieldIdleExpiry: return m.IdleExpiry() + case authsession.FieldLogoutState: + return m.LogoutState() } return nil, false } @@ -3746,6 +3801,8 @@ func (m *AuthSessionMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldAbsoluteExpiry(ctx) case authsession.FieldIdleExpiry: return m.OldIdleExpiry(ctx) + case authsession.FieldLogoutState: + return m.OldLogoutState(ctx) } return nil, fmt.Errorf("unknown AuthSession field %s", name) } @@ -3825,6 +3882,13 @@ func (m *AuthSessionMutation) SetField(name string, value ent.Value) error { } m.SetIdleExpiry(v) return nil + case authsession.FieldLogoutState: + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLogoutState(v) + return nil } return fmt.Errorf("unknown AuthSession field %s", name) } @@ -3854,7 +3918,11 @@ func (m *AuthSessionMutation) AddField(name string, value ent.Value) error { // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *AuthSessionMutation) ClearedFields() []string { - return nil + var fields []string + if m.FieldCleared(authsession.FieldLogoutState) { + fields = append(fields, authsession.FieldLogoutState) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was @@ -3867,6 +3935,11 @@ func (m *AuthSessionMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *AuthSessionMutation) ClearField(name string) error { + switch name { + case authsession.FieldLogoutState: + m.ClearLogoutState() + return nil + } return fmt.Errorf("unknown AuthSession nullable field %s", name) } @@ -3904,6 +3977,9 @@ func (m *AuthSessionMutation) ResetField(name string) error { case authsession.FieldIdleExpiry: m.ResetIdleExpiry() return nil + case authsession.FieldLogoutState: + m.ResetLogoutState() + return nil } return fmt.Errorf("unknown AuthSession field %s", name) } diff --git a/storage/ent/schema/authsession.go b/storage/ent/schema/authsession.go index 0b641b7f7a..76fee7f5cd 100644 --- a/storage/ent/schema/authsession.go +++ b/storage/ent/schema/authsession.go @@ -41,6 +41,9 @@ func (AuthSession) Fields() []ent.Field { SchemaType(timeSchema), field.Time("idle_expiry"). SchemaType(timeSchema), + field.Bytes("logout_state"). + Nillable(). + Optional(), } } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index aabb16f6c1..20db9b63ef 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -336,6 +336,7 @@ type AuthSession struct { UserAgent string `json:"user_agent,omitempty"` AbsoluteExpiry time.Time `json:"absolute_expiry"` IdleExpiry time.Time `json:"idle_expiry"` + LogoutState *storage.LogoutState `json:"logout_state,omitempty"` } func fromStorageAuthSession(s storage.AuthSession) AuthSession { @@ -350,6 +351,7 @@ func fromStorageAuthSession(s storage.AuthSession) AuthSession { UserAgent: s.UserAgent, AbsoluteExpiry: s.AbsoluteExpiry, IdleExpiry: s.IdleExpiry, + LogoutState: s.LogoutState, } } @@ -365,6 +367,7 @@ func toStorageAuthSession(s AuthSession) storage.AuthSession { UserAgent: s.UserAgent, AbsoluteExpiry: s.AbsoluteExpiry, IdleExpiry: s.IdleExpiry, + LogoutState: s.LogoutState, } if result.ClientStates == nil { result.ClientStates = make(map[string]*storage.ClientAuthState) diff --git a/storage/storage.go b/storage/storage.go index 6d5f40b427..bc65ee27ac 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -427,6 +427,13 @@ type LogoutState struct { State string // RP's opaque state parameter ClientID string ConnectorID string + + // ConnectorState is opaque bytes returned by LogoutCallbackConnector.LogoutURL + // and handed back to HandleLogoutCallback. Used by the SAML connector to + // remember the outgoing LogoutRequest ID so it can validate InResponseTo + // against a server-side, one-shot value (defense against replay of captured + // LogoutResponses). Nil for connectors that don't need correlation state. + ConnectorState []byte } // AuthSession represents a user's authentication session from a specific connector. From 482026250b7d8179ebd8ca74247671d0dc43f0db Mon Sep 17 00:00:00 2001 From: Ivan Zvyagintsev Date: Tue, 19 May 2026 11:33:48 +0300 Subject: [PATCH 4/4] apply suggestions from review Signed-off-by: Ivan Zvyagintsev --- connector/saml/saml.go | 18 +++++--- connector/saml/saml_test.go | 83 +++++++++++++++++++++---------------- server/logout.go | 28 ++++++++++++- server/logout_test.go | 69 ++++++++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 41 deletions(-) diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 4ebc7e66d4..f762a025a6 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -1016,7 +1016,7 @@ func (p *provider) HandleLogoutCallbackWithState(_ context.Context, r *http.Requ } if samlResponse == "" { - return nil + return fmt.Errorf("saml slo: missing SAMLResponse parameter") } compressed, err := base64.StdEncoding.DecodeString(samlResponse) @@ -1024,10 +1024,18 @@ func (p *provider) HandleLogoutCallbackWithState(_ context.Context, r *http.Requ return fmt.Errorf("saml slo: failed to decode SAMLResponse: %v", err) } - // HTTP-Redirect binding uses DEFLATE compression; HTTP-POST does not. - // Try to inflate; if it fails, treat the data as uncompressed XML. - rawResp, err := io.ReadAll(flate.NewReader(bytes.NewReader(compressed))) - if err != nil { + // Per SAML 2.0 Bindings: + // §3.4 HTTP-Redirect: SAMLResponse MUST be DEFLATE-encoded, then base64. + // §3.5 HTTP-POST: SAMLResponse is base64 of the raw XML, no DEFLATE. + // Mixing the two would let a malformed response slip through one path while + // pretending to satisfy the other, so we treat the binding strictly. + var rawResp []byte + if r.Method == http.MethodGet { + rawResp, err = io.ReadAll(flate.NewReader(bytes.NewReader(compressed))) + if err != nil { + return fmt.Errorf("saml slo: failed to inflate SAMLResponse (HTTP-Redirect binding requires DEFLATE): %w", err) + } + } else { rawResp = compressed } diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index e6055f9487..1dff5489ba 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -33,6 +33,22 @@ import ( "github.com/dexidp/dex/connector" ) +func redirectBindingEncode(t *testing.T, xmlPayload string) string { + t.Helper() + var buf bytes.Buffer + fw, err := flate.NewWriter(&buf, flate.DefaultCompression) + if err != nil { + t.Fatalf("deflate writer: %v", err) + } + if _, err := fw.Write([]byte(xmlPayload)); err != nil { + t.Fatalf("deflate write: %v", err) + } + if err := fw.Close(); err != nil { + t.Fatalf("deflate close: %v", err) + } + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + // responseTest maps a SAML 2.0 response object to a set of expected values. // // Tests are defined in the "testdata" directory and are self-signed using xmlsec1. @@ -1289,13 +1305,13 @@ func TestHandleLogoutCallback(t *testing.T) { t.Run("EmptySAMLResponse", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/logout/callback", nil) - if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { - t.Errorf("expected nil error for empty SAMLResponse, got: %v", err) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { + t.Error("expected error for empty SAMLResponse") } }) t.Run("ValidLogoutResponse", func(t *testing.T) { - encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) + encoded := redirectBindingEncode(t, successLogoutResponseXML) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, successState); err != nil { t.Errorf("expected no error for valid response, got: %v", err) @@ -1303,7 +1319,7 @@ func TestHandleLogoutCallback(t *testing.T) { }) t.Run("FailedStatus", func(t *testing.T) { - encoded := base64.StdEncoding.EncodeToString([]byte(failedLogoutResponseXML)) + encoded := redirectBindingEncode(t, failedLogoutResponseXML) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for failed status") @@ -1317,8 +1333,17 @@ func TestHandleLogoutCallback(t *testing.T) { } }) + t.Run("GETRequiresDeflate", func(t *testing.T) { + // HTTP-Redirect uses DEFLATE; raw base64(XML) without DEFLATE must be rejected. + encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML)) + req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) + if err := conn.HandleLogoutCallbackWithState(context.Background(), req, successState); err == nil { + t.Error("expected error for non-deflated GET SAMLResponse") + } + }) + t.Run("InvalidXML", func(t *testing.T) { - encoded := base64.StdEncoding.EncodeToString([]byte("not xml at all")) + encoded := redirectBindingEncode(t, "not xml at all") req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for invalid XML") @@ -1337,19 +1362,7 @@ func TestHandleLogoutCallback(t *testing.T) { }) t.Run("DeflatedResponse", func(t *testing.T) { - var buf bytes.Buffer - fw, err := flate.NewWriter(&buf, flate.DefaultCompression) - if err != nil { - t.Fatal(err) - } - if _, err := fw.Write([]byte(successLogoutResponseXML)); err != nil { - t.Fatal(err) - } - if err := fw.Close(); err != nil { - t.Fatal(err) - } - - encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) + encoded := redirectBindingEncode(t, successLogoutResponseXML) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, successState); err != nil { t.Errorf("expected no error for deflated response, got: %v", err) @@ -1375,7 +1388,7 @@ func TestHandleLogoutCallbackIssuerValidation(t *testing.T) { https://correct-idp.example.com `, time.Now().UTC().Format(timeFormat)) - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { t.Errorf("expected no error, got: %v", err) @@ -1387,7 +1400,7 @@ func TestHandleLogoutCallbackIssuerValidation(t *testing.T) { https://evil-idp.example.com `, time.Now().UTC().Format(timeFormat)) - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for mismatched issuer") @@ -1399,7 +1412,7 @@ func TestHandleLogoutCallbackIssuerValidation(t *testing.T) { xml := fmt.Sprintf(` `, time.Now().UTC().Format(timeFormat)) - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for missing Issuer when ssoIssuer is configured") @@ -1431,7 +1444,7 @@ func TestHandleLogoutCallbackDestination(t *testing.T) { t.Run("MatchingAbsoluteURL", func(t *testing.T) { dest := "https://dex.example.com/logout/callback" - enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + enc := redirectBindingEncode(t, makeResp(dest)) u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) req := httptest.NewRequest(http.MethodGet, u, nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { @@ -1441,7 +1454,7 @@ func TestHandleLogoutCallbackDestination(t *testing.T) { t.Run("TrailingSlashEquivalence", func(t *testing.T) { dest := "https://dex.example.com/logout/callback/" - enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + enc := redirectBindingEncode(t, makeResp(dest)) u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) req := httptest.NewRequest(http.MethodGet, u, nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { @@ -1452,7 +1465,7 @@ func TestHandleLogoutCallbackDestination(t *testing.T) { t.Run("CaseInsensitiveSchemeHost", func(t *testing.T) { // RFC 3986 mandates case-insensitive comparison for scheme and host. dest := "HTTPS://Dex.Example.COM/logout/callback" - enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + enc := redirectBindingEncode(t, makeResp(dest)) u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) req := httptest.NewRequest(http.MethodGet, u, nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { @@ -1461,7 +1474,7 @@ func TestHandleLogoutCallbackDestination(t *testing.T) { }) t.Run("MismatchedDestination", func(t *testing.T) { - enc := base64.StdEncoding.EncodeToString([]byte(makeResp("https://evil.example.com/callback"))) + enc := redirectBindingEncode(t, makeResp("https://evil.example.com/callback")) u := "https://dex.example.com/logout/callback?SAMLResponse=" + url.QueryEscape(enc) req := httptest.NewRequest(http.MethodGet, u, nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { @@ -1471,7 +1484,7 @@ func TestHandleLogoutCallbackDestination(t *testing.T) { t.Run("XForwardedProtoAndHost", func(t *testing.T) { dest := "https://public.example.com/dex/logout/callback" - enc := base64.StdEncoding.EncodeToString([]byte(makeResp(dest))) + enc := redirectBindingEncode(t, makeResp(dest)) u := "http://10.0.0.5/dex/logout/callback?SAMLResponse=" + url.QueryEscape(enc) req := httptest.NewRequest(http.MethodGet, u, nil) req.Header.Set("X-Forwarded-Proto", "https") @@ -1499,7 +1512,7 @@ func TestHandleLogoutCallbackIssueInstantFreshness(t *testing.T) { xml := fmt.Sprintf(` `, time.Now().UTC().Format(timeFormat)) - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { t.Errorf("expected no error for fresh response, got: %v", err) @@ -1512,7 +1525,7 @@ func TestHandleLogoutCallbackIssueInstantFreshness(t *testing.T) { xml := fmt.Sprintf(` `, stale) - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for stale IssueInstant") @@ -1525,7 +1538,7 @@ func TestHandleLogoutCallbackIssueInstantFreshness(t *testing.T) { xml := fmt.Sprintf(` `, future) - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err == nil { t.Error("expected error for future IssueInstant") @@ -1556,7 +1569,7 @@ func TestHandleLogoutCallbackInResponseTo(t *testing.T) { // outgoing LogoutRequest ID that LogoutURL produced for this user. t.Run("MatchingID", func(t *testing.T) { xml := makeResponse("_abc123") - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, []byte("_abc123")); err != nil { @@ -1566,7 +1579,7 @@ func TestHandleLogoutCallbackInResponseTo(t *testing.T) { t.Run("MismatchedID", func(t *testing.T) { xml := makeResponse("_wrong") - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, []byte("_abc123")); err == nil { @@ -1578,7 +1591,7 @@ func TestHandleLogoutCallbackInResponseTo(t *testing.T) { // Legacy sessions persisted before this change won't have ConnectorState // — we must not break the upgrade path. xml := makeResponse("_anything") - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), req, nil); err != nil { @@ -1591,7 +1604,7 @@ func TestHandleLogoutCallbackInResponseTo(t *testing.T) { // RelayState must NOT be accepted just because RelayState matches — // only server-side state counts. xml := makeResponse("_attacker_chosen") - encoded := base64.StdEncoding.EncodeToString([]byte(xml)) + encoded := redirectBindingEncode(t, xml) req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded)+ "&RelayState="+url.QueryEscape("_attacker_chosen"), @@ -1883,7 +1896,7 @@ func TestSLOEndToEnd(t *testing.T) { `, time.Now().UTC().Format(timeFormat), logReq.ID) - encoded := base64.StdEncoding.EncodeToString([]byte(logoutResponseXML)) + encoded := redirectBindingEncode(t, logoutResponseXML) callbackReq := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil) if err := conn.HandleLogoutCallbackWithState(context.Background(), callbackReq, connectorState); err != nil { @@ -1899,7 +1912,7 @@ func TestSLOEndToEnd(t *testing.T) { `, time.Now().UTC().Format(timeFormat)) - badEncoded := base64.StdEncoding.EncodeToString([]byte(badResponseXML)) + badEncoded := redirectBindingEncode(t, badResponseXML) badCallbackReq := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(badEncoded)+ "&RelayState="+url.QueryEscape("_wrong_id"), diff --git a/server/logout.go b/server/logout.go index 2426b2afdd..9d17f8be40 100644 --- a/server/logout.go +++ b/server/logout.go @@ -195,14 +195,23 @@ func (s *Server) handleLogoutCallback(w http.ResponseWriter, r *http.Request) { // Prefer StatefulLogoutCallbackConnector (replays ls.ConnectorState — e.g. // SAML's outgoing LogoutRequest ID for InResponseTo correlation) and fall // back to the simpler LogoutCallbackConnector for stateless connectors. + var statefulLogoutErr error if ls.ConnectorID != "" { conn, err := s.getConnector(ctx, ls.ConnectorID) - if err == nil { + if err != nil { + // The upstream connector vanished between the outgoing logout + // redirect and this callback (deleted/reconfigured). We can't + // validate the response and must not silently swallow that — + // surface it in logs so the operator can investigate. + s.logger.ErrorContext(ctx, "logout: failed to resolve connector for callback validation", + "connector_id", ls.ConnectorID, "err", err) + } else { switch logoutConn := conn.Connector.(type) { case connector.StatefulLogoutCallbackConnector: if err := logoutConn.HandleLogoutCallbackWithState(ctx, r, ls.ConnectorState); err != nil { s.logger.ErrorContext(ctx, "logout: upstream logout response validation failed", "connector_id", ls.ConnectorID, "err", err) + statefulLogoutErr = err } case connector.LogoutCallbackConnector: if err := logoutConn.HandleLogoutCallback(ctx, r); err != nil { @@ -213,6 +222,23 @@ func (s *Server) handleLogoutCallback(w http.ResponseWriter, r *http.Request) { } } + // Stateful connectors (e.g. SAML) perform cryptographic validation; do not + // complete Dex logout if that fails — otherwise a forged GET with a valid + // session cookie could clear the session without a valid IdP response. + if statefulLogoutErr != nil { + // Clear LogoutState so a new logout can start from scratch — the current + // one-shot request ID has been used and must not be replayed. + if err := s.storage.UpdateAuthSession(ctx, userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) { + old.LogoutState = nil + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "logout: failed to clear LogoutState after failed validation", + "connector_id", ls.ConnectorID, "err", err) + } + s.renderError(r, w, http.StatusBadRequest, "Upstream logout response validation failed.") + return + } + // Session kept alive until now — delete it and clear the cookie. s.deleteAuthSession(ctx, userID, connectorID) s.clearSessionCookie(w) diff --git a/server/logout_test.go b/server/logout_test.go index 0554cb0ff1..f69dde7bf1 100644 --- a/server/logout_test.go +++ b/server/logout_test.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -12,9 +13,27 @@ import ( "github.com/stretchr/testify/require" + "github.com/dexidp/dex/connector" "github.com/dexidp/dex/storage" ) +// fakeStatefulLogoutConnector always reports a validation failure from +// HandleLogoutCallbackWithState. Used to exercise the failure branch of +// /logout/callback without depending on a real SAML setup. +type fakeStatefulLogoutConnector struct{} + +func (fakeStatefulLogoutConnector) LoginURL(connector.Scopes, string, string) (string, error) { + return "", nil +} + +func (fakeStatefulLogoutConnector) LogoutURLWithState(_ context.Context, _ []byte, _ string) (string, []byte, error) { + return "https://idp.example.com/slo", []byte("_req_id"), nil +} + +func (fakeStatefulLogoutConnector) HandleLogoutCallbackWithState(_ context.Context, _ *http.Request, _ []byte) error { + return errors.New("forced validation failure for tests") +} + func TestHandleLogoutNoSessions(t *testing.T) { httpServer, server := newTestServer(t, nil) defer httpServer.Close() @@ -332,6 +351,56 @@ func TestHandleLogoutFromCookie(t *testing.T) { } } +// TestLogoutCallbackStatefulFailureKeepsSessionClearsLogoutState verifies that +// a failed stateful (e.g. SAML) logout validation: +// - returns HTTP 400, +// - leaves the auth session intact (so a forged GET with only a session +// cookie cannot complete logout), and +// - clears LogoutState so the one-shot correlation state cannot be replayed. +func TestLogoutCallbackStatefulFailureKeepsSessionClearsLogoutState(t *testing.T) { + httpServer, server := newTestServerWithSessions(t, nil) + defer httpServer.Close() + + ctx := t.Context() + userID := "test-user" + connectorID := "stateful-fake" + nonce := "testnonce" + + registerTestConnector(t, server, connectorID, fakeStatefulLogoutConnector{}) + + require.NoError(t, server.storage.CreateAuthSession(ctx, storage.AuthSession{ + UserID: userID, + ConnectorID: connectorID, + Nonce: nonce, + CreatedAt: time.Now(), + LastActivity: time.Now(), + LogoutState: &storage.LogoutState{ + ConnectorID: connectorID, + ConnectorState: []byte("_req_id"), + }, + })) + + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/logout/callback", nil) + req.AddCookie(&http.Cookie{ + Name: "dex_session", + Value: sessionCookieValue(userID, connectorID, nonce, server.sessionConfig.CookieEncryptionKey), + }) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + + got, err := server.storage.GetAuthSession(ctx, userID, connectorID) + require.NoError(t, err, "session must survive failed stateful logout validation") + require.Nil(t, got.LogoutState, "LogoutState must be cleared after failed validation") + + for _, c := range rr.Result().Cookies() { + if c.Name == "dex_session" { + require.NotEqual(t, -1, c.MaxAge, "session cookie must NOT be cleared on failure") + } + } +} + // TestLogoutCallbackWithExpiredSession tests that /logout/callback // returns an error when the session has expired or been deleted. func TestLogoutCallbackWithExpiredSession(t *testing.T) {